jaxsim 0.2.dev191__py3-none-any.whl → 0.2.dev366__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. jaxsim/__init__.py +3 -4
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +13 -2
  6. jaxsim/api/contact.py +120 -43
  7. jaxsim/api/data.py +112 -71
  8. jaxsim/api/joint.py +77 -36
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +150 -75
  11. jaxsim/api/model.py +542 -269
  12. jaxsim/api/ode.py +86 -74
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +12 -11
  15. jaxsim/integrators/__init__.py +2 -2
  16. jaxsim/integrators/common.py +110 -24
  17. jaxsim/integrators/fixed_step.py +11 -67
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +93 -0
  25. jaxsim/parsers/descriptions/link.py +2 -2
  26. jaxsim/parsers/rod/utils.py +7 -8
  27. jaxsim/rbda/__init__.py +7 -0
  28. jaxsim/rbda/aba.py +295 -0
  29. jaxsim/rbda/collidable_points.py +142 -0
  30. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  31. jaxsim/rbda/forward_kinematics.py +113 -0
  32. jaxsim/rbda/jacobian.py +201 -0
  33. jaxsim/rbda/rnea.py +237 -0
  34. jaxsim/rbda/soft_contacts.py +296 -0
  35. jaxsim/rbda/utils.py +152 -0
  36. jaxsim/terrain/__init__.py +2 -0
  37. jaxsim/utils/__init__.py +1 -4
  38. jaxsim/utils/hashless.py +18 -0
  39. jaxsim/utils/jaxsim_dataclass.py +281 -30
  40. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/METADATA +4 -6
  41. jaxsim-0.2.dev366.dist-info/RECORD +64 -0
  42. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/WHEEL +1 -1
  43. jaxsim/high_level/__init__.py +0 -2
  44. jaxsim/high_level/common.py +0 -11
  45. jaxsim/high_level/joint.py +0 -148
  46. jaxsim/high_level/link.py +0 -259
  47. jaxsim/high_level/model.py +0 -1686
  48. jaxsim/math/conv.py +0 -114
  49. jaxsim/math/joint.py +0 -102
  50. jaxsim/math/plucker.py +0 -100
  51. jaxsim/physics/__init__.py +0 -12
  52. jaxsim/physics/algos/__init__.py +0 -0
  53. jaxsim/physics/algos/aba.py +0 -254
  54. jaxsim/physics/algos/aba_motors.py +0 -284
  55. jaxsim/physics/algos/forward_kinematics.py +0 -79
  56. jaxsim/physics/algos/jacobian.py +0 -98
  57. jaxsim/physics/algos/rnea.py +0 -180
  58. jaxsim/physics/algos/rnea_motors.py +0 -196
  59. jaxsim/physics/algos/soft_contacts.py +0 -523
  60. jaxsim/physics/algos/utils.py +0 -69
  61. jaxsim/physics/model/__init__.py +0 -0
  62. jaxsim/physics/model/ground_contact.py +0 -53
  63. jaxsim/physics/model/physics_model.py +0 -388
  64. jaxsim/physics/model/physics_model_state.py +0 -283
  65. jaxsim/simulation/__init__.py +0 -4
  66. jaxsim/simulation/integrators.py +0 -393
  67. jaxsim/simulation/ode.py +0 -290
  68. jaxsim/simulation/ode_data.py +0 -96
  69. jaxsim/simulation/ode_integration.py +0 -62
  70. jaxsim/simulation/simulator.py +0 -543
  71. jaxsim/simulation/simulator_callbacks.py +0 -79
  72. jaxsim/simulation/utils.py +0 -15
  73. jaxsim/sixd/__init__.py +0 -2
  74. jaxsim/utils/oop.py +0 -536
  75. jaxsim/utils/vmappable.py +0 -117
  76. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  77. /jaxsim/{physics/algos → terrain}/terrain.py +0 -0
  78. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/LICENSE +0 -0
  79. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/top_level.txt +0 -0
