jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,523 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import dataclasses
|
4
|
-
from typing import Tuple
|
5
|
-
|
6
|
-
import jax
|
7
|
-
import jax.flatten_util
|
8
|
-
import jax.numpy as jnp
|
9
|
-
import jax_dataclasses
|
10
|
-
import numpy as np
|
11
|
-
|
12
|
-
import jaxsim.physics.model.physics_model
|
13
|
-
import jaxsim.typing as jtp
|
14
|
-
from jaxsim.math.adjoint import Adjoint
|
15
|
-
from jaxsim.math.skew import Skew
|
16
|
-
from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
|
17
|
-
from jaxsim.physics.model.physics_model import PhysicsModel
|
18
|
-
from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass
|
19
|
-
|
20
|
-
from . import utils
|
21
|
-
|
22
|
-
|
23
|
-
@jax_dataclasses.pytree_dataclass
|
24
|
-
class SoftContactsState(JaxsimDataclass):
|
25
|
-
"""
|
26
|
-
State of the soft contacts model.
|
27
|
-
|
28
|
-
Attributes:
|
29
|
-
tangential_deformation:
|
30
|
-
The tangential deformation of the material at each collidable point.
|
31
|
-
"""
|
32
|
-
|
33
|
-
tangential_deformation: jtp.Matrix
|
34
|
-
|
35
|
-
@staticmethod
|
36
|
-
def build(
|
37
|
-
tangential_deformation: jtp.Matrix | None = None,
|
38
|
-
number_of_collidable_points: int | None = None,
|
39
|
-
) -> SoftContactsState:
|
40
|
-
""""""
|
41
|
-
|
42
|
-
tangential_deformation = (
|
43
|
-
tangential_deformation
|
44
|
-
if tangential_deformation is not None
|
45
|
-
else jnp.zeros(shape=(3, number_of_collidable_points))
|
46
|
-
)
|
47
|
-
|
48
|
-
return SoftContactsState(
|
49
|
-
tangential_deformation=jnp.array(tangential_deformation, dtype=float)
|
50
|
-
)
|
51
|
-
|
52
|
-
@staticmethod
|
53
|
-
def build_from_physics_model(
|
54
|
-
tangential_deformation: jtp.Matrix | None = None,
|
55
|
-
physics_model: jaxsim.physics.model.physics_model.PhysicsModel | None = None,
|
56
|
-
) -> SoftContactsState:
|
57
|
-
""""""
|
58
|
-
|
59
|
-
return SoftContactsState.build(
|
60
|
-
tangential_deformation=tangential_deformation,
|
61
|
-
number_of_collidable_points=len(physics_model.gc.body),
|
62
|
-
)
|
63
|
-
|
64
|
-
@staticmethod
|
65
|
-
def zero(
|
66
|
-
physics_model: jaxsim.physics.model.physics_model.PhysicsModel,
|
67
|
-
) -> SoftContactsState:
|
68
|
-
"""
|
69
|
-
Modify the SoftContactsState instance imposing zero tangential deformation.
|
70
|
-
|
71
|
-
Args:
|
72
|
-
physics_model: The physics model.
|
73
|
-
|
74
|
-
Returns:
|
75
|
-
A SoftContactsState instance with zero tangential deformation.
|
76
|
-
"""
|
77
|
-
|
78
|
-
return SoftContactsState.build_from_physics_model(physics_model=physics_model)
|
79
|
-
|
80
|
-
def valid(
|
81
|
-
self, physics_model: jaxsim.physics.model.physics_model.PhysicsModel
|
82
|
-
) -> bool:
|
83
|
-
"""
|
84
|
-
Check if the soft contacts state has valid shape.
|
85
|
-
|
86
|
-
Args:
|
87
|
-
physics_model: The physics model.
|
88
|
-
|
89
|
-
Returns:
|
90
|
-
True if the state has a valid shape, otherwise False.
|
91
|
-
"""
|
92
|
-
|
93
|
-
from jaxsim.simulation.utils import check_valid_shape
|
94
|
-
|
95
|
-
return check_valid_shape(
|
96
|
-
what="tangential_deformation",
|
97
|
-
shape=self.tangential_deformation.shape,
|
98
|
-
expected_shape=(3, len(physics_model.gc.body)),
|
99
|
-
valid=True,
|
100
|
-
)
|
101
|
-
|
102
|
-
|
103
|
-
def collidable_points_pos_vel(
|
104
|
-
model: PhysicsModel,
|
105
|
-
q: jtp.Vector,
|
106
|
-
qd: jtp.Vector,
|
107
|
-
xfb: jtp.Vector | None = None,
|
108
|
-
) -> Tuple[jtp.Matrix, jtp.Matrix]:
|
109
|
-
"""
|
110
|
-
Compute the position and linear velocity of collidable points in the world frame.
|
111
|
-
|
112
|
-
Args:
|
113
|
-
model (PhysicsModel): The physics model.
|
114
|
-
q (jtp.Vector): The joint positions.
|
115
|
-
qd (jtp.Vector): The joint velocities.
|
116
|
-
xfb (jtp.Vector, optional): The floating base state. Defaults to None.
|
117
|
-
|
118
|
-
Returns:
|
119
|
-
Tuple[jtp.Matrix, jtp.Matrix]: A tuple containing the position and velocity of collidable points.
|
120
|
-
"""
|
121
|
-
|
122
|
-
# Make sure that shape and size are correct
|
123
|
-
xfb, q, qd, _, _, _ = utils.process_inputs(physics_model=model, xfb=xfb, q=q, qd=qd)
|
124
|
-
|
125
|
-
# Initialize buffers of link transforms (W_X_i) and 6D inertial velocities (W_v_Wi)
|
126
|
-
W_X_i = jnp.zeros(shape=[model.NB, 6, 6])
|
127
|
-
W_v_Wi = jnp.zeros(shape=[model.NB, 6, 1])
|
128
|
-
|
129
|
-
# 6D transform of base velocity
|
130
|
-
W_X_0 = Adjoint.from_quaternion_and_translation(
|
131
|
-
quaternion=xfb[0:4], translation=xfb[4:7], normalize_quaternion=True
|
132
|
-
)
|
133
|
-
W_X_i = W_X_i.at[0].set(W_X_0)
|
134
|
-
|
135
|
-
# Store the 6D inertial velocity W_v_W0 of the base link
|
136
|
-
W_v_W0 = jnp.vstack(jnp.hstack([xfb[10:13], xfb[7:10]]))
|
137
|
-
W_v_Wi = W_v_Wi.at[0].set(W_v_W0)
|
138
|
-
|
139
|
-
# Compute useful resources from the model
|
140
|
-
S = model.motion_subspaces(q=q)
|
141
|
-
|
142
|
-
# Get the 6D transform between the parent link λi and the joint's predecessor frame
|
143
|
-
pre_X_λi = model.tree_transforms
|
144
|
-
|
145
|
-
# Compute the 6D transform of the joints (from predecessor to successor)
|
146
|
-
i_X_pre = model.joint_transforms(q=q)
|
147
|
-
|
148
|
-
# Parent array mapping: i -> λ(i).
|
149
|
-
# Exception: λ(0) must not be used, it's initialized to -1.
|
150
|
-
λ = model.parent_array()
|
151
|
-
|
152
|
-
# ====================
|
153
|
-
# Propagate kinematics
|
154
|
-
# ====================
|
155
|
-
|
156
|
-
PropagateTransformsCarry = Tuple[jtp.MatrixJax]
|
157
|
-
propagate_transforms_carry: PropagateTransformsCarry = (W_X_i,)
|
158
|
-
|
159
|
-
def propagate_transforms(
|
160
|
-
carry: PropagateTransformsCarry, i: jtp.Int
|
161
|
-
) -> Tuple[PropagateTransformsCarry, None]:
|
162
|
-
# Unpack the carry
|
163
|
-
(W_X_i,) = carry
|
164
|
-
|
165
|
-
# We need the inverse transforms (from parent to child direction)
|
166
|
-
pre_Xi_i = Adjoint.inverse(i_X_pre[i])
|
167
|
-
λi_Xi_pre = Adjoint.inverse(pre_X_λi[i])
|
168
|
-
|
169
|
-
# Compute the parent to child 6D transform
|
170
|
-
λi_X_i = λi_Xi_pre @ pre_Xi_i
|
171
|
-
|
172
|
-
# Compute the world to child 6D transform
|
173
|
-
W_Xi_i = W_X_i[λ[i]] @ λi_X_i
|
174
|
-
W_X_i = W_X_i.at[i].set(W_Xi_i)
|
175
|
-
|
176
|
-
# Pack and return the carry
|
177
|
-
return (W_X_i,), None
|
178
|
-
|
179
|
-
(W_X_i,), _ = jax.lax.scan(
|
180
|
-
f=propagate_transforms,
|
181
|
-
init=propagate_transforms_carry,
|
182
|
-
xs=np.arange(start=1, stop=model.NB),
|
183
|
-
)
|
184
|
-
|
185
|
-
# ====================
|
186
|
-
# Propagate velocities
|
187
|
-
# ====================
|
188
|
-
|
189
|
-
PropagateVelocitiesCarry = Tuple[jtp.MatrixJax]
|
190
|
-
propagate_velocities_carry: PropagateVelocitiesCarry = (W_v_Wi,)
|
191
|
-
|
192
|
-
def propagate_velocities(
|
193
|
-
carry: PropagateVelocitiesCarry, j_vel_and_j_idx: jtp.VectorJax
|
194
|
-
) -> Tuple[PropagateVelocitiesCarry, None]:
|
195
|
-
# Unpack the scanned data
|
196
|
-
qd_ii = j_vel_and_j_idx[0]
|
197
|
-
ii = jnp.array(j_vel_and_j_idx[1], dtype=int)
|
198
|
-
|
199
|
-
# Given a joint whose velocity is qd[ii], the index of its parent link is ii + 1
|
200
|
-
i = ii + 1
|
201
|
-
|
202
|
-
# Unpack the carry
|
203
|
-
(W_v_Wi,) = carry
|
204
|
-
|
205
|
-
# Propagate the 6D velocity
|
206
|
-
W_vi_Wi = W_v_Wi[λ[i]] + W_X_i[i] @ (S[i] * qd_ii)
|
207
|
-
W_v_Wi = W_v_Wi.at[i].set(W_vi_Wi)
|
208
|
-
|
209
|
-
# Pack and return the carry
|
210
|
-
return (W_v_Wi,), None
|
211
|
-
|
212
|
-
(W_v_Wi,), _ = jax.lax.scan(
|
213
|
-
f=propagate_velocities,
|
214
|
-
init=propagate_velocities_carry,
|
215
|
-
xs=jnp.vstack([qd, jnp.arange(start=0, stop=qd.size)]).T,
|
216
|
-
)
|
217
|
-
|
218
|
-
# ==================================================
|
219
|
-
# Compute position and velocity of collidable points
|
220
|
-
# ==================================================
|
221
|
-
|
222
|
-
def process_point_kinematics(
|
223
|
-
Li_p_C: jtp.VectorJax, parent_body: jtp.Int
|
224
|
-
) -> Tuple[jtp.VectorJax, jtp.VectorJax]:
|
225
|
-
# Compute the position of the collidable point
|
226
|
-
W_p_Ci = (
|
227
|
-
Adjoint.to_transform(adjoint=W_X_i[parent_body]) @ jnp.hstack([Li_p_C, 1])
|
228
|
-
)[0:3]
|
229
|
-
|
230
|
-
# Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}
|
231
|
-
CW_vl_WCi = (
|
232
|
-
jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()])
|
233
|
-
@ W_v_Wi[parent_body].squeeze()
|
234
|
-
)
|
235
|
-
|
236
|
-
return W_p_Ci, CW_vl_WCi
|
237
|
-
|
238
|
-
# Process all the collidable points in parallel
|
239
|
-
W_p_Ci, CW_v_WC = jax.vmap(process_point_kinematics)(
|
240
|
-
model.gc.point.T, np.array(model.gc.body, dtype=int)
|
241
|
-
)
|
242
|
-
|
243
|
-
return W_p_Ci.transpose(), CW_v_WC.transpose()
|
244
|
-
|
245
|
-
|
246
|
-
@jax_dataclasses.pytree_dataclass
|
247
|
-
class SoftContactsParams:
|
248
|
-
"""Parameters of the soft contacts model."""
|
249
|
-
|
250
|
-
K: float = dataclasses.field(default_factory=lambda: jnp.array(1e6, dtype=float))
|
251
|
-
D: float = dataclasses.field(default_factory=lambda: jnp.array(2000, dtype=float))
|
252
|
-
mu: float = dataclasses.field(default_factory=lambda: jnp.array(0.5, dtype=float))
|
253
|
-
|
254
|
-
@staticmethod
|
255
|
-
def build(
|
256
|
-
K: jtp.FloatLike = 1e6, D: jtp.FloatLike = 2_000, mu: jtp.FloatLike = 0.5
|
257
|
-
) -> SoftContactsParams:
|
258
|
-
"""
|
259
|
-
Create a SoftContactsParams instance with specified parameters.
|
260
|
-
|
261
|
-
Args:
|
262
|
-
K (float, optional): The stiffness parameter. Defaults to 1e6.
|
263
|
-
D (float, optional): The damping parameter. Defaults to 2000.
|
264
|
-
mu (float, optional): The friction coefficient. Defaults to 0.5.
|
265
|
-
|
266
|
-
Returns:
|
267
|
-
SoftContactsParams: A SoftContactsParams instance with the specified parameters.
|
268
|
-
"""
|
269
|
-
|
270
|
-
return SoftContactsParams(
|
271
|
-
K=jnp.array(K, dtype=float),
|
272
|
-
D=jnp.array(D, dtype=float),
|
273
|
-
mu=jnp.array(mu, dtype=float),
|
274
|
-
)
|
275
|
-
|
276
|
-
@staticmethod
|
277
|
-
def build_default_from_physics_model(
|
278
|
-
physics_model: PhysicsModel,
|
279
|
-
static_friction_coefficient: jtp.FloatLike = 0.5,
|
280
|
-
max_penetration: jtp.FloatLike = 0.001,
|
281
|
-
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
|
282
|
-
damping_ratio: jtp.FloatLike = 1.0,
|
283
|
-
) -> SoftContactsParams:
|
284
|
-
"""
|
285
|
-
Create a SoftContactsParams instance with good default parameters.
|
286
|
-
|
287
|
-
Args:
|
288
|
-
physics_model: The target physics model.
|
289
|
-
static_friction_coefficient: The static friction coefficient.
|
290
|
-
max_penetration: The maximum penetration depth.
|
291
|
-
number_of_active_collidable_points_steady_state: The number of contacts
|
292
|
-
supporting the weight of the model in steady state.
|
293
|
-
damping_ratio: The ratio controlling the damping behavior.
|
294
|
-
|
295
|
-
Returns:
|
296
|
-
A SoftContactsParams instance with the specified parameters.
|
297
|
-
|
298
|
-
Note:
|
299
|
-
The `damping_ratio` parameter allows to operate on the following conditions:
|
300
|
-
- ξ > 1.0: over-damped
|
301
|
-
- ξ = 1.0: critically damped
|
302
|
-
- ξ < 1.0: under-damped
|
303
|
-
"""
|
304
|
-
|
305
|
-
# Use symbols for input parameters
|
306
|
-
ξ = damping_ratio
|
307
|
-
δ_max = max_penetration
|
308
|
-
μc = static_friction_coefficient
|
309
|
-
|
310
|
-
# Compute the total mass of the model
|
311
|
-
m = jnp.array(
|
312
|
-
[l.mass for l in physics_model.description.links_dict.values()]
|
313
|
-
).sum()
|
314
|
-
|
315
|
-
# Extract gravity
|
316
|
-
g = -physics_model.gravity[0:3][-1]
|
317
|
-
|
318
|
-
# Compute the average support force on each collidable point
|
319
|
-
f_average = m * g / number_of_active_collidable_points_steady_state
|
320
|
-
|
321
|
-
# Compute the stiffness to get the desired steady-state penetration
|
322
|
-
K = f_average / jnp.power(δ_max, 3 / 2)
|
323
|
-
|
324
|
-
# Compute the damping using the damping ratio
|
325
|
-
critical_damping = 2 * jnp.sqrt(K * m)
|
326
|
-
D = ξ * critical_damping
|
327
|
-
|
328
|
-
return SoftContactsParams.build(K=K, D=D, mu=μc)
|
329
|
-
|
330
|
-
|
331
|
-
@jax_dataclasses.pytree_dataclass
|
332
|
-
class SoftContacts:
|
333
|
-
"""Soft contacts model."""
|
334
|
-
|
335
|
-
parameters: SoftContactsParams = dataclasses.field(
|
336
|
-
default_factory=SoftContactsParams
|
337
|
-
)
|
338
|
-
|
339
|
-
terrain: Terrain = dataclasses.field(default_factory=FlatTerrain)
|
340
|
-
|
341
|
-
def contact_model(
|
342
|
-
self,
|
343
|
-
position: jtp.Vector,
|
344
|
-
velocity: jtp.Vector,
|
345
|
-
tangential_deformation: jtp.Vector,
|
346
|
-
) -> Tuple[jtp.Vector, jtp.Vector]:
|
347
|
-
"""
|
348
|
-
Compute the contact forces and material deformation rate.
|
349
|
-
|
350
|
-
Args:
|
351
|
-
position (jtp.Vector): The position of the collidable point.
|
352
|
-
velocity (jtp.Vector): The linear velocity of the collidable point.
|
353
|
-
tangential_deformation (jtp.Vector): The tangential deformation.
|
354
|
-
|
355
|
-
Returns:
|
356
|
-
Tuple[jtp.Vector, jtp.Vector]: A tuple containing the contact force and material deformation rate.
|
357
|
-
"""
|
358
|
-
|
359
|
-
# Short name of parameters
|
360
|
-
K = self.parameters.K
|
361
|
-
D = self.parameters.D
|
362
|
-
μ = self.parameters.mu
|
363
|
-
|
364
|
-
# Material 3D tangential deformation and its derivative
|
365
|
-
m = tangential_deformation.squeeze()
|
366
|
-
ṁ = jnp.zeros_like(m)
|
367
|
-
|
368
|
-
# Note: all the small hardcoded tolerances in this method have been introduced
|
369
|
-
# to allow jax differentiating through this algorithm. They should not affect
|
370
|
-
# the accuracy of the simulation, although they might make it less readable.
|
371
|
-
|
372
|
-
# ========================
|
373
|
-
# Normal force computation
|
374
|
-
# ========================
|
375
|
-
|
376
|
-
# Unpack the position of the collidable point
|
377
|
-
px, py, pz = W_p_C = position.squeeze()
|
378
|
-
vx, vy, vz = W_ṗ_C = velocity.squeeze()
|
379
|
-
|
380
|
-
# Compute the terrain normal and the contact depth
|
381
|
-
n̂ = self.terrain.normal(x=px, y=py).squeeze()
|
382
|
-
h = jnp.array([0, 0, self.terrain.height(x=px, y=py) - pz])
|
383
|
-
|
384
|
-
# Compute the penetration depth normal to the terrain
|
385
|
-
δ = jnp.maximum(0.0, jnp.dot(h, n̂))
|
386
|
-
|
387
|
-
# Compute the penetration normal velocity
|
388
|
-
δ̇ = -jnp.dot(W_ṗ_C, n̂)
|
389
|
-
|
390
|
-
# Non-linear spring-damper model.
|
391
|
-
# This is the force magnitude along the direction normal to the terrain.
|
392
|
-
force_normal_mag = jax.lax.select(
|
393
|
-
pred=δ >= 1e-9,
|
394
|
-
on_true=jnp.sqrt(δ + 1e-12) * (K * δ + D * δ̇),
|
395
|
-
on_false=jnp.array(0.0),
|
396
|
-
)
|
397
|
-
|
398
|
-
# Prevent negative normal forces that might occur when δ̇ is largely negative
|
399
|
-
force_normal_mag = jnp.maximum(0.0, force_normal_mag)
|
400
|
-
|
401
|
-
# Compute the 3D linear force in C[W] frame
|
402
|
-
force_normal = force_normal_mag * n̂
|
403
|
-
|
404
|
-
# ====================================
|
405
|
-
# No friction and no tangential forces
|
406
|
-
# ====================================
|
407
|
-
|
408
|
-
# Compute the adjoint C[W]->W for transforming 6D forces from mixed to inertial.
|
409
|
-
# Note: this is equal to the 6D velocities transform: CW_X_W.transpose().
|
410
|
-
W_Xf_CW = jnp.vstack(
|
411
|
-
[
|
412
|
-
jnp.block([jnp.eye(3), jnp.zeros(shape=(3, 3))]),
|
413
|
-
jnp.block([Skew.wedge(W_p_C), jnp.eye(3)]),
|
414
|
-
]
|
415
|
-
)
|
416
|
-
|
417
|
-
def with_no_friction():
|
418
|
-
# Compute 6D mixed force in C[W]
|
419
|
-
CW_f_lin = force_normal
|
420
|
-
CW_f = jnp.hstack([force_normal, jnp.zeros_like(CW_f_lin)])
|
421
|
-
|
422
|
-
# Compute lin-ang 6D forces (inertial representation)
|
423
|
-
W_f = W_Xf_CW @ CW_f
|
424
|
-
|
425
|
-
return W_f, ṁ
|
426
|
-
|
427
|
-
# =========================
|
428
|
-
# Compute tangential forces
|
429
|
-
# =========================
|
430
|
-
|
431
|
-
def with_friction():
|
432
|
-
# Initialize the tangential deformation rate ṁ.
|
433
|
-
# For inactive contacts with m≠0, this is the dynamics of the material
|
434
|
-
# relaxation converging exponentially to steady state.
|
435
|
-
ṁ = (-K / D) * m
|
436
|
-
|
437
|
-
# Check if the collidable point is below ground.
|
438
|
-
# Note: when δ=0, we consider the point still not it contact such that
|
439
|
-
# we prevent divisions by 0 in the computations below.
|
440
|
-
active_contact = pz < self.terrain.height(x=px, y=py)
|
441
|
-
|
442
|
-
def above_terrain():
|
443
|
-
return jnp.zeros(6), ṁ
|
444
|
-
|
445
|
-
def below_terrain():
|
446
|
-
# Decompose the velocity in normal and tangential components
|
447
|
-
v_normal = jnp.dot(W_ṗ_C, n̂) * n̂
|
448
|
-
v_tangential = W_ṗ_C - v_normal
|
449
|
-
|
450
|
-
# Compute the tangential force. If inside the friction cone, the contact
|
451
|
-
f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential)
|
452
|
-
|
453
|
-
def sticking_contact():
|
454
|
-
# Sum the normal and tangential forces, and create the 6D force
|
455
|
-
CW_f_stick = force_normal + f_tangential
|
456
|
-
CW_f = jnp.hstack([CW_f_stick, jnp.zeros(3)])
|
457
|
-
|
458
|
-
# In this case the 3D material deformation is the tangential velocity
|
459
|
-
ṁ = v_tangential
|
460
|
-
|
461
|
-
# Return the 6D force in the contact frame and
|
462
|
-
# the deformation derivative
|
463
|
-
return CW_f, ṁ
|
464
|
-
|
465
|
-
def slipping_contact():
|
466
|
-
# Clip the tangential force if too small, allowing jax to
|
467
|
-
# differentiate through the norm computation
|
468
|
-
f_tangential_no_nan = jax.lax.select(
|
469
|
-
pred=f_tangential.dot(f_tangential) >= 1e-9**2,
|
470
|
-
on_true=f_tangential,
|
471
|
-
on_false=jnp.array([1e-12, 0, 0]),
|
472
|
-
)
|
473
|
-
|
474
|
-
# Project the force to the friction cone boundary
|
475
|
-
f_tangential_projected = (μ * force_normal_mag) * (
|
476
|
-
f_tangential / jnp.linalg.norm(f_tangential_no_nan)
|
477
|
-
)
|
478
|
-
|
479
|
-
# Sum the normal and tangential forces, and create the 6D force
|
480
|
-
CW_f_slip = force_normal + f_tangential_projected
|
481
|
-
CW_f = jnp.hstack([CW_f_slip, jnp.zeros(3)])
|
482
|
-
|
483
|
-
# Correct the material deformation derivative for slipping contacts.
|
484
|
-
# Basically we compute ṁ such that we get `f_tangential` on the cone
|
485
|
-
# given the current (m, δ).
|
486
|
-
ε = 1e-9
|
487
|
-
δε = jnp.maximum(δ, ε)
|
488
|
-
α = -K * jnp.sqrt(δε)
|
489
|
-
β = -D * jnp.sqrt(δε)
|
490
|
-
ṁ = (f_tangential_projected - α * m) / β
|
491
|
-
|
492
|
-
# Return the 6D force in the contact frame and
|
493
|
-
# the deformation derivative
|
494
|
-
return CW_f, ṁ
|
495
|
-
|
496
|
-
CW_f, ṁ = jax.lax.cond(
|
497
|
-
pred=f_tangential.dot(f_tangential) > (μ * force_normal_mag) ** 2,
|
498
|
-
true_fun=lambda _: slipping_contact(),
|
499
|
-
false_fun=lambda _: sticking_contact(),
|
500
|
-
operand=None,
|
501
|
-
)
|
502
|
-
|
503
|
-
# Express the 6D force in the world frame
|
504
|
-
W_f = W_Xf_CW @ CW_f
|
505
|
-
|
506
|
-
# Return the 6D force in the world frame and the deformation derivative
|
507
|
-
return W_f, ṁ
|
508
|
-
|
509
|
-
# (W_f, ṁ)
|
510
|
-
return jax.lax.cond(
|
511
|
-
pred=active_contact,
|
512
|
-
true_fun=lambda _: below_terrain(),
|
513
|
-
false_fun=lambda _: above_terrain(),
|
514
|
-
operand=None,
|
515
|
-
)
|
516
|
-
|
517
|
-
# (W_f, ṁ)
|
518
|
-
return jax.lax.cond(
|
519
|
-
pred=(μ == 0.0),
|
520
|
-
true_fun=lambda _: with_no_friction(),
|
521
|
-
false_fun=lambda _: with_friction(),
|
522
|
-
operand=None,
|
523
|
-
)
|
jaxsim/physics/algos/terrain.py
DELETED
@@ -1,78 +0,0 @@
|
|
1
|
-
import abc
|
2
|
-
|
3
|
-
import jax.numpy as jnp
|
4
|
-
import jax_dataclasses
|
5
|
-
|
6
|
-
import jaxsim.typing as jtp
|
7
|
-
|
8
|
-
|
9
|
-
class Terrain(abc.ABC):
|
10
|
-
delta = 0.010
|
11
|
-
|
12
|
-
@abc.abstractmethod
|
13
|
-
def height(self, x: float, y: float) -> float:
|
14
|
-
pass
|
15
|
-
|
16
|
-
def normal(self, x: float, y: float) -> jtp.Vector:
|
17
|
-
"""
|
18
|
-
Compute the normal vector of the terrain at a specific (x, y) location.
|
19
|
-
|
20
|
-
Args:
|
21
|
-
x (float): The x-coordinate of the location.
|
22
|
-
y (float): The y-coordinate of the location.
|
23
|
-
|
24
|
-
Returns:
|
25
|
-
jtp.Vector: The normal vector of the terrain surface at the specified location.
|
26
|
-
"""
|
27
|
-
|
28
|
-
# https://stackoverflow.com/a/5282364
|
29
|
-
h_xp = self.height(x=x + self.delta, y=y)
|
30
|
-
h_xm = self.height(x=x - self.delta, y=y)
|
31
|
-
h_yp = self.height(x=x, y=y + self.delta)
|
32
|
-
h_ym = self.height(x=x, y=y - self.delta)
|
33
|
-
|
34
|
-
n = jnp.array(
|
35
|
-
[(h_xm - h_xp) / (2 * self.delta), (h_ym - h_yp) / (2 * self.delta), 1.0]
|
36
|
-
)
|
37
|
-
|
38
|
-
return n / jnp.linalg.norm(n)
|
39
|
-
|
40
|
-
|
41
|
-
@jax_dataclasses.pytree_dataclass
|
42
|
-
class FlatTerrain(Terrain):
|
43
|
-
def height(self, x: float, y: float) -> float:
|
44
|
-
return 0.0
|
45
|
-
|
46
|
-
|
47
|
-
@jax_dataclasses.pytree_dataclass
|
48
|
-
class PlaneTerrain(Terrain):
|
49
|
-
plane_normal: list = jax_dataclasses.field(default_factory=lambda: [0, 0, 1.0])
|
50
|
-
|
51
|
-
@staticmethod
|
52
|
-
def build(plane_normal: list) -> "PlaneTerrain":
|
53
|
-
"""
|
54
|
-
Create a PlaneTerrain instance with a specified plane normal vector.
|
55
|
-
|
56
|
-
Args:
|
57
|
-
plane_normal (list): The normal vector of the terrain plane.
|
58
|
-
|
59
|
-
Returns:
|
60
|
-
PlaneTerrain: A PlaneTerrain instance.
|
61
|
-
"""
|
62
|
-
|
63
|
-
return PlaneTerrain(plane_normal=plane_normal)
|
64
|
-
|
65
|
-
def height(self, x: float, y: float) -> float:
|
66
|
-
"""
|
67
|
-
Compute the height of the terrain at a specific (x, y) location on a plane.
|
68
|
-
|
69
|
-
Args:
|
70
|
-
x (float): The x-coordinate of the location.
|
71
|
-
y (float): The y-coordinate of the location.
|
72
|
-
|
73
|
-
Returns:
|
74
|
-
float: The height of the terrain at the specified location on the plane.
|
75
|
-
"""
|
76
|
-
|
77
|
-
a, b, c = self.plane_normal
|
78
|
-
return -(a * x + b * y) / c
|
jaxsim/physics/algos/utils.py
DELETED
@@ -1,69 +0,0 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
|
-
import jax.numpy as jnp
|
4
|
-
|
5
|
-
import jaxsim.typing as jtp
|
6
|
-
from jaxsim.physics.model.physics_model import PhysicsModel
|
7
|
-
|
8
|
-
|
9
|
-
def process_inputs(
|
10
|
-
physics_model: PhysicsModel,
|
11
|
-
xfb: jtp.Vector | None = None,
|
12
|
-
q: jtp.Vector | None = None,
|
13
|
-
qd: jtp.Vector | None = None,
|
14
|
-
qdd: jtp.Vector | None = None,
|
15
|
-
tau: jtp.Vector | None = None,
|
16
|
-
f_ext: jtp.Matrix | None = None,
|
17
|
-
) -> Tuple[jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector, jtp.Matrix]:
|
18
|
-
"""
|
19
|
-
Adjust the inputs to the physics model.
|
20
|
-
|
21
|
-
Args:
|
22
|
-
physics_model: The physics model.
|
23
|
-
xfb: The variables of the base link.
|
24
|
-
q: The generalized coordinates.
|
25
|
-
qd: The generalized velocities.
|
26
|
-
qdd: The generalized accelerations.
|
27
|
-
tau: The generalized forces.
|
28
|
-
f_ext: The link external forces.
|
29
|
-
|
30
|
-
Returns:
|
31
|
-
The adjusted inputs.
|
32
|
-
"""
|
33
|
-
|
34
|
-
# Remove extra dimensions
|
35
|
-
q = q.squeeze() if q is not None else jnp.zeros(physics_model.dofs())
|
36
|
-
qd = qd.squeeze() if qd is not None else jnp.zeros(physics_model.dofs())
|
37
|
-
qdd = qdd.squeeze() if qdd is not None else jnp.zeros(physics_model.dofs())
|
38
|
-
tau = tau.squeeze() if tau is not None else jnp.zeros(physics_model.dofs())
|
39
|
-
xfb = xfb.squeeze() if xfb is not None else jnp.zeros(13).at[0].set(1)
|
40
|
-
f_ext = (
|
41
|
-
f_ext.squeeze()
|
42
|
-
if f_ext is not None
|
43
|
-
else jnp.zeros(shape=(physics_model.NB, 6)).squeeze()
|
44
|
-
)
|
45
|
-
|
46
|
-
# Fix case with just 1 DoF
|
47
|
-
q = jnp.atleast_1d(q)
|
48
|
-
qd = jnp.atleast_1d(qd)
|
49
|
-
qdd = jnp.atleast_1d(qdd)
|
50
|
-
tau = jnp.atleast_1d(tau)
|
51
|
-
|
52
|
-
# Fix case with just 1 body
|
53
|
-
f_ext = jnp.atleast_2d(f_ext)
|
54
|
-
|
55
|
-
# Validate dimensions
|
56
|
-
dofs = physics_model.dofs()
|
57
|
-
|
58
|
-
if xfb is not None and xfb.shape[0] != 13:
|
59
|
-
raise ValueError(xfb.shape)
|
60
|
-
if q is not None and q.shape[0] != dofs:
|
61
|
-
raise ValueError(q.shape, dofs)
|
62
|
-
if qd is not None and qd.shape[0] != dofs:
|
63
|
-
raise ValueError(qd.shape, dofs)
|
64
|
-
if tau is not None and tau.shape[0] != dofs:
|
65
|
-
raise ValueError(tau.shape, dofs)
|
66
|
-
if f_ext is not None and f_ext.shape != (physics_model.NB, 6):
|
67
|
-
raise ValueError(f_ext.shape, (physics_model.NB, 6))
|
68
|
-
|
69
|
-
return xfb, q, qd, qdd, tau, f_ext
|
jaxsim/physics/model/__init__.py
DELETED
File without changes
|