jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +64 -30
  24. jaxsim/math/cross.py +18 -9
  25. jaxsim/math/inertia.py +11 -9
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +59 -25
  28. jaxsim/math/rotation.py +30 -24
  29. jaxsim/math/skew.py +18 -7
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev5.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev5.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/top_level.txt +0 -0
jaxsim/api/model.py CHANGED
@@ -1,30 +1,30 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import copy
3
4
  import dataclasses
4
5
  import functools
5
6
  import pathlib
7
+ from collections.abc import Sequence
6
8
  from typing import Any
7
9
 
8
10
  import jax
9
11
  import jax.numpy as jnp
10
12
  import jax_dataclasses
11
- import jaxlie
12
13
  import rod
13
14
  from jax_dataclasses import Static
14
15
 
15
16
  import jaxsim.api as js
16
- import jaxsim.physics.algos.aba
17
- import jaxsim.physics.algos.crba
18
- import jaxsim.physics.algos.forward_kinematics
19
- import jaxsim.physics.algos.rnea
20
- import jaxsim.physics.model.physics_model
17
+ import jaxsim.exceptions
18
+ import jaxsim.terrain
21
19
  import jaxsim.typing as jtp
22
- from jaxsim.high_level.common import VelRepr
23
- from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
24
- from jaxsim.utils import JaxsimDataclass, Mutability
20
+ from jaxsim.math import Adjoint, Cross
21
+ from jaxsim.parsers.descriptions import ModelDescription
22
+ from jaxsim.utils import JaxsimDataclass, Mutability, wrappers
25
23
 
24
+ from .common import VelRepr
26
25
 
27
- @jax_dataclasses.pytree_dataclass
26
+
27
+ @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
28
28
  class JaxSimModel(JaxsimDataclass):
29
29
  """
30
30
  The JaxSim model defining the kinematics and dynamics of a robot.
@@ -32,46 +32,87 @@ class JaxSimModel(JaxsimDataclass):
32
32
 
33
33
  model_name: Static[str]
34
34
 
35
- physics_model: jaxsim.physics.model.physics_model.PhysicsModel = dataclasses.field(
36
- repr=False
35
+ time_step: jaxsim.integrators.TimeStep = dataclasses.field(
36
+ default_factory=lambda: jnp.array(0.001, dtype=float),
37
+ )
38
+
39
+ terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
40
+ default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
41
+ )
42
+
43
+ # Note that this is the default contact model.
44
+ contact_model: Static[jaxsim.rbda.contacts.ContactModel | None] = dataclasses.field(
45
+ default=None, repr=False
37
46
  )
38
47
 
39
- terrain: Static[Terrain] = dataclasses.field(default=FlatTerrain(), repr=False)
48
+ kin_dyn_parameters: js.kin_dyn_parameters.KinDynParameters | None = (
49
+ dataclasses.field(default=None, repr=False)
50
+ )
40
51
 
41
52
  built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
42
- repr=False, default=None
53
+ default=None, repr=False
43
54
  )
44
55
 
45
- _number_of_links: Static[int] = dataclasses.field(
46
- init=False, repr=False, default=None
56
+ integrator: Static[jaxsim.integrators.Integrator | None] = dataclasses.field(
57
+ default=None, repr=False
47
58
  )
48
59
 
49
- _number_of_joints: Static[int] = dataclasses.field(
50
- init=False, repr=False, default=None
60
+ _description: Static[wrappers.HashlessObject[ModelDescription | None]] = (
61
+ dataclasses.field(default=None, repr=False)
51
62
  )
52
63
 
53
- def __post_init__(self):
64
+ @property
65
+ def description(self) -> ModelDescription:
66
+ """
67
+ Return the model description.
68
+ """
69
+ return self._description.get()
70
+
71
+ def __eq__(self, other: JaxSimModel) -> bool:
72
+
73
+ if not isinstance(other, JaxSimModel):
74
+ return False
75
+
76
+ if self.model_name != other.model_name:
77
+ return False
78
+
79
+ if self.time_step != other.time_step:
80
+ return False
54
81
 
55
- # These attributes are Static so that we can use `jax.vmap` and `jax.lax.scan`
56
- # over the all links and joints
57
- with self.mutable_context(
58
- mutability=Mutability.MUTABLE_NO_VALIDATION,
59
- restore_after_exception=False,
60
- ):
61
- self._number_of_links = len(self.physics_model.description.links_dict)
62
- self._number_of_joints = len(self.physics_model.description.joints_dict)
82
+ if self.kin_dyn_parameters != other.kin_dyn_parameters:
83
+ return False
84
+
85
+ return True
86
+
87
+ def __hash__(self) -> int:
88
+
89
+ return hash(
90
+ (
91
+ hash(self.model_name),
92
+ hash(float(self.time_step)),
93
+ hash(self.kin_dyn_parameters),
94
+ hash(self.contact_model),
95
+ )
96
+ )
63
97
 
64
98
  # ========================
65
99
  # Initialization and state
66
100
  # ========================
67
101
 
68
- @staticmethod
102
+ @classmethod
69
103
  def build_from_model_description(
104
+ cls,
70
105
  model_description: str | pathlib.Path | rod.Model,
106
+ *,
71
107
  model_name: str | None = None,
72
- gravity: jtp.Array = jaxsim.physics.default_gravity(),
108
+ time_step: jtp.FloatLike | None = None,
109
+ integrator: (
110
+ jaxsim.integrators.Integrator | type[jaxsim.integrators.Integrator] | None
111
+ ) = None,
112
+ terrain: jaxsim.terrain.Terrain | None = None,
113
+ contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
73
114
  is_urdf: bool | None = None,
74
- considered_joints: list[str] | None = None,
115
+ considered_joints: Sequence[str] | None = None,
75
116
  ) -> JaxSimModel:
76
117
  """
77
118
  Build a Model object from a model description.
@@ -81,12 +122,21 @@ class JaxSimModel(JaxsimDataclass):
81
122
  A path to an SDF/URDF file, a string containing
82
123
  its content, or a pre-parsed/pre-built rod model.
83
124
  model_name:
84
- The optional name of the model that overrides the one in
85
- the description.
86
- gravity: The 3D gravity vector.
125
+ The name of the model. If not specified, it is read from the description.
126
+ time_step:
127
+ The default time step to consider for the simulation. It can be
128
+ manually overridden in the function that steps the simulation.
129
+ terrain: The terrain to consider (the default is a flat infinite plane).
130
+ contact_model:
131
+ The contact model to consider.
132
+ If not specified, a soft contacts model is used.
133
+ integrator:
134
+ The integrator to use. If not specified, a default one is used.
135
+ This argument can either be a pre-built integrator instance or one
136
+ of the integrator classes defined in JaxSim.
87
137
  is_urdf:
88
- Whether the model description is a URDF or an SDF. This is
89
- automatically inferred if the model description is a path to a file.
138
+ The optional flag to force the model description to be parsed as a URDF.
139
+ This is usually automatically inferred.
90
140
  considered_joints:
91
141
  The list of joints to consider. If None, all joints are considered.
92
142
 
@@ -97,7 +147,7 @@ class JaxSimModel(JaxsimDataclass):
97
147
  import jaxsim.parsers.rod
98
148
 
99
149
  # Parse the input resource (either a path to file or a string with the URDF/SDF)
100
- # and build the -intermediate- model description
150
+ # and build the -intermediate- model description.
101
151
  intermediate_description = jaxsim.parsers.rod.build_model_description(
102
152
  model_description=model_description, is_urdf=is_urdf
103
153
  )
@@ -109,44 +159,134 @@ class JaxSimModel(JaxsimDataclass):
109
159
  considered_joints=considered_joints
110
160
  )
111
161
 
112
- # Create the physics model from the model description
113
- physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
114
- model_description=intermediate_description, gravity=gravity
162
+ # Build the model.
163
+ model = cls.build(
164
+ model_description=intermediate_description,
165
+ model_name=model_name,
166
+ time_step=time_step,
167
+ integrator=integrator,
168
+ terrain=terrain,
169
+ contact_model=contact_model,
115
170
  )
116
171
 
117
- # Build the model
118
- model = JaxSimModel.build(physics_model=physics_model, model_name=model_name)
119
-
120
- # Store the origin of the model, in case downstream logic needs it
172
+ # Store the origin of the model, in case downstream logic needs it.
121
173
  with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
122
174
  model.built_from = model_description
123
175
 
124
176
  return model
125
177
 
126
- @staticmethod
178
+ @classmethod
127
179
  def build(
128
- physics_model: jaxsim.physics.model.physics_model.PhysicsModel,
180
+ cls,
181
+ model_description: ModelDescription,
182
+ *,
129
183
  model_name: str | None = None,
184
+ time_step: jtp.FloatLike | None = None,
185
+ integrator: (
186
+ jaxsim.integrators.Integrator | type[jaxsim.integrators.Integrator] | None
187
+ ) = None,
188
+ terrain: jaxsim.terrain.Terrain | None = None,
189
+ contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
130
190
  ) -> JaxSimModel:
131
191
  """
132
- Build a Model object from a physics model.
192
+ Build a Model object from an intermediate model description.
133
193
 
134
194
  Args:
135
- physics_model: The physics model.
195
+ model_description:
196
+ The intermediate model description defining the kinematics and dynamics
197
+ of the model.
136
198
  model_name:
199
+ The name of the model. If not specified, it is read from the description.
200
+ time_step:
201
+ The default time step to consider for the simulation. It can be
202
+ manually overridden in the function that steps the simulation.
203
+ terrain: The terrain to consider (the default is a flat infinite plane).
137
204
  The optional name of the model overriding the physics model name.
205
+ integrator:
206
+ The integrator to use. If not specified, a default one is used.
207
+ This argument can either be a pre-built integrator instance or one
208
+ of the integrator classes defined in JaxSim.
209
+ contact_model:
210
+ The contact model to consider.
211
+ If not specified, a soft contacts model is used.
138
212
 
139
213
  Returns:
140
214
  The built Model object.
141
215
  """
142
216
 
143
- # Set the model name (if not provided, use the one from the model description)
144
- model_name = (
145
- model_name if model_name is not None else physics_model.description.name
217
+ # Set the model name (if not provided, use the one from the model description).
218
+ model_name = model_name if model_name is not None else model_description.name
219
+
220
+ # Consider the default terrain (a flat infinite plane) if not specified.
221
+ terrain = (
222
+ terrain
223
+ if terrain is not None
224
+ else JaxSimModel.__dataclass_fields__["terrain"].default_factory()
225
+ )
226
+
227
+ # Consider the default time step if not specified.
228
+ time_step = (
229
+ time_step
230
+ if time_step is not None
231
+ else JaxSimModel.__dataclass_fields__["time_step"].default_factory()
232
+ )
233
+
234
+ # Create the default contact model.
235
+ # It will be populated with an initial estimation of good parameters.
236
+ # While these might not be the best, they are a good starting point.
237
+ contact_model = (
238
+ contact_model
239
+ if contact_model is not None
240
+ else jaxsim.rbda.contacts.SoftContacts.build()
146
241
  )
147
242
 
148
- # Build the model
149
- model = JaxSimModel(physics_model=physics_model, model_name=model_name) # noqa
243
+ # Build the integrator if not provided.
244
+ match integrator:
245
+
246
+ # If None, build a default integrator.
247
+ case None:
248
+
249
+ integrator = jaxsim.integrators.fixed_step.Heun2SO3.build(
250
+ dynamics=js.ode.wrap_system_dynamics_for_integration(
251
+ system_dynamics=js.ode.system_dynamics
252
+ )
253
+ )
254
+
255
+ # If it's a pre-built integrator (also a custom one from the user)
256
+ # just use it as is.
257
+ case _ if isinstance(integrator, jaxsim.integrators.Integrator):
258
+ pass
259
+
260
+ # If an integrator class is passed, assume that it is a JaxSim integrator
261
+ # and build it with the default system dynamics.
262
+ case _ if issubclass(integrator, jaxsim.integrators.Integrator):
263
+
264
+ integrator_cls = integrator
265
+ integrator = integrator_cls.build(
266
+ dynamics=js.ode.wrap_system_dynamics_for_integration(
267
+ system_dynamics=js.ode.system_dynamics
268
+ )
269
+ )
270
+
271
+ case _:
272
+ raise ValueError(f"Invalid integrator: {integrator}")
273
+
274
+ # Build the model.
275
+ model = cls(
276
+ model_name=model_name,
277
+ kin_dyn_parameters=js.kin_dyn_parameters.KinDynParameters.build(
278
+ model_description=model_description
279
+ ),
280
+ time_step=time_step,
281
+ terrain=terrain,
282
+ contact_model=contact_model,
283
+ integrator=integrator,
284
+ # The following is wrapped as hashless since it's a static argument, and we
285
+ # don't want to trigger recompilation if it changes. All relevant parameters
286
+ # needed to compute kinematics and dynamics quantities are stored in the
287
+ # kin_dyn_parameters attribute.
288
+ _description=wrappers.HashlessObject(obj=model_description),
289
+ )
150
290
 
151
291
  return model
152
292
 
@@ -164,7 +304,7 @@ class JaxSimModel(JaxsimDataclass):
164
304
 
165
305
  return self.model_name
166
306
 
167
- def number_of_links(self) -> jtp.Int:
307
+ def number_of_links(self) -> int:
168
308
  """
169
309
  Return the number of links in the model.
170
310
 
@@ -175,9 +315,9 @@ class JaxSimModel(JaxsimDataclass):
175
315
  The base link is included in the count and its index is always 0.
176
316
  """
177
317
 
178
- return self._number_of_links
318
+ return self.kin_dyn_parameters.number_of_links()
179
319
 
180
- def number_of_joints(self) -> jtp.Int:
320
+ def number_of_joints(self) -> int:
181
321
  """
182
322
  Return the number of joints in the model.
183
323
 
@@ -185,7 +325,18 @@ class JaxSimModel(JaxsimDataclass):
185
325
  The number of joints in the model.
186
326
  """
187
327
 
188
- return self._number_of_joints
328
+ return self.kin_dyn_parameters.number_of_joints()
329
+
330
+ def number_of_frames(self) -> int:
331
+ """
332
+ Return the number of frames in the model.
333
+
334
+ Returns:
335
+ The number of frames in the model.
336
+
337
+ """
338
+
339
+ return self.kin_dyn_parameters.number_of_frames()
189
340
 
190
341
  # =================
191
342
  # Base link methods
@@ -199,7 +350,7 @@ class JaxSimModel(JaxsimDataclass):
199
350
  True if the model is floating-base, False otherwise.
200
351
  """
201
352
 
202
- return self.physics_model.is_floating_base
353
+ return bool(self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6)
203
354
 
204
355
  def base_link(self) -> str:
205
356
  """
@@ -207,9 +358,12 @@ class JaxSimModel(JaxsimDataclass):
207
358
 
208
359
  Returns:
209
360
  The name of the base link.
361
+
362
+ Note:
363
+ By default, the base link is the root of the kinematic tree.
210
364
  """
211
365
 
212
- return self.physics_model.description.root.name
366
+ return self.link_names()[0]
213
367
 
214
368
  # =====================
215
369
  # Joint-related methods
@@ -227,7 +381,7 @@ class JaxSimModel(JaxsimDataclass):
227
381
  the number of joints. In the future, this could be different.
228
382
  """
229
383
 
230
- return len(self.physics_model.description.joints_dict)
384
+ return int(sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:]))
231
385
 
