jaxsim 0.3.1.dev62__py3-none-any.whl → 0.3.1.dev94__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. jaxsim/__init__.py +5 -5
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/com.py +3 -4
  4. jaxsim/api/common.py +11 -11
  5. jaxsim/api/contact.py +11 -3
  6. jaxsim/api/data.py +3 -6
  7. jaxsim/api/frame.py +9 -10
  8. jaxsim/api/kin_dyn_parameters.py +25 -28
  9. jaxsim/api/link.py +12 -12
  10. jaxsim/api/model.py +47 -43
  11. jaxsim/api/ode.py +19 -12
  12. jaxsim/api/ode_data.py +11 -11
  13. jaxsim/integrators/common.py +19 -29
  14. jaxsim/integrators/fixed_step.py +10 -10
  15. jaxsim/integrators/variable_step.py +13 -13
  16. jaxsim/math/__init__.py +2 -1
  17. jaxsim/math/joint_model.py +2 -1
  18. jaxsim/math/quaternion.py +3 -9
  19. jaxsim/math/transform.py +2 -2
  20. jaxsim/mujoco/loaders.py +5 -5
  21. jaxsim/mujoco/model.py +6 -6
  22. jaxsim/mujoco/visualizer.py +3 -0
  23. jaxsim/parsers/__init__.py +0 -1
  24. jaxsim/parsers/descriptions/joint.py +1 -1
  25. jaxsim/parsers/descriptions/link.py +3 -4
  26. jaxsim/parsers/descriptions/model.py +1 -1
  27. jaxsim/parsers/kinematic_graph.py +38 -39
  28. jaxsim/parsers/rod/parser.py +14 -14
  29. jaxsim/parsers/rod/utils.py +9 -11
  30. jaxsim/rbda/aba.py +6 -12
  31. jaxsim/rbda/collidable_points.py +8 -7
  32. jaxsim/rbda/contacts/soft.py +29 -27
  33. jaxsim/rbda/crba.py +3 -3
  34. jaxsim/rbda/forward_kinematics.py +1 -1
  35. jaxsim/rbda/jacobian.py +8 -8
  36. jaxsim/rbda/rnea.py +3 -3
  37. jaxsim/rbda/utils.py +1 -1
  38. jaxsim/terrain/terrain.py +100 -22
  39. jaxsim/typing.py +14 -22
  40. jaxsim/utils/jaxsim_dataclass.py +4 -4
  41. jaxsim/utils/wrappers.py +5 -1
  42. {jaxsim-0.3.1.dev62.dist-info → jaxsim-0.3.1.dev94.dist-info}/METADATA +1 -1
  43. jaxsim-0.3.1.dev94.dist-info/RECORD +68 -0
  44. {jaxsim-0.3.1.dev62.dist-info → jaxsim-0.3.1.dev94.dist-info}/WHEEL +1 -1
  45. jaxsim-0.3.1.dev62.dist-info/RECORD +0 -68
  46. {jaxsim-0.3.1.dev62.dist-info → jaxsim-0.3.1.dev94.dist-info}/LICENSE +0 -0
  47. {jaxsim-0.3.1.dev62.dist-info → jaxsim-0.3.1.dev94.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,11 @@
1
1
  import os
2
2
 
3
- import jaxlie
4
3
  import numpy as np
5
4
  import numpy.typing as npt
6
5
  import rod
7
6
 
8
7
  import jaxsim.typing as jtp
9
- from jaxsim.math import Inertia
8
+ from jaxsim.math import Adjoint, Inertia
10
9
  from jaxsim.parsers import descriptions
11
10
 
12
11
 
@@ -21,10 +20,10 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
21
20
  The 6D inertia matrix of the link expressed in the link frame.
22
21
  """
23
22
 
24
- # Extract the "mass" element
23
+ # Extract the "mass" element.
25
24
  m = inertial.mass
26
25
 
27
- # Extract the "inertia" element
26
+ # Extract the "inertia" element.
28
27
  inertia_element = inertial.inertia
29
28
 
30
29
  ixx = inertia_element.ixx
@@ -34,7 +33,7 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
34
33
  ixz = inertia_element.ixz if inertia_element.ixz is not None else 0.0
35
34
  iyz = inertia_element.iyz if inertia_element.iyz is not None else 0.0
36
35
 
37
- # Build the 3x3 inertia matrix expressed in the CoM
36
+ # Build the 3x3 inertia matrix expressed in the CoM.
38
37
  I_CoM = np.array(
39
38
  [
40
39
  [ixx, ixy, ixz],
@@ -43,17 +42,16 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
43
42
  ]
44
43
  )
45
44
 
46
- # Build the 6x6 generalized inertia at the CoM
45
+ # Build the 6x6 generalized inertia at the CoM.
47
46
  M_CoM = Inertia.to_sixd(mass=m, com=np.zeros(3), I=I_CoM)
48
47
 
49
- # Compute the transform from the inertial frame (CoM) to the link frame
48
+ # Compute the transform from the inertial frame (CoM) to the link frame.
50
49
  L_H_CoM = inertial.pose.transform() if inertial.pose is not None else np.eye(4)
51
50
 
52
- # We need its inverse
53
- CoM_H_L = jaxlie.SE3.from_matrix(matrix=L_H_CoM).inverse()
54
- CoM_X_L = CoM_H_L.adjoint()
51
+ # We need its inverse.
52
+ CoM_X_L = Adjoint.from_transform(transform=L_H_CoM, inverse=True)
55
53
 
56
- # Express the CoM inertia matrix in the link frame L
54
+ # Express the CoM inertia matrix in the link frame L.
57
55
  M_L = CoM_X_L.T @ M_CoM @ CoM_X_L
58
56
 
59
57
  return M_L.astype(dtype=float)
jaxsim/rbda/aba.py CHANGED
@@ -102,7 +102,7 @@ def aba(
102
102
  i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
103
103
  i_X_0 = i_X_0.at[0].set(jnp.eye(6))
104
104
 
105
- # Initialize base quantities
105
+ # Initialize base quantities.
106
106
  if model.floating_base():
107
107
 
108
108
  # Base velocity v₀ in body-fixed representation.
@@ -121,10 +121,7 @@ def aba(
121
121
  # Pass 1
122
122
  # ======
123
123
 
124
- Pass1Carry = tuple[
125
- jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
126
- ]
127
-
124
+ Pass1Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
128
125
  pass_1_carry: Pass1Carry = (v, c, MA, pA, i_X_0)
129
126
 
130
127
  # Propagate kinematics and initialize AB inertia and AB bias forces.
@@ -178,10 +175,7 @@ def aba(
178
175
  d = jnp.zeros(shape=(model.number_of_links(), 1))
179
176
  u = jnp.zeros(shape=(model.number_of_links(), 1))
180
177
 
181
- Pass2Carry = tuple[
182
- jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
183
- ]
184
-
178
+ Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
185
179
  pass_2_carry: Pass2Carry = (U, d, u, MA, pA)
186
180
 
187
181
  def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]:
@@ -204,8 +198,8 @@ def aba(
204
198
 
205
199
  # Propagate them to the parent, handling the base link.
206
200
  def propagate(
207
- MA_pA: tuple[jtp.MatrixJax, jtp.MatrixJax]
208
- ) -> tuple[jtp.MatrixJax, jtp.MatrixJax]:
201
+ MA_pA: tuple[jtp.Matrix, jtp.Matrix]
202
+ ) -> tuple[jtp.Matrix, jtp.Matrix]:
209
203
 
210
204
  MA, pA = MA_pA
211
205
 
@@ -248,7 +242,7 @@ def aba(
248
242
  s̈ = jnp.zeros_like(s)
249
243
  a = jnp.zeros_like(v).at[0].set(a0)
250
244
 
251
- Pass3Carry = tuple[jtp.MatrixJax, jtp.VectorJax]
245
+ Pass3Carry = tuple[jtp.Matrix, jtp.Vector]
252
246
  pass_3_carry = (a, s̈)
253
247
 
254
248
  def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]:
@@ -80,7 +80,7 @@ def collidable_points_pos_vel(
80
80
  # Propagate kinematics
81
81
  # ====================
82
82
 
83
- PropagateTransformsCarry = tuple[jtp.MatrixJax, jtp.Matrix]
83
+ PropagateTransformsCarry = tuple[jtp.Matrix, jtp.Matrix]
84
84
  propagate_transforms_carry: PropagateTransformsCarry = (W_X_i, W_v_Wi)
85
85
 
86
86
  def propagate_kinematics(
@@ -97,7 +97,7 @@ def collidable_points_pos_vel(
97
97
  W_Xi_i = W_X_i[λ[i]] @ λi_X_i
98
98
  W_X_i = W_X_i.at[i].set(W_Xi_i)
99
99
 
100
- # Propagate the 6D velocity
100
+ # Propagate the 6D velocity.
101
101
  W_vi_Wi = W_v_Wi[λ[i]] + W_X_i[i] @ (S[i] * ṡ[ii]).squeeze()
102
102
  W_v_Wi = W_v_Wi.at[i].set(W_vi_Wi)
103
103
 
@@ -118,14 +118,15 @@ def collidable_points_pos_vel(
118
118
  # ==================================================
119
119
 
120
120
  def process_point_kinematics(
121
- Li_p_C: jtp.VectorJax, parent_body: jtp.Int
122
- ) -> tuple[jtp.VectorJax, jtp.VectorJax]:
123
- # Compute the position of the collidable point
121
+ Li_p_C: jtp.Vector, parent_body: jtp.Int
122
+ ) -> tuple[jtp.Vector, jtp.Vector]:
123
+
124
+ # Compute the position of the collidable point.
124
125
  W_p_Ci = (
125
126
  Adjoint.to_transform(adjoint=W_X_i[parent_body]) @ jnp.hstack([Li_p_C, 1])
126
127
  )[0:3]
127
128
 
128
- # Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}
129
+ # Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}.
129
130
  CW_vl_WCi = (
130
131
  jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()])
131
132
  @ W_v_Wi[parent_body].squeeze()
@@ -133,7 +134,7 @@ def collidable_points_pos_vel(
133
134
 
134
135
  return W_p_Ci, CW_vl_WCi
135
136
 
136
- # Process all the collidable points in parallel
137
+ # Process all the collidable points in parallel.
137
138
  W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)(
138
139
  model.kin_dyn_parameters.contact_parameters.point,
139
140
  jnp.array(model.kin_dyn_parameters.contact_parameters.body),
@@ -105,24 +105,24 @@ class SoftContactsParams(ContactsParams):
105
105
  - ξ < 1.0: under-damped
106
106
  """
107
107
 
108
- # Use symbols for input parameters
108
+ # Use symbols for input parameters.
109
109
  ξ = damping_ratio
110
110
  δ_max = max_penetration
111
111
  μc = static_friction_coefficient
112
112
 
113
- # Compute the total mass of the model
113
+ # Compute the total mass of the model.
114
114
  m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum()
115
115
 
116
- # Rename the standard gravity
116
+ # Rename the standard gravity.
117
117
  g = standard_gravity
118
118
 
119
- # Compute the average support force on each collidable point
119
+ # Compute the average support force on each collidable point.
120
120
  f_average = m * g / number_of_active_collidable_points_steady_state
121
121
 
122
- # Compute the stiffness to get the desired steady-state penetration
122
+ # Compute the stiffness to get the desired steady-state penetration.
123
123
  K = f_average / jnp.power(δ_max, 3 / 2)
124
124
 
125
- # Compute the damping using the damping ratio
125
+ # Compute the damping using the damping ratio.
126
126
  critical_damping = 2 * jnp.sqrt(K * m)
127
127
  D = ξ * critical_damping
128
128
 
@@ -151,14 +151,16 @@ class SoftContacts(ContactModel):
151
151
  default_factory=SoftContactsParams
152
152
  )
