jaxsim 0.1rc0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +5 -6
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -0
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +216 -0
- jaxsim/api/contact.py +271 -0
- jaxsim/api/data.py +821 -0
- jaxsim/api/joint.py +189 -0
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +361 -0
- jaxsim/api/model.py +1633 -0
- jaxsim/api/ode.py +295 -0
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +421 -0
- jaxsim/integrators/__init__.py +2 -0
- jaxsim/integrators/common.py +594 -0
- jaxsim/integrators/fixed_step.py +102 -0
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +92 -0
- jaxsim/mujoco/__init__.py +3 -0
- jaxsim/mujoco/__main__.py +192 -0
- jaxsim/mujoco/loaders.py +615 -0
- jaxsim/mujoco/model.py +414 -0
- jaxsim/mujoco/visualizer.py +176 -0
- jaxsim/parsers/descriptions/collision.py +14 -0
- jaxsim/parsers/descriptions/link.py +13 -2
- jaxsim/parsers/kinematic_graph.py +8 -3
- jaxsim/parsers/rod/parser.py +54 -38
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/{physics/algos → terrain}/terrain.py +4 -6
- jaxsim/typing.py +30 -30
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -31
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
- jaxsim-0.2.0.dist-info/METADATA +237 -0
- jaxsim-0.2.0.dist-info/RECORD +64 -0
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1695
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -101
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -256
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -454
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -358
- jaxsim/physics/model/physics_model_state.py +0 -174
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -452
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -53
- jaxsim/simulation/ode_integration.py +0 -125
- jaxsim/simulation/simulator.py +0 -544
- jaxsim/simulation/simulator_callbacks.py +0 -53
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -532
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.1rc0.dist-info/METADATA +0 -167
- jaxsim-0.1rc0.dist-info/RECORD +0 -64
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,68 +1,68 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
1
|
import jax
|
4
2
|
import jax.numpy as jnp
|
5
|
-
import numpy as np
|
6
3
|
|
4
|
+
import jaxsim.api as js
|
7
5
|
import jaxsim.typing as jtp
|
8
|
-
from jaxsim.physics.model.physics_model import PhysicsModel
|
9
6
|
|
10
7
|
from . import utils
|
11
8
|
|
12
9
|
|
13
|
-
def crba(model:
|
10
|
+
def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Matrix:
|
14
11
|
"""
|
15
|
-
Compute the Composite Rigid-Body
|
12
|
+
Compute the free-floating mass matrix using the Composite Rigid-Body Algorithm (CRBA).
|
16
13
|
|
17
14
|
Args:
|
18
|
-
model
|
19
|
-
|
15
|
+
model: The model to consider.
|
16
|
+
joint_positions: The positions of the joints.
|
20
17
|
|
21
18
|
Returns:
|
22
|
-
|
19
|
+
The free-floating mass matrix of the model in body-fixed representation.
|
23
20
|
"""
|
24
21
|
|
25
|
-
_,
|
26
|
-
|
22
|
+
_, _, s, _, _, _, _, _, _, _ = utils.process_inputs(
|
23
|
+
model=model, joint_positions=joint_positions
|
27
24
|
)
|
28
25
|
|
29
|
-
|
30
|
-
Mc = model.
|
31
|
-
S = model.motion_subspaces(q=q)
|
32
|
-
Xj = model.joint_transforms(q=q)
|
26
|
+
# Get the 6D spatial inertia matrices of all links.
|
27
|
+
Mc = js.model.link_spatial_inertia_matrices(model=model)
|
33
28
|
|
34
|
-
|
35
|
-
|
36
|
-
|
29
|
+
# Get the parent array λ(i).
|
30
|
+
# Note: λ(0) must not be used, it's initialized to -1.
|
31
|
+
λ = model.kin_dyn_parameters.parent_array
|
32
|
+
|
33
|
+
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
|
34
|
+
# These transforms define the relative kinematics of the entire model, including
|
35
|
+
# the base transform for both floating-base and fixed-base models.
|
36
|
+
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
|
37
|
+
joint_positions=s, base_transform=jnp.eye(4)
|
38
|
+
)
|
37
39
|
|
38
|
-
#
|
39
|
-
|
40
|
-
|
40
|
+
# Allocate the buffer of transforms link -> base.
|
41
|
+
i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
|
42
|
+
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
|
41
43
|
|
42
44
|
# ====================
|
43
45
|
# Propagate kinematics
|
44
46
|
# ====================
|
45
47
|
|
46
|
-
ForwardPassCarry =
|
47
|
-
forward_pass_carry = (
|
48
|
+
ForwardPassCarry = tuple[jtp.MatrixJax]
|
49
|
+
forward_pass_carry: ForwardPassCarry = (i_X_0,)
|
48
50
|
|
49
51
|
def propagate_kinematics(
|
50
52
|
carry: ForwardPassCarry, i: jtp.Int
|
51
|
-
) ->
|
52
|
-
Xup, i_X_0 = carry
|
53
|
+
) -> tuple[ForwardPassCarry, None]:
|
53
54
|
|
54
|
-
|
55
|
-
Xup = Xup.at[i].set(Xup_i)
|
55
|
+
(i_X_0,) = carry
|
56
56
|
|
57
|
-
i_X_0_i =
|
57
|
+
i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
|
58
58
|
i_X_0 = i_X_0.at[i].set(i_X_0_i)
|
59
59
|
|
60
|
-
return (
|
60
|
+
return (i_X_0,), None
|
61
61
|
|
62
|
-
(
|
62
|
+
(i_X_0,), _ = jax.lax.scan(
|
63
63
|
f=propagate_kinematics,
|
64
64
|
init=forward_pass_carry,
|
65
|
-
xs=
|
65
|
+
xs=jnp.arange(start=1, stop=model.number_of_links()),
|
66
66
|
)
|
67
67
|
|
68
68
|
# ===================
|
@@ -71,16 +71,17 @@ def crba(model: PhysicsModel, q: jtp.Vector) -> jtp.Matrix:
|
|
71
71
|
|
72
72
|
M = jnp.zeros(shape=(6 + model.dofs(), 6 + model.dofs()))
|
73
73
|
|
74
|
-
BackwardPassCarry =
|
75
|
-
backward_pass_carry = (Mc, M)
|
74
|
+
BackwardPassCarry = tuple[jtp.MatrixJax, jtp.MatrixJax]
|
75
|
+
backward_pass_carry: BackwardPassCarry = (Mc, M)
|
76
76
|
|
77
77
|
def backward_pass(
|
78
78
|
carry: BackwardPassCarry, i: jtp.Int
|
79
|
-
) ->
|
79
|
+
) -> tuple[BackwardPassCarry, None]:
|
80
|
+
|
80
81
|
ii = i - 1
|
81
82
|
Mc, M = carry
|
82
83
|
|
83
|
-
Mc_λi = Mc[λ[i]] +
|
84
|
+
Mc_λi = Mc[λ[i]] + i_X_λi[i].T @ Mc[i] @ i_X_λi[i]
|
84
85
|
Mc = Mc.at[λ[i]].set(Mc_λi)
|
85
86
|
|
86
87
|
Fi = Mc[i] @ S[i]
|
@@ -89,13 +90,13 @@ def crba(model: PhysicsModel, q: jtp.Vector) -> jtp.Matrix:
|
|
89
90
|
|
90
91
|
j = i
|
91
92
|
|
92
|
-
CarryInnerFn =
|
93
|
+
CarryInnerFn = tuple[jtp.Int, jtp.MatrixJax, jtp.MatrixJax]
|
93
94
|
carry_inner_fn = (j, Fi, M)
|
94
95
|
|
95
96
|
def while_loop_body(carry: CarryInnerFn) -> CarryInnerFn:
|
96
97
|
j, Fi, M = carry
|
97
98
|
|
98
|
-
Fi =
|
99
|
+
Fi = i_X_λi[j].T @ Fi
|
99
100
|
j = λ[j]
|
100
101
|
jj = j - 1
|
101
102
|
|
@@ -108,8 +109,8 @@ def crba(model: PhysicsModel, q: jtp.Vector) -> jtp.Matrix:
|
|
108
109
|
|
109
110
|
# The following functions are part of a (rather messy) workaround for computing
|
110
111
|
# a while loop using a for loop with fixed number of iterations.
|
111
|
-
def inner_fn(carry: CarryInnerFn, k: jtp.Int) ->
|
112
|
-
def compute_inner(carry: CarryInnerFn) ->
|
112
|
+
def inner_fn(carry: CarryInnerFn, k: jtp.Int) -> tuple[CarryInnerFn, None]:
|
113
|
+
def compute_inner(carry: CarryInnerFn) -> tuple[CarryInnerFn, None]:
|
113
114
|
j, Fi, M = carry
|
114
115
|
out = jax.lax.cond(
|
115
116
|
pred=(λ[j] > 0),
|
@@ -130,7 +131,7 @@ def crba(model: PhysicsModel, q: jtp.Vector) -> jtp.Matrix:
|
|
130
131
|
(j, Fi, M), _ = jax.lax.scan(
|
131
132
|
f=inner_fn,
|
132
133
|
init=carry_inner_fn,
|
133
|
-
xs=
|
134
|
+
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
|
134
135
|
)
|
135
136
|
|
136
137
|
Fi = i_X_0[j].T @ Fi
|
@@ -145,10 +146,10 @@ def crba(model: PhysicsModel, q: jtp.Vector) -> jtp.Matrix:
|
|
145
146
|
(Mc, M), _ = jax.lax.scan(
|
146
147
|
f=backward_pass,
|
147
148
|
init=backward_pass_carry,
|
148
|
-
xs=
|
149
|
+
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
|
149
150
|
)
|
150
151
|
|
151
|
-
# Store the locked 6D rigid-body inertia matrix Mbb ∈
|
152
|
+
# Store the locked 6D rigid-body inertia matrix Mbb ∈ ℝ⁶ˣ⁶.
|
152
153
|
M = M.at[0:6, 0:6].set(Mc[0])
|
153
154
|
|
154
155
|
return M
|
@@ -0,0 +1,113 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import jaxlie
|
4
|
+
|
5
|
+
import jaxsim.api as js
|
6
|
+
import jaxsim.typing as jtp
|
7
|
+
from jaxsim.math import Adjoint, Quaternion
|
8
|
+
|
9
|
+
from . import utils
|
10
|
+
|
11
|
+
|
12
|
+
def forward_kinematics_model(
|
13
|
+
model: js.model.JaxSimModel,
|
14
|
+
*,
|
15
|
+
base_position: jtp.VectorLike,
|
16
|
+
base_quaternion: jtp.VectorLike,
|
17
|
+
joint_positions: jtp.VectorLike,
|
18
|
+
) -> jtp.Array:
|
19
|
+
"""
|
20
|
+
Compute the forward kinematics.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
model: The model to consider.
|
24
|
+
base_position: The position of the base link.
|
25
|
+
base_quaternion: The quaternion of the base link.
|
26
|
+
joint_positions: The positions of the joints.
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
A 3D array containing the SE(3) transforms of all links belonging to the model.
|
30
|
+
"""
|
31
|
+
|
32
|
+
W_p_B, W_Q_B, s, _, _, _, _, _, _, _ = utils.process_inputs(
|
33
|
+
model=model,
|
34
|
+
base_position=base_position,
|
35
|
+
base_quaternion=base_quaternion,
|
36
|
+
joint_positions=joint_positions,
|
37
|
+
)
|
38
|
+
|
39
|
+
# Get the parent array λ(i).
|
40
|
+
# Note: λ(0) must not be used, it's initialized to -1.
|
41
|
+
λ = model.kin_dyn_parameters.parent_array
|
42
|
+
|
43
|
+
# Compute the base transform.
|
44
|
+
W_H_B = jaxlie.SE3.from_rotation_and_translation(
|
45
|
+
rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
|
46
|
+
translation=W_p_B,
|
47
|
+
)
|
48
|
+
|
49
|
+
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
|
50
|
+
# These transforms define the relative kinematics of the entire model, including
|
51
|
+
# the base transform for both floating-base and fixed-base models.
|
52
|
+
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
|
53
|
+
joint_positions=s, base_transform=W_H_B.as_matrix()
|
54
|
+
)
|
55
|
+
|
56
|
+
# Allocate the buffer of transforms world -> link and initialize the base pose.
|
57
|
+
W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
|
58
|
+
W_X_i = W_X_i.at[0].set(Adjoint.inverse(i_X_λi[0]))
|
59
|
+
|
60
|
+
# ========================
|
61
|
+
# Propagate the kinematics
|
62
|
+
# ========================
|
63
|
+
|
64
|
+
PropagateKinematicsCarry = tuple[jtp.MatrixJax]
|
65
|
+
propagate_kinematics_carry: PropagateKinematicsCarry = (W_X_i,)
|
66
|
+
|
67
|
+
def propagate_kinematics(
|
68
|
+
carry: PropagateKinematicsCarry, i: jtp.Int
|
69
|
+
) -> tuple[PropagateKinematicsCarry, None]:
|
70
|
+
|
71
|
+
(W_X_i,) = carry
|
72
|
+
|
73
|
+
W_X_i_i = W_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i])
|
74
|
+
W_X_i = W_X_i.at[i].set(W_X_i_i)
|
75
|
+
|
76
|
+
return (W_X_i,), None
|
77
|
+
|
78
|
+
(W_X_i,), _ = jax.lax.scan(
|
79
|
+
f=propagate_kinematics,
|
80
|
+
init=propagate_kinematics_carry,
|
81
|
+
xs=jnp.arange(start=1, stop=model.number_of_links()),
|
82
|
+
)
|
83
|
+
|
84
|
+
return jax.vmap(Adjoint.to_transform)(W_X_i)
|
85
|
+
|
86
|
+
|
87
|
+
def forward_kinematics(
|
88
|
+
model: js.model.JaxSimModel,
|
89
|
+
link_index: jtp.Int,
|
90
|
+
base_position: jtp.VectorLike,
|
91
|
+
base_quaternion: jtp.VectorLike,
|
92
|
+
joint_positions: jtp.VectorLike,
|
93
|
+
) -> jtp.Matrix:
|
94
|
+
"""
|
95
|
+
Compute the forward kinematics of a specific link.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
model: The model to consider.
|
99
|
+
link_index: The index of the link to consider.
|
100
|
+
base_position: The position of the base link.
|
101
|
+
base_quaternion: The quaternion of the base link.
|
102
|
+
joint_positions: The positions of the joints.
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
The SE(3) transform of the link.
|
106
|
+
"""
|
107
|
+
|
108
|
+
return forward_kinematics_model(
|
109
|
+
model=model,
|
110
|
+
base_position=base_position,
|
111
|
+
base_quaternion=base_quaternion,
|
112
|
+
joint_positions=joint_positions,
|
113
|
+
)[link_index]
|
jaxsim/rbda/jacobian.py
ADDED
@@ -0,0 +1,201 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
import jaxsim.api as js
|
6
|
+
import jaxsim.typing as jtp
|
7
|
+
from jaxsim.math import Adjoint
|
8
|
+
|
9
|
+
from . import utils
|
10
|
+
|
11
|
+
|
12
|
+
def jacobian(
|
13
|
+
model: js.model.JaxSimModel,
|
14
|
+
*,
|
15
|
+
link_index: jtp.Int,
|
16
|
+
joint_positions: jtp.VectorLike,
|
17
|
+
) -> jtp.Matrix:
|
18
|
+
"""
|
19
|
+
Compute the free-floating Jacobian of a link.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
model: The model to consider.
|
23
|
+
link_index: The index of the link for which to compute the Jacobian matrix.
|
24
|
+
joint_positions: The positions of the joints.
|
25
|
+
|
26
|
+
Returns:
|
27
|
+
The free-floating left-trivialized Jacobian of the link :math:`{}^L J_{W,L/B}`.
|
28
|
+
"""
|
29
|
+
|
30
|
+
_, _, s, _, _, _, _, _, _, _ = utils.process_inputs(
|
31
|
+
model=model, joint_positions=joint_positions
|
32
|
+
)
|
33
|
+
|
34
|
+
# Get the parent array λ(i).
|
35
|
+
# Note: λ(0) must not be used, it's initialized to -1.
|
36
|
+
λ = model.kin_dyn_parameters.parent_array
|
37
|
+
|
38
|
+
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
|
39
|
+
# These transforms define the relative kinematics of the entire model, including
|
40
|
+
# the base transform for both floating-base and fixed-base models.
|
41
|
+
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
|
42
|
+
joint_positions=s, base_transform=jnp.eye(4)
|
43
|
+
)
|
44
|
+
|
45
|
+
# Allocate the buffer of transforms link -> base.
|
46
|
+
i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
|
47
|
+
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
|
48
|
+
|
49
|
+
# ====================
|
50
|
+
# Propagate kinematics
|
51
|
+
# ====================
|
52
|
+
|
53
|
+
PropagateKinematicsCarry = tuple[jtp.MatrixJax]
|
54
|
+
propagate_kinematics_carry: PropagateKinematicsCarry = (i_X_0,)
|
55
|
+
|
56
|
+
def propagate_kinematics(
|
57
|
+
carry: PropagateKinematicsCarry, i: jtp.Int
|
58
|
+
) -> tuple[PropagateKinematicsCarry, None]:
|
59
|
+
|
60
|
+
(i_X_0,) = carry
|
61
|
+
|
62
|
+
# Compute the base (0) to link (i) adjoint matrix.
|
63
|
+
# This works fine since we traverse the kinematic tree following the link
|
64
|
+
# indices assigned with BFS.
|
65
|
+
i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
|
66
|
+
i_X_0 = i_X_0.at[i].set(i_X_0_i)
|
67
|
+
|
68
|
+
return (i_X_0,), None
|
69
|
+
|
70
|
+
(i_X_0,), _ = jax.lax.scan(
|
71
|
+
f=propagate_kinematics,
|
72
|
+
init=propagate_kinematics_carry,
|
73
|
+
xs=np.arange(start=1, stop=model.number_of_links()),
|
74
|
+
)
|
75
|
+
|
76
|
+
# ============================
|
77
|
+
# Compute doubly-left Jacobian
|
78
|
+
# ============================
|
79
|
+
|
80
|
+
J = jnp.zeros(shape=(6, 6 + model.dofs()))
|
81
|
+
|
82
|
+
Jb = i_X_0[link_index]
|
83
|
+
J = J.at[0:6, 0:6].set(Jb)
|
84
|
+
|
85
|
+
# To make JIT happy, we operate on a boolean version of κ(i).
|
86
|
+
# Checking if j ∈ κ(i) is equivalent to: κ_bool(j) is True.
|
87
|
+
κ_bool = model.kin_dyn_parameters.support_body_array_bool[link_index]
|
88
|
+
|
89
|
+
def compute_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> tuple[jtp.MatrixJax, None]:
|
90
|
+
|
91
|
+
def update_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> jtp.MatrixJax:
|
92
|
+
|
93
|
+
ii = i - 1
|
94
|
+
|
95
|
+
Js_i = i_X_0[link_index] @ Adjoint.inverse(i_X_0[i]) @ S[i]
|
96
|
+
J = J.at[0:6, 6 + ii].set(Js_i.squeeze())
|
97
|
+
|
98
|
+
return J
|
99
|
+
|
100
|
+
J = jax.lax.select(
|
101
|
+
pred=κ_bool[i],
|
102
|
+
on_true=update_jacobian(J, i),
|
103
|
+
on_false=J,
|
104
|
+
)
|
105
|
+
|
106
|
+
return J, None
|
107
|
+
|
108
|
+
L_J_WL_B, _ = jax.lax.scan(
|
109
|
+
f=compute_jacobian,
|
110
|
+
init=J,
|
111
|
+
xs=np.arange(start=1, stop=model.number_of_links()),
|
112
|
+
)
|
113
|
+
|
114
|
+
return L_J_WL_B
|
115
|
+
|
116
|
+
|
117
|
+
@jax.jit
|
118
|
+
def jacobian_full_doubly_left(
|
119
|
+
model: js.model.JaxSimModel,
|
120
|
+
*,
|
121
|
+
joint_positions: jtp.VectorLike,
|
122
|
+
) -> tuple[jtp.Matrix, jtp.Array]:
|
123
|
+
r"""
|
124
|
+
Compute the doubly-left full free-floating Jacobian of a model.
|
125
|
+
|
126
|
+
The full Jacobian is a 6x(6+n) matrix with all the columns filled.
|
127
|
+
It is useful to run the algorithm once, and then extract the link Jacobian by
|
128
|
+
filtering the columns of the full Jacobian using the support parent array
|
129
|
+
:math:`\kappa(i)` of the link.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
model: The model to consider.
|
133
|
+
joint_positions: The positions of the joints.
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
The doubly-left full free-floating Jacobian of a model.
|
137
|
+
"""
|
138
|
+
|
139
|
+
_, _, s, _, _, _, _, _, _, _ = utils.process_inputs(
|
140
|
+
model=model, joint_positions=joint_positions
|
141
|
+
)
|
142
|
+
|
143
|
+
# Get the parent array λ(i).
|
144
|
+
# Note: λ(0) must not be used, it's initialized to -1.
|
145
|
+
λ = model.kin_dyn_parameters.parent_array
|
146
|
+
|
147
|
+
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
|
148
|
+
# These transforms define the relative kinematics of the entire model, including
|
149
|
+
# the base transform for both floating-base and fixed-base models.
|
150
|
+
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
|
151
|
+
joint_positions=s, base_transform=jnp.eye(4)
|
152
|
+
)
|
153
|
+
|
154
|
+
# Allocate the buffer of transforms base -> link.
|
155
|
+
B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
|
156
|
+
B_X_i = B_X_i.at[0].set(jnp.eye(6))
|
157
|
+
|
158
|
+
# =============================
|
159
|
+
# Compute doubly-left Jacobian
|
160
|
+
# =============================
|
161
|
+
|
162
|
+
# Allocate the Jacobian matrix.
|
163
|
+
# The Jbb section of the doubly-left Jacobian is an identity matrix.
|
164
|
+
J = jnp.zeros(shape=(6, 6 + model.dofs()))
|
165
|
+
J = J.at[0:6, 0:6].set(jnp.eye(6))
|
166
|
+
|
167
|
+
ComputeFullJacobianCarry = tuple[jtp.MatrixJax, jtp.MatrixJax]
|
168
|
+
compute_full_jacobian_carry: ComputeFullJacobianCarry = (B_X_i, J)
|
169
|
+
|
170
|
+
def compute_full_jacobian(
|
171
|
+
carry: ComputeFullJacobianCarry, i: jtp.Int
|
172
|
+
) -> tuple[ComputeFullJacobianCarry, None]:
|
173
|
+
|
174
|
+
ii = i - 1
|
175
|
+
B_X_i, J = carry
|
176
|
+
|
177
|
+
# Compute the base (0) to link (i) adjoint matrix.
|
178
|
+
B_Xi_i = B_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i])
|
179
|
+
B_X_i = B_X_i.at[i].set(B_Xi_i)
|
180
|
+
|
181
|
+
# Compute the ii-th column of the B_S_BL(s) matrix.
|
182
|
+
B_Sii_BL = B_Xi_i @ S[i]
|
183
|
+
J = J.at[0:6, 6 + ii].set(B_Sii_BL.squeeze())
|
184
|
+
|
185
|
+
return (B_X_i, J), None
|
186
|
+
|
187
|
+
(B_X_i, J), _ = jax.lax.scan(
|
188
|
+
f=compute_full_jacobian,
|
189
|
+
init=compute_full_jacobian_carry,
|
190
|
+
xs=np.arange(start=1, stop=model.number_of_links()),
|
191
|
+
)
|
192
|
+
|
193
|
+
# Convert adjoints to SE(3) transforms.
|
194
|
+
# Returning them here prevents calling FK in case the output representation
|
195
|
+
# of the Jacobian needs to be changed.
|
196
|
+
B_H_L = jax.vmap(lambda B_X_L: Adjoint.to_transform(B_X_L))(B_X_i)
|
197
|
+
|
198
|
+
# Adjust shape of doubly-left free-floating full Jacobian.
|
199
|
+
B_J_full_WL_B = J.squeeze().astype(float)
|
200
|
+
|
201
|
+
return B_J_full_WL_B, B_H_L
|