jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,284 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import numpy as np
6
-
7
- import jaxsim.typing as jtp
8
- from jaxsim.math.adjoint import Adjoint
9
- from jaxsim.math.cross import Cross
10
- from jaxsim.physics.model.physics_model import PhysicsModel
11
-
12
- from . import utils
13
-
14
-
15
- def aba(
16
- model: PhysicsModel,
17
- xfb: jtp.Vector,
18
- q: jtp.Vector,
19
- qd: jtp.Vector,
20
- tau: jtp.Vector,
21
- f_ext: jtp.Matrix | None = None,
22
- ) -> Tuple[jtp.Vector, jtp.Vector]:
23
- """
24
- Articulated Body Algorithm (ABA) algorithm with motor dynamics for forward dynamics.
25
- """
26
-
27
- x_fb, q, qd, _, tau, f_ext = utils.process_inputs(
28
- physics_model=model, xfb=xfb, q=q, qd=qd, tau=tau, f_ext=f_ext
29
- )
30
-
31
- # Extract data from the physics model
32
- pre_X_λi = model.tree_transforms
33
- M = model.spatial_inertias
34
- i_X_pre = model.joint_transforms(q=q)
35
- S = model.motion_subspaces(q=q)
36
- λ = model.parent_array()
37
-
38
- # Extract motor parameters from the physics model
39
- Γ = jnp.array([*model._joint_motor_gear_ratio.values()])
40
- IM = jnp.array(
41
- [jnp.eye(6) * m for m in [*model._joint_motor_inertia.values()]] * model.NB
42
- )
43
- K̅ᵥ = Γ.T * jnp.array([*model._joint_motor_viscous_friction.values()]) * Γ
44
- m_S = jnp.concatenate([S[:1], S[1:] * Γ[:, None, None]], axis=0)
45
-
46
- # Initialize buffers
47
- v = jnp.array([jnp.zeros([6, 1])] * model.NB)
48
- MA = jnp.array([jnp.zeros([6, 6])] * model.NB)
49
- pA = jnp.array([jnp.zeros([6, 1])] * model.NB)
50
- c = jnp.array([jnp.zeros([6, 1])] * model.NB)
51
- i_X_λi = jnp.zeros_like(i_X_pre)
52
-
53
- m_v = jnp.array([jnp.zeros([6, 1])] * model.NB)
54
- m_c = jnp.array([jnp.zeros([6, 1])] * model.NB)
55
- pR = jnp.array([jnp.zeros([6, 1])] * model.NB)
56
-
57
- # Base pose B_X_W and velocity
58
- base_quat = jnp.vstack(x_fb[0:4])
59
- base_pos = jnp.vstack(x_fb[4:7])
60
- base_vel = jnp.vstack(jnp.hstack([x_fb[10:13], x_fb[7:10]]))
61
-
62
- # 6D transform of base velocity
63
- B_X_W = Adjoint.from_quaternion_and_translation(
64
- quaternion=base_quat,
65
- translation=base_pos,
66
- inverse=True,
67
- normalize_quaternion=True,
68
- )
69
- i_X_λi = i_X_λi.at[0].set(B_X_W)
70
-
71
- # Transforms link -> base
72
- i_X_0 = jnp.zeros_like(pre_X_λi)
73
- i_X_0 = i_X_0.at[0].set(jnp.eye(6))
74
-
75
- # Initialize base quantities
76
- if model.is_floating_base:
77
- # Base velocity v₀
78
- v_0 = B_X_W @ base_vel
79
- v = v.at[0].set(v_0)
80
-
81
- # AB inertia (Mᴬ) and AB bias forces (pᴬ)
82
- MA_0 = M[0]
83
- MA = MA.at[0].set(MA_0)
84
- pA_0 = Cross.vx_star(v[0]) @ MA_0 @ v[0] - Adjoint.inverse(
85
- B_X_W
86
- ).T @ jnp.vstack(f_ext[0])
87
- pA = pA.at[0].set(pA_0)
88
-
89
- Pass1Carry = Tuple[
90
- jtp.MatrixJax,
91
- jtp.MatrixJax,
92
- jtp.MatrixJax,
93
- jtp.MatrixJax,
94
- jtp.MatrixJax,
95
- jtp.MatrixJax,
96
- jtp.MatrixJax,
97
- jtp.MatrixJax,
98
- jtp.MatrixJax,
99
- ]
100
-
101
- pass_1_carry = (i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0)
102
-
103
- # Pass 1
104
- def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:
105
- ii = i - 1
106
- i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0 = carry
107
-
108
- # Compute parent-to-child transform
109
- i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
110
- i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
111
-
112
- # Propagate link velocity
113
- vJ = S[i] * qd[ii] * (qd.size != 0)
114
- m_vJ = m_S[i] * qd[ii] * (qd.size != 0)
115
-
116
- v_i = i_X_λi[i] @ v[λ[i]] + vJ
117
- v = v.at[i].set(v_i)
118
-
119
- m_v_i = i_X_λi[i] @ v[λ[i]] + m_vJ
120
- m_v = m_v.at[i].set(m_v_i)
121
-
122
- c_i = Cross.vx(v[i]) @ vJ
123
- c = c.at[i].set(c_i)
124
- m_c_i = Cross.vx(m_v[i]) @ m_vJ
125
- m_c = m_c.at[i].set(m_c_i)
126
-
127
- # Initialize articulated-body inertia
128
- MA_i = jnp.array(M[i])
129
- MA = MA.at[i].set(MA_i)
130
-
131
- # Initialize articulated-body bias forces
132
- i_X_0_i = i_X_λi[i] @ i_X_0[model.parent[i]]
133
- i_X_0 = i_X_0.at[i].set(i_X_0_i)
134
- i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
135
-
136
- pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(f_ext[i])
137
- pA = pA.at[i].set(pA_i)
138
-
139
- pR_i = Cross.vx_star(m_v[i]) @ IM[i] @ m_v[i] - K̅ᵥ[i] * m_v[i]
140
- pR = pR.at[i].set(pR_i)
141
-
142
- return (i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0), None
143
-
144
- (i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0), _ = jax.lax.scan(
145
- f=loop_body_pass1,
146
- init=pass_1_carry,
147
- xs=np.arange(start=1, stop=model.NB),
148
- )
149
-
150
- U = jnp.zeros_like(S)
151
- m_U = jnp.zeros_like(S)
152
- d = jnp.zeros(shape=(model.NB, 1))
153
- u = jnp.zeros(shape=(model.NB, 1))
154
- m_u = jnp.zeros(shape=(model.NB, 1))
155
-
156
- Pass2Carry = Tuple[
157
- jtp.MatrixJax,
158
- jtp.MatrixJax,
159
- jtp.MatrixJax,
160
- jtp.MatrixJax,
161
- jtp.MatrixJax,
162
- jtp.MatrixJax,
163
- jtp.MatrixJax,
164
- ]
165
-
166
- pass_2_carry = (U, m_U, d, u, m_u, MA, pA)
167
-
168
- # Pass 2
169
- def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]:
170
- ii = i - 1
171
- U, m_U, d, u, m_u, MA, pA = carry
172
-
173
- # Compute intermediate results
174
- u_i = tau[ii] - S[i].T @ pA[i] if tau.size != 0 else -S[i].T @ pA[i]
175
- u = u.at[i].set(u_i.squeeze())
176
-
177
- has_motors = jnp.allclose(Γ[i], 1.0)
178
-
179
- m_u_i = (
180
- tau[ii] / Γ[i] * has_motors - m_S[i].T @ pR[i]
181
- if tau.size != 0
182
- else -m_S[i].T @ pR[i]
183
- )
184
- m_u = m_u.at[i].set(m_u_i.squeeze())
185
-
186
- U_i = MA[i] @ S[i]
187
- U = U.at[i].set(U_i)
188
-
189
- m_U_i = IM[i] @ m_S[i]
190
- m_U = m_U.at[i].set(m_U_i)
191
-
192
- d_i = S[i].T @ MA[i] @ S[i] + m_S[i].T @ IM[i] @ m_S[i]
193
- d = d.at[i].set(d_i.squeeze())
194
-
195
- # Compute the articulated-body inertia and bias forces of this link
196
- Ma = MA[i] + IM[i] - U[i] / d[i] @ U[i].T - m_U[i] / d[i] @ m_U[i].T
197
- pa = (
198
- pA[i]
199
- + pR[i]
200
- + Ma[i] @ c[i]
201
- + IM[i] @ m_c[i]
202
- + U[i] / d[i] * u[i]
203
- + m_U[i] / d[i] * m_u[i]
204
- )
205
-
206
- # Propagate them to the parent, handling the base link
207
- def propagate(
208
- MA_pA: Tuple[jtp.MatrixJax, jtp.MatrixJax]
209
- ) -> Tuple[jtp.MatrixJax, jtp.MatrixJax]:
210
- MA, pA = MA_pA
211
-
212
- MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]
213
- MA = MA.at[λ[i]].set(MA_λi)
214
-
215
- pA_λi = pA[λ[i]] + i_X_λi[i].T @ pa
216
- pA = pA.at[λ[i]].set(pA_λi)
217
-
218
- return MA, pA
219
-
220
- MA, pA = jax.lax.cond(
221
- pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(),
222
- true_fun=propagate,
223
- false_fun=lambda MA_pA: MA_pA,
224
- operand=(MA, pA),
225
- )
226
-
227
- return (U, m_U, d, u, m_u, MA, pA), None
228
-
229
- (U, m_U, d, u, m_u, MA, pA), _ = jax.lax.scan(
230
- f=loop_body_pass2,
231
- init=pass_2_carry,
232
- xs=np.flip(np.arange(start=1, stop=model.NB)),
233
- )
234
-
235
- if model.is_floating_base:
236
- a0 = jnp.linalg.solve(-MA[0], pA[0])
237
- else:
238
- a0 = -B_X_W @ jnp.vstack(model.gravity)
239
-
240
- a = jnp.zeros_like(S)
241
- a = a.at[0].set(a0)
242
- qdd = jnp.zeros_like(q)
243
-
244
- Pass3Carry = Tuple[jtp.MatrixJax, jtp.VectorJax]
245
- pass_3_carry = (a, qdd)
246
-
247
- # Pass 3
248
- def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> Tuple[Pass3Carry, None]:
249
- ii = i - 1
250
- a, qdd = carry
251
-
252
- # Propagate link accelerations
253
- a_i = i_X_λi[i] @ a[λ[i]] + c[i]
254
-
255
- # Compute joint accelerations
256
- qdd_ii = (u[i] + m_u[i] - (U[i].T + m_U[i].T) @ a_i) / d[i]
257
- qdd = qdd.at[ii].set(qdd_ii.squeeze()) if qdd.size != 0 else qdd
258
-
259
- a_i = a_i + S[i] * qdd[ii] if qdd.size != 0 else a_i
260
- a = a.at[i].set(a_i)
261
-
262
- return (a, qdd), None
263
-
264
- (a, qdd), _ = jax.lax.scan(
265
- f=loop_body_pass3,
266
- init=pass_3_carry,
267
- xs=np.arange(1, model.NB),
268
- )
269
-
270
- # Handle 1 DoF models
271
- qdd = jnp.atleast_1d(qdd.squeeze())
272
- qdd = jnp.vstack(qdd) if qdd.size > 0 else jnp.empty(shape=(0, 1))
273
-
274
- # Get the resulting base acceleration (w/o gravity) in body-fixed representation
275
- B_a_WB = a[0]
276
-
277
- # Convert the base acceleration to inertial-fixed representation, and add gravity
278
- W_a_WB = jnp.vstack(
279
- jnp.linalg.solve(B_X_W, B_a_WB) + jnp.vstack(model.gravity)
280
- if model.is_floating_base
281
- else jnp.zeros(6)
282
- )
283
-
284
- return W_a_WB, qdd
@@ -1,154 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import numpy as np
6
-
7
- import jaxsim.typing as jtp
8
- from jaxsim.physics.model.physics_model import PhysicsModel
9
-
10
- from . import utils
11
-
12
-
13
- def crba(model: PhysicsModel, q: jtp.Vector) -> jtp.Matrix:
14
- """
15
- Compute the Composite Rigid-Body Inertia Matrix (CRBA) for an articulated body or robot given joint positions.
16
-
17
- Args:
18
- model (PhysicsModel): The physics model of the articulated body or robot.
19
- q (jtp.Vector): Joint positions (Generalized coordinates).
20
-
21
- Returns:
22
- jtp.Matrix: The Composite Rigid-Body Inertia Matrix (CRBA) of the articulated body or robot.
23
- """
24
-
25
- _, q, _, _, _, _ = utils.process_inputs(
26
- physics_model=model, xfb=None, q=q, qd=None, tau=None, f_ext=None
27
- )
28
-
29
- Xtree = model.tree_transforms
30
- Mc = model.spatial_inertias
31
- S = model.motion_subspaces(q=q)
32
- Xj = model.joint_transforms(q=q)
33
-
34
- Xup = jnp.zeros_like(Xtree)
35
- i_X_0 = jnp.zeros_like(Xtree)
36
- i_X_0 = i_X_0.at[0].set(jnp.eye(6))
37
-
38
- # Parent array mapping: i -> λ(i).
39
- # Exception: λ(0) must not be used, it's initialized to -1.
40
- λ = model.parent
41
-
42
- # ====================
43
- # Propagate kinematics
44
- # ====================
45
-
46
- ForwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
47
- forward_pass_carry = (Xup, i_X_0)
48
-
49
- def propagate_kinematics(
50
- carry: ForwardPassCarry, i: jtp.Int
51
- ) -> Tuple[ForwardPassCarry, None]:
52
- Xup, i_X_0 = carry
53
-
54
- Xup_i = Xj[i] @ Xtree[i]
55
- Xup = Xup.at[i].set(Xup_i)
56
-
57
- i_X_0_i = Xup[i] @ i_X_0[λ[i]]
58
- i_X_0 = i_X_0.at[i].set(i_X_0_i)
59
-
60
- return (Xup, i_X_0), None
61
-
62
- (Xup, i_X_0), _ = jax.lax.scan(
63
- f=propagate_kinematics,
64
- init=forward_pass_carry,
65
- xs=np.arange(start=1, stop=model.NB),
66
- )
67
-
68
- # ===================
69
- # Compute mass matrix
70
- # ===================
71
-
72
- M = jnp.zeros(shape=(6 + model.dofs(), 6 + model.dofs()))
73
-
74
- BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
75
- backward_pass_carry = (Mc, M)
76
-
77
- def backward_pass(
78
- carry: BackwardPassCarry, i: jtp.Int
79
- ) -> Tuple[BackwardPassCarry, None]:
80
- ii = i - 1
81
- Mc, M = carry
82
-
83
- Mc_λi = Mc[λ[i]] + Xup[i].T @ Mc[i] @ Xup[i]
84
- Mc = Mc.at[λ[i]].set(Mc_λi)
85
-
86
- Fi = Mc[i] @ S[i]
87
- M_ii = S[i].T @ Fi
88
- M = M.at[ii + 6, ii + 6].set(M_ii.squeeze())
89
-
90
- j = i
91
-
92
- CarryInnerFn = Tuple[jtp.Int, jtp.MatrixJax, jtp.MatrixJax]
93
- carry_inner_fn = (j, Fi, M)
94
-
95
- def while_loop_body(carry: CarryInnerFn) -> CarryInnerFn:
96
- j, Fi, M = carry
97
-
98
- Fi = Xup[j].T @ Fi
99
- j = λ[j]
100
- jj = j - 1
101
-
102
- M_ij = Fi.T @ S[j]
103
-
104
- M = M.at[ii + 6, jj + 6].set(M_ij.squeeze())
105
- M = M.at[jj + 6, ii + 6].set(M_ij.squeeze())
106
-
107
- return j, Fi, M
108
-
109
- # The following functions are part of a (rather messy) workaround for computing
110
- # a while loop using a for loop with fixed number of iterations.
111
- def inner_fn(carry: CarryInnerFn, k: jtp.Int) -> Tuple[CarryInnerFn, None]:
112
- def compute_inner(carry: CarryInnerFn) -> Tuple[CarryInnerFn, None]:
113
- j, Fi, M = carry
114
- out = jax.lax.cond(
115
- pred=(λ[j] > 0),
116
- true_fun=while_loop_body,
117
- false_fun=lambda carry: carry,
118
- operand=carry,
119
- )
120
- return out, None
121
-
122
- j, Fi, M = carry
123
- return jax.lax.cond(
124
- pred=(k == j),
125
- true_fun=compute_inner,
126
- false_fun=lambda carry: (carry, None),
127
- operand=carry,
128
- )
129
-
130
- (j, Fi, M), _ = jax.lax.scan(
131
- f=inner_fn,
132
- init=carry_inner_fn,
133
- xs=np.flip(np.arange(start=1, stop=model.NB)),
134
- )
135
-
136
- Fi = i_X_0[j].T @ Fi
137
-
138
- M = M.at[0:6, ii + 6].set(Fi.squeeze())
139
- M = M.at[ii + 6, 0:6].set(Fi.squeeze())
140
-
141
- return (Mc, M), None
142
-
143
- # This scan performs the backward pass to compute Mbj, Mjb and Mjj, that
144
- # also includes a fake while loop implemented with a scan and two cond.
145
- (Mc, M), _ = jax.lax.scan(
146
- f=backward_pass,
147
- init=backward_pass_carry,
148
- xs=np.flip(np.arange(start=1, stop=model.NB)),
149
- )
150
-
151
- # Store the locked 6D rigid-body inertia matrix Mbb ∈ ℝ⁶ˣ⁶
152
- M = M.at[0:6, 0:6].set(Mc[0])
153
-
154
- return M
@@ -1,79 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import numpy as np
6
-
7
- import jaxsim.typing as jtp
8
- from jaxsim.math.adjoint import Adjoint
9
- from jaxsim.physics.model.physics_model import PhysicsModel
10
-
11
- from . import utils
12
-
13
-
14
- def forward_kinematics_model(
15
- model: PhysicsModel, q: jtp.Vector, xfb: jtp.Vector
16
- ) -> jtp.Array:
17
- """
18
- Compute the forward kinematics transformations for all links in an articulated body or robot.
19
-
20
- Args:
21
- model (PhysicsModel): The physics model of the articulated body or robot.
22
- q (jtp.Vector): Joint positions (Generalized coordinates).
23
- xfb (jtp.Vector): The base pose vector, including the quaternion (first 4 elements) and translation (last 3 elements).
24
-
25
- Returns:
26
- jtp.Array: A 3D array containing the forward kinematics transformations for all links.
27
- """
28
-
29
- x_fb, q, _, _, _, _ = utils.process_inputs(
30
- physics_model=model, xfb=xfb, q=q, qd=None, tau=None, f_ext=None
31
- )
32
-
33
- W_X_0 = Adjoint.from_quaternion_and_translation(
34
- quaternion=x_fb[0:4], translation=x_fb[4:7]
35
- )
36
-
37
- # This is the 6D velocity transform from i-th link frame to the world frame
38
- W_X_i = jnp.zeros(shape=[model.NB, 6, 6])
39
- W_X_i = W_X_i.at[0].set(W_X_0)
40
-
41
- i_X_pre = model.joint_transforms(q=q)
42
- pre_X_λi = model.tree_transforms
43
-
44
- # This is the parent-to-child 6D velocity transforms of all links
45
- i_X_λi = jnp.zeros_like(i_X_pre)
46
-
47
- # Parent array mapping: i -> λ(i).
48
- # Exception: λ(0) must not be used, it's initialized to -1.
49
- λ = model.parent
50
-
51
- PropagateKinematicsCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
52
- propagate_kinematics_carry = (i_X_λi, W_X_i)
53
-
54
- def propagate_kinematics(
55
- carry: PropagateKinematicsCarry, i: jtp.Int
56
- ) -> Tuple[PropagateKinematicsCarry, None]:
57
- i_X_λi, W_X_i = carry
58
-
59
- i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
60
- i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
61
-
62
- W_X_i_i = W_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i])
63
- W_X_i = W_X_i.at[i].set(W_X_i_i)
64
-
65
- return (i_X_λi, W_X_i), None
66
-
67
- (_, W_X_i), _ = jax.lax.scan(
68
- f=propagate_kinematics,
69
- init=propagate_kinematics_carry,
70
- xs=np.arange(start=1, stop=model.NB),
71
- )
72
-
73
- return jnp.stack([Adjoint.to_transform(adjoint=X) for X in list(W_X_i)])
74
-
75
-
76
- def forward_kinematics(
77
- model: PhysicsModel, body_index: jtp.Int, q: jtp.Vector, xfb: jtp.Vector
78
- ) -> jtp.Matrix:
79
- return forward_kinematics_model(model=model, q=q, xfb=xfb)[body_index]
@@ -1,98 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import numpy as np
6
-
7
- import jaxsim.typing as jtp
8
- from jaxsim.math.adjoint import Adjoint
9
- from jaxsim.physics.model.physics_model import PhysicsModel
10
-
11
- from . import utils
12
-
13
-
14
- def jacobian(model: PhysicsModel, body_index: jtp.Int, q: jtp.Vector) -> jtp.Matrix:
15
- """
16
- Compute the Jacobian matrix for a specific link in an articulated body or robot.
17
-
18
- Args:
19
- model (PhysicsModel): The physics model of the articulated body or robot.
20
- body_index (jtp.Int): The index of the link for which to compute the Jacobian matrix.
21
- q (jtp.Vector): Joint positions (Generalized coordinates).
22
-
23
- Returns:
24
- jtp.Matrix: The Jacobian matrix for the specified link.
25
- """
26
- _, q, _, _, _, _ = utils.process_inputs(physics_model=model, q=q)
27
-
28
- S = model.motion_subspaces(q=q)
29
- i_X_pre = model.joint_transforms(q=q)
30
- pre_X_λi = model.tree_transforms
31
- i_X_λi = jnp.zeros_like(i_X_pre)
32
-
33
- i_X_0 = jnp.zeros_like(i_X_pre)
34
- i_X_0 = i_X_0.at[0].set(jnp.eye(6))
35
-
36
- # Parent array mapping: i -> λ(i).
37
- # Exception: λ(0) must not be used, it's initialized to -1.
38
- λ = model.parent
39
-
40
- # ====================
41
- # Propagate kinematics
42
- # ====================
43
-
44
- PropagateKinematicsCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
45
- propagate_kinematics_carry = (i_X_λi, i_X_0)
46
-
47
- def propagate_kinematics(
48
- carry: PropagateKinematicsCarry, i: jtp.Int
49
- ) -> Tuple[PropagateKinematicsCarry, None]:
50
- i_X_λi, i_X_0 = carry
51
-
52
- # For each body (i), compute the parent (λi) to body (i) adjoint matrix
53
- i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
54
- i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
55
-
56
- # Compute the base (0) to body (i) adjoint matrix.
57
- # This works fine since we traverse the kinematic tree following the link
58
- # indices assigned with BFS.
59
- i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
60
- i_X_0 = i_X_0.at[i].set(i_X_0_i)
61
-
62
- return (i_X_λi, i_X_0), None
63
-
64
- (i_X_λi, i_X_0), _ = jax.lax.scan(
65
- f=propagate_kinematics,
66
- init=propagate_kinematics_carry,
67
- xs=np.arange(start=1, stop=model.NB),
68
- )
69
-
70
- # ============================
71
- # Compute doubly-left Jacobian
72
- # ============================
73
-
74
- J = jnp.zeros(shape=(6, 6 + model.dofs()))
75
-
76
- Jb = i_X_0[body_index]
77
- J = J.at[0:6, 0:6].set(Jb)
78
-
79
- # To make JIT happy, we operate on a boolean version of κ(i).
80
- # Checking if j ∈ κ(i) is equivalent to: κ_bool(j) is True.
81
- κ_bool = model.support_body_array_bool(body_index=body_index)
82
-
83
- def compute_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> Tuple[jtp.MatrixJax, None]:
84
- def update_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> jtp.MatrixJax:
85
- ii = i - 1
86
- Js_i = i_X_0[body_index] @ Adjoint.inverse(i_X_0[i]) @ S[i]
87
- J = J.at[0:6, 6 + ii].set(Js_i.squeeze())
88
-
89
- return J
90
-
91
- J = jax.lax.select(pred=κ_bool[i], on_true=update_jacobian(J, i), on_false=J)
92
- return J, None
93
-
94
- J, _ = jax.lax.scan(
95
- f=compute_jacobian, init=J, xs=np.arange(start=1, stop=model.NB)
96
- )
97
-
98
- return J