jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -129
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +87 -16
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +62 -24
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +607 -225
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -80
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev188.dist-info/METADATA +0 -184
- jaxsim-0.2.dev188.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/api/references.py
CHANGED
@@ -7,10 +7,11 @@ import jax.numpy as jnp
|
|
7
7
|
import jax_dataclasses
|
8
8
|
|
9
9
|
import jaxsim.api as js
|
10
|
-
import jaxsim.physics.model.physics_model_state
|
11
10
|
import jaxsim.typing as jtp
|
12
|
-
from jaxsim import
|
13
|
-
from jaxsim.
|
11
|
+
from jaxsim import exceptions
|
12
|
+
from jaxsim.utils.tracing import not_tracing
|
13
|
+
|
14
|
+
from .common import VelRepr
|
14
15
|
|
15
16
|
try:
|
16
17
|
from typing import Self
|
@@ -22,13 +23,19 @@ except ImportError:
|
|
22
23
|
class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
23
24
|
"""
|
24
25
|
Class containing the references for a `JaxSimModel` object.
|
26
|
+
|
27
|
+
Attributes:
|
28
|
+
_link_forces: The link 6D forces in inertial-fixed representation.
|
29
|
+
_joint_force_references: The joint force references.
|
25
30
|
"""
|
26
31
|
|
27
|
-
|
32
|
+
_link_forces: jtp.Matrix
|
33
|
+
_joint_force_references: jtp.Vector
|
28
34
|
|
29
35
|
@staticmethod
|
30
36
|
def zero(
|
31
37
|
model: js.model.JaxSimModel,
|
38
|
+
data: js.data.JaxSimModelData | None = None,
|
32
39
|
velocity_representation: VelRepr = VelRepr.Inertial,
|
33
40
|
) -> JaxSimModelReferences:
|
34
41
|
"""
|
@@ -36,6 +43,9 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
36
43
|
|
37
44
|
Args:
|
38
45
|
model: The model for which to create the zero references.
|
46
|
+
data:
|
47
|
+
The data of the model, only needed if the velocity representation is
|
48
|
+
not inertial-fixed.
|
39
49
|
velocity_representation: The velocity representation to use.
|
40
50
|
|
41
51
|
Returns:
|
@@ -43,14 +53,14 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
43
53
|
"""
|
44
54
|
|
45
55
|
return JaxSimModelReferences.build(
|
46
|
-
model=model, velocity_representation=velocity_representation
|
56
|
+
model=model, data=data, velocity_representation=velocity_representation
|
47
57
|
)
|
48
58
|
|
49
59
|
@staticmethod
|
50
60
|
def build(
|
51
61
|
model: js.model.JaxSimModel,
|
52
|
-
joint_force_references: jtp.
|
53
|
-
link_forces: jtp.
|
62
|
+
joint_force_references: jtp.VectorLike | None = None,
|
63
|
+
link_forces: jtp.MatrixLike | None = None,
|
54
64
|
data: js.data.JaxSimModelData | None = None,
|
55
65
|
velocity_representation: VelRepr | None = None,
|
56
66
|
) -> JaxSimModelReferences:
|
@@ -72,14 +82,14 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
72
82
|
|
73
83
|
# Create or adjust joint force references.
|
74
84
|
joint_force_references = jnp.atleast_1d(
|
75
|
-
joint_force_references.squeeze()
|
85
|
+
jnp.array(joint_force_references, dtype=float).squeeze()
|
76
86
|
if joint_force_references is not None
|
77
87
|
else jnp.zeros(model.dofs())
|
78
88
|
).astype(float)
|
79
89
|
|
80
90
|
# Create or adjust link forces.
|
81
91
|
f_L = jnp.atleast_2d(
|
82
|
-
link_forces.squeeze()
|
92
|
+
jnp.array(link_forces, dtype=float).squeeze()
|
83
93
|
if link_forces is not None
|
84
94
|
else jnp.zeros((model.number_of_links(), 6))
|
85
95
|
).astype(float)
|
@@ -88,17 +98,21 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
88
98
|
velocity_representation = (
|
89
99
|
velocity_representation
|
90
100
|
if velocity_representation is not None
|
91
|
-
else (
|
92
|
-
data.velocity_representation if data is not None else VelRepr.Inertial
|
93
|
-
)
|
101
|
+
else getattr(data, "velocity_representation", VelRepr.Inertial)
|
94
102
|
)
|
95
103
|
|
96
104
|
# Create a zero references object.
|
97
105
|
references = JaxSimModelReferences(
|
98
|
-
|
106
|
+
_link_forces=f_L,
|
107
|
+
_joint_force_references=joint_force_references,
|
99
108
|
velocity_representation=velocity_representation,
|
100
109
|
)
|
101
110
|
|
111
|
+
# If the velocity representation is inertial-fixed, we can return
|
112
|
+
# the references directly, as we store the link forces in this frame.
|
113
|
+
if velocity_representation is VelRepr.Inertial:
|
114
|
+
return references
|
115
|
+
|
102
116
|
# Store the joint force references.
|
103
117
|
references = references.set_joint_force_references(
|
104
118
|
forces=joint_force_references,
|
@@ -129,17 +143,27 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
129
143
|
`False` otherwise.
|
130
144
|
"""
|
131
145
|
|
132
|
-
|
146
|
+
if model is None:
|
147
|
+
return True
|
148
|
+
|
149
|
+
shape = self._joint_force_references.shape
|
150
|
+
expected_shape = (model.dofs(),)
|
151
|
+
|
152
|
+
if shape != expected_shape:
|
153
|
+
return False
|
154
|
+
|
155
|
+
shape = self._link_forces.shape
|
156
|
+
expected_shape = (model.number_of_links(), 6)
|
133
157
|
|
134
|
-
if
|
135
|
-
|
158
|
+
if shape != expected_shape:
|
159
|
+
return False
|
136
160
|
|
137
|
-
return
|
161
|
+
return True
|
138
162
|
|
139
163
|
# ==================
|
140
164
|
# Extract quantities
|
141
165
|
# ==================
|
142
|
-
|
166
|
+
@js.common.named_scope
|
143
167
|
@functools.partial(jax.jit, static_argnames=["link_names"])
|
144
168
|
def link_forces(
|
145
169
|
self,
|
@@ -172,7 +196,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
172
196
|
e.g. to the contact model and other kinematic constraints.
|
173
197
|
"""
|
174
198
|
|
175
|
-
W_f_L = self.
|
199
|
+
W_f_L = self._link_forces
|
176
200
|
|
177
201
|
# Return all link forces in inertial-fixed representation using the implicit
|
178
202
|
# serialization.
|
@@ -184,11 +208,14 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
184
208
|
if link_names is not None:
|
185
209
|
raise ValueError("Link names cannot be provided without a model")
|
186
210
|
|
187
|
-
return
|
211
|
+
return W_f_L
|
188
212
|
|
189
213
|
# If we have the model, we can extract the link names, if not provided.
|
190
|
-
|
191
|
-
|
214
|
+
link_idxs = (
|
215
|
+
js.link.names_to_idxs(link_names=link_names, model=model)
|
216
|
+
if link_names is not None
|
217
|
+
else jnp.arange(model.number_of_links())
|
218
|
+
)
|
192
219
|
|
193
220
|
# In inertial-fixed representation, we already have the link forces.
|
194
221
|
if self.velocity_representation is VelRepr.Inertial:
|
@@ -198,20 +225,25 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
198
225
|
msg = "Missing model data to use a representation different from {}"
|
199
226
|
raise ValueError(msg.format(VelRepr.Inertial.name))
|
200
227
|
|
201
|
-
if not data.valid(model=model):
|
228
|
+
if not_tracing(self._link_forces) and not data.valid(model=model):
|
202
229
|
raise ValueError("The provided data is not valid for the model")
|
203
230
|
|
204
|
-
# Helper function to convert a single 6D force to the active representation
|
205
|
-
|
206
|
-
|
207
|
-
array=f_L,
|
208
|
-
other_representation=self.velocity_representation,
|
209
|
-
transform=data.base_transform(),
|
210
|
-
is_force=True,
|
211
|
-
)
|
231
|
+
# Helper function to convert a single 6D force to the active representation
|
232
|
+
# considering as body the link (i.e. L_f_L and LW_f_L).
|
233
|
+
def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix:
|
212
234
|
|
213
|
-
|
214
|
-
|
235
|
+
return jax.vmap(
|
236
|
+
lambda W_f_L, W_H_L: JaxSimModelReferences.inertial_to_other_representation(
|
237
|
+
array=W_f_L,
|
238
|
+
other_representation=self.velocity_representation,
|
239
|
+
transform=W_H_L,
|
240
|
+
is_force=True,
|
241
|
+
)
|
242
|
+
)(W_f_L, W_H_L)
|
243
|
+
|
244
|
+
# The f_L output is either L_f_L or LW_f_L, depending on the representation.
|
245
|
+
W_H_L = js.model.forward_kinematics(model=model, data=data)
|
246
|
+
f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :])
|
215
247
|
|
216
248
|
return f_L
|
217
249
|
|
@@ -250,23 +282,26 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
250
282
|
if joint_names is not None:
|
251
283
|
raise ValueError("Joint names cannot be provided without a model")
|
252
284
|
|
253
|
-
return self.
|
285
|
+
return self._joint_force_references
|
254
286
|
|
255
|
-
if not self.valid(model=model):
|
287
|
+
if not_tracing(self._joint_force_references) and not self.valid(model=model):
|
256
288
|
msg = "The actuation object is not compatible with the provided model"
|
257
289
|
raise ValueError(msg)
|
258
290
|
|
259
|
-
|
260
|
-
|
291
|
+
joint_idxs = (
|
292
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
293
|
+
if joint_names is not None
|
294
|
+
else jnp.arange(model.number_of_joints())
|
295
|
+
)
|
261
296
|
|
262
297
|
return jnp.atleast_1d(
|
263
|
-
self.
|
298
|
+
self._joint_force_references[joint_idxs].squeeze()
|
264
299
|
).astype(float)
|
265
300
|
|
266
301
|
# ================
|
267
302
|
# Store quantities
|
268
303
|
# ================
|
269
|
-
|
304
|
+
@js.common.named_scope
|
270
305
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
271
306
|
def set_joint_force_references(
|
272
307
|
self,
|
@@ -288,37 +323,37 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
288
323
|
A new `JaxSimModelReferences` object with the given joint force references.
|
289
324
|
"""
|
290
325
|
|
291
|
-
forces = jnp.array(forces)
|
326
|
+
forces = jnp.atleast_1d(jnp.array(forces, dtype=float).squeeze())
|
292
327
|
|
293
|
-
def replace(forces: jtp.
|
328
|
+
def replace(forces: jtp.Vector) -> JaxSimModelReferences:
|
294
329
|
return self.replace(
|
295
330
|
validate=True,
|
296
|
-
|
297
|
-
physics_model=self.input.physics_model.replace(
|
298
|
-
tau=jnp.atleast_1d(forces.squeeze()).astype(float)
|
299
|
-
)
|
300
|
-
),
|
331
|
+
_joint_force_references=jnp.atleast_1d(forces.squeeze()).astype(float),
|
301
332
|
)
|
302
333
|
|
303
334
|
if model is None:
|
304
335
|
return replace(forces=forces)
|
305
336
|
|
306
|
-
if not self.valid(model=model):
|
337
|
+
if not_tracing(forces) and not self.valid(model=model):
|
307
338
|
msg = "The references object is not compatible with the provided model"
|
308
339
|
raise ValueError(msg)
|
309
340
|
|
310
|
-
|
311
|
-
|
341
|
+
joint_idxs = (
|
342
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
343
|
+
if joint_names is not None
|
344
|
+
else jnp.arange(model.number_of_joints())
|
345
|
+
)
|
312
346
|
|
313
|
-
return replace(forces=self.
|
347
|
+
return replace(forces=self._joint_force_references.at[joint_idxs].set(forces))
|
314
348
|
|
349
|
+
@js.common.named_scope
|
315
350
|
@functools.partial(jax.jit, static_argnames=["link_names", "additive"])
|
316
351
|
def apply_link_forces(
|
317
352
|
self,
|
318
353
|
forces: jtp.MatrixLike,
|
319
354
|
model: js.model.JaxSimModel | None = None,
|
320
355
|
data: js.data.JaxSimModelData | None = None,
|
321
|
-
link_names: tuple[str, ...] | None = None,
|
356
|
+
link_names: tuple[str, ...] | str | None = None,
|
322
357
|
additive: bool = False,
|
323
358
|
) -> Self:
|
324
359
|
"""
|
@@ -344,17 +379,13 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
344
379
|
Then, we always convert and store forces in inertial-fixed representation.
|
345
380
|
"""
|
346
381
|
|
347
|
-
f_L = jnp.
|
382
|
+
f_L = jnp.atleast_2d(forces).astype(float)
|
348
383
|
|
349
384
|
# Helper function to replace the link forces.
|
350
385
|
def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:
|
351
386
|
return self.replace(
|
352
387
|
validate=True,
|
353
|
-
|
354
|
-
physics_model=self.input.physics_model.replace(
|
355
|
-
f_ext=jnp.atleast_2d(forces.squeeze()).astype(float)
|
356
|
-
)
|
357
|
-
),
|
388
|
+
_link_forces=jnp.atleast_2d(forces.squeeze()).astype(float),
|
358
389
|
)
|
359
390
|
|
360
391
|
# In this case, we allow only to set the inertial 6D forces to all links
|
@@ -369,52 +400,157 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
369
400
|
|
370
401
|
W_f_L = f_L
|
371
402
|
|
372
|
-
W_f0_L = (
|
373
|
-
jnp.zeros_like(W_f_L)
|
374
|
-
if not additive
|
375
|
-
else self.input.physics_model.f_ext
|
376
|
-
)
|
403
|
+
W_f0_L = jnp.zeros_like(W_f_L) if not additive else self._link_forces
|
377
404
|
|
378
405
|
return replace(forces=W_f0_L + W_f_L)
|
379
406
|
|
380
|
-
|
381
|
-
|
382
|
-
|
407
|
+
if link_names is not None and len(link_names) != f_L.shape[0]:
|
408
|
+
msg = "The number of link names ({}) must match the number of forces ({})"
|
409
|
+
raise ValueError(msg.format(len(link_names), f_L.shape[0]))
|
410
|
+
|
411
|
+
# Extract the link indices.
|
412
|
+
link_idxs = (
|
413
|
+
js.link.names_to_idxs(link_names=link_names, model=model)
|
414
|
+
if link_names is not None
|
415
|
+
else jnp.arange(model.number_of_links())
|
416
|
+
)
|
383
417
|
|
384
418
|
# Compute the bias depending on whether we either set or add the link forces.
|
385
419
|
W_f0_L = (
|
386
|
-
jnp.zeros_like(f_L)
|
387
|
-
if not additive
|
388
|
-
else self.input.physics_model.f_ext[link_idxs, :]
|
420
|
+
jnp.zeros_like(f_L) if not additive else self._link_forces[link_idxs, :]
|
389
421
|
)
|
390
422
|
|
391
423
|
# If inertial-fixed representation, we can directly store the link forces.
|
392
424
|
if self.velocity_representation is VelRepr.Inertial:
|
393
425
|
W_f_L = f_L
|
394
426
|
return replace(
|
395
|
-
forces=self.
|
396
|
-
W_f0_L + W_f_L
|
397
|
-
)
|
427
|
+
forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L)
|
398
428
|
)
|
399
429
|
|
400
430
|
if data is None:
|
401
431
|
msg = "Missing model data to use a representation different from {}"
|
402
432
|
raise ValueError(msg.format(VelRepr.Inertial.name))
|
403
433
|
|
404
|
-
if not data.valid(model=model):
|
434
|
+
if not_tracing(forces) and not data.valid(model=model):
|
405
435
|
raise ValueError("The provided data is not valid for the model")
|
406
436
|
|
407
|
-
# Helper function to convert a single 6D force to the inertial representation
|
408
|
-
|
437
|
+
# Helper function to convert a single 6D force to the inertial representation
|
438
|
+
# considering as body the link (i.e. L_f_L and LW_f_L).
|
439
|
+
def convert_using_link_frame(
|
440
|
+
f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike
|
441
|
+
) -> jtp.Matrix:
|
442
|
+
|
443
|
+
return jax.vmap(
|
444
|
+
lambda f_L, W_H_L: JaxSimModelReferences.other_representation_to_inertial(
|
445
|
+
array=f_L,
|
446
|
+
other_representation=self.velocity_representation,
|
447
|
+
transform=W_H_L,
|
448
|
+
is_force=True,
|
449
|
+
)
|
450
|
+
)(f_L, W_H_L)
|
451
|
+
|
452
|
+
# The f_L input is either L_f_L or LW_f_L, depending on the representation.
|
453
|
+
W_H_L = js.model.forward_kinematics(model=model, data=data)
|
454
|
+
W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :])
|
455
|
+
|
456
|
+
return replace(forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L))
|
457
|
+
|
458
|
+
def apply_frame_forces(
|
459
|
+
self,
|
460
|
+
forces: jtp.MatrixLike,
|
461
|
+
model: js.model.JaxSimModel,
|
462
|
+
data: js.data.JaxSimModelData,
|
463
|
+
frame_names: tuple[str, ...] | str | None = None,
|
464
|
+
additive: bool = False,
|
465
|
+
) -> Self:
|
466
|
+
"""
|
467
|
+
Apply the frame forces.
|
468
|
+
|
469
|
+
Args:
|
470
|
+
forces: The frame 6D forces in the active representation.
|
471
|
+
model:
|
472
|
+
The model to consider, only needed if a frame serialization different
|
473
|
+
from the implicit one is used.
|
474
|
+
data:
|
475
|
+
The data of the considered model, only needed if the velocity
|
476
|
+
representation is not inertial-fixed.
|
477
|
+
frame_names: The names of the frames corresponding to the forces.
|
478
|
+
additive:
|
479
|
+
Whether to add the forces to the existing ones instead of replacing them.
|
480
|
+
|
481
|
+
Returns:
|
482
|
+
A new `JaxSimModelReferences` object with the given frame forces.
|
483
|
+
|
484
|
+
Note:
|
485
|
+
The frame forces must be expressed in the active representation.
|
486
|
+
Then, we always convert and store forces in inertial-fixed representation.
|
487
|
+
"""
|
488
|
+
|
489
|
+
f_F = jnp.atleast_2d(forces).astype(float)
|
490
|
+
|
491
|
+
if len(frame_names) != f_F.shape[0]:
|
492
|
+
msg = "The number of frame names ({}) must match the number of forces ({})"
|
493
|
+
raise ValueError(msg.format(len(frame_names), f_F.shape[0]))
|
494
|
+
|
495
|
+
# Extract the frame indices.
|
496
|
+
frame_idxs = (
|
497
|
+
js.frame.names_to_idxs(frame_names=frame_names, model=model)
|
498
|
+
if frame_names is not None
|
499
|
+
else jnp.arange(len(model.frame_names()))
|
500
|
+
)
|
501
|
+
|
502
|
+
parent_link_idxs = jnp.array(model.kin_dyn_parameters.frame_parameters.body)[
|
503
|
+
frame_idxs - model.number_of_links()
|
504
|
+
]
|
505
|
+
|
506
|
+
exceptions.raise_value_error_if(
|
507
|
+
condition=~data.valid(model=model),
|
508
|
+
msg="The provided data is not valid for the model",
|
509
|
+
)
|
510
|
+
W_H_Fi = jax.vmap(
|
511
|
+
lambda frame_idx: js.frame.transform(
|
512
|
+
model=model, data=data, frame_index=frame_idx
|
513
|
+
)
|
514
|
+
)(frame_idxs)
|
515
|
+
|
516
|
+
# Helper function to convert a single 6D force to the inertial representation
|
517
|
+
# considering as body the frame (i.e. L_f_F and LW_f_F).
|
518
|
+
def to_inertial(f_F: jtp.MatrixLike, W_H_F: jtp.MatrixLike) -> jtp.Matrix:
|
409
519
|
return JaxSimModelReferences.other_representation_to_inertial(
|
410
|
-
array=
|
520
|
+
array=f_F,
|
411
521
|
other_representation=self.velocity_representation,
|
412
|
-
transform=
|
522
|
+
transform=W_H_F,
|
413
523
|
is_force=True,
|
414
524
|
)
|
415
525
|
|
416
|
-
|
526
|
+
match self.velocity_representation:
|
527
|
+
case VelRepr.Inertial:
|
528
|
+
W_f_F = f_F
|
417
529
|
|
418
|
-
|
419
|
-
|
420
|
-
|
530
|
+
case VelRepr.Body | VelRepr.Mixed:
|
531
|
+
W_f_F = jax.vmap(to_inertial)(f_F, W_H_Fi)
|
532
|
+
|
533
|
+
case _:
|
534
|
+
raise ValueError("Invalid velocity representation.")
|
535
|
+
|
536
|
+
# Sum the forces on the parent links.
|
537
|
+
mask = parent_link_idxs[:, jnp.newaxis] == jnp.arange(model.number_of_links())
|
538
|
+
W_f_L = mask.T @ W_f_F
|
539
|
+
|
540
|
+
with self.switch_velocity_representation(
|
541
|
+
velocity_representation=VelRepr.Inertial
|
542
|
+
):
|
543
|
+
references = self.apply_link_forces(
|
544
|
+
model=model,
|
545
|
+
data=data,
|
546
|
+
link_names=js.link.idxs_to_names(
|
547
|
+
model=model, link_indices=parent_link_idxs
|
548
|
+
),
|
549
|
+
forces=W_f_L,
|
550
|
+
additive=additive,
|
551
|
+
)
|
552
|
+
|
553
|
+
with references.switch_velocity_representation(
|
554
|
+
velocity_representation=self.velocity_representation
|
555
|
+
):
|
556
|
+
return references
|
jaxsim/exceptions.py
ADDED
@@ -0,0 +1,80 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
import jax
|
4
|
+
|
5
|
+
|
6
|
+
def raise_if(
|
7
|
+
condition: bool | jax.Array, exception: type, msg: str, *args, **kwargs
|
8
|
+
) -> None:
|
9
|
+
"""
|
10
|
+
Raise a host-side exception if a condition is met. Useful in jit-compiled functions.
|
11
|
+
|
12
|
+
Args:
|
13
|
+
condition:
|
14
|
+
The boolean condition of the evaluated expression that triggers
|
15
|
+
the exception during runtime.
|
16
|
+
exception: The type of exception to raise.
|
17
|
+
msg:
|
18
|
+
The message to display when the exception is raised. The message can be a
|
19
|
+
format string (fmt), whose fields are filled with the args and kwargs.
|
20
|
+
*args: The arguments to fill the format string.
|
21
|
+
**kwargs: The keyword arguments to fill the format string
|
22
|
+
"""
|
23
|
+
|
24
|
+
# Disable host callback if running on unsupported hardware or if the user
|
25
|
+
# explicitly disabled it.
|
26
|
+
if jax.devices()[0].platform in {"tpu", "METAL"} or os.environ.get(
|
27
|
+
"JAXSIM_DISABLE_EXCEPTIONS", 0
|
28
|
+
):
|
29
|
+
return
|
30
|
+
|
31
|
+
# Check early that the format string is well-formed.
|
32
|
+
try:
|
33
|
+
_ = msg.format(*args, **kwargs)
|
34
|
+
except Exception as e:
|
35
|
+
msg = "Error in formatting exception message with args={} and kwargs={}"
|
36
|
+
raise ValueError(msg.format(args, kwargs)) from e
|
37
|
+
|
38
|
+
def _raise_exception(condition: bool, *args, **kwargs) -> None:
|
39
|
+
"""The function called by the JAX callback."""
|
40
|
+
|
41
|
+
if condition:
|
42
|
+
raise exception(msg.format(*args, **kwargs))
|
43
|
+
|
44
|
+
def _callback(args, kwargs) -> None:
|
45
|
+
"""The function that calls the JAX callback, executed only when needed."""
|
46
|
+
|
47
|
+
jax.debug.callback(_raise_exception, condition, *args, **kwargs)
|
48
|
+
|
49
|
+
# Since running a callable on the host is expensive, we prevent its execution
|
50
|
+
# if the condition is False with a low-level conditional expression.
|
51
|
+
def _run_callback_only_if_condition_is_true(*args, **kwargs) -> None:
|
52
|
+
return jax.lax.cond(
|
53
|
+
condition,
|
54
|
+
_callback,
|
55
|
+
lambda args, kwargs: None,
|
56
|
+
args,
|
57
|
+
kwargs,
|
58
|
+
)
|
59
|
+
|
60
|
+
return _run_callback_only_if_condition_is_true(*args, **kwargs)
|
61
|
+
|
62
|
+
|
63
|
+
def raise_runtime_error_if(
|
64
|
+
condition: bool | jax.Array, msg: str, *args, **kwargs
|
65
|
+
) -> None:
|
66
|
+
"""
|
67
|
+
Raise a RuntimeError if a condition is met. Useful in jit-compiled functions.
|
68
|
+
"""
|
69
|
+
|
70
|
+
return raise_if(condition, RuntimeError, msg, *args, **kwargs)
|
71
|
+
|
72
|
+
|
73
|
+
def raise_value_error_if(
|
74
|
+
condition: bool | jax.Array, msg: str, *args, **kwargs
|
75
|
+
) -> None:
|
76
|
+
"""
|
77
|
+
Raise a ValueError if a condition is met. Useful in jit-compiled functions.
|
78
|
+
"""
|
79
|
+
|
80
|
+
return raise_if(condition, ValueError, msg, *args, **kwargs)
|
jaxsim/integrators/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
from . import fixed_step
|
2
|
-
from .common import Integrator, Time, TimeStep
|
1
|
+
from . import fixed_step, variable_step
|
2
|
+
from .common import Integrator, SystemDynamics, Time, TimeStep
|