jaxsim 0.1rc0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1rc0.dist-info/METADATA +0 -167
  88. jaxsim-0.1rc0.dist-info/RECORD +0 -64
  89. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,7 @@ from jaxsim import logging
10
10
  from jaxsim.math.quaternion import Quaternion
11
11
  from jaxsim.parsers import descriptions, kinematic_graph
12
12
 
13
- from . import utils as utils
13
+ from . import utils
14
14
 
15
15
 
16
16
  class SDFData(NamedTuple):
@@ -135,11 +135,13 @@ def extract_model_data(
135
135
  parent=world_link,
136
136
  child=links_dict[j.child],
137
137
  jtype=utils.axis_to_jtype(axis=j.axis, type=j.type),
138
- axis=np.array(j.axis.xyz.xyz)
139
- if j.axis is not None
140
- and j.axis.xyz is not None
141
- and j.axis.xyz.xyz is not None
142
- else None,
138
+ axis=(
139
+ np.array(j.axis.xyz.xyz)
140
+ if j.axis is not None
141
+ and j.axis.xyz is not None
142
+ and j.axis.xyz.xyz is not None
143
+ else None
144
+ ),
143
145
  pose=j.pose.transform() if j.pose is not None else np.eye(4),
144
146
  )
145
147
  for j in sdf_model.joints()
@@ -200,41 +202,55 @@ def extract_model_data(
200
202
  parent=links_dict[j.parent],
201
203
  child=links_dict[j.child],
202
204
  jtype=utils.axis_to_jtype(axis=j.axis, type=j.type),
203
- axis=np.array(j.axis.xyz.xyz)
204
- if j.axis is not None
205
- and j.axis.xyz is not None
206
- and j.axis.xyz.xyz is not None
207
- else None,
205
+ axis=(
206
+ np.array(j.axis.xyz.xyz)
207
+ if j.axis is not None
208
+ and j.axis.xyz is not None
209
+ and j.axis.xyz.xyz is not None
210
+ else None
211
+ ),
208
212
  pose=j.pose.transform() if j.pose is not None else np.eye(4),
209
213
  initial_position=0.0,
210
214
  position_limit=(
211
- float(j.axis.limit.lower)
212
- if j.axis is not None and j.axis.limit is not None
213
- else np.finfo(float).min,
214
- float(j.axis.limit.upper)
215
- if j.axis is not None and j.axis.limit is not None
216
- else np.finfo(float).max,
215
+ (
216
+ float(j.axis.limit.lower)
217
+ if j.axis is not None and j.axis.limit is not None
218
+ else np.finfo(float).min
219
+ ),
220
+ (
221
+ float(j.axis.limit.upper)
222
+ if j.axis is not None and j.axis.limit is not None
223
+ else np.finfo(float).max
224
+ ),
225
+ ),
226
+ friction_static=(
227
+ j.axis.dynamics.friction
228
+ if j.axis is not None
229
+ and j.axis.dynamics is not None
230
+ and j.axis.dynamics.friction is not None
231
+ else 0.0
232
+ ),
233
+ friction_viscous=(
234
+ j.axis.dynamics.damping
235
+ if j.axis is not None
236
+ and j.axis.dynamics is not None
237
+ and j.axis.dynamics.damping is not None
238
+ else 0.0
239
+ ),
240
+ position_limit_damper=(
241
+ j.axis.limit.dissipation
242
+ if j.axis is not None
243
+ and j.axis.limit is not None
244
+ and j.axis.limit.dissipation is not None
245
+ else 0.0
246
+ ),
247
+ position_limit_spring=(
248
+ j.axis.limit.stiffness
249
+ if j.axis is not None
250
+ and j.axis.limit is not None
251
+ and j.axis.limit.stiffness is not None
252
+ else 0.0
217
253
  ),
218
- friction_static=j.axis.dynamics.friction
219
- if j.axis is not None
220
- and j.axis.dynamics is not None
221
- and j.axis.dynamics.friction is not None
222
- else 0.0,
223
- friction_viscous=j.axis.dynamics.damping
224
- if j.axis is not None
225
- and j.axis.dynamics is not None
226
- and j.axis.dynamics.damping is not None
227
- else 0.0,
228
- position_limit_damper=j.axis.limit.dissipation
229
- if j.axis is not None
230
- and j.axis.limit is not None
231
- and j.axis.limit.dissipation is not None
232
- else 0.0,
233
- position_limit_spring=j.axis.limit.stiffness
234
- if j.axis is not None
235
- and j.axis.limit is not None
236
- and j.axis.limit.stiffness is not None
237
- else 0.0,
238
254
  )
