jaxsim 0.4.3.dev181__py3-none-any.whl → 0.4.3.dev200__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev181'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev181')
15
+ __version__ = version = '0.4.3.dev200'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev200')
jaxsim/api/contact.py CHANGED
@@ -36,11 +36,10 @@ def collidable_point_kinematics(
36
36
  the linear component of the mixed 6D frame velocity.
37
37
  """
38
38
 
39
- from jaxsim.rbda import collidable_points
40
-
41
39
  # Switch to inertial-fixed since the RBDAs expect velocities in this representation.
42
40
  with data.switch_velocity_representation(VelRepr.Inertial):
43
- W_p_Ci, W_ṗ_Ci = collidable_points.collidable_points_pos_vel(
41
+
42
+ W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
44
43
  model=model,
45
44
  base_position=data.base_position(),
46
45
  base_quaternion=data.base_orientation(dcm=False),
@@ -304,6 +303,15 @@ def in_contact(
304
303
 
305
304
 
306
305
  def estimate_good_soft_contacts_parameters(
306
+ *args, **kwargs
307
+ ) -> jaxsim.rbda.contacts.ContactParamsTypes:
308
+
309
+ msg = "This method is deprecated, please use `{}`."
310
+ logging.warning(msg.format(estimate_good_contact_parameters.__name__))
311
+ return estimate_good_contact_parameters(*args, **kwargs)
312
+
313
+
314
+ def estimate_good_contact_parameters(
307
315
  model: js.model.JaxSimModel,
308
316
  *,
309
317
  standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
@@ -312,14 +320,9 @@ def estimate_good_soft_contacts_parameters(
312
320
  damping_ratio: jtp.FloatLike = 1.0,
313
321
  max_penetration: jtp.FloatLike | None = None,
314
322
  **kwargs,
315
- ) -> (
316
- jaxsim.rbda.contacts.RelaxedRigidContactsParams
317
- | jaxsim.rbda.contacts.RigidContactsParams
318
- | jaxsim.rbda.contacts.SoftContactsParams
319
- | jaxsim.rbda.contacts.ViscoElasticContactsParams
320
- ):
323
+ ) -> jaxsim.rbda.contacts.ContactParamsTypes:
321
324
  """
322
- Estimate good parameters for soft-like contact models.
325
+ Estimate good contact parameters.
323
326
 
324
327
  Args:
325
328
  model: The model to consider.
@@ -332,12 +335,19 @@ def estimate_good_soft_contacts_parameters(
332
335
  max_penetration:
333
336
  The maximum penetration allowed in steady state when the robot is
334
337
  supported by the configured number of active collidable points.
338
+ kwargs:
339
+ Additional model-specific parameters passed to the builder method of
340
+ the parameters class.
335
341
 
336
342
  Returns:
337
- The estimated good soft contacts parameters.
343
+ The estimated good contacts parameters.
344
+
345
+ Note:
346
+ This is primarily a convenience function for soft-like contact models.
347
+ However, it provides with some good default parameters also for the other ones.
338
348
 
339
349
  Note:
340
- This method provides a good starting point for the soft contacts parameters.
350
+ This method provides a good set of contacts parameters.
341
351
  The user is encouraged to fine-tune the parameters based on the
342
352
  specific application.
343
353
  """
@@ -364,6 +374,7 @@ def estimate_good_soft_contacts_parameters(
364
374
  max_δ = (
365
375
  max_penetration
366
376
  if max_penetration is not None
377
+ # Consider as default a 0.5% of the model height.
367
378
  else 0.005 * estimate_model_height(model=model)
368
379
  )
369
380
 
@@ -381,8 +392,11 @@ def estimate_good_soft_contacts_parameters(
381
392
  max_penetration=max_δ,
382
393
  number_of_active_collidable_points_steady_state=nc,
383
394
  damping_ratio=damping_ratio,
384
- p=model.contact_model.parameters.p,
385
- q=model.contact_model.parameters.q,
395
+ **dict(
396
+ p=model.contact_model.parameters.p,
397
+ q=model.contact_model.parameters.q,
398
+ )
399
+ | kwargs,
386
400
  )
387
401
 
388
402
  case contacts.ViscoElasticContacts():
@@ -396,15 +410,40 @@ def estimate_good_soft_contacts_parameters(
396
410
  max_penetration=max_δ,
397
411
  number_of_active_collidable_points_steady_state=nc,
398
412
  damping_ratio=damping_ratio,
399
- p=model.contact_model.parameters.p,
400
- q=model.contact_model.parameters.q,
401
- **kwargs,
413
+ **dict(
414
+ p=model.contact_model.parameters.p,
415
+ q=model.contact_model.parameters.q,
416
+ )
417
+ | kwargs,
402
418
  )
403
419
  )
404
420
 
421
+ case contacts.RigidContacts():
422
+ assert isinstance(model.contact_model, contacts.RigidContacts)
423
+
424
+ # Disable Baumgarte stabilization by default since it does not play
425
+ # well with the forward Euler integrator.
426
+ K = kwargs.get("K", 0.0)
427
+
428
+ parameters = contacts.RigidContactsParams.build(
429
+ mu=static_friction_coefficient,
430
+ **dict(
431
+ K=K,
432
+ D=2 * jnp.sqrt(K),
433
+ )
434
+ | kwargs,
435
+ )
436
+
437
+ case contacts.RelaxedRigidContacts():
438
+ assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)
439
+
440
+ parameters = contacts.RelaxedRigidContactsParams.build(
441
+ mu=static_friction_coefficient,
442
+ **kwargs,
443
+ )
444
+
405
445
  case _:
406
- logging.warning("The active contact model is not soft-like, no-op.")
407
- parameters = model.contact_model.parameters
446
+ raise ValueError(f"Invalid contact model: {model.contact_model}")
408
447
 
409
448
  return parameters
410
449
 
jaxsim/api/data.py CHANGED
@@ -34,7 +34,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
34
34
 
35
35
  state: ODEState
36
36
 
37
- gravity: jtp.Array
37
+ gravity: jtp.Vector
38
38
 
39
39
  contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False)
40
40
 
@@ -224,7 +224,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
224
224
  jaxsim.rbda.contacts.SoftContacts
225
225
  | jaxsim.rbda.contacts.ViscoElasticContacts,
226
226
  ):
227
- contacts_params = js.contact.estimate_good_soft_contacts_parameters(
227
+
228
+ contacts_params = js.contact.estimate_good_contact_parameters(
228
229
  model=model, standard_gravity=standard_gravity
229
230
  )
230
231
 
jaxsim/api/model.py CHANGED
@@ -40,6 +40,8 @@ class JaxSimModel(JaxsimDataclass):
40
40
  default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
41
41
  )
42
42
 
43
+ # Note that this is the default contact model.
44
+ # Its parameters, if any, are then overridden from those stored in JaxSimModelData.
43
45
  contact_model: jaxsim.rbda.contacts.ContactModel | None = dataclasses.field(
44
46
  default=None, repr=False
45
47
  )
@@ -2044,24 +2046,18 @@ def step(
2044
2046
  M = js.model.free_floating_mass_matrix(model, data_tf)
2045
2047
  W_p_C = js.contact.collidable_point_positions(model, data_tf)
2046
2048
 
2047
- # Compute the height of the terrain below each collidable point.
2048
- px, py, _ = W_p_C.T
2049
- terrain_height = jax.vmap(model.terrain.height)(px, py)
2050
-
2051
- # Compute the contact state.
2052
- inactive_collidable_points, _ = (
2053
- jaxsim.rbda.contacts.RigidContacts.detect_contacts(
2054
- W_p_C=W_p_C,
2055
- terrain_height=terrain_height,
2056
- )
2057
- )
2049
+ # Compute the penetration depth of the collidable points.
2050
+ δ, *_ = jax.vmap(
2051
+ jaxsim.rbda.contacts.common.compute_penetration_data,
2052
+ in_axes=(0, 0, None),
2053
+ )(W_p_C, jnp.zeros_like(W_p_C), model.terrain)
2058
2054
 
2059
2055
  # Compute the impact velocity.
2060
2056
  # It may be discontinuous in case new contacts are made.
2061
2057
  BW_nu_post_impact = (
2062
2058
  jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity(
2063
2059
  data=data_tf,
2064
- inactive_collidable_points=inactive_collidable_points,
2060
+ inactive_collidable_points=(δ <= 0),
2065
2061
  M=M,
2066
2062
  J_WC=J_WC,
2067
2063
  )
jaxsim/mujoco/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
1
  from .loaders import RodModelToMjcf, SdfToMjcf, UrdfToMjcf
2
2
  from .model import MujocoModelHelper
3
+ from .utils import mujoco_data_from_jaxsim
3
4
  from .visualizer import MujocoVideoRecorder, MujocoVisualizer
jaxsim/mujoco/loaders.py CHANGED
@@ -646,7 +646,7 @@ class MujocoCamera:
646
646
  def build(cls, **kwargs) -> MujocoCamera:
647
647
 
648
648
  if not all(isinstance(value, str) for value in kwargs.values()):
649
- raise ValueError("Values must be strings")
649
+ raise ValueError(f"Values must be strings: {kwargs}")
650
650
 
651
651
  return cls(**kwargs)
652
652
 
jaxsim/mujoco/model.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import functools
4
4
  import pathlib
5
- from collections.abc import Callable
5
+ from collections.abc import Callable, Sequence
6
6
  from typing import Any
7
7
 
8
8
  import mujoco as mj
@@ -107,7 +107,8 @@ class MujocoModelHelper:
107
107
  size = [float(el) for el in hfield_element["@size"].split(" ")]
108
108
  size[0], size[1] = heightmap_radius_xy
109
109
  size[2] = 1.0
110
- size[3] = max(0, -min(hfield))
110
+ # The following could be zero but Mujoco complains if it's exactly zero.
111
+ size[3] = max(0.000_001, -min(hfield))
111
112
 
112
113
  # Replace the 'size' attribute.
113
114
  hfields_dict[heightmap_name]["@size"] = " ".join(str(el) for el in size)
@@ -315,7 +316,7 @@ class MujocoModelHelper:
315
316
  self.data.qpos[sl] = position
316
317
 
317
318
  def set_joint_positions(
318
- self, joint_names: list[str], positions: npt.NDArray | list[npt.NDArray]
319
+ self, joint_names: Sequence[str], positions: npt.NDArray | list[npt.NDArray]
319
320
  ) -> None:
320
321
  """Set the positions of multiple joints."""
321
322
 
jaxsim/mujoco/utils.py ADDED
@@ -0,0 +1,101 @@
1
+ import mujoco as mj
2
+ import numpy as np
3
+
4
+ from . import MujocoModelHelper
5
+
6
+
7
+ def mujoco_data_from_jaxsim(
8
+ mujoco_model: mj.MjModel,
9
+ jaxsim_model,
10
+ jaxsim_data,
11
+ mujoco_data: mj.MjData | None = None,
12
+ update_removed_joints: bool = True,
13
+ ) -> mj.MjData:
14
+ """
15
+ Create a Mujoco data object from a JaxSim model and data objects.
16
+
17
+ Args:
18
+ mujoco_model: The Mujoco model object corresponding to the JaxSim model.
19
+ jaxsim_model: The JaxSim model object from which the Mujoco model was created.
20
+ jaxsim_data: The JaxSim data object containing the state of the model.
21
+ mujoco_data: An optional Mujoco data object. If None, a new one will be created.
22
+ update_removed_joints:
23
+ If True, the positions of the joints that have been removed during the
24
+ model reduction process will be set to their initial values.
25
+
26
+ Returns:
27
+ The Mujoco data object containing the state of the JaxSim model.
28
+
29
+ Note:
30
+ This method is useful to initialize a Mujoco data object used for visualization
31
+ with the state of a JaxSim model. In particular, this function takes care of
32
+ initializing the positions of the joints that have been removed during the
33
+ model reduction process. After the initial creation of the Mujoco data object,
34
+ it's faster to update the state using an external MujocoModelHelper object.
35
+ """
36
+
37
+ # The package `jaxsim.mujoco` is supposed to be jax-independent.
38
+ # We import all the JaxSim resources privately.
39
+ import jaxsim.api as js
40
+
41
+ if not isinstance(jaxsim_model, js.model.JaxSimModel):
42
+ raise ValueError("The `jaxsim_model` argument must be a JaxSimModel object.")
43
+
44
+ if not isinstance(jaxsim_data, js.data.JaxSimModelData):
45
+ raise ValueError("The `jaxsim_data` argument must be a JaxSimModelData object.")
46
+
47
+ # Create the helper to operate on the Mujoco model and data.
48
+ model_helper = MujocoModelHelper(model=mujoco_model, data=mujoco_data)
49
+
50
+ # If the model is fixed-base, the Mujoco model won't have the joint corresponding
51
+ # to the floating base, and the helper would raise an exception.
52
+ if jaxsim_model.floating_base():
53
+
54
+ # Set the model position.
55
+ model_helper.set_base_position(position=np.array(jaxsim_data.base_position()))
56
+
57
+ # Set the model orientation.
58
+ model_helper.set_base_orientation(
59
+ orientation=np.array(jaxsim_data.base_orientation())
60
+ )
61
+
62
+ # Set the joint positions.
63
+ if jaxsim_model.dofs() > 0:
64
+
65
+ model_helper.set_joint_positions(
66
+ joint_names=list(jaxsim_model.joint_names()),
67
+ positions=np.array(
68
+ jaxsim_data.joint_positions(
69
+ model=jaxsim_model, joint_names=jaxsim_model.joint_names()
70
+ )
71
+ ),
72
+ )
73
+
74
+ # Updating these joints is not necessary after the first time.
75
+ # Users can disable this update after initialization.
76
+ if update_removed_joints:
77
+
78
+ # Create a dictionary with the joints that have been removed for various reasons
79
+ # (like link lumping due to model reduction).
80
+ joints_removed_dict = {
81
+ j.name: j
82
+ for j in jaxsim_model.description._joints_removed
83
+ if j.name not in set(jaxsim_model.joint_names())
84
+ }
85
+
86
+ # Set the positions of the removed joints.
87
+ _ = [
88
+ model_helper.set_joint_position(
89
+ position=joints_removed_dict[joint_name].initial_position,
90
+ joint_name=joint_name,
91
+ )
92
+ # Select all original joint that have been removed from the JaxSim model
93
+ # that are still present in the Mujoco model.
94
+ for joint_name in joints_removed_dict
95
+ if joint_name in model_helper.joint_names()
96
+ ]
97
+
98
+ # Return the mujoco data with updated kinematics.
99
+ mj.mj_forward(mujoco_model, model_helper.data)
100
+
101
+ return model_helper.data
@@ -89,7 +89,7 @@ class MujocoVideoRecorder:
89
89
  if not exist_ok and path.is_file():
90
90
  raise FileExistsError(f"The file '{path}' already exists.")
91
91
 
92
- media.write_video(path=path, images=self.frames, fps=self.fps)
92
+ media.write_video(path=path, images=np.array(self.frames), fps=self.fps)
93
93
 
94
94
  @staticmethod
95
95
  def compute_down_sampling(original_fps: int, target_min_fps: int) -> int:
@@ -4,3 +4,10 @@ from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams
4
4
  from .rigid import RigidContacts, RigidContactsParams
5
5
  from .soft import SoftContacts, SoftContactsParams
6
6
  from .visco_elastic import ViscoElasticContacts, ViscoElasticContactsParams
7
+
8
+ ContactParamsTypes = (
9
+ SoftContactsParams
10
+ | RigidContactsParams
11
+ | RelaxedRigidContactsParams
12
+ | ViscoElasticContactsParams
13
+ )
@@ -1,8 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import abc
4
+ import functools
4
5
  from typing import Any
5
6
 
7
+ import jax
8
+ import jax.numpy as jnp
9
+
6
10
  import jaxsim.api as js
7
11
  import jaxsim.terrain
8
12
  import jaxsim.typing as jtp
@@ -14,6 +18,47 @@ except ImportError:
14
18
  from typing_extensions import Self
15
19
 
16
20
 
21
+ @functools.partial(jax.jit, static_argnames=("terrain",))
22
+ def compute_penetration_data(
23
+ p: jtp.VectorLike,
24
+ v: jtp.VectorLike,
25
+ terrain: jaxsim.terrain.Terrain,
26
+ ) -> tuple[jtp.Float, jtp.Float, jtp.Vector]:
27
+ """
28
+ Compute the penetration data (depth, rate, and terrain normal) of a collidable point.
29
+
30
+ Args:
31
+ p: The position of the collidable point.
32
+ v:
33
+ The linear velocity of the point (linear component of the mixed 6D velocity
34
+ of the implicit frame `C = (W_p_C, [W])` associated to the point).
35
+ terrain: The considered terrain.
36
+
37
+ Returns:
38
+ A tuple containing the penetration depth, the penetration velocity,
39
+ and the considered terrain normal.
40
+ """
41
+
42
+ # Pre-process the position and the linear velocity of the collidable point.
43
+ W_ṗ_C = jnp.array(v).squeeze()
44
+ px, py, pz = jnp.array(p).squeeze()
45
+
46
+ # Compute the terrain normal and the contact depth.
47
+ n̂ = terrain.normal(x=px, y=py).squeeze()
48
+ h = jnp.array([0, 0, terrain.height(x=px, y=py) - pz])
49
+
50
+ # Compute the penetration depth normal to the terrain.
51
+ δ = jnp.maximum(0.0, jnp.dot(h, n̂))
52
+
53
+ # Compute the penetration normal velocity.
54
+ δ_dot = -jnp.dot(W_ṗ_C, n̂)
55
+
56
+ # Enforce the penetration rate to be zero when the penetration depth is zero.
57
+ δ_dot = jnp.where(δ > 0, δ_dot, 0.0)
58
+
59
+ return δ, δ_dot, n̂
60
+
61
+
17
62
  class ContactsParams(JaxsimDataclass):
18
63
  """
19
64
  Abstract class representing the parameters of a contact model.
@@ -86,7 +131,7 @@ class ContactModel(JaxsimDataclass):
86
131
  model: js.model.JaxSimModel,
87
132
  data: js.data.JaxSimModelData,
88
133
  **kwargs,
89
- ) -> tuple[jtp.Vector, tuple[Any, ...]]:
134
+ ) -> tuple[jtp.Matrix, tuple[Any, ...]]:
90
135
  """
91
136
  Compute the contact forces.
92
137
 
@@ -95,8 +140,9 @@ class ContactModel(JaxsimDataclass):
95
140
  data: The data of the considered model.
96
141
 
97
142
  Returns:
98
- A tuple containing as first element the computed 6D contact force applied to the contact point and expressed in the world frame,
99
- and as second element a tuple of optional additional information.
143
+ A tuple containing as first element the computed 6D contact force applied to
144
+ the contact points and expressed in the world frame, and as second element
145
+ a tuple of optional additional information.
100
146
  """
101
147
 
102
148
  pass