153
153
 
154
- terrain: Terrain = dataclasses.field(default_factory=FlatTerrain)
154
+ terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
155
+ default_factory=FlatTerrain
156
+ )
155
157
 
156
158
  def compute_contact_forces(
157
159
  self,
158
160
  position: jtp.Vector,
159
161
  velocity: jtp.Vector,
160
162
  tangential_deformation: jtp.Vector,
161
- ) -> tuple[jtp.Vector, tuple[jtp.Vector, None]]:
163
+ ) -> tuple[jtp.Vector, tuple[jtp.Vector]]:
162
164
  """
163
165
  Compute the contact forces and material deformation rate.
164
166
 
@@ -188,18 +190,18 @@ class SoftContacts(ContactModel):
188
190
  # Normal force computation
189
191
  # ========================
190
192
 
191
- # Unpack the position of the collidable point
193
+ # Unpack the position of the collidable point.
192
194
  px, py, pz = W_p_C = position.squeeze()
193
195
  vx, vy, vz = W_ṗ_C = velocity.squeeze()
194
196
 
195
- # Compute the terrain normal and the contact depth
197
+ # Compute the terrain normal and the contact depth.
196
198
  n̂ = self.terrain.normal(x=px, y=py).squeeze()
197
199
  h = jnp.array([0, 0, self.terrain.height(x=px, y=py) - pz])
198
200
 
199
- # Compute the penetration depth normal to the terrain
201
+ # Compute the penetration depth normal to the terrain.
200
202
  δ = jnp.maximum(0.0, jnp.dot(h, n̂))
201
203
 
202
- # Compute the penetration normal velocity
204
+ # Compute the penetration normal velocity.
203
205
  δ̇ = -jnp.dot(W_ṗ_C, n̂)
204
206
 
205
207
  # Non-linear spring-damper model.
@@ -210,10 +212,10 @@ class SoftContacts(ContactModel):
210
212
  on_false=jnp.array(0.0),
211
213
  )
212
214
 
213
- # Prevent negative normal forces that might occur when δ̇ is largely negative
215
+ # Prevent negative normal forces that might occur when δ̇ is largely negative.
214
216
  force_normal_mag = jnp.maximum(0.0, force_normal_mag)
215
217
 
216
- # Compute the 3D linear force in C[W] frame
218
+ # Compute the 3D linear force in C[W] frame.
217
219
  force_normal = force_normal_mag * n̂
218
220
 
219
221
  # ====================================
@@ -230,11 +232,11 @@ class SoftContacts(ContactModel):
230
232
  )
231
233
 
232
234
  def with_no_friction():
233
- # Compute 6D mixed force in C[W]
235
+ # Compute 6D mixed force in C[W].
234
236
  CW_f_lin = force_normal
235
237
  CW_f = jnp.hstack([force_normal, jnp.zeros_like(CW_f_lin)])
236
238
 
237
- # Compute lin-ang 6D forces (inertial representation)
239
+ # Compute lin-ang 6D forces (inertial representation).
238
240
  W_f = W_Xf_CW @ CW_f
239
241
 
240
242
  return W_f, (ṁ,)
@@ -258,32 +260,32 @@ class SoftContacts(ContactModel):
258
260
  return jnp.zeros(6), (ṁ,)
259
261
 
260
262
  def below_terrain():
261
- # Decompose the velocity in normal and tangential components
263
+ # Decompose the velocity in normal and tangential components.
262
264
  v_normal = jnp.dot(W_ṗ_C, n̂) * n̂
263
265
  v_tangential = W_ṗ_C - v_normal
264
266
 
265
- # Compute the tangential force. If inside the friction cone, the contact
267
+ # Compute the tangential force. If inside the friction cone, the contact.
266
268
  f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential)
267
269
 
268
270
  def sticking_contact():
269
- # Sum the normal and tangential forces, and create the 6D force
271
+ # Sum the normal and tangential forces, and create the 6D force.
270
272
  CW_f_stick = force_normal + f_tangential
271
273
  CW_f = jnp.hstack([CW_f_stick, jnp.zeros(3)])
272
274
 
273
- # In this case the 3D material deformation is the tangential velocity
275
+ # In this case the 3D material deformation is the tangential velocity.
274
276
  ṁ = v_tangential
275
277
 
276
278
  # Return the 6D force in the contact frame and
277
- # the deformation derivative
279
+ # the deformation derivative.
278
280
  return CW_f, ṁ
279
281
 
280
282
  def slipping_contact():
281
- # Project the force to the friction cone boundary
283
+ # Project the force to the friction cone boundary.
282
284
  f_tangential_projected = (μ * force_normal_mag) * (
283
285
  f_tangential / jnp.maximum(jnp.linalg.norm(f_tangential), 1e-9)
284
286
  )
285
287
 
286
- # Sum the normal and tangential forces, and create the 6D force
288
+ # Sum the normal and tangential forces, and create the 6D force.
287
289
  CW_f_slip = force_normal + f_tangential_projected
288
290
  CW_f = jnp.hstack([CW_f_slip, jnp.zeros(3)])
289
291
 
@@ -297,7 +299,7 @@ class SoftContacts(ContactModel):
297
299
  ṁ = (f_tangential_projected - α * m) / β
298
300
 
299
301
  # Return the 6D force in the contact frame and
300
- # the deformation derivative
302
+ # the deformation derivative.
301
303
  return CW_f, ṁ
302
304
 
303
305
  CW_f, ṁ = jax.lax.cond(
@@ -307,10 +309,10 @@ class SoftContacts(ContactModel):
307
309
  operand=None,
308
310
  )
309
311
 
310
- # Express the 6D force in the world frame
312
+ # Express the 6D force in the world frame.
311
313
  W_f = W_Xf_CW @ CW_f
312
314
 
313
- # Return the 6D force in the world frame and the deformation derivative
315
+ # Return the 6D force in the world frame and the deformation derivative.
314
316
  return W_f, (ṁ,)
315
317
 
316
318
  # (W_f, (ṁ,))
@@ -321,7 +323,7 @@ class SoftContacts(ContactModel):
321
323
  operand=None,
322
324
  )
323
325
 
324
- # (W_f, ṁ)
326
+ # (W_f, (ṁ,))
325
327
  return jax.lax.cond(
326
328
  pred=(μ == 0.0),
327
329
  true_fun=lambda _: with_no_friction(),
jaxsim/rbda/crba.py CHANGED
@@ -45,7 +45,7 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
45
45
  # Propagate kinematics
46
46
  # ====================
47
47
 
48
- ForwardPassCarry = tuple[jtp.MatrixJax]
48
+ ForwardPassCarry = tuple[jtp.Matrix]
49
49
  forward_pass_carry: ForwardPassCarry = (i_X_0,)
50
50
 
51
51
  def propagate_kinematics(
@@ -71,7 +71,7 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
71
71
 
72
72
  M = jnp.zeros(shape=(6 + model.dofs(), 6 + model.dofs()))
73
73
 
74
- BackwardPassCarry = tuple[jtp.MatrixJax, jtp.MatrixJax]
74
+ BackwardPassCarry = tuple[jtp.Matrix, jtp.Matrix]
75
75
  backward_pass_carry: BackwardPassCarry = (Mc, M)
76
76
 
77
77
  def backward_pass(
@@ -90,7 +90,7 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
90
90
 
91
91
  j = i
92
92
 
93
- CarryInnerFn = tuple[jtp.Int, jtp.MatrixJax, jtp.MatrixJax]
93
+ CarryInnerFn = tuple[jtp.Int, jtp.Matrix, jtp.Matrix]
94
94
  carry_inner_fn = (j, Fi, M)
95
95
 
96
96
  def while_loop_body(carry: CarryInnerFn) -> CarryInnerFn:
@@ -61,7 +61,7 @@ def forward_kinematics_model(
61
61
  # Propagate the kinematics
62
62
  # ========================
63
63
 
64
- PropagateKinematicsCarry = tuple[jtp.MatrixJax]
64
+ PropagateKinematicsCarry = tuple[jtp.Matrix]
65
65
  propagate_kinematics_carry: PropagateKinematicsCarry = (W_X_i,)
66
66
 
67
67
  def propagate_kinematics(
jaxsim/rbda/jacobian.py CHANGED
@@ -50,7 +50,7 @@ def jacobian(
50
50
  # Propagate kinematics
51
51
  # ====================
52
52
 
53
- PropagateKinematicsCarry = tuple[jtp.MatrixJax]
53
+ PropagateKinematicsCarry = tuple[jtp.Matrix]
54
54
  propagate_kinematics_carry: PropagateKinematicsCarry = (i_X_0,)
55
55
 
56
56
  def propagate_kinematics(
@@ -86,9 +86,9 @@ def jacobian(
86
86
  # Checking if j ∈ κ(i) is equivalent to: κ_bool(j) is True.
87
87
  κ_bool = model.kin_dyn_parameters.support_body_array_bool[link_index]
88
88
 
89
- def compute_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> tuple[jtp.MatrixJax, None]:
89
+ def compute_jacobian(J: jtp.Matrix, i: jtp.Int) -> tuple[jtp.Matrix, None]:
90
90
 
91
- def update_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> jtp.MatrixJax:
91
+ def update_jacobian(J: jtp.Matrix, i: jtp.Int) -> jtp.Matrix:
92
92
 
93
93
  ii = i - 1
94
94
 
@@ -155,16 +155,16 @@ def jacobian_full_doubly_left(
155
155
  B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
156
156
  B_X_i = B_X_i.at[0].set(jnp.eye(6))
157
157
 
158
- # =============================
159
- # Compute doubly-left Jacobian
160
- # =============================
158
+ # =================================
159
+ # Compute doubly-left full Jacobian
160
+ # =================================
161
161
 
162
162
  # Allocate the Jacobian matrix.
163
163
  # The Jbb section of the doubly-left Jacobian is an identity matrix.
164
164
  J = jnp.zeros(shape=(6, 6 + model.dofs()))
165
165
  J = J.at[0:6, 0:6].set(jnp.eye(6))
166
166
 
167
- ComputeFullJacobianCarry = tuple[jtp.MatrixJax, jtp.MatrixJax]
167
+ ComputeFullJacobianCarry = tuple[jtp.Matrix, jtp.Matrix]
168
168
  compute_full_jacobian_carry: ComputeFullJacobianCarry = (B_X_i, J)
169
169
 
170
170
  def compute_full_jacobian(
@@ -261,7 +261,7 @@ def jacobian_derivative_full_doubly_left(
261
261
  J̇ = jnp.zeros(shape=(6, 6 + model.dofs()))
262
262
 
263
263
  ComputeFullJacobianDerivativeCarry = tuple[
264
- jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
264
+ jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix
265
265
  ]
266
266
 
267
267
  compute_full_jacobian_derivative_carry: ComputeFullJacobianDerivativeCarry = (
jaxsim/rbda/rnea.py CHANGED
@@ -132,7 +132,7 @@ def rnea(
132
132
  # Pass 1
133
133
  # ======
134
134
 
135
- ForwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax]
135
+ ForwardPassCarry = Tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
136
136
  forward_pass_carry: ForwardPassCarry = (v, a, i_X_0, f)
137
137
 
138
138
  def forward_pass(
@@ -186,7 +186,7 @@ def rnea(
186
186
 
187
187
  τ = jnp.zeros_like(s)
188
188
 
189
- BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
189
+ BackwardPassCarry = Tuple[jtp.Vector, jtp.Matrix]
190
190
  backward_pass_carry: BackwardPassCarry = (τ, f)
191
191
 
192
192
  def backward_pass(
@@ -201,7 +201,7 @@ def rnea(
201
201
  τ = τ.at[ii].set(τ_i.squeeze())
202
202
 
203
203
  # Propagate the force to the parent link.
204
- def update_f(f: jtp.MatrixJax) -> jtp.MatrixJax:
204
+ def update_f(f: jtp.Matrix) -> jtp.Matrix:
205
205
 
206
206
  f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
207
207
  f = f.at[λ[i]].set(f_λi)
jaxsim/rbda/utils.py CHANGED
@@ -19,7 +19,7 @@ def process_inputs(
19
19
  joint_accelerations: jtp.VectorLike | None = None,
20
20
  joint_forces: jtp.VectorLike | None = None,
21
21
  link_forces: jtp.MatrixLike | None = None,
22
- standard_gravity: jtp.VectorLike | None = None,
22
+ standard_gravity: jtp.ScalarLike | None = None,
23
23
  ) -> tuple[
24
24
  jtp.Vector,
25
25
  jtp.Vector,
jaxsim/terrain/terrain.py CHANGED
@@ -1,4 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  import abc
4
+ import dataclasses
2
5
 
3
6
  import jax.numpy as jnp
4
7
  import jax_dataclasses
@@ -7,22 +10,23 @@ import jaxsim.typing as jtp
7
10
 
8
11
 
9
12
  class Terrain(abc.ABC):
13
+
10
14
  delta = 0.010
11
15
 
12
16
  @abc.abstractmethod
13
- def height(self, x: float, y: float) -> float:
17
+ def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
14
18
  pass
15
19
 
16
- def normal(self, x: float, y: float) -> jtp.Vector:
20
+ def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
17
21
  """