239
255
  for j in sdf_model.joints()
240
256
  if j.type in {"revolute", "prismatic", "fixed"}
@@ -341,6 +357,6 @@ def build_model_description(
341
357
  )
342
358
 
343
359
  # Store the parsed SDF tree as extra info
344
- model = dataclasses.replace(model, extra_info=dict(sdf_model=sdf_data.sdf_model))
360
+ model = dataclasses.replace(model, extra_info={"sdf_model": sdf_data.sdf_model})
345
361
 
346
362
  return model
@@ -1,15 +1,17 @@
1
1
  import os
2
2
  from typing import Union
3
3
 
4
- import jax.numpy as jnp
4
+ import jaxlie
5
5
  import numpy as np
6
6
  import numpy.typing as npt
7
7
  import rod
8
8
 
9
+ import jaxsim.typing as jtp
10
+ from jaxsim.math.inertia import Inertia
9
11
  from jaxsim.parsers import descriptions
10
12
 
11
13
 
12
- def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray:
14
+ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
13
15
  """
14
16
  Extract the 6D inertia matrix from an SDF inertial element.
15
17
 
@@ -20,9 +22,6 @@ def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray:
20
22
  The 6D inertia matrix of the link expressed in the link frame.
21
23
  """
22
24
 
23
- from jaxsim.math.inertia import Inertia
24
- from jaxsim.sixd import se3
25
-
26
25
  # Extract the "mass" element
27
26
  m = inertial.mass
28
27
 
@@ -52,13 +51,13 @@ def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray:
52
51
  L_H_CoM = inertial.pose.transform() if inertial.pose is not None else np.eye(4)
53
52
 
54
53
  # We need its inverse
55
- CoM_H_L = se3.SE3.from_matrix(matrix=L_H_CoM).inverse()
56
- CoM_X_L: npt.NDArray = CoM_H_L.adjoint()
54
+ CoM_H_L = jaxlie.SE3.from_matrix(matrix=L_H_CoM).inverse()
55
+ CoM_X_L = CoM_H_L.adjoint()
57
56
 
58
57
  # Express the CoM inertia matrix in the link frame L
59
58
  M_L = CoM_X_L.T @ M_CoM @ CoM_X_L
60
59
 
61
- return jnp.array(M_L)
60
+ return M_L.astype(dtype=float)
62
61
 
63
62
 
