jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -129
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +87 -16
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +62 -24
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +607 -225
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -80
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -55
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev188.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/math/conv.py DELETED
@@ -1,114 +0,0 @@
1
- import jax.numpy as jnp
2
-
3
- import jaxsim.typing as jtp
4
-
5
- from .skew import Skew
6
-
7
-
8
- class Convert:
9
- @staticmethod
10
- def coordinates_tf(X: jtp.Matrix, p: jtp.Matrix) -> jtp.Matrix:
11
- """
12
- Transform coordinates from one frame to another using a transformation matrix.
13
-
14
- Args:
15
- X (jtp.Matrix): The transformation matrix (4x4 or 6x6).
16
- p (jtp.Matrix): The coordinates to be transformed (3xN).
17
-
18
- Returns:
19
- jtp.Matrix: Transformed coordinates (3xN).
20
-
21
- Raises:
22
- ValueError: If the input matrix p does not have shape (3, N).
23
- """
24
- X = X.squeeze()
25
- p = p.squeeze()
26
-
27
- # If p has shape (X,), transform it to a column vector
28
- p = jnp.vstack(p) if len(p.shape) == 1 else p
29
- rows_p, cols_p = p.shape
30
-
31
- if rows_p != 3:
32
- raise ValueError(p.shape)
33
-
34
- R = X[0:3, 0:3]
35
- r = -Skew.vee(R.T @ X[0:3, 3:6])
36
-
37
- if cols_p > 1:
38
- r = jnp.tile(r, (1, cols_p))
39
-
40
- assert r.shape == p.shape, (r.shape, p.shape)
41
-
42
- xp = R @ (p - r)
43
- return jnp.vstack(xp)
44
-
45
- @staticmethod
46
- def velocities_threed(v_6d: jtp.Matrix, p: jtp.Matrix) -> jtp.Matrix:
47
- """
48
- Compute 3D velocities based on 6D velocities and positions.
49
-
50
- Args:
51
- v_6d (jtp.Matrix): The 6D velocities (6xN).
52
- p (jtp.Matrix): The positions (3xN).
53
-
54
- Returns:
55
- jtp.Matrix: 3D velocities (3xN).
56
-
57
- Raises:
58
- ValueError: If the input matrices have incompatible shapes.
59
- """
60
- v = v_6d.squeeze()
61
- p = p.squeeze()
62
-
63
- # If the arrays have shape (X,), transform them to column vectors
64
- v = jnp.vstack(v) if len(v.shape) == 1 else v
65
- p = jnp.vstack(p) if len(p.shape) == 1 else p
66
-
67
- rows_v, cols_v = v.shape
68
- _, cols_p = p.shape
69
-
70
- if cols_v == 1 and cols_p > 1:
71
- v = jnp.repeat(v, cols_p, axis=1)
72
-
73
- if rows_v == 6:
74
- vp = v[0:3, :] + jnp.cross(v[3:6, :], p, axis=0)
75
- else:
76
- raise ValueError(v.shape)
77
-
78
- return jnp.vstack(vp)
79
-
80
- @staticmethod
81
- def forces_sixd(f_3d: jtp.Matrix, p: jtp.Matrix) -> jtp.Matrix:
82
- """
83
- Compute 6D forces based on 3D forces and positions.
84
-
85
- Args:
86
- f_3d (jtp.Matrix): The 3D forces (3xN).
87
- p (jtp.Matrix): The positions (3xN).
88
-
89
- Returns:
90
- jtp.Matrix: 6D forces (6xN).
91
-
92
- Raises:
93
- ValueError: If the input matrices have incompatible shapes.
94
- """
95
- f = f_3d.squeeze()
96
- p = p.squeeze()
97
-
98
- # If the arrays have shape (X,), transform them to column vectors
99
- fp = jnp.vstack(f) if len(f.shape) == 1 else f
100
- p = jnp.vstack(p) if len(p.shape) == 1 else p
101
-
102
- _, cols_p = p.shape
103
- rows_fp, cols_fp = fp.shape
104
-
105
- # Number of columns must match
106
- if cols_p != cols_fp:
107
- raise ValueError(cols_p, cols_fp)
108
-
109
- if rows_fp == 3:
110
- f = jnp.vstack([fp, jnp.cross(p, fp, axis=0)])
111
- else:
112
- raise ValueError(fp.shape)
113
-
114
- return jnp.vstack(f)
jaxsim/math/joint.py DELETED
@@ -1,102 +0,0 @@
1
- from typing import Tuple, Union
2
-
3
- import jax.numpy as jnp
4
-
5
- import jaxsim.typing as jtp
6
- from jaxsim.parsers.descriptions import JointDescriptor, JointGenericAxis, JointType
7
-
8
- from .adjoint import Adjoint
9
- from .rotation import Rotation
10
-
11
-
12
- def jcalc(
13
- jtyp: Union[JointType, JointDescriptor], q: jtp.Float
14
- ) -> Tuple[jtp.Matrix, jtp.Vector]:
15
- """
16
- Compute the spatial transformation matrix and motion subspace vector for a joint.
17
-
18
- Args:
19
- jtyp (Union[JointType, JointDescriptor]): The type or descriptor of the joint.
20
- q (jtp.Float): The joint configuration parameter.
21
-
22
- Returns:
23
- Tuple[jtp.Matrix, jtp.Vector]: A tuple containing the spatial transformation matrix (6x6) and the motion subspace vector (6x1).
24
-
25
- Raises:
26
- ValueError: If the joint type or descriptor is not recognized.
27
- """
28
- if isinstance(jtyp, JointType):
29
- code = jtyp
30
- elif isinstance(jtyp, JointDescriptor):
31
- code = jtyp.code
32
- else:
33
- raise ValueError(jtyp)
34
-
35
- match code:
36
- case JointType.F:
37
- raise ValueError("Fixed joints shouldn't be here")
38
-
39
- case JointType.R:
40
- jtyp: JointGenericAxis
41
-
42
- Xj = Adjoint.from_rotation_and_translation(
43
- rotation=Rotation.from_axis_angle(vector=q * jtyp.axis), inverse=True
44
- )
45
-
46
- S = jnp.vstack(jnp.hstack([jnp.zeros(3), jtyp.axis.squeeze()]))
47
-
48
- case JointType.P:
49
- jtyp: JointGenericAxis
50
-
51
- Xj = Adjoint.from_rotation_and_translation(
52
- translation=jnp.array(q * jtyp.axis), inverse=True
53
- )
54
-
55
- S = jnp.vstack(jnp.hstack([jtyp.axis.squeeze(), jnp.zeros(3)]))
56
-
57
- case JointType.Rx:
58
- Xj = Adjoint.from_rotation_and_translation(
59
- rotation=Rotation.x(theta=q), inverse=True
60
- )
61
-
62
- S = jnp.vstack([0, 0, 0, 1.0, 0, 0])
63
-
64
- case JointType.Ry:
65
- Xj = Adjoint.from_rotation_and_translation(
66
- rotation=Rotation.y(theta=q), inverse=True
67
- )
68
-
69
- S = jnp.vstack([0, 0, 0, 0, 1.0, 0])
70
-
71
- case JointType.Rz:
72
- Xj = Adjoint.from_rotation_and_translation(
73
- rotation=Rotation.z(theta=q), inverse=True
74
- )
75
-
76
- S = jnp.vstack([0, 0, 0, 0, 0, 1.0])
77
-
78
- case JointType.Px:
79
- Xj = Adjoint.from_rotation_and_translation(
80
- translation=jnp.array([q, 0.0, 0.0]), inverse=True
81
- )
82
-
83
- S = jnp.vstack([1.0, 0, 0, 0, 0, 0])
84
-
85
- case JointType.Py:
86
- Xj = Adjoint.from_rotation_and_translation(
87
- translation=jnp.array([0.0, q, 0.0]), inverse=True
88
- )
89
-
90
- S = jnp.vstack([0, 1.0, 0, 0, 0, 0])
91
-
92
- case JointType.Pz:
93
- Xj = Adjoint.from_rotation_and_translation(
94
- translation=jnp.array([0.0, 0.0, q]), inverse=True
95
- )
96
-
97
- S = jnp.vstack([0, 0, 1.0, 0, 0, 0])
98
-
99
- case _:
100
- raise ValueError(code)
101
-
102
- return Xj, S
jaxsim/math/plucker.py DELETED
@@ -1,100 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax.numpy as jnp
4
-
5
- import jaxsim.typing as jtp
6
-
7
- from .skew import Skew
8
-
9
-
10
- class Plucker:
11
- @staticmethod
12
- def from_rot_and_trans(dcm: jtp.Matrix, translation: jtp.Vector) -> jtp.Matrix:
13
- """
14
- Computes the Plücker matrix from a rotation matrix and a translation vector.
15
-
16
- Args:
17
- dcm: A 3x3 rotation matrix.
18
- translation: A 3x1 translation vector.
19
-
20
- Returns:
21
- A 6x6 Plücker matrix.
22
- """
23
- R = dcm
24
-
25
- X = jnp.block(
26
- [
27
- [R, -R @ Skew.wedge(vector=translation)],
28
- [jnp.zeros(shape=(3, 3)), R],
29
- ]
30
- )
31
-
32
- return X
33
-
34
- @staticmethod
35
- def to_rot_and_trans(adjoint: jtp.Matrix) -> Tuple[jtp.Matrix, jtp.Vector]:
36
- """
37
- Computes the rotation matrix and translation vector from a Plücker matrix.
38
-
39
- Args:
40
- adjoint: A 6x6 Plücker matrix.
41
-
42
- Returns:
43
- A tuple containing the 3x3 rotation matrix and the 3x1 translation vector.
44
- """
45
- X = adjoint
46
-
47
- R = X[0:3, 0:3]
48
- p = -Skew.vee(R.T @ X[0:3, 3:6])
49
-
50
- return R, p
51
-
52
- @staticmethod
53
- def from_transform(transform: jtp.Matrix) -> jtp.Matrix:
54
- """
55
- Computes the Plücker matrix from a homogeneous transformation matrix.
56
-
57
- Args:
58
- transform: A 4x4 homogeneous transformation matrix.
59
-
60
- Returns:
61
- A 6x6 Plücker matrix.
62
- """
63
- H = transform
64
-
65
- R = H[0:3, 0:3]
66
- p = H[0:3, 3]
67
-
68
- X = jnp.block(
69
- [
70
- [R, Skew.wedge(vector=p) @ R],
71
- [jnp.zeros(shape=(3, 3)), R],
72
- ]
73
- )
74
-
75
- return X
76
-
77
- @staticmethod
78
- def to_transform(adjoint: jtp.Matrix) -> jtp.Matrix:
79
- """
80
- Computes the homogeneous transformation matrix from a Plücker matrix.
81
-
82
- Args:
83
- adjoint: A 6x6 Plücker matrix.
84
-
85
- Returns:
86
- A 4x4 homogeneous transformation matrix.
87
- """
88
- X = adjoint
89
-
90
- R = X[0:3, 0:3]
91
- o_x_R = X[0:3, 3:6]
92
-
93
- H = jnp.vstack(
94
- [
95
- jnp.hstack([R, Skew.vee(matrix=o_x_R @ R.T)]),
96
- [0, 0, 0, 1],
97
- ]
98
- )
99
-
100
- return H
@@ -1,12 +0,0 @@
1
- import numpy.typing
2
-
3
- from . import algos, model
4
-
5
-
6
- def default_gravity() -> numpy.typing.NDArray:
7
- import jax.numpy as jnp
8
-
9
- return jnp.array([0, 0, -9.80])
10
-
11
-
12
- # from . import dyn, models, spatial, threed, utils
File without changes
@@ -1,254 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import numpy as np
6
-
7
- import jaxsim.typing as jtp
8
- from jaxsim.math.adjoint import Adjoint
9
- from jaxsim.math.cross import Cross
10
- from jaxsim.physics.model.physics_model import PhysicsModel
11
-
12
- from . import utils
13
-
14
-
15
- def aba(
16
- model: PhysicsModel,
17
- xfb: jtp.Vector,
18
- q: jtp.Vector,
19
- qd: jtp.Vector,
20
- tau: jtp.Vector,
21
- f_ext: jtp.Matrix | None = None,
22
- ) -> Tuple[jtp.Vector, jtp.Vector]:
23
- """
24
- Articulated Body Algorithm (ABA) algorithm for forward dynamics.
25
-
26
- Args:
27
- model: The physics model of the articulated body or robot.
28
- xfb: The floating base state vector containing quaternion (4D) and position (3D).
29
- q: Joint positions (Generalized coordinates).
30
- qd: Joint velocities.
31
- tau: Joint torques or forces.
32
- f_ext: External forces and torques acting on each link. Defaults to None.
33
-
34
- Returns:
35
- A tuple containing the resulting base acceleration (in inertial-fixed representation)
36
- and joint accelerations.
37
-
38
- Note:
39
- The ABA algorithm is used to compute the accelerations of the links in an articulated body or robot system given
40
- inputs such as joint positions, velocities, torques, and external forces. The algorithm involves multiple passes
41
- to calculate intermediate quantities required for simulating the motion of the robot.
42
- """
43
-
44
- x_fb, q, qd, _, tau, f_ext = utils.process_inputs(
45
- physics_model=model, xfb=xfb, q=q, qd=qd, tau=tau, f_ext=f_ext
46
- )
47
-
48
- # Extract data from the physics model
49
- pre_X_λi = model.tree_transforms
50
- M = model.spatial_inertias
51
- i_X_pre = model.joint_transforms(q=q)
52
- S = model.motion_subspaces(q=q)
53
- λ = model.parent_array()
54
-
55
- # Initialize buffers
56
- v = jnp.array([jnp.zeros([6, 1])] * model.NB)
57
- MA = jnp.array([jnp.zeros([6, 6])] * model.NB)
58
- pA = jnp.array([jnp.zeros([6, 1])] * model.NB)
59
- c = jnp.array([jnp.zeros([6, 1])] * model.NB)
60
- i_X_λi = jnp.zeros_like(i_X_pre)
61
-
62
- # Base pose B_X_W and velocity
63
- base_quat = jnp.vstack(x_fb[0:4])
64
- base_pos = jnp.vstack(x_fb[4:7])
65
- base_vel = jnp.vstack(jnp.hstack([x_fb[10:13], x_fb[7:10]]))
66
-
67
- # 6D transform of base velocity
68
- B_X_W = Adjoint.from_quaternion_and_translation(
69
- quaternion=base_quat,
70
- translation=base_pos,
71
- inverse=True,
72
- normalize_quaternion=True,
73
- )
74
- i_X_λi = i_X_λi.at[0].set(B_X_W)
75
-
76
- # Transforms link -> base
77
- i_X_0 = jnp.zeros_like(pre_X_λi)
78
- i_X_0 = i_X_0.at[0].set(jnp.eye(6))
79
-
80
- # Initialize base quantities
81
- if model.is_floating_base:
82
- # Base velocity v₀
83
- v_0 = B_X_W @ base_vel
84
- v = v.at[0].set(v_0)
85
-
86
- # AB inertia (Mᴬ) and AB bias forces (pᴬ)
87
- MA_0 = M[0]
88
- MA = MA.at[0].set(MA_0)
89
- pA_0 = Cross.vx_star(v[0]) @ MA_0 @ v[0] - Adjoint.inverse(
90
- B_X_W
91
- ).T @ jnp.vstack(f_ext[0])
92
- pA = pA.at[0].set(pA_0)
93
-
94
- Pass1Carry = Tuple[
95
- jtp.MatrixJax,
96
- jtp.MatrixJax,
97
- jtp.MatrixJax,
98
- jtp.MatrixJax,
99
- jtp.MatrixJax,
100
- jtp.MatrixJax,
101
- ]
102
-
103
- pass_1_carry = (i_X_λi, v, c, MA, pA, i_X_0)
104
-
105
- # Pass 1
106
- def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:
107
- ii = i - 1
108
- i_X_λi, v, c, MA, pA, i_X_0 = carry
109
-
110
- # Compute parent-to-child transform
111
- i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
112
- i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
113
-
114
- # Propagate link velocity
115
- vJ = S[i] * qd[ii] if qd.size != 0 else S[i] * 0
116
-
117
- v_i = i_X_λi[i] @ v[λ[i]] + vJ
118
- v = v.at[i].set(v_i)
119
-
120
- c_i = Cross.vx(v[i]) @ vJ
121
- c = c.at[i].set(c_i)
122
-
123
- # Initialize articulated-body inertia
124
- MA_i = jnp.array(M[i])
125
- MA = MA.at[i].set(MA_i)
126
-
127
- # Initialize articulated-body bias forces
128
- i_X_0_i = i_X_λi[i] @ i_X_0[model.parent[i]]
129
- i_X_0 = i_X_0.at[i].set(i_X_0_i)
130
- i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
131
-
132
- pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(f_ext[i])
133
- pA = pA.at[i].set(pA_i)
134
-
135
- return (i_X_λi, v, c, MA, pA, i_X_0), None
136
-
137
- (i_X_λi, v, c, MA, pA, i_X_0), _ = jax.lax.scan(
138
- f=loop_body_pass1,
139
- init=pass_1_carry,
140
- xs=np.arange(start=1, stop=model.NB),
141
- )
142
-
143
- U = jnp.zeros_like(S)
144
- d = jnp.zeros(shape=(model.NB, 1))
145
- u = jnp.zeros(shape=(model.NB, 1))
146
-
147
- Pass2Carry = Tuple[
148
- jtp.MatrixJax,
149
- jtp.MatrixJax,
150
- jtp.MatrixJax,
151
- jtp.MatrixJax,
152
- jtp.MatrixJax,
153
- ]
154
-
155
- pass_2_carry = (U, d, u, MA, pA)
156
-
157
- # Pass 2
158
- def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]:
159
- ii = i - 1
160
- U, d, u, MA, pA = carry
161
-
162
- # Compute intermediate results
163
- U_i = MA[i] @ S[i]
164
- U = U.at[i].set(U_i)
165
-
166
- d_i = S[i].T @ U[i]
167
- d = d.at[i].set(d_i.squeeze())
168
-
169
- u_i = tau[ii] - S[i].T @ pA[i] if tau.size != 0 else -S[i].T @ pA[i]
170
- u = u.at[i].set(u_i.squeeze())
171
-
172
- # Compute the articulated-body inertia and bias forces of this link
173
- Ma = MA[i] - U[i] / d[i] @ U[i].T
174
- pa = pA[i] + Ma @ c[i] + U[i] * u[i] / d[i]
175
-
176
- # Propagate them to the parent, handling the base link
177
- def propagate(
178
- MA_pA: Tuple[jtp.MatrixJax, jtp.MatrixJax]
179
- ) -> Tuple[jtp.MatrixJax, jtp.MatrixJax]:
180
- MA, pA = MA_pA
181
-
182
- MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]
183
- MA = MA.at[λ[i]].set(MA_λi)
184
-
185
- pA_λi = pA[λ[i]] + i_X_λi[i].T @ pa
186
- pA = pA.at[λ[i]].set(pA_λi)
187
-
188
- return MA, pA
189
-
190
- MA, pA = jax.lax.cond(
191
- pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(),
192
- true_fun=propagate,
193
- false_fun=lambda MA_pA: MA_pA,
194
- operand=(MA, pA),
195
- )
196
-
197
- return (U, d, u, MA, pA), None
198
-
199
- (U, d, u, MA, pA), _ = jax.lax.scan(
200
- f=loop_body_pass2,
201
- init=pass_2_carry,
202
- xs=np.flip(np.arange(start=1, stop=model.NB)),
203
- )
204
-
205
- if model.is_floating_base:
206
- a0 = jnp.linalg.solve(-MA[0], pA[0])
207
- else:
208
- a0 = -B_X_W @ jnp.vstack(model.gravity)
209
-
210
- a = jnp.zeros_like(S)
211
- a = a.at[0].set(a0)
212
- qdd = jnp.zeros_like(q)
213
-
214
- Pass3Carry = Tuple[jtp.MatrixJax, jtp.VectorJax]
215
- pass_3_carry = (a, qdd)
216
-
217
- # Pass 3
218
- def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> Tuple[Pass3Carry, None]:
219
- ii = i - 1
220
- a, qdd = carry
221
-
222
- # Propagate link accelerations
223
- a_i = i_X_λi[i] @ a[λ[i]] + c[i]
224
-
225
- # Compute joint accelerations
226
- qdd_ii = (u[i] - U[i].T @ a_i) / d[i]
227
- qdd = qdd.at[i - 1].set(qdd_ii.squeeze()) if qdd.size != 0 else qdd
228
-
229
- a_i = a_i + S[i] * qdd[ii] if qdd.size != 0 else a_i
230
- a = a.at[i].set(a_i)
231
-
232
- return (a, qdd), None
233
-
234
- (a, qdd), _ = jax.lax.scan(
235
- f=loop_body_pass3,
236
- init=pass_3_carry,
237
- xs=np.arange(1, model.NB),
238
- )
239
-
240
- # Handle 1 DoF models
241
- qdd = jnp.atleast_1d(qdd.squeeze())
242
- qdd = jnp.vstack(qdd) if qdd.size > 0 else jnp.empty(shape=(0, 1))
243
-
244
- # Get the resulting base acceleration (w/o gravity) in body-fixed representation
245
- B_a_WB = a[0]
246
-
247
- # Convert the base acceleration to inertial-fixed representation, and add gravity
248
- W_a_WB = jnp.vstack(
249
- jnp.linalg.solve(B_X_W, B_a_WB) + jnp.vstack(model.gravity)
250
- if model.is_floating_base
251
- else jnp.zeros(6)
252
- )
253
-
254
- return W_a_WB, qdd