232
386
  def joint_names(self) -> tuple[str, ...]:
233
387
  """
@@ -237,7 +391,7 @@ class JaxSimModel(JaxsimDataclass):
237
391
  The names of the joints in the model.
238
392
  """
239
393
 
240
- return tuple(self.physics_model.description.joints_dict.keys())
394
+ return self.kin_dyn_parameters.joint_model.joint_names[1:]
241
395
 
242
396
  # ====================
243
397
  # Link-related methods
@@ -251,7 +405,21 @@ class JaxSimModel(JaxsimDataclass):
251
405
  The names of the links in the model.
252
406
  """
253
407
 
254
- return tuple(self.physics_model.description.links_dict.keys())
408
+ return self.kin_dyn_parameters.link_names
409
+
410
+ # =====================
411
+ # Frame-related methods
412
+ # =====================
413
+
414
+ def frame_names(self) -> tuple[str, ...]:
415
+ """
416
+ Return the names of the frames in the model.
417
+
418
+ Returns:
419
+ The names of the frames in the model.
420
+ """
421
+
422
+ return self.kin_dyn_parameters.frame_parameters.name
255
423
 
256
424
 
257
425
  # =====================
@@ -259,42 +427,63 @@ class JaxSimModel(JaxsimDataclass):
259
427
  # =====================
260
428
 
261
429
 
262
- def reduce(model: JaxSimModel, considered_joints: tuple[str, ...]) -> JaxSimModel:
430
+ def reduce(
431
+ model: JaxSimModel,
432
+ considered_joints: tuple[str, ...],
433
+ locked_joint_positions: dict[str, jtp.FloatLike] | None = None,
434
+ ) -> JaxSimModel:
263
435
  """
264
436
  Reduce the model by lumping together the links connected by removed joints.
265
437
 
266
438
  Args:
267
439
  model: The model to reduce.
268
440
  considered_joints: The sequence of joints to consider.
269
-
270
- Note:
271
- If considered_joints contains joints not existing in the model, the method
272
- will raise an exception. If considered_joints is empty, the method will
273
- return a copy of the input model.
441
+ locked_joint_positions:
442
+ A dictionary containing the positions of the joints to be considered
443
+ in the reduction process. The removed joints in the reduced model
444
+ will have their position locked to their value of this dictionary.
445
+ If a joint is not part of the dictionary, its position is set to zero.
274
446
  """
275
447
 
276
- if len(considered_joints) == 0:
277
- return model.copy()
448
+ locked_joint_positions = (
449
+ locked_joint_positions if locked_joint_positions is not None else {}
450
+ )
451
+
452
+ # If locked joints are passed, make sure that they are valid.
453
+ if not set(locked_joint_positions).issubset(model.joint_names()):
454
+ new_joints = set(model.joint_names()) - set(locked_joint_positions)
455
+ raise ValueError(f"Passed joints not existing in the model: {new_joints}")
456
+
457
+ # Operate on a deep copy of the model description in order to prevent problems
458
+ # when mutable attributes are updated.
459
+ intermediate_description = copy.deepcopy(model.description)
460
+
461
+ # Update the initial position of the joints.
462
+ # This is necessary to compute the correct pose of the link pairs connected
463
+ # to removed joints.
464
+ for joint_name in set(model.joint_names()) - set(considered_joints):
465
+ j = intermediate_description.joints_dict[joint_name]
466
+ with j.mutable_context():
467
+ j.initial_position = float(locked_joint_positions.get(joint_name, 0.0))
278
468
 
279
469
  # Reduce the model description.
280
- # If considered_joints contains joints not existing in the model, the method
281
- # will raise an exception.
282
- reduced_intermediate_description = model.physics_model.description.reduce(
470
+ # If `considered_joints` contains joints not existing in the model,
471
+ # the method will raise an exception.
472
+ reduced_intermediate_description = intermediate_description.reduce(
283
473
  considered_joints=list(considered_joints)
284
474
  )
285
475
 
286
- # Create the physics model from the reduced model description
287
- physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
288
- model_description=reduced_intermediate_description,
289
- gravity=model.physics_model.gravity[0:3],
290
- )
291
-
292
- # Build the reduced model
476
+ # Build the reduced model.
293
477
  reduced_model = JaxSimModel.build(
294
- physics_model=physics_model, model_name=model.name()
478
+ model_description=reduced_intermediate_description,
479
+ model_name=model.name(),
480
+ time_step=model.time_step,
481
+ terrain=model.terrain,
482
+ contact_model=model.contact_model,
483
+ integrator=model.integrator,
295
484
  )
296
485
 
297
- # Store the origin of the model, in case downstream logic needs it
486
+ # Store the origin of the model, in case downstream logic needs it.
298
487
  with reduced_model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
299
488
  reduced_model.built_from = model.built_from
300
489
 
@@ -307,6 +496,7 @@ def reduce(model: JaxSimModel, considered_joints: tuple[str, ...]) -> JaxSimMode
307
496
 
308
497
 
309
498
  @jax.jit
499
+ @js.common.named_scope
310
500
  def total_mass(model: JaxSimModel) -> jtp.Float:
311
501
  """
312
502
  Compute the total mass of the model.
@@ -318,52 +508,25 @@ def total_mass(model: JaxSimModel) -> jtp.Float:
318
508
  The total mass of the model.
319
509
  """
320
510
 
321
- return (
322
- jax.vmap(lambda idx: js.link.mass(model=model, link_index=idx))(
323
- jnp.arange(model.number_of_links())
324
- )
325
- .sum()
326
- .astype(float)
327
- )
328
-
329
-
330
- # ==============
331
- # Center of mass
332
- # ==============
511
+ return model.kin_dyn_parameters.link_parameters.mass.sum().astype(float)
333
512
 
334
513
 
335
514
  @jax.jit
336
- def com_position(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
515
+ @js.common.named_scope
516
+ def link_spatial_inertia_matrices(model: JaxSimModel) -> jtp.Array:
337
517
  """
338
- Compute the position of the center of mass of the model.
518
+ Compute the spatial 6D inertia matrices of all links of the model.
339
519
 
340
520
  Args:
341
521
  model: The model to consider.
342
- data: The data of the considered model.
343
522
 
344
523
  Returns:
345
- The position of the center of mass of the model w.r.t. the world frame.
524
+ A 3D array containing the stacked spatial 6D inertia matrices of the links.
346
525
  """
347
526
 
348
- m = total_mass(model=model)
349
-
350
- W_H_L = forward_kinematics(model=model, data=data)
351
- W_H_B = data.base_transform()
352
- B_H_W = jaxlie.SE3.from_matrix(W_H_B).inverse().as_matrix()
353
-
354
- def B_p̃_LCoM(i) -> jtp.Vector:
355
- m = js.link.mass(model=model, link_index=i)
356
- L_p_LCoM = js.link.com_position(
357
- model=model, data=data, link_index=i, in_link_frame=True
358
- )
359
- return m * B_H_W @ W_H_L[i] @ jnp.hstack([L_p_LCoM, 1])
360
-
361
- com_links = jax.vmap(B_p̃_LCoM)(jnp.arange(model.number_of_links()))
362
-
363
- B_p̃_CoM = (1 / m) * com_links.sum(axis=0)
364
- B_p̃_CoM = B_p̃_CoM.at[3].set(1)
365
-
366
- return (W_H_B @ B_p̃_CoM)[0:3].astype(float)
527
+ return jax.vmap(js.kin_dyn_parameters.LinkParameters.spatial_inertia)(
528
+ model.kin_dyn_parameters.link_parameters
529
+ )
367
530
 
368
531
 
369
532
  # ==============================
@@ -372,6 +535,7 @@ def com_position(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vecto
372
535
 
373
536
 
374
537
  @jax.jit
538
+ @js.common.named_scope
375
539
  def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
376
540
  """
377
541
  Compute the SE(3) transforms from the world frame to the frames of all links.
@@ -385,10 +549,11 @@ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp
385
549
  The first axis is the link index.