64
63
  def axis_to_jtype(
@@ -0,0 +1,7 @@
1
+ from .aba import aba
2
+ from .collidable_points import collidable_points_pos_vel
3
+ from .crba import crba
4
+ from .forward_kinematics import forward_kinematics, forward_kinematics_model
5
+ from .jacobian import jacobian, jacobian_full_doubly_left
6
+ from .rnea import rnea
7
+ from .soft_contacts import SoftContacts, SoftContactsParams
jaxsim/rbda/aba.py ADDED
@@ -0,0 +1,295 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import jaxlie
4
+
5
+ import jaxsim.api as js
6
+ import jaxsim.typing as jtp
7
+ from jaxsim.math import Adjoint, Cross, Quaternion, StandardGravity
8
+
9
+ from . import utils
10
+
11
+
12
+ def aba(
13
+ model: js.model.JaxSimModel,
14
+ *,
15
+ base_position: jtp.VectorLike,
16
+ base_quaternion: jtp.VectorLike,
17
+ joint_positions: jtp.VectorLike,
18
+ base_linear_velocity: jtp.VectorLike,
19
+ base_angular_velocity: jtp.VectorLike,
20
+ joint_velocities: jtp.VectorLike,
21
+ joint_forces: jtp.VectorLike | None = None,
22
+ link_forces: jtp.MatrixLike | None = None,
23
+ standard_gravity: jtp.FloatLike = StandardGravity,
24
+ ) -> tuple[jtp.Vector, jtp.Vector]:
25
+ """
26
+ Compute forward dynamics using the Articulated Body Algorithm (ABA).
27
+
28
+ Args:
29
+ model: The model to consider.
30
+ base_position: The position of the base link.
31
+ base_quaternion: The quaternion of the base link.
32
+ joint_positions: The positions of the joints.
33
+ base_linear_velocity:
34
+ The linear velocity of the base link in inertial-fixed representation.
35
+ base_angular_velocity:
36
+ The angular velocity of the base link in inertial-fixed representation.
37
+ joint_velocities: The velocities of the joints.
38
+ joint_forces: The forces applied to the joints.
39
+ link_forces:
40
+ The forces applied to the links expressed in the world frame.
41
+ standard_gravity: The standard gravity constant.
42
+
43
+ Returns:
44
+ A tuple containing the base acceleration in inertial-fixed representation
45
+ and the joint accelerations that result from the applications of the given
46
+ joint and link forces.
47
+
48
+ Note:
49
+ The algorithm expects a quaternion with unit norm.
50
+ """
51
+
52
+ W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, τ, W_f, W_g = utils.process_inputs(
53
+ model=model,
54
+ base_position=base_position,
55
+ base_quaternion=base_quaternion,
56
+ joint_positions=joint_positions,
57
+ base_linear_velocity=base_linear_velocity,
58
+ base_angular_velocity=base_angular_velocity,
59
+ joint_velocities=joint_velocities,
60
+ base_linear_acceleration=None,
61
+ base_angular_acceleration=None,
62
+ joint_accelerations=None,
63
+ joint_forces=joint_forces,
64
+ link_forces=link_forces,
65
+ standard_gravity=standard_gravity,
66
+ )
67
+
68
+ W_g = jnp.atleast_2d(W_g).T
69
+ W_v_WB = jnp.atleast_2d(W_v_WB).T
70
+
71
+ # Get the 6D spatial inertia matrices of all links.
72
+ M = js.model.link_spatial_inertia_matrices(model=model)
73
+
74
+ # Get the parent array λ(i).
75
+ # Note: λ(0) must not be used, it's initialized to -1.
76
+ λ = model.kin_dyn_parameters.parent_array
77
+
78
+ # Compute the base transform.
79
+ W_H_B = jaxlie.SE3.from_rotation_and_translation(
80
+ rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
81
+ translation=W_p_B,
82
+ )
83
+
84
+ # Compute 6D transforms of the base velocity.
85
+ W_X_B = W_H_B.adjoint()
86
+ B_X_W = W_H_B.inverse().adjoint()
87
+
88
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
89
+ # These transforms define the relative kinematics of the entire model, including
90
+ # the base transform for both floating-base and fixed-base models.
91
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
92
+ joint_positions=s, base_transform=W_H_B.as_matrix()
93
+ )
94
+
95
+ # Allocate buffers.
96
+ v = jnp.zeros(shape=(model.number_of_links(), 6, 1))
97
+ c = jnp.zeros(shape=(model.number_of_links(), 6, 1))
98
+ pA = jnp.zeros(shape=(model.number_of_links(), 6, 1))
99
+ MA = jnp.zeros(shape=(model.number_of_links(), 6, 6))
100
+
101
+ # Allocate the buffer of transforms link -> base.
102
+ i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
103
+ i_X_0 = i_X_0.at[0].set(jnp.eye(6))
104
+
105
+ # Initialize base quantities
106
+ if model.floating_base():
107
+
108
+ # Base velocity v₀ in body-fixed representation.
109
+ v_0 = B_X_W @ W_v_WB
110
+ v = v.at[0].set(v_0)
111
+
112
+ # Initialize the articulated-body inertia (Mᴬ) of base link.
113
+ MA_0 = M[0]
114
+ MA = MA.at[0].set(MA_0)
115
+
116
+ # Initialize the articulated-body bias force (pᴬ) of the base link.
117
+ pA_0 = Cross.vx_star(v[0]) @ MA[0] @ v[0] - W_X_B.T @ jnp.vstack(W_f[0])
118
+ pA = pA.at[0].set(pA_0)
119
+
120
+ # ======
121
+ # Pass 1
122
+ # ======
123
+
124
+ Pass1Carry = tuple[
125
+ jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
126
+ ]
127
+
128
+ pass_1_carry: Pass1Carry = (v, c, MA, pA, i_X_0)
129
+
130
+ # Propagate kinematics and initialize AB inertia and AB bias forces.
131
+ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]:
132
+
133
+ ii = i - 1
134
+ v, c, MA, pA, i_X_0 = carry
135
+
136
+ # Project the joint velocity into its motion subspace.
137
+ vJ = S[i] * ṡ[ii]
138
+
139
+ # Propagate the link velocity.
140
+ v_i = i_X_λi[i] @ v[λ[i]] + vJ
141
+ v = v.at[i].set(v_i)
142
+
143
+ c_i = Cross.vx(v[i]) @ vJ
144
+ c = c.at[i].set(c_i)
145
+
146
+ # Initialize the articulated-body inertia.
147
+ MA_i = jnp.array(M[i])
148
+ MA = MA.at[i].set(MA_i)
149
+
150
+ # Compute the link-to-base transform.
151
+ i_Xi_0 = i_X_λi[i] @ i_X_0[λ[i]]
152
+ i_X_0 = i_X_0.at[i].set(i_Xi_0)
153
+
154
+ # Compute link-to-world transform for the 6D force.
155
+ i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
156
+
157
+ # Initialize articulated-body bias force.
158
+ pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(W_f[i])
159
+ pA = pA.at[i].set(pA_i)
160
+
161
+ return (v, c, MA, pA, i_X_0), None
162
+
163
+ (v, c, MA, pA, i_X_0), _ = (
164
+ jax.lax.scan(
165
+ f=loop_body_pass1,
166
+ init=pass_1_carry,
167
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
168
+ )
169
+ if model.number_of_links() > 1
170
+ else [(v, c, MA, pA, i_X_0), None]
171
+ )
172
+
173
+ # ======
174
+ # Pass 2
175
+ # ======
176
+
177
+ U = jnp.zeros_like(S)
178
+ d = jnp.zeros(shape=(model.number_of_links(), 1))
179
+ u = jnp.zeros(shape=(model.number_of_links(), 1))
180
+
181
+ Pass2Carry = tuple[
182
+ jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
183
+ ]
184
+
185
+ pass_2_carry: Pass2Carry = (U, d, u, MA, pA)
186
+
187
+ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]:
188
+
189
+ ii = i - 1
190
+ U, d, u, MA, pA = carry
191
+
192
+ U_i = MA[i] @ S[i]
193
+ U = U.at[i].set(U_i)
194
+
195
+ d_i = S[i].T @ U[i]
196
+ d = d.at[i].set(d_i.squeeze())
197
+
198
+ u_i = τ[ii] - S[i].T @ pA[i]
199
+ u = u.at[i].set(u_i.squeeze())
200
+
201
+ # Compute the articulated-body inertia and bias force of this link.
202
+ Ma = MA[i] - U[i] / d[i] @ U[i].T
203
+ pa = pA[i] + Ma @ c[i] + U[i] * (u[i] / d[i])
204
+
205
+ # Propagate them to the parent, handling the base link.
206
+ def propagate(
207
+ MA_pA: tuple[jtp.MatrixJax, jtp.MatrixJax]
208
+ ) -> tuple[jtp.MatrixJax, jtp.MatrixJax]:
209
+
210
+ MA, pA = MA_pA
211
+
212
+ MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]
213
+ MA = MA.at[λ[i]].set(MA_λi)
214
+
215
+ pA_λi = pA[λ[i]] + i_X_λi[i].T @ pa
216
+ pA = pA.at[λ[i]].set(pA_λi)
217
+
218
+ return MA, pA
219
+
220
+ MA, pA = jax.lax.cond(
221
+ pred=jnp.logical_or(λ[i] != 0, model.floating_base()),
222
+ true_fun=propagate,
223
+ false_fun=lambda MA_pA: MA_pA,
224
+ operand=(MA, pA),
225
+ )
226
+
227
+ return (U, d, u, MA, pA), None
228
+
229
+ (U, d, u, MA, pA), _ = (
230
+ jax.lax.scan(
231
+ f=loop_body_pass2,
232
+ init=pass_2_carry,
233
+ xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
234
+ )
235
+ if model.number_of_links() > 1
236
+ else [(U, d, u, MA, pA), None]
237
+ )
238
+
239
+ # ======
240
+ # Pass 3
241
+ # ======
242
+
243
+ if model.floating_base():
244
+ a0 = jnp.linalg.solve(-MA[0], pA[0])
245
+ else:
246
+ a0 = -B_X_W @ W_g
247
+
248
+ s̈ = jnp.zeros_like(s)
249
+ a = jnp.zeros_like(v).at[0].set(a0)
250
+
251
+ Pass3Carry = tuple[jtp.MatrixJax, jtp.VectorJax]
252
+ pass_3_carry = (a, s̈)
253
+
254
+ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]:
255
+
256
+ ii = i - 1
257
+ a, s̈ = carry
258
+
259
+ # Propagate the link acceleration.
260
+ a_i = i_X_λi[i] @ a[λ[i]] + c[i]
261
+
262
+ # Compute the joint acceleration.
263
+ s̈_ii = (u[i] - U[i].T @ a_i) / d[i]
264
+ s̈ = s̈.at[ii].set(s̈_ii.squeeze())
265
+
266
+ # Sum the joint acceleration to the parent link acceleration.
267
+ a_i = a_i + S[i] * s̈[ii]
268
+ a = a.at[i].set(a_i)
269
+
270
+ return (a, s̈), None
271
+
272
+ (a, s̈), _ = (
273
+ jax.lax.scan(
274
+ f=loop_body_pass3,
275
+ init=pass_3_carry,
276
+ xs=jnp.arange(1, model.number_of_links()),
277
+ )
278
+ if model.number_of_links() > 1
279
+ else [(a, s̈), None]
280
+ )
281
+
282
+ # ==============
283
+ # Adjust outputs
284
+ # ==============
285
+
286
+ # TODO: remove vstack and shape=(6, 1)?
287
+ if model.floating_base():
288
+ # Convert the base acceleration to inertial-fixed representation,
289
+ # and add gravity.
290
+ B_a_WB = a[0]
291
+ W_a_WB = W_X_B @ B_a_WB + W_g
292
+ else:
293
+ W_a_WB = jnp.zeros(6)
294
+
295
+ return W_a_WB.squeeze(), jnp.atleast_1d(s̈.squeeze())
@@ -0,0 +1,142 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import jaxlie
4
+
5
+ import jaxsim.api as js
6
+ import jaxsim.typing as jtp
7
+ from jaxsim.math import Adjoint, Quaternion, Skew
8
+
9
+ from . import utils
10
+
11
+
12
+ def collidable_points_pos_vel(
13
+ model: js.model.JaxSimModel,
14
+ *,
15
+ base_position: jtp.Vector,
16
+ base_quaternion: jtp.Vector,
17
+ joint_positions: jtp.Vector,
18
+ base_linear_velocity: jtp.Vector,
19
+ base_angular_velocity: jtp.Vector,
20
+ joint_velocities: jtp.Vector,
21
+ ) -> tuple[jtp.Matrix, jtp.Matrix]:
22
+ """
23
+
24
+ Compute the position and linear velocity of collidable points in the world frame.
25
+
26
+ Args:
27
+ model: The model to consider.
28
+ base_position: The position of the base link.
29
+ base_quaternion: The quaternion of the base link.
30
+ joint_positions: The positions of the joints.
31
+ base_linear_velocity:
32
+ The linear velocity of the base link in inertial-fixed representation.
33
+ base_angular_velocity:
34
+ The angular velocity of the base link in inertial-fixed representation.
35
+ joint_velocities: The velocities of the joints.
36
+
37
+ Returns:
38
+ A tuple containing the position and linear velocity of collidable points.
39
+ """
40
+
41
+ if len(model.kin_dyn_parameters.contact_parameters.body) == 0:
42
+ return jnp.array(0).astype(float), jnp.empty(0).astype(float)
43
+
44
+ W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs(
45
+ model=model,
46
+ base_position=base_position,
47
+ base_quaternion=base_quaternion,
48
+ joint_positions=joint_positions,
49
+ base_linear_velocity=base_linear_velocity,
50
+ base_angular_velocity=base_angular_velocity,
51
+ joint_velocities=joint_velocities,
52
+ )
53
+
54
+ # Get the parent array λ(i).
55
+ # Note: λ(0) must not be used, it's initialized to -1.
56
+ λ = model.kin_dyn_parameters.parent_array
57
+
58
+ # Compute the base transform.
59
+ W_H_B = jaxlie.SE3.from_rotation_and_translation(
60
+ rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
61
+ translation=W_p_B,
62
+ )
63
+
64
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
65
+ # These transforms define the relative kinematics of the entire model, including
66
+ # the base transform for both floating-base and fixed-base models.
67
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
68
+ joint_positions=s, base_transform=W_H_B.as_matrix()
69
+ )
70
+
71
+ # Allocate buffer of transforms world -> link and initialize the base pose.
72
+ W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
73
+ W_X_i = W_X_i.at[0].set(Adjoint.inverse(i_X_λi[0]))
74
+
75
+ # Allocate buffer of 6D inertial-fixed velocities and initialize the base velocity.
76
+ W_v_Wi = jnp.zeros(shape=(model.number_of_links(), 6))
77
+ W_v_Wi = W_v_Wi.at[0].set(W_v_WB)
78
+
79
+ # ====================
80
+ # Propagate kinematics
81
+ # ====================
82
+
83
+ PropagateTransformsCarry = tuple[jtp.MatrixJax, jtp.Matrix]
84
+ propagate_transforms_carry: PropagateTransformsCarry = (W_X_i, W_v_Wi)
85
+
86
+ def propagate_kinematics(
87
+ carry: PropagateTransformsCarry, i: jtp.Int
88
+ ) -> tuple[PropagateTransformsCarry, None]:
89
+
90
+ ii = i - 1
91
+ W_X_i, W_v_Wi = carry
92
+
93
+ # Compute the parent to child 6D transform.
94
+ λi_X_i = Adjoint.inverse(adjoint=i_X_λi[i])
95
+
96
+ # Compute the world to child 6D transform.
97
+ W_Xi_i = W_X_i[λ[i]] @ λi_X_i
98
+ W_X_i = W_X_i.at[i].set(W_Xi_i)
99
+
100
+ # Propagate the 6D velocity
101
+ W_vi_Wi = W_v_Wi[λ[i]] + W_X_i[i] @ (S[i] * ṡ[ii]).squeeze()
102
+ W_v_Wi = W_v_Wi.at[i].set(W_vi_Wi)
103
+
104
+ return (W_X_i, W_v_Wi), None
105
+
106
+ (W_X_i, W_v_Wi), _ = (
107
+ jax.lax.scan(
108
+ f=propagate_kinematics,
109
+ init=propagate_transforms_carry,
110
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
111
+ )
112
+ if model.number_of_links() > 1
113
+ else [(W_X_i, W_v_Wi), None]
114
+ )
115
+
116
+ # ==================================================
117
+ # Compute position and velocity of collidable points
118
+ # ==================================================
119
+
120
+ def process_point_kinematics(
121
+ Li_p_C: jtp.VectorJax, parent_body: jtp.Int
122
+ ) -> tuple[jtp.VectorJax, jtp.VectorJax]:
123
+ # Compute the position of the collidable point
124
+ W_p_Ci = (
125
+ Adjoint.to_transform(adjoint=W_X_i[parent_body]) @ jnp.hstack([Li_p_C, 1])
126
+ )[0:3]
127
+
128
+ # Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}
129
+ CW_vl_WCi = (
130
+ jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()])
131
+ @ W_v_Wi[parent_body].squeeze()
132
+ )
133
+
134
+ return W_p_Ci, CW_vl_WCi
135
+
136
+ # Process all the collidable points in parallel
137
+ W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)(
138
+ model.kin_dyn_parameters.contact_parameters.point,
139
+ jnp.array(model.kin_dyn_parameters.contact_parameters.body),
140
+ )
141
+
142
+ return W_p_Ci, CW_vl_WC