jaxsim 0.6.2.dev2__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +1 -1
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/actuation_model.py +96 -0
- jaxsim/api/com.py +8 -8
- jaxsim/api/contact.py +15 -255
- jaxsim/api/contact_model.py +101 -0
- jaxsim/api/data.py +258 -556
- jaxsim/api/frame.py +7 -7
- jaxsim/api/integrators.py +76 -0
- jaxsim/api/kin_dyn_parameters.py +41 -58
- jaxsim/api/link.py +7 -7
- jaxsim/api/model.py +190 -453
- jaxsim/api/ode.py +34 -338
- jaxsim/api/references.py +2 -2
- jaxsim/exceptions.py +2 -2
- jaxsim/math/__init__.py +4 -3
- jaxsim/math/joint_model.py +17 -107
- jaxsim/mujoco/model.py +1 -1
- jaxsim/mujoco/utils.py +2 -2
- jaxsim/parsers/kinematic_graph.py +1 -3
- jaxsim/rbda/aba.py +7 -4
- jaxsim/rbda/collidable_points.py +7 -98
- jaxsim/rbda/contacts/__init__.py +2 -10
- jaxsim/rbda/contacts/common.py +0 -138
- jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
- jaxsim/rbda/crba.py +5 -2
- jaxsim/rbda/forward_kinematics.py +37 -12
- jaxsim/rbda/jacobian.py +15 -6
- jaxsim/rbda/rnea.py +7 -4
- jaxsim/rbda/utils.py +3 -3
- jaxsim/utils/jaxsim_dataclass.py +5 -1
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
- jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
- jaxsim/api/ode_data.py +0 -401
- jaxsim/integrators/__init__.py +0 -2
- jaxsim/integrators/common.py +0 -592
- jaxsim/integrators/fixed_step.py +0 -153
- jaxsim/integrators/variable_step.py +0 -706
- jaxsim/rbda/contacts/rigid.py +0 -462
- jaxsim/rbda/contacts/soft.py +0 -480
- jaxsim/rbda/contacts/visco_elastic.py +0 -1066
- jaxsim-0.6.2.dev2.dist-info/RECORD +0 -74
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
jaxsim/api/ode.py
CHANGED
@@ -1,230 +1,24 @@
|
|
1
|
-
from typing import Any, Protocol
|
2
|
-
|
3
1
|
import jax
|
4
2
|
import jax.numpy as jnp
|
5
3
|
|
6
4
|
import jaxsim.api as js
|
7
|
-
import jaxsim.rbda
|
8
5
|
import jaxsim.typing as jtp
|
9
|
-
from jaxsim.
|
10
|
-
from jaxsim.math import Quaternion
|
11
|
-
from jaxsim.rbda import contacts
|
6
|
+
from jaxsim.api.data import JaxSimModelData
|
7
|
+
from jaxsim.math import Quaternion, Skew
|
12
8
|
|
13
9
|
from .common import VelRepr
|
14
|
-
from .ode_data import ODEState
|
15
|
-
|
16
|
-
|
17
|
-
class SystemDynamicsFromModelAndData(Protocol):
|
18
|
-
"""
|
19
|
-
Protocol defining the signature of a function computing the system dynamics
|
20
|
-
given a model and data object.
|
21
|
-
"""
|
22
|
-
|
23
|
-
def __call__(
|
24
|
-
self,
|
25
|
-
model: js.model.JaxSimModel,
|
26
|
-
data: js.data.JaxSimModelData,
|
27
|
-
**kwargs: dict[str, Any],
|
28
|
-
) -> tuple[ODEState, dict[str, Any]]:
|
29
|
-
"""
|
30
|
-
Compute the system dynamics given a model and data object.
|
31
|
-
|
32
|
-
Args:
|
33
|
-
model: The model to consider.
|
34
|
-
data: The data of the considered model.
|
35
|
-
**kwargs: Additional keyword arguments.
|
36
|
-
|
37
|
-
Returns:
|
38
|
-
A tuple with an `ODEState` object storing in each of its attributes the
|
39
|
-
corresponding derivative, and the dictionary of auxiliary data returned
|
40
|
-
by the system dynamics evaluation.
|
41
|
-
"""
|
42
|
-
|
43
|
-
pass
|
44
|
-
|
45
|
-
|
46
|
-
def wrap_system_dynamics_for_integration(
|
47
|
-
*,
|
48
|
-
system_dynamics: SystemDynamicsFromModelAndData,
|
49
|
-
**kwargs: dict[str, Any],
|
50
|
-
) -> jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]:
|
51
|
-
"""
|
52
|
-
Wrap the system dynamics considered by JaxSim integrators in a generic
|
53
|
-
`f(x, t, **u, **parameters)` function.
|
54
|
-
|
55
|
-
Args:
|
56
|
-
system_dynamics: The system dynamics to wrap.
|
57
|
-
**kwargs: Additional kwargs to close over the system dynamics.
|
58
|
-
|
59
|
-
Returns:
|
60
|
-
The system dynamics closed over the additional kwargs to be used by
|
61
|
-
JaxSim integrators.
|
62
|
-
"""
|
63
|
-
|
64
|
-
# Close `system_dynamics` over additional kwargs.
|
65
|
-
# Similarly to what done in `jaxsim.api.model.step`, to be future-proof, we use the
|
66
|
-
# following logic to allow the caller to close over arguments having the same name
|
67
|
-
# of the ones used in the `wrap_system_dynamics_for_integration` function.
|
68
|
-
kwargs = kwargs.copy() if kwargs is not None else {}
|
69
|
-
colliding_system_dynamics_kwargs = kwargs.pop("system_dynamics_kwargs", {})
|
70
|
-
system_dynamics_kwargs = kwargs | colliding_system_dynamics_kwargs
|
71
|
-
|
72
|
-
# Remove `model` and `data` for backward compatibility.
|
73
|
-
# It's no longer necessary to close over them at this stage, as this is always
|
74
|
-
# done in `jaxsim.api.model.step`.
|
75
|
-
# We can remove the following lines in a few releases.
|
76
|
-
_ = system_dynamics_kwargs.pop("data", None)
|
77
|
-
_ = system_dynamics_kwargs.pop("model", None)
|
78
|
-
|
79
|
-
# Create the function with the signature expected by our generic integrators.
|
80
|
-
# Note that our system dynamics is time independent.
|
81
|
-
def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
|
82
|
-
|
83
|
-
# Get the data and model objects from the kwargs.
|
84
|
-
data_f = kwargs_f.pop("data")
|
85
|
-
model_f = kwargs_f.pop("model")
|
86
|
-
|
87
|
-
# Update the state and time stored inside data.
|
88
|
-
with data_f.editable(validate=True) as data_rw:
|
89
|
-
data_rw.state = x
|
90
|
-
|
91
|
-
# Evaluate the system dynamics, allowing to override the kwargs originally
|
92
|
-
# passed when the closure was created.
|
93
|
-
return system_dynamics(
|
94
|
-
model=model_f,
|
95
|
-
data=data_rw,
|
96
|
-
**(system_dynamics_kwargs | kwargs_f),
|
97
|
-
)
|
98
|
-
|
99
|
-
f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]
|
100
|
-
return f
|
101
|
-
|
102
10
|
|
103
11
|
# ==================================
|
104
12
|
# Functions defining system dynamics
|
105
13
|
# ==================================
|
106
14
|
|
107
15
|
|
108
|
-
@jax.jit
|
109
|
-
@js.common.named_scope
|
110
|
-
def system_velocity_dynamics(
|
111
|
-
model: js.model.JaxSimModel,
|
112
|
-
data: js.data.JaxSimModelData,
|
113
|
-
*,
|
114
|
-
link_forces: jtp.Vector | None = None,
|
115
|
-
joint_force_references: jtp.Vector | None = None,
|
116
|
-
) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]:
|
117
|
-
"""
|
118
|
-
Compute the dynamics of the system velocity.
|
119
|
-
|
120
|
-
Args:
|
121
|
-
model: The model to consider.
|
122
|
-
data: The data of the considered model.
|
123
|
-
link_forces:
|
124
|
-
The 6D forces to apply to the links expressed in the frame corresponding to
|
125
|
-
the velocity representation of `data`.
|
126
|
-
joint_force_references: The joint force references to apply.
|
127
|
-
|
128
|
-
Returns:
|
129
|
-
A tuple containing the derivative of the base 6D velocity in inertial-fixed
|
130
|
-
representation, the derivative of the joint velocities, and auxiliary data
|
131
|
-
returned by the system dynamics evaluation.
|
132
|
-
"""
|
133
|
-
|
134
|
-
# Build link forces if not provided.
|
135
|
-
# These forces are expressed in the frame corresponding to the velocity
|
136
|
-
# representation of data.
|
137
|
-
O_f_L = (
|
138
|
-
jnp.atleast_2d(link_forces.squeeze())
|
139
|
-
if link_forces is not None
|
140
|
-
else jnp.zeros((model.number_of_links(), 6))
|
141
|
-
).astype(float)
|
142
|
-
|
143
|
-
# We expect that the 6D forces included in the `link_forces` argument are expressed
|
144
|
-
# in the frame corresponding to the velocity representation of `data`.
|
145
|
-
references = js.references.JaxSimModelReferences.build(
|
146
|
-
model=model,
|
147
|
-
link_forces=O_f_L,
|
148
|
-
joint_force_references=joint_force_references,
|
149
|
-
data=data,
|
150
|
-
velocity_representation=data.velocity_representation,
|
151
|
-
)
|
152
|
-
|
153
|
-
# ======================
|
154
|
-
# Compute contact forces
|
155
|
-
# ======================
|
156
|
-
|
157
|
-
# Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
|
158
|
-
# with the terrain.
|
159
|
-
W_f_L_terrain = jnp.zeros_like(O_f_L).astype(float)
|
160
|
-
|
161
|
-
# Initialize a dictionary of auxiliary data.
|
162
|
-
# This dictionary is used to store additional data computed by the contact model.
|
163
|
-
aux_data = {}
|
164
|
-
|
165
|
-
if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
|
166
|
-
|
167
|
-
with (
|
168
|
-
data.switch_velocity_representation(VelRepr.Inertial),
|
169
|
-
references.switch_velocity_representation(VelRepr.Inertial),
|
170
|
-
):
|
171
|
-
|
172
|
-
# Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
|
173
|
-
# along with contact-specific auxiliary states.
|
174
|
-
W_f_C, aux_data = js.contact.collidable_point_dynamics(
|
175
|
-
model=model,
|
176
|
-
data=data,
|
177
|
-
link_forces=references.link_forces(model=model, data=data),
|
178
|
-
joint_force_references=references.joint_force_references(model=model),
|
179
|
-
)
|
180
|
-
|
181
|
-
# Compute the 6D forces applied to the links equivalent to the forces applied
|
182
|
-
# to the frames associated to the collidable points.
|
183
|
-
W_f_L_terrain = model.contact_model.link_forces_from_contact_forces(
|
184
|
-
model=model,
|
185
|
-
data=data,
|
186
|
-
contact_forces=W_f_C,
|
187
|
-
)
|
188
|
-
|
189
|
-
# ===========================
|
190
|
-
# Compute system acceleration
|
191
|
-
# ===========================
|
192
|
-
|
193
|
-
# Compute the total link forces.
|
194
|
-
with (
|
195
|
-
data.switch_velocity_representation(VelRepr.Inertial),
|
196
|
-
references.switch_velocity_representation(VelRepr.Inertial),
|
197
|
-
):
|
198
|
-
|
199
|
-
# Sum the contact forces just computed with the link forces applied by the user.
|
200
|
-
references = references.apply_link_forces(
|
201
|
-
model=model,
|
202
|
-
data=data,
|
203
|
-
forces=W_f_L_terrain,
|
204
|
-
additive=True,
|
205
|
-
)
|
206
|
-
|
207
|
-
# Get the link forces in inertial-fixed representation.
|
208
|
-
f_L_total = references.link_forces(model=model, data=data)
|
209
|
-
|
210
|
-
# Compute the system acceleration in inertial-fixed representation.
|
211
|
-
# This representation is useful for integration purpose.
|
212
|
-
W_v̇_WB, s̈ = system_acceleration(
|
213
|
-
model=model,
|
214
|
-
data=data,
|
215
|
-
joint_force_references=joint_force_references,
|
216
|
-
link_forces=f_L_total,
|
217
|
-
)
|
218
|
-
|
219
|
-
return W_v̇_WB, s̈, aux_data
|
220
|
-
|
221
|
-
|
222
16
|
def system_acceleration(
|
223
17
|
model: js.model.JaxSimModel,
|
224
18
|
data: js.data.JaxSimModelData,
|
225
19
|
*,
|
226
20
|
link_forces: jtp.MatrixLike | None = None,
|
227
|
-
|
21
|
+
joint_torques: jtp.VectorLike | None = None,
|
228
22
|
) -> tuple[jtp.Vector, jtp.Vector]:
|
229
23
|
"""
|
230
24
|
Compute the system acceleration in the active representation.
|
@@ -235,7 +29,7 @@ def system_acceleration(
|
|
235
29
|
link_forces:
|
236
30
|
The 6D forces to apply to the links expressed in the same
|
237
31
|
velocity representation of data.
|
238
|
-
|
32
|
+
joint_torques: The joint torques applied to the joints.
|
239
33
|
|
240
34
|
Returns:
|
241
35
|
A tuple containing the base 6D acceleration in the active representation
|
@@ -253,80 +47,6 @@ def system_acceleration(
|
|
253
47
|
else jnp.zeros((model.number_of_links(), 6))
|
254
48
|
).astype(float)
|
255
49
|
|
256
|
-
# Build joint torques if not provided.
|
257
|
-
τ_references = (
|
258
|
-
jnp.atleast_1d(joint_force_references.squeeze())
|
259
|
-
if joint_force_references is not None
|
260
|
-
else jnp.zeros_like(data.joint_positions())
|
261
|
-
).astype(float)
|
262
|
-
|
263
|
-
# ====================
|
264
|
-
# Enforce joint limits
|
265
|
-
# ====================
|
266
|
-
|
267
|
-
τ_position_limit = jnp.zeros_like(τ_references).astype(float)
|
268
|
-
|
269
|
-
if model.dofs() > 0:
|
270
|
-
|
271
|
-
# Stiffness and damper parameters for the joint position limits.
|
272
|
-
k_j = jnp.array(
|
273
|
-
model.kin_dyn_parameters.joint_parameters.position_limit_spring
|
274
|
-
).astype(float)
|
275
|
-
d_j = jnp.array(
|
276
|
-
model.kin_dyn_parameters.joint_parameters.position_limit_damper
|
277
|
-
).astype(float)
|
278
|
-
|
279
|
-
# Compute the joint position limit violations.
|
280
|
-
lower_violation = jnp.clip(
|
281
|
-
data.state.physics_model.joint_positions
|
282
|
-
- model.kin_dyn_parameters.joint_parameters.position_limits_min,
|
283
|
-
max=0.0,
|
284
|
-
)
|
285
|
-
|
286
|
-
upper_violation = jnp.clip(
|
287
|
-
data.state.physics_model.joint_positions
|
288
|
-
- model.kin_dyn_parameters.joint_parameters.position_limits_max,
|
289
|
-
min=0.0,
|
290
|
-
)
|
291
|
-
|
292
|
-
# Compute the joint position limit torque.
|
293
|
-
τ_position_limit -= jnp.diag(k_j) @ (lower_violation + upper_violation)
|
294
|
-
|
295
|
-
τ_position_limit -= (
|
296
|
-
jnp.positive(τ_position_limit)
|
297
|
-
* jnp.diag(d_j)
|
298
|
-
@ data.state.physics_model.joint_velocities
|
299
|
-
)
|
300
|
-
|
301
|
-
# ====================
|
302
|
-
# Joint friction model
|
303
|
-
# ====================
|
304
|
-
|
305
|
-
τ_friction = jnp.zeros_like(τ_references).astype(float)
|
306
|
-
|
307
|
-
if model.dofs() > 0:
|
308
|
-
|
309
|
-
# Static and viscous joint friction parameters
|
310
|
-
kc = jnp.array(
|
311
|
-
model.kin_dyn_parameters.joint_parameters.friction_static
|
312
|
-
).astype(float)
|
313
|
-
kv = jnp.array(
|
314
|
-
model.kin_dyn_parameters.joint_parameters.friction_viscous
|
315
|
-
).astype(float)
|
316
|
-
|
317
|
-
# Compute the joint friction torque.
|
318
|
-
τ_friction = -(
|
319
|
-
jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_velocities)
|
320
|
-
+ jnp.diag(kv) @ data.state.physics_model.joint_velocities
|
321
|
-
)
|
322
|
-
|
323
|
-
# ========================
|
324
|
-
# Compute forward dynamics
|
325
|
-
# ========================
|
326
|
-
|
327
|
-
# Compute the total joint forces.
|
328
|
-
τ_total = τ_references + τ_friction + τ_position_limit
|
329
|
-
|
330
50
|
# Store the link forces in a references object.
|
331
51
|
references = js.references.JaxSimModelReferences.build(
|
332
52
|
model=model,
|
@@ -345,7 +65,7 @@ def system_acceleration(
|
|
345
65
|
v̇_WB, s̈ = js.model.forward_dynamics_aba(
|
346
66
|
model=model,
|
347
67
|
data=data,
|
348
|
-
joint_forces
|
68
|
+
joint_forces=joint_torques,
|
349
69
|
link_forces=references.link_forces(model=model, data=data),
|
350
70
|
)
|
351
71
|
|
@@ -359,7 +79,7 @@ def system_position_dynamics(
|
|
359
79
|
data: js.data.JaxSimModelData,
|
360
80
|
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
|
361
81
|
) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
|
362
|
-
"""
|
82
|
+
r"""
|
363
83
|
Compute the dynamics of the system position.
|
364
84
|
|
365
85
|
Args:
|
@@ -371,16 +91,18 @@ def system_position_dynamics(
|
|
371
91
|
Returns:
|
372
92
|
A tuple containing the derivative of the base position, the derivative of the
|
373
93
|
base quaternion, and the derivative of the joint positions.
|
374
|
-
"""
|
375
94
|
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
95
|
+
Note:
|
96
|
+
In inertial-fixed representation, the linear component of the base velocity is not
|
97
|
+
the derivative of the base position. In fact, the base velocity is defined as:
|
98
|
+
:math:`{} ^W v_{W, B} = \begin{bmatrix} {} ^W \dot{p}_B S({} ^W \omega_{W, B}) {} ^W p _B\\ {} ^W \omega_{W, B} \end{bmatrix}`.
|
99
|
+
Where :math:`S(\cdot)` is the skew-symmetric matrix operator.
|
100
|
+
"""
|
381
101
|
|
382
|
-
|
383
|
-
|
102
|
+
ṡ = data.joint_velocities
|
103
|
+
W_Q_B = data.base_orientation
|
104
|
+
W_ω_WB = data.base_velocity[3:6]
|
105
|
+
W_ṗ_B = data.base_velocity[0:3] + Skew.wedge(W_ω_WB) @ data.base_position
|
384
106
|
|
385
107
|
W_Q̇_B = Quaternion.derivative(
|
386
108
|
quaternion=W_Q_B,
|
@@ -399,9 +121,9 @@ def system_dynamics(
|
|
399
121
|
data: js.data.JaxSimModelData,
|
400
122
|
*,
|
401
123
|
link_forces: jtp.Vector | None = None,
|
402
|
-
|
124
|
+
joint_torques: jtp.Vector | None = None,
|
403
125
|
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
|
404
|
-
) ->
|
126
|
+
) -> JaxSimModelData:
|
405
127
|
"""
|
406
128
|
Compute the dynamics of the system.
|
407
129
|
|
@@ -411,57 +133,32 @@ def system_dynamics(
|
|
411
133
|
link_forces:
|
412
134
|
The 6D forces to apply to the links expressed in the frame corresponding to
|
413
135
|
the velocity representation of `data`.
|
414
|
-
|
136
|
+
joint_torques: The joint torques acting on the joints.
|
415
137
|
baumgarte_quaternion_regularization:
|
416
138
|
The Baumgarte regularization coefficient used to adjust the norm of the
|
417
139
|
quaternion (only used in integrators not operating on the SO(3) manifold).
|
418
140
|
|
419
141
|
Returns:
|
420
|
-
A tuple with an `
|
142
|
+
A tuple with an `JaxSimModelData` object storing in each of its attributes the
|
421
143
|
corresponding derivative, and the dictionary of auxiliary data returned
|
422
144
|
by the system dynamics evaluation.
|
423
145
|
"""
|
424
146
|
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
# Initialize the dictionary storing the derivative of the additional state variables
|
434
|
-
# that extend the state vector of the integrated ODE system.
|
435
|
-
extended_ode_state = {}
|
436
|
-
|
437
|
-
match model.contact_model:
|
438
|
-
|
439
|
-
case contacts.SoftContacts():
|
440
|
-
extended_ode_state["tangential_deformation"] = aux_dict["m_dot"]
|
441
|
-
|
442
|
-
case contacts.ViscoElasticContacts():
|
443
|
-
|
444
|
-
extended_ode_state["tangential_deformation"] = jnp.zeros_like(
|
445
|
-
data.state.extended["tangential_deformation"]
|
446
|
-
)
|
447
|
-
|
448
|
-
case contacts.RigidContacts() | contacts.RelaxedRigidContacts():
|
449
|
-
pass
|
450
|
-
|
451
|
-
case _:
|
452
|
-
raise ValueError(f"Invalid contact model: {model.contact_model}")
|
147
|
+
with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
|
148
|
+
W_v̇_WB, s̈ = system_acceleration(
|
149
|
+
model=model,
|
150
|
+
data=data,
|
151
|
+
joint_torques=joint_torques,
|
152
|
+
link_forces=link_forces,
|
153
|
+
)
|
453
154
|
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
)
|
155
|
+
W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
|
156
|
+
model=model,
|
157
|
+
data=data,
|
158
|
+
baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
|
159
|
+
)
|
460
160
|
|
461
|
-
|
462
|
-
# Our integrators, operating on generic pytrees, will be able to handle it
|
463
|
-
# automatically as state derivative.
|
464
|
-
ode_state_derivative = ODEState.build_from_jaxsim_model(
|
161
|
+
ode_state_derivative = JaxSimModelData.build(
|
465
162
|
model=model,
|
466
163
|
base_position=W_ṗ_B,
|
467
164
|
base_quaternion=W_Q̇_B,
|
@@ -469,7 +166,6 @@ def system_dynamics(
|
|
469
166
|
base_linear_velocity=W_v̇_WB[0:3],
|
470
167
|
base_angular_velocity=W_v̇_WB[3:6],
|
471
168
|
joint_velocities=s̈,
|
472
|
-
**extended_ode_state,
|
473
169
|
)
|
474
170
|
|
475
|
-
return ode_state_derivative
|
171
|
+
return ode_state_derivative
|
jaxsim/api/references.py
CHANGED
@@ -242,7 +242,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
242
242
|
)(W_f_L, W_H_L)
|
243
243
|
|
244
244
|
# The f_L output is either L_f_L or LW_f_L, depending on the representation.
|
245
|
-
W_H_L =
|
245
|
+
W_H_L = data._link_transforms
|
246
246
|
f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :])
|
247
247
|
|
248
248
|
return f_L
|
@@ -450,7 +450,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
450
450
|
)(f_L, W_H_L)
|
451
451
|
|
452
452
|
# The f_L input is either L_f_L or LW_f_L, depending on the representation.
|
453
|
-
W_H_L =
|
453
|
+
W_H_L = data._link_transforms
|
454
454
|
W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :])
|
455
455
|
|
456
456
|
return replace(forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L))
|
jaxsim/exceptions.py
CHANGED
@@ -23,8 +23,8 @@ def raise_if(
|
|
23
23
|
|
24
24
|
# Disable host callback if running on unsupported hardware or if the user
|
25
25
|
# explicitly disabled it.
|
26
|
-
if jax.devices()[0].platform in {"tpu", "METAL"} or os.environ.get(
|
27
|
-
"
|
26
|
+
if jax.devices()[0].platform in {"tpu", "METAL"} or not os.environ.get(
|
27
|
+
"JAXSIM_ENABLE_EXCEPTIONS", 0
|
28
28
|
):
|
29
29
|
return
|
30
30
|
|
jaxsim/math/__init__.py
CHANGED
@@ -1,6 +1,3 @@
|
|
1
|
-
# Define the default standard gravity constant.
|
2
|
-
StandardGravity = 9.81
|
3
|
-
|
4
1
|
from .adjoint import Adjoint
|
5
2
|
from .cross import Cross
|
6
3
|
from .inertia import Inertia
|
@@ -11,3 +8,7 @@ from .transform import Transform
|
|
11
8
|
from .utils import safe_norm
|
12
9
|
|
13
10
|
from .joint_model import JointModel, supported_joint_motion # isort:skip
|
11
|
+
|
12
|
+
|
13
|
+
# Define the default standard gravity constant.
|
14
|
+
STANDARD_GRAVITY = -9.81
|
jaxsim/math/joint_model.py
CHANGED
@@ -7,12 +7,10 @@ import jaxlie
|
|
7
7
|
from jax_dataclasses import Static
|
8
8
|
|
9
9
|
import jaxsim.typing as jtp
|
10
|
+
from jaxsim.math import Rotation
|
10
11
|
from jaxsim.parsers.descriptions import JointGenericAxis, JointType, ModelDescription
|
11
12
|
from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms
|
12
13
|
|
13
|
-
from .rotation import Rotation
|
14
|
-
from .transform import Transform
|
15
|
-
|
16
14
|
|
17
15
|
@jax_dataclasses.pytree_dataclass
|
18
16
|
class JointModel:
|
@@ -113,60 +111,6 @@ class JointModel:
|
|
113
111
|
joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints),
|
114
112
|
)
|
115
113
|
|
116
|
-
def parent_H_child(
|
117
|
-
self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
|
118
|
-
) -> tuple[jtp.Matrix, jtp.Array]:
|
119
|
-
r"""
|
120
|
-
Compute the homogeneous transformation between the parent link and
|
121
|
-
the child link of a joint, and the corresponding motion subspace.
|
122
|
-
|
123
|
-
Args:
|
124
|
-
joint_index: The index of the joint.
|
125
|
-
joint_position: The position of the joint.
|
126
|
-
|
127
|
-
Returns:
|
128
|
-
A tuple containing the homogeneous transformation
|
129
|
-
:math:`{}^{\lambda(i)} \mathbf{H}_i(s)`
|
130
|
-
and the motion subspace :math:`\mathbf{S}(s)`.
|
131
|
-
"""
|
132
|
-
|
133
|
-
i = joint_index
|
134
|
-
s = joint_position
|
135
|
-
|
136
|
-
# Get the components of the joint model.
|
137
|
-
λ_Hi_pre = self.parent_H_predecessor(joint_index=i)
|
138
|
-
pre_Hi_suc, S = self.predecessor_H_successor(joint_index=i, joint_position=s)
|
139
|
-
suc_Hi_i = self.successor_H_child(joint_index=i)
|
140
|
-
|
141
|
-
# Compose all the transforms.
|
142
|
-
return λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i, S
|
143
|
-
|
144
|
-
@jax.jit
|
145
|
-
def child_H_parent(
|
146
|
-
self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
|
147
|
-
) -> tuple[jtp.Matrix, jtp.Array]:
|
148
|
-
r"""
|
149
|
-
Compute the homogeneous transformation between the child link and
|
150
|
-
the parent link of a joint, and the corresponding motion subspace.
|
151
|
-
|
152
|
-
Args:
|
153
|
-
joint_index: The index of the joint.
|
154
|
-
joint_position: The position of the joint.
|
155
|
-
|
156
|
-
Returns:
|
157
|
-
A tuple containing the homogeneous transformation
|
158
|
-
:math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)`
|
159
|
-
and the motion subspace :math:`\mathbf{S}(s)`.
|
160
|
-
"""
|
161
|
-
|
162
|
-
λ_Hi_i, S = self.parent_H_child(
|
163
|
-
joint_index=joint_index, joint_position=joint_position
|
164
|
-
)
|
165
|
-
|
166
|
-
i_Hi_λ = Transform.inverse(λ_Hi_i)
|
167
|
-
|
168
|
-
return i_Hi_λ, S
|
169
|
-
|
170
114
|
def parent_H_predecessor(self, joint_index: jtp.IntLike) -> jtp.Matrix:
|
171
115
|
r"""
|
172
116
|
Return the homogeneous transformation between the parent link and
|
@@ -182,31 +126,6 @@ class JointModel:
|
|
182
126
|
|
183
127
|
return self.λ_H_pre[joint_index]
|
184
128
|
|
185
|
-
def predecessor_H_successor(
|
186
|
-
self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
|
187
|
-
) -> tuple[jtp.Matrix, jtp.Array]:
|
188
|
-
r"""
|
189
|
-
Compute the homogeneous transformation between the predecessor and
|
190
|
-
the successor frame of a joint, and the corresponding motion subspace.
|
191
|
-
|
192
|
-
Args:
|
193
|
-
joint_index: The index of the joint.
|
194
|
-
joint_position: The position of the joint.
|
195
|
-
|
196
|
-
Returns:
|
197
|
-
A tuple containing the homogeneous transformation
|
198
|
-
:math:`{}^{\text{pre}(i)} \mathbf{H}_{\text{suc}(i)}(s)`
|
199
|
-
and the motion subspace :math:`\mathbf{S}(s)`.
|
200
|
-
"""
|
201
|
-
|
202
|
-
pre_H_suc, S = supported_joint_motion(
|
203
|
-
self.joint_types[joint_index],
|
204
|
-
joint_position,
|
205
|
-
self.joint_axis[joint_index].axis,
|
206
|
-
)
|
207
|
-
|
208
|
-
return pre_H_suc, S
|
209
|
-
|
210
129
|
def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix:
|
211
130
|
r"""
|
212
131
|
Return the homogeneous transformation between the successor frame and
|
@@ -225,65 +144,56 @@ class JointModel:
|
|
225
144
|
|
226
145
|
@jax.jit
|
227
146
|
def supported_joint_motion(
|
228
|
-
|
229
|
-
|
230
|
-
joint_axis: jtp.VectorLike | None = None,
|
231
|
-
/,
|
232
|
-
) -> tuple[jtp.Matrix, jtp.Array]:
|
147
|
+
joint_types: jtp.Array, joint_positions: jtp.Matrix, joint_axes: jtp.Matrix
|
148
|
+
) -> jtp.Matrix:
|
233
149
|
"""
|
234
|
-
Compute the
|
150
|
+
Compute the transforms of the joints.
|
235
151
|
|
236
152
|
Args:
|
237
|
-
|
238
|
-
|
239
|
-
|
153
|
+
joint_types: The types of the joints.
|
154
|
+
joint_positions: The positions of the joints.
|
155
|
+
joint_axes: The axes of the joints.
|
240
156
|
|
241
157
|
Returns:
|
242
|
-
|
158
|
+
The transforms of the joints.
|
243
159
|
"""
|
244
160
|
|
245
161
|
# Prepare the joint position
|
246
|
-
s = jnp.array(
|
162
|
+
s = jnp.array(joint_positions).astype(float)
|
247
163
|
|
248
164
|
def compute_F() -> tuple[jtp.Matrix, jtp.Array]:
|
249
|
-
return jaxlie.SE3.identity()
|
165
|
+
return jaxlie.SE3.identity()
|
250
166
|
|
251
167
|
def compute_R() -> tuple[jtp.Matrix, jtp.Array]:
|
252
168
|
|
253
169
|
# Get the additional argument specifying the joint axis.
|
254
170
|
# This is a metadata required by only some joint types.
|
255
|
-
axis = jnp.array(
|
171
|
+
axis = jnp.array(joint_axes).astype(float).squeeze()
|
256
172
|
|
257
173
|
pre_H_suc = jaxlie.SE3.from_matrix(
|
258
174
|
matrix=jnp.eye(4).at[:3, :3].set(Rotation.from_axis_angle(vector=s * axis))
|
259
175
|
)
|
260
176
|
|
261
|
-
|
262
|
-
|
263
|
-
return pre_H_suc, S
|
177
|
+
return pre_H_suc
|
264
178
|
|
265
179
|
def compute_P() -> tuple[jtp.Matrix, jtp.Array]:
|
266
180
|
|
267
181
|
# Get the additional argument specifying the joint axis.
|
268
182
|
# This is a metadata required by only some joint types.
|
269
|
-
axis = jnp.array(
|
183
|
+
axis = jnp.array(joint_axes).astype(float).squeeze()
|
270
184
|
|
271
185
|
pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
|
272
186
|
rotation=jaxlie.SO3.identity(),
|
273
187
|
translation=jnp.array(s * axis),
|
274
188
|
)
|
275
189
|
|
276
|
-
|
190
|
+
return pre_H_suc
|
277
191
|
|
278
|
-
|
279
|
-
|
280
|
-
pre_H_suc, S = jax.lax.switch(
|
281
|
-
index=joint_type,
|
192
|
+
return jax.lax.switch(
|
193
|
+
index=joint_types,
|
282
194
|
branches=(
|
283
195
|
compute_F, # JointType.Fixed
|
284
196
|
compute_R, # JointType.Revolute
|
285
197
|
compute_P, # JointType.Prismatic
|
286
198
|
),
|
287
|
-
)
|
288
|
-
|
289
|
-
return pre_H_suc.as_matrix(), S
|
199
|
+
).as_matrix()
|