jaxsim 0.6.1.dev13__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. jaxsim/__init__.py +1 -1
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/actuation_model.py +96 -0
  5. jaxsim/api/com.py +8 -8
  6. jaxsim/api/contact.py +15 -255
  7. jaxsim/api/contact_model.py +101 -0
  8. jaxsim/api/data.py +258 -556
  9. jaxsim/api/frame.py +7 -7
  10. jaxsim/api/integrators.py +76 -0
  11. jaxsim/api/kin_dyn_parameters.py +41 -58
  12. jaxsim/api/link.py +7 -7
  13. jaxsim/api/model.py +190 -453
  14. jaxsim/api/ode.py +34 -338
  15. jaxsim/api/references.py +2 -2
  16. jaxsim/exceptions.py +2 -2
  17. jaxsim/math/__init__.py +4 -3
  18. jaxsim/math/joint_model.py +17 -107
  19. jaxsim/mujoco/model.py +1 -1
  20. jaxsim/mujoco/utils.py +2 -2
  21. jaxsim/parsers/kinematic_graph.py +1 -3
  22. jaxsim/rbda/aba.py +7 -4
  23. jaxsim/rbda/collidable_points.py +7 -98
  24. jaxsim/rbda/contacts/__init__.py +2 -10
  25. jaxsim/rbda/contacts/common.py +0 -138
  26. jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
  27. jaxsim/rbda/crba.py +5 -2
  28. jaxsim/rbda/forward_kinematics.py +37 -12
  29. jaxsim/rbda/jacobian.py +15 -6
  30. jaxsim/rbda/rnea.py +7 -4
  31. jaxsim/rbda/utils.py +3 -3
  32. jaxsim/utils/jaxsim_dataclass.py +5 -1
  33. {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
  34. jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
  35. jaxsim/api/ode_data.py +0 -401
  36. jaxsim/integrators/__init__.py +0 -2
  37. jaxsim/integrators/common.py +0 -592
  38. jaxsim/integrators/fixed_step.py +0 -153
  39. jaxsim/integrators/variable_step.py +0 -706
  40. jaxsim/rbda/contacts/rigid.py +0 -462
  41. jaxsim/rbda/contacts/soft.py +0 -480
  42. jaxsim/rbda/contacts/visco_elastic.py +0 -1066
  43. jaxsim-0.6.1.dev13.dist-info/RECORD +0 -74
  44. {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
  45. {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
  46. {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
jaxsim/api/model.py CHANGED
@@ -5,7 +5,6 @@ import dataclasses
5
5
  import functools
6
6
  import pathlib
7
7
  from collections.abc import Sequence
8
- from typing import Any
9
8
 
10
9
  import jax
11
10
  import jax.numpy as jnp
@@ -30,21 +29,28 @@ class JaxSimModel(JaxsimDataclass):
30
29
  The JaxSim model defining the kinematics and dynamics of a robot.
31
30
  """
32
31
 
32
+ # link_spatial_inertial_matrices, motion_subspaces
33
+
33
34
  model_name: Static[str]
34
35
 
35
- time_step: jaxsim.integrators.TimeStep = dataclasses.field(
36
- default_factory=lambda: jnp.array(0.001, dtype=float),
36
+ time_step: float = dataclasses.field(
37
+ default=0.001,
37
38
  )
38
39
 
39
40
  terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
40
41
  default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
41
42
  )
42
43
 
43
- # Note that this is the default contact model.
44
+ gravity: Static[float] = jaxsim.math.STANDARD_GRAVITY
45
+
44
46
  contact_model: Static[jaxsim.rbda.contacts.ContactModel | None] = dataclasses.field(
45
47
  default=None, repr=False
46
48
  )
47
49
 
50
+ contacts_params: Static[jaxsim.rbda.contacts.ContactsParams] = dataclasses.field(
51
+ default=None, repr=False
52
+ )
53
+
48
54
  kin_dyn_parameters: js.kin_dyn_parameters.KinDynParameters | None = (
49
55
  dataclasses.field(default=None, repr=False)
50
56
  )
@@ -53,10 +59,6 @@ class JaxSimModel(JaxsimDataclass):
53
59
  default=None, repr=False
54
60
  )
55
61
 
56
- integrator: Static[jaxsim.integrators.Integrator | None] = dataclasses.field(
57
- default=None, repr=False
58
- )
59
-
60
62
  _description: Static[wrappers.HashlessObject[ModelDescription | None]] = (
61
63
  dataclasses.field(default=None, repr=False)
62
64
  )
@@ -89,7 +91,7 @@ class JaxSimModel(JaxsimDataclass):
89
91
  return hash(
90
92
  (
91
93
  hash(self.model_name),
92
- hash(float(self.time_step)),
94
+ hash(self.time_step),
93
95
  hash(self.kin_dyn_parameters),
94
96
  hash(self.contact_model),
95
97
  )
@@ -106,11 +108,9 @@ class JaxSimModel(JaxsimDataclass):
106
108
  *,
107
109
  model_name: str | None = None,
108
110
  time_step: jtp.FloatLike | None = None,
109
- integrator: (
110
- jaxsim.integrators.Integrator | type[jaxsim.integrators.Integrator] | None
111
- ) = None,
112
111
  terrain: jaxsim.terrain.Terrain | None = None,
113
112
  contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
113
+ contact_params: jaxsim.rbda.contacts.ContactsParams | None = None,
114
114
  is_urdf: bool | None = None,
115
115
  considered_joints: Sequence[str] | None = None,
116
116
  ) -> JaxSimModel:
@@ -130,10 +130,7 @@ class JaxSimModel(JaxsimDataclass):
130
130
  contact_model:
131
131
  The contact model to consider.
132
132
  If not specified, a soft contacts model is used.
133
- integrator:
134
- The integrator to use. If not specified, a default one is used.
135
- This argument can either be a pre-built integrator instance or one
136
- of the integrator classes defined in JaxSim.
133
+ contact_params: The parameters of the contact model.
137
134
  is_urdf:
138
135
  The optional flag to force the model description to be parsed as a URDF.
139
136
  This is usually automatically inferred.
@@ -164,9 +161,9 @@ class JaxSimModel(JaxsimDataclass):
164
161
  model_description=intermediate_description,
165
162
  model_name=model_name,
166
163
  time_step=time_step,
167
- integrator=integrator,
168
164
  terrain=terrain,
169
165
  contact_model=contact_model,
166
+ contacts_params=contact_params,
170
167
  )
171
168
 
172
169
  # Store the origin of the model, in case downstream logic needs it.
@@ -182,11 +179,10 @@ class JaxSimModel(JaxsimDataclass):
182
179
  *,
183
180
  model_name: str | None = None,
184
181
  time_step: jtp.FloatLike | None = None,
185
- integrator: (
186
- jaxsim.integrators.Integrator | type[jaxsim.integrators.Integrator] | None
187
- ) = None,
188
182
  terrain: jaxsim.terrain.Terrain | None = None,
189
183
  contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
184
+ contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
185
+ gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
190
186
  ) -> JaxSimModel:
191
187
  """
192
188
  Build a Model object from an intermediate model description.
@@ -202,13 +198,11 @@ class JaxSimModel(JaxsimDataclass):
202
198
  manually overridden in the function that steps the simulation.
203
199
  terrain: The terrain to consider (the default is a flat infinite plane).
204
200
  The optional name of the model overriding the physics model name.
205
- integrator:
206
- The integrator to use. If not specified, a default one is used.
207
- This argument can either be a pre-built integrator instance or one
208
- of the integrator classes defined in JaxSim.
209
201
  contact_model:
210
202
  The contact model to consider.
211
203
  If not specified, a soft contacts model is used.
204
+ contacts_params: The parameters of the soft contacts.
205
+ gravity: The gravity constant.
212
206
 
213
207
  Returns:
214
208
  The built Model object.
@@ -228,7 +222,7 @@ class JaxSimModel(JaxsimDataclass):
228
222
  time_step = (
229
223
  time_step
230
224
  if time_step is not None
231
- else JaxSimModel.__dataclass_fields__["time_step"].default_factory()
225
+ else JaxSimModel.__dataclass_fields__["time_step"].default
232
226
  )
233
227
 
234
228
  # Create the default contact model.
@@ -237,39 +231,11 @@ class JaxSimModel(JaxsimDataclass):
237
231
  contact_model = (
238
232
  contact_model
239
233
  if contact_model is not None
240
- else jaxsim.rbda.contacts.SoftContacts.build()
234
+ else jaxsim.rbda.contacts.RelaxedRigidContacts.build()
241
235
  )
242
236
 
243
- # Build the integrator if not provided.
244
- match integrator:
245
-
246
- # If None, build a default integrator.
247
- case None:
248
-
249
- integrator = jaxsim.integrators.fixed_step.Heun2SO3.build(
250
- dynamics=js.ode.wrap_system_dynamics_for_integration(
251
- system_dynamics=js.ode.system_dynamics
252
- )
253
- )
254
-
255
- # If it's a pre-built integrator (also a custom one from the user)
256
- # just use it as is.
257
- case _ if isinstance(integrator, jaxsim.integrators.Integrator):
258
- pass
259
-
260
- # If an integrator class is passed, assume that it is a JaxSim integrator
261
- # and build it with the default system dynamics.
262
- case _ if issubclass(integrator, jaxsim.integrators.Integrator):
263
-
264
- integrator_cls = integrator
265
- integrator = integrator_cls.build(
266
- dynamics=js.ode.wrap_system_dynamics_for_integration(
267
- system_dynamics=js.ode.system_dynamics
268
- )
269
- )
270
-
271
- case _:
272
- raise ValueError(f"Invalid integrator: {integrator}")
237
+ if contacts_params is None:
238
+ contacts_params = contact_model._parameters_class()
273
239
 
274
240
  # Build the model.
275
241
  model = cls(
@@ -280,7 +246,8 @@ class JaxSimModel(JaxsimDataclass):
280
246
  time_step=time_step,
281
247
  terrain=terrain,
282
248
  contact_model=contact_model,
283
- integrator=integrator,
249
+ contacts_params=contacts_params,
250
+ gravity=gravity,
284
251
  # The following is wrapped as hashless since it's a static argument, and we
285
252
  # don't want to trigger recompilation if it changes. All relevant parameters
286
253
  # needed to compute kinematics and dynamics quantities are stored in the
@@ -350,7 +317,7 @@ class JaxSimModel(JaxsimDataclass):
350
317
  True if the model is floating-base, False otherwise.
351
318
  """
352
319
 
353
- return bool(self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6)
320
+ return self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6
354
321
 
355
322
  def base_link(self) -> str:
356
323
  """
@@ -381,7 +348,7 @@ class JaxSimModel(JaxsimDataclass):
381
348
  the number of joints. In the future, this could be different.
382
349
  """
383
350
 
384
- return int(sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:]))
351
+ return sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:])
385
352
 
386
353
  def joint_names(self) -> tuple[str, ...]:
387
354
  """
@@ -464,7 +431,7 @@ def reduce(
464
431
  for joint_name in set(model.joint_names()) - set(considered_joints):
465
432
  j = intermediate_description.joints_dict[joint_name]
466
433
  with j.mutable_context():
467
- j.initial_position = float(locked_joint_positions.get(joint_name, 0.0))
434
+ j.initial_position = locked_joint_positions.get(joint_name, 0.0)
468
435
 
469
436
  # Reduce the model description.
470
437
  # If `considered_joints` contains joints not existing in the model,
@@ -480,7 +447,6 @@ def reduce(
480
447
  time_step=model.time_step,
481
448
  terrain=model.terrain,
482
449
  contact_model=model.contact_model,
483
- integrator=model.integrator,
484
450
  )
485
451
 
486
452
  # Store the origin of the model, in case downstream logic needs it.
@@ -534,31 +500,6 @@ def link_spatial_inertia_matrices(model: JaxSimModel) -> jtp.Array:
534
500
  # ==============================
535
501
 
536
502
 
537
- @jax.jit
538
- @js.common.named_scope
539
- def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
540
- """
541
- Compute the SE(3) transforms from the world frame to the frames of all links.
542
-
543
- Args:
544
- model: The model to consider.
545
- data: The data of the considered model.
546
-
547
- Returns:
548
- A (nL, 4, 4) array containing the stacked SE(3) transforms of the links.
549
- The first axis is the link index.
550
- """
551
-
552
- W_H_LL = jaxsim.rbda.forward_kinematics_model(
553
- model=model,
554
- base_position=data.base_position(),
555
- base_quaternion=data.base_orientation(dcm=False),
556
- joint_positions=data.joint_positions(model=model),
557
- )
558
-
559
- return jnp.atleast_3d(W_H_LL).astype(float)
560
-
561
-
562
503
  @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
563
504
  def generalized_free_floating_jacobian(
564
505
  model: JaxSimModel,
@@ -592,7 +533,7 @@ def generalized_free_floating_jacobian(
592
533
  # Compute the doubly-left free-floating full jacobian.
593
534
  B_J_full_WX_B, B_H_L = jaxsim.rbda.jacobian_full_doubly_left(
594
535
  model=model,
595
- joint_positions=data.joint_positions(),
536
+ joint_positions=data.joint_positions,
596
537
  )
597
538
 
598
539
  # ======================================================================
@@ -603,7 +544,7 @@ def generalized_free_floating_jacobian(
603
544
 
604
545
  case VelRepr.Inertial:
605
546
 
606
- W_H_B = data.base_transform()
547
+ W_H_B = data._base_transform
607
548
  B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
608
549
 
609
550
  B_J_full_WX_I = B_J_full_WX_W = ( # noqa: F841
@@ -617,7 +558,7 @@ def generalized_free_floating_jacobian(
617
558
 
618
559
  case VelRepr.Mixed:
619
560
 
620
- W_R_B = data.base_orientation(dcm=True)
561
+ W_R_B = jaxsim.math.Quaternion.to_dcm(data.base_orientation)
621
562
  BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
622
563
  B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
623
564
 
@@ -651,7 +592,7 @@ def generalized_free_floating_jacobian(
651
592
 
652
593
  case VelRepr.Inertial:
653
594
 
654
- W_H_B = data.base_transform()
595
+ W_H_B = data._base_transform
655
596
  W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B)
656
597
 
657
598
  O_J_WL_I = W_J_WL_I = jax.vmap( # noqa: F841
@@ -669,7 +610,7 @@ def generalized_free_floating_jacobian(
669
610
 
670
611
  case VelRepr.Mixed:
671
612
 
672
- W_H_B = data.base_transform()
613
+ W_H_B = data._base_transform
673
614
 
674
615
  LW_H_L = jax.vmap(
675
616
  lambda B_H_L: (W_H_B @ B_H_L).at[0:3, 3].set(jnp.zeros(3))
@@ -718,15 +659,15 @@ def generalized_free_floating_jacobian_derivative(
718
659
  # Compute the derivative of the doubly-left free-floating full jacobian.
719
660
  B_J̇_full_WX_B, B_H_L = jaxsim.rbda.jacobian_derivative_full_doubly_left(
720
661
  model=model,
721
- joint_positions=data.joint_positions(),
722
- joint_velocities=data.joint_velocities(),
662
+ joint_positions=data.joint_positions,
663
+ joint_velocities=data.joint_velocities,
723
664
  )
724
665
 
725
666
  # The derivative of the equation to change the input and output representations
726
667
  # of the Jacobian derivative needs the computation of the plain link Jacobian.
727
668
  B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left(
728
669
  model=model,
729
- joint_positions=data.joint_positions(),
670
+ joint_positions=data.joint_positions,
730
671
  )
731
672
 
732
673
  # Compute the actual doubly-left free-floating jacobian derivative of the link
@@ -734,7 +675,7 @@ def generalized_free_floating_jacobian_derivative(
734
675
  κb = model.kin_dyn_parameters.support_body_array_bool
735
676
 
736
677
  # Compute the base transform.
737
- W_H_B = data.base_transform()
678
+ W_H_B = data._base_transform
738
679
 
739
680
  # We add the 5 columns of ones to the Jacobian derivative to account for the
740
681
  # base velocity and acceleration (5 + number of links = 6 + number of joints).
@@ -758,7 +699,7 @@ def generalized_free_floating_jacobian_derivative(
758
699
 
759
700
  B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True)
760
701
 
761
- W_v_WB = data.base_velocity()
702
+ W_v_WB = data.base_velocity
762
703
  B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)
763
704
 
764
705
  # Compute the operator to change the representation of ν, and its
@@ -784,7 +725,7 @@ def generalized_free_floating_jacobian_derivative(
784
725
  BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
785
726
  B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)
786
727
 
787
- BW_v_WB = data.base_velocity()
728
+ BW_v_WB = data.base_velocity
788
729
  BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
789
730
 
790
731
  BW_v_BW_B = BW_v_WB - BW_v_W_BW
@@ -809,7 +750,7 @@ def generalized_free_floating_jacobian_derivative(
809
750
  O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B)
810
751
 
811
752
  with data.switch_velocity_representation(VelRepr.Body):
812
- B_v_WB = data.base_velocity()
753
+ B_v_WB = data.base_velocity
813
754
 
814
755
  O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841
815
756
 
@@ -822,9 +763,9 @@ def generalized_free_floating_jacobian_derivative(
822
763
  B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B)
823
764
 
824
765
  with data.switch_velocity_representation(VelRepr.Body):
825
- B_v_WB = data.base_velocity()
766
+ B_v_WB = data.base_velocity
826
767
  L_v_WL = jnp.einsum(
827
- "b6j,j->b6", L_X_B @ B_J_WL_B, data.generalized_velocity()
768
+ "b6j,j->b6", L_X_B @ B_J_WL_B, data.generalized_velocity
828
769
  )
829
770
 
830
771
  O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841
@@ -842,7 +783,7 @@ def generalized_free_floating_jacobian_derivative(
842
783
  B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B)
843
784
 
844
785
  with data.switch_velocity_representation(VelRepr.Body):
845
- B_v_WB = data.base_velocity()
786
+ B_v_WB = data.base_velocity
846
787
 
847
788
  with data.switch_velocity_representation(VelRepr.Mixed):
848
789
  BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
@@ -852,7 +793,7 @@ def generalized_free_floating_jacobian_derivative(
852
793
  LW_X_B,
853
794
  B_J_WL_B
854
795
  @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
855
- @ data.generalized_velocity(),
796
+ @ data.generalized_velocity,
856
797
  )
857
798
 
858
799
  LW_v_W_LW = LW_v_WL.at[:, 3:6].set(jnp.zeros_like(LW_v_WL[:, 3:6]))
@@ -954,7 +895,7 @@ def forward_dynamics_aba(
954
895
  τ = (
955
896
  jnp.atleast_1d(joint_forces.squeeze())
956
897
  if joint_forces is not None
957
- else jnp.zeros_like(data.joint_positions())
898
+ else jnp.zeros_like(data.joint_positions)
958
899
  )
959
900
 
960
901
  # Build link forces, if not provided.
@@ -973,22 +914,17 @@ def forward_dynamics_aba(
973
914
  velocity_representation=data.velocity_representation,
974
915
  )
975
916
 
976
- # Extract the link and joint serializations.
977
- link_names = model.link_names()
978
- joint_names = model.joint_names()
979
-
980
917
  # Extract the state in inertial-fixed representation.
981
918
  with data.switch_velocity_representation(VelRepr.Inertial):
982
- W_p_B = data.base_position()
983
- W_v_WB = data.base_velocity()
984
- W_Q_B = data.base_orientation(dcm=False)
985
- s = data.joint_positions(model=model, joint_names=joint_names)
986
- ṡ = data.joint_velocities(model=model, joint_names=joint_names)
919
+ W_p_B = data.base_position
920
+ W_v_WB = data.base_velocity
921
+ W_Q_B = data.base_orientation
922
+ s = data.joint_positions
923
+ ṡ = data.joint_velocities
987
924
 
988
925
  # Extract the inputs in inertial-fixed representation.
989
- with references.switch_velocity_representation(VelRepr.Inertial):
990
- W_f_L = references.link_forces(model=model, data=data, link_names=link_names)
991
- τ = references.joint_force_references(model=model, joint_names=joint_names)
926
+ W_f_L = references._link_forces
927
+ τ = references._joint_force_references
992
928
 
993
929
  # ========================
994
930
  # Compute forward dynamics
@@ -1004,7 +940,7 @@ def forward_dynamics_aba(
1004
940
  joint_velocities=ṡ,
1005
941
  joint_forces=τ,
1006
942
  link_forces=W_f_L,
1007
- standard_gravity=data.standard_gravity(),
943
+ standard_gravity=model.gravity,
1008
944
  )
1009
945
 
1010
946
  # =============
@@ -1032,14 +968,14 @@ def forward_dynamics_aba(
1032
968
 
1033
969
  case VelRepr.Body:
1034
970
  # In this case C=B
1035
- W_H_C = W_H_B = data.base_transform()
971
+ W_H_C = W_H_B = data._base_transform
1036
972
  W_v_WC = W_v_WB
1037
973
 
1038
974
  case VelRepr.Mixed:
1039
975
  # In this case C=B[W]
1040
- W_H_B = data.base_transform()
976
+ W_H_B = data._base_transform
1041
977
  W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
1042
- W_ṗ_B = data.base_velocity()[0:3]
978
+ W_ṗ_B = data.base_velocity[0:3]
1043
979
  W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841
1044
980
 
1045
981
  case _:
@@ -1103,7 +1039,7 @@ def forward_dynamics_crb(
1103
1039
  τ = (
1104
1040
  jnp.atleast_1d(joint_forces)
1105
1041
  if joint_forces is not None
1106
- else jnp.zeros_like(data.joint_positions())
1042
+ else jnp.zeros_like(data.joint_positions)
1107
1043
  )
1108
1044
 
1109
1045
  # Build external forces if not provided.
@@ -1174,7 +1110,7 @@ def free_floating_mass_matrix(
1174
1110
 
1175
1111
  M_body = jaxsim.rbda.crba(
1176
1112
  model=model,
1177
- joint_positions=data.state.physics_model.joint_positions,
1113
+ joint_positions=data.joint_positions,
1178
1114
  )
1179
1115
 
1180
1116
  match data.velocity_representation:
@@ -1183,16 +1119,14 @@ def free_floating_mass_matrix(
1183
1119
 
1184
1120
  case VelRepr.Inertial:
1185
1121
 
1186
- B_X_W = Adjoint.from_transform(
1187
- transform=data.base_transform(), inverse=True
1188
- )
1122
+ B_X_W = Adjoint.from_transform(transform=data._base_transform, inverse=True)
1189
1123
  invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
1190
1124
 
1191
1125
  return invT.T @ M_body @ invT
1192
1126
 
1193
1127
  case VelRepr.Mixed:
1194
1128
 
1195
- BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
1129
+ BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3))
1196
1130
  B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
1197
1131
  invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
1198
1132
 
@@ -1228,7 +1162,7 @@ def free_floating_coriolis_matrix(
1228
1162
  # to the active representation stored in data.
1229
1163
  with data.switch_velocity_representation(VelRepr.Body):
1230
1164
 
1231
- B_ν = data.generalized_velocity()
1165
+ B_ν = data.generalized_velocity
1232
1166
 
1233
1167
  # Doubly-left free-floating Jacobian.
1234
1168
  L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data)
@@ -1275,12 +1209,12 @@ def free_floating_coriolis_matrix(
1275
1209
  case VelRepr.Inertial:
1276
1210
 
1277
1211
  n = model.dofs()
1278
- W_H_B = data.base_transform()
1212
+ W_H_B = data._base_transform
1279
1213
  B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True)
1280
1214
  B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n))
1281
1215
 
1282
1216
  with data.switch_velocity_representation(VelRepr.Inertial):
1283
- W_v_WB = data.base_velocity()
1217
+ W_v_WB = data.base_velocity
1284
1218
  B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)
1285
1219
 
1286
1220
  B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n)))
@@ -1295,12 +1229,12 @@ def free_floating_coriolis_matrix(
1295
1229
  case VelRepr.Mixed:
1296
1230
 
1297
1231
  n = model.dofs()
1298
- BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
1232
+ BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3))
1299
1233
  B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)
1300
1234
  B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n))
1301
1235
 
1302
1236
  with data.switch_velocity_representation(VelRepr.Mixed):
1303
- BW_v_WB = data.base_velocity()
1237
+ BW_v_WB = data.base_velocity
1304
1238
  BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
1305
1239
 
1306
1240
  BW_v_BW_B = BW_v_WB - BW_v_W_BW
@@ -1357,7 +1291,7 @@ def inverse_dynamics(
1357
1291
  s̈ = (
1358
1292
  jnp.atleast_1d(jnp.array(joint_accelerations).squeeze())
1359
1293
  if joint_accelerations is not None
1360
- else jnp.zeros_like(data.joint_positions())
1294
+ else jnp.zeros_like(data.joint_positions)
1361
1295
  )
1362
1296
 
1363
1297
  # Build base acceleration, if not provided.
@@ -1394,14 +1328,14 @@ def inverse_dynamics(
1394
1328
  W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
1395
1329
 
1396
1330
  case VelRepr.Body:
1397
- W_H_C = W_H_B = data.base_transform()
1331
+ W_H_C = W_H_B = data._base_transform
1398
1332
  with data.switch_velocity_representation(VelRepr.Inertial):
1399
- W_v_WC = W_v_WB = data.base_velocity()
1333
+ W_v_WC = W_v_WB = data.base_velocity
1400
1334
 
1401
1335
  case VelRepr.Mixed:
1402
- W_H_B = data.base_transform()
1336
+ W_H_B = data._base_transform
1403
1337
  W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
1404
- W_ṗ_B = data.base_velocity()[0:3]
1338
+ W_ṗ_B = data.base_velocity[0:3]
1405
1339
  W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841
1406
1340
 
1407
1341
  case _:
@@ -1413,7 +1347,7 @@ def inverse_dynamics(
1413
1347
  W_v̇_WB = to_inertial(
1414
1348
  C_v̇_WB=v̇_WB,
1415
1349
  W_H_C=W_H_C,
1416
- C_v_WB=data.base_velocity(),
1350
+ C_v_WB=data.base_velocity,
1417
1351
  W_v_WC=W_v_WC,
1418
1352
  )
1419
1353
 
@@ -1425,21 +1359,16 @@ def inverse_dynamics(
1425
1359
  velocity_representation=data.velocity_representation,
1426
1360
  )
1427
1361
 
1428
- # Extract the link and joint serializations.
1429
- link_names = model.link_names()
1430
- joint_names = model.joint_names()
1431
-
1432
1362
  # Extract the state in inertial-fixed representation.
1433
1363
  with data.switch_velocity_representation(VelRepr.Inertial):
1434
- W_p_B = data.base_position()
1435
- W_v_WB = data.base_velocity()
1436
- W_Q_B = data.base_orientation(dcm=False)
1437
- s = data.joint_positions(model=model, joint_names=joint_names)
1438
- ṡ = data.joint_velocities(model=model, joint_names=joint_names)
1364
+ W_p_B = data.base_position
1365
+ W_v_WB = data.base_velocity
1366
+ W_Q_B = data.base_quaternion
1367
+ s = data.joint_positions
1368
+ ṡ = data.joint_velocities
1439
1369
 
1440
1370
  # Extract the inputs in inertial-fixed representation.
1441
- with references.switch_velocity_representation(VelRepr.Inertial):
1442
- W_f_L = references.link_forces(model=model, data=data, link_names=link_names)
1371
+ W_f_L = references._link_forces
1443
1372
 
1444
1373
  # ========================
1445
1374
  # Compute inverse dynamics
@@ -1457,7 +1386,7 @@ def inverse_dynamics(
1457
1386
  base_angular_acceleration=W_v̇_WB[3:6],
1458
1387
  joint_accelerations=s̈,
1459
1388
  link_forces=W_f_L,
1460
- standard_gravity=data.standard_gravity(),
1389
+ standard_gravity=model.gravity,
1461
1390
  )
1462
1391
 
1463
1392
  # =============
@@ -1468,7 +1397,7 @@ def inverse_dynamics(
1468
1397
  f_B = js.data.JaxSimModelData.inertial_to_other_representation(
1469
1398
  array=W_f_B,
1470
1399
  other_representation=data.velocity_representation,
1471
- transform=data.base_transform(),
1400
+ transform=data._base_transform,
1472
1401
  is_force=True,
1473
1402
  ).squeeze()
1474
1403
 
@@ -1497,21 +1426,12 @@ def free_floating_gravity_forces(
1497
1426
  )
1498
1427
 
1499
1428
  # Set just the generalized position.
1500
- with data_rnea.mutable_context(
1501
- mutability=Mutability.MUTABLE, restore_after_exception=False
1502
- ):
1503
-
1504
- data_rnea.state.physics_model.base_position = (
1505
- data.state.physics_model.base_position
1506
- )
1507
-
1508
- data_rnea.state.physics_model.base_quaternion = (
1509
- data.state.physics_model.base_quaternion
1510
- )
1511
-
1512
- data_rnea.state.physics_model.joint_positions = (
1513
- data.state.physics_model.joint_positions
1514
- )
1429
+ data_rnea = data_rnea.replace(
1430
+ model=model,
1431
+ base_position=data.base_position,
1432
+ base_quaternion=data.base_quaternion,
1433
+ joint_positions=data.joint_positions,
1434
+ )
1515
1435
 
1516
1436
  return jnp.hstack(
1517
1437
  inverse_dynamics(
@@ -1548,35 +1468,20 @@ def free_floating_bias_forces(
1548
1468
  )
1549
1469
 
1550
1470
  # Set the generalized position and generalized velocity.
1551
- with data_rnea.mutable_context(
1552
- mutability=Mutability.MUTABLE, restore_after_exception=False
1553
- ):
1554
-
1555
- data_rnea.state.physics_model.base_position = (
1556
- data.state.physics_model.base_position
1557
- )
1558
-
1559
- data_rnea.state.physics_model.base_quaternion = (
1560
- data.state.physics_model.base_quaternion
1561
- )
1562
-
1563
- data_rnea.state.physics_model.joint_positions = (
1564
- data.state.physics_model.joint_positions
1565
- )
1566
-
1567
- data_rnea.state.physics_model.joint_velocities = (
1568
- data.state.physics_model.joint_velocities
1569
- )
1570
-
1571
- # Make sure that base velocity is zero for fixed-base model.
1572
- if model.floating_base():
1573
- data_rnea.state.physics_model.base_linear_velocity = (
1574
- data.state.physics_model.base_linear_velocity
1575
- )
1576
-
1577
- data_rnea.state.physics_model.base_angular_velocity = (
1578
- data.state.physics_model.base_angular_velocity
1579
- )
1471
+ base_linear_velocity, base_angular_velocity = None, None
1472
+ if model.floating_base():
1473
+ base_velocity = data.base_velocity
1474
+ base_linear_velocity = base_velocity[:3]
1475
+ base_angular_velocity = base_velocity[3:]
1476
+ data_rnea = data_rnea.replace(
1477
+ model=model,
1478
+ base_position=data.base_position,
1479
+ base_quaternion=data.base_quaternion,
1480
+ joint_positions=data.joint_positions,
1481
+ joint_velocities=data.joint_velocities,
1482
+ base_linear_velocity=base_linear_velocity,
1483
+ base_angular_velocity=base_angular_velocity,
1484
+ )
1580
1485
 
1581
1486
  return jnp.hstack(
1582
1487
  inverse_dynamics(
@@ -1628,7 +1533,7 @@ def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vec
1628
1533
  The total momentum of the model in the active velocity representation.
1629
1534
  """
1630
1535
 
1631
- ν = data.generalized_velocity()
1536
+ ν = data.generalized_velocity
1632
1537
  Jh = total_momentum_jacobian(model=model, data=data)
1633
1538
 
1634
1539
  return Jh @ ν
@@ -1668,13 +1573,11 @@ def total_momentum_jacobian(
1668
1573
  B_Jh = B_Jh_B
1669
1574
 
1670
1575
  case VelRepr.Inertial:
1671
- B_X_W = Adjoint.from_transform(
1672
- transform=data.base_transform(), inverse=True
1673
- )
1576
+ B_X_W = Adjoint.from_transform(transform=data._base_transform, inverse=True)
1674
1577
  B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
1675
1578
 
1676
1579
  case VelRepr.Mixed:
1677
- BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
1580
+ BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3))
1678
1581
  B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
1679
1582
  B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
1680
1583
 
@@ -1686,14 +1589,14 @@ def total_momentum_jacobian(
1686
1589
  return B_Jh
1687
1590
 
1688
1591
  case VelRepr.Inertial:
1689
- W_H_B = data.base_transform()
1592
+ W_H_B = data._base_transform
1690
1593
  B_Xv_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
1691
1594
  W_Xf_B = B_Xv_W.T
1692
1595
  W_Jh = W_Xf_B @ B_Jh
1693
1596
  return W_Jh
1694
1597
 
1695
1598
  case VelRepr.Mixed:
1696
- BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
1599
+ BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3))
1697
1600
  B_Xv_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
1698
1601
  BW_Xf_B = B_Xv_BW.T
1699
1602
  BW_Jh = BW_Xf_B @ B_Jh
@@ -1718,7 +1621,7 @@ def average_velocity(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.V
1718
1621
  in the active representation.
1719
1622
  """
1720
1623
 
1721
- ν = data.generalized_velocity()
1624
+ ν = data.generalized_velocity
1722
1625
  J = average_velocity_jacobian(model=model, data=data)
1723
1626
 
1724
1627
  return J @ ν
@@ -1766,9 +1669,9 @@ def average_velocity_jacobian(
1766
1669
  case VelRepr.Body:
1767
1670
 
1768
1671
  GB_J = G_J
1769
- W_p_B = data.base_position()
1672
+ W_p_B = data.base_position
1770
1673
  W_p_CoM = js.com.com_position(model=model, data=data)
1771
- B_R_W = data.base_orientation(dcm=True).transpose()
1674
+ B_R_W = jaxsim.math.Quaternion.to_dcm(data.base_orientation).transpose()
1772
1675
 
1773
1676
  B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B))
1774
1677
  B_X_GB = Adjoint.from_transform(transform=B_H_GB)
@@ -1778,7 +1681,7 @@ def average_velocity_jacobian(
1778
1681
  case VelRepr.Mixed:
1779
1682
 
1780
1683
  GW_J = G_J
1781
- W_p_B = data.base_position()
1684
+ W_p_B = data.base_position
1782
1685
  W_p_CoM = js.com.com_position(model=model, data=data)
1783
1686
 
1784
1687
  BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B)
@@ -1819,7 +1722,7 @@ def link_bias_accelerations(
1819
1722
  # ================================================
1820
1723
 
1821
1724
  # Compute the base transform.
1822
- W_H_B = data.base_transform()
1725
+ W_H_B = data._base_transform
1823
1726
 
1824
1727
  def other_representation_to_inertial(
1825
1728
  C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector
@@ -1846,25 +1749,25 @@ def link_bias_accelerations(
1846
1749
  W_H_C = W_H_W = jnp.eye(4) # noqa: F841
1847
1750
  W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
1848
1751
  with data.switch_velocity_representation(VelRepr.Inertial):
1849
- C_v_WB = W_v_WB = data.base_velocity()
1752
+ C_v_WB = W_v_WB = data.base_velocity
1850
1753
 
1851
1754
  case VelRepr.Body:
1852
1755
  W_H_C = W_H_B
1853
1756
  with data.switch_velocity_representation(VelRepr.Inertial):
1854
- W_v_WC = W_v_WB = data.base_velocity() # noqa: F841
1757
+ W_v_WC = W_v_WB = data.base_velocity # noqa: F841
1855
1758
  with data.switch_velocity_representation(VelRepr.Body):
1856
- C_v_WB = B_v_WB = data.base_velocity()
1759
+ C_v_WB = B_v_WB = data.base_velocity
1857
1760
 
1858
1761
  case VelRepr.Mixed:
1859
1762
  W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
1860
1763
  W_H_C = W_H_BW
1861
1764
  with data.switch_velocity_representation(VelRepr.Mixed):
1862
- W_ṗ_B = data.base_velocity()[0:3]
1765
+ W_ṗ_B = data.base_velocity[0:3]
1863
1766
  BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
1864
1767
  W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW)
1865
1768
  W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW # noqa: F841
1866
1769
  with data.switch_velocity_representation(VelRepr.Mixed):
1867
- C_v_WB = BW_v_WB = data.base_velocity() # noqa: F841
1770
+ C_v_WB = BW_v_WB = data.base_velocity # noqa: F841
1868
1771
 
1869
1772
  case _:
1870
1773
  raise ValueError(data.velocity_representation)
@@ -1888,20 +1791,23 @@ def link_bias_accelerations(
1888
1791
  # Compute the parent-to-child adjoints and the motion subspaces of the joints.
1889
1792
  # These transforms define the relative kinematics of the entire model, including
1890
1793
  # the base transform for both floating-base and fixed-base models.
1891
- i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
1892
- joint_positions=data.joint_positions(), base_transform=W_H_B
1794
+ i_X_λi = model.kin_dyn_parameters.joint_transforms(
1795
+ joint_positions=data.joint_positions, base_transform=W_H_B
1893
1796
  )
1894
1797
 
1798
+ # Extract the joint motion subspaces.
1799
+ S = model.kin_dyn_parameters.motion_subspaces
1800
+
1895
1801
  # Allocate the buffer to store the body-fixed link velocities.
1896
1802
  L_v_WL = jnp.zeros(shape=(model.number_of_links(), 6))
1897
1803
 
1898
1804
  # Store the base velocity.
1899
1805
  with data.switch_velocity_representation(VelRepr.Body):
1900
- B_v_WB = data.base_velocity()
1806
+ B_v_WB = data.base_velocity
1901
1807
  L_v_WL = L_v_WL.at[0].set(B_v_WB)
1902
1808
 
1903
1809
  # Get the joint velocities.
1904
- ṡ = data.joint_velocities(model=model, joint_names=model.joint_names())
1810
+ ṡ = data.joint_velocities
1905
1811
 
1906
1812
  # Allocate the buffer to store the body-fixed link accelerations,
1907
1813
  # and initialize the base acceleration.
@@ -1980,11 +1886,11 @@ def link_bias_accelerations(
1980
1886
  )
1981
1887
 
1982
1888
  case VelRepr.Inertial:
1983
- C_H_L = W_H_L = js.model.forward_kinematics(model=model, data=data)
1889
+ C_H_L = W_H_L = data._link_transforms
1984
1890
  L_v_CL = L_v_WL
1985
1891
 
1986
1892
  case VelRepr.Mixed:
1987
- W_H_L = js.model.forward_kinematics(model=model, data=data)
1893
+ W_H_L = data._link_transforms
1988
1894
  LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L)
1989
1895
  C_H_L = LW_H_L
1990
1896
  L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841
@@ -2002,77 +1908,6 @@ def link_bias_accelerations(
2002
1908
  return O_v̇_WL
2003
1909
 
2004
1910
 
2005
- @jax.jit
2006
- @js.common.named_scope
2007
- def link_contact_forces(
2008
- model: js.model.JaxSimModel,
2009
- data: js.data.JaxSimModelData,
2010
- *,
2011
- link_forces: jtp.MatrixLike | None = None,
2012
- joint_force_references: jtp.VectorLike | None = None,
2013
- **kwargs,
2014
- ) -> jtp.Matrix:
2015
- """
2016
- Compute the 6D contact forces of all links of the model.
2017
-
2018
- Args:
2019
- model: The model to consider.
2020
- data: The data of the considered model.
2021
- link_forces:
2022
- The 6D external forces to apply to the links expressed in the same
2023
- representation of data.
2024
- joint_force_references:
2025
- The joint force references to apply to the joints.
2026
- kwargs: Additional keyword arguments to pass to the active contact model..
2027
-
2028
- Returns:
2029
- A `(nL, 6)` array containing the stacked 6D contact forces of the links,
2030
- expressed in the frame corresponding to the active representation.
2031
- """
2032
-
2033
- # Note: the following code should be kept in sync with the function
2034
- # `jaxsim.api.ode.system_velocity_dynamics`. We cannot merge them since
2035
- # there we need to get also aux_data.
2036
-
2037
- # Build link forces if not provided.
2038
- # These forces are expressed in the frame corresponding to the velocity
2039
- # representation of data.
2040
- O_f_L = (
2041
- jnp.atleast_2d(link_forces.squeeze())
2042
- if link_forces is not None
2043
- else jnp.zeros((model.number_of_links(), 6))
2044
- ).astype(float)
2045
-
2046
- # Build joint force references if not provided.
2047
- joint_force_references = (
2048
- jnp.atleast_1d(joint_force_references)
2049
- if joint_force_references is not None
2050
- else jnp.zeros(model.dofs())
2051
- )
2052
-
2053
- # We expect that the 6D forces included in the `link_forces` argument are expressed
2054
- # in the frame corresponding to the velocity representation of `data`.
2055
- input_references = js.references.JaxSimModelReferences.build(
2056
- model=model,
2057
- data=data,
2058
- velocity_representation=data.velocity_representation,
2059
- link_forces=O_f_L,
2060
- joint_force_references=joint_force_references,
2061
- )
2062
-
2063
- # Compute the 6D forces applied to the links equivalent to the forces applied
2064
- # to the frames associated to the collidable points.
2065
- f_L, _ = model.contact_model.compute_link_contact_forces(
2066
- model=model,
2067
- data=data,
2068
- link_forces=input_references.link_forces(model=model, data=data),
2069
- joint_force_references=input_references.joint_force_references(),
2070
- **kwargs,
2071
- )
2072
-
2073
- return f_L
2074
-
2075
-
2076
1911
  # ======
2077
1912
  # Energy
2078
1913
  # ======
@@ -2113,7 +1948,7 @@ def kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Flo
2113
1948
  """
2114
1949
 
2115
1950
  with data.switch_velocity_representation(velocity_representation=VelRepr.Body):
2116
- B_ν = data.generalized_velocity()
1951
+ B_ν = data.generalized_velocity
2117
1952
  M_B = free_floating_mass_matrix(model=model, data=data)
2118
1953
 
2119
1954
  K = 0.5 * B_ν.T @ M_B @ B_ν
@@ -2135,11 +1970,8 @@ def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.F
2135
1970
  """
2136
1971
 
2137
1972
  m = total_mass(model=model)
2138
- gravity = data.gravity.squeeze()
2139
1973
  W_p̃_CoM = jnp.hstack([js.com.com_position(model=model, data=data), 1])
2140
-
2141
- U = -jnp.hstack([gravity, 0]) @ (m * W_p̃_CoM)
2142
- return U.squeeze().astype(float)
1974
+ return jnp.sum((m * W_p̃_CoM)[2] * model.gravity)
2143
1975
 
2144
1976
 
2145
1977
  # ==========
@@ -2153,34 +1985,22 @@ def step(
2153
1985
  model: JaxSimModel,
2154
1986
  data: js.data.JaxSimModelData,
2155
1987
  *,
2156
- t0: jtp.FloatLike = 0.0,
2157
- dt: jtp.FloatLike | None = None,
2158
- integrator: jaxsim.integrators.Integrator | None = None,
2159
- integrator_metadata: dict[str, Any] | None = None,
2160
1988
  link_forces: jtp.MatrixLike | None = None,
2161
1989
  joint_force_references: jtp.VectorLike | None = None,
2162
- **kwargs,
2163
- ) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:
1990
+ ) -> js.data.JaxSimModelData:
2164
1991
  """
2165
1992
  Perform a simulation step.
2166
1993
 
2167
1994
  Args:
2168
1995
  model: The model to consider.
2169
1996
  data: The data of the considered model.
2170
- integrator: The integrator to use.
2171
- integrator_metadata: The metadata of the integrator, if needed.
2172
- t0: The initial time to consider. Only relevant for time-dependent dynamics.
2173
1997
  dt: The time step to consider. If not specified, it is read from the model.
2174
1998
  link_forces:
2175
- The 6D forces to apply to the links expressed in the frame corresponding to
2176
- the velocity representation of `data`.
1999
+ The 6D forces to apply to the links expressed in same representation of data.
2177
2000
  joint_force_references: The joint force references to consider.
2178
- kwargs: Additional kwargs to pass to the integrator.
2179
2001
 
2180
2002
  Returns:
2181
- A tuple containing the new data of the model and a dictionary of auxiliary
2182
- data computed during the step. If the integrator has metadata, the dictionary
2183
- will contain the new metadata stored in the `integrator_metadata` key.
2003
+ The new data of the model after the simulation step.
2184
2004
 
2185
2005
  Note:
2186
2006
  In order to reduce the occurrences of frame conversions performed internally,
@@ -2188,167 +2008,84 @@ def step(
2188
2008
  particularly useful for automatically differentiated logic.
2189
2009
  """
2190
2010
 
2191
- # Extract the integrator kwargs.
2192
- # The following logic allows using integrators having kwargs colliding with the
2193
- # kwargs of this step function.
2194
- kwargs = kwargs if kwargs is not None else {}
2195
- integrator_kwargs = kwargs.pop("integrator_kwargs", {})
2196
- integrator_kwargs = kwargs | integrator_kwargs
2197
-
2198
- # Extract the integrator and the optional metadata.
2199
- integrator_metadata_t0 = integrator_metadata
2200
- integrator = integrator if integrator is not None else model.integrator
2201
-
2202
- # Initialize the time-related variables.
2203
- state_t0 = data.state
2204
- t0 = jnp.array(t0, dtype=float)
2205
- dt = jnp.array(dt if dt is not None else model.time_step).astype(float)
2206
-
2207
- # The visco-elastic contacts operate at best with their own integrator.
2208
- # They can be used with Euler-like integrators, paying the price of ignoring
2209
- # some of the benefits of continuous-time integration on the system position.
2210
- # Furthermore, the requirement to know the Δt used by the integrator is not
2211
- # compatible with high-order integrators, that use advanced RK stages to evaluate
2212
- # the dynamics at intermediate times.
2213
- module = jaxsim.rbda.contacts.visco_elastic.step.__module__
2214
- name = jaxsim.rbda.contacts.visco_elastic.step.__name__
2215
- msg = "You need to use the custom '{}.{}' function with this contact model."
2216
- jaxsim.exceptions.raise_runtime_error_if(
2217
- condition=(
2218
- isinstance(model.contact_model, jaxsim.rbda.contacts.ViscoElasticContacts)
2219
- & (
2220
- ~jnp.allclose(dt, model.time_step)
2221
- | ~int(
2222
- isinstance(integrator, jaxsim.integrators.fixed_step.ForwardEuler)
2223
- )
2224
- )
2225
- ),
2226
- msg=msg.format(module, name),
2227
- )
2011
+ # TODO: some contact models here may want to perform a dynamic filtering of
2012
+ # the enabled collidable points
2228
2013
 
2229
- # =================
2230
- # Phase 1: pre-step
2231
- # =================
2014
+ # Extract the inputs
2015
+ O_f_L_external = jnp.atleast_2d(
2016
+ jnp.array(link_forces, dtype=float).squeeze()
2017
+ if link_forces is not None
2018
+ else jnp.zeros((model.number_of_links(), 6))
2019
+ )
2232
2020
 
2233
- # TODO: some contact models here may want to perform a dynamic filtering of
2234
- # the enabled collidable points.
2021
+ # Get the external forces in inertial-fixed representation.
2022
+ W_f_L_external = jax.vmap(
2023
+ lambda f_L, W_H_L: js.data.JaxSimModelData.other_representation_to_inertial(
2024
+ f_L,
2025
+ other_representation=data.velocity_representation,
2026
+ transform=W_H_L,
2027
+ is_force=True,
2028
+ )
2029
+ )(O_f_L_external, data._link_transforms)
2235
2030
 
2236
- # Build the references object.
2237
- # We assume that the link forces are expressed in the frame corresponding to the
2238
- # velocity representation of the data.
2239
- references = js.references.JaxSimModelReferences.build(
2240
- model=model,
2241
- data=data,
2242
- velocity_representation=data.velocity_representation,
2243
- link_forces=link_forces,
2244
- joint_force_references=joint_force_references,
2031
+ τ_references = jnp.atleast_1d(
2032
+ jnp.array(joint_force_references, dtype=float).squeeze()
2033
+ if joint_force_references is not None
2034
+ else jnp.zeros(model.dofs())
2245
2035
  )
2246
2036
 
2247
- # =============
2248
- # Phase 2: step
2249
- # =============
2037
+ # ================================
2038
+ # Compute the total joint torques
2039
+ # ================================
2250
2040
 
2251
- # Prepare the references to pass.
2252
- with references.switch_velocity_representation(data.velocity_representation):
2253
-
2254
- f_L = references.link_forces(model=model, data=data)
2255
- τ_references = references.joint_force_references(model=model)
2256
-
2257
- # Step the dynamics forward.
2258
- state_tf, integrator_metadata_tf = integrator.step(
2259
- x0=state_t0,
2260
- t0=t0,
2261
- dt=dt,
2262
- metadata=integrator_metadata_t0,
2263
- # Always inject the current (model, data) pair into the system dynamics
2264
- # considered by the integrator, and include the input variables represented
2265
- # by the pair (f_L, τ_references).
2266
- # Note that the wrapper of the system dynamics will override (state_x0, t0)
2267
- # inside the passed data even if it is not strictly needed. This logic is
2268
- # necessary to reuse the jit-compiled step function of compatible pytrees
2269
- # of model and data produced e.g. by parameterized applications.
2270
- **(
2271
- dict(
2272
- model=model,
2273
- data=data,
2274
- link_forces=f_L,
2275
- joint_force_references=τ_references,
2276
- )
2277
- | integrator_kwargs
2278
- ),
2041
+ τ_total = js.actuation_model.compute_resultant_torques(
2042
+ model, data, joint_force_references=τ_references
2279
2043
  )
2280
2044
 
2281
- # Store the new state of the model.
2282
- data_tf = data.replace(state=state_tf)
2045
+ # ======================
2046
+ # Compute contact forces
2047
+ # ======================
2283
2048
 
2284
- # ==================
2285
- # Phase 3: post-step
2286
- # ==================
2049
+ W_f_L_terrain = jnp.zeros_like(W_f_L_external)
2287
2050
 
2288
- # Post process the simulation state, if needed.
2289
- match model.contact_model:
2051
+ if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
2290
2052
 
2291
- # Rigid contact models use an impact model that produces discontinuous model velocities.
2292
- # Hence, here we need to reset the velocity after each impact to guarantee that
2293
- # the linear velocity of the active collidable points is zero.
2294
- case jaxsim.rbda.contacts.RigidContacts():
2053
+ # Compute the 6D forces W_f ℝ^{n_L × 6} applied to links due to contact
2054
+ # with the terrain.
2055
+ W_f_L_terrain = js.contact_model.link_contact_forces(
2056
+ model=model,
2057
+ data=data,
2058
+ link_forces=W_f_L_external,
2059
+ joint_torques=τ_total,
2060
+ )
2295
2061
 
2296
- # Raise runtime error for not supported case in which Rigid contacts and
2297
- # Baumgarte stabilization are enabled and used with ForwardEuler integrator.
2298
- jaxsim.exceptions.raise_runtime_error_if(
2299
- condition=isinstance(
2300
- integrator,
2301
- jaxsim.integrators.fixed_step.ForwardEuler
2302
- | jaxsim.integrators.fixed_step.ForwardEulerSO3,
2303
- )
2304
- & ((data_tf.contacts_params.K > 0) | (data_tf.contacts_params.D > 0)),
2305
- msg="Baumgarte stabilization is not supported with ForwardEuler integrators",
2306
- )
2062
+ # ==============================
2063
+ # Compute the total link forces
2064
+ # ==============================
2307
2065
 
2308
- # Extract the indices corresponding to the enabled collidable points.
2309
- indices_of_enabled_collidable_points = (
2310
- model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
2311
- )
2066
+ W_f_L_total = W_f_L_external + W_f_L_terrain
2312
2067
 
2313
- W_p_C = js.contact.collidable_point_positions(model, data_tf)[
2314
- indices_of_enabled_collidable_points
2315
- ]
2316
-
2317
- # Compute the penetration depth of the collidable points.
2318
- δ, *_ = jax.vmap(
2319
- jaxsim.rbda.contacts.common.compute_penetration_data,
2320
- in_axes=(0, 0, None),
2321
- )(W_p_C, jnp.zeros_like(W_p_C), model.terrain)
2322
-
2323
- with data_tf.switch_velocity_representation(VelRepr.Mixed):
2324
- J_WC = js.contact.jacobian(model, data_tf)[
2325
- indices_of_enabled_collidable_points
2326
- ]
2327
- M = js.model.free_floating_mass_matrix(model, data_tf)
2328
- BW_ν_pre_impact = data_tf.generalized_velocity()
2329
-
2330
- # Compute the impact velocity.
2331
- # It may be discontinuous in case new contacts are made.
2332
- BW_ν_post_impact = (
2333
- jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity(
2334
- generalized_velocity=BW_ν_pre_impact,
2335
- inactive_collidable_points=(δ <= 0),
2336
- M=M,
2337
- J_WC=J_WC,
2338
- )
2339
- )
2068
+ # ===============================
2069
+ # Compute the system acceleration
2070
+ # ===============================
2340
2071
 
2341
- # Reset the generalized velocity.
2342
- data_tf = data_tf.reset_base_velocity(BW_ν_post_impact[0:6])
2343
- data_tf = data_tf.reset_joint_velocities(BW_ν_post_impact[6:])
2072
+ with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
2073
+ W_v̇_WB, = js.ode.system_acceleration(
2074
+ model=model,
2075
+ data=data,
2076
+ link_forces=W_f_L_total,
2077
+ joint_torques=τ_total,
2078
+ )
2344
2079
 
2345
- # Restore the input velocity representation.
2346
- data_tf = data_tf.replace(
2347
- velocity_representation=data.velocity_representation, validate=False
2348
- )
2080
+ # =============================
2081
+ # Advance the simulation state
2082
+ # =============================
2349
2083
 
2350
- return data_tf, {} | (
2351
- dict(integrator_metadata=integrator_metadata_tf)
2352
- if integrator_metadata is not None
2353
- else {}
2084
+ data_tf = js.integrators.semi_implicit_euler_integration(
2085
+ model=model,
2086
+ data=data,
2087
+ base_acceleration_inertial=W_v̇_WB,
2088
+ joint_accelerations=s̈,
2354
2089
  )
2090
+
2091
+ return data_tf