jaxsim 0.6.1.dev13__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. jaxsim/__init__.py +1 -1
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/actuation_model.py +96 -0
  5. jaxsim/api/com.py +8 -8
  6. jaxsim/api/contact.py +15 -255
  7. jaxsim/api/contact_model.py +101 -0
  8. jaxsim/api/data.py +258 -556
  9. jaxsim/api/frame.py +7 -7
  10. jaxsim/api/integrators.py +76 -0
  11. jaxsim/api/kin_dyn_parameters.py +41 -58
  12. jaxsim/api/link.py +7 -7
  13. jaxsim/api/model.py +190 -453
  14. jaxsim/api/ode.py +34 -338
  15. jaxsim/api/references.py +2 -2
  16. jaxsim/exceptions.py +2 -2
  17. jaxsim/math/__init__.py +4 -3
  18. jaxsim/math/joint_model.py +17 -107
  19. jaxsim/mujoco/model.py +1 -1
  20. jaxsim/mujoco/utils.py +2 -2
  21. jaxsim/parsers/kinematic_graph.py +1 -3
  22. jaxsim/rbda/aba.py +7 -4
  23. jaxsim/rbda/collidable_points.py +7 -98
  24. jaxsim/rbda/contacts/__init__.py +2 -10
  25. jaxsim/rbda/contacts/common.py +0 -138
  26. jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
  27. jaxsim/rbda/crba.py +5 -2
  28. jaxsim/rbda/forward_kinematics.py +37 -12
  29. jaxsim/rbda/jacobian.py +15 -6
  30. jaxsim/rbda/rnea.py +7 -4
  31. jaxsim/rbda/utils.py +3 -3
  32. jaxsim/utils/jaxsim_dataclass.py +5 -1
  33. {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
  34. jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
  35. jaxsim/api/ode_data.py +0 -401
  36. jaxsim/integrators/__init__.py +0 -2
  37. jaxsim/integrators/common.py +0 -592
  38. jaxsim/integrators/fixed_step.py +0 -153
  39. jaxsim/integrators/variable_step.py +0 -706
  40. jaxsim/rbda/contacts/rigid.py +0 -462
  41. jaxsim/rbda/contacts/soft.py +0 -480
  42. jaxsim/rbda/contacts/visco_elastic.py +0 -1066
  43. jaxsim-0.6.1.dev13.dist-info/RECORD +0 -74
  44. {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
  45. {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
  46. {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
jaxsim/api/ode.py CHANGED
@@ -1,230 +1,24 @@
1
- from typing import Any, Protocol
2
-
3
1
  import jax
4
2
  import jax.numpy as jnp
5
3
 
6
4
  import jaxsim.api as js
7
- import jaxsim.rbda
8
5
  import jaxsim.typing as jtp
9
- from jaxsim.integrators import Time
10
- from jaxsim.math import Quaternion
11
- from jaxsim.rbda import contacts
6
+ from jaxsim.api.data import JaxSimModelData
7
+ from jaxsim.math import Quaternion, Skew
12
8
 
13
9
  from .common import VelRepr
14
- from .ode_data import ODEState
15
-
16
-
17
- class SystemDynamicsFromModelAndData(Protocol):
18
- """
19
- Protocol defining the signature of a function computing the system dynamics
20
- given a model and data object.
21
- """
22
-
23
- def __call__(
24
- self,
25
- model: js.model.JaxSimModel,
26
- data: js.data.JaxSimModelData,
27
- **kwargs: dict[str, Any],
28
- ) -> tuple[ODEState, dict[str, Any]]:
29
- """
30
- Compute the system dynamics given a model and data object.
31
-
32
- Args:
33
- model: The model to consider.
34
- data: The data of the considered model.
35
- **kwargs: Additional keyword arguments.
36
-
37
- Returns:
38
- A tuple with an `ODEState` object storing in each of its attributes the
39
- corresponding derivative, and the dictionary of auxiliary data returned
40
- by the system dynamics evaluation.
41
- """
42
-
43
- pass
44
-
45
-
46
- def wrap_system_dynamics_for_integration(
47
- *,
48
- system_dynamics: SystemDynamicsFromModelAndData,
49
- **kwargs: dict[str, Any],
50
- ) -> jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]:
51
- """
52
- Wrap the system dynamics considered by JaxSim integrators in a generic
53
- `f(x, t, **u, **parameters)` function.
54
-
55
- Args:
56
- system_dynamics: The system dynamics to wrap.
57
- **kwargs: Additional kwargs to close over the system dynamics.
58
-
59
- Returns:
60
- The system dynamics closed over the additional kwargs to be used by
61
- JaxSim integrators.
62
- """
63
-
64
- # Close `system_dynamics` over additional kwargs.
65
- # Similarly to what done in `jaxsim.api.model.step`, to be future-proof, we use the
66
- # following logic to allow the caller to close over arguments having the same name
67
- # of the ones used in the `wrap_system_dynamics_for_integration` function.
68
- kwargs = kwargs.copy() if kwargs is not None else {}
69
- colliding_system_dynamics_kwargs = kwargs.pop("system_dynamics_kwargs", {})
70
- system_dynamics_kwargs = kwargs | colliding_system_dynamics_kwargs
71
-
72
- # Remove `model` and `data` for backward compatibility.
73
- # It's no longer necessary to close over them at this stage, as this is always
74
- # done in `jaxsim.api.model.step`.
75
- # We can remove the following lines in a few releases.
76
- _ = system_dynamics_kwargs.pop("data", None)
77
- _ = system_dynamics_kwargs.pop("model", None)
78
-
79
- # Create the function with the signature expected by our generic integrators.
80
- # Note that our system dynamics is time independent.
81
- def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
82
-
83
- # Get the data and model objects from the kwargs.
84
- data_f = kwargs_f.pop("data")
85
- model_f = kwargs_f.pop("model")
86
-
87
- # Update the state and time stored inside data.
88
- with data_f.editable(validate=True) as data_rw:
89
- data_rw.state = x
90
-
91
- # Evaluate the system dynamics, allowing to override the kwargs originally
92
- # passed when the closure was created.
93
- return system_dynamics(
94
- model=model_f,
95
- data=data_rw,
96
- **(system_dynamics_kwargs | kwargs_f),
97
- )
98
-
99
- f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]
100
- return f
101
-
102
10
 
103
11
  # ==================================
104
12
  # Functions defining system dynamics
105
13
  # ==================================
106
14
 
107
15
 
108
- @jax.jit
109
- @js.common.named_scope
110
- def system_velocity_dynamics(
111
- model: js.model.JaxSimModel,
112
- data: js.data.JaxSimModelData,
113
- *,
114
- link_forces: jtp.Vector | None = None,
115
- joint_force_references: jtp.Vector | None = None,
116
- ) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]:
117
- """
118
- Compute the dynamics of the system velocity.
119
-
120
- Args:
121
- model: The model to consider.
122
- data: The data of the considered model.
123
- link_forces:
124
- The 6D forces to apply to the links expressed in the frame corresponding to
125
- the velocity representation of `data`.
126
- joint_force_references: The joint force references to apply.
127
-
128
- Returns:
129
- A tuple containing the derivative of the base 6D velocity in inertial-fixed
130
- representation, the derivative of the joint velocities, and auxiliary data
131
- returned by the system dynamics evaluation.
132
- """
133
-
134
- # Build link forces if not provided.
135
- # These forces are expressed in the frame corresponding to the velocity
136
- # representation of data.
137
- O_f_L = (
138
- jnp.atleast_2d(link_forces.squeeze())
139
- if link_forces is not None
140
- else jnp.zeros((model.number_of_links(), 6))
141
- ).astype(float)
142
-
143
- # We expect that the 6D forces included in the `link_forces` argument are expressed
144
- # in the frame corresponding to the velocity representation of `data`.
145
- references = js.references.JaxSimModelReferences.build(
146
- model=model,
147
- link_forces=O_f_L,
148
- joint_force_references=joint_force_references,
149
- data=data,
150
- velocity_representation=data.velocity_representation,
151
- )
152
-
153
- # ======================
154
- # Compute contact forces
155
- # ======================
156
-
157
- # Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
158
- # with the terrain.
159
- W_f_L_terrain = jnp.zeros_like(O_f_L).astype(float)
160
-
161
- # Initialize a dictionary of auxiliary data.
162
- # This dictionary is used to store additional data computed by the contact model.
163
- aux_data = {}
164
-
165
- if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
166
-
167
- with (
168
- data.switch_velocity_representation(VelRepr.Inertial),
169
- references.switch_velocity_representation(VelRepr.Inertial),
170
- ):
171
-
172
- # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
173
- # along with contact-specific auxiliary states.
174
- W_f_C, aux_data = js.contact.collidable_point_dynamics(
175
- model=model,
176
- data=data,
177
- link_forces=references.link_forces(model=model, data=data),
178
- joint_force_references=references.joint_force_references(model=model),
179
- )
180
-
181
- # Compute the 6D forces applied to the links equivalent to the forces applied
182
- # to the frames associated to the collidable points.
183
- W_f_L_terrain = model.contact_model.link_forces_from_contact_forces(
184
- model=model,
185
- data=data,
186
- contact_forces=W_f_C,
187
- )
188
-
189
- # ===========================
190
- # Compute system acceleration
191
- # ===========================
192
-
193
- # Compute the total link forces.
194
- with (
195
- data.switch_velocity_representation(VelRepr.Inertial),
196
- references.switch_velocity_representation(VelRepr.Inertial),
197
- ):
198
-
199
- # Sum the contact forces just computed with the link forces applied by the user.
200
- references = references.apply_link_forces(
201
- model=model,
202
- data=data,
203
- forces=W_f_L_terrain,
204
- additive=True,
205
- )
206
-
207
- # Get the link forces in inertial-fixed representation.
208
- f_L_total = references.link_forces(model=model, data=data)
209
-
210
- # Compute the system acceleration in inertial-fixed representation.
211
- # This representation is useful for integration purpose.
212
- W_v̇_WB, s̈ = system_acceleration(
213
- model=model,
214
- data=data,
215
- joint_force_references=joint_force_references,
216
- link_forces=f_L_total,
217
- )
218
-
219
- return W_v̇_WB, s̈, aux_data
220
-
221
-
222
16
  def system_acceleration(
223
17
  model: js.model.JaxSimModel,
224
18
  data: js.data.JaxSimModelData,
225
19
  *,
226
20
  link_forces: jtp.MatrixLike | None = None,
227
- joint_force_references: jtp.VectorLike | None = None,
21
+ joint_torques: jtp.VectorLike | None = None,
228
22
  ) -> tuple[jtp.Vector, jtp.Vector]:
229
23
  """
230
24
  Compute the system acceleration in the active representation.
@@ -235,7 +29,7 @@ def system_acceleration(
235
29
  link_forces:
236
30
  The 6D forces to apply to the links expressed in the same
237
31
  velocity representation of data.
238
- joint_force_references: The joint force references to apply.
32
+ joint_torques: The joint torques applied to the joints.
239
33
 
240
34
  Returns:
241
35
  A tuple containing the base 6D acceleration in the active representation
@@ -253,80 +47,6 @@ def system_acceleration(
253
47
  else jnp.zeros((model.number_of_links(), 6))
254
48
  ).astype(float)
255
49
 
256
- # Build joint torques if not provided.
257
- τ_references = (
258
- jnp.atleast_1d(joint_force_references.squeeze())
259
- if joint_force_references is not None
260
- else jnp.zeros_like(data.joint_positions())
261
- ).astype(float)
262
-
263
- # ====================
264
- # Enforce joint limits
265
- # ====================
266
-
267
- τ_position_limit = jnp.zeros_like(τ_references).astype(float)
268
-
269
- if model.dofs() > 0:
270
-
271
- # Stiffness and damper parameters for the joint position limits.
272
- k_j = jnp.array(
273
- model.kin_dyn_parameters.joint_parameters.position_limit_spring
274
- ).astype(float)
275
- d_j = jnp.array(
276
- model.kin_dyn_parameters.joint_parameters.position_limit_damper
277
- ).astype(float)
278
-
279
- # Compute the joint position limit violations.
280
- lower_violation = jnp.clip(
281
- data.state.physics_model.joint_positions
282
- - model.kin_dyn_parameters.joint_parameters.position_limits_min,
283
- max=0.0,
284
- )
285
-
286
- upper_violation = jnp.clip(
287
- data.state.physics_model.joint_positions
288
- - model.kin_dyn_parameters.joint_parameters.position_limits_max,
289
- min=0.0,
290
- )
291
-
292
- # Compute the joint position limit torque.
293
- τ_position_limit -= jnp.diag(k_j) @ (lower_violation + upper_violation)
294
-
295
- τ_position_limit -= (
296
- jnp.positive(τ_position_limit)
297
- * jnp.diag(d_j)
298
- @ data.state.physics_model.joint_velocities
299
- )
300
-
301
- # ====================
302
- # Joint friction model
303
- # ====================
304
-
305
- τ_friction = jnp.zeros_like(τ_references).astype(float)
306
-
307
- if model.dofs() > 0:
308
-
309
- # Static and viscous joint friction parameters
310
- kc = jnp.array(
311
- model.kin_dyn_parameters.joint_parameters.friction_static
312
- ).astype(float)
313
- kv = jnp.array(
314
- model.kin_dyn_parameters.joint_parameters.friction_viscous
315
- ).astype(float)
316
-
317
- # Compute the joint friction torque.
318
- τ_friction = -(
319
- jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_velocities)
320
- + jnp.diag(kv) @ data.state.physics_model.joint_velocities
321
- )
322
-
323
- # ========================
324
- # Compute forward dynamics
325
- # ========================
326
-
327
- # Compute the total joint forces.
328
- τ_total = τ_references + τ_friction + τ_position_limit
329
-
330
50
  # Store the link forces in a references object.
331
51
  references = js.references.JaxSimModelReferences.build(
332
52
  model=model,
@@ -345,7 +65,7 @@ def system_acceleration(
345
65
  v̇_WB, s̈ = js.model.forward_dynamics_aba(
346
66
  model=model,
347
67
  data=data,
348
- joint_forces=τ_total,
68
+ joint_forces=joint_torques,
349
69
  link_forces=references.link_forces(model=model, data=data),
350
70
  )
351
71
 
@@ -359,7 +79,7 @@ def system_position_dynamics(
359
79
  data: js.data.JaxSimModelData,
360
80
  baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
361
81
  ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
362
- """
82
+ r"""
363
83
  Compute the dynamics of the system position.
364
84
 
365
85
  Args:
@@ -371,16 +91,18 @@ def system_position_dynamics(
371
91
  Returns:
372
92
  A tuple containing the derivative of the base position, the derivative of the
373
93
  base quaternion, and the derivative of the joint positions.
374
- """
375
94
 
376
- ṡ = data.joint_velocities(model=model)
377
- W_Q_B = data.base_orientation(dcm=False)
378
-
379
- with data.switch_velocity_representation(velocity_representation=VelRepr.Mixed):
380
- W_ṗ_B = data.base_velocity()[0:3]
95
+ Note:
96
+ In inertial-fixed representation, the linear component of the base velocity is not
97
+ the derivative of the base position. In fact, the base velocity is defined as:
98
+ :math:`{} ^W v_{W, B} = \begin{bmatrix} {} ^W \dot{p}_B S({} ^W \omega_{W, B}) {} ^W p _B\\ {} ^W \omega_{W, B} \end{bmatrix}`.
99
+ Where :math:`S(\cdot)` is the skew-symmetric matrix operator.
100
+ """
381
101
 
382
- with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
383
- W_ω_WB = data.base_velocity()[3:6]
102
+ = data.joint_velocities
103
+ W_Q_B = data.base_orientation
104
+ W_ω_WB = data.base_velocity[3:6]
105
+ W_ṗ_B = data.base_velocity[0:3] + Skew.wedge(W_ω_WB) @ data.base_position
384
106
 
385
107
  W_Q̇_B = Quaternion.derivative(
386
108
  quaternion=W_Q_B,
@@ -399,9 +121,9 @@ def system_dynamics(
399
121
  data: js.data.JaxSimModelData,
400
122
  *,
401
123
  link_forces: jtp.Vector | None = None,
402
- joint_force_references: jtp.Vector | None = None,
124
+ joint_torques: jtp.Vector | None = None,
403
125
  baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
404
- ) -> tuple[ODEState, dict[str, Any]]:
126
+ ) -> JaxSimModelData:
405
127
  """
406
128
  Compute the dynamics of the system.
407
129
 
@@ -411,57 +133,32 @@ def system_dynamics(
411
133
  link_forces:
412
134
  The 6D forces to apply to the links expressed in the frame corresponding to
413
135
  the velocity representation of `data`.
414
- joint_force_references: The joint force references to apply.
136
+ joint_torques: The joint torques acting on the joints.
415
137
  baumgarte_quaternion_regularization:
416
138
  The Baumgarte regularization coefficient used to adjust the norm of the
417
139
  quaternion (only used in integrators not operating on the SO(3) manifold).
418
140
 
419
141
  Returns:
420
- A tuple with an `ODEState` object storing in each of its attributes the
142
+ A tuple with an `JaxSimModelData` object storing in each of its attributes the
421
143
  corresponding derivative, and the dictionary of auxiliary data returned
422
144
  by the system dynamics evaluation.
423
145
  """
424
146
 
425
- # Compute the accelerations and the material deformation rate.
426
- W_v̇_WB, s̈, aux_dict = system_velocity_dynamics(
427
- model=model,
428
- data=data,
429
- joint_force_references=joint_force_references,
430
- link_forces=link_forces,
431
- )
432
-
433
- # Initialize the dictionary storing the derivative of the additional state variables
434
- # that extend the state vector of the integrated ODE system.
435
- extended_ode_state = {}
436
-
437
- match model.contact_model:
438
-
439
- case contacts.SoftContacts():
440
- extended_ode_state["tangential_deformation"] = aux_dict["m_dot"]
441
-
442
- case contacts.ViscoElasticContacts():
443
-
444
- extended_ode_state["tangential_deformation"] = jnp.zeros_like(
445
- data.state.extended["tangential_deformation"]
446
- )
447
-
448
- case contacts.RigidContacts() | contacts.RelaxedRigidContacts():
449
- pass
450
-
451
- case _:
452
- raise ValueError(f"Invalid contact model: {model.contact_model}")
147
+ with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
148
+ W_v̇_WB, s̈ = system_acceleration(
149
+ model=model,
150
+ data=data,
151
+ joint_torques=joint_torques,
152
+ link_forces=link_forces,
153
+ )
453
154
 
454
- # Extract the velocities.
455
- W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
456
- model=model,
457
- data=data,
458
- baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
459
- )
155
+ W_ṗ_B, W_Q̇_B, = system_position_dynamics(
156
+ model=model,
157
+ data=data,
158
+ baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
159
+ )
460
160
 
461
- # Create an ODEState object populated with the derivative of each leaf.
462
- # Our integrators, operating on generic pytrees, will be able to handle it
463
- # automatically as state derivative.
464
- ode_state_derivative = ODEState.build_from_jaxsim_model(
161
+ ode_state_derivative = JaxSimModelData.build(
465
162
  model=model,
466
163
  base_position=W_ṗ_B,
467
164
  base_quaternion=W_Q̇_B,
@@ -469,7 +166,6 @@ def system_dynamics(
469
166
  base_linear_velocity=W_v̇_WB[0:3],
470
167
  base_angular_velocity=W_v̇_WB[3:6],
471
168
  joint_velocities=s̈,
472
- **extended_ode_state,
473
169
  )
474
170
 
475
- return ode_state_derivative, aux_dict
171
+ return ode_state_derivative
jaxsim/api/references.py CHANGED
@@ -242,7 +242,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
242
242
  )(W_f_L, W_H_L)
243
243
 
244
244
  # The f_L output is either L_f_L or LW_f_L, depending on the representation.
245
- W_H_L = js.model.forward_kinematics(model=model, data=data)
245
+ W_H_L = data._link_transforms
246
246
  f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :])
247
247
 
248
248
  return f_L
@@ -450,7 +450,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
450
450
  )(f_L, W_H_L)
451
451
 
452
452
  # The f_L input is either L_f_L or LW_f_L, depending on the representation.
453
- W_H_L = js.model.forward_kinematics(model=model, data=data)
453
+ W_H_L = data._link_transforms
454
454
  W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :])
455
455
 
456
456
  return replace(forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L))
jaxsim/exceptions.py CHANGED
@@ -23,8 +23,8 @@ def raise_if(
23
23
 
24
24
  # Disable host callback if running on unsupported hardware or if the user
25
25
  # explicitly disabled it.
26
- if jax.devices()[0].platform in {"tpu", "METAL"} or os.environ.get(
27
- "JAXSIM_DISABLE_EXCEPTIONS", 0
26
+ if jax.devices()[0].platform in {"tpu", "METAL"} or not os.environ.get(
27
+ "JAXSIM_ENABLE_EXCEPTIONS", 0
28
28
  ):
29
29
  return
30
30
 
jaxsim/math/__init__.py CHANGED
@@ -1,6 +1,3 @@
1
- # Define the default standard gravity constant.
2
- StandardGravity = 9.81
3
-
4
1
  from .adjoint import Adjoint
5
2
  from .cross import Cross
6
3
  from .inertia import Inertia
@@ -11,3 +8,7 @@ from .transform import Transform
11
8
  from .utils import safe_norm
12
9
 
13
10
  from .joint_model import JointModel, supported_joint_motion # isort:skip
11
+
12
+
13
+ # Define the default standard gravity constant.
14
+ STANDARD_GRAVITY = -9.81
@@ -7,12 +7,10 @@ import jaxlie
7
7
  from jax_dataclasses import Static
8
8
 
9
9
  import jaxsim.typing as jtp
10
+ from jaxsim.math import Rotation
10
11
  from jaxsim.parsers.descriptions import JointGenericAxis, JointType, ModelDescription
11
12
  from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms
12
13
 
13
- from .rotation import Rotation
14
- from .transform import Transform
15
-
16
14
 
17
15
  @jax_dataclasses.pytree_dataclass
18
16
  class JointModel:
@@ -113,60 +111,6 @@ class JointModel:
113
111
  joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints),
114
112
  )
115
113
 
116
- def parent_H_child(
117
- self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
118
- ) -> tuple[jtp.Matrix, jtp.Array]:
119
- r"""
120
- Compute the homogeneous transformation between the parent link and
121
- the child link of a joint, and the corresponding motion subspace.
122
-
123
- Args:
124
- joint_index: The index of the joint.
125
- joint_position: The position of the joint.
126
-
127
- Returns:
128
- A tuple containing the homogeneous transformation
129
- :math:`{}^{\lambda(i)} \mathbf{H}_i(s)`
130
- and the motion subspace :math:`\mathbf{S}(s)`.
131
- """
132
-
133
- i = joint_index
134
- s = joint_position
135
-
136
- # Get the components of the joint model.
137
- λ_Hi_pre = self.parent_H_predecessor(joint_index=i)
138
- pre_Hi_suc, S = self.predecessor_H_successor(joint_index=i, joint_position=s)
139
- suc_Hi_i = self.successor_H_child(joint_index=i)
140
-
141
- # Compose all the transforms.
142
- return λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i, S
143
-
144
- @jax.jit
145
- def child_H_parent(
146
- self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
147
- ) -> tuple[jtp.Matrix, jtp.Array]:
148
- r"""
149
- Compute the homogeneous transformation between the child link and
150
- the parent link of a joint, and the corresponding motion subspace.
151
-
152
- Args:
153
- joint_index: The index of the joint.
154
- joint_position: The position of the joint.
155
-
156
- Returns:
157
- A tuple containing the homogeneous transformation
158
- :math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)`
159
- and the motion subspace :math:`\mathbf{S}(s)`.
160
- """
161
-
162
- λ_Hi_i, S = self.parent_H_child(
163
- joint_index=joint_index, joint_position=joint_position
164
- )
165
-
166
- i_Hi_λ = Transform.inverse(λ_Hi_i)
167
-
168
- return i_Hi_λ, S
169
-
170
114
  def parent_H_predecessor(self, joint_index: jtp.IntLike) -> jtp.Matrix:
171
115
  r"""
172
116
  Return the homogeneous transformation between the parent link and
@@ -182,31 +126,6 @@ class JointModel:
182
126
 
183
127
  return self.λ_H_pre[joint_index]
184
128
 
185
- def predecessor_H_successor(
186
- self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike
187
- ) -> tuple[jtp.Matrix, jtp.Array]:
188
- r"""
189
- Compute the homogeneous transformation between the predecessor and
190
- the successor frame of a joint, and the corresponding motion subspace.
191
-
192
- Args:
193
- joint_index: The index of the joint.
194
- joint_position: The position of the joint.
195
-
196
- Returns:
197
- A tuple containing the homogeneous transformation
198
- :math:`{}^{\text{pre}(i)} \mathbf{H}_{\text{suc}(i)}(s)`
199
- and the motion subspace :math:`\mathbf{S}(s)`.
200
- """
201
-
202
- pre_H_suc, S = supported_joint_motion(
203
- self.joint_types[joint_index],
204
- joint_position,
205
- self.joint_axis[joint_index].axis,
206
- )
207
-
208
- return pre_H_suc, S
209
-
210
129
  def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix:
211
130
  r"""
212
131
  Return the homogeneous transformation between the successor frame and
@@ -225,65 +144,56 @@ class JointModel:
225
144
 
226
145
  @jax.jit
227
146
  def supported_joint_motion(
228
- joint_type: jtp.IntLike,
229
- joint_position: jtp.VectorLike,
230
- joint_axis: jtp.VectorLike | None = None,
231
- /,
232
- ) -> tuple[jtp.Matrix, jtp.Array]:
147
+ joint_types: jtp.Array, joint_positions: jtp.Matrix, joint_axes: jtp.Matrix
148
+ ) -> jtp.Matrix:
233
149
  """
234
- Compute the homogeneous transformation and motion subspace of a joint.
150
+ Compute the transforms of the joints.
235
151
 
236
152
  Args:
237
- joint_type: The type of the joint.
238
- joint_position: The position of the joint.
239
- joint_axis: The optional 3D axis of rotation or translation of the joint.
153
+ joint_types: The types of the joints.
154
+ joint_positions: The positions of the joints.
155
+ joint_axes: The axes of the joints.
240
156
 
241
157
  Returns:
242
- A tuple containing the homogeneous transformation and the motion subspace.
158
+ The transforms of the joints.
243
159
  """
244
160
 
245
161
  # Prepare the joint position
246
- s = jnp.array(joint_position).astype(float)
162
+ s = jnp.array(joint_positions).astype(float)
247
163
 
248
164
  def compute_F() -> tuple[jtp.Matrix, jtp.Array]:
249
- return jaxlie.SE3.identity(), jnp.zeros(shape=(6, 1))
165
+ return jaxlie.SE3.identity()
250
166
 
251
167
  def compute_R() -> tuple[jtp.Matrix, jtp.Array]:
252
168
 
253
169
  # Get the additional argument specifying the joint axis.
254
170
  # This is a metadata required by only some joint types.
255
- axis = jnp.array(joint_axis).astype(float).squeeze()
171
+ axis = jnp.array(joint_axes).astype(float).squeeze()
256
172
 
257
173
  pre_H_suc = jaxlie.SE3.from_matrix(
258
174
  matrix=jnp.eye(4).at[:3, :3].set(Rotation.from_axis_angle(vector=s * axis))
259
175
  )
260
176
 
261
- S = jnp.vstack(jnp.hstack([jnp.zeros(3), axis]))
262
-
263
- return pre_H_suc, S
177
+ return pre_H_suc
264
178
 
265
179
  def compute_P() -> tuple[jtp.Matrix, jtp.Array]:
266
180
 
267
181
  # Get the additional argument specifying the joint axis.
268
182
  # This is a metadata required by only some joint types.
269
- axis = jnp.array(joint_axis).astype(float).squeeze()
183
+ axis = jnp.array(joint_axes).astype(float).squeeze()
270
184
 
271
185
  pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
272
186
  rotation=jaxlie.SO3.identity(),
273
187
  translation=jnp.array(s * axis),
274
188
  )
275
189
 
276
- S = jnp.vstack(jnp.hstack([axis, jnp.zeros(3)]))
190
+ return pre_H_suc
277
191
 
278
- return pre_H_suc, S
279
-
280
- pre_H_suc, S = jax.lax.switch(
281
- index=joint_type,
192
+ return jax.lax.switch(
193
+ index=joint_types,
282
194
  branches=(
283
195
  compute_F, # JointType.Fixed
284
196
  compute_R, # JointType.Revolute
285
197
  compute_P, # JointType.Prismatic
286
198
  ),
287
- )
288
-
289
- return pre_H_suc.as_matrix(), S
199
+ ).as_matrix()