jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +64 -30
  24. jaxsim/math/cross.py +18 -9
  25. jaxsim/math/inertia.py +11 -9
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +59 -25
  28. jaxsim/math/rotation.py +30 -24
  29. jaxsim/math/skew.py +18 -7
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev5.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev5.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/top_level.txt +0 -0
jaxsim/rbda/crba.py ADDED
@@ -0,0 +1,167 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+ import jaxsim.api as js
5
+ import jaxsim.typing as jtp
6
+
7
+ from . import utils
8
+
9
+
10
+ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Matrix:
11
+ """
12
+ Compute the free-floating mass matrix using the Composite Rigid-Body Algorithm (CRBA).
13
+
14
+ Args:
15
+ model: The model to consider.
16
+ joint_positions: The positions of the joints.
17
+
18
+ Returns:
19
+ The free-floating mass matrix of the model in body-fixed representation.
20
+ """
21
+
22
+ _, _, s, _, _, _, _, _, _, _ = utils.process_inputs(
23
+ model=model, joint_positions=joint_positions
24
+ )
25
+
26
+ # Get the 6D spatial inertia matrices of all links.
27
+ Mc = js.model.link_spatial_inertia_matrices(model=model)
28
+
29
+ # Get the parent array λ(i).
30
+ # Note: λ(0) must not be used, it's initialized to -1.
31
+ λ = model.kin_dyn_parameters.parent_array
32
+
33
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
34
+ # These transforms define the relative kinematics of the entire model, including
35
+ # the base transform for both floating-base and fixed-base models.
36
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
37
+ joint_positions=s, base_transform=jnp.eye(4)
38
+ )
39
+
40
+ # Allocate the buffer of transforms link -> base.
41
+ i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
42
+ i_X_0 = i_X_0.at[0].set(jnp.eye(6))
43
+
44
+ # ====================
45
+ # Propagate kinematics
46
+ # ====================
47
+
48
+ ForwardPassCarry = tuple[jtp.Matrix]
49
+ forward_pass_carry: ForwardPassCarry = (i_X_0,)
50
+
51
+ def propagate_kinematics(
52
+ carry: ForwardPassCarry, i: jtp.Int
53
+ ) -> tuple[ForwardPassCarry, None]:
54
+
55
+ (i_X_0,) = carry
56
+
57
+ i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
58
+ i_X_0 = i_X_0.at[i].set(i_X_0_i)
59
+
60
+ return (i_X_0,), None
61
+
62
+ (i_X_0,), _ = (
63
+ jax.lax.scan(
64
+ f=propagate_kinematics,
65
+ init=forward_pass_carry,
66
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
67
+ )
68
+ if model.number_of_links() > 1
69
+ else [(i_X_0,), None]
70
+ )
71
+
72
+ # ===================
73
+ # Compute mass matrix
74
+ # ===================
75
+
76
+ M = jnp.zeros(shape=(6 + model.dofs(), 6 + model.dofs()))
77
+
78
+ BackwardPassCarry = tuple[jtp.Matrix, jtp.Matrix]
79
+ backward_pass_carry: BackwardPassCarry = (Mc, M)
80
+
81
+ def backward_pass(
82
+ carry: BackwardPassCarry, i: jtp.Int
83
+ ) -> tuple[BackwardPassCarry, None]:
84
+
85
+ ii = i - 1
86
+ Mc, M = carry
87
+
88
+ Mc_λi = Mc[λ[i]] + i_X_λi[i].T @ Mc[i] @ i_X_λi[i]
89
+ Mc = Mc.at[λ[i]].set(Mc_λi)
90
+
91
+ Fi = Mc[i] @ S[i]
92
+ M_ii = S[i].T @ Fi
93
+ M = M.at[ii + 6, ii + 6].set(M_ii.squeeze())
94
+
95
+ j = i
96
+
97
+ FakeWhileCarry = tuple[jtp.Int, jtp.Vector, jtp.Matrix]
98
+ fake_while_carry = (j, Fi, M)
99
+
100
+ # This internal for loop implements the while loop of the CRBA algorithm
101
+ # to compute off-diagonal blocks of the mass matrix M.
102
+ # In pseudocode it is implemented as a while loop. However, in order to enable
103
+ # applying reverse-mode AD, we implement it as a nested for loop with a fixed
104
+ # number of iterations and a branching model to skip for loop iterations.
105
+ def fake_while_loop(
106
+ carry: FakeWhileCarry, i: jtp.Int
107
+ ) -> tuple[FakeWhileCarry, None]:
108
+
109
+ def compute(carry: FakeWhileCarry) -> FakeWhileCarry:
110
+
111
+ j, Fi, M = carry
112
+
113
+ Fi = i_X_λi[j].T @ Fi
114
+ j = λ[j]
115
+
116
+ M_ij = Fi.T @ S[j]
117
+
118
+ jj = j - 1
119
+ M = M.at[ii + 6, jj + 6].set(M_ij.squeeze())
120
+ M = M.at[jj + 6, ii + 6].set(M_ij.squeeze())
121
+
122
+ return j, Fi, M
123
+
124
+ j, _, _ = carry
125
+
126
+ j, Fi, M = jax.lax.cond(
127
+ pred=jnp.logical_and(i == λ[j], λ[j] > 0),
128
+ true_fun=compute,
129
+ false_fun=lambda carry: carry,
130
+ operand=carry,
131
+ )
132
+
133
+ return (j, Fi, M), None
134
+
135
+ (j, Fi, M), _ = (
136
+ jax.lax.scan(
137
+ f=fake_while_loop,
138
+ init=fake_while_carry,
139
+ xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
140
+ )
141
+ if model.number_of_links() > 1
142
+ else [(j, Fi, M), None]
143
+ )
144
+
145
+ Fi = i_X_0[j].T @ Fi
146
+
147
+ M = M.at[0:6, ii + 6].set(Fi.squeeze())
148
+ M = M.at[ii + 6, 0:6].set(Fi.squeeze())
149
+
150
+ return (Mc, M), None
151
+
152
+ # This scan performs the backward pass to compute Mbj, Mjb and Mjj, that
153
+ # also includes a fake while loop implemented with a scan and two cond.
154
+ (Mc, M), _ = (
155
+ jax.lax.scan(
156
+ f=backward_pass,
157
+ init=backward_pass_carry,
158
+ xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
159
+ )
160
+ if model.number_of_links() > 1
161
+ else [(Mc, M), None]
162
+ )
163
+
164
+ # Store the locked 6D rigid-body inertia matrix Mbb ∈ ℝ⁶ˣ⁶.
165
+ M = M.at[0:6, 0:6].set(Mc[0])
166
+
167
+ return M
@@ -0,0 +1,117 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import jaxlie
4
+
5
+ import jaxsim.api as js
6
+ import jaxsim.typing as jtp
7
+ from jaxsim.math import Adjoint
8
+
9
+ from . import utils
10
+
11
+
12
+ def forward_kinematics_model(
13
+ model: js.model.JaxSimModel,
14
+ *,
15
+ base_position: jtp.VectorLike,
16
+ base_quaternion: jtp.VectorLike,
17
+ joint_positions: jtp.VectorLike,
18
+ ) -> jtp.Array:
19
+ """
20
+ Compute the forward kinematics.
21
+
22
+ Args:
23
+ model: The model to consider.
24
+ base_position: The position of the base link.
25
+ base_quaternion: The quaternion of the base link.
26
+ joint_positions: The positions of the joints.
27
+
28
+ Returns:
29
+ A 3D array containing the SE(3) transforms of all links belonging to the model.
30
+ """
31
+
32
+ W_p_B, W_Q_B, s, _, _, _, _, _, _, _ = utils.process_inputs(
33
+ model=model,
34
+ base_position=base_position,
35
+ base_quaternion=base_quaternion,
36
+ joint_positions=joint_positions,
37
+ )
38
+
39
+ # Get the parent array λ(i).
40
+ # Note: λ(0) must not be used, it's initialized to -1.
41
+ λ = model.kin_dyn_parameters.parent_array
42
+
43
+ # Compute the base transform.
44
+ W_H_B = jaxlie.SE3.from_rotation_and_translation(
45
+ rotation=jaxlie.SO3(wxyz=W_Q_B),
46
+ translation=W_p_B,
47
+ )
48
+
49
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
50
+ # These transforms define the relative kinematics of the entire model, including
51
+ # the base transform for both floating-base and fixed-base models.
52
+ i_X_λi, _ = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
53
+ joint_positions=s, base_transform=W_H_B.as_matrix()
54
+ )
55
+
56
+ # Allocate the buffer of transforms world -> link and initialize the base pose.
57
+ W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
58
+ W_X_i = W_X_i.at[0].set(Adjoint.inverse(i_X_λi[0]))
59
+
60
+ # ========================
61
+ # Propagate the kinematics
62
+ # ========================
63
+
64
+ PropagateKinematicsCarry = tuple[jtp.Matrix]
65
+ propagate_kinematics_carry: PropagateKinematicsCarry = (W_X_i,)
66
+
67
+ def propagate_kinematics(
68
+ carry: PropagateKinematicsCarry, i: jtp.Int
69
+ ) -> tuple[PropagateKinematicsCarry, None]:
70
+
71
+ (W_X_i,) = carry
72
+
73
+ W_X_i_i = W_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i])
74
+ W_X_i = W_X_i.at[i].set(W_X_i_i)
75
+
76
+ return (W_X_i,), None
77
+
78
+ (W_X_i,), _ = (
79
+ jax.lax.scan(
80
+ f=propagate_kinematics,
81
+ init=propagate_kinematics_carry,
82
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
83
+ )
84
+ if model.number_of_links() > 1
85
+ else [(W_X_i,), None]
86
+ )
87
+
88
+ return jax.vmap(Adjoint.to_transform)(W_X_i)
89
+
90
+
91
+ def forward_kinematics(
92
+ model: js.model.JaxSimModel,
93
+ link_index: jtp.Int,
94
+ base_position: jtp.VectorLike,
95
+ base_quaternion: jtp.VectorLike,
96
+ joint_positions: jtp.VectorLike,
97
+ ) -> jtp.Matrix:
98
+ """
99
+ Compute the forward kinematics of a specific link.
100
+
101
+ Args:
102
+ model: The model to consider.
103
+ link_index: The index of the link to consider.
104
+ base_position: The position of the base link.
105
+ base_quaternion: The quaternion of the base link.
106
+ joint_positions: The positions of the joints.
107
+
108
+ Returns:
109
+ The SE(3) transform of the link.
110
+ """
111
+
112
+ return forward_kinematics_model(
113
+ model=model,
114
+ base_position=base_position,
115
+ base_quaternion=base_quaternion,
116
+ joint_positions=joint_positions,
117
+ )[link_index]
@@ -0,0 +1,330 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+
5
+ import jaxsim.api as js
6
+ import jaxsim.typing as jtp
7
+ from jaxsim.math import Adjoint, Cross
8
+
9
+ from . import utils
10
+
11
+
12
+ def jacobian(
13
+ model: js.model.JaxSimModel,
14
+ *,
15
+ link_index: jtp.Int,
16
+ joint_positions: jtp.VectorLike,
17
+ ) -> jtp.Matrix:
18
+ """
19
+ Compute the free-floating Jacobian of a link.
20
+
21
+ Args:
22
+ model: The model to consider.
23
+ link_index: The index of the link for which to compute the Jacobian matrix.
24
+ joint_positions: The positions of the joints.
25
+
26
+ Returns:
27
+ The free-floating left-trivialized Jacobian of the link :math:`{}^L J_{W,L/B}`.
28
+ """
29
+
30
+ _, _, s, _, _, _, _, _, _, _ = utils.process_inputs(
31
+ model=model, joint_positions=joint_positions
32
+ )
33
+
34
+ # Get the parent array λ(i).
35
+ # Note: λ(0) must not be used, it's initialized to -1.
36
+ λ = model.kin_dyn_parameters.parent_array
37
+
38
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
39
+ # These transforms define the relative kinematics of the entire model, including
40
+ # the base transform for both floating-base and fixed-base models.
41
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
42
+ joint_positions=s, base_transform=jnp.eye(4)
43
+ )
44
+
45
+ # Allocate the buffer of transforms link -> base.
46
+ i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
47
+ i_X_0 = i_X_0.at[0].set(jnp.eye(6))
48
+
49
+ # ====================
50
+ # Propagate kinematics
51
+ # ====================
52
+
53
+ PropagateKinematicsCarry = tuple[jtp.Matrix]
54
+ propagate_kinematics_carry: PropagateKinematicsCarry = (i_X_0,)
55
+
56
+ def propagate_kinematics(
57
+ carry: PropagateKinematicsCarry, i: jtp.Int
58
+ ) -> tuple[PropagateKinematicsCarry, None]:
59
+
60
+ (i_X_0,) = carry
61
+
62
+ # Compute the base (0) to link (i) adjoint matrix.
63
+ # This works fine since we traverse the kinematic tree following the link
64
+ # indices assigned with BFS.
65
+ i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
66
+ i_X_0 = i_X_0.at[i].set(i_X_0_i)
67
+
68
+ return (i_X_0,), None
69
+
70
+ (i_X_0,), _ = (
71
+ jax.lax.scan(
72
+ f=propagate_kinematics,
73
+ init=propagate_kinematics_carry,
74
+ xs=np.arange(start=1, stop=model.number_of_links()),
75
+ )
76
+ if model.number_of_links() > 1
77
+ else [(i_X_0,), None]
78
+ )
79
+
80
+ # ============================
81
+ # Compute doubly-left Jacobian
82
+ # ============================
83
+
84
+ J = jnp.zeros(shape=(6, 6 + model.dofs()))
85
+
86
+ Jb = i_X_0[link_index]
87
+ J = J.at[0:6, 0:6].set(Jb)
88
+
89
+ # To make JIT happy, we operate on a boolean version of κ(i).
90
+ # Checking if j ∈ κ(i) is equivalent to: κ_bool(j) is True.
91
+ κ_bool = model.kin_dyn_parameters.support_body_array_bool[link_index]
92
+
93
+ def compute_jacobian(J: jtp.Matrix, i: jtp.Int) -> tuple[jtp.Matrix, None]:
94
+
95
+ def update_jacobian(J: jtp.Matrix, i: jtp.Int) -> jtp.Matrix:
96
+
97
+ ii = i - 1
98
+
99
+ Js_i = i_X_0[link_index] @ Adjoint.inverse(i_X_0[i]) @ S[i]
100
+ J = J.at[0:6, 6 + ii].set(Js_i.squeeze())
101
+
102
+ return J
103
+
104
+ J = jax.lax.select(
105
+ pred=κ_bool[i],
106
+ on_true=update_jacobian(J, i),
107
+ on_false=J,
108
+ )
109
+
110
+ return J, None
111
+
112
+ L_J_WL_B, _ = (
113
+ jax.lax.scan(
114
+ f=compute_jacobian,
115
+ init=J,
116
+ xs=np.arange(start=1, stop=model.number_of_links()),
117
+ )
118
+ if model.number_of_links() > 1
119
+ else [J, None]
120
+ )
121
+
122
+ return L_J_WL_B
123
+
124
+
125
+ @jax.jit
126
+ def jacobian_full_doubly_left(
127
+ model: js.model.JaxSimModel,
128
+ *,
129
+ joint_positions: jtp.VectorLike,
130
+ ) -> tuple[jtp.Matrix, jtp.Array]:
131
+ r"""
132
+ Compute the doubly-left full free-floating Jacobian of a model.
133
+
134
+ The full Jacobian is a 6x(6+n) matrix with all the columns filled.
135
+ It is useful to run the algorithm once, and then extract the link Jacobian by
136
+ filtering the columns of the full Jacobian using the support parent array
137
+ :math:`\kappa(i)` of the link.
138
+
139
+ Args:
140
+ model: The model to consider.
141
+ joint_positions: The positions of the joints.
142
+
143
+ Returns:
144
+ The doubly-left full free-floating Jacobian of a model.
145
+ """
146
+
147
+ _, _, s, _, _, _, _, _, _, _ = utils.process_inputs(
148
+ model=model, joint_positions=joint_positions
149
+ )
150
+
151
+ # Get the parent array λ(i).
152
+ # Note: λ(0) must not be used, it's initialized to -1.
153
+ λ = model.kin_dyn_parameters.parent_array
154
+
155
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
156
+ # These transforms define the relative kinematics of the entire model, including
157
+ # the base transform for both floating-base and fixed-base models.
158
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
159
+ joint_positions=s, base_transform=jnp.eye(4)
160
+ )
161
+
162
+ # Allocate the buffer of transforms base -> link.
163
+ B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
164
+ B_X_i = B_X_i.at[0].set(jnp.eye(6))
165
+
166
+ # =================================
167
+ # Compute doubly-left full Jacobian
168
+ # =================================
169
+
170
+ # Allocate the Jacobian matrix.
171
+ # The Jbb section of the doubly-left Jacobian is an identity matrix.
172
+ J = jnp.zeros(shape=(6, 6 + model.dofs()))
173
+ J = J.at[0:6, 0:6].set(jnp.eye(6))
174
+
175
+ ComputeFullJacobianCarry = tuple[jtp.Matrix, jtp.Matrix]
176
+ compute_full_jacobian_carry: ComputeFullJacobianCarry = (B_X_i, J)
177
+
178
+ def compute_full_jacobian(
179
+ carry: ComputeFullJacobianCarry, i: jtp.Int
180
+ ) -> tuple[ComputeFullJacobianCarry, None]:
181
+
182
+ ii = i - 1
183
+ B_X_i, J = carry
184
+
185
+ # Compute the base (0) to link (i) adjoint matrix.
186
+ B_Xi_i = B_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i])
187
+ B_X_i = B_X_i.at[i].set(B_Xi_i)
188
+
189
+ # Compute the ii-th column of the B_S_BL(s) matrix.
190
+ B_Sii_BL = B_Xi_i @ S[i]
191
+ J = J.at[0:6, 6 + ii].set(B_Sii_BL.squeeze())
192
+
193
+ return (B_X_i, J), None
194
+
195
+ (B_X_i, J), _ = (
196
+ jax.lax.scan(
197
+ f=compute_full_jacobian,
198
+ init=compute_full_jacobian_carry,
199
+ xs=np.arange(start=1, stop=model.number_of_links()),
200
+ )
201
+ if model.number_of_links() > 1
202
+ else [(B_X_i, J), None]
203
+ )
204
+
205
+ # Convert adjoints to SE(3) transforms.
206
+ # Returning them here prevents calling FK in case the output representation
207
+ # of the Jacobian needs to be changed.
208
+ B_H_L = jax.vmap(Adjoint.to_transform)(B_X_i)
209
+
210
+ # Adjust shape of doubly-left free-floating full Jacobian.
211
+ B_J_full_WL_B = J.squeeze().astype(float)
212
+
213
+ return B_J_full_WL_B, B_H_L
214
+
215
+
216
+ def jacobian_derivative_full_doubly_left(
217
+ model: js.model.JaxSimModel,
218
+ *,
219
+ joint_positions: jtp.VectorLike,
220
+ joint_velocities: jtp.VectorLike,
221
+ ) -> tuple[jtp.Matrix, jtp.Array]:
222
+ r"""
223
+ Compute the derivative of the doubly-left full free-floating Jacobian of a model.
224
+
225
+ The derivative of the full Jacobian is a 6x(6+n) matrix with all the columns filled.
226
+ It is useful to run the algorithm once, and then extract the link Jacobian
227
+ derivative by filtering the columns of the full Jacobian using the support
228
+ parent array :math:`\kappa(i)` of the link.
229
+
230
+ Args:
231
+ model: The model to consider.
232
+ joint_positions: The positions of the joints.
233
+ joint_velocities: The velocities of the joints.
234
+
235
+ Returns:
236
+ The derivative of the doubly-left full free-floating Jacobian of a model.
237
+ """
238
+
239
+ _, _, s, _, ṡ, _, _, _, _, _ = utils.process_inputs(
240
+ model=model, joint_positions=joint_positions, joint_velocities=joint_velocities
241
+ )
242
+
243
+ # Get the parent array λ(i).
244
+ # Note: λ(0) must not be used, it's initialized to -1.
245
+ λ = model.kin_dyn_parameters.parent_array
246
+
247
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
248
+ # These transforms define the relative kinematics of the entire model, including
249
+ # the base transform for both floating-base and fixed-base models.
250
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
251
+ joint_positions=s, base_transform=jnp.eye(4)
252
+ )
253
+
254
+ # Allocate the buffer of 6D transform base -> link.
255
+ B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
256
+ B_X_i = B_X_i.at[0].set(jnp.eye(6))
257
+
258
+ # Allocate the buffer of 6D transform derivatives base -> link.
259
+ B_Ẋ_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
260
+
261
+ # Allocate the buffer of the 6D link velocity in body-fixed representation.
262
+ B_v_Bi = jnp.zeros(shape=(model.number_of_links(), 6))
263
+
264
+ # Helper to compute the time derivative of the adjoint matrix.
265
+ def A_Ẋ_B(A_X_B: jtp.Matrix, B_v_AB: jtp.Vector) -> jtp.Matrix:
266
+ return A_X_B @ Cross.vx(B_v_AB).squeeze()
267
+
268
+ # ============================================
269
+ # Compute doubly-left full Jacobian derivative
270
+ # ============================================
271
+
272
+ # Allocate the Jacobian matrix.
273
+ J̇ = jnp.zeros(shape=(6, 6 + model.dofs()))
274
+
275
+ ComputeFullJacobianDerivativeCarry = tuple[
276
+ jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix
277
+ ]
278
+
279
+ compute_full_jacobian_derivative_carry: ComputeFullJacobianDerivativeCarry = (
280
+ B_v_Bi,
281
+ B_X_i,
282
+ B_Ẋ_i,
283
+ J̇,
284
+ )
285
+
286
+ def compute_full_jacobian_derivative(
287
+ carry: ComputeFullJacobianDerivativeCarry, i: jtp.Int
288
+ ) -> tuple[ComputeFullJacobianDerivativeCarry, None]:
289
+
290
+ ii = i - 1
291
+ B_v_Bi, B_X_i, B_Ẋ_i, J̇ = carry
292
+
293
+ # Compute the base (0) to link (i) adjoint matrix.
294
+ B_Xi_i = B_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i])
295
+ B_X_i = B_X_i.at[i].set(B_Xi_i)
296
+
297
+ # Compute the body-fixed velocity of the link.
298
+ B_vi_Bi = B_v_Bi[λ[i]] + B_X_i[i] @ S[i].squeeze() * ṡ[ii]
299
+ B_v_Bi = B_v_Bi.at[i].set(B_vi_Bi)
300
+
301
+ # Compute the base (0) to link (i) adjoint matrix derivative.
302
+ i_Xi_B = Adjoint.inverse(B_Xi_i)
303
+ B_Ẋi_i = A_Ẋ_B(A_X_B=B_Xi_i, B_v_AB=i_Xi_B @ B_vi_Bi)
304
+ B_Ẋ_i = B_Ẋ_i.at[i].set(B_Ẋi_i)
305
+
306
+ # Compute the ii-th column of the B_Ṡ_BL(s) matrix.
307
+ B_Ṡii_BL = B_Ẋ_i[i] @ S[i]
308
+ J̇ = J̇.at[0:6, 6 + ii].set(B_Ṡii_BL.squeeze())
309
+
310
+ return (B_v_Bi, B_X_i, B_Ẋ_i, J̇), None
311
+
312
+ (_, B_X_i, B_Ẋ_i, J̇), _ = (
313
+ jax.lax.scan(
314
+ f=compute_full_jacobian_derivative,
315
+ init=compute_full_jacobian_derivative_carry,
316
+ xs=np.arange(start=1, stop=model.number_of_links()),
317
+ )
318
+ if model.number_of_links() > 1
319
+ else [(_, B_X_i, B_Ẋ_i, J̇), None]
320
+ )
321
+
322
+ # Convert adjoints to SE(3) transforms.
323
+ # Returning them here prevents calling FK in case the output representation
324
+ # of the Jacobian needs to be changed.
325
+ B_H_L = jax.vmap(Adjoint.to_transform)(B_X_i)
326
+
327
+ # Adjust shape of doubly-left free-floating full Jacobian derivative.
328
+ B_J̇_full_WL_B = J̇.squeeze().astype(float)
329
+
330
+ return B_J̇_full_WL_B, B_H_L