@@ -1,284 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import numpy as np
6
-
7
- import jaxsim.typing as jtp
8
- from jaxsim.math.adjoint import Adjoint
9
- from jaxsim.math.cross import Cross
10
- from jaxsim.physics.model.physics_model import PhysicsModel
11
-
12
- from . import utils
13
-
14
-
15
- def aba(
16
- model: PhysicsModel,
17
- xfb: jtp.Vector,
18
- q: jtp.Vector,
19
- qd: jtp.Vector,
20
- tau: jtp.Vector,
21
- f_ext: jtp.Matrix | None = None,
22
- ) -> Tuple[jtp.Vector, jtp.Vector]:
23
- """
24
- Articulated Body Algorithm (ABA) algorithm with motor dynamics for forward dynamics.
25
- """
26
-
27
- x_fb, q, qd, _, tau, f_ext = utils.process_inputs(
28
- physics_model=model, xfb=xfb, q=q, qd=qd, tau=tau, f_ext=f_ext
29
- )
30
-
31
- # Extract data from the physics model
32
- pre_X_λi = model.tree_transforms
33
- M = model.spatial_inertias
34
- i_X_pre = model.joint_transforms(q=q)
35
- S = model.motion_subspaces(q=q)
36
- λ = model.parent_array()
37
-
38
- # Extract motor parameters from the physics model
39
- Γ = jnp.array([*model._joint_motor_gear_ratio.values()])
40
- IM = jnp.array(
41
- [jnp.eye(6) * m for m in [*model._joint_motor_inertia.values()]] * model.NB
42
- )
43
- K̅ᵥ = Γ.T * jnp.array([*model._joint_motor_viscous_friction.values()]) * Γ
44
- m_S = jnp.concatenate([S[:1], S[1:] * Γ[:, None, None]], axis=0)
45
-
46
- # Initialize buffers
47
- v = jnp.array([jnp.zeros([6, 1])] * model.NB)
48
- MA = jnp.array([jnp.zeros([6, 6])] * model.NB)
49
- pA = jnp.array([jnp.zeros([6, 1])] * model.NB)
50
- c = jnp.array([jnp.zeros([6, 1])] * model.NB)
51
- i_X_λi = jnp.zeros_like(i_X_pre)
52
-
53
- m_v = jnp.array([jnp.zeros([6, 1])] * model.NB)
54
- m_c = jnp.array([jnp.zeros([6, 1])] * model.NB)
55
- pR = jnp.array([jnp.zeros([6, 1])] * model.NB)
56
-
57
- # Base pose B_X_W and velocity
58
- base_quat = jnp.vstack(x_fb[0:4])
59
- base_pos = jnp.vstack(x_fb[4:7])
60
- base_vel = jnp.vstack(jnp.hstack([x_fb[10:13], x_fb[7:10]]))
61
-
62
- # 6D transform of base velocity
63
- B_X_W = Adjoint.from_quaternion_and_translation(
64
- quaternion=base_quat,
65
- translation=base_pos,
66
- inverse=True,
67
- normalize_quaternion=True,
68
- )
69
- i_X_λi = i_X_λi.at[0].set(B_X_W)
70
-
71
- # Transforms link -> base
72
- i_X_0 = jnp.zeros_like(pre_X_λi)
73
- i_X_0 = i_X_0.at[0].set(jnp.eye(6))
74
-
75
- # Initialize base quantities
76
- if model.is_floating_base:
77
- # Base velocity v₀
78
- v_0 = B_X_W @ base_vel
79
- v = v.at[0].set(v_0)
80
-
81
- # AB inertia (Mᴬ) and AB bias forces (pᴬ)
82
- MA_0 = M[0]
83
- MA = MA.at[0].set(MA_0)
84
- pA_0 = Cross.vx_star(v[0]) @ MA_0 @ v[0] - Adjoint.inverse(
85
- B_X_W
86
- ).T @ jnp.vstack(f_ext[0])
87
- pA = pA.at[0].set(pA_0)
88
-
89
- Pass1Carry = Tuple[
90
- jtp.MatrixJax,
91
- jtp.MatrixJax,
92
- jtp.MatrixJax,
93
- jtp.MatrixJax,
94
- jtp.MatrixJax,
95
- jtp.MatrixJax,
96
- jtp.MatrixJax,
97
- jtp.MatrixJax,
98
- jtp.MatrixJax,
99
- ]
100
-
101
- pass_1_carry = (i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0)
102
-
103
- # Pass 1
104
- def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:
105
- ii = i - 1
106
- i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0 = carry
107
-
108
- # Compute parent-to-child transform
109
- i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
110
- i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
111
-
112
- # Propagate link velocity
113
- vJ = S[i] * qd[ii] * (qd.size != 0)
114
- m_vJ = m_S[i] * qd[ii] * (qd.size != 0)
115
-
116
- v_i = i_X_λi[i] @ v[λ[i]] + vJ
117
- v = v.at[i].set(v_i)
118
-
119
- m_v_i = i_X_λi[i] @ v[λ[i]] + m_vJ
120
- m_v = m_v.at[i].set(m_v_i)
121
-
122
- c_i = Cross.vx(v[i]) @ vJ
123
- c = c.at[i].set(c_i)
124
- m_c_i = Cross.vx(m_v[i]) @ m_vJ
125
- m_c = m_c.at[i].set(m_c_i)
126
-
127
- # Initialize articulated-body inertia
128
- MA_i = jnp.array(M[i])
129
- MA = MA.at[i].set(MA_i)
130
-
131
- # Initialize articulated-body bias forces
132
- i_X_0_i = i_X_λi[i] @ i_X_0[model.parent[i]]
133
- i_X_0 = i_X_0.at[i].set(i_X_0_i)
134
- i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
135
-
136
- pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(f_ext[i])
137
- pA = pA.at[i].set(pA_i)
138
-
139
- pR_i = Cross.vx_star(m_v[i]) @ IM[i] @ m_v[i] - K̅ᵥ[i] * m_v[i]
140
- pR = pR.at[i].set(pR_i)
141
-
142
- return (i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0), None
143
-
144
- (i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0), _ = jax.lax.scan(
145
- f=loop_body_pass1,
146
- init=pass_1_carry,
147
- xs=np.arange(start=1, stop=model.NB),
148
- )
149
-
150
- U = jnp.zeros_like(S)
151
- m_U = jnp.zeros_like(S)
152
- d = jnp.zeros(shape=(model.NB, 1))
153
- u = jnp.zeros(shape=(model.NB, 1))
154
- m_u = jnp.zeros(shape=(model.NB, 1))
155
-
156
- Pass2Carry = Tuple[
157
- jtp.MatrixJax,
158
- jtp.MatrixJax,
159
- jtp.MatrixJax,
160
- jtp.MatrixJax,
161
- jtp.MatrixJax,
162
- jtp.MatrixJax,
163
- jtp.MatrixJax,
164
- ]
165
-
166
- pass_2_carry = (U, m_U, d, u, m_u, MA, pA)
167
-
168
- # Pass 2
169
- def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]:
170
- ii = i - 1
171
- U, m_U, d, u, m_u, MA, pA = carry
172
-
173
- # Compute intermediate results
174
- u_i = tau[ii] - S[i].T @ pA[i] if tau.size != 0 else -S[i].T @ pA[i]
175
- u = u.at[i].set(u_i.squeeze())
176
-
177
- has_motors = jnp.allclose(Γ[i], 1.0)
178
-
179
- m_u_i = (
180
- tau[ii] / Γ[i] * has_motors - m_S[i].T @ pR[i]
181
- if tau.size != 0
182
- else -m_S[i].T @ pR[i]
183
- )
184
- m_u = m_u.at[i].set(m_u_i.squeeze())
185
-
186
- U_i = MA[i] @ S[i]
187
- U = U.at[i].set(U_i)
188
-
189
- m_U_i = IM[i] @ m_S[i]
190
- m_U = m_U.at[i].set(m_U_i)
191
-
192
- d_i = S[i].T @ MA[i] @ S[i] + m_S[i].T @ IM[i] @ m_S[i]
193
- d = d.at[i].set(d_i.squeeze())
194
-
195
- # Compute the articulated-body inertia and bias forces of this link
196
- Ma = MA[i] + IM[i] - U[i] / d[i] @ U[i].T - m_U[i] / d[i] @ m_U[i].T
197
- pa = (
198
- pA[i]
199
- + pR[i]
200
- + Ma[i] @ c[i]
201
- + IM[i] @ m_c[i]
202
- + U[i] / d[i] * u[i]
203
- + m_U[i] / d[i] * m_u[i]
204
- )
205
-
206
- # Propagate them to the parent, handling the base link
207
- def propagate(
208
- MA_pA: Tuple[jtp.MatrixJax, jtp.MatrixJax]
209
- ) -> Tuple[jtp.MatrixJax, jtp.MatrixJax]:
210
- MA, pA = MA_pA
211
-
212
- MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]
213
- MA = MA.at[λ[i]].set(MA_λi)
214
-
215
- pA_λi = pA[λ[i]] + i_X_λi[i].T @ pa
216
- pA = pA.at[λ[i]].set(pA_λi)
217
-
218
- return MA, pA
219
-
220
- MA, pA = jax.lax.cond(
221
- pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(),
222
- true_fun=propagate,
223
- false_fun=lambda MA_pA: MA_pA,
224
- operand=(MA, pA),
225
- )
226
-
227
- return (U, m_U, d, u, m_u, MA, pA), None
228
-
229
- (U, m_U, d, u, m_u, MA, pA), _ = jax.lax.scan(
230
- f=loop_body_pass2,
231
- init=pass_2_carry,
232
- xs=np.flip(np.arange(start=1, stop=model.NB)),
233
- )
234
-
235
- if model.is_floating_base:
236
- a0 = jnp.linalg.solve(-MA[0], pA[0])
237
- else:
238
- a0 = -B_X_W @ jnp.vstack(model.gravity)
239
-
240
- a = jnp.zeros_like(S)
241
- a = a.at[0].set(a0)
242
- qdd = jnp.zeros_like(q)
243
-
244
- Pass3Carry = Tuple[jtp.MatrixJax, jtp.VectorJax]
245
- pass_3_carry = (a, qdd)
246
-
247
- # Pass 3
248
- def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> Tuple[Pass3Carry, None]:
249
- ii = i - 1
250
- a, qdd = carry
251
-
252
- # Propagate link accelerations
253
- a_i = i_X_λi[i] @ a[λ[i]] + c[i]
254
-
255
- # Compute joint accelerations
256
- qdd_ii = (u[i] + m_u[i] - (U[i].T + m_U[i].T) @ a_i) / d[i]
257
- qdd = qdd.at[ii].set(qdd_ii.squeeze()) if qdd.size != 0 else qdd
258
-
259
- a_i = a_i + S[i] * qdd[ii] if qdd.size != 0 else a_i
260
- a = a.at[i].set(a_i)
261
-
262
- return (a, qdd), None
263
-
264
- (a, qdd), _ = jax.lax.scan(
265
- f=loop_body_pass3,
266
- init=pass_3_carry,
267
- xs=np.arange(1, model.NB),
268
- )
269
-
270
- # Handle 1 DoF models
271
- qdd = jnp.atleast_1d(qdd.squeeze())
272
- qdd = jnp.vstack(qdd) if qdd.size > 0 else jnp.empty(shape=(0, 1))
273
-
274
- # Get the resulting base acceleration (w/o gravity) in body-fixed representation
275
- B_a_WB = a[0]
276
-
277
- # Convert the base acceleration to inertial-fixed representation, and add gravity
278
- W_a_WB = jnp.vstack(
279
- jnp.linalg.solve(B_X_W, B_a_WB) + jnp.vstack(model.gravity)
280
- if model.is_floating_base
281
- else jnp.zeros(6)
282
- )
283
-
284
- return W_a_WB, qdd
@@ -1,79 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import numpy as np
6
-
7
- import jaxsim.typing as jtp
8
- from jaxsim.math.adjoint import Adjoint
9
- from jaxsim.physics.model.physics_model import PhysicsModel
10
-
11
- from . import utils
12
-
13
-
14
- def forward_kinematics_model(
15
- model: PhysicsModel, q: jtp.Vector, xfb: jtp.Vector
16
- ) -> jtp.Array:
17
- """
18
- Compute the forward kinematics transformations for all links in an articulated body or robot.
19
-
20
- Args:
21
- model (PhysicsModel): The physics model of the articulated body or robot.
22
- q (jtp.Vector): Joint positions (Generalized coordinates).
23
- xfb (jtp.Vector): The base pose vector, including the quaternion (first 4 elements) and translation (last 3 elements).
24
-
25
- Returns:
26
- jtp.Array: A 3D array containing the forward kinematics transformations for all links.
27
- """
28
-
29
- x_fb, q, _, _, _, _ = utils.process_inputs(
30
- physics_model=model, xfb=xfb, q=q, qd=None, tau=None, f_ext=None
31
- )
32
-
33
- W_X_0 = Adjoint.from_quaternion_and_translation(
34
- quaternion=x_fb[0:4], translation=x_fb[4:7]
35
- )
36
-
37
- # This is the 6D velocity transform from i-th link frame to the world frame
38
- W_X_i = jnp.zeros(shape=[model.NB, 6, 6])
39
- W_X_i = W_X_i.at[0].set(W_X_0)
40
-
41
- i_X_pre = model.joint_transforms(q=q)
42
- pre_X_λi = model.tree_transforms
43
-
44
- # This is the parent-to-child 6D velocity transforms of all links
45
- i_X_λi = jnp.zeros_like(i_X_pre)
46
-
47
- # Parent array mapping: i -> λ(i).
48
- # Exception: λ(0) must not be used, it's initialized to -1.
49
- λ = model.parent
50
-
51
- PropagateKinematicsCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
52
- propagate_kinematics_carry = (i_X_λi, W_X_i)
53
-
54
- def propagate_kinematics(
55
- carry: PropagateKinematicsCarry, i: jtp.Int
56
- ) -> Tuple[PropagateKinematicsCarry, None]:
57
- i_X_λi, W_X_i = carry
58
-
59
- i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
60
- i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
61
-
62
- W_X_i_i = W_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i])
63
- W_X_i = W_X_i.at[i].set(W_X_i_i)
64
-
65
- return (i_X_λi, W_X_i), None
66
-
67
- (_, W_X_i), _ = jax.lax.scan(
68
- f=propagate_kinematics,
69
- init=propagate_kinematics_carry,
70
- xs=np.arange(start=1, stop=model.NB),
71
- )
72
-
73
- return jnp.stack([Adjoint.to_transform(adjoint=X) for X in list(W_X_i)])
74
-
75
-
76
- def forward_kinematics(
77
- model: PhysicsModel, body_index: jtp.Int, q: jtp.Vector, xfb: jtp.Vector
78
- ) -> jtp.Matrix:
79
- return forward_kinematics_model(model=model, q=q, xfb=xfb)[body_index]
@@ -1,98 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import numpy as np
6
-
7
- import jaxsim.typing as jtp
8
- from jaxsim.math.adjoint import Adjoint
9
- from jaxsim.physics.model.physics_model import PhysicsModel
10
-
11
- from . import utils
12
-
13
-
14
- def jacobian(model: PhysicsModel, body_index: jtp.Int, q: jtp.Vector) -> jtp.Matrix:
15
- """
16
- Compute the Jacobian matrix for a specific link in an articulated body or robot.
17
-
18
- Args:
19
- model (PhysicsModel): The physics model of the articulated body or robot.
20
- body_index (jtp.Int): The index of the link for which to compute the Jacobian matrix.
21
- q (jtp.Vector): Joint positions (Generalized coordinates).
22
-
23
- Returns:
24
- jtp.Matrix: The Jacobian matrix for the specified link.
25
- """
26
- _, q, _, _, _, _ = utils.process_inputs(physics_model=model, q=q)
27
-
28
- S = model.motion_subspaces(q=q)
29
- i_X_pre = model.joint_transforms(q=q)
30
- pre_X_λi = model.tree_transforms
31
- i_X_λi = jnp.zeros_like(i_X_pre)
32
-
33
- i_X_0 = jnp.zeros_like(i_X_pre)
34
- i_X_0 = i_X_0.at[0].set(jnp.eye(6))
35
-
36
- # Parent array mapping: i -> λ(i).
37
- # Exception: λ(0) must not be used, it's initialized to -1.
38
- λ = model.parent
39
-
40
- # ====================
41
- # Propagate kinematics
42
- # ====================
43
-
44
- PropagateKinematicsCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
45
- propagate_kinematics_carry = (i_X_λi, i_X_0)
46
-
47
- def propagate_kinematics(
48
- carry: PropagateKinematicsCarry, i: jtp.Int
49
- ) -> Tuple[PropagateKinematicsCarry, None]:
50
- i_X_λi, i_X_0 = carry
51
-
52
- # For each body (i), compute the parent (λi) to body (i) adjoint matrix
53
- i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
54
- i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
55
-
56
- # Compute the base (0) to body (i) adjoint matrix.
57
- # This works fine since we traverse the kinematic tree following the link
58
- # indices assigned with BFS.
59
- i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
60
- i_X_0 = i_X_0.at[i].set(i_X_0_i)
61
-
62
- return (i_X_λi, i_X_0), None
63
-
64
- (i_X_λi, i_X_0), _ = jax.lax.scan(
65
- f=propagate_kinematics,
66
- init=propagate_kinematics_carry,
67
- xs=np.arange(start=1, stop=model.NB),
68
- )
69
-
70
- # ============================
71
- # Compute doubly-left Jacobian
72
- # ============================
73
-
74
- J = jnp.zeros(shape=(6, 6 + model.dofs()))
75
-
76
- Jb = i_X_0[body_index]
77
- J = J.at[0:6, 0:6].set(Jb)
78
-
79
- # To make JIT happy, we operate on a boolean version of κ(i).
80
- # Checking if j ∈ κ(i) is equivalent to: κ_bool(j) is True.
81
- κ_bool = model.support_body_array_bool(body_index=body_index)
82
-
83
- def compute_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> Tuple[jtp.MatrixJax, None]:
84
- def update_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> jtp.MatrixJax:
85
- ii = i - 1
86
- Js_i = i_X_0[body_index] @ Adjoint.inverse(i_X_0[i]) @ S[i]
87
- J = J.at[0:6, 6 + ii].set(Js_i.squeeze())
88
-
89
- return J
90
-
91
- J = jax.lax.select(pred=κ_bool[i], on_true=update_jacobian(J, i), on_false=J)
92
- return J, None
93
-
94
- J, _ = jax.lax.scan(
95
- f=compute_jacobian, init=J, xs=np.arange(start=1, stop=model.NB)
96
- )
97
-
98
- return J
@@ -1,180 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import numpy as np
6
-
7
- import jaxsim.typing as jtp
8
- from jaxsim.math.adjoint import Adjoint
9
- from jaxsim.math.cross import Cross
10
- from jaxsim.physics.model.physics_model import PhysicsModel
11
-
12
- from . import utils
13
-
14
-
15
- def rnea(
16
- model: PhysicsModel,
17
- xfb: jtp.Vector,
18
- q: jtp.Vector,
19
- qd: jtp.Vector,
20
- qdd: jtp.Vector,
21
- a0fb: jtp.Vector = jnp.zeros(6),
22
- f_ext: jtp.Matrix | None = None,
23
- ) -> Tuple[jtp.Vector, jtp.Vector]:
24
- """
25
- Perform Inverse Dynamics Calculation using the Recursive Newton-Euler Algorithm (RNEA).
26
-
27
- This function calculates the joint torques (forces) required to achieve a desired motion
28
- given the robot's configuration, velocities, accelerations, and external forces.
29
-
30
- Args:
31
- model (PhysicsModel): The robot's physics model containing dynamic parameters.
32
- xfb (jtp.Vector): The floating base state, including orientation and position.
33
- q (jtp.Vector): Joint positions (angles).
34
- qd (jtp.Vector): Joint velocities.
35
- qdd (jtp.Vector): Joint accelerations.
36
- a0fb (jtp.Vector, optional): Base acceleration. Defaults to zeros.
37
- f_ext (jtp.Matrix, optional): External forces acting on the robot. Defaults to None.
38
-
39
- Returns:
40
- W_f0 (jtp.Vector): The base 6D force expressed in the world frame.
41
- tau (jtp.Vector): Joint torques (forces) required for the desired motion.
42
- """
43
-
44
- xfb, q, qd, qdd, _, f_ext = utils.process_inputs(
45
- physics_model=model, xfb=xfb, q=q, qd=qd, qdd=qdd, f_ext=f_ext
46
- )
47
-
48
- a0fb = a0fb.squeeze()
49
- gravity = model.gravity.squeeze()
50
-
51
- if a0fb.shape[0] != 6:
52
- raise ValueError(a0fb.shape)
53
-
54
- M = model.spatial_inertias
55
- pre_X_λi = model.tree_transforms
56
- i_X_pre = model.joint_transforms(q=q)
57
- S = model.motion_subspaces(q=q)
58
- i_X_λi = jnp.zeros_like(pre_X_λi)
59
-
60
- i_X_0 = jnp.zeros_like(pre_X_λi)
61
- i_X_0 = i_X_0.at[0].set(jnp.eye(6))
62
-
63
- # Parent array mapping: i -> λ(i).
64
- # Exception: λ(0) must not be used, it's initialized to -1.
65
- λ = model.parent_array()
66
-
67
- v = jnp.array([jnp.zeros([6, 1])] * model.NB)
68
- a = jnp.array([jnp.zeros([6, 1])] * model.NB)
69
- f = jnp.array([jnp.zeros([6, 1])] * model.NB)
70
-
71
- # 6D transform of base velocity
72
- B_X_W = Adjoint.from_quaternion_and_translation(
73
- quaternion=xfb[0:4],
74
- translation=xfb[4:7],
75
- inverse=True,
76
- normalize_quaternion=True,
77
- )
78
- i_X_λi = i_X_λi.at[0].set(B_X_W)
79
-
80
- a_0 = -B_X_W @ jnp.vstack(gravity)
81
- a = a.at[0].set(a_0)
82
-
83
- if model.is_floating_base:
84
- W_v_WB = jnp.vstack(jnp.hstack([xfb[10:13], xfb[7:10]]))
85
-
86
- v_0 = B_X_W @ W_v_WB
87
- v = v.at[0].set(v_0)
88
-
89
- a_0 = B_X_W @ (jnp.vstack(a0fb) - jnp.vstack(gravity))
90
- a = a.at[0].set(a_0)
91
-
92
- f_0 = (
93
- M[0] @ a[0]
94
- + Cross.vx_star(v[0]) @ M[0] @ v[0]
95
- - Adjoint.inverse(B_X_W).T @ jnp.vstack(f_ext[0])
96
- )
97
- f = f.at[0].set(f_0)
98
-
99
- ForwardPassCarry = Tuple[
100
- jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
101
- ]
102
- forward_pass_carry = (i_X_λi, v, a, i_X_0, f)
103
-
104
- def forward_pass(
105
- carry: ForwardPassCarry, i: jtp.Int
106
- ) -> Tuple[ForwardPassCarry, None]:
107
- ii = i - 1
108
- i_X_λi, v, a, i_X_0, f = carry
109
-
110
- vJ = S[i] * qd[ii]
111
- i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
112
- i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
113
-
114
- v_i = i_X_λi[i] @ v[λ[i]] + vJ
115
- v = v.at[i].set(v_i)
116
-
117
- a_i = i_X_λi[i] @ a[λ[i]] + S[i] * qdd[ii] + Cross.vx(v[i]) @ vJ
118
- a = a.at[i].set(a_i)
119
-
120
- i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
121
- i_X_0 = i_X_0.at[i].set(i_X_0_i)
122
- i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
123
-
124
- f_i = (
125
- M[i] @ a[i]
126
- + Cross.vx_star(v[i]) @ M[i] @ v[i]
127
- - i_Xf_W @ jnp.vstack(f_ext[i])
128
- )
129
- f = f.at[i].set(f_i)
130
-
131
- return (i_X_λi, v, a, i_X_0, f), None
132
-
133
- (i_X_λi, v, a, i_X_0, f), _ = jax.lax.scan(
134
- f=forward_pass,
135
- init=forward_pass_carry,
136
- xs=np.arange(start=1, stop=model.NB),
137
- )
138
-
139
- tau = jnp.zeros_like(q)
140
-
141
- BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
142
- backward_pass_carry = (tau, f)
143
-
144
- def backward_pass(
145
- carry: BackwardPassCarry, i: jtp.Int
146
- ) -> Tuple[BackwardPassCarry, None]:
147
- ii = i - 1
148
- tau, f = carry
149
-
150
- value = S[i].T @ f[i]
151
- tau = tau.at[ii].set(value.squeeze())
152
-
153
- def update_f(f: jtp.MatrixJax) -> jtp.MatrixJax:
154
- f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
155
- f = f.at[λ[i]].set(f_λi)
156
- return f
157
-
158
- f = jax.lax.cond(
159
- pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(),
160
- true_fun=update_f,
161
- false_fun=lambda f: f,
162
- operand=f,
163
- )
164
-
165
- return (tau, f), None
166
-
167
- (tau, f), _ = jax.lax.scan(
168
- f=backward_pass,
169
- init=backward_pass_carry,
170
- xs=np.flip(np.arange(start=1, stop=model.NB)),
171
- )
172
-
173
- # Handle 1 DoF models
174
- tau = jnp.atleast_1d(tau.squeeze())
175
- tau = jnp.vstack(tau) if tau.size > 0 else jnp.empty(shape=(0, 1))
176
-
177
- # Express the base 6D force in the world frame
178
- W_f0 = B_X_W.T @ jnp.vstack(f[0])
179
-
180
- return W_f0, tau