jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/__init__.py CHANGED
@@ -8,16 +8,37 @@ def _jnp_options() -> None:
8
8
 
9
9
  import jax
10
10
 
11
- # Enable by default
12
- if not ("JAX_ENABLE_X64" in os.environ and os.environ["JAX_ENABLE_X64"] == "0"):
13
- logging.info("Enabling JAX to use 64bit precision")
11
+ # Check if running on TPU.
12
+ is_tpu = jax.devices()[0].platform == "tpu"
13
+
14
+ # Check if running on Metal.
15
+ is_metal = jax.devices()[0].platform == "METAL"
16
+
17
+ # Enable by default 64-bit precision to get accurate physics.
18
+ # Users can enforce 32-bit precision by setting the following variable to 0.
19
+ use_x64 = os.environ.get("JAX_ENABLE_X64", "1") != "0"
20
+
21
+ # Notify the user if unsupported 64-bit precision was enforced on TPU.
22
+ if (is_tpu or is_metal) and use_x64:
23
+ msg = f"64-bit precision is not allowed on {jax.devices()[0].platform.upper}. Enforcing 32bit precision."
24
+ logging.warning(msg)
25
+ use_x64 = False
26
+
27
+ if is_metal:
28
+ logging.warning(
29
+ "JAX Metal backend is experimental. Some functionalities may not be available."
30
+ )
31
+
32
+ # Enable 64-bit precision in JAX.
33
+ if use_x64:
34
+ logging.info("Enabling JAX to use 64-bit precision")
14
35
  jax.config.update("jax_enable_x64", True)
15
36
 
16
- import jax.numpy as jnp
17
- import numpy as np
18
-
19
- if jnp.empty(0, dtype=float).dtype != jnp.empty(0, dtype=np.float64).dtype:
20
- logging.warning("Failed to enable 64bit precision in JAX")
37
+ # Warn about experimental usage of 32-bit precision.
38
+ else:
39
+ logging.warning(
40
+ "Using 32-bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
41
+ )
21
42
 
22
43
 
23
44
  def _np_options() -> None:
@@ -27,41 +48,71 @@ def _np_options() -> None:
27
48
 
28
49
 
29
50
  def _is_editable() -> bool:
51
+
30
52
  import importlib.util
31
53
  import pathlib
32
54
  import site
33
55
 
34
- # Get the ModuleSpec of jaxsim
56
+ # Get the ModuleSpec of jaxsim.
35
57
  jaxsim_spec = importlib.util.find_spec(name="jaxsim")
36
58
 
37
59
  # This can be None. If it's None, assume non-editable installation.
38
60
  if jaxsim_spec.origin is None:
39
61
  return False
40
62
 
41
- # Get the folder containing the jaxsim package
63
+ # Get the folder containing the jaxsim package.
42
64
  jaxsim_package_dir = str(pathlib.Path(jaxsim_spec.origin).parent.parent)
43
65
 
44
- # The installation is editable if the package dir is not in any {site|dist}-packages
66
+ # The installation is editable if the package dir is not in any {site|dist}-packages.
45
67
  return jaxsim_package_dir not in site.getsitepackages()
46
68
 
47
69
 
48
- # Initialize the logging verbosity
49
- if _is_editable():
50
- logging.configure(level=logging.LoggingLevel.DEBUG)
51
- else:
52
- logging.configure(level=logging.LoggingLevel.WARNING)
70
+ def _get_default_logging_level(env_var: str) -> logging.LoggingLevel:
71
+ """
72
+ Get the default logging level.
73
+
74
+ Args:
75
+ env_var: The environment variable to check.
76
+
77
+ Returns:
78
+ The logging level to set.
79
+ """
80
+
81
+ import os
82
+
83
+ # Define the default logging level depending on the installation mode.
84
+ default_logging_level = (
85
+ logging.LoggingLevel.DEBUG
86
+ if _is_editable() # noqa: F821
87
+ else logging.LoggingLevel.WARNING
88
+ )
89
+
90
+ # Allow to override the default logging level with an environment variable.
91
+ try:
92
+ return logging.LoggingLevel[
93
+ os.environ.get(env_var, default_logging_level.name).upper()
94
+ ]
95
+
96
+ except KeyError as exc:
97
+ msg = f"Invalid logging level defined in {env_var}='{os.environ[env_var]}'"
98
+ raise RuntimeError(msg) from exc
99
+
100
+
101
+ # Configure the logger with the default logging level.
102
+ logging.configure(level=_get_default_logging_level(env_var="JAXSIM_LOGGING_LEVEL"))
103
+
53
104
 
