jaxsim 0.4.3.dev12__py3-none-any.whl → 0.4.3.dev18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev12'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev12')
15
+ __version__ = version = '0.4.3.dev18'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev18')
jaxsim/api/common.py CHANGED
@@ -3,7 +3,7 @@ import contextlib
3
3
  import dataclasses
4
4
  import enum
5
5
  import functools
6
- from typing import ContextManager
6
+ from collections.abc import Iterator
7
7
 
8
8
  import jax
9
9
  import jax.numpy as jnp
@@ -44,7 +44,7 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
44
44
  @contextlib.contextmanager
45
45
  def switch_velocity_representation(
46
46
  self, velocity_representation: VelRepr
47
- ) -> ContextManager[Self]:
47
+ ) -> Iterator[Self]:
48
48
  """
49
49
  Context manager to temporarily switch the velocity representation.
50
50
 
jaxsim/api/contact.py CHANGED
@@ -114,19 +114,24 @@ def collidable_point_forces(
114
114
 
115
115
  @jax.jit
116
116
  def collidable_point_dynamics(
117
- model: js.model.JaxSimModel, data: js.data.JaxSimModelData
118
- ) -> tuple[jtp.Matrix, jtp.Matrix]:
117
+ model: js.model.JaxSimModel,
118
+ data: js.data.JaxSimModelData,
119
+ link_forces: jtp.MatrixLike | None = None,
120
+ ) -> tuple[jtp.Matrix, dict[str, jtp.Array]]:
119
121
  r"""
120
- Compute the 6D force applied to each collidable point and the corresponding
121
- material deformation rate.
122
+ Compute the 6D force applied to each collidable point.
122
123
 
123
124
  Args:
124
125
  model: The model to consider.
125
126
  data: The data of the considered model.
127
+ link_forces:
128
+ The 6D external forces to apply to the links expressed in the same
129
+ representation of data.
126
130
 
127
131
  Returns:
128
- The 6D force applied to each collidable point and the corresponding
129
- material deformation rate.
132
+ The 6D force applied to each collidable point and additional data based on the contact model configured:
133
+ - Soft: the material deformation rate.
134
+ - Rigid: nothing.
130
135
 
131
136
  Note:
132
137
  The material deformation rate is always returned in the mixed frame
