jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1.dev401.dist-info/METADATA +0 -167
  88. jaxsim-0.1.dev401.dist-info/RECORD +0 -64
  89. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/math/conv.py DELETED
@@ -1,114 +0,0 @@
1
- import jax.numpy as jnp
2
-
3
- import jaxsim.typing as jtp
4
-
5
- from .skew import Skew
6
-
7
-
8
- class Convert:
9
- @staticmethod
10
- def coordinates_tf(X: jtp.Matrix, p: jtp.Matrix) -> jtp.Matrix:
11
- """
12
- Transform coordinates from one frame to another using a transformation matrix.
13
-
14
- Args:
15
- X (jtp.Matrix): The transformation matrix (4x4 or 6x6).
16
- p (jtp.Matrix): The coordinates to be transformed (3xN).
17
-
18
- Returns:
19
- jtp.Matrix: Transformed coordinates (3xN).
20
-
21
- Raises:
22
- ValueError: If the input matrix p does not have shape (3, N).
23
- """
24
- X = X.squeeze()
25
- p = p.squeeze()
26
-
27
- # If p has shape (X,), transform it to a column vector
28
- p = jnp.vstack(p) if len(p.shape) == 1 else p
29
- rows_p, cols_p = p.shape
30
-
31
- if rows_p != 3:
32
- raise ValueError(p.shape)
33
-
34
- R = X[0:3, 0:3]
35
- r = -Skew.vee(R.T @ X[0:3, 3:6])
36
-
37
- if cols_p > 1:
38
- r = jnp.tile(r, (1, cols_p))
39
-
40
- assert r.shape == p.shape, (r.shape, p.shape)
41
-
42
- xp = R @ (p - r)
43
- return jnp.vstack(xp)
44
-
45
- @staticmethod
46
- def velocities_threed(v_6d: jtp.Matrix, p: jtp.Matrix) -> jtp.Matrix:
47
- """
48
- Compute 3D velocities based on 6D velocities and positions.
49
-
50
- Args:
51
- v_6d (jtp.Matrix): The 6D velocities (6xN).
52
- p (jtp.Matrix): The positions (3xN).
53
-
54
- Returns:
55
- jtp.Matrix: 3D velocities (3xN).
56
-
57
- Raises:
58
- ValueError: If the input matrices have incompatible shapes.
59
- """
60
- v = v_6d.squeeze()
61
- p = p.squeeze()
62
-
63
- # If the arrays have shape (X,), transform them to column vectors
64
- v = jnp.vstack(v) if len(v.shape) == 1 else v
65
- p = jnp.vstack(p) if len(p.shape) == 1 else p
66
-
67
- rows_v, cols_v = v.shape
68
- _, cols_p = p.shape
69
-
70
- if cols_v == 1 and cols_p > 1:
71
- v = jnp.repeat(v, cols_p, axis=1)
72
-
73
- if rows_v == 6:
74
- vp = v[0:3, :] + jnp.cross(v[3:6, :], p, axis=0)
75
- else:
76
- raise ValueError(v.shape)
77
-
78
- return jnp.vstack(vp)
79
-
80
- @staticmethod
81
- def forces_sixd(f_3d: jtp.Matrix, p: jtp.Matrix) -> jtp.Matrix:
82
- """
83
- Compute 6D forces based on 3D forces and positions.
84
-
85
- Args:
86
- f_3d (jtp.Matrix): The 3D forces (3xN).
87
- p (jtp.Matrix): The positions (3xN).
88
-
89
- Returns:
90
- jtp.Matrix: 6D forces (6xN).
91
-
92
- Raises:
93
- ValueError: If the input matrices have incompatible shapes.
94
- """
95
- f = f_3d.squeeze()
96
- p = p.squeeze()
97
-
98
- # If the arrays have shape (X,), transform them to column vectors
99
- fp = jnp.vstack(f) if len(f.shape) == 1 else f
100
- p = jnp.vstack(p) if len(p.shape) == 1 else p
101
-
102
- _, cols_p = p.shape
103
- rows_fp, cols_fp = fp.shape
104
-
105
- # Number of columns must match
106
- if cols_p != cols_fp:
107
- raise ValueError(cols_p, cols_fp)
108
-
109
- if rows_fp == 3:
110
- f = jnp.vstack([fp, jnp.cross(p, fp, axis=0)])
111
- else:
112
- raise ValueError(fp.shape)
113
-
114
- return jnp.vstack(f)
jaxsim/math/joint.py DELETED
@@ -1,101 +0,0 @@
1
- from typing import Tuple, Union
2
-
3
- import jax.numpy as jnp
4
-
5
- import jaxsim.typing as jtp
6
- from jaxsim.parsers.descriptions import JointDescriptor, JointGenericAxis, JointType
7
-
8
- from .adjoint import Adjoint
9
- from .rotation import Rotation
10
-
11
-
12
- def jcalc(
13
- jtyp: Union[JointType, JointDescriptor], q: jtp.Float
14
- ) -> Tuple[jtp.Matrix, jtp.Vector]:
15
- """
16
- Compute the spatial transformation matrix and motion subspace vector for a joint.
17
-
18
- Args:
19
- jtyp (Union[JointType, JointDescriptor]): The type or descriptor of the joint.
20
- q (jtp.Float): The joint configuration parameter.
21
-
22
- Returns:
23
- Tuple[jtp.Matrix, jtp.Vector]: A tuple containing the spatial transformation matrix (6x6) and the motion subspace vector (6x1).
24
-
25
- Raises:
26
- ValueError: If the joint type or descriptor is not recognized.
27
- """
28
- if isinstance(jtyp, JointType):
29
- code = jtyp
30
- elif isinstance(jtyp, JointDescriptor):
31
- code = jtyp.code
32
- else:
33
- raise ValueError(jtyp)
34
-
35
- if code is JointType.F:
36
- raise ValueError("Fixed joints shouldn't be here")
37
-
38
- if code is JointType.R:
39
- jtyp: JointGenericAxis
40
-
41
- Xj = Adjoint.from_rotation_and_translation(
42
- rotation=Rotation.from_axis_angle(vector=q * jtyp.axis), inverse=True
43
- )
44
-
45
- S = jnp.vstack(jnp.hstack([jnp.zeros(3), jtyp.axis.squeeze()]))
46
-
47
- elif code is JointType.P:
48
- jtyp: JointGenericAxis
49
-
50
- Xj = Adjoint.from_rotation_and_translation(
51
- translation=jnp.array(q * jtyp.axis), inverse=True
52
- )
53
-
54
- S = jnp.vstack(jnp.hstack([jtyp.axis.squeeze(), jnp.zeros(3)]))
55
-
56
- elif code is JointType.Rx:
57
- Xj = Adjoint.from_rotation_and_translation(
58
- rotation=Rotation.x(theta=q), inverse=True
59
- )
60
-
61
- S = jnp.vstack([0, 0, 0, 1.0, 0, 0])
62
-
63
- elif code is JointType.Ry:
64
- Xj = Adjoint.from_rotation_and_translation(
65
- rotation=Rotation.y(theta=q), inverse=True
66
- )
67
-
68
- S = jnp.vstack([0, 0, 0, 0, 1.0, 0])
69
-
70
- elif code is JointType.Rz:
71
- Xj = Adjoint.from_rotation_and_translation(
72
- rotation=Rotation.z(theta=q), inverse=True
73
- )
74
-
75
- S = jnp.vstack([0, 0, 0, 0, 0, 1.0])
76
-
77
- elif code is JointType.Px:
78
- Xj = Adjoint.from_rotation_and_translation(
79
- translation=jnp.array([q, 0.0, 0.0]), inverse=True
80
- )
81
-
82
- S = jnp.vstack([1.0, 0, 0, 0, 0, 0])
83
-
84
- elif code is JointType.Py:
85
- Xj = Adjoint.from_rotation_and_translation(
86
- translation=jnp.array([0.0, q, 0.0]), inverse=True
87
- )
88
-
89
- S = jnp.vstack([0, 1.0, 0, 0, 0, 0])
90
-
91
- elif code is JointType.Pz:
92
- Xj = Adjoint.from_rotation_and_translation(
93
- translation=jnp.array([0.0, 0.0, q]), inverse=True
94
- )
95
-
96
- S = jnp.vstack([0, 0, 1.0, 0, 0, 0])
97
-
98
- else:
99
- raise ValueError(code)
100
-
101
- return Xj, S
jaxsim/math/plucker.py DELETED
@@ -1,100 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax.numpy as jnp
4
-
5
- import jaxsim.typing as jtp
6
-
7
- from .skew import Skew
8
-
9
-
10
- class Plucker:
11
- @staticmethod
12
- def from_rot_and_trans(dcm: jtp.Matrix, translation: jtp.Vector) -> jtp.Matrix:
13
- """
14
- Computes the Plücker matrix from a rotation matrix and a translation vector.
15
-
16
- Args:
17
- dcm: A 3x3 rotation matrix.
18
- translation: A 3x1 translation vector.
19
-
20
- Returns:
21
- A 6x6 Plücker matrix.
22
- """
23
- R = dcm
24
-
25
- X = jnp.block(
26
- [
27
- [R, -R @ Skew.wedge(vector=translation)],
28
- [jnp.zeros(shape=(3, 3)), R],
29
- ]
30
- )
31
-
32
- return X
33
-
34
- @staticmethod
35
- def to_rot_and_trans(adjoint: jtp.Matrix) -> Tuple[jtp.Matrix, jtp.Vector]:
36
- """
37
- Computes the rotation matrix and translation vector from a Plücker matrix.
38
-
39
- Args:
40
- adjoint: A 6x6 Plücker matrix.
41
-
42
- Returns:
43
- A tuple containing the 3x3 rotation matrix and the 3x1 translation vector.
44
- """
45
- X = adjoint
46
-
47
- R = X[0:3, 0:3]
48
- p = -Skew.vee(R.T @ X[0:3, 3:6])
49
-
50
- return R, p
51
-
52
- @staticmethod
53
- def from_transform(transform: jtp.Matrix) -> jtp.Matrix:
54
- """
55
- Computes the Plücker matrix from a homogeneous transformation matrix.
56
-
57
- Args:
58
- transform: A 4x4 homogeneous transformation matrix.
59
-
60
- Returns:
61
- A 6x6 Plücker matrix.
62
- """
63
- H = transform
64
-
65
- R = H[0:3, 0:3]
66
- p = H[0:3, 3]
67
-
68
- X = jnp.block(
69
- [
70
- [R, Skew.wedge(vector=p) @ R],
71
- [jnp.zeros(shape=(3, 3)), R],
72
- ]
73
- )
74
-
75
- return X
76
-
77
- @staticmethod
78
- def to_transform(adjoint: jtp.Matrix) -> jtp.Matrix:
79
- """
80
- Computes the homogeneous transformation matrix from a Plücker matrix.
81
-
82
- Args:
83
- adjoint: A 6x6 Plücker matrix.
84
-
85
- Returns:
86
- A 4x4 homogeneous transformation matrix.
87
- """
88
- X = adjoint
89
-
90
- R = X[0:3, 0:3]
91
- o_x_R = X[0:3, 3:6]
92
-
93
- H = jnp.vstack(
94
- [
95
- jnp.hstack([R, Skew.vee(matrix=o_x_R @ R.T)]),
96
- [0, 0, 0, 1],
97
- ]
98
- )
99
-
100
- return H
@@ -1,12 +0,0 @@
1
- import numpy.typing
2
-
3
- from . import algos, model
4
-
5
-
6
- def default_gravity() -> numpy.typing.NDArray:
7
- import jax.numpy as jnp
8
-
9
- return jnp.array([0, 0, -9.80])
10
-
11
-
12
- # from . import dyn, models, spatial, threed, utils
File without changes
@@ -1,256 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Tuple
4
-
5
- import jax
6
- import jax.numpy as jnp
7
- import numpy as np
8
-
9
- import jaxsim.typing as jtp
10
- from jaxsim.math.adjoint import Adjoint
11
- from jaxsim.math.cross import Cross
12
- from jaxsim.physics.model.physics_model import PhysicsModel
13
-
14
- from . import utils
15
-
16
-
17
- def aba(
18
- model: PhysicsModel,
19
- xfb: jtp.Vector,
20
- q: jtp.Vector,
21
- qd: jtp.Vector,
22
- tau: jtp.Vector,
23
- f_ext: jtp.Matrix | None = None,
24
- ) -> Tuple[jtp.Vector, jtp.Vector]:
25
- """
26
- Articulated Body Algorithm (ABA) algorithm for forward dynamics.
27
-
28
- Args:
29
- model: The physics model of the articulated body or robot.
30
- xfb: The floating base state vector containing quaternion (4D) and position (3D).
31
- q: Joint positions (Generalized coordinates).
32
- qd: Joint velocities.
33
- tau: Joint torques or forces.
34
- f_ext: External forces and torques acting on each link. Defaults to None.
35
-
36
- Returns:
37
- A tuple containing the resulting base acceleration (in inertial-fixed representation)
38
- and joint accelerations.
39
-
40
- Note:
41
- The ABA algorithm is used to compute the accelerations of the links in an articulated body or robot system given
42
- inputs such as joint positions, velocities, torques, and external forces. The algorithm involves multiple passes
43
- to calculate intermediate quantities required for simulating the motion of the robot.
44
- """
45
-
46
- x_fb, q, qd, _, tau, f_ext = utils.process_inputs(
47
- physics_model=model, xfb=xfb, q=q, qd=qd, tau=tau, f_ext=f_ext
48
- )
49
-
50
- # Extract data from the physics model
51
- pre_X_λi = model.tree_transforms
52
- M = model.spatial_inertias
53
- i_X_pre = model.joint_transforms(q=q)
54
- S = model.motion_subspaces(q=q)
55
- λ = model.parent_array()
56
-
57
- # Initialize buffers
58
- v = jnp.array([jnp.zeros([6, 1])] * model.NB)
59
- MA = jnp.array([jnp.zeros([6, 6])] * model.NB)
60
- pA = jnp.array([jnp.zeros([6, 1])] * model.NB)
61
- c = jnp.array([jnp.zeros([6, 1])] * model.NB)
62
- i_X_λi = jnp.zeros_like(i_X_pre)
63
-
64
- # Base pose B_X_W and velocity
65
- base_quat = jnp.vstack(x_fb[0:4])
66
- base_pos = jnp.vstack(x_fb[4:7])
67
- base_vel = jnp.vstack(jnp.hstack([x_fb[10:13], x_fb[7:10]]))
68
-
69
- # 6D transform of base velocity
70
- B_X_W = Adjoint.from_quaternion_and_translation(
71
- quaternion=base_quat,
72
- translation=base_pos,
73
- inverse=True,
74
- normalize_quaternion=True,
75
- )
76
- i_X_λi = i_X_λi.at[0].set(B_X_W)
77
-
78
- # Transforms link -> base
79
- i_X_0 = jnp.zeros_like(pre_X_λi)
80
- i_X_0 = i_X_0.at[0].set(jnp.eye(6))
81
-
82
- # Initialize base quantities
83
- if model.is_floating_base:
84
- # Base velocity v₀
85
- v_0 = B_X_W @ base_vel
86
- v = v.at[0].set(v_0)
87
-
88
- # AB inertia (Mᴬ) and AB bias forces (pᴬ)
89
- MA_0 = M[0]
90
- MA = MA.at[0].set(MA_0)
91
- pA_0 = Cross.vx_star(v[0]) @ MA_0 @ v[0] - Adjoint.inverse(
92
- B_X_W
93
- ).T @ jnp.vstack(f_ext[0])
94
- pA = pA.at[0].set(pA_0)
95
-
96
- Pass1Carry = Tuple[
97
- jtp.MatrixJax,
98
- jtp.MatrixJax,
99
- jtp.MatrixJax,
100
- jtp.MatrixJax,
101
- jtp.MatrixJax,
102
- jtp.MatrixJax,
103
- ]
104
-
105
- pass_1_carry = (i_X_λi, v, c, MA, pA, i_X_0)
106
-
107
- # Pass 1
108
- def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:
109
- ii = i - 1
110
- i_X_λi, v, c, MA, pA, i_X_0 = carry
111
-
112
- # Compute parent-to-child transform
113
- i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
114
- i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
115
-
116
- # Propagate link velocity
117
- vJ = S[i] * qd[ii] if qd.size != 0 else S[i] * 0
118
-
119
- v_i = i_X_λi[i] @ v[λ[i]] + vJ
120
- v = v.at[i].set(v_i)
121
-
122
- c_i = Cross.vx(v[i]) @ vJ
123
- c = c.at[i].set(c_i)
124
-
125
- # Initialize articulated-body inertia
126
- MA_i = jnp.array(M[i])
127
- MA = MA.at[i].set(MA_i)
128
-
129
- # Initialize articulated-body bias forces
130
- i_X_0_i = i_X_λi[i] @ i_X_0[model.parent[i]]
131
- i_X_0 = i_X_0.at[i].set(i_X_0_i)
132
- i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
133
-
134
- pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(f_ext[i])
135
- pA = pA.at[i].set(pA_i)
136
-
137
- return (i_X_λi, v, c, MA, pA, i_X_0), None
138
-
139
- (i_X_λi, v, c, MA, pA, i_X_0), _ = jax.lax.scan(
140
- f=loop_body_pass1,
141
- init=pass_1_carry,
142
- xs=np.arange(start=1, stop=model.NB),
143
- )
144
-
145
- U = jnp.zeros_like(S)
146
- d = jnp.zeros(shape=(model.NB, 1))
147
- u = jnp.zeros(shape=(model.NB, 1))
148
-
149
- Pass2Carry = Tuple[
150
- jtp.MatrixJax,
151
- jtp.MatrixJax,
152
- jtp.MatrixJax,
153
- jtp.MatrixJax,
154
- jtp.MatrixJax,
155
- ]
156
-
157
- pass_2_carry = (U, d, u, MA, pA)
158
-
159
- # Pass 2
160
- def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]:
161
- ii = i - 1
162
- U, d, u, MA, pA = carry
163
-
164
- # Compute intermediate results
165
- U_i = MA[i] @ S[i]
166
- U = U.at[i].set(U_i)
167
-
168
- d_i = S[i].T @ U[i]
169
- d = d.at[i].set(d_i.squeeze())
170
-
171
- u_i = tau[ii] - S[i].T @ pA[i] if tau.size != 0 else -S[i].T @ pA[i]
172
- u = u.at[i].set(u_i.squeeze())
173
-
174
- # Compute the articulated-body inertia and bias forces of this link
175
- Ma = MA[i] - U[i] / d[i] @ U[i].T
176
- pa = pA[i] + Ma @ c[i] + U[i] * u[i] / d[i]
177
-
178
- # Propagate them to the parent, handling the base link
179
- def propagate(
180
- MA_pA: Tuple[jtp.MatrixJax, jtp.MatrixJax]
181
- ) -> Tuple[jtp.MatrixJax, jtp.MatrixJax]:
182
- MA, pA = MA_pA
183
-
184
- MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]
185
- MA = MA.at[λ[i]].set(MA_λi)
186
-
187
- pA_λi = pA[λ[i]] + i_X_λi[i].T @ pa
188
- pA = pA.at[λ[i]].set(pA_λi)
189
-
190
- return MA, pA
191
-
192
- MA, pA = jax.lax.cond(
193
- pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(),
194
- true_fun=propagate,
195
- false_fun=lambda MA_pA: MA_pA,
196
- operand=(MA, pA),
197
- )
198
-
199
- return (U, d, u, MA, pA), None
200
-
201
- (U, d, u, MA, pA), _ = jax.lax.scan(
202
- f=loop_body_pass2,
203
- init=pass_2_carry,
204
- xs=np.flip(np.arange(start=1, stop=model.NB)),
205
- )
206
-
207
- if model.is_floating_base:
208
- a0 = jnp.linalg.solve(-MA[0], pA[0])
209
- else:
210
- a0 = -B_X_W @ jnp.vstack(model.gravity)
211
-
212
- a = jnp.zeros_like(S)
213
- a = a.at[0].set(a0)
214
- qdd = jnp.zeros_like(q)
215
-
216
- Pass3Carry = Tuple[jtp.MatrixJax, jtp.VectorJax]
217
- pass_3_carry = (a, qdd)
218
-
219
- # Pass 3
220
- def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> Tuple[Pass3Carry, None]:
221
- ii = i - 1
222
- a, qdd = carry
223
-
224
- # Propagate link accelerations
225
- a_i = i_X_λi[i] @ a[λ[i]] + c[i]
226
-
227
- # Compute joint accelerations
228
- qdd_ii = (u[i] - U[i].T @ a_i) / d[i]
229
- qdd = qdd.at[i - 1].set(qdd_ii.squeeze()) if qdd.size != 0 else qdd
230
-
231
- a_i = a_i + S[i] * qdd[ii] if qdd.size != 0 else a_i
232
- a = a.at[i].set(a_i)
233
-
234
- return (a, qdd), None
235
-
236
- (a, qdd), _ = jax.lax.scan(
237
- f=loop_body_pass3,
238
- init=pass_3_carry,
239
- xs=np.arange(1, model.NB),
240
- )
241
-
242
- # Handle 1 DoF models
243
- qdd = jnp.atleast_1d(qdd.squeeze())
244
- qdd = jnp.vstack(qdd) if qdd.size > 0 else jnp.empty(shape=(0, 1))
245
-
246
- # Get the resulting base acceleration (w/o gravity) in body-fixed representation
247
- B_a_WB = a[0]
248
-
249
- # Convert the base acceleration to inertial-fixed representation, and add gravity
250
- W_a_WB = jnp.vstack(
251
- jnp.linalg.solve(B_X_W, B_a_WB) + jnp.vstack(model.gravity)
252
- if model.is_floating_base
253
- else jnp.zeros(6)
254
- )
255
-
256
- return W_a_WB, qdd