386
550
  """
387
551
 
388
- W_H_LL = jaxsim.physics.algos.forward_kinematics.forward_kinematics_model(
389
- model=model.physics_model,
390
- q=data.state.physics_model.joint_positions,
391
- xfb=data.state.physics_model.xfb(),
552
+ W_H_LL = jaxsim.rbda.forward_kinematics_model(
553
+ model=model,
554
+ base_position=data.base_position(),
555
+ base_quaternion=data.base_orientation(dcm=False),
556
+ joint_positions=data.joint_positions(model=model),
392
557
  )
393
558
 
394
559
  return jnp.atleast_3d(W_H_LL).astype(float)
@@ -424,51 +589,296 @@ def generalized_free_floating_jacobian(
424
589
  output_vel_repr if output_vel_repr is not None else data.velocity_representation
425
590
  )
426
591
 
427
- # The body frame of the link.jacobian method is the link frame L.
428
- # In this method, we want instead to use the base link B as body frame.
429
- # Therefore, we always get the link jacobian having Inertial as output
430
- # representation, and then we convert it to the desired output representation.
592
+ # Compute the doubly-left free-floating full jacobian.
593
+ B_J_full_WX_B, B_H_L = jaxsim.rbda.jacobian_full_doubly_left(
594
+ model=model,
595
+ joint_positions=data.joint_positions(),
596
+ )
597
+
598
+ # ======================================================================
599
+ # Update the input velocity representation such that v_WL = J_WL_I @ I_ν
600
+ # ======================================================================
601
+
602
+ match data.velocity_representation:
603
+
604
+ case VelRepr.Inertial:
605
+
606
+ W_H_B = data.base_transform()
607
+ B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
608
+
609
+ B_J_full_WX_I = B_J_full_WX_W = ( # noqa: F841
610
+ B_J_full_WX_B
611
+ @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
612
+ )
613
+
614
+ case VelRepr.Body:
615
+
616
+ B_J_full_WX_I = B_J_full_WX_B
617
+
618
+ case VelRepr.Mixed:
619
+
620
+ W_R_B = data.base_orientation(dcm=True)
621
+ BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
622
+ B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
623
+
624
+ B_J_full_WX_I = B_J_full_WX_BW = ( # noqa: F841
625
+ B_J_full_WX_B
626
+ @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
627
+ )
628
+
629
+ case _:
630
+ raise ValueError(data.velocity_representation)
631
+
632
+ # ====================================================================
633
+ # Create stacked Jacobian for each link by filtering the full Jacobian
634
+ # ====================================================================
635
+
636
+ κ_bool = model.kin_dyn_parameters.support_body_array_bool
637
+
638
+ # Keep only the columns of the full Jacobian corresponding to the support
639
+ # body array of each link.
640
+ B_J_WL_I = jax.vmap(
641
+ lambda κ: jnp.where(
642
+ jnp.hstack([jnp.ones(5), κ]), B_J_full_WX_I, jnp.zeros_like(B_J_full_WX_I)
643
+ )
644
+ )(κ_bool)
645
+
646
+ # =======================================================================
647
+ # Update the output velocity representation such that O_v_WL = O_J_WL @ ν
648
+ # =======================================================================
649
+
431
650
  match output_vel_repr:
651
+
432
652
  case VelRepr.Inertial:
433
- to_output = lambda W_J_WL: W_J_WL
653
+
654
+ W_H_B = data.base_transform()
655
+ W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B)
656
+
657
+ O_J_WL_I = W_J_WL_I = jax.vmap( # noqa: F841
658
+ lambda B_J_WL_I: W_X_B @ B_J_WL_I
659
+ )(B_J_WL_I)
434
660
 
435
661
  case VelRepr.Body:
436
662
 
437
- def to_output(W_J_WL: jtp.Matrix) -> jtp.Matrix:
438
- W_H_B = data.base_transform()
439
- B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
440
- return B_X_W @ W_J_WL
663
+ O_J_WL_I = L_J_WL_I = jax.vmap( # noqa: F841
664
+ lambda B_H_L, B_J_WL_I: jaxsim.math.Adjoint.from_transform(
665
+ B_H_L, inverse=True
666
+ )
667
+ @ B_J_WL_I
668
+ )(B_H_L, B_J_WL_I)
441
669
 
442
670
  case VelRepr.Mixed:
443
671
 
444
- def to_output(W_J_WL: jtp.Matrix) -> jtp.Matrix:
445
- W_H_B = data.base_transform()
446
- W_H_BW = jnp.array(W_H_B).at[0:3, 0:3].set(jnp.eye(3))
447
- BW_X_W = jaxlie.SE3.from_matrix(W_H_BW).inverse().adjoint()
448
- return BW_X_W @ W_J_WL
672
+ W_H_B = data.base_transform()
673
+
674
+ LW_H_L = jax.vmap(
675
+ lambda B_H_L: (W_H_B @ B_H_L).at[0:3, 3].set(jnp.zeros(3))
676
+ )(B_H_L)
677
+
678
+ LW_H_B = jax.vmap(
679
+ lambda LW_H_L, B_H_L: LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
680
+ )(LW_H_L, B_H_L)
681
+
682
+ O_J_WL_I = LW_J_WL_I = jax.vmap( # noqa: F841
683
+ lambda LW_H_B, B_J_WL_I: jaxsim.math.Adjoint.from_transform(LW_H_B)
684
+ @ B_J_WL_I
685
+ )(LW_H_B, B_J_WL_I)
449
686
 
450
687
  case _:
451
688
  raise ValueError(output_vel_repr)
452
689
 
453
- # Compute first the link jacobians having the active representation of `data`
454
- # as input representation (matching the one of ν), and inertial as output
455
- # representation (i.e. W_J_WL_C where C is C_ν).
456
- # Then, with to_output, we convert this jacobian to the desired output
457
- # representation, that can either be W (inertial), B (body), or B[W] (mixed).
458
- # This is necessary because for example the body-fixed free-floating jacobian
459
- # of a link is L_J_WL, but here being inside model we need B_J_WL.
460
- J_free_floating = jax.vmap(
461
- lambda i: to_output(
462
- W_J_WL=js.link.jacobian(
463
- model=model,
464
- data=data,
465
- link_index=i,
466
- output_vel_repr=VelRepr.Inertial,
690
+ return O_J_WL_I
691
+
692
+
693
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
694
+ def generalized_free_floating_jacobian_derivative(
695
+ model: JaxSimModel,
696
+ data: js.data.JaxSimModelData,
697
+ *,
698
+ output_vel_repr: VelRepr | None = None,
699
+ ) -> jtp.Matrix:
700
+ """
701
+ Compute the free-floating jacobian derivatives of all links.
702
+
703
+ Args:
704
+ model: The model to consider.
705
+ data: The data of the considered model.
706
+ output_vel_repr:
707
+ The output velocity representation of the free-floating jacobian derivatives.
708
+
709
+ Returns:
710
+ The `(nL, 6, 6+dofs)` array containing the stacked free-floating
711
+ jacobian derivatives of the links. The first axis is the link index.
712
+ """
713
+
714
+ output_vel_repr = (
715
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
716
+ )
717
+
718
+ # Compute the derivative of the doubly-left free-floating full jacobian.
719
+ B_J̇_full_WX_B, B_H_L = jaxsim.rbda.jacobian_derivative_full_doubly_left(
720
+ model=model,
721
+ joint_positions=data.joint_positions(),
722
+ joint_velocities=data.joint_velocities(),
723
+ )
724
+
725
+ # The derivative of the equation to change the input and output representations
726
+ # of the Jacobian derivative needs the computation of the plain link Jacobian.
727
+ B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left(
728
+ model=model,
729
+ joint_positions=data.joint_positions(),
730
+ )
731
+
732
+ # Compute the actual doubly-left free-floating jacobian derivative of the link
733
+ # by zeroing the columns not in the path π_B(L) using the boolean κ(i).
734
+ κb = model.kin_dyn_parameters.support_body_array_bool
735
+
736
+ # Compute the base transform.
737
+ W_H_B = data.base_transform()
738
+
739
+ # We add the 5 columns of ones to the Jacobian derivative to account for the
740
+ # base velocity and acceleration (5 + number of links = 6 + number of joints).
741
+ B_J̇_WL_B = (
742
+ jnp.hstack([jnp.ones((κb.shape[0], 5)), κb])[:, jnp.newaxis] * B_J̇_full_WX_B
743
+ )
744
+ B_J_WL_B = (
745
+ jnp.hstack([jnp.ones((κb.shape[0], 5)), κb])[:, jnp.newaxis] * B_J_full_WL_B
746
+ )
747
+
748
+ # =====================================================
749
+ # Compute quantities to adjust the input representation
750
+ # =====================================================
751
+
752
+ In = jnp.eye(model.dofs())
753
+ On = jnp.zeros(shape=(model.dofs(), model.dofs()))
754
+
755
+ match data.velocity_representation:
756
+
757
+ case VelRepr.Inertial:
758
+
759
+ B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True)
760
+
761
+ W_v_WB = data.base_velocity()
762
+ B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)
763
+
764
+ # Compute the operator to change the representation of ν, and its
765
+ # time derivative.
766
+ T = jax.scipy.linalg.block_diag(B_X_W, In)
767
+ Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_W, On)
768
+
769
+ case VelRepr.Body:
770
+
771
+ B_X_B = jaxsim.math.Adjoint.from_rotation_and_translation(
772
+ translation=jnp.zeros(3), rotation=jnp.eye(3)
773
+ )
774
+
775
+ B_Ẋ_B = jnp.zeros(shape=(6, 6))
776
+
777
+ # Compute the operator to change the representation of ν, and its
778
+ # time derivative.
779
+ T = jax.scipy.linalg.block_diag(B_X_B, In)
780
+ Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_B, On)
781
+
782
+ case VelRepr.Mixed:
783
+
784
+ BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
785
+ B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)
786
+
787
+ BW_v_WB = data.base_velocity()
788
+ BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
789
+
790
+ BW_v_BW_B = BW_v_WB - BW_v_W_BW
791
+ B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B)
792
+
793
+ # Compute the operator to change the representation of ν, and its
794
+ # time derivative.
795
+ T = jax.scipy.linalg.block_diag(B_X_BW, In)
796
+ Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_BW, On)
797
+
798
+ case _:
799
+ raise ValueError(data.velocity_representation)
800
+
801
+ # ======================================================
802
+ # Compute quantities to adjust the output representation
803
+ # ======================================================
804
+
805
+ match output_vel_repr:
806
+
807
+ case VelRepr.Inertial:
808
+
809
+ O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B)
810
+
811
+ with data.switch_velocity_representation(VelRepr.Body):
812
+ B_v_WB = data.base_velocity()
813
+
814
+ O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841
815
+
816
+ case VelRepr.Body:
817
+
818
+ O_X_B = L_X_B = jaxsim.math.Adjoint.from_transform(
819
+ transform=B_H_L, inverse=True
820
+ )
821
+
822
+ B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B)
823
+
824
+ with data.switch_velocity_representation(VelRepr.Body):
825
+ B_v_WB = data.base_velocity()
826
+ L_v_WL = jnp.einsum(
827
+ "b6j,j->b6", L_X_B @ B_J_WL_B, data.generalized_velocity()
828
+ )
829
+
830
+ O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841
831
+ jnp.einsum("bij,bj->bi", B_X_L, L_v_WL) - B_v_WB
832
+ )
833
+
834
+ case VelRepr.Mixed:
835
+
836
+ W_H_L = W_H_B @ B_H_L
837
+ LW_H_L = W_H_L.at[:, 0:3, 3].set(jnp.zeros_like(W_H_L[:, 0:3, 3]))
838
+ LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
839
+
840
+ O_X_B = LW_X_B = jaxsim.math.Adjoint.from_transform(transform=LW_H_B)
841
+
842
+ B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B)
843
+
844
+ with data.switch_velocity_representation(VelRepr.Body):
845
+ B_v_WB = data.base_velocity()
846
+
847
+ with data.switch_velocity_representation(VelRepr.Mixed):
848
+ BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
849
+ B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
850
+ LW_v_WL = jnp.einsum(
851
+ "bij,bj->bi",
852
+ LW_X_B,
853
+ B_J_WL_B
854
+ @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
855
+ @ data.generalized_velocity(),
856
+ )
857
+
858
+ LW_v_W_LW = LW_v_WL.at[:, 3:6].set(jnp.zeros_like(LW_v_WL[:, 3:6]))
859
+
860
+ LW_v_LW_L = LW_v_WL - LW_v_W_LW
861
+ LW_v_B_LW = LW_v_WL - jnp.einsum("bij,j->bi", LW_X_B, B_v_WB) - LW_v_LW_L
862
+
863
+ O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx( # noqa: F841
864
+ jnp.einsum("bij,bj->bi", B_X_LW, LW_v_B_LW)
467
865
  )
468
- )
469
- )(jnp.arange(model.number_of_links()))
470
866
 
471
- return J_free_floating
867
+ case _:
868
+ raise ValueError(output_vel_repr)
869
+
870
+ # =============================================================
871
+ # Express the Jacobian derivative in the target representations
872
+ # =============================================================
873
+
874
+ # Sum all the components that form the Jacobian derivative in the target
875
+ # input/output velocity representations.
876
+ O_J̇_WL_I = jnp.zeros_like(B_J̇_WL_B)
877
+ O_J̇_WL_I += O_Ẋ_B @ B_J_WL_B @ T
878
+ O_J̇_WL_I += O_X_B @ B_J̇_WL_B @ T
879
+ O_J̇_WL_I += O_X_B @ B_J_WL_B @ Ṫ
880
+
881
+ return O_J̇_WL_I
472
882
 
473
883
 
474
884
  @functools.partial(jax.jit, static_argnames=["prefer_aba"])
@@ -477,7 +887,7 @@ def forward_dynamics(
477
887
  data: js.data.JaxSimModelData,
478
888
  *,
479
889
  joint_forces: jtp.VectorLike | None = None,
480
- external_forces: jtp.MatrixLike | None = None,
890
+ link_forces: jtp.MatrixLike | None = None,
481
891
  prefer_aba: float = True,
482
892
  ) -> tuple[jtp.Vector, jtp.Vector]:
483
893
  """