54
- # Configure JAX
105
+ # Configure JAX.
55
106
  _jnp_options()
56
107
 
57
- # Initialize the numpy print options
108
+ # Initialize the numpy print options.
58
109
  _np_options()
59
110
 
60
111
  del _jnp_options
61
112
  del _np_options
113
+ del _get_default_logging_level
62
114
  del _is_editable
63
115
 
64
- from . import high_level, logging, math, simulation, sixd
65
- from .high_level.common import VelRepr
66
- from .simulation.ode_integration import IntegratorType
67
- from .simulation.simulator import JaxSim
116
+ from . import terrain # isort:skip
117
+ from . import api, integrators, logging, math, rbda
118
+ from .api.common import VelRepr
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.dev191'
16
- __version_tuple__ = version_tuple = (0, 2, 'dev191')
15
+ __version__ = version = '0.6.1.dev2'
16
+ __version_tuple__ = version_tuple = (0, 6, 1, 'dev2')
jaxsim/api/__init__.py CHANGED
@@ -1 +1,13 @@
1
- from . import contact, data, joint, link, model, ode
1
+ from . import common # isort:skip
2
+ from . import model, data # isort:skip
3
+ from . import (
4
+ com,
5
+ contact,
6
+ frame,
7
+ joint,
8
+ kin_dyn_parameters,
9
+ link,
10
+ ode,
11
+ ode_data,
12
+ references,
13
+ )
jaxsim/api/com.py ADDED
@@ -0,0 +1,423 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+ import jaxsim.api as js
5
+ import jaxsim.math
6
+ import jaxsim.typing as jtp
7
+
8
+ from .common import VelRepr
9
+
10
+
11
+ @jax.jit
12
+ @js.common.named_scope
13
+ def com_position(
14
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
15
+ ) -> jtp.Vector:
16
+ """
17
+ Compute the position of the center of mass of the model.
18
+
19
+ Args:
20
+ model: The model to consider.
21
+ data: The data of the considered model.
22
+
23
+ Returns:
24
+ The position of the center of mass of the model w.r.t. the world frame.
25
+ """
26
+
27
+ m = js.model.total_mass(model=model)
28
+
29
+ W_H_L = js.model.forward_kinematics(model=model, data=data)
30
+ W_H_B = data.base_transform()
31
+ B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B)
32
+
33
+ def B_p̃_LCoM(i) -> jtp.Vector:
34
+ m = js.link.mass(model=model, link_index=i)
35
+ L_p_LCoM = js.link.com_position(
36
+ model=model, data=data, link_index=i, in_link_frame=True
37
+ )
38
+ return m * B_H_W @ W_H_L[i] @ jnp.hstack([L_p_LCoM, 1])
39
+
40
+ com_links = jax.vmap(B_p̃_LCoM)(jnp.arange(model.number_of_links()))
41
+
42
+ B_p̃_CoM = (1 / m) * com_links.sum(axis=0)
43
+ B_p̃_CoM = B_p̃_CoM.at[3].set(1)
44
+
45
+ return (W_H_B @ B_p̃_CoM)[0:3].astype(float)
46
+
47
+
48
+ @jax.jit
49
+ @js.common.named_scope
50
+ def com_linear_velocity(
51
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
52
+ ) -> jtp.Vector:
53
+ r"""
54
+ Compute the linear velocity of the center of mass of the model.
55
+
56
+ Args:
57
+ model: The model to consider.
58
+ data: The data of the considered model.
59
+
60
+ Returns:
61
+ The linear velocity of the center of mass of the model in the
62
+ active representation.
63
+
64
+ Note:
65
+ The linear velocity of the center of mass is expressed in the mixed frame
66
+ :math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
67
+ active velocity representation is either inertial-fixed or mixed,
68
+ and :math:`[C] = [B]` if the active velocity representation is body-fixed.
69
+ """
70
+
71
+ # Extract the linear component of the 6D average centroidal velocity.
72
+ # This is expressed in G[B] in body-fixed representation, and in G[W] in
73
+ # inertial-fixed or mixed representation.
74
+ G_vl_WG = average_centroidal_velocity(model=model, data=data)[0:3]
75
+
76
+ return G_vl_WG
77
+
78
+
79
+ @jax.jit
80
+ @js.common.named_scope
81
+ def centroidal_momentum(
82
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
83
+ ) -> jtp.Vector:
84
+ r"""
85
+ Compute the centroidal momentum of the model.
86
+
87
+ Args:
88
+ model: The model to consider.
89
+ data: The data of the considered model.
90
+
91
+ Returns:
92
+ The centroidal momentum of the model.
93
+
94
+ Note:
95
+ The centroidal momentum is expressed in the mixed frame
96
+ :math:`({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`C = W` if the
97
+ active velocity representation is either inertial-fixed or mixed,
98
+ and :math:`C = B` if the active velocity representation is body-fixed.
99
+ """
100
+
101
+ ν = data.generalized_velocity()
102
+ G_J = centroidal_momentum_jacobian(model=model, data=data)
103
+
104
+ return G_J @ ν
105
+
106
+
107
+ @jax.jit
108
+ @js.common.named_scope
109
+ def centroidal_momentum_jacobian(
110
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
111
+ ) -> jtp.Matrix:
112
+ r"""
113
+ Compute the Jacobian of the centroidal momentum of the model.
114
+
115
+ Args:
116
+ model: The model to consider.
117
+ data: The data of the considered model.
118
+
119
+ Returns:
120
+ The Jacobian of the centroidal momentum of the model.
121
+
122
+ Note:
123
+ The frame corresponding to the output representation of this Jacobian is either
124
+ :math:`G[W]`, if the active velocity representation is inertial-fixed or mixed,
125
+ or :math:`G[B]`, if the active velocity representation is body-fixed.
126
+
127
+ Note:
128
+ This Jacobian is also known in the literature as Centroidal Momentum Matrix.
129
+ """
130
+
131
+ # Compute the Jacobian of the total momentum with body-fixed output representation.
132
+ # We convert the output representation either to G[W] or G[B] below.
133
+ B_Jh = js.model.total_momentum_jacobian(
134
+ model=model, data=data, output_vel_repr=VelRepr.Body
135
+ )
136
+
137
+ W_H_B = data.base_transform()
138
+ B_H_W = jaxsim.math.Transform.inverse(W_H_B)
139
+
140
+ W_p_CoM = com_position(model=model, data=data)
141
+
142
+ match data.velocity_representation:
143
+ case VelRepr.Inertial | VelRepr.Mixed:
144
+ W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841
145
+ case VelRepr.Body:
146
+ W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841
147
+ case _:
148
+ raise ValueError(data.velocity_representation)
149
+
150
+ # Compute the transform for 6D forces.
151
+ G_Xf_B = jaxsim.math.Adjoint.from_transform(transform=B_H_W @ W_H_G).T
152
+
153
+ return G_Xf_B @ B_Jh
154
+
155
+
156
+ @jax.jit
157
+ @js.common.named_scope
158
+ def locked_centroidal_spatial_inertia(
159
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
160
+ ):
161
+ """
162
+ Compute the locked centroidal spatial inertia of the model.
163
+
164
+ Args:
165
+ model: The model to consider.
166
+ data: The data of the considered model.
167
+
168
+ Returns:
169
+ The locked centroidal spatial inertia of the model.
170
+ """
171
+
172
+ with data.switch_velocity_representation(VelRepr.Body):
173
+ B_Mbb_B = js.model.locked_spatial_inertia(model=model, data=data)
174
+
175
+ W_H_B = data.base_transform()
176
+ W_p_CoM = com_position(model=model, data=data)
177
+
178
+ match data.velocity_representation:
179
+ case VelRepr.Inertial | VelRepr.Mixed:
180
+ W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841
181
+ case VelRepr.Body:
182
+ W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841
183
+ case _:
184
+ raise ValueError(data.velocity_representation)
185
+
186
+ B_H_G = jaxsim.math.Transform.inverse(W_H_B) @ W_H_G
187
+
188
+ B_Xv_G = jaxsim.math.Adjoint.from_transform(transform=B_H_G)
189
+ G_Xf_B = B_Xv_G.transpose()
190
+
191
+ return G_Xf_B @ B_Mbb_B @ B_Xv_G
192
+
193
+
194
+ @jax.jit
195
+ @js.common.named_scope
196
+ def average_centroidal_velocity(
197
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
198
+ ) -> jtp.Vector:
199
+ r"""
200
+ Compute the average centroidal velocity of the model.
201
+
202
+ Args:
203
+ model: The model to consider.
204
+ data: The data of the considered model.
205
+
206
+ Returns:
207
+ The average centroidal velocity of the model.
208
+
209
+ Note:
210
+ The average velocity is expressed in the mixed frame
211
+ :math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
212
+ active velocity representation is either inertial-fixed or mixed,
213
+ and :math:`[C] = [B]` if the active velocity representation is body-fixed.
214
+ """
215
+
216
+ ν = data.generalized_velocity()
217
+ G_J = average_centroidal_velocity_jacobian(model=model, data=data)
218
+
219
+ return G_J @ ν
220
+
221
+
222
+ @jax.jit
223
+ @js.common.named_scope
224
+ def average_centroidal_velocity_jacobian(
225
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
226
+ ) -> jtp.Matrix:
227
+ r"""
228
+ Compute the Jacobian of the average centroidal velocity of the model.
229
+
230
+ Args:
231
+ model: The model to consider.
232
+ data: The data of the considered model.
233
+
234
+ Returns:
235
+ The Jacobian of the average centroidal velocity of the model.
236
+
237
+ Note:
238
+ The frame corresponding to the output representation of this Jacobian is either
239
+ :math:`G[W]`, if the active velocity representation is inertial-fixed or mixed,
240
+ or :math:`G[B]`, if the active velocity representation is body-fixed.
241
+ """
242
+
243
+ G_J = centroidal_momentum_jacobian(model=model, data=data)
244
+ G_Mbb = locked_centroidal_spatial_inertia(model=model, data=data)
245
+
246
+ return jnp.linalg.inv(G_Mbb) @ G_J
247
+
248
+
249
+ @jax.jit
250
+ @js.common.named_scope
251
+ def bias_acceleration(
252
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
253
+ ) -> jtp.Vector:
254
+ r"""
255
+ Compute the bias linear acceleration of the center of mass.
256
+
257
+ Args:
258
+ model: The model to consider.
259
+ data: The data of the considered model.
260
+
261
+ Returns:
262
+ The bias linear acceleration of the center of mass in the active representation.
263
+
264
+ Note:
265
+ The bias acceleration is expressed in the mixed frame
266
+ :math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
267
+ active velocity representation is either inertial-fixed or mixed,
268
+ and :math:`[C] = [B]` if the active velocity representation is body-fixed.
269
+ """
270
+
271
+ # Compute the pose of all links with forward kinematics.
272
+ W_H_L = js.model.forward_kinematics(model=model, data=data)
273
+
274
+ # Compute the bias acceleration of all links by zeroing the generalized velocity
275
+ # in the active representation.
276
+ v̇_bias_WL = js.model.link_bias_accelerations(model=model, data=data)
277
+
278
+ def other_representation_to_body(
279
+ C_v̇_WL: jtp.Vector, C_v_WC: jtp.Vector, L_H_C: jtp.Matrix, L_v_LC: jtp.Vector
280
+ ) -> jtp.Vector:
281
+ """
282
+ Convert the body-fixed representation of the link bias acceleration
283
+ C_v̇_WL expressed in a generic frame C to the body-fixed representation L_v̇_WL.
284
+ """
285
+
286
+ L_X_C = jaxsim.math.Adjoint.from_transform(transform=L_H_C)
287
+ C_X_L = jaxsim.math.Adjoint.inverse(L_X_C)
288
+
289
+ L_v̇_WL = L_X_C @ (C_v̇_WL + jaxsim.math.Cross.vx(C_X_L @ L_v_LC) @ C_v_WC)
290
+ return L_v̇_WL
291
+
292
+ # We need here to get the body-fixed bias acceleration of the links.
293
+ # Since it's computed in the active representation, we need to convert it to body.
294
+ match data.velocity_representation:
295
+
296
+ case VelRepr.Body:
297
+ L_a_bias_WL = v̇_bias_WL
298
+
299
+ case VelRepr.Inertial:
300
+
301
+ C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841
302
+ C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
303
+
304
+ L_H_C = L_H_W = jax.vmap( # noqa: F841
305
+ lambda W_H_L: jaxsim.math.Transform.inverse(W_H_L)
306
+ )(W_H_L)
307
+
308
+ L_v_LC = L_v_LW = jax.vmap( # noqa: F841
309
+ lambda i: -js.link.velocity(
310
+ model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
311
+ )
312
+ )(jnp.arange(model.number_of_links()))
313
+
314
+ L_a_bias_WL = jax.vmap(
315
+ lambda i: other_representation_to_body(
316
+ C_v̇_WL=C_v̇_WL[i],
317
+ C_v_WC=C_v_WC,
318
+ L_H_C=L_H_C[i],
319
+ L_v_LC=L_v_LC[i],
320
+ )
321
+ )(jnp.arange(model.number_of_links()))
322
+
323
+ case VelRepr.Mixed:
324
+
325
+ C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL # noqa: F841
326
+
327
+ C_v_WC = LW_v_W_LW = jax.vmap( # noqa: F841
328
+ lambda i: js.link.velocity(
329
+ model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed
330
+ )
331
+ .at[3:6]
332
+ .set(jnp.zeros(3))
333
+ )(jnp.arange(model.number_of_links()))
334
+
335
+ L_H_C = L_H_LW = jax.vmap( # noqa: F841
336
+ lambda W_H_L: jaxsim.math.Transform.inverse(
337
+ W_H_L.at[0:3, 3].set(jnp.zeros(3))
338
+ )
339
+ )(W_H_L)
340
+
341
+ L_v_LC = L_v_L_LW = jax.vmap( # noqa: F841
342
+ lambda i: -js.link.velocity(
343
+ model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
344
+ )
345
+ .at[0:3]
346
+ .set(jnp.zeros(3))
347
+ )(jnp.arange(model.number_of_links()))
348
+
349
+ L_a_bias_WL = jax.vmap(
350
+ lambda i: other_representation_to_body(
351
+ C_v̇_WL=C_v̇_WL[i],
352
+ C_v_WC=C_v_WC[i],
353
+ L_H_C=L_H_C[i],
354
+ L_v_LC=L_v_LC[i],
355
+ )
356
+ )(jnp.arange(model.number_of_links()))
357
+
358
+ case _:
359
+ raise ValueError(data.velocity_representation)
360
+
361
+ # Compute the bias of the 6D momentum derivative.
362
+ def bias_momentum_derivative_term(
363
+ link_index: jtp.Int, L_a_bias_WL: jtp.Vector
364
+ ) -> jtp.Vector:
365
+
366
+ # Get the body-fixed 6D inertia matrix.
367
+ L_M_L = js.link.spatial_inertia(model=model, link_index=link_index)
368
+
369
+ # Compute the body-fixed 6D velocity.
370
+ L_v_WL = js.link.velocity(
371
+ model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body
372
+ )
373
+
374
+ # Compute the world-to-link transformations for 6D forces.
375
+ W_Xf_L = jaxsim.math.Adjoint.from_transform(
376
+ transform=W_H_L[link_index], inverse=True
377
+ ).T
378
+
379
+ # Compute the contribution of the link to the bias acceleration of the CoM.
380
+ W_ḣ_bias_link_contribution = W_Xf_L @ (
381
+ L_M_L @ L_a_bias_WL + jaxsim.math.Cross.vx_star(L_v_WL) @ L_M_L @ L_v_WL
382
+ )
383
+
384
+ return W_ḣ_bias_link_contribution
385
+
386
+ # Sum the contributions of all links to the bias acceleration of the CoM.
387
+ W_ḣ_bias = jax.vmap(bias_momentum_derivative_term)(
388
+ jnp.arange(model.number_of_links()), L_a_bias_WL
389
+ ).sum(axis=0)
390
+
391
+ # Compute the total mass of the model.
392
+ m = js.model.total_mass(model=model)
393
+
394
+ # Compute the position of the CoM.
395
+ W_p_CoM = com_position(model=model, data=data)
396
+
397
+ match data.velocity_representation:
398
+
399
+ # G := G[W] = (W_p_CoM, [W])
400
+ case VelRepr.Inertial | VelRepr.Mixed:
401
+
402
+ W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
403
+ GW_Xf_W = jaxsim.math.Adjoint.from_transform(W_H_GW).T
404
+
405
+ GW_ḣ_bias = GW_Xf_W @ W_ḣ_bias
406
+ GW_v̇l_com_bias = GW_ḣ_bias[0:3] / m
407
+
408
+ return GW_v̇l_com_bias
409
+
410
+ # G := G[B] = (W_p_CoM, [B])
411
+ case VelRepr.Body:
412
+
413
+ GB_Xf_W = jaxsim.math.Adjoint.from_transform(
414
+ transform=data.base_transform().at[0:3].set(W_p_CoM)
415
+ ).T
416
+
417
+ GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias
418
+ GB_v̇l_com_bias = GB_ḣ_bias[0:3] / m
419
+
420
+ return GB_v̇l_com_bias
421
+
422
+ case _:
423
+ raise ValueError(data.velocity_representation)