jaxsim 0.4.3.dev64__py3-none-any.whl → 0.4.3.dev68__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/__init__.py CHANGED
@@ -20,11 +20,6 @@ def _jnp_options() -> None:
20
20
  if jnp.empty(0, dtype=float).dtype != jnp.empty(0, dtype=np.float64).dtype:
21
21
  logging.warning("Failed to enable 64bit precision in JAX")
22
22
 
23
- else:
24
- logging.warning(
25
- "Using 32bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
26
- )
27
-
28
23
 
29
24
  def _np_options() -> None:
30
25
  import numpy as np
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev64'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev64')
15
+ __version__ = version = '0.4.3.dev68'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev68')
jaxsim/api/contact.py CHANGED
@@ -131,8 +131,7 @@ def collidable_point_dynamics(
131
131
  Returns:
132
132
  The 6D force applied to each collidable point and additional data based on the contact model configured:
133
133
  - Soft: the material deformation rate.
134
- - Rigid: no additional data.
135
- - QuasiRigid: no additional data.
134
+ - Rigid: nothing.
136
135
 
137
136
  Note:
138
137
  The material deformation rate is always returned in the mixed frame
@@ -145,10 +144,6 @@ def collidable_point_dynamics(
145
144
  W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
146
145
 
147
146
  # Import privately the contacts classes.
148
- from jaxsim.rbda.contacts.relaxed_rigid import (
149
- RelaxedRigidContacts,
150
- RelaxedRigidContactsState,
151
- )
152
147
  from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
153
148
  from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
154
149
 
@@ -195,27 +190,6 @@ def collidable_point_dynamics(
195
190
 
196
191
  aux_data = dict()
197
192
 
198
- case RelaxedRigidContacts():
199
- assert isinstance(model.contact_model, RelaxedRigidContacts)
200
- assert isinstance(data.state.contact, RelaxedRigidContactsState)
201
-
202
- # Build the contact model.
203
- relaxed_rigid_contacts = RelaxedRigidContacts(
204
- parameters=data.contacts_params, terrain=model.terrain
205
- )
206
-
207
- # Compute the 6D force expressed in the inertial frame and applied to each
208
- # collidable point.
209
- W_f_Ci, _ = relaxed_rigid_contacts.compute_contact_forces(
210
- position=W_p_Ci,
211
- velocity=W_ṗ_Ci,
212
- model=model,
213
- data=data,
214
- link_forces=link_forces,
215
- )
216
-
217
- aux_data = dict()
218
-
219
193
  case _:
220
194
  raise ValueError(f"Invalid contact model {model.contact_model}")
221
195
 
jaxsim/api/data.py CHANGED
@@ -39,9 +39,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
39
39
  contacts_params: jaxsim.rbda.ContactsParams = dataclasses.field(repr=False)
40
40
 
41
41
  time_ns: jtp.Int = dataclasses.field(
42
- default_factory=lambda: jnp.array(
43
- 0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
44
- ),
42
+ default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
45
43
  )
46
44
 
47
45
  def __hash__(self) -> int:
@@ -174,14 +172,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
174
172
  )
175
173
 
176
174
  time_ns = (
177
- jnp.array(
178
- time * 1e9,
179
- dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32,
180
- )
175
+ jnp.array(time * 1e9, dtype=jnp.uint64)
181
176
  if time is not None
182
- else jnp.array(
183
- 0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
184
- )
177
+ else jnp.array(0, dtype=jnp.uint64)
185
178
  )
186
179
 
187
180
  if isinstance(model.contact_model, SoftContacts):
@@ -593,18 +586,16 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
593
586
  The updated `JaxSimModelData` object.
594
587
  """
595
588
 
596
- W_Q_B = jnp.array(base_quaternion, dtype=float)
597
-
598
- W_Q_B = jax.lax.select(
599
- pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
600
- on_true=W_Q_B,
601
- on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
602
- )
589
+ base_quaternion = jnp.array(base_quaternion)
603
590
 
604
591
  return self.replace(
605
592
  validate=True,
606
593
  state=self.state.replace(
607
- physics_model=self.state.physics_model.replace(base_quaternion=W_Q_B)
594
+ physics_model=self.state.physics_model.replace(
595
+ base_quaternion=jnp.atleast_1d(base_quaternion.squeeze()).astype(
596
+ float
597
+ )
598
+ )
608
599
  ),
609
600
  )
610
601
 
@@ -746,13 +737,6 @@ def random_model_data(
746
737
  jtp.FloatLike | Sequence[jtp.FloatLike],
747
738
  jtp.FloatLike | Sequence[jtp.FloatLike],
748
739
  ] = ((-1, -1, 0.5), 1.0),
749
- joint_pos_bounds: (
750
- tuple[
751
- jtp.FloatLike | Sequence[jtp.FloatLike],
752
- jtp.FloatLike | Sequence[jtp.FloatLike],
753
- ]
754
- | None
755
- ) = None,
756
740
  base_vel_lin_bounds: tuple[
757
741
  jtp.FloatLike | Sequence[jtp.FloatLike],
758
742
  jtp.FloatLike | Sequence[jtp.FloatLike],
@@ -778,8 +762,6 @@ def random_model_data(
778
762
  key: The random key.
779
763
  velocity_representation: The velocity representation to use.
780
764
  base_pos_bounds: The bounds for the base position.
781
- joint_pos_bounds:
782
- The bounds for the joint positions (reading the joint limits if None).
783
765
  base_vel_lin_bounds: The bounds for the base linear velocity.
784
766
  base_vel_ang_bounds: The bounds for the base angular velocity.
785
767
  joint_vel_bounds: The bounds for the joint velocities.
@@ -824,19 +806,8 @@ def random_model_data(
824
806
  ).wxyz
825
807
 
826
808
  if model.number_of_joints() > 0:
827
-
828
- s_min, s_max = (
829
- jnp.array(joint_pos_bounds, dtype=float)
830
- if joint_pos_bounds is not None
831
- else (None, None)
832
- )
833
-
834
- physics_model_state.joint_positions = (
835
- js.joint.random_joint_positions(model=model, key=k3)
836
- if (s_min is None or s_max is None)
837
- else jax.random.uniform(
838
- key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max
839
- )
809
+ physics_model_state.joint_positions = js.joint.random_joint_positions(
810
+ model=model, key=k3
840
811
  )
841
812
 
842
813
  physics_model_state.joint_velocities = jax.random.uniform(
jaxsim/api/joint.py CHANGED
@@ -180,77 +180,17 @@ def random_joint_positions(
180
180
 
181
181
  Args:
182
182
  model: The model to consider.
183
- joint_names: The names of the considered joints (all if None).
184
- key: The random key (initialized from seed 0 if None).
185
-
186
- Note:
187
- If the joint range or revolute joints is larger than 2π, their joint positions
188
- will be sampled from an interval of size 2π.
183
+ joint_names: The names of the joints.
184
+ key: The random key.
189
185
 
190
186
  Returns:
191
187
  The random joint positions.
192
188
  """
193
189
 
194
- # Consider the key corresponding to a zero seed if it was not passed.
195
190
  key = key if key is not None else jax.random.PRNGKey(seed=0)
196
191
 
197
- # Get the joint limits parsed from the model description.
198
192
  s_min, s_max = position_limits(model=model, joint_names=joint_names)
199
193
 
200
- # Get the joint indices.
201
- # Note that it will trigger an exception if the given `joint_names` are not valid.
202
- joint_names = joint_names if joint_names is not None else model.joint_names()
203
- joint_indices = names_to_idxs(model=model, joint_names=joint_names)
204
-
205
- from jaxsim.parsers.descriptions.joint import JointType
206
-
207
- # Filter for revolute joints.
208
- is_revolute = jnp.where(
209
- jnp.array(model.kin_dyn_parameters.joint_model.joint_types[1:])[joint_indices]
210
- == JointType.Revolute,
211
- True,
212
- False,
213
- )
214
-
215
- # Shorthand for π.
216
- π = jnp.pi
217
-
218
- # Filter for revolute with full range (or continuous).
219
- is_revolute_full_range = jnp.logical_and(is_revolute, s_max - s_min >= 2 * π)
220
-
221
- # Clip the lower limit to -π if the joint range is larger than [-π, π].
222
- s_min = jnp.where(
223
- jnp.logical_and(
224
- is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
225
- ),
226
- -π,
227
- s_min,
228
- )
229
-
230
- # Clip the upper limit to +π if the joint range is larger than [-π, π].
231
- s_max = jnp.where(
232
- jnp.logical_and(
233
- is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
234
- ),
235
- π,
236
- s_max,
237
- )
238
-
239
- # Shift the lower limit if the upper limit is smaller than +π.
240
- s_min = jnp.where(
241
- jnp.logical_and(is_revolute_full_range, s_max < π),
242
- s_max - 2 * π,
243
- s_min,
244
- )
245
-
246
- # Shift the upper limit if the lower limit is larger than -π.
247
- s_max = jnp.where(
248
- jnp.logical_and(is_revolute_full_range, s_min > -π),
249
- s_min + 2 * π,
250
- s_max,
251
- )
252
-
253
- # Sample the joint positions.
254
194
  s_random = jax.random.uniform(
255
195
  minval=s_min,
256
196
  maxval=s_max,
jaxsim/api/model.py CHANGED
@@ -1931,22 +1931,11 @@ def step(
1931
1931
  ),
1932
1932
  )
1933
1933
 
1934
- tf_ns = t0_ns + jnp.array(dt * 1e9, dtype=t0_ns.dtype)
1935
- tf_ns = jnp.where(tf_ns >= t0_ns, tf_ns, jnp.array(0, dtype=t0_ns.dtype))
1936
-
1937
- jax.lax.cond(
1938
- pred=tf_ns < t0_ns,
1939
- true_fun=lambda: jax.debug.print(
1940
- "The simulation time overflowed, resetting simulation time to 0."
1941
- ),
1942
- false_fun=lambda: None,
1943
- )
1944
-
1945
1934
  data_tf = (
1946
1935
  # Store the new state of the model and the new time.
1947
1936
  data.replace(
1948
1937
  state=state_tf,
1949
- time_ns=tf_ns,
1938
+ time_ns=t0_ns + jnp.array(dt * 1e9).astype(jnp.uint64),
1950
1939
  )
1951
1940
  )
1952
1941
 
jaxsim/api/ode.py CHANGED
@@ -175,15 +175,17 @@ def system_velocity_dynamics(
175
175
  forces=W_f_Li_terrain,
176
176
  additive=True,
177
177
  )
178
-
179
- # Get the link forces in inertial representation
178
+ # Get the link forces in the data representation
179
+ with references.switch_velocity_representation(data.velocity_representation):
180
180
  f_L_total = references.link_forces(model=model, data=data)
181
181
 
182
- v̇_WB, = system_acceleration(
183
- model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
184
- )
182
+ # The following method always returns the inertial-fixed acceleration, and expects
183
+ # the link_forces expressed in the inertial frame.
184
+ W_v̇_WB, s̈ = system_acceleration(
185
+ model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
186
+ )
185
187
 
186
- return v̇_WB, s̈, aux_data
188
+ return W_v̇_WB, s̈, aux_data
187
189
 
188
190
 
189
191
  def system_acceleration(
@@ -194,7 +196,7 @@ def system_acceleration(
194
196
  link_forces: jtp.MatrixLike | None = None,
195
197
  ) -> tuple[jtp.Vector, jtp.Vector]:
196
198
  """
197
- Compute the system acceleration in the active representation.
199
+ Compute the system acceleration in inertial-fixed representation.
198
200
 
199
201
  Args:
200
202
  model: The model to consider.
@@ -204,7 +206,7 @@ def system_acceleration(
204
206
  The 6D forces to apply to the links expressed in the same representation of data.
205
207
 
206
208
  Returns:
207
- A tuple containing the base 6D acceleration in in the active representation
209
+ A tuple containing the base 6D acceleration in inertial-fixed representation
208
210
  and the joint accelerations.
209
211
  """
210
212
 
@@ -270,15 +272,18 @@ def system_acceleration(
270
272
  )
271
273
 
272
274
  # - Joint accelerations: s̈ ∈ ℝⁿ
273
- # - Base acceleration: v̇_WB ∈ ℝ⁶
274
- v̇_WB, s̈ = js.model.forward_dynamics_aba(
275
- model=model,
276
- data=data,
277
- joint_forces=references.joint_force_references(model=model),
278
- link_forces=references.link_forces(model=model, data=data),
279
- )
280
-
281
- return v̇_WB,
275
+ # - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
276
+ with (
277
+ data.switch_velocity_representation(velocity_representation=VelRepr.Inertial),
278
+ references.switch_velocity_representation(VelRepr.Inertial),
279
+ ):
280
+ W_v̇_WB, s̈ = js.model.forward_dynamics_aba(
281
+ model=model,
282
+ data=data,
283
+ joint_forces=references.joint_force_references(),
284
+ link_forces=references.link_forces(),
285
+ )
286
+ return W_v̇_WB, s̈
282
287
 
283
288
 
284
289
  @jax.jit
@@ -348,7 +353,7 @@ def system_dynamics(
348
353
  corresponding derivative, and the dictionary of auxiliary data returned
349
354
  by the system dynamics evaluation.
350
355
  """
351
- from jaxsim.rbda.contacts.relaxed_rigid import RelaxedRigidContacts
356
+
352
357
  from jaxsim.rbda.contacts.rigid import RigidContacts
353
358
  from jaxsim.rbda.contacts.soft import SoftContacts
354
359
 
@@ -366,7 +371,7 @@ def system_dynamics(
366
371
  case SoftContacts():
367
372
  ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"]
368
373
 
369
- case RigidContacts() | RelaxedRigidContacts():
374
+ case RigidContacts():
370
375
  pass
371
376
 
372
377
  case _:
jaxsim/api/ode_data.py CHANGED
@@ -6,10 +6,6 @@ import jax_dataclasses
6
6
  import jaxsim.api as js
7
7
  import jaxsim.typing as jtp
8
8
  from jaxsim.rbda import ContactsState
9
- from jaxsim.rbda.contacts.relaxed_rigid import (
10
- RelaxedRigidContacts,
11
- RelaxedRigidContactsState,
12
- )
13
9
  from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
14
10
  from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
15
11
  from jaxsim.utils import JaxsimDataclass
@@ -177,10 +173,6 @@ class ODEState(JaxsimDataclass):
177
173
  )
178
174
  case RigidContacts():
179
175
  contact = RigidContactsState.build()
180
-
181
- case RelaxedRigidContacts():
182
- contact = RelaxedRigidContactsState.build()
183
-
184
176
  case _:
185
177
  raise ValueError("Unable to determine contact state class prefix.")
186
178
 
@@ -224,9 +216,7 @@ class ODEState(JaxsimDataclass):
224
216
 
225
217
  # Get the contact model from the `JaxSimModel`.
226
218
  match contact:
227
- case (
228
- SoftContactsState() | RigidContactsState() | RelaxedRigidContactsState()
229
- ):
219
+ case SoftContactsState() | RigidContactsState():
230
220
  pass
231
221
  case None:
232
222
  contact = SoftContactsState.zero(model=model)
@@ -497,7 +497,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
497
497
  b: jtp.Matrix,
498
498
  c: jtp.Vector,
499
499
  index_of_solution: jtp.IntLike = 0,
500
- ) -> tuple[bool, int | None]:
500
+ ) -> [bool, int | None]:
501
501
  """
502
502
  Check if the Butcher tableau supports the FSAL (first-same-as-last) property.
503
503
 
jaxsim/math/inertia.py CHANGED
@@ -45,7 +45,7 @@ class Inertia:
45
45
  M (jtp.Matrix): The 6x6 inertia matrix.
46
46
 
47
47
  Returns:
48
- tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
48
+ Tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
49
49
 
50
50
  Raises:
51
51
  ValueError: If the input matrix M has an unexpected shape.
jaxsim/mujoco/loaders.py CHANGED
@@ -211,7 +211,7 @@ class RodModelToMjcf:
211
211
  joints_dict = {j.name: j for j in rod_model.joints()}
212
212
 
213
213
  # Convert all the joints not considered to fixed joints.
214
- for joint_name in {j.name for j in rod_model.joints()} - considered_joints:
214
+ for joint_name in set(j.name for j in rod_model.joints()) - considered_joints:
215
215
  joints_dict[joint_name].type = "fixed"
216
216
 
217
217
  # Convert the ROD model to URDF.
@@ -289,10 +289,10 @@ class RodModelToMjcf:
289
289
  mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets)
290
290
 
291
291
  # Get the joint names.
292
- mj_joint_names = {
292
+ mj_joint_names = set(
293
293
  mj.mj_id2name(mj_model, mj.mjtObj.mjOBJ_JOINT, idx)
294
294
  for idx in range(mj_model.njnt)
295
- }
295
+ )
296
296
 
297
297
  # Check that the Mujoco model only has the considered joints.
298
298
  if mj_joint_names != considered_joints:
@@ -394,7 +394,7 @@ class KinematicGraph(Sequence[LinkDescription]):
394
394
  return copy.deepcopy(self)
395
395
 
396
396
  # Check if all considered joints are part of the full kinematic graph
397
- if len(set(considered_joints) - {j.name for j in full_graph.joints}) != 0:
397
+ if len(set(considered_joints) - set(j.name for j in full_graph.joints)) != 0:
398
398
  extra_j = set(considered_joints) - {j.name for j in full_graph.joints}
399
399
  msg = f"Not all joints to consider are part of the graph ({{{extra_j}}})"
400
400
  raise ValueError(msg)
@@ -536,8 +536,8 @@ class KinematicGraph(Sequence[LinkDescription]):
536
536
  root_link_name=full_graph.root.name,
537
537
  )
538
538
 
539
- assert {f.name for f in self.frames}.isdisjoint(
540
- {f.name for f in unconnected_frames + reduced_frames}
539
+ assert set(f.name for f in self.frames).isdisjoint(
540
+ set(f.name for f in unconnected_frames + reduced_frames)
541
541
  )
542
542
 
543
543
  for link in unconnected_links:
@@ -223,7 +223,7 @@ def extract_model_data(
223
223
  child=links_dict[j.child],
224
224
  jtype=utils.joint_to_joint_type(joint=j),
225
225
  axis=(
226
- np.array(j.axis.xyz.xyz, dtype=float)
226
+ np.array(j.axis.xyz.xyz)
227
227
  if j.axis is not None
228
228
  and j.axis.xyz is not None
229
229
  and j.axis.xyz.xyz is not None
@@ -232,43 +232,39 @@ def extract_model_data(
232
232
  pose=j.pose.transform() if j.pose is not None else np.eye(4),
233
233
  initial_position=0.0,
234
234
  position_limit=(
235
- float(
236
- j.axis.limit.lower
237
- if j.axis is not None
238
- and j.axis.limit is not None
239
- and j.axis.limit.lower is not None
240
- else jnp.finfo(float).min
235
+ (
236
+ float(j.axis.limit.lower)
237
+ if j.axis is not None and j.axis.limit is not None
238
+ else np.finfo(float).min
241
239
  ),
242
- float(
243
- j.axis.limit.upper
244
- if j.axis is not None
245
- and j.axis.limit is not None
246
- and j.axis.limit.upper is not None
247
- else jnp.finfo(float).max
240
+ (
241
+ float(j.axis.limit.upper)
242
+ if j.axis is not None and j.axis.limit is not None
243
+ else np.finfo(float).max
248
244
  ),
249
245
  ),
250
- friction_static=float(
246
+ friction_static=(
251
247
  j.axis.dynamics.friction
252
248
  if j.axis is not None
253
249
  and j.axis.dynamics is not None
254
250
  and j.axis.dynamics.friction is not None
255
251
  else 0.0
256
252
  ),
257
- friction_viscous=float(
253
+ friction_viscous=(
258
254
  j.axis.dynamics.damping
259
255
  if j.axis is not None
260
256
  and j.axis.dynamics is not None
261
257
  and j.axis.dynamics.damping is not None
262
258
  else 0.0
263
259
  ),
264
- position_limit_damper=float(
260
+ position_limit_damper=(
265
261
  j.axis.limit.dissipation
266
262
  if j.axis is not None
267
263
  and j.axis.limit is not None
268
264
  and j.axis.limit.dissipation is not None
269
265
  else 0.0
270
266
  ),
271
- position_limit_spring=float(
267
+ position_limit_spring=(
272
268
  j.axis.limit.stiffness
273
269
  if j.axis is not None
274
270
  and j.axis.limit is not None
@@ -277,7 +273,7 @@ def extract_model_data(
277
273
  ),
278
274
  )
279
275
  for j in sdf_model.joints()
280
- if j.type in {"revolute", "continuous", "prismatic", "fixed"}
276
+ if j.type in {"revolute", "prismatic", "fixed"}
281
277
  and j.parent != "world"
282
278
  and j.child in links_dict.keys()
283
279
  ]
@@ -9,6 +9,7 @@ import jax_dataclasses
9
9
 
10
10
  import jaxsim.api as js
11
11
  import jaxsim.typing as jtp
12
+ from jaxsim import math
12
13
  from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
13
14
  from jaxsim.terrain import FlatTerrain, Terrain
14
15
 
@@ -271,17 +272,9 @@ class RigidContacts(ContactModel):
271
272
  link_forces=link_forces,
272
273
  )
273
274
 
274
- with (
275
- references.switch_velocity_representation(VelRepr.Mixed),
276
- data.switch_velocity_representation(VelRepr.Mixed),
277
- ):
278
- BW_ν̇_free = jnp.hstack(
279
- js.ode.system_acceleration(
280
- model=model,
281
- data=data,
282
- joint_forces=references.joint_force_references(model=model),
283
- link_forces=references.link_forces(model=model, data=data),
284
- )
275
+ with references.switch_velocity_representation(VelRepr.Mixed):
276
+ BW_ν̇_free = RigidContacts._compute_mixed_nu_dot_free(
277
+ model, data, references=references
285
278
  )
286
279
 
287
280
  free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
@@ -387,6 +380,43 @@ class RigidContacts(ContactModel):
387
380
  n_constraints = 6 * n_collidable_points
388
381
  return jnp.zeros(shape=(n_constraints,))
389
382
 
383
+ @staticmethod
384
+ def _compute_mixed_nu_dot_free(
385
+ model: js.model.JaxSimModel,
386
+ data: js.data.JaxSimModelData,
387
+ references: js.references.JaxSimModelReferences | None = None,
388
+ ) -> jtp.Array:
389
+ references = (
390
+ references
391
+ if references is not None
392
+ else js.references.JaxSimModelReferences.zero(model=model, data=data)
393
+ )
394
+
395
+ with (
396
+ data.switch_velocity_representation(VelRepr.Mixed),
397
+ references.switch_velocity_representation(VelRepr.Mixed),
398
+ ):
399
+ BW_v_WB = data.base_velocity()
400
+ W_ṗ_B, W_ω_WB = jnp.split(BW_v_WB, 2)
401
+ W_v̇_WB, s̈ = js.ode.system_acceleration(
402
+ model=model,
403
+ data=data,
404
+ joint_forces=references.joint_force_references(model=model),
405
+ link_forces=references.link_forces(model=model, data=data),
406
+ )
407
+
408
+ # Convert the inertial-fixed base acceleration to a mixed base acceleration.
409
+ W_H_B = data.base_transform()
410
+ W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
411
+ BW_X_W = math.Adjoint.from_transform(W_H_BW, inverse=True)
412
+ term1 = BW_X_W @ W_v̇_WB
413
+ term2 = jnp.zeros(6).at[0:3].set(jnp.cross(W_ṗ_B, W_ω_WB))
414
+ BW_v̇_WB = term1 - term2
415
+
416
+ BW_ν̇ = jnp.hstack([BW_v̇_WB, s̈])
417
+
418
+ return BW_ν̇
419
+
390
420
  @staticmethod
391
421
  def _linear_acceleration_of_collidable_points(
392
422
  model: js.model.JaxSimModel,
jaxsim/terrain/terrain.py CHANGED
@@ -46,82 +46,66 @@ class Terrain(abc.ABC):
46
46
  @jax_dataclasses.pytree_dataclass
47
47
  class FlatTerrain(Terrain):
48
48
 
49
- _height: float = dataclasses.field(default=0.0, kw_only=True)
49
+ z: float = dataclasses.field(default=0.0, kw_only=True)
50
50
 
51
51
  @staticmethod
52
52
  def build(height: jtp.FloatLike) -> FlatTerrain:
53
53
 
54
- return FlatTerrain(_height=float(height))
54
+ return FlatTerrain(z=float(height))
55
55
 
56
56
  def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
57
57
 
58
- return jnp.array(self._height, dtype=float)
59
-
60
- def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
61
-
62
- return jnp.array([0.0, 0.0, 1.0], dtype=float)
58
+ return jnp.array(self.z, dtype=float)
63
59
 
64
60
  def __hash__(self) -> int:
65
61
 
66
- return hash(self._height)
62
+ return hash(self.z)
67
63
 
68
64
  def __eq__(self, other: FlatTerrain) -> bool:
69
65
 
70
66
  if not isinstance(other, FlatTerrain):
71
67
  return False
72
68
 
73
- return self._height == other._height
69
+ return self.z == other.z
74
70
 
75
71
 
76
72
  @jax_dataclasses.pytree_dataclass
77
73
  class PlaneTerrain(FlatTerrain):
78
74
 
79
- _normal: tuple[float, float, float] = jax_dataclasses.field(
75
+ plane_normal: tuple[float, float, float] = jax_dataclasses.field(
80
76
  default=(0.0, 0.0, 1.0), kw_only=True
81
77
  )
82
78
 
83
79
  @staticmethod
84
- def build(height: jtp.FloatLike = 0.0, *, normal: jtp.VectorLike) -> PlaneTerrain:
80
+ def build(
81
+ plane_normal: jtp.VectorLike, plane_height_over_origin: jtp.FloatLike = 0.0
82
+ ) -> PlaneTerrain:
85
83
  """
86
84
  Create a PlaneTerrain instance with a specified plane normal vector.
87
85
 
88
86
  Args:
89
- normal: The normal vector of the terrain plane.
90
- height: The height of the plane over the origin.
87
+ plane_normal: The normal vector of the terrain plane.
88
+ plane_height_over_origin: The height of the plane over the origin.
91
89
 
92
90
  Returns:
93
91
  PlaneTerrain: A PlaneTerrain instance.
94
92
  """
95
93
 
96
- normal = jnp.array(normal, dtype=float)
97
- height = jnp.array(height, dtype=float)
94
+ plane_normal = jnp.array(plane_normal, dtype=float)
95
+ plane_height_over_origin = jnp.array(plane_height_over_origin, dtype=float)
98
96
 
99
- if normal.shape != (3,):
97
+ if plane_normal.shape != (3,):
100
98
  msg = "Expected a 3D vector for the plane normal, got '{}'."
101
- raise ValueError(msg.format(normal.shape))
99
+ raise ValueError(msg.format(plane_normal.shape))
102
100
 
103
101
  # Make sure that the plane normal is a unit vector.
104
- normal = normal / jnp.linalg.norm(normal)
102
+ plane_normal = plane_normal / jnp.linalg.norm(plane_normal)
105
103
 
106
104
  return PlaneTerrain(
107
- _height=height.item(),
108
- _normal=tuple(normal.tolist()),
105
+ z=float(plane_height_over_origin),
106
+ plane_normal=tuple(plane_normal.tolist()),
109
107
  )
110
108
 
111
- def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
112
- """
113
- Compute the normal vector of the terrain at a specific (x, y) location.
114
-
115
- Args:
116
- x: The x-coordinate of the location.
117
- y: The y-coordinate of the location.
118
-
119
- Returns:
120
- The normal vector of the terrain surface at the specified location.
121
- """
122
-
123
- return jnp.array(self._normal, dtype=float)
124
-
125
109
  def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
126
110
  """
127
111
  Compute the height of the terrain at a specific (x, y) location on a plane.
@@ -139,10 +123,10 @@ class PlaneTerrain(FlatTerrain):
139
123
  # The height over the origin: -D/C
140
124
 
141
125
  # Get the plane equation coefficients from the terrain normal.
142
- A, B, C = self._normal
126
+ A, B, C = self.plane_normal
143
127
 
144
128
  # Compute the final coefficient D considering the terrain height.
145
- D = -C * self._height
129
+ D = -C * self.z
146
130
 
147
131
  # Invert the plane equation to get the height at the given (x, y) coordinates.
148
132
  return jnp.array(-(A * x + B * y + D) / C).astype(float)
@@ -153,9 +137,9 @@ class PlaneTerrain(FlatTerrain):
153
137
 
154
138
  return hash(
155
139
  (
156
- hash(self._height),
140
+ hash(self.z),
157
141
  HashedNumpyArray.hash_of_array(
158
- array=jnp.array(self._normal, dtype=float)
142
+ array=jnp.array(self.plane_normal, dtype=float)
159
143
  ),
160
144
  )
161
145
  )
@@ -166,10 +150,10 @@ class PlaneTerrain(FlatTerrain):
166
150
  return False
167
151
 
168
152
  if not (
169
- np.allclose(self._height, other._height)
153
+ np.allclose(self.z, other.z)
170
154
  and np.allclose(
171
- np.array(self._normal, dtype=float),
172
- np.array(other._normal, dtype=float),
155
+ np.array(self.plane_normal, dtype=float),
156
+ np.array(other.plane_normal, dtype=float),
173
157
  )
174
158
  ):
175
159
  return False
jaxsim/typing.py CHANGED
@@ -16,7 +16,7 @@ Int = Scalar
16
16
  Bool = Scalar
17
17
  Float = Scalar
18
18
 
19
- PyTree: object = (
19
+ PyTree = (
20
20
  dict[Hashable, TypeVar("PyTree")]
21
21
  | list[TypeVar("PyTree")]
22
22
  | tuple[TypeVar("PyTree")]
@@ -135,10 +135,9 @@ class JaxsimDataclass(abc.ABC):
135
135
  """
136
136
 
137
137
  return tuple(
138
- map(
139
- lambda leaf: getattr(leaf, "shape", None),
140
- jax.tree_util.tree_leaves(tree),
141
- )
138
+ leaf.shape if hasattr(leaf, "shape") else None
139
+ for leaf in jax.tree_util.tree_leaves(tree)
140
+ if hasattr(leaf, "shape")
142
141
  )
143
142
 
144
143
  @staticmethod
@@ -155,10 +154,9 @@ class JaxsimDataclass(abc.ABC):
155
154
  """
156
155
 
157
156
  return tuple(
158
- map(
159
- lambda leaf: getattr(leaf, "dtype", None),
160
- jax.tree_util.tree_leaves(tree),
161
- )
157
+ leaf.dtype if hasattr(leaf, "dtype") else None
158
+ for leaf in jax.tree_util.tree_leaves(tree)
159
+ if hasattr(leaf, "dtype")
162
160
  )
163
161
 
164
162
  @staticmethod
@@ -174,10 +172,9 @@ class JaxsimDataclass(abc.ABC):
174
172
  """
175
173
 
176
174
  return tuple(
177
- map(
178
- lambda leaf: getattr(leaf, "weak_type", None),
179
- jax.tree_util.tree_leaves(tree),
180
- )
175
+ leaf.weak_type if hasattr(leaf, "weak_type") else False
176
+ for leaf in jax.tree_util.tree_leaves(tree)
177
+ if hasattr(leaf, "weak_type")
181
178
  )
182
179
 
183
180
  @staticmethod
jaxsim/utils/wrappers.py CHANGED
@@ -110,7 +110,7 @@ class HashedNumpyArray:
110
110
  return np.allclose(
111
111
  self.array,
112
112
  other.array,
113
- **(dict(atol=self.precision) if self.precision is not None else {}),
113
+ **({dict(atol=self.precision)} if self.precision is not None else {}),
114
114
  )
115
115
 
116
116
  return hash(self) == hash(other)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev64
3
+ Version: 0.4.3.dev68
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -61,7 +61,6 @@ Description-Content-Type: text/markdown
61
61
  License-File: LICENSE
62
62
  Requires-Dist: coloredlogs
63
63
  Requires-Dist: jax>=0.4.13
64
- Requires-Dist: jaxopt>=0.8.0
65
64
  Requires-Dist: jaxlib>=0.4.13
66
65
  Requires-Dist: jaxlie>=1.3.0
67
66
  Requires-Dist: jax-dataclasses>=1.4.0
@@ -1,29 +1,29 @@
1
- jaxsim/__init__.py,sha256=bSbpggIz5aG6QuGZLa0V2EfHjAOeucMxi-vIYxzLmN8,2788
2
- jaxsim/_version.py,sha256=lLNskxtfHW1HqvnLRuhux3LlK89fMiZFUWknSYopw7k,426
1
+ jaxsim/__init__.py,sha256=ixsS4dYMPex2wOUUp_rkPnwrPhYzkRh1xO_YuMj3Cr4,2626
2
+ jaxsim/_version.py,sha256=XDf5LPSlhAhH48AO29kysLP_4FTR5VWOpS0LrK5RSfo,426
3
3
  jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
- jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
5
+ jaxsim/typing.py,sha256=IbFx3UkEXi-cm7UBqMPi58rJAFV_HbZ9E_K4JwfNvVM,753
6
6
  jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
7
7
  jaxsim/api/com.py,sha256=m-p3EJDhpnMTlXKplfbZE_aH9NqX_VyLlAE3vUhc6l4,13642
8
8
  jaxsim/api/common.py,sha256=SNgxq42r6eF_-aPszvOjUYkGwXOzz4hKmhDwEUkscFQ,6650
9
- jaxsim/api/contact.py,sha256=C_PgMjWYYiqpA7Oz3IxHeFgrp855-xG6AQr6Ze98CtI,21863
10
- jaxsim/api/data.py,sha256=mFUw2mj8AIXduW6HnkGN7eooZHfJhwnWbtYZfLF6gk4,28206
9
+ jaxsim/api/contact.py,sha256=HyEAjF7BySDDOlRahN0l7V15IPB0HPXuoM0twamuEW0,20913
10
+ jaxsim/api/data.py,sha256=CUh9lvhVk3_clNQ26BUBGpjvFSsK_PrVWVMEWpMdHRM,27206
11
11
  jaxsim/api/frame.py,sha256=KS8A5wRfjxhe9NgcVo2QA516iP5zky7UVnWxG7nTa7c,12911
12
- jaxsim/api/joint.py,sha256=lksT1Doxz2jknHyhb4ls20z6f6dofpZSzBJtVacZXAE,7129
12
+ jaxsim/api/joint.py,sha256=L81bQe-noPT6_54KOSF7KBjRmEPAS433ULn2EcXI8vI,5115
13
13
  jaxsim/api/kin_dyn_parameters.py,sha256=CcfSg5Mc8qb1mZeMQ4AK_ffZIsK5yOl7tu397pFhcDA,29369
14
14
  jaxsim/api/link.py,sha256=qPRtc8qqMRjZxUCZYXJMygbB6huDXBfIT1b1b8Durkw,18631
15
- jaxsim/api/model.py,sha256=K0q8-j-04f6B3MEXsctDGtWiuWlN3HbDrsS7zoPYStk,65871
16
- jaxsim/api/ode.py,sha256=VuOLvCFoyGLmhNf2vFP5BI9BAPz78V_RW5tJ4hrizsw,13041
17
- jaxsim/api/ode_data.py,sha256=7RSoBhfCJdP6P9InQbDwdBVpClPMMuetewI-6AWm-_0,20276
15
+ jaxsim/api/model.py,sha256=HXoqCtQ3KStGoxhgvFm8P_Sc-lbEM4l5No2MoHzNlOk,65558
16
+ jaxsim/api/ode.py,sha256=Vb2sN4zwpXnaJDD9-ziz2qvfmfa4jvIQ0fONbBIRGmU,13368
17
+ jaxsim/api/ode_data.py,sha256=U7F6TL6bENAxpQQl4PupPoDG7d7VfTTFqDAs3xwu6Hs,20003
18
18
  jaxsim/api/references.py,sha256=XOVKuQXRmjPoP-T5JWGSbqIGX5DzOkeGafqRpj0ZQEM,20771
19
19
  jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
20
- jaxsim/integrators/common.py,sha256=XIrJVJDO0ldaZ93WgoGNlFoRvazsRJTpO3DrK9kIXqM,20437
20
+ jaxsim/integrators/common.py,sha256=ntjflaV3qWaFH_E65pAGZ6QipdnFsgQDasKtIKpxTe4,20432
21
21
  jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
22
22
  jaxsim/integrators/variable_step.py,sha256=5StkFh9oQba34zlkIoXG2fUN78gbxkHePWbrpQ-QZOI,21274
23
23
  jaxsim/math/__init__.py,sha256=8oPITEoGwgRcOeG8KxtqxPQ8b5uku1HNRMokpCoi9Tc,352
24
24
  jaxsim/math/adjoint.py,sha256=o1FCipkGwPtMbN2gFNIyUV8ADF3TX5fxElpTEXK0bIs,4377
25
25
  jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
26
- jaxsim/math/inertia.py,sha256=01hz6wMFreN2jBA0rVoBS1YMVh77KvwuzXSOpI3pxNk,1614
26
+ jaxsim/math/inertia.py,sha256=_hNpoeyEpAGr9ExDQJjckbjhk39luJFF-jv0SKqefnQ,1614
27
27
  jaxsim/math/joint_model.py,sha256=EzAveaG5B6ZnCFNUzN30KEQUVesd83lfWXJarYR-kUw,9989
28
28
  jaxsim/math/quaternion.py,sha256=_WA7W3iv7px83sWO1V1n0-J78hqAlO4SL1-jofE-UZ4,4754
29
29
  jaxsim/math/rotation.py,sha256=k-nwT79zmWrys3NNAB-lGWxat7Kqm_6JnFRoimJ8rBg,2156
@@ -31,18 +31,18 @@ jaxsim/math/skew.py,sha256=oOGSSR8PUGROl6IJFlrmu6K3gPH-u16hUPfKIkcVv9o,1177
31
31
  jaxsim/math/transform.py,sha256=KXzQgOnCfAtbXCwxhplpJ3F0JT3oEyeLVby1_uRAryQ,2892
32
32
  jaxsim/mujoco/__init__.py,sha256=Zo5GAlN1DYKvX8s1hu1j6HntKIbBMLB9Puv9ouaNAZ8,158
33
33
  jaxsim/mujoco/__main__.py,sha256=GBmB7J-zj75ZnFyuAAmpSOpbxi_HhHhWJeot3ljGDJY,5291
34
- jaxsim/mujoco/loaders.py,sha256=_8Af_5Yo0-lWHE-46BBMcrqSJnDNxr3peyc519DExtA,25322
34
+ jaxsim/mujoco/loaders.py,sha256=XB-fgXuWMTFiaand5MZlLFQ5__Sh8MK5CJsxIU34MBk,25328
35
35
  jaxsim/mujoco/model.py,sha256=AQksXemXWACJ3yvefV2G5HLwwBU9ISoJrOD1wlxdY5w,16386
36
36
  jaxsim/mujoco/visualizer.py,sha256=T1vU-w4NKSmgEkZ0FqVcGmIvYrYO0len2UBSsU4MOZ0,6978
37
37
  jaxsim/parsers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
- jaxsim/parsers/kinematic_graph.py,sha256=wT2bgaCS8VQJTHy2H9sENkVPDOiMkRikxEF1t_WaahQ,34748
38
+ jaxsim/parsers/kinematic_graph.py,sha256=KijMWKyhTLKSNUmOOk4sYQMgPh_OkA_brncL7gBRHaY,34757
39
39
  jaxsim/parsers/descriptions/__init__.py,sha256=PbIlunVfb59pB5jSX97YVpMAANRZPRkJ0X-hS14rzv4,221
40
40
  jaxsim/parsers/descriptions/collision.py,sha256=BQeIG-TKi4SVny23w6riDrQ5itC6VRwEMBX6HgAXHxA,3973
41
41
  jaxsim/parsers/descriptions/joint.py,sha256=VSb6C0FBBKMqwrHBKfc-Bbn4rl_J0RzUxMQlhIEvOPM,5185
42
42
  jaxsim/parsers/descriptions/link.py,sha256=Eh0W5qL7_Uw0GV-BkNKXhm9Q2dRTfIWCX5D-87zQkxA,3711
43
43
  jaxsim/parsers/descriptions/model.py,sha256=I2Vsbv8Josl4Le7b5rIvhqA2k9Bbv5JxMqwytayxds0,9833
44
44
  jaxsim/parsers/rod/__init__.py,sha256=G2vqlLajBLUc4gyzXwsEI2Wsi4TMOIF9bLDFeT6KrGU,92
45
- jaxsim/parsers/rod/parser.py,sha256=9EigYv2oGn4bfIY1q0Cd_55yVKfN2rXP_MuZSZqGxYM,13681
45
+ jaxsim/parsers/rod/parser.py,sha256=HskeCqDsbtwH2BDk3vfxvx391wUTVGLaUXNvBrdNo-4,13486
46
46
  jaxsim/parsers/rod/utils.py,sha256=5DsF3OeePZGidOJ5GiFSZx-51uIdnFvMW9EK6SgOW6Q,5698
47
47
  jaxsim/rbda/__init__.py,sha256=H7DhXpxkPOi9lpUvg31IMHFfRafke1UoJLc5GQIdyhA,387
48
48
  jaxsim/rbda/aba.py,sha256=w7ciyxB0IsmueatT0C7PcBQEl9dyiH9oqJgIi3xeTUE,8983
@@ -54,17 +54,16 @@ jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
54
54
  jaxsim/rbda/utils.py,sha256=eeT21Y4DiiyhrdF0lUE_VvRuwru5-rR7yOlOlWzCCWE,5381
55
55
  jaxsim/rbda/contacts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
56
56
  jaxsim/rbda/contacts/common.py,sha256=VwAs742futAmLnDgbaOuLzNDBFiKDfYItdEZ4UcFgzE,2467
57
- jaxsim/rbda/contacts/relaxed_rigid.py,sha256=9YkPLbK6Kk0wPkuj47r7NBqY2tARyJsiCbrvDlOWHSI,12700
58
- jaxsim/rbda/contacts/rigid.py,sha256=fbZk7sC6YOnTs_tzQRfsyBpHyT22XF-wB-EvOSZmhos,14746
57
+ jaxsim/rbda/contacts/rigid.py,sha256=8Vbnxng-ERZ5ka_eZGIBuhBDr2PNjc7m-Or255AfEw4,15862
59
58
  jaxsim/rbda/contacts/soft.py,sha256=_wvb5iZDjGcVg6rNQelN4LZN7qSC2NIp0HdKvZmlGfk,15647
60
59
  jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
61
- jaxsim/terrain/terrain.py,sha256=xUQg47yGxIOcTkLPbnO3sruEGBhoCd16j1evTGlmNjI,5010
60
+ jaxsim/terrain/terrain.py,sha256=ctyNANIFSM3tZmamprjaEDcWgUSP0oNJbmT1zw9RjPs,4565
62
61
  jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
63
- jaxsim/utils/jaxsim_dataclass.py,sha256=FSiUvdnq4Y1T9Jaa_mw4ZBQJe8H7deLr3Kupxtlh4iI,11322
62
+ jaxsim/utils/jaxsim_dataclass.py,sha256=5xJbY0G8d7C0OTNIW9T4vQxiDak6TGZT9gpNOvRykFI,11373
64
63
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
65
- jaxsim/utils/wrappers.py,sha256=Fh82ZcaFi5fUnByyFLnmumaobsu1hJIvFdopUVzJ1ps,4052
66
- jaxsim-0.4.3.dev64.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
67
- jaxsim-0.4.3.dev64.dist-info/METADATA,sha256=0-JS1eJjFMSaMzwqbCSpWYU2GcrZkxT1LBDo7lhWICo,17276
68
- jaxsim-0.4.3.dev64.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
69
- jaxsim-0.4.3.dev64.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
70
- jaxsim-0.4.3.dev64.dist-info/RECORD,,
64
+ jaxsim/utils/wrappers.py,sha256=JhLUh1g8iU-lhjbuZRfkscPZhYlLCOorVM2Xl3ulRBI,4054
65
+ jaxsim-0.4.3.dev68.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
66
+ jaxsim-0.4.3.dev68.dist-info/METADATA,sha256=IrZMXHUptvvLA5YgloveNIge4OdEBjT-DxhdHBrn_WM,17247
67
+ jaxsim-0.4.3.dev68.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
68
+ jaxsim-0.4.3.dev68.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
69
+ jaxsim-0.4.3.dev68.dist-info/RECORD,,
@@ -1,384 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import dataclasses
4
- from typing import Any
5
-
6
- import jax
7
- import jax.numpy as jnp
8
- import jax_dataclasses
9
- import jaxopt
10
-
11
- import jaxsim.api as js
12
- import jaxsim.typing as jtp
13
- from jaxsim.api.common import VelRepr
14
- from jaxsim.math import Adjoint
15
- from jaxsim.terrain.terrain import FlatTerrain, Terrain
16
-
17
- from .common import ContactModel, ContactsParams, ContactsState
18
-
19
-
20
- @jax_dataclasses.pytree_dataclass
21
- class RelaxedRigidContactsParams(ContactsParams):
22
- """Parameters of the relaxed rigid contacts model."""
23
-
24
- # Time constant
25
- time_constant: jtp.Float = dataclasses.field(
26
- default_factory=lambda: jnp.array(0.01, dtype=float)
27
- )
28
-
29
- # Adimensional damping coefficient
30
- damping_coefficient: jtp.Float = dataclasses.field(
31
- default_factory=lambda: jnp.array(1.0, dtype=float)
32
- )
33
-
34
- # Minimum impedance
35
- d_min: jtp.Float = dataclasses.field(
36
- default_factory=lambda: jnp.array(0.9, dtype=float)
37
- )
38
-
39
- # Maximum impedance
40
- d_max: jtp.Float = dataclasses.field(
41
- default_factory=lambda: jnp.array(0.95, dtype=float)
42
- )
43
-
44
- # Width
45
- width: jtp.Float = dataclasses.field(
46
- default_factory=lambda: jnp.array(0.0001, dtype=float)
47
- )
48
-
49
- # Midpoint
50
- midpoint: jtp.Float = dataclasses.field(
51
- default_factory=lambda: jnp.array(0.1, dtype=float)
52
- )
53
-
54
- # Power exponent
55
- power: jtp.Float = dataclasses.field(
56
- default_factory=lambda: jnp.array(1.0, dtype=float)
57
- )
58
-
59
- # Stiffness
60
- stiffness: jtp.Float = dataclasses.field(
61
- default_factory=lambda: jnp.array(0.0, dtype=float)
62
- )
63
-
64
- # Damping
65
- damping: jtp.Float = dataclasses.field(
66
- default_factory=lambda: jnp.array(0.0, dtype=float)
67
- )
68
-
69
- # Friction coefficient
70
- mu: jtp.Float = dataclasses.field(
71
- default_factory=lambda: jnp.array(0.5, dtype=float)
72
- )
73
-
74
- # Maximum number of iterations
75
- max_iterations: jtp.Int = dataclasses.field(
76
- default_factory=lambda: jnp.array(50, dtype=int)
77
- )
78
-
79
- # Solver tolerance
80
- tolerance: jtp.Float = dataclasses.field(
81
- default_factory=lambda: jnp.array(1e-6, dtype=float)
82
- )
83
-
84
- def __hash__(self) -> int:
85
- from jaxsim.utils.wrappers import HashedNumpyArray
86
-
87
- return hash(
88
- (
89
- HashedNumpyArray(self.time_constant),
90
- HashedNumpyArray(self.damping_coefficient),
91
- HashedNumpyArray(self.d_min),
92
- HashedNumpyArray(self.d_max),
93
- HashedNumpyArray(self.width),
94
- HashedNumpyArray(self.midpoint),
95
- HashedNumpyArray(self.power),
96
- HashedNumpyArray(self.stiffness),
97
- HashedNumpyArray(self.damping),
98
- HashedNumpyArray(self.mu),
99
- HashedNumpyArray(self.max_iterations),
100
- HashedNumpyArray(self.tolerance),
101
- )
102
- )
103
-
104
- def __eq__(self, other: RelaxedRigidContactsParams) -> bool:
105
- return hash(self) == hash(other)
106
-
107
- @classmethod
108
- def build(
109
- cls,
110
- time_constant: jtp.FloatLike | None = None,
111
- damping_coefficient: jtp.FloatLike | None = None,
112
- d_min: jtp.FloatLike | None = None,
113
- d_max: jtp.FloatLike | None = None,
114
- width: jtp.FloatLike | None = None,
115
- midpoint: jtp.FloatLike | None = None,
116
- power: jtp.FloatLike | None = None,
117
- stiffness: jtp.FloatLike | None = None,
118
- damping: jtp.FloatLike | None = None,
119
- mu: jtp.FloatLike | None = None,
120
- max_iterations: jtp.IntLike | None = None,
121
- tolerance: jtp.FloatLike | None = None,
122
- ) -> RelaxedRigidContactsParams:
123
- """Create a `RelaxedRigidContactsParams` instance"""
124
-
125
- return cls(
126
- **{
127
- field: jnp.array(locals().get(field, default), dtype=default.dtype)
128
- for field, default in map(
129
- lambda f: (f, cls.__dataclass_fields__[f].default),
130
- filter(lambda f: f != "__mutability__", cls.__dataclass_fields__),
131
- )
132
- }
133
- )
134
-
135
- def valid(self) -> bool:
136
- return bool(
137
- jnp.all(self.time_constant >= 0.0)
138
- and jnp.all(self.damping_coefficient > 0.0)
139
- and jnp.all(self.d_min >= 0.0)
140
- and jnp.all(self.d_max <= 1.0)
141
- and jnp.all(self.d_min <= self.d_max)
142
- and jnp.all(self.width >= 0.0)
143
- and jnp.all(self.midpoint >= 0.0)
144
- and jnp.all(self.power >= 0.0)
145
- and jnp.all(self.mu >= 0.0)
146
- and jnp.all(self.max_iterations > 0)
147
- and jnp.all(self.tolerance > 0.0)
148
- )
149
-
150
-
151
- @jax_dataclasses.pytree_dataclass
152
- class RelaxedRigidContactsState(ContactsState):
153
- """Class storing the state of the relaxed rigid contacts model."""
154
-
155
- def __eq__(self, other: RelaxedRigidContactsState) -> bool:
156
- return hash(self) == hash(other)
157
-
158
- @staticmethod
159
- def build() -> RelaxedRigidContactsState:
160
- """Create a `RelaxedRigidContactsState` instance"""
161
-
162
- return RelaxedRigidContactsState()
163
-
164
- @staticmethod
165
- def zero(model: js.model.JaxSimModel) -> RelaxedRigidContactsState:
166
- """Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`."""
167
- return RelaxedRigidContactsState.build()
168
-
169
- def valid(self, model: js.model.JaxSimModel) -> bool:
170
- return True
171
-
172
-
173
- @jax_dataclasses.pytree_dataclass
174
- class RelaxedRigidContacts(ContactModel):
175
- """Relaxed rigid contacts model."""
176
-
177
- parameters: RelaxedRigidContactsParams = dataclasses.field(
178
- default_factory=RelaxedRigidContactsParams
179
- )
180
-
181
- terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
182
- default_factory=FlatTerrain
183
- )
184
-
185
- def compute_contact_forces(
186
- self,
187
- position: jtp.Vector,
188
- velocity: jtp.Vector,
189
- model: js.model.JaxSimModel,
190
- data: js.data.JaxSimModelData,
191
- link_forces: jtp.MatrixLike | None = None,
192
- ) -> tuple[jtp.Vector, tuple[Any, ...]]:
193
-
194
- link_forces = (
195
- link_forces
196
- if link_forces is not None
197
- else jnp.zeros((model.number_of_links(), 6))
198
- )
199
-
200
- references = js.references.JaxSimModelReferences.build(
201
- model=model,
202
- data=data,
203
- velocity_representation=data.velocity_representation,
204
- link_forces=link_forces,
205
- )
206
-
207
- def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
208
- x, y, z = jax.tree_map(jnp.squeeze, (x, y, z))
209
-
210
- n̂ = self.terrain.normal(x=x, y=y).squeeze()
211
- h = jnp.array([0, 0, z - model.terrain.height(x=x, y=y)])
212
-
213
- return jnp.dot(h, n̂)
214
-
215
- # Compute the activation state of the collidable points
216
- δ = jax.vmap(_detect_contact)(*position.T)
217
-
218
- with (
219
- references.switch_velocity_representation(VelRepr.Mixed),
220
- data.switch_velocity_representation(VelRepr.Mixed),
221
- ):
222
- M = js.model.free_floating_mass_matrix(model=model, data=data)
223
- Jl_WC = jnp.vstack(
224
- jax.vmap(lambda J, height: J * (height < 0))(
225
- js.contact.jacobian(model=model, data=data)[:, :3, :], δ
226
- )
227
- )
228
- W_H_C = js.contact.transforms(model=model, data=data)
229
- BW_ν̇_free = jnp.hstack(
230
- js.ode.system_acceleration(
231
- model=model,
232
- data=data,
233
- link_forces=references.link_forces(model=model, data=data),
234
- )
235
- )
236
- BW_ν = data.generalized_velocity()
237
- J̇_WC = jnp.vstack(
238
- jax.vmap(lambda J̇, height: J̇ * (height < 0))(
239
- js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
240
- ),
241
- )
242
-
243
- a_ref, R, K, D = self._regularizers(
244
- model=model,
245
- penetration=δ,
246
- velocity=velocity,
247
- parameters=self.parameters,
248
- )
249
-
250
- G = Jl_WC @ jnp.linalg.lstsq(M, Jl_WC.T)[0]
251
- CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇_WC @ BW_ν
252
-
253
- # Calculate quantities for the linear optimization problem.
254
- A = G + R
255
- b = CW_al_free_WC - a_ref
256
-
257
- objective = lambda x: jnp.sum(jnp.square(A @ x + b))
258
-
259
- # Compute the 3D linear force in C[W] frame
260
- opt = jaxopt.LBFGS(
261
- fun=objective,
262
- maxiter=self.parameters.max_iterations,
263
- tol=self.parameters.tolerance,
264
- maxls=30,
265
- history_size=10,
266
- max_stepsize=100.0,
267
- )
268
-
269
- init_params = (
270
- K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ)
271
- + D[:, jnp.newaxis] * velocity
272
- ).flatten()
273
-
274
- CW_f_Ci = opt.run(init_params=init_params).params.reshape(-1, 3)
275
-
276
- def mixed_to_inertial(W_H_C: jax.Array, CW_fl: jax.Array) -> jax.Array:
277
- W_Xf_CW = Adjoint.from_transform(
278
- W_H_C.at[0:3, 0:3].set(jnp.eye(3)),
279
- inverse=True,
280
- ).T
281
- return W_Xf_CW @ jnp.hstack([CW_fl, jnp.zeros(3)])
282
-
283
- W_f_C = jax.vmap(mixed_to_inertial)(W_H_C, CW_f_Ci)
284
-
285
- return W_f_C, (None,)
286
-
287
- @staticmethod
288
- def _regularizers(
289
- model: js.model.JaxSimModel,
290
- penetration: jtp.Array,
291
- velocity: jtp.Array,
292
- parameters: RelaxedRigidContactsParams,
293
- ) -> tuple:
294
- """
295
- Compute the contact jacobian and the reference acceleration.
296
-
297
- Args:
298
- model: The jaxsim model.
299
- penetration: The penetration of the collidable points.
300
- velocity: The velocity of the collidable points.
301
- parameters: The parameters of the relaxed rigid contacts model.
302
-
303
- Returns:
304
- A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping.
305
- """
306
-
307
- Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ, *_ = jax_dataclasses.astuple(
308
- parameters
309
- )
310
-
311
- def _imp_aref(
312
- penetration: jtp.Array,
313
- velocity: jtp.Array,
314
- ) -> tuple[jtp.Array, jtp.Array]:
315
- """
316
- Calculates impedance and offset acceleration in constraint frame.
317
-
318
- Args:
319
- penetration: penetration in constraint frame
320
- velocity: velocity in constraint frame
321
-
322
- Returns:
323
- a_ref: offset acceleration in constraint frame
324
- R: regularization matrix
325
- K: computed stiffness
326
- D: computed damping
327
- """
328
- position = jnp.zeros(shape=(3,)).at[2].set(penetration)
329
-
330
- imp_x = jnp.abs(position) / width
331
- imp_a = (1.0 / jnp.power(mid, p - 1)) * jnp.power(imp_x, p)
332
-
333
- imp_b = 1 - (1.0 / jnp.power(1 - mid, p - 1)) * jnp.power(1 - imp_x, p)
334
-
335
- imp_y = jnp.where(imp_x < mid, imp_a, imp_b)
336
-
337
- imp = jnp.clip(ξ_min + imp_y * (ξ_max - ξ_min), ξ_min, ξ_max)
338
- imp = jnp.atleast_1d(jnp.where(imp_x > 1.0, ξ_max, imp))
339
-
340
- # When passing negative values, K and D represent a spring and damper, respectively.
341
- K_f = jnp.where(K < 0, -K / ξ_max**2, 1 / (ξ_max * Ω * ζ) ** 2)
342
- D_f = jnp.where(D < 0, -D / ξ_max, 2 / (ξ_max * Ω))
343
-
344
- a_ref = -jnp.atleast_1d(D_f * velocity + K_f * imp * position)
345
-
346
- return imp, a_ref, jnp.atleast_1d(K_f), jnp.atleast_1d(D_f)
347
-
348
- def _compute_row(
349
- *,
350
- link_idx: jtp.Float,
351
- penetration: jtp.Array,
352
- velocity: jtp.Array,
353
- ) -> tuple[jtp.Array, jtp.Array]:
354
-
355
- # Compute the reference acceleration.
356
- ξ, a_ref, K, D = _imp_aref(
357
- penetration=penetration,
358
- velocity=velocity,
359
- )
360
-
361
- # Compute the regularization terms.
362
- R = (
363
- (2 * μ**2 * (1 - ξ) / (ξ + 1e-12))
364
- * (1 + μ**2)
365
- @ jnp.linalg.inv(M_L[link_idx, :3, :3])
366
- )
367
-
368
- return jax.tree.map(lambda x: x * (penetration < 0), (a_ref, R, K, D))
369
-
370
- M_L = js.model.link_spatial_inertia_matrices(model=model)
371
-
372
- a_ref, R, K, D = jax.tree.map(
373
- jnp.concatenate,
374
- (
375
- *jax.vmap(_compute_row)(
376
- link_idx=jnp.array(
377
- model.kin_dyn_parameters.contact_parameters.body
378
- ),
379
- penetration=penetration,
380
- velocity=velocity,
381
- ),
382
- ),
383
- )
384
- return a_ref, jnp.diag(R), K, D