18
22
  Compute the normal vector of the terrain at a specific (x, y) location.
19
23
 
20
24
  Args:
21
- x (float): The x-coordinate of the location.
22
- y (float): The y-coordinate of the location.
25
+ x: The x-coordinate of the location.
26
+ y: The y-coordinate of the location.
23
27
 
24
28
  Returns:
25
- jtp.Vector: The normal vector of the terrain surface at the specified location.
29
+ The normal vector of the terrain surface at the specified location.
26
30
  """
27
31
 
28
32
  # https://stackoverflow.com/a/5282364
@@ -40,43 +44,117 @@ class Terrain(abc.ABC):
40
44
 
41
45
  @jax_dataclasses.pytree_dataclass
42
46
  class FlatTerrain(Terrain):
43
- def height(self, x: float, y: float) -> float:
44
- return 0.0
47
+
48
+ z: float = dataclasses.field(default=0.0, kw_only=True)
49
+
50
+ @staticmethod
51
+ def build(height: jtp.FloatLike) -> FlatTerrain:
52
+
53
+ return FlatTerrain(z=float(height))
54
+
55
+ def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
56
+
57
+ return jnp.array(self.z, dtype=float)
58
+
59
+ def __hash__(self) -> int:
60
+
61
+ return hash(self.z)
62
+
63
+ def __eq__(self, other: FlatTerrain) -> bool:
64
+
65
+ if not isinstance(other, FlatTerrain):
66
+ return False
67
+
68
+ return self.z == other.z
45
69
 
46
70
 
47
71
  @jax_dataclasses.pytree_dataclass
48
- class PlaneTerrain(Terrain):
49
- plane_normal: list = jax_dataclasses.field(default_factory=lambda: [0, 0, 1.0])
72
+ class PlaneTerrain(FlatTerrain):
73
+
74
+ plane_normal: tuple[float, float, float] = jax_dataclasses.field(
75
+ default=(0.0, 0.0, 0.0), kw_only=True
76
+ )
50
77
 
51
78
  @staticmethod
52
- def build(plane_normal: list) -> "PlaneTerrain":
79
+ def build(
80
+ plane_normal: jtp.VectorLike, plane_height_over_origin: jtp.FloatLike = 0.0
81
+ ) -> PlaneTerrain:
53
82
  """