@@ -138,7 +143,8 @@ def collidable_point_dynamics(
138
143
  # all collidable points belonging to the robot.
139
144
  W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
140
145
 
141
- # Import privately the soft contacts classes.
146
+ # Import privately the contacts classes.
147
+ from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
142
148
  from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
143
149
 
144
150
  # Build the soft contact model.
@@ -161,9 +167,31 @@ def collidable_point_dynamics(
161
167
  W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)(
162
168
  W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation
163
169
  )
170
+ aux_data = dict(m_dot=CW_ṁ)
171
+
172
+ case RigidContacts():
173
+ assert isinstance(model.contact_model, RigidContacts)
174
+ assert isinstance(data.state.contact, RigidContactsState)
175
+
176
+ # Build the contact model.
177
+ rigid_contacts = RigidContacts(
178
+ parameters=data.contacts_params, terrain=model.terrain
179
+ )
180
+
181
+ # Compute the 6D force expressed in the inertial frame and applied to each
182
+ # collidable point.
183
+ W_f_Ci, _ = rigid_contacts.compute_contact_forces(
184
+ position=W_p_Ci,
185
+ velocity=W_ṗ_Ci,
186
+ model=model,
187
+ data=data,
188
+ link_forces=link_forces,
189
+ )
190
+
191
+ aux_data = dict()
164
192
 
165
193
  case _:
166
- raise ValueError("Invalid contact model {}".format(model.contact_model))
194
+ raise ValueError(f"Invalid contact model {model.contact_model}")
167
195
 
168
196
  # Convert the 6D forces to the active representation.
169
197
  f_Ci = jax.vmap(
@@ -175,7 +203,7 @@ def collidable_point_dynamics(
175
203
  )
176
204
  )(W_f_Ci)
177
205
 
178
- return f_Ci, CW_ṁ
206
+ return f_Ci, aux_data
179
207
 
180
208
 
181
209
  @functools.partial(jax.jit, static_argnames=["link_names"])
jaxsim/api/data.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
4
  import functools
5
- from typing import Sequence
5
+ from collections.abc import Sequence
6
6
 
7
7
  import jax
8
8
  import jax.numpy as jnp
jaxsim/api/frame.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import functools
2
- from typing import Sequence
2
+ from collections.abc import Sequence
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
jaxsim/api/joint.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import functools
2
- from typing import Sequence
2
+ from collections.abc import Sequence
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
jaxsim/api/link.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import functools
2
- from typing import Sequence
2
+ from collections.abc import Sequence
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
jaxsim/api/model.py CHANGED
@@ -4,7 +4,8 @@ import copy
4
4
  import dataclasses
5
5
  import functools
6
6
  import pathlib
7
- from typing import Any, Sequence
7
+ from collections.abc import Sequence
8
+ from typing import Any
8
9
 
9
10
  import jax
10
11
  import jax.numpy as jnp
@@ -13,6 +14,7 @@ import rod
13
14
  from jax_dataclasses import Static
14
15
 
15
16
  import jaxsim.api as js
17
+ import jaxsim.exceptions
16
18
  import jaxsim.terrain
17
19
  import jaxsim.typing as jtp
18
20
  from jaxsim.math import Adjoint, Cross
@@ -1889,6 +1891,8 @@ def step(
1889
1891
  and the new state of the integrator.
1890
1892
  """
1891
1893
 
1894
+ from jaxsim.rbda.contacts.rigid import RigidContacts
1895
+
1892
1896
  # Extract the integrator kwargs.
1893
1897
  # The following logic allows using integrators having kwargs colliding with the
1894
1898
  # kwargs of this step function.
@@ -1900,12 +1904,12 @@ def step(
1900
1904
 
1901
1905
  # Extract the initial resources.
1902
1906
  t0_ns = data.time_ns
1903
- state_x0 = data.state
1907
+ state_t0 = data.state
1904
1908
  integrator_state_x0 = integrator_state
1905
1909
 
1906
1910
  # Step the dynamics forward.
1907
- state_xf, integrator_state_xf = integrator.step(
1908
- x0=state_x0,
1911
+ state_tf, integrator_state_tf = integrator.step(
1912
+ x0=state_t0,
1909
1913
  t0=jnp.array(t0_ns / 1e9).astype(float),
1910
1914
  dt=dt,
1911
1915
  params=integrator_state_x0,
@@ -1927,11 +1931,61 @@ def step(
1927
1931
  ),
1928
1932
  )
1929
1933
 
1930
- return (
1934
+ data_tf = (
1931
1935
  # Store the new state of the model and the new time.
1932
1936
  data.replace(
1933
- state=state_xf,
1937
+ state=state_tf,
1934
1938
  time_ns=t0_ns + jnp.array(dt * 1e9).astype(jnp.uint64),
1935
- ),
1936
- integrator_state_xf,
1939
+ )
1940
+ )
1941
+
1942
+ # Post process the simulation state, if needed.
1943
+ match model.contact_model:
1944
+
1945
+ # Rigid contact models use an impact model that produces a discontinuous model velocity.
1946
+ # Hence here we need to reset the velocity after each impact to guarantee that
1947
+ # the linear velocity of the active collidable points is zero.
1948
+ case RigidContacts():
1949
+ # Raise runtime error for not supported case in which Rigid contacts and Baumgarte stabilization
1950
+ # enabled are used with ForwardEuler integrator.
1951
+ jaxsim.exceptions.raise_runtime_error_if(
1952
+ condition=jnp.logical_and(
1953
+ isinstance(
1954
+ integrator,
1955
+ jaxsim.integrators.fixed_step.ForwardEuler
1956
+ | jaxsim.integrators.fixed_step.ForwardEulerSO3,
1957
+ ),
1958
+ jnp.array(
1959
+ [data_tf.contacts_params.K, data_tf.contacts_params.D]
1960
+ ).any(),
1961
+ ),
1962
+ msg="Baumgarte stabilization is not supported with ForwardEuler integrators",
1963
+ )
1964
+
1965
+ with data_tf.switch_velocity_representation(VelRepr.Mixed):
1966
+ W_p_C = js.contact.collidable_point_positions(model, data_tf)
1967
+ M = js.model.free_floating_mass_matrix(model, data_tf)
1968
+ J_WC = js.contact.jacobian(model, data_tf)
1969
+ px, py, _ = W_p_C.T
1970
+ terrain_height = jax.vmap(model.terrain.height)(px, py)
1971
+ inactive_collidable_points, _ = RigidContacts.detect_contacts(
1972
+ W_p_C=W_p_C,
1973
+ terrain_height=terrain_height,
1974
+ )
1975
+ BW_nu_post_impact = RigidContacts.compute_impact_velocity(
1976
+ data=data_tf,
1977
+ inactive_collidable_points=inactive_collidable_points,
1978
+ M=M,
1979
+ J_WC=J_WC,
1980
+ )
1981
+ data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6])
1982
+ data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:])
1983
+ # Restore the input velocity representation.
1984
+ data_tf = data_tf.replace(
1985
+ velocity_representation=data.velocity_representation, validate=False
1986
+ )
1987
+
1988
+ return (
1989
+ data_tf,
1990
+ integrator_state_tf,
1937
1991
  )
jaxsim/api/ode.py CHANGED
@@ -50,7 +50,7 @@ def wrap_system_dynamics_for_integration(
50
50
  # The wrapped dynamics will hold a reference of this object.
51
51
  model_closed = model.copy()
52
52
  data_closed = data.copy().replace(
53
- state=js.ode_data.ODEState.zero(model=model_closed)
53
+ state=js.ode_data.ODEState.zero(model=model_closed, data=data)
54
54
  )
55
55
 
56
56
  def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
@@ -88,7 +88,7 @@ def system_velocity_dynamics(
88
88
  *,
89
89
  joint_forces: jtp.Vector | None = None,
90
90
  link_forces: jtp.Vector | None = None,
91
- ) -> tuple[jtp.Vector, jtp.Vector, jtp.Matrix, dict[str, Any]]:
91
+ ) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]:
92
92
  """
93
93
  Compute the dynamics of the system velocity.
94
94
 
@@ -102,18 +102,10 @@ def system_velocity_dynamics(
102
102
 
103
103
  Returns:
104
104
  A tuple containing the derivative of the base 6D velocity in inertial-fixed
105
- representation, the derivative of the joint velocities, the derivative of
106
- the material deformation, and the dictionary of auxiliary data returned by
107
- the system dynamics evaluation.
105
+ representation, the derivative of the joint velocities, and auxiliary data
106
+ returned by the system dynamics evaluation.
108
107
  """
109
108
 
110
- # Build joint torques if not provided.
111
- τ = (
112
- jnp.atleast_1d(joint_forces.squeeze())
113
- if joint_forces is not None
114
- else jnp.zeros_like(data.joint_positions())
115
- ).astype(float)
116
-
117
109
  # Build link forces if not provided.
118
110
  # These forces are expressed in the frame corresponding to the velocity
119
111
  # representation of data.
@@ -123,6 +115,15 @@ def system_velocity_dynamics(
123
115
  else jnp.zeros((model.number_of_links(), 6))
124
116
  ).astype(float)
125
117
 
118
+ # We expect that the 6D forces included in the `link_forces` argument are expressed
119
+ # in the frame corresponding to the velocity representation of `data`.
120
+ references = js.references.JaxSimModelReferences.build(
121
+ model=model,
122
+ link_forces=O_f_L,
123
+ data=data,
124
+ velocity_representation=data.velocity_representation,
125
+ )
126
+
126
127
  # ======================
127
128
  # Compute contact forces
128
129
  # ======================
@@ -131,19 +132,17 @@ def system_velocity_dynamics(
131
132
  # with the terrain.
132
133
  W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float)
133
134
 
134
- # Import privately the soft contacts classes.
135
- from jaxsim.rbda.contacts.soft import SoftContactsState
136
-
137
- # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
138
- assert isinstance(data.state.contact, SoftContactsState)
139
- ṁ = jnp.zeros_like(data.state.contact.tangential_deformation).astype(float)
140
-
135
+ aux_data = {}
141
136
  if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
142
137
 
143
138
  # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
144
- # and the corresponding material deformation rates.
139
+ # along with contact-specific auxiliary states.
145
140
  with data.switch_velocity_representation(VelRepr.Inertial):
146
- W_f_Ci, = js.contact.collidable_point_dynamics(model=model, data=data)
141
+ W_f_Ci, aux_data = js.contact.collidable_point_dynamics(
142
+ model=model,
143
+ data=data,
144
+ link_forces=references.link_forces(model=model, data=data),
145
+ )
147
146
 
148
147
  # Construct the vector defining the parent link index of each collidable point.
149
148
  # We use this vector to sum the 6D forces of all collidable points rigidly
@@ -161,6 +160,74 @@ def system_velocity_dynamics(
161
160
 
162
161
  W_f_Li_terrain = mask.T @ W_f_Ci
163
162
 
163
+ # ===========================
164
+ # Compute system acceleration
165
+ # ===========================
166
+
167
+ # Compute the total link forces
168
+ with (
169
+ data.switch_velocity_representation(VelRepr.Inertial),
170
+ references.switch_velocity_representation(VelRepr.Inertial),
171
+ ):
172
+ references = references.apply_link_forces(
173
+ model=model,
174
+ data=data,
175
+ forces=W_f_Li_terrain,
176
+ additive=True,
177
+ )
178
+ # Get the link forces in the data representation
179
+ with references.switch_velocity_representation(data.velocity_representation):
180
+ f_L_total = references.link_forces(model=model, data=data)
181
+
182
+ # The following method always returns the inertial-fixed acceleration, and expects
183
+ # the link_forces expressed in the inertial frame.
184
+ W_v̇_WB, s̈ = system_acceleration(
185
+ model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
186
+ )
187
+
188
+ return W_v̇_WB, s̈, aux_data
189
+
190
+
191
+ def system_acceleration(
192
+ model: js.model.JaxSimModel,
193
+ data: js.data.JaxSimModelData,
194
+ *,
195
+ joint_forces: jtp.VectorLike | None = None,
196
+ link_forces: jtp.MatrixLike | None = None,
197
+ ) -> tuple[jtp.Vector, jtp.Vector]:
198
+ """
199
+ Compute the system acceleration in inertial-fixed representation.
200
+
201
+ Args:
202
+ model: The model to consider.
203
+ data: The data of the considered model.
204
+ joint_forces: The joint forces to apply.
205
+ link_forces:
206
+ The 6D forces to apply to the links expressed in the same representation of data.
207
+
208
+ Returns:
209
+ A tuple containing the base 6D acceleration in inertial-fixed representation
210
+ and the joint accelerations.
211
+ """
212
+
213
+ # ====================
214
+ # Validate input data
215
+ # ====================
216
+
217
+ # Build link forces if not provided.
218
+ f_L = (
219
+ jnp.atleast_2d(link_forces.squeeze())
220
+ if link_forces is not None
221
+ else jnp.zeros((model.number_of_links(), 6))
222
+ ).astype(float)
223
+
224
+ # Build joint torques if not provided.
225
+ τ = (
226
+ jnp.atleast_1d(joint_forces.squeeze())
227
+ if joint_forces is not None
228
+ else jnp.zeros_like(data.joint_positions())
229
+ ).astype(float)
230
+
164
231
  # ====================
165
232
  # Enforce joint limits
166
233
  # ====================
@@ -198,29 +265,25 @@ def system_velocity_dynamics(
198
265
 
199
266
  references = js.references.JaxSimModelReferences.build(
200
267
  model=model,
201
- joint_force_references=τ_total,
202
- link_forces=O_f_L,
203
268
  data=data,
204
269
  velocity_representation=data.velocity_representation,
270
+ joint_force_references=τ_total,
271
+ link_forces=f_L,
205
272
  )
206
273
 
207
- with references.switch_velocity_representation(VelRepr.Inertial):
208
- W_f_L = references.link_forces(model=model, data=data)
209
-
210
- # Compute the total external 6D forces applied to the links.
211
- W_f_L_total = W_f_L + W_f_Li_terrain
212
-
213
274
  # - Joint accelerations: s̈ ∈ ℝⁿ
214
275
  # - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
215
- with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
276
+ with (
277
+ data.switch_velocity_representation(velocity_representation=VelRepr.Inertial),
278
+ references.switch_velocity_representation(VelRepr.Inertial),
279
+ ):
216
280
  W_v̇_WB, s̈ = js.model.forward_dynamics_aba(
217
281
  model=model,
218
282
  data=data,
219
- joint_forces=τ_total,
220
- link_forces=W_f_L_total,
283
+ joint_forces=references.joint_force_references(),
284
+ link_forces=references.link_forces(),
221
285
  )
222
-
223
- return W_v̇_WB, s̈, ṁ, dict()
286
+ return W_v̇_WB, s̈
224
287
 
225
288
 
226
289
  @jax.jit
@@ -291,14 +354,29 @@ def system_dynamics(
291
354
  by the system dynamics evaluation.
292
355
  """
293
356
 
357
+ from jaxsim.rbda.contacts.rigid import RigidContacts
358
+ from jaxsim.rbda.contacts.soft import SoftContacts
359
+
294
360
  # Compute the accelerations and the material deformation rate.
295
- W_v̇_WB, s̈, ṁ, aux_dict = system_velocity_dynamics(
361
+ W_v̇_WB, s̈, aux_dict = system_velocity_dynamics(
296
362
  model=model,
297
363
  data=data,
298
364
  joint_forces=joint_forces,
299
365
  link_forces=link_forces,
300
366
  )
301
367
 
368
+ ode_state_kwargs = {}
369
+
370
+ match model.contact_model:
371
+ case SoftContacts():
372
+ ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"]
373
+
374
+ case RigidContacts():
375
+ pass
376
+
377
+ case _:
378
+ raise ValueError("Unable to determine contact state class prefix.")
379
+
302
380
  # Extract the velocities.
303
381
  W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
304
382
  model=model,
@@ -317,7 +395,7 @@ def system_dynamics(
317
395
  base_linear_velocity=W_v̇_WB[0:3],
318
396
  base_angular_velocity=W_v̇_WB[3:6],
319
397
  joint_velocities=s̈,
320
- tangential_deformation=ṁ,
398
+ **ode_state_kwargs,
321
399
  )
322
400
 
323
401
  return ode_state_derivative, aux_dict
jaxsim/api/ode_data.py CHANGED
@@ -6,6 +6,7 @@ import jax_dataclasses
6
6
  import jaxsim.api as js
7
7
  import jaxsim.typing as jtp
8
8
  from jaxsim.rbda import ContactsState
9
+ from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
9
10
  from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
10
11
  from jaxsim.utils import JaxsimDataclass
11
12
 
@@ -133,7 +134,7 @@ class ODEState(JaxsimDataclass):
133
134
  base_quaternion: jtp.Vector | None = None,
134
135
  base_linear_velocity: jtp.Vector | None = None,
135
136
  base_angular_velocity: jtp.Vector | None = None,
136
- tangential_deformation: jtp.Matrix | None = None,
137
+ **kwargs,
137
138
  ) -> ODEState:
138
139
  """
139
140
  Build an `ODEState` from a `JaxSimModel`.
@@ -148,9 +149,7 @@ class ODEState(JaxsimDataclass):
148
149
  The linear velocity of the base link in inertial-fixed representation.
149
150
  base_angular_velocity:
150
151
  The angular velocity of the base link in inertial-fixed representation.
151
- tangential_deformation:
152
- The matrix of 3D tangential material deformations corresponding to
153
- each collidable point.
152
+ kwargs: Additional arguments needed to build the contact state.
154
153
 
155
154
  Returns:
156
155
  The `ODEState` built from the `JaxSimModel`.
@@ -163,6 +162,7 @@ class ODEState(JaxsimDataclass):
163
162
  # Get the contact model from the `JaxSimModel`.
164
163
  match model.contact_model:
165
164
  case SoftContacts():
165
+ tangential_deformation = kwargs.get("tangential_deformation", None)
166
166
  contact = SoftContactsState.build_from_jaxsim_model(
167
167
  model=model,
168
168
  **(
@@ -171,6 +171,8 @@ class ODEState(JaxsimDataclass):
171
171
  else dict()
172
172
  ),
173
173
  )
174
+ case RigidContacts():
175
+ contact = RigidContactsState.build()
174
176
  case _:
175
177
  raise ValueError("Unable to determine contact state class prefix.")
176
178
 
@@ -214,7 +216,7 @@ class ODEState(JaxsimDataclass):
214
216
 
215
217
  # Get the contact model from the `JaxSimModel`.
216
218
  match contact:
217
- case SoftContactsState():
219
+ case SoftContactsState() | RigidContactsState():
218
220
  pass
219
221
  case None:
220
222
  contact = SoftContactsState.zero(model=model)
@@ -224,7 +226,7 @@ class ODEState(JaxsimDataclass):
224
226
  return ODEState(physics_model=physics_model_state, contact=contact)
225
227
 
226
228
  @staticmethod
227
- def zero(model: js.model.JaxSimModel) -> ODEState:
229
+ def zero(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> ODEState:
228
230
  """
229
231
  Build a zero `ODEState` from a `JaxSimModel`.
230
232
 
@@ -235,7 +237,9 @@ class ODEState(JaxsimDataclass):
235
237
  A zero `ODEState` instance.
236
238
  """
237
239
 
238
- model_state = ODEState.build(model=model)
240
+ model_state = ODEState.build(
241
+ model=model, contact=data.state.contact.zero(model=model)
242
+ )
239
243
 
240
244
  return model_state
241
245