@@ -488,8 +898,8 @@ def forward_dynamics(
488
898
  data: The data of the considered model.
489
899
  joint_forces:
490
900
  The joint forces to consider as a vector of shape `(dofs,)`.
491
- external_forces:
492
- The external forces to consider as a matrix of shape `(nL, 6)`.
901
+ link_forces:
902
+ The link 6D forces consider as a matrix of shape `(nL, 6)`.
493
903
  The frame in which they are expressed must be `data.velocity_representation`.
494
904
  prefer_aba: Whether to prefer the ABA algorithm over the CRB one.
495
905
 
@@ -505,17 +915,18 @@ def forward_dynamics(
505
915
  model=model,
506
916
  data=data,
507
917
  joint_forces=joint_forces,
508
- external_forces=external_forces,
918
+ link_forces=link_forces,
509
919
  )
510
920
 
511
921
 
512
922
  @jax.jit
923
+ @js.common.named_scope
513
924
  def forward_dynamics_aba(
514
925
  model: JaxSimModel,
515
926
  data: js.data.JaxSimModelData,
516
927
  *,
517
928
  joint_forces: jtp.VectorLike | None = None,
518
- external_forces: jtp.MatrixLike | None = None,
929
+ link_forces: jtp.MatrixLike | None = None,
519
930
  ) -> tuple[jtp.Vector, jtp.Vector]:
520
931
  """
521
932
  Compute the forward dynamics of the model with the ABA algorithm.
@@ -525,8 +936,8 @@ def forward_dynamics_aba(
525
936
  data: The data of the considered model.
526
937
  joint_forces:
527
938
  The joint forces to consider as a vector of shape `(dofs,)`.
528
- external_forces:
529
- The external forces to consider as a matrix of shape `(nL, 6)`.
939
+ link_forces:
940
+ The link 6D forces to consider as a matrix of shape `(nL, 6)`.
530
941
  The frame in which they are expressed must be `data.velocity_representation`.
531
942
 
532
943
  Returns:
@@ -535,86 +946,132 @@ def forward_dynamics_aba(
535
946
  considered joint forces and external forces.
536
947
  """
537
948
 
538
- # Build joint torques if not provided
949
+ # ============
950
+ # Prepare data
951
+ # ============
952
+
953
+ # Build joint forces, if not provided.
539
954
  τ = (
540
- joint_forces
955
+ jnp.atleast_1d(joint_forces.squeeze())
541
956
  if joint_forces is not None
542
957
  else jnp.zeros_like(data.joint_positions())
543
958
  )
544
959
 
545
- # Build external forces if not provided
546
- f_ext = (
547
- external_forces
548
- if external_forces is not None
960
+ # Build link forces, if not provided.
961
+ f_L = (
962
+ jnp.atleast_2d(link_forces.squeeze())
963
+ if link_forces is not None
549
964
  else jnp.zeros((model.number_of_links(), 6))
550
965
  )
551
966
 
552
- # Compute ABA
553
- W_v̇_WB, = jaxsim.physics.algos.aba.aba(
554
- model=model.physics_model,
555
- xfb=data.state.physics_model.xfb(),
556
- q=data.state.physics_model.joint_positions,
557
- qd=data.state.physics_model.joint_velocities,
558
- tau=τ,
559
- f_ext=f_ext,
967
+ # Create a references object that simplifies converting among representations.
968
+ references = js.references.JaxSimModelReferences.build(
969
+ model=model,
970
+ joint_force_references=τ,
971
+ link_forces=f_L,
972
+ data=data,
973
+ velocity_representation=data.velocity_representation,
560
974
  )
561
975
 
562
- def to_active(W_vd_WB, W_H_C, W_v_WB, W_vl_WC):
563
- C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
976
+ # Extract the link and joint serializations.
977
+ link_names = model.link_names()
978
+ joint_names = model.joint_names()
979
+
980
+ # Extract the state in inertial-fixed representation.
981
+ with data.switch_velocity_representation(VelRepr.Inertial):
982
+ W_p_B = data.base_position()
983
+ W_v_WB = data.base_velocity()
984
+ W_Q_B = data.base_orientation(dcm=False)
985
+ s = data.joint_positions(model=model, joint_names=joint_names)
986
+ ṡ = data.joint_velocities(model=model, joint_names=joint_names)
987
+
988
+ # Extract the inputs in inertial-fixed representation.
989
+ with references.switch_velocity_representation(VelRepr.Inertial):
990
+ W_f_L = references.link_forces(model=model, data=data, link_names=link_names)
991
+ τ = references.joint_force_references(model=model, joint_names=joint_names)
992
+
993
+ # ========================
994
+ # Compute forward dynamics
995
+ # ========================
996
+
997
+ W_v̇_WB, s̈ = jaxsim.rbda.aba(
998
+ model=model,
999
+ base_position=W_p_B,
1000
+ base_quaternion=W_Q_B,
1001
+ joint_positions=s,
1002
+ base_linear_velocity=W_v_WB[0:3],
1003
+ base_angular_velocity=W_v_WB[3:6],
1004
+ joint_velocities=ṡ,
1005
+ joint_forces=τ,
1006
+ link_forces=W_f_L,
1007
+ standard_gravity=data.standard_gravity(),
1008
+ )
564
1009
 
565
- if data.velocity_representation != VelRepr.Mixed:
566
- return C_X_W @ W_vd_WB
1010
+ # =============
1011
+ # Adjust output
1012
+ # =============
567
1013
 
568
- from jaxsim.math.cross import Cross
1014
+ def to_active(
1015
+ W_v̇_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WB: jtp.Vector, W_v_WC: jtp.Vector
1016
+ ) -> jtp.Vector:
1017
+ """
1018
+ Convert the inertial-fixed apparent base acceleration W_v̇_WB to
1019
+ another representation C_v̇_WB expressed in a generic frame C.
1020
+ """
569
1021
 
570
- W_v_WC = jnp.hstack([W_vl_WC, jnp.zeros(3)])
571
- return C_X_W @ (W_vd_WB - Cross.vx(W_v_WC) @ W_v_WB)
1022
+ # In Mixed representation, we need to include a cross product in ℝ⁶.
1023
+ # In Inertial and Body representations, the cross product is always zero.
1024
+ C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
1025
+ return C_X_W @ (W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB)
572
1026
 
573
1027
  match data.velocity_representation:
574
1028
  case VelRepr.Inertial:
575
- W_H_C = W_H_W = jnp.eye(4)
576
- W_vl_WC = W_vl_WW = jnp.zeros(3)
1029
+ # In this case C=W
1030
+ W_H_C = W_H_W = jnp.eye(4) # noqa: F841
1031
+ W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
577
1032
 
578
1033
  case VelRepr.Body:
1034
+ # In this case C=B
579
1035
  W_H_C = W_H_B = data.base_transform()
580
- W_vl_WC = W_vl_WB = data.base_velocity()[0:3]
1036
+ W_v_WC = W_v_WB
581
1037
 
582
1038
  case VelRepr.Mixed:
1039
+ # In this case C=B[W]
583
1040
  W_H_B = data.base_transform()
584
- W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
585
- W_vl_WC = W_vl_W_BW = data.base_velocity()[0:3]
1041
+ W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
1042
+ W_ṗ_B = data.base_velocity()[0:3]
1043
+ W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841
586
1044
 
587
1045
  case _:
588
1046
  raise ValueError(data.velocity_representation)
589
1047
 
590
- # We need to convert the derivative of the base acceleration to the active
1048
+ # We need to convert the derivative of the base velocity to the active
591
1049
  # representation. In Mixed representation, this conversion is not a plain
592
1050
  # transformation with just X, but it also involves a cross product in ℝ⁶.
593
1051
  C_v̇_WB = to_active(
594
- W_vd_WB=W_v̇_WB.squeeze(),
1052
+ W_v̇_WB=W_v̇_WB,
595
1053
  W_H_C=W_H_C,
596
- W_v_WB=jnp.hstack(
597
- [
598
- data.state.physics_model.base_linear_velocity,
599
- data.state.physics_model.base_angular_velocity,
600
- ]
601
- ),
602
- W_vl_WC=W_vl_WC,
1054
+ W_v_WB=W_v_WB,
1055
+ W_v_WC=W_v_WC,
603
1056
  )
604
1057
 
605
- # Adjust shape
606
- = jnp.atleast_1d(s̈.squeeze())
1058
+ # The ABA algorithm already returns a zero base 6D acceleration for
1059
+ # fixed-based models. However, the to_active function introduces an
1060
+ # additional acceleration component in Mixed representation.
1061
+ # Here below we make sure that the base acceleration is zero.
1062
+ C_v̇_WB = C_v̇_WB if model.floating_base() else jnp.zeros(6)
607
1063
 
608
- return C_v̇_WB, s̈
1064
+ return C_v̇_WB.astype(float), s̈.astype(float)
609
1065
 
610
1066
 
611
1067
  @jax.jit
1068
+ @js.common.named_scope
612
1069
  def forward_dynamics_crb(
613
1070
  model: JaxSimModel,
614
1071
  data: js.data.JaxSimModelData,
615
1072
  *,
616
1073
  joint_forces: jtp.VectorLike | None = None,
617
- external_forces: jtp.MatrixLike | None = None,
1074
+ link_forces: jtp.MatrixLike | None = None,
618
1075
  ) -> tuple[jtp.Vector, jtp.Vector]:
619
1076
  """
620
1077
  Compute the forward dynamics of the model with the CRB algorithm.
@@ -624,8 +1081,8 @@ def forward_dynamics_crb(
624
1081
  data: The data of the considered model.
625
1082
  joint_forces:
626
1083
  The joint forces to consider as a vector of shape `(dofs,)`.
627
- external_forces:
628
- The external forces to consider as a matrix of shape `(nL, 6)`.
1084
+ link_forces:
1085
+ The link 6D forces to consider as a matrix of shape `(nL, 6)`.
629
1086
  The frame in which they are expressed must be `data.velocity_representation`.
630
1087
 
631
1088
  Returns:
@@ -638,21 +1095,25 @@ def forward_dynamics_crb(
638
1095
  models with a large number of degrees of freedom.
639
1096
  """
640
1097
 
641
- # Build joint torques if not provided
1098
+ # ============
1099
+ # Prepare data
1100
+ # ============
1101
+
1102
+ # Build joint torques if not provided.
642
1103
  τ = (
643
1104
  jnp.atleast_1d(joint_forces)
644
1105
  if joint_forces is not None
645
1106
  else jnp.zeros_like(data.joint_positions())
646
1107
  )
647
1108
 
648
- # Build external forces if not provided
1109
+ # Build external forces if not provided.
649
1110
  f = (
650
- jnp.atleast_2d(external_forces)
651
- if external_forces is not None
1111
+ jnp.atleast_2d(link_forces)
1112
+ if link_forces is not None
652
1113
  else jnp.zeros(shape=(model.number_of_links(), 6))
653
1114
  )
654
1115
 
655
- # Compute terms of the floating-base EoM
1116
+ # Compute terms of the floating-base EoM.
656
1117
  M = free_floating_mass_matrix(model=model, data=data)
657
1118
  h = free_floating_bias_forces(model=model, data=data)
658
1119
  S = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T
@@ -660,6 +1121,10 @@ def forward_dynamics_crb(
660
1121
 
661
1122
  # TODO: invert the Mss block exploiting sparsity defined by the parent array λ(i)
662
1123
 
1124
+ # ========================
1125
+ # Compute forward dynamics
1126
+ # ========================
1127
+
663
1128
  if model.floating_base():
664
1129
  # l: number of links.
665
1130
  # g: generalized coordinates, 6 + number of joints.
@@ -675,19 +1140,24 @@ def forward_dynamics_crb(
675
1140
  v̇_WB = jnp.zeros(6)
676
1141
  ν̇ = jnp.hstack([v̇_WB, s̈.squeeze()])
677
1142
 
1143
+ # =============
1144
+ # Adjust output
1145
+ # =============
1146
+
678
1147
  # Extract the base acceleration in the active representation.
679
1148
  # Note that this is an apparent acceleration (relevant in Mixed representation),
680
1149
  # therefore it cannot be always expressed in different frames with just a
681
1150
  # 6D transformation X.
682
1151
  v̇_WB = ν̇[0:6].squeeze().astype(float)
683
1152
 
684
- # Extract the joint accelerations
1153
+ # Extract the joint accelerations.
685
1154
  s̈ = jnp.atleast_1d(ν̇[6:].squeeze()).astype(float)
686
1155
 
687
1156
  return v̇_WB, s̈
688
1157
 
689
1158
 
690
1159
  @jax.jit
1160
+ @js.common.named_scope
691
1161
  def free_floating_mass_matrix(
692
1162
  model: JaxSimModel, data: js.data.JaxSimModelData
693
1163
  ) -> jtp.Matrix:
@@ -702,9 +1172,9 @@ def free_floating_mass_matrix(
702
1172
  The free-floating mass matrix of the model.
703
1173
  """
704
1174
 
705
- M_body = jaxsim.physics.algos.crba.crba(
706
- model=model.physics_model,
707
- q=data.state.physics_model.joint_positions,
1175
+ M_body = jaxsim.rbda.crba(
1176
+ model=model,
1177
+ joint_positions=data.state.physics_model.joint_positions,
708
1178
  )
709
1179
 
710
1180
  match data.velocity_representation:
@@ -712,29 +1182,19 @@ def free_floating_mass_matrix(
712
1182
  return M_body
713
1183
 
714
1184
  case VelRepr.Inertial:
715
- zero_6n = jnp.zeros(shape=(6, model.dofs()))
716
- B_X_W = jaxlie.SE3.from_matrix(data.base_transform()).inverse().adjoint()
717
1185
 
718
- invT = jnp.vstack(
719
- [
720
- jnp.block([B_X_W, zero_6n]),
721
- jnp.block([zero_6n.T, jnp.eye(model.dofs())]),
722
- ]
1186
+ B_X_W = Adjoint.from_transform(
1187
+ transform=data.base_transform(), inverse=True
723
1188
  )
1189
+ invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
724
1190
 
725
1191
  return invT.T @ M_body @ invT
726
1192
 
727
1193
  case VelRepr.Mixed:
728
- zero_6n = jnp.zeros(shape=(6, model.dofs()))
729
- W_H_BW = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
730
- BW_X_W = jaxlie.SE3.from_matrix(W_H_BW).inverse().adjoint()
731
-
732
- invT = jnp.vstack(
733
- [
734
- jnp.block([BW_X_W, zero_6n]),
735
- jnp.block([zero_6n.T, jnp.eye(model.dofs())]),
736
- ]
737
- )
1194
+
1195
+ BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
1196
+ B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
1197
+ invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
738
1198
 
739
1199
  return invT.T @ M_body @ invT
740
1200
 
@@ -743,77 +1203,206 @@ def free_floating_mass_matrix(
743
1203
 
744
1204
 
745
1205
  @jax.jit
746
- def inverse_dynamics(
747
- model: JaxSimModel,
748
- data: js.data.JaxSimModelData,
749
- *,
750
- joint_accelerations: jtp.Vector | None = None,
751
- base_acceleration: jtp.Vector | None = None,
752
- external_forces: jtp.Matrix | None = None,
753
- ) -> tuple[jtp.Vector, jtp.Vector]:
1206
+ @js.common.named_scope
1207
+ def free_floating_coriolis_matrix(
1208
+ model: JaxSimModel, data: js.data.JaxSimModelData
1209
+ ) -> jtp.Matrix:
754
1210
  """
755
- Compute inverse dynamics with the RNEA algorithm.
1211
+ Compute the free-floating Coriolis matrix of the model.
756
1212
 
757
1213
  Args:
758
1214
  model: The model to consider.
759
1215
  data: The data of the considered model.
760
- joint_accelerations:
761
- The joint accelerations to consider as a vector of shape `(dofs,)`.
762
- base_acceleration:
763
- The base acceleration to consider as a vector of shape `(6,)`.
764
- external_forces:
765
- The external forces to consider as a matrix of shape `(nL, 6)`.
766
- The frame in which they are expressed must be `data.velocity_representation`.
767
1216
 
768
1217
  Returns:
769
- A tuple containing the 6D force in the active representation applied to the
770
- base to obtain the considered base acceleration, and the joint forces to apply
771
- to obtain the considered joint accelerations.
1218
+ The free-floating Coriolis matrix of the model.
1219
+
1220
+ Note:
1221
+ This function, contrarily to other quantities of the equations of motion,
1222
+ does not exploit any iterative algorithm. Therefore, the computation of
1223
+ the Coriolis matrix may be much slower than other quantities.
772
1224
  """
773
1225
 
774
- # Build joint accelerations if not provided
775
- joint_accelerations = (
776
- joint_accelerations
777
- if joint_accelerations is not None
778
- else jnp.zeros_like(data.joint_positions())
779
- )
1226
+ # We perform all the calculation in body-fixed.
1227
+ # The Coriolis matrix computed in this representation is converted later
1228
+ # to the active representation stored in data.
1229
+ with data.switch_velocity_representation(VelRepr.Body):
780
1230
 
781
- # Build base acceleration if not provided
782
- base_acceleration = (
783
- base_acceleration if base_acceleration is not None else jnp.zeros(6)
784
- )
1231
+ B_ν = data.generalized_velocity()
785
1232
 
786
- external_forces = (
787
- external_forces
788
- if external_forces is not None
789
- else jnp.zeros(shape=(model.number_of_links(), 6))
790
- )
1233
+ # Doubly-left free-floating Jacobian.
1234
+ L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data)
791
1235
 
792
- def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_vl_WC):
793
- W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint()
794
- C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
1236
+ # Doubly-left free-floating Jacobian derivative.
1237
+ L_J̇_WL_B = generalized_free_floating_jacobian_derivative(model=model, data=data)
795
1238
 
796
- if data.velocity_representation != VelRepr.Mixed:
797
- return W_X_C @ C_v̇_WB
798
- else:
799
- from jaxsim.math.cross import Cross
1239
+ L_M_L = link_spatial_inertia_matrices(model=model)
800
1240
 
801
- C_v_WC = C_X_W @ jnp.hstack([W_vl_WC, jnp.zeros(3)])
802
- return W_X_C @ (C_v̇_WB + Cross.vx(C_v_WC) @ C_v_WB)
1241
+ # Body-fixed link velocities.
1242
+ # Note: we could have called link.velocity() instead of computing it ourselves,
1243
+ # but since we need the link Jacobians later, we can save a double calculation.
1244
+ L_v_WL = jax.vmap(lambda J: J @ B_ν)(L_J_WL_B)
803
1245
 
804
- match data.velocity_representation:
805
- case VelRepr.Inertial:
806
- W_H_C = W_H_W = jnp.eye(4)
807
- W_vl_WC = W_vl_WW = jnp.zeros(3)
1246
+ # Compute the contribution of each link to the Coriolis matrix.
1247
+ def compute_link_contribution(M, v, J, J̇) -> jtp.Array:
808
1248
 
809
- case VelRepr.Body:
810
- W_H_C = W_H_B = data.base_transform()
811
- W_vl_WC = W_vl_WB = data.base_velocity()[0:3]
1249
+ return J.T @ ((Cross.vx_star(v) @ M + M @ Cross.vx(v)) @ J + M @ J̇)
812
1250
 
813
- case VelRepr.Mixed:
814
- W_H_B = data.base_transform()
815
- W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
816
- W_vl_WC = W_vl_W_BW = data.base_velocity()[0:3]
1251
+ C_B_links = jax.vmap(compute_link_contribution)(
1252
+ L_M_L,
1253
+ L_v_WL,
1254
+ L_J_WL_B,
1255
+ L_J̇_WL_B,
1256
+ )
1257
+
1258
+ # We need to adjust the Coriolis matrix for fixed-base models.
1259
+ # In this case, the base link does not contribute to the matrix, and we need to zero
1260
+ # the off-diagonal terms mapping joint quantities onto the base configuration.
1261
+ if model.floating_base():
1262
+ C_B = C_B_links.sum(axis=0)
1263
+ else:
1264
+ C_B = C_B_links[1:].sum(axis=0)
1265
+ C_B = C_B.at[0:6, 6:].set(0.0)
1266
+ C_B = C_B.at[6:, 0:6].set(0.0)
1267
+
1268
+ # Adjust the representation of the Coriolis matrix.
1269
+ # Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6.
1270
+ match data.velocity_representation:
1271
+
1272
+ case VelRepr.Body:
1273
+ return C_B
1274
+
1275
+ case VelRepr.Inertial:
1276
+
1277
+ n = model.dofs()
1278
+ W_H_B = data.base_transform()
1279
+ B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True)
1280
+ B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n))
1281
+
1282
+ with data.switch_velocity_representation(VelRepr.Inertial):
1283
+ W_v_WB = data.base_velocity()
1284
+ B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)
1285
+
1286
+ B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n)))
1287
+
1288
+ with data.switch_velocity_representation(VelRepr.Body):
1289
+ M = free_floating_mass_matrix(model=model, data=data)
1290
+
1291
+ C = B_T_W.T @ (M @ B_Ṫ_W + C_B @ B_T_W)
1292
+
1293
+ return C
1294
+
1295
+ case VelRepr.Mixed:
1296
+
1297
+ n = model.dofs()
1298
+ BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
1299
+ B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)
1300
+ B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n))
1301
+
1302
+ with data.switch_velocity_representation(VelRepr.Mixed):
1303
+ BW_v_WB = data.base_velocity()
1304
+ BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
1305
+
1306
+ BW_v_BW_B = BW_v_WB - BW_v_W_BW
1307
+ B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B)
1308
+
1309
+ B_Ṫ_BW = jax.scipy.linalg.block_diag(B_Ẋ_BW, jnp.zeros(shape=(n, n)))
1310
+
1311
+ with data.switch_velocity_representation(VelRepr.Body):
1312
+ M = free_floating_mass_matrix(model=model, data=data)
1313
+
1314
+ C = B_T_BW.T @ (M @ B_Ṫ_BW + C_B @ B_T_BW)
1315
+
1316
+ return C
1317
+
1318
+ case _:
1319
+ raise ValueError(data.velocity_representation)
1320
+
1321
+
1322
+ @jax.jit
1323
+ @js.common.named_scope
1324
+ def inverse_dynamics(
1325
+ model: JaxSimModel,
1326
+ data: js.data.JaxSimModelData,
1327
+ *,
1328
+ joint_accelerations: jtp.VectorLike | None = None,
1329
+ base_acceleration: jtp.VectorLike | None = None,
1330
+ link_forces: jtp.MatrixLike | None = None,
1331
+ ) -> tuple[jtp.Vector, jtp.Vector]:
1332
+ """
1333
+ Compute inverse dynamics with the RNEA algorithm.
1334
+
1335
+ Args:
1336
+ model: The model to consider.
1337
+ data: The data of the considered model.
1338
+ joint_accelerations:
1339
+ The joint accelerations to consider as a vector of shape `(dofs,)`.
1340
+ base_acceleration:
1341
+ The base acceleration to consider as a vector of shape `(6,)`.
1342
+ link_forces:
1343
+ The link 6D forces to consider as a matrix of shape `(nL, 6)`.
1344
+ The frame in which they are expressed must be `data.velocity_representation`.
1345
+
1346
+ Returns:
1347
+ A tuple containing the 6D force in the active representation applied to the
1348
+ base to obtain the considered base acceleration, and the joint forces to apply
1349
+ to obtain the considered joint accelerations.
1350
+ """
1351
+
1352
+ # ============
1353
+ # Prepare data
1354
+ # ============
1355
+
1356
+ # Build joint accelerations, if not provided.
1357
+ s̈ = (
1358
+ jnp.atleast_1d(jnp.array(joint_accelerations).squeeze())
1359
+ if joint_accelerations is not None
1360
+ else jnp.zeros_like(data.joint_positions())
1361
+ )
1362
+
1363
+ # Build base acceleration, if not provided.
1364
+ v̇_WB = (
1365
+ jnp.array(base_acceleration).squeeze()
1366
+ if base_acceleration is not None
1367
+ else jnp.zeros(6)
1368
+ )
1369
+
1370
+ # Build link forces, if not provided.
1371
+ f_L = (
1372
+ jnp.atleast_2d(jnp.array(link_forces).squeeze())
1373
+ if link_forces is not None
1374
+ else jnp.zeros(shape=(model.number_of_links(), 6))
1375
+ )
1376
+
1377
+ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):
1378
+ """
1379
+ Convert the active representation of the base acceleration C_v̇_WB
1380
+ expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
1381
+ """
1382
+
1383
+ W_X_C = Adjoint.from_transform(transform=W_H_C)
1384
+ C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
1385
+ C_v_WC = C_X_W @ W_v_WC
1386
+
1387
+ # In Mixed representation, we need to include a cross product in ℝ⁶.
1388
+ # In Inertial and Body representations, the cross product is always zero.
1389
+ return W_X_C @ (C_v̇_WB + Cross.vx(C_v_WC) @ C_v_WB)
1390
+
1391
+ match data.velocity_representation:
1392
+ case VelRepr.Inertial:
1393
+ W_H_C = W_H_W = jnp.eye(4) # noqa: F841
1394
+ W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
1395
+
1396
+ case VelRepr.Body:
1397
+ W_H_C = W_H_B = data.base_transform()
1398
+ with data.switch_velocity_representation(VelRepr.Inertial):
1399
+ W_v_WC = W_v_WB = data.base_velocity()
1400
+
1401
+ case VelRepr.Mixed:
1402
+ W_H_B = data.base_transform()
1403
+ W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
1404
+ W_ṗ_B = data.base_velocity()[0:3]
1405
+ W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841
817
1406
 
818
1407
  case _:
819
1408
  raise ValueError(data.velocity_representation)
@@ -822,35 +1411,60 @@ def inverse_dynamics(
822
1411
  # representation. In Mixed representation, this conversion is not a plain
823
1412
  # transformation with just X, but it also involves a cross product in ℝ⁶.
824
1413
  W_v̇_WB = to_inertial(
825
- C_v̇_WB=base_acceleration,
1414
+ C_v̇_WB=v̇_WB,
826
1415
  W_H_C=W_H_C,
827
1416
  C_v_WB=data.base_velocity(),
828
- W_vl_WC=W_vl_WC,
1417
+ W_v_WC=W_v_WC,
829
1418
  )
830
1419
 
1420
+ # Create a references object that simplifies converting among representations.
831
1421
  references = js.references.JaxSimModelReferences.build(
832
1422
  model=model,
833
1423
  data=data,
834
- link_forces=external_forces,
1424
+ link_forces=f_L,
835
1425
  velocity_representation=data.velocity_representation,
836
1426
  )
837
1427
 
838
- # Compute RNEA
1428
+ # Extract the link and joint serializations.
1429
+ link_names = model.link_names()
1430
+ joint_names = model.joint_names()
1431
+
1432
+ # Extract the state in inertial-fixed representation.
1433
+ with data.switch_velocity_representation(VelRepr.Inertial):
1434
+ W_p_B = data.base_position()
1435
+ W_v_WB = data.base_velocity()
1436
+ W_Q_B = data.base_orientation(dcm=False)
1437
+ s = data.joint_positions(model=model, joint_names=joint_names)
1438
+ ṡ = data.joint_velocities(model=model, joint_names=joint_names)
1439
+
1440
+ # Extract the inputs in inertial-fixed representation.
839
1441
  with references.switch_velocity_representation(VelRepr.Inertial):
840
- W_f_B, τ = jaxsim.physics.algos.rnea.rnea(
841
- model=model.physics_model,
842
- xfb=data.state.physics_model.xfb(),
843
- q=data.state.physics_model.joint_positions,
844
- qd=data.state.physics_model.joint_velocities,
845
- qdd=joint_accelerations,
846
- a0fb=W_v̇_WB,
847
- f_ext=references.link_forces(model=model, data=data),
848
- )
1442
+ W_f_L = references.link_forces(model=model, data=data, link_names=link_names)
1443
+
1444
+ # ========================
1445
+ # Compute inverse dynamics
1446
+ # ========================
1447
+
1448
+ W_f_B, τ = jaxsim.rbda.rnea(
1449
+ model=model,
1450
+ base_position=W_p_B,
1451
+ base_quaternion=W_Q_B,
1452
+ joint_positions=s,
1453
+ base_linear_velocity=W_v_WB[0:3],
1454
+ base_angular_velocity=W_v_WB[3:6],
1455
+ joint_velocities=ṡ,
1456
+ base_linear_acceleration=W_v̇_WB[0:3],
1457
+ base_angular_acceleration=W_v̇_WB[3:6],
1458
+ joint_accelerations=s̈,
1459
+ link_forces=W_f_L,
1460
+ standard_gravity=data.standard_gravity(),
1461
+ )
849
1462
 
850
- # Adjust shape
851
- τ = jnp.atleast_1d(τ.squeeze())
1463
+ # =============
1464
+ # Adjust output
1465
+ # =============
852
1466
 
853
- # Express W_f_B in the active representation
1467
+ # Express W_f_B in the active representation.
854
1468
  f_B = js.data.JaxSimModelData.inertial_to_other_representation(
855
1469
  array=W_f_B,
856
1470
  other_representation=data.velocity_representation,
@@ -862,10 +1476,11 @@ def inverse_dynamics(
862
1476
 
863
1477
 
864
1478
  @jax.jit
1479
+ @js.common.named_scope
865
1480
  def free_floating_gravity_forces(
866
1481
  model: JaxSimModel, data: js.data.JaxSimModelData
867
1482
  ) -> jtp.Vector:
868
- """
1483
+ r"""
869
1484
  Compute the free-floating gravity forces :math:`g(\mathbf{q})` of the model.
870
1485
 
871
1486
  Args:
@@ -876,12 +1491,12 @@ def free_floating_gravity_forces(
876
1491
  The free-floating gravity forces of the model.
877
1492
  """
878
1493
 
879
- # Build a zeroed state
1494
+ # Build a zeroed state.
880
1495
  data_rnea = js.data.JaxSimModelData.zero(
881
1496
  model=model, velocity_representation=data.velocity_representation
882
1497
  )
883
1498
 
884
- # Set just the generalized position
1499
+ # Set just the generalized position.
885
1500
  with data_rnea.mutable_context(
886
1501
  mutability=Mutability.MUTABLE, restore_after_exception=False
887
1502
  ):
@@ -905,16 +1520,17 @@ def free_floating_gravity_forces(
905
1520
  # Set zero inputs:
906
1521
  joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),
907
1522
  base_acceleration=jnp.zeros(6),
908
- external_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
1523
+ link_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
909
1524
  )
910
1525
  ).astype(float)
911
1526
 
912
1527
 
913
1528
  @jax.jit
1529
+ @js.common.named_scope
914
1530
  def free_floating_bias_forces(
915
1531
  model: JaxSimModel, data: js.data.JaxSimModelData
916
1532
  ) -> jtp.Vector:
917
- """
1533
+ r"""
918
1534
  Compute the free-floating bias forces :math:`h(\mathbf{q}, \boldsymbol{\nu})`
919
1535
  of the model.
920
1536
 
@@ -926,12 +1542,12 @@ def free_floating_bias_forces(
926
1542
  The free-floating bias forces of the model.
927
1543
  """
928
1544
 
929
- # Build a zeroed state
1545
+ # Build a zeroed state.
930
1546
  data_rnea = js.data.JaxSimModelData.zero(
931
1547
  model=model, velocity_representation=data.velocity_representation
932
1548
  )
933
1549
 
934
- # Set the generalized position and generalized velocity
1550
+ # Set the generalized position and generalized velocity.
935
1551
  with data_rnea.mutable_context(
936
1552
  mutability=Mutability.MUTABLE, restore_after_exception=False
937
1553
  ):
@@ -948,18 +1564,20 @@ def free_floating_bias_forces(
948
1564
  data.state.physics_model.joint_positions
949
1565
  )
950
1566
 
951
- data_rnea.state.physics_model.base_linear_velocity = (
952
- data.state.physics_model.base_linear_velocity
953
- )
954
-
955
- data_rnea.state.physics_model.base_angular_velocity = (
956
- data.state.physics_model.base_angular_velocity
957
- )
958
-
959
1567
  data_rnea.state.physics_model.joint_velocities = (
960
1568
  data.state.physics_model.joint_velocities
961
1569
  )
962
1570
 
1571
+ # Make sure that base velocity is zero for fixed-base model.
1572
+ if model.floating_base():
1573
+ data_rnea.state.physics_model.base_linear_velocity = (
1574
+ data.state.physics_model.base_linear_velocity
1575
+ )
1576
+
1577
+ data_rnea.state.physics_model.base_angular_velocity = (
1578
+ data.state.physics_model.base_angular_velocity
1579
+ )
1580
+
963
1581
  return jnp.hstack(
964
1582
  inverse_dynamics(
965
1583
  model=model,
@@ -967,7 +1585,7 @@ def free_floating_bias_forces(
967
1585
  # Set zero inputs:
968
1586
  joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),
969
1587
  base_acceleration=jnp.zeros(6),
970
- external_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
1588
+ link_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
971
1589
  )
972
1590
  ).astype(float)
973
1591
 
@@ -978,6 +1596,26 @@ def free_floating_bias_forces(
978
1596
 
979
1597
 
980
1598
  @jax.jit
1599
+ @js.common.named_scope
1600
+ def locked_spatial_inertia(
1601
+ model: JaxSimModel, data: js.data.JaxSimModelData
1602
+ ) -> jtp.Matrix:
1603
+ """
1604
+ Compute the locked 6D inertia matrix of the model.
1605
+
1606
+ Args:
1607
+ model: The model to consider.
1608
+ data: The data of the considered model.
1609
+
1610
+ Returns:
1611
+ The locked 6D inertia matrix of the model.
1612
+ """
1613
+
1614
+ return total_momentum_jacobian(model=model, data=data)[:, 0:6]
1615
+
1616
+
1617
+ @jax.jit
1618
+ @js.common.named_scope
981
1619
  def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
982
1620
  """
983
1621
  Compute the total momentum of the model.
@@ -987,35 +1625,453 @@ def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vec
987
1625
  data: The data of the considered model.
988
1626
 
989
1627
  Returns:
990
- The total momentum of the model.
1628
+ The total momentum of the model in the active velocity representation.
991
1629
  """
992
1630
 
993
- # Compute the momentum in body-fixed velocity representation.
994
- # Note: the first 6 rows of the mass matrix define the jacobian of the
995
- # floating-base momentum.
996
- with data.switch_velocity_representation(velocity_representation=VelRepr.Body):
997
- B_ν = data.generalized_velocity()
998
- M_B = free_floating_mass_matrix(model=model, data=data)
1631
+ ν = data.generalized_velocity()
1632
+ Jh = total_momentum_jacobian(model=model, data=data)
1633
+
1634
+ return Jh @ ν
1635
+
1636
+
1637
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
1638
+ def total_momentum_jacobian(
1639
+ model: JaxSimModel,
1640
+ data: js.data.JaxSimModelData,
1641
+ *,
1642
+ output_vel_repr: VelRepr | None = None,
1643
+ ) -> jtp.Matrix:
1644
+ """
1645
+ Compute the jacobian of the total momentum.
1646
+
1647
+ Args:
1648
+ model: The model to consider.
1649
+ data: The data of the considered model.
1650
+ output_vel_repr: The output velocity representation of the jacobian.
1651
+
1652
+ Returns:
1653
+ The jacobian of the total momentum of the model in the active representation.
1654
+ """
1655
+
1656
+ output_vel_repr = (
1657
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
1658
+ )
1659
+
1660
+ if output_vel_repr is data.velocity_representation:
1661
+ return free_floating_mass_matrix(model=model, data=data)[0:6]
1662
+
1663
+ with data.switch_velocity_representation(VelRepr.Body):
1664
+ B_Jh_B = free_floating_mass_matrix(model=model, data=data)[0:6]
1665
+
1666
+ match data.velocity_representation:
1667
+ case VelRepr.Body:
1668
+ B_Jh = B_Jh_B
1669
+
1670
+ case VelRepr.Inertial:
1671
+ B_X_W = Adjoint.from_transform(
1672
+ transform=data.base_transform(), inverse=True
1673
+ )
1674
+ B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
1675
+
1676
+ case VelRepr.Mixed:
1677
+ BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
1678
+ B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
1679
+ B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
1680
+
1681
+ case _:
1682
+ raise ValueError(data.velocity_representation)
1683
+
1684
+ match output_vel_repr:
1685
+ case VelRepr.Body:
1686
+ return B_Jh
1687
+
1688
+ case VelRepr.Inertial:
1689
+ W_H_B = data.base_transform()
1690
+ B_Xv_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
1691
+ W_Xf_B = B_Xv_W.T
1692
+ W_Jh = W_Xf_B @ B_Jh
1693
+ return W_Jh
1694
+
1695
+ case VelRepr.Mixed:
1696
+ BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
1697
+ B_Xv_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
1698
+ BW_Xf_B = B_Xv_BW.T
1699
+ BW_Jh = BW_Xf_B @ B_Jh
1700
+ return BW_Jh
1701
+
1702
+ case _:
1703
+ raise ValueError(output_vel_repr)
1704
+
1705
+
1706
+ @jax.jit
1707
+ @js.common.named_scope
1708
+ def average_velocity(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
1709
+ """
1710
+ Compute the average velocity of the model.
1711
+
1712
+ Args:
1713
+ model: The model to consider.
1714
+ data: The data of the considered model.
1715
+
1716
+ Returns:
1717
+ The average velocity of the model computed in the base frame and expressed
1718
+ in the active representation.
1719
+ """
1720
+
1721
+ ν = data.generalized_velocity()
1722
+ J = average_velocity_jacobian(model=model, data=data)
1723
+
1724
+ return J @ ν
1725
+
1726
+
1727
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
1728
+ def average_velocity_jacobian(
1729
+ model: JaxSimModel,
1730
+ data: js.data.JaxSimModelData,
1731
+ *,
1732
+ output_vel_repr: VelRepr | None = None,
1733
+ ) -> jtp.Matrix:
1734
+ """
1735
+ Compute the Jacobian of the average velocity of the model.
1736
+
1737
+ Args:
1738
+ model: The model to consider.
1739
+ data: The data of the considered model.
1740
+ output_vel_repr: The output velocity representation of the jacobian.
1741
+
1742
+ Returns:
1743
+ The Jacobian of the average centroidal velocity of the model in the desired
1744
+ representation.
1745
+ """
1746
+
1747
+ output_vel_repr = (
1748
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
1749
+ )
1750
+
1751
+ # Depending on the velocity representation, the frame G is either G[W] or G[B].
1752
+ G_J = js.com.average_centroidal_velocity_jacobian(model=model, data=data)
1753
+
1754
+ match output_vel_repr:
1755
+
1756
+ case VelRepr.Inertial:
1757
+
1758
+ GW_J = G_J
1759
+ W_p_CoM = js.com.com_position(model=model, data=data)
1760
+
1761
+ W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
1762
+ W_X_GW = Adjoint.from_transform(transform=W_H_GW)
1763
+
1764
+ return W_X_GW @ GW_J
1765
+
1766
+ case VelRepr.Body:
1767
+
1768
+ GB_J = G_J
1769
+ W_p_B = data.base_position()
1770
+ W_p_CoM = js.com.com_position(model=model, data=data)
1771
+ B_R_W = data.base_orientation(dcm=True).transpose()
1772
+
1773
+ B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B))
1774
+ B_X_GB = Adjoint.from_transform(transform=B_H_GB)
1775
+
1776
+ return B_X_GB @ GB_J
1777
+
1778
+ case VelRepr.Mixed:
999
1779
 
1000
- # Compute the total momentum expressed in the base frame
1001
- B_h = M_B[0:6, :] @ B_ν
1780
+ GW_J = G_J
1781
+ W_p_B = data.base_position()
1782
+ W_p_CoM = js.com.com_position(model=model, data=data)
1002
1783
 
1003
- # Compute the 6D transformation matrix
1784
+ BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B)
1785
+ BW_X_GW = Adjoint.from_transform(transform=BW_H_GW)
1786
+
1787
+ return BW_X_GW @ GW_J
1788
+
1789
+
1790
+ # ========================
1791
+ # Other dynamic quantities
1792
+ # ========================
1793
+
1794
+
1795
+ @jax.jit
1796
+ @js.common.named_scope
1797
+ def link_bias_accelerations(
1798
+ model: JaxSimModel,
1799
+ data: js.data.JaxSimModelData,
1800
+ ) -> jtp.Vector:
1801
+ r"""
1802
+ Compute the bias accelerations of the links of the model.
1803
+
1804
+ Args:
1805
+ model: The model to consider.
1806
+ data: The data of the considered model.
1807
+
1808
+ Returns:
1809
+ The bias accelerations of the links of the model.
1810
+
1811
+ Note:
1812
+ This function computes the component of the total 6D acceleration not due to
1813
+ the joint or base acceleration.
1814
+ It is often called :math:`\dot{J} \boldsymbol{\nu}`.
1815
+ """
1816
+
1817
+ # ================================================
1818
+ # Compute the body-fixed zero base 6D acceleration
1819
+ # ================================================
1820
+
1821
+ # Compute the base transform.
1004
1822
  W_H_B = data.base_transform()
1005
- B_X_W: jtp.Array = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
1006
1823
 
1007
- # Convert to inertial-fixed representation
1008
- # (its coordinates transform like 6D forces)
1009
- W_h = B_X_W.T @ B_h
1824
+ def other_representation_to_inertial(
1825
+ C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector
1826
+ ) -> jtp.Vector:
1827
+ """
1828
+ Convert the active representation of the base acceleration C_v̇_WB
1829
+ expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
1830
+ """
1831
+
1832
+ W_X_C = Adjoint.from_transform(transform=W_H_C)
1833
+ C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
1010
1834
 
1011
- # Convert to the active representation of the model
1012
- return js.data.JaxSimModelData.inertial_to_other_representation(
1013
- array=W_h,
1014
- other_representation=data.velocity_representation,
1015
- transform=W_H_B,
1016
- is_force=True,
1835
+ # In Mixed representation, we need to include a cross product in ℝ⁶.
1836
+ # In Inertial and Body representations, the cross product is always zero.
1837
+ return W_X_C @ (C_v̇_WB + jaxsim.math.Cross.vx(C_X_W @ W_v_WC) @ C_v_WB)
1838
+
1839
+ # Here we initialize a zero 6D acceleration in the active representation, and
1840
+ # convert it to inertial-fixed. This is a useful intermediate representation
1841
+ # because the apparent acceleration W_v̇_WB is equal to the intrinsic acceleration
1842
+ # W_a_WB, and intrinsic accelerations can be expressed in different frames through
1843
+ # a simple C_X_W 6D transform.
1844
+ match data.velocity_representation:
1845
+ case VelRepr.Inertial:
1846
+ W_H_C = W_H_W = jnp.eye(4) # noqa: F841
1847
+ W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
1848
+ with data.switch_velocity_representation(VelRepr.Inertial):
1849
+ C_v_WB = W_v_WB = data.base_velocity()
1850
+
1851
+ case VelRepr.Body:
1852
+ W_H_C = W_H_B
1853
+ with data.switch_velocity_representation(VelRepr.Inertial):
1854
+ W_v_WC = W_v_WB = data.base_velocity() # noqa: F841
1855
+ with data.switch_velocity_representation(VelRepr.Body):
1856
+ C_v_WB = B_v_WB = data.base_velocity()
1857
+
1858
+ case VelRepr.Mixed:
1859
+ W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
1860
+ W_H_C = W_H_BW
1861
+ with data.switch_velocity_representation(VelRepr.Mixed):
1862
+ W_ṗ_B = data.base_velocity()[0:3]
1863
+ BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
1864
+ W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW)
1865
+ W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW # noqa: F841
1866
+ with data.switch_velocity_representation(VelRepr.Mixed):
1867
+ C_v_WB = BW_v_WB = data.base_velocity() # noqa: F841
1868
+
1869
+ case _:
1870
+ raise ValueError(data.velocity_representation)
1871
+
1872
+ # Convert a zero 6D acceleration from the active representation to inertial-fixed.
1873
+ W_v̇_WB = other_representation_to_inertial(
1874
+ C_v̇_WB=jnp.zeros(6), C_v_WB=C_v_WB, W_H_C=W_H_C, W_v_WC=W_v_WC
1875
+ )
1876
+
1877
+ # ===================================
1878
+ # Initialize buffers and prepare data
1879
+ # ===================================
1880
+
1881
+ # Get the parent array λ(i).
1882
+ # Note: λ(0) must not be used, it's initialized to -1.
1883
+ λ = model.kin_dyn_parameters.parent_array
1884
+
1885
+ # Compute 6D transforms of the base velocity.
1886
+ B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True)
1887
+
1888
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
1889
+ # These transforms define the relative kinematics of the entire model, including
1890
+ # the base transform for both floating-base and fixed-base models.
1891
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
1892
+ joint_positions=data.joint_positions(), base_transform=W_H_B
1893
+ )
1894
+
1895
+ # Allocate the buffer to store the body-fixed link velocities.
1896
+ L_v_WL = jnp.zeros(shape=(model.number_of_links(), 6))
1897
+
1898
+ # Store the base velocity.
1899
+ with data.switch_velocity_representation(VelRepr.Body):
1900
+ B_v_WB = data.base_velocity()
1901
+ L_v_WL = L_v_WL.at[0].set(B_v_WB)
1902
+
1903
+ # Get the joint velocities.
1904
+ ṡ = data.joint_velocities(model=model, joint_names=model.joint_names())
1905
+
1906
+ # Allocate the buffer to store the body-fixed link accelerations,
1907
+ # and initialize the base acceleration.
1908
+ L_v̇_WL = jnp.zeros(shape=(model.number_of_links(), 6))
1909
+ L_v̇_WL = L_v̇_WL.at[0].set(B_X_W @ W_v̇_WB)
1910
+
1911
+ # ======================================
1912
+ # Propagate accelerations and velocities
1913
+ # ======================================
1914
+
1915
+ # The computation of the bias forces is similar to the forward pass of RNEA,
1916
+ # this time with zero base and joint accelerations. Furthermore, here we do
1917
+ # not remove gravity during the propagation.
1918
+
1919
+ # Initialize the loop.
1920
+ Carry = tuple[jtp.Matrix, jtp.Matrix]
1921
+ carry0: Carry = (L_v_WL, L_v̇_WL)
1922
+
1923
+ def propagate_accelerations(carry: Carry, i: jtp.Int) -> tuple[Carry, None]:
1924
+ # Initialize index and unpack the carry.
1925
+ ii = i - 1
1926
+ v, a = carry
1927
+
1928
+ # Get the motion subspace of the joint.
1929
+ Si = S[i].squeeze()
1930
+
1931
+ # Project the joint velocity into its motion subspace.
1932
+ vJ = Si * ṡ[ii]
1933
+
1934
+ # Propagate the link body-fixed velocity.
1935
+ v_i = i_X_λi[i] @ v[λ[i]] + vJ
1936
+ v = v.at[i].set(v_i)
1937
+
1938
+ # Propagate the link body-fixed acceleration considering zero joint acceleration.
1939
+ s̈ = 0.0
1940
+ a_i = i_X_λi[i] @ a[λ[i]] + Si * s̈ + jaxsim.math.Cross.vx(v[i]) @ vJ
1941
+ a = a.at[i].set(a_i)
1942
+
1943
+ return (v, a), None
1944
+
1945
+ # Compute the body-fixed velocity and body-fixed apparent acceleration of the links.
1946
+ (L_v_WL, L_v̇_WL), _ = (
1947
+ jax.lax.scan(
1948
+ f=propagate_accelerations,
1949
+ init=carry0,
1950
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
1951
+ )
1952
+ if model.number_of_links() > 1
1953
+ else [(L_v_WL, L_v̇_WL), None]
1954
+ )
1955
+
1956
+ # ===================================================================
1957
+ # Convert the body-fixed 6D acceleration to the active representation
1958
+ # ===================================================================
1959
+
1960
+ def body_to_other_representation(
1961
+ L_v̇_WL: jtp.Vector, L_v_WL: jtp.Vector, C_H_L: jtp.Matrix, L_v_CL: jtp.Vector
1962
+ ) -> jtp.Vector:
1963
+ """
1964
+ Convert the body-fixed apparent acceleration L_v̇_WL to
1965
+ another representation C_v̇_WL expressed in a generic frame C.
1966
+ """
1967
+
1968
+ # In Mixed representation, we need to include a cross product in ℝ⁶.
1969
+ # In Inertial and Body representations, the cross product is always zero.
1970
+ C_X_L = jaxsim.math.Adjoint.from_transform(transform=C_H_L)
1971
+ return C_X_L @ (L_v̇_WL + jaxsim.math.Cross.vx(L_v_CL) @ L_v_WL)
1972
+
1973
+ match data.velocity_representation:
1974
+ case VelRepr.Body:
1975
+ C_H_L = L_H_L = jnp.stack( # noqa: F841
1976
+ [jnp.eye(4)] * model.number_of_links()
1977
+ )
1978
+ L_v_CL = L_v_LL = jnp.zeros( # noqa: F841
1979
+ shape=(model.number_of_links(), 6)
1980
+ )
1981
+
1982
+ case VelRepr.Inertial:
1983
+ C_H_L = W_H_L = js.model.forward_kinematics(model=model, data=data)
1984
+ L_v_CL = L_v_WL
1985
+
1986
+ case VelRepr.Mixed:
1987
+ W_H_L = js.model.forward_kinematics(model=model, data=data)
1988
+ LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L)
1989
+ C_H_L = LW_H_L
1990
+ L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841
1991
+ lambda v: v.at[0:3].set(jnp.zeros(3))
1992
+ )(L_v_WL)
1993
+
1994
+ case _:
1995
+ raise ValueError(data.velocity_representation)
1996
+
1997
+ # Convert from body-fixed to the active representation.
1998
+ O_v̇_WL = jax.vmap(body_to_other_representation)(
1999
+ L_v̇_WL=L_v̇_WL, L_v_WL=L_v_WL, C_H_L=C_H_L, L_v_CL=L_v_CL
2000
+ )
2001
+
2002
+ return O_v̇_WL
2003
+
2004
+
2005
+ @jax.jit
2006
+ @js.common.named_scope
2007
+ def link_contact_forces(
2008
+ model: js.model.JaxSimModel,
2009
+ data: js.data.JaxSimModelData,
2010
+ *,
2011
+ link_forces: jtp.MatrixLike | None = None,
2012
+ joint_force_references: jtp.VectorLike | None = None,
2013
+ **kwargs,
2014
+ ) -> jtp.Matrix:
2015
+ """
2016
+ Compute the 6D contact forces of all links of the model.
2017
+
2018
+ Args:
2019
+ model: The model to consider.
2020
+ data: The data of the considered model.
2021
+ link_forces:
2022
+ The 6D external forces to apply to the links expressed in the same
2023
+ representation of data.
2024
+ joint_force_references:
2025
+ The joint force references to apply to the joints.
2026
+ kwargs: Additional keyword arguments to pass to the active contact model..
2027
+
2028
+ Returns:
2029
+ A `(nL, 6)` array containing the stacked 6D contact forces of the links,
2030
+ expressed in the frame corresponding to the active representation.
2031
+ """
2032
+
2033
+ # Note: the following code should be kept in sync with the function
2034
+ # `jaxsim.api.ode.system_velocity_dynamics`. We cannot merge them since
2035
+ # there we need to get also aux_data.
2036
+
2037
+ # Build link forces if not provided.
2038
+ # These forces are expressed in the frame corresponding to the velocity
2039
+ # representation of data.
2040
+ O_f_L = (
2041
+ jnp.atleast_2d(link_forces.squeeze())
2042
+ if link_forces is not None
2043
+ else jnp.zeros((model.number_of_links(), 6))
1017
2044
  ).astype(float)
1018
2045
 
2046
+ # Build joint force references if not provided.
2047
+ joint_force_references = (
2048
+ jnp.atleast_1d(joint_force_references)
2049
+ if joint_force_references is not None
2050
+ else jnp.zeros(model.dofs())
2051
+ )
2052
+
2053
+ # We expect that the 6D forces included in the `link_forces` argument are expressed
2054
+ # in the frame corresponding to the velocity representation of `data`.
2055
+ input_references = js.references.JaxSimModelReferences.build(
2056
+ model=model,
2057
+ data=data,
2058
+ velocity_representation=data.velocity_representation,
2059
+ link_forces=O_f_L,
2060
+ joint_force_references=joint_force_references,
2061
+ )
2062
+
2063
+ # Compute the 6D forces applied to the links equivalent to the forces applied
2064
+ # to the frames associated to the collidable points.
2065
+ f_L, _ = model.contact_model.compute_link_contact_forces(
2066
+ model=model,
2067
+ data=data,
2068
+ link_forces=input_references.link_forces(model=model, data=data),
2069
+ joint_force_references=input_references.joint_force_references(),
2070
+ **kwargs,
2071
+ )
2072
+
2073
+ return f_L
2074
+
1019
2075
 
1020
2076
  # ======
1021
2077
  # Energy
@@ -1023,6 +2079,7 @@ def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vec
1023
2079
 
1024
2080
 
1025
2081
  @jax.jit
2082
+ @js.common.named_scope
1026
2083
  def mechanical_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
1027
2084
  """
1028
2085
  Compute the mechanical energy of the model.
@@ -1042,6 +2099,7 @@ def mechanical_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.
1042
2099
 
1043
2100
 
1044
2101
  @jax.jit
2102
+ @js.common.named_scope
1045
2103
  def kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
1046
2104
  """
1047
2105
  Compute the kinetic energy of the model.
@@ -1063,6 +2121,7 @@ def kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Flo
1063
2121
 
1064
2122
 
1065
2123
  @jax.jit
2124
+ @js.common.named_scope
1066
2125
  def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
1067
2126
  """
1068
2127
  Compute the potential energy of the model.
@@ -1077,7 +2136,7 @@ def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.F
1077
2136
 
1078
2137
  m = total_mass(model=model)
1079
2138
  gravity = data.gravity.squeeze()
1080
- W_p̃_CoM = jnp.hstack([com_position(model=model, data=data), 1])
2139
+ W_p̃_CoM = jnp.hstack([js.com.com_position(model=model, data=data), 1])
1081
2140
 
1082
2141
  U = -jnp.hstack([gravity, 0]) @ (m * W_p̃_CoM)
1083
2142
  return U.squeeze().astype(float)
@@ -1089,15 +2148,18 @@ def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.F
1089
2148
 
1090
2149
 
1091
2150
  @jax.jit
2151
+ @js.common.named_scope
1092
2152
  def step(
1093
2153
  model: JaxSimModel,
1094
2154
  data: js.data.JaxSimModelData,
1095
2155
  *,
1096
- dt: jtp.FloatLike,
1097
- integrator: jaxsim.integrators.Integrator,
1098
- integrator_state: dict[str, Any] | None = None,
1099
- joint_forces: jtp.VectorLike | None = None,
1100
- external_forces: jtp.MatrixLike | None = None,
2156
+ t0: jtp.FloatLike = 0.0,
2157
+ dt: jtp.FloatLike | None = None,
2158
+ integrator: jaxsim.integrators.Integrator | None = None,
2159
+ integrator_metadata: dict[str, Any] | None = None,
2160
+ link_forces: jtp.MatrixLike | None = None,
2161
+ joint_force_references: jtp.VectorLike | None = None,
2162
+ **kwargs,
1101
2163
  ) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:
1102
2164
  """
1103
2165
  Perform a simulation step.
@@ -1105,40 +2167,188 @@ def step(
1105
2167
  Args:
1106
2168
  model: The model to consider.
1107
2169
  data: The data of the considered model.
1108
- dt: The time step to consider.
1109
2170
  integrator: The integrator to use.
1110
- integrator_state: The state of the integrator.
1111
- joint_forces: The joint forces to consider.
1112
- external_forces:
1113
- The external forces to consider.
1114
- The frame in which they are expressed must be `data.velocity_representation`.
2171
+ integrator_metadata: The metadata of the integrator, if needed.
2172
+ t0: The initial time to consider. Only relevant for time-dependent dynamics.
2173
+ dt: The time step to consider. If not specified, it is read from the model.
2174
+ link_forces:
2175
+ The 6D forces to apply to the links expressed in the frame corresponding to
2176
+ the velocity representation of `data`.
2177
+ joint_force_references: The joint force references to consider.
2178
+ kwargs: Additional kwargs to pass to the integrator.
1115
2179
 
1116
2180
  Returns:
1117
- A tuple containing the new data of the model
1118
- and the new state of the integrator.
2181
+ A tuple containing the new data of the model and a dictionary of auxiliary
2182
+ data computed during the step. If the integrator has metadata, the dictionary
2183
+ will contain the new metadata stored in the `integrator_metadata` key.
2184
+
2185
+ Note:
2186
+ In order to reduce the occurrences of frame conversions performed internally,
2187
+ it is recommended to use inertial-fixed velocity representation. This can be
2188
+ particularly useful for automatically differentiated logic.
1119
2189
  """
1120
2190
 
1121
- integrator_state = integrator_state if integrator_state is not None else dict()
2191
+ # Extract the integrator kwargs.
2192
+ # The following logic allows using integrators having kwargs colliding with the
2193
+ # kwargs of this step function.
2194
+ kwargs = kwargs if kwargs is not None else {}
2195
+ integrator_kwargs = kwargs.pop("integrator_kwargs", {})
2196
+ integrator_kwargs = kwargs | integrator_kwargs
2197
+
2198
+ # Extract the integrator and the optional metadata.
2199
+ integrator_metadata_t0 = integrator_metadata
2200
+ integrator = integrator if integrator is not None else model.integrator
2201
+
2202
+ # Initialize the time-related variables.
2203
+ state_t0 = data.state
2204
+ t0 = jnp.array(t0, dtype=float)
2205
+ dt = jnp.array(dt if dt is not None else model.time_step).astype(float)
2206
+
2207
+ # The visco-elastic contacts operate at best with their own integrator.
2208
+ # They can be used with Euler-like integrators, paying the price of ignoring
2209
+ # some of the benefits of continuous-time integration on the system position.
2210
+ # Furthermore, the requirement to know the Δt used by the integrator is not
2211
+ # compatible with high-order integrators, that use advanced RK stages to evaluate
2212
+ # the dynamics at intermediate times.
2213
+ module = jaxsim.rbda.contacts.visco_elastic.step.__module__
2214
+ name = jaxsim.rbda.contacts.visco_elastic.step.__name__
2215
+ msg = "You need to use the custom '{}.{}' function with this contact model."
2216
+ jaxsim.exceptions.raise_runtime_error_if(
2217
+ condition=(
2218
+ isinstance(model.contact_model, jaxsim.rbda.contacts.ViscoElasticContacts)
2219
+ & (
2220
+ ~jnp.allclose(dt, model.time_step)
2221
+ | ~int(
2222
+ isinstance(integrator, jaxsim.integrators.fixed_step.ForwardEuler)
2223
+ )
2224
+ )
2225
+ ),
2226
+ msg=msg.format(module, name),
2227
+ )
1122
2228
 
1123
- # Extract the initial resources.
1124
- t0_ns = data.time_ns
1125
- state_x0 = data.state
1126
- integrator_state_x0 = integrator_state
2229
+ # =================
2230
+ # Phase 1: pre-step
2231
+ # =================
2232
+
2233
+ # TODO: some contact models here may want to perform a dynamic filtering of
2234
+ # the enabled collidable points.
2235
+
2236
+ # Build the references object.
2237
+ # We assume that the link forces are expressed in the frame corresponding to the
2238
+ # velocity representation of the data.
2239
+ references = js.references.JaxSimModelReferences.build(
2240
+ model=model,
2241
+ data=data,
2242
+ velocity_representation=data.velocity_representation,
2243
+ link_forces=link_forces,
2244
+ joint_force_references=joint_force_references,
2245
+ )
2246
+
2247
+ # =============
2248
+ # Phase 2: step
2249
+ # =============
2250
+
2251
+ # Prepare the references to pass.
2252
+ with references.switch_velocity_representation(data.velocity_representation):
2253
+
2254
+ f_L = references.link_forces(model=model, data=data)
2255
+ τ_references = references.joint_force_references(model=model)
1127
2256
 
1128
2257
  # Step the dynamics forward.
1129
- state_xf, integrator_state_xf = integrator.step(
1130
- x0=state_x0,
1131
- t0=jnp.array(t0_ns * 1e9).astype(float),
2258
+ state_tf, integrator_metadata_tf = integrator.step(
2259
+ x0=state_t0,
2260
+ t0=t0,
1132
2261
  dt=dt,
1133
- params=integrator_state_x0,
1134
- **dict(joint_forces=joint_forces, external_forces=external_forces),
2262
+ metadata=integrator_metadata_t0,
2263
+ # Always inject the current (model, data) pair into the system dynamics
2264
+ # considered by the integrator, and include the input variables represented
2265
+ # by the pair (f_L, τ_references).
2266
+ # Note that the wrapper of the system dynamics will override (state_x0, t0)
2267
+ # inside the passed data even if it is not strictly needed. This logic is
2268
+ # necessary to reuse the jit-compiled step function of compatible pytrees
2269
+ # of model and data produced e.g. by parameterized applications.
2270
+ **(
2271
+ dict(
2272
+ model=model,
2273
+ data=data,
2274
+ link_forces=f_L,
2275
+ joint_force_references=τ_references,
2276
+ )
2277
+ | integrator_kwargs
2278
+ ),
1135
2279
  )
1136
2280
 
1137
- return (
1138
- # Store the new state of the model and the new time.
1139
- data.replace(
1140
- state=state_xf,
1141
- time_ns=t0_ns + jnp.array(dt * 1e9).astype(jnp.uint64),
1142
- ),
1143
- integrator_state_xf,
2281
+ # Store the new state of the model.
2282
+ data_tf = data.replace(state=state_tf)
2283
+
2284
+ # ==================
2285
+ # Phase 3: post-step
2286
+ # ==================
2287
+
2288
+ # Post process the simulation state, if needed.
2289
+ match model.contact_model:
2290
+
2291
+ # Rigid contact models use an impact model that produces discontinuous model velocities.
2292
+ # Hence, here we need to reset the velocity after each impact to guarantee that
2293
+ # the linear velocity of the active collidable points is zero.
2294
+ case jaxsim.rbda.contacts.RigidContacts():
2295
+
2296
+ # Raise runtime error for not supported case in which Rigid contacts and
2297
+ # Baumgarte stabilization are enabled and used with ForwardEuler integrator.
2298
+ jaxsim.exceptions.raise_runtime_error_if(
2299
+ condition=isinstance(
2300
+ integrator,
2301
+ jaxsim.integrators.fixed_step.ForwardEuler
2302
+ | jaxsim.integrators.fixed_step.ForwardEulerSO3,
2303
+ )
2304
+ & ((data_tf.contacts_params.K > 0) | (data_tf.contacts_params.D > 0)),
2305
+ msg="Baumgarte stabilization is not supported with ForwardEuler integrators",
2306
+ )
2307
+
2308
+ # Extract the indices corresponding to the enabled collidable points.
2309
+ indices_of_enabled_collidable_points = (
2310
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
2311
+ )
2312
+
2313
+ W_p_C = js.contact.collidable_point_positions(model, data_tf)[
2314
+ indices_of_enabled_collidable_points
2315
+ ]
2316
+
2317
+ # Compute the penetration depth of the collidable points.
2318
+ δ, *_ = jax.vmap(
2319
+ jaxsim.rbda.contacts.common.compute_penetration_data,
2320
+ in_axes=(0, 0, None),
2321
+ )(W_p_C, jnp.zeros_like(W_p_C), model.terrain)
2322
+
2323
+ with data_tf.switch_velocity_representation(VelRepr.Mixed):
2324
+ J_WC = js.contact.jacobian(model, data_tf)[
2325
+ indices_of_enabled_collidable_points
2326
+ ]
2327
+ M = js.model.free_floating_mass_matrix(model, data_tf)
2328
+ BW_ν_pre_impact = data_tf.generalized_velocity()
2329
+
2330
+ # Compute the impact velocity.
2331
+ # It may be discontinuous in case new contacts are made.
2332
+ BW_ν_post_impact = (
2333
+ jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity(
2334
+ generalized_velocity=BW_ν_pre_impact,
2335
+ inactive_collidable_points=(δ <= 0),
2336
+ M=M,
2337
+ J_WC=J_WC,
2338
+ )
2339
+ )
2340
+
2341
+ # Reset the generalized velocity.
2342
+ data_tf = data_tf.reset_base_velocity(BW_ν_post_impact[0:6])
2343
+ data_tf = data_tf.reset_joint_velocities(BW_ν_post_impact[6:])
2344
+
2345
+ # Restore the input velocity representation.
2346
+ data_tf = data_tf.replace(
2347
+ velocity_representation=data.velocity_representation, validate=False
2348
+ )
2349
+
2350
+ return data_tf, {} | (
2351
+ dict(integrator_metadata=integrator_metadata_tf)
2352
+ if integrator_metadata is not None
2353
+ else {}
1144
2354
  )