54
83
  Create a PlaneTerrain instance with a specified plane normal vector.
55
84
 
56
85
  Args:
57
- plane_normal (list): The normal vector of the terrain plane.
86
+ plane_normal: The normal vector of the terrain plane.
87
+ plane_height_over_origin: The height of the plane over the origin.
58
88
 
59
89
  Returns:
60
90
  PlaneTerrain: A PlaneTerrain instance.
61
91
  """
62
- if not isinstance(plane_normal, list):
63
- raise TypeError(
64
- f"Expected a list for the plane normal vector, got: {type(plane_normal)}."
65
- )
66
92
 
67
- return PlaneTerrain(plane_normal=plane_normal)
93
+ plane_normal = jnp.array(plane_normal, dtype=float)
94
+ plane_height_over_origin = jnp.array(plane_height_over_origin, dtype=float)
95
+
96
+ if plane_normal.shape != (3,):
97
+ msg = "Expected a 3D vector for the plane normal, got '{}'."
98
+ raise ValueError(msg.format(plane_normal.shape))
68
99
 
69
- def height(self, x: float, y: float) -> float:
100
+ # Make sure that the plane normal is a unit vector.
101
+ plane_normal = plane_normal / jnp.linalg.norm(plane_normal)
102
+
103
+ return PlaneTerrain(
104
+ z=float(plane_height_over_origin),
105
+ plane_normal=tuple(plane_normal.tolist()),
106
+ )
107
+
108
+ def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
70
109
  """
71
110
  Compute the height of the terrain at a specific (x, y) location on a plane.
72
111
 
73
112
  Args:
74
- x (float): The x-coordinate of the location.
75
- y (float): The y-coordinate of the location.
113
+ x: The x-coordinate of the location.
114
+ y: The y-coordinate of the location.
76
115
 
77
116
  Returns:
78
- float: The height of the terrain at the specified location on the plane.
117
+ The height of the terrain at the specified location on the plane.
79
118
  """
80
119
 
81
- a, b, c = self.plane_normal
82
- return -(a * x + b * y) / c
120
+ # Equation of the plane: A x + B y + C z + D = 0
121
+ # Normal vector coordinates: (A, B, C)
122
+ # The height over the origin: -D/C
123
+
124
+ # Get the plane equation coefficients from the terrain normal.
125
+ A, B, C = self.plane_normal
126
+
127
+ # Compute the final coefficient D considering the terrain height.
128
+ D = -C * self.z
129
+
130
+ # Invert the plane equation to get the height at the given (x, y) coordinates.
131
+ return jnp.array(-(A * x + B * y + D) / C).astype(float)
132
+
133
+ def __hash__(self) -> int:
134
+
135
+ from jaxsim.utils.wrappers import HashedNumpyArray
136
+
137
+ return hash(
138
+ (
139
+ hash(self.z),
140
+ HashedNumpyArray.hash_of_array(
141
+ array=jnp.array(self.plane_normal, dtype=float)
142
+ ),
143
+ )
144
+ )
145
+
146
+ def __eq__(self, other: PlaneTerrain) -> bool:
147
+
148
+ if not isinstance(other, PlaneTerrain):
149
+ return False
150
+
151
+ if not (
152
+ jnp.allclose(self.z, other.z)
153
+ and jnp.allclose(
154
+ jnp.array(self.plane_normal, dtype=float),
155
+ jnp.array(other.plane_normal, dtype=float),
156
+ )
157
+ ):
158
+ return False
159
+
160
+ return True
jaxsim/typing.py CHANGED
@@ -7,14 +7,14 @@ import jax
7
7
  # JAX types
8
8
  # =========
9
9
 
10
- ScalarJax = jax.Array
11
- IntJax = ScalarJax
12
- BoolJax = ScalarJax
13
- FloatJax = ScalarJax
10
+ Array = jax.Array
11
+ Scalar = Array
12
+ Vector = Array
13
+ Matrix = Array
14
14
 
15
- ArrayJax = jax.Array
16
- VectorJax = ArrayJax
17
- MatrixJax = ArrayJax
15
+ Int = Scalar
16
+ Bool = Scalar
17
+ Float = Scalar
18
18
 
19
19
  PyTree = (
20
20
  dict[Hashable, "PyTree"] | list["PyTree"] | tuple["PyTree"] | None | jax.Array | Any
@@ -24,19 +24,11 @@ PyTree = (
24
24
  # Mixed JAX / NumPy types
25
25
  # =======================
26
26
 
27
- Array = jax.typing.ArrayLike
28
- Scalar = Array
29
- Vector = Array
30
- Matrix = Array
27
+ ArrayLike = jax.typing.ArrayLike | tuple
28
+ ScalarLike = int | float | Scalar | ArrayLike
29
+ VectorLike = Vector | ArrayLike | tuple
30
+ MatrixLike = Matrix | ArrayLike
31
31
 
32
- Int = int | IntJax
33
- Bool = bool | ArrayJax
34
- Float = float | FloatJax
35
-
36
- ScalarLike = Scalar | int | float
37
- ArrayLike = Array
38
- VectorLike = Vector
39
- MatrixLike = Matrix
40
- IntLike = Int
41
- BoolLike = Bool
42
- FloatLike = Float
32
+ IntLike = int | Int | jax.typing.ArrayLike
33
+ BoolLike = bool | Bool | jax.typing.ArrayLike
34
+ FloatLike = float | Float | jax.typing.ArrayLike