jaxsim 0.4.3.dev327__py3-none-any.whl → 0.4.3.dev350__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev327'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev327')
15
+ __version__ = version = '0.4.3.dev350'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev350')
jaxsim/api/joint.py CHANGED
@@ -53,9 +53,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str
53
53
  """
54
54
 
55
55
  exceptions.raise_value_error_if(
56
- condition=jnp.array(
57
- [joint_index < 0, joint_index >= model.number_of_joints()]
58
- ).any(),
56
+ condition=joint_index < 0,
59
57
  msg="Invalid joint index '{idx}'",
60
58
  idx=joint_index,
61
59
  )
@@ -123,10 +121,7 @@ def position_limit(
123
121
  """
124
122
 
125
123
  if model.number_of_joints() == 0:
126
- s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min
127
- s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max
128
-
129
- return jnp.atleast_1d(s_min).astype(float), jnp.atleast_1d(s_max).astype(float)
124
+ return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
130
125
 
131
126
  exceptions.raise_value_error_if(
132
127
  condition=jnp.array(
@@ -136,8 +131,12 @@ def position_limit(
136
131
  idx=joint_index,
137
132
  )
138
133
 
139
- s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_index]
140
- s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_index]
134
+ s_min = jnp.atleast_1d(
135
+ model.kin_dyn_parameters.joint_parameters.position_limits_min
136
+ )[joint_index]
137
+ s_max = jnp.atleast_1d(
138
+ model.kin_dyn_parameters.joint_parameters.position_limits_max
139
+ )[joint_index]
141
140
 
142
141
  return s_min.astype(float), s_max.astype(float)
143
142
 
@@ -438,7 +438,9 @@ class KynDynParameters(JaxsimDataclass):
438
438
  # Helpers to update parameters
439
439
  # ============================
440
440
 
441
- def set_link_mass(self, link_index: int, mass: jtp.FloatLike) -> KynDynParameters:
441
+ def set_link_mass(
442
+ self, link_index: jtp.IntLike, mass: jtp.FloatLike
443
+ ) -> KynDynParameters:
442
444
  """
443
445
  Set the mass of a link.
444
446
 
@@ -457,7 +459,7 @@ class KynDynParameters(JaxsimDataclass):
457
459
  return self.replace(link_parameters=link_parameters)
458
460
 
459
461
  def set_link_inertia(
460
- self, link_index: int, inertia: jtp.MatrixLike
462
+ self, link_index: jtp.IntLike, inertia: jtp.MatrixLike
461
463
  ) -> KynDynParameters:
462
464
  r"""
463
465
  Set the inertia tensor of a link.
@@ -593,10 +595,10 @@ class LinkParameters(JaxsimDataclass):
593
595
  """
594
596
 
595
597
  # Extract the link parameters from the 6D spatial inertia.
596
- m, L_p_CoM, I = Inertia.to_params(M=M)
598
+ m, L_p_CoM, I_CoM = Inertia.to_params(M=M)
597
599
 
598
600
  # Extract only the necessary elements of the inertia tensor.
599
- inertia_elements = I[jnp.triu_indices(3)]
601
+ inertia_elements = I_CoM[jnp.triu_indices(3)]
600
602
 
601
603
  return LinkParameters(
602
604
  index=jnp.array(index).squeeze().astype(int),
jaxsim/api/link.py CHANGED
@@ -4,6 +4,7 @@ from collections.abc import Sequence
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
  import jax.scipy.linalg
7
+ import numpy as np
7
8
 
8
9
  import jaxsim.api as js
9
10
  import jaxsim.rbda
@@ -54,9 +55,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
54
55
  """
55
56
 
56
57
  exceptions.raise_value_error_if(
57
- condition=jnp.array(
58
- [link_index < 0, link_index >= model.number_of_links()]
59
- ).any(),
58
+ condition=link_index < 0,
60
59
  msg="Invalid link index '{idx}'",
61
60
  idx=link_index,
62
61
  )
@@ -98,7 +97,7 @@ def idxs_to_names(
98
97
  The names of the links.
99
98
  """
100
99
 
101
- return tuple(idx_to_name(model=model, link_index=idx) for idx in link_indices)
100
+ return tuple(np.array(model.kin_dyn_parameters.link_names)[list(link_indices)])
102
101
 
103
102
 
104
103
  # =========
jaxsim/api/model.py CHANGED
@@ -304,7 +304,7 @@ class JaxSimModel(JaxsimDataclass):
304
304
 
305
305
  return self.model_name
306
306
 
307
- def number_of_links(self) -> jtp.Int:
307
+ def number_of_links(self) -> int:
308
308
  """
309
309
  Return the number of links in the model.
310
310
 
@@ -317,7 +317,7 @@ class JaxSimModel(JaxsimDataclass):
317
317
 
318
318
  return self.kin_dyn_parameters.number_of_links()
319
319
 
320
- def number_of_joints(self) -> jtp.Int:
320
+ def number_of_joints(self) -> int:
321
321
  """
322
322
  Return the number of joints in the model.
323
323
 
@@ -419,7 +419,7 @@ class JaxSimModel(JaxsimDataclass):
419
419
  def reduce(
420
420
  model: JaxSimModel,
421
421
  considered_joints: tuple[str, ...],
422
- locked_joint_positions: dict[str, jtp.Float] | None = None,
422
+ locked_joint_positions: dict[str, jtp.FloatLike] | None = None,
423
423
  ) -> JaxSimModel:
424
424
  """
425
425
  Reduce the model by lumping together the links connected by removed joints.
@@ -1038,12 +1038,7 @@ def forward_dynamics_aba(
1038
1038
  C_v̇_WB = to_active(
1039
1039
  W_v̇_WB=W_v̇_WB,
1040
1040
  W_H_C=W_H_C,
1041
- W_v_WB=jnp.hstack(
1042
- [
1043
- data.state.physics_model.base_linear_velocity,
1044
- data.state.physics_model.base_angular_velocity,
1045
- ]
1046
- ),
1041
+ W_v_WB=W_v_WB,
1047
1042
  W_v_WC=W_v_WC,
1048
1043
  )
1049
1044
 
@@ -2274,16 +2269,12 @@ def step(
2274
2269
  # Raise runtime error for not supported case in which Rigid contacts and
2275
2270
  # Baumgarte stabilization are enabled and used with ForwardEuler integrator.
2276
2271
  jaxsim.exceptions.raise_runtime_error_if(
2277
- condition=jnp.logical_and(
2278
- isinstance(
2279
- integrator,
2280
- jaxsim.integrators.fixed_step.ForwardEuler
2281
- | jaxsim.integrators.fixed_step.ForwardEulerSO3,
2282
- ),
2283
- jnp.array(
2284
- [data_tf.contacts_params.K, data_tf.contacts_params.D]
2285
- ).any(),
2286
- ),
2272
+ condition=isinstance(
2273
+ integrator,
2274
+ jaxsim.integrators.fixed_step.ForwardEuler
2275
+ | jaxsim.integrators.fixed_step.ForwardEulerSO3,
2276
+ )
2277
+ & ((data_tf.contacts_params.K > 0) | (data_tf.contacts_params.D > 0)),
2287
2278
  msg="Baumgarte stabilization is not supported with ForwardEuler integrators",
2288
2279
  )
2289
2280
 
jaxsim/api/references.py CHANGED
@@ -503,7 +503,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
503
503
  ]
504
504
 
505
505
  exceptions.raise_value_error_if(
506
- condition=jnp.logical_not(data.valid(model=model)),
506
+ condition=~data.valid(model=model),
507
507
  msg="The provided data is not valid for the model",
508
508
  )
509
509
  W_H_Fi = jax.vmap(
@@ -319,7 +319,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
319
319
  f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
320
320
 
321
321
  # Initialize the carry of the for loop with the stacked kᵢ vectors.
322
- carry0 = jax.tree_map(
322
+ carry0 = jax.tree.map(
323
323
  lambda l: jnp.zeros((c.size, *l.shape), dtype=l.dtype), x0
324
324
  )
325
325
 
@@ -507,7 +507,7 @@ class ExplicitRungeKuttaSO3Mixin:
507
507
 
508
508
  # We assume that the initial quaternion is already unary.
509
509
  exceptions.raise_runtime_error_if(
510
- condition=jnp.logical_not(jnp.allclose(W_Q_B_t0.dot(W_Q_B_t0), 1.0)),
510
+ condition=~jnp.allclose(W_Q_B_t0.dot(W_Q_B_t0), 1.0),
511
511
  msg="The SO(3) integrator received a quaternion at t0 that is not unary.",
512
512
  )
513
513
 
@@ -152,7 +152,7 @@ def compute_pytree_scale(
152
152
  """
153
153
 
154
154
  # Consider a zero second pytree, if not given.
155
- x2 = jax.tree.map(lambda l: jnp.zeros_like(l), x1) if x2 is None else x2
155
+ x2 = jax.tree.map(jnp.zeros_like, x1) if x2 is None else x2
156
156
 
157
157
  # Compute the scaling factors of the initial state and its derivative.
158
158
  compute_scale = lambda l1, l2: atol + jnp.maximum(jnp.abs(l1), jnp.abs(l2)) * rtol
@@ -199,9 +199,7 @@ def local_error_estimation(
199
199
 
200
200
  # Consider a zero estimated final state, if not given.
201
201
  xf_estimate = (
202
- jax.tree.map(lambda l: jnp.zeros_like(l), xf)
203
- if xf_estimate is None
204
- else xf_estimate
202
+ jax.tree.map(jnp.zeros_like, xf) if xf_estimate is None else xf_estimate
205
203
  )
206
204
 
207
205
  # Estimate the error.
@@ -483,14 +481,10 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
483
481
  metadata_next,
484
482
  discarded_steps,
485
483
  ) = jax.lax.cond(
486
- pred=jnp.array(
487
- [
488
- discarded_steps >= self.max_step_rejections,
489
- local_error <= 1.0,
490
- Δt_next < self.dt_min,
491
- integrator_init,
492
- ]
493
- ).any(),
484
+ pred=discarded_steps
485
+ >= self.max_step_rejections | local_error
486
+ <= 1.0 | Δt_next
487
+ < self.dt_min | integrator_init,
494
488
  true_fun=accept_step,
495
489
  false_fun=reject_step,
496
490
  )
jaxsim/mujoco/loaders.py CHANGED
@@ -1,6 +1,3 @@
1
- from __future__ import annotations
2
-
3
- import dataclasses
4
1
  import pathlib
5
2
  import tempfile
6
3
  import warnings
@@ -9,10 +6,14 @@ from typing import Any
9
6
 
10
7
  import mujoco as mj
11
8
  import numpy as np
12
- import numpy.typing as npt
13
9
  import rod.urdf.exporter
14
10
  from lxml import etree as ET
15
- from scipy.spatial.transform import Rotation
11
+
12
+ from .utils import MujocoCamera
13
+
14
+ MujocoCameraType = (
15
+ MujocoCamera | Sequence[MujocoCamera] | dict[str, str] | Sequence[dict[str, str]]
16
+ )
16
17
 
17
18
 
18
19
  def load_rod_model(
@@ -167,12 +168,7 @@ class RodModelToMjcf:
167
168
  plane_normal: tuple[float, float, float] = (0, 0, 1),
168
169
  heightmap: bool | None = None,
169
170
  heightmap_samples_xy: tuple[int, int] = (101, 101),
170
- cameras: (
171
- MujocoCamera
172
- | Sequence[MujocoCamera]
173
- | dict[str, str]
174
- | Sequence[dict[str, str]]
175
- ) = (),
171
+ cameras: MujocoCameraType = (),
176
172
  ) -> tuple[str, dict[str, Any]]:
177
173
  """
178
174
  Converts a ROD model to a Mujoco MJCF string.
@@ -533,12 +529,7 @@ class UrdfToMjcf:
533
529
  model_name: str | None = None,
534
530
  plane_normal: tuple[float, float, float] = (0, 0, 1),
535
531
  heightmap: bool | None = None,
536
- cameras: (
537
- MujocoCamera
538
- | Sequence[MujocoCamera]
539
- | dict[str, str]
540
- | Sequence[dict[str, str]]
541
- ) = (),
532
+ cameras: MujocoCameraType = (),
542
533
  ) -> tuple[str, dict[str, Any]]:
543
534
  """
544
535
  Converts a URDF file to a Mujoco MJCF string.
@@ -580,12 +571,7 @@ class SdfToMjcf:
580
571
  model_name: str | None = None,
581
572
  plane_normal: tuple[float, float, float] = (0, 0, 1),
582
573
  heightmap: bool | None = None,
583
- cameras: (
584
- MujocoCamera
585
- | Sequence[MujocoCamera]
586
- | dict[str, str]
587
- | Sequence[dict[str, str]]
588
- ) = (),
574
+ cameras: MujocoCameraType = (),
589
575
  ) -> tuple[str, dict[str, Any]]:
590
576
  """
591
577
  Converts a SDF file to a Mujoco MJCF string.
@@ -617,118 +603,3 @@ class SdfToMjcf:
617
603
  heightmap=heightmap,
618
604
  cameras=cameras,
619
605
  )
620
-
621
-
622
- @dataclasses.dataclass
623
- class MujocoCamera:
624
- """
625
- Helper class storing parameters of a Mujoco camera.
626
-
627
- Refer to the official documentation for more details:
628
- https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-camera
629
- """
630
-
631
- mode: str = "fixed"
632
-
633
- target: str | None = None
634
- fovy: str = "45"
635
- pos: str = "0 0 0"
636
-
637
- quat: str | None = None
638
- axisangle: str | None = None
639
- xyaxes: str | None = None
640
- zaxis: str | None = None
641
- euler: str | None = None
642
-
643
- name: str | None = None
644
-
645
- @classmethod
646
- def build(cls, **kwargs) -> MujocoCamera:
647
-
648
- if not all(isinstance(value, str) for value in kwargs.values()):
649
- raise ValueError(f"Values must be strings: {kwargs}")
650
-
651
- return cls(**kwargs)
652
-
653
- @staticmethod
654
- def build_from_target_view(
655
- camera_name: str,
656
- lookat: Sequence[float | int] | npt.NDArray = (0, 0, 0),
657
- distance: float | int | npt.NDArray = 3,
658
- azimut: float | int | npt.NDArray = 90,
659
- elevation: float | int | npt.NDArray = -45,
660
- fovy: float | int | npt.NDArray = 45,
661
- degrees: bool = True,
662
- **kwargs,
663
- ) -> MujocoCamera:
664
- """
665
- Create a custom camera that looks at a target point.
666
-
667
- Note:
668
- The choice of the parameters is easier if we imagine to consider a target
669
- frame `T` whose origin is located over the lookat point and having the same
670
- orientation of the world frame `W`. We also introduce a camera frame `C`
671
- whose origin is located over the lower-left corner of the image, and having
672
- the x-axis pointing right and the y-axis pointing up in image coordinates.
673
- The camera renders what it sees in the -z direction of frame `C`.
674
-
675
- Args:
676
- camera_name: The name of the camera.
677
- lookat: The target point to look at (origin of `T`).
678
- distance:
679
- The distance from the target point (displacement between the origins
680
- of `T` and `C`).
681
- azimut:
682
- The rotation around z of the camera. With an angle of 0, the camera
683
- would loot at the target point towards the positive x-axis of `T`.
684
- elevation:
685
- The rotation around the x-axis of the camera frame `C`. Note that if
686
- you want to lift the view angle, the elevation is negative.
687
- fovy: The field of view of the camera.
688
- degrees: Whether the angles are in degrees or radians.
689
- **kwargs: Additional camera parameters.
690
-
691
- Returns:
692
- The custom camera.
693
- """
694
-
695
- # Start from a frame whose origin is located over the lookat point.
696
- # We initialize a -90 degrees rotation around the z-axis because due to
697
- # the default camera coordinate system (x pointing right, y pointing up).
698
- W_H_C = np.eye(4)
699
- W_H_C[0:3, 3] = np.array(lookat)
700
- W_H_C[0:3, 0:3] = Rotation.from_euler(
701
- seq="ZX", angles=[-90, 90], degrees=True
702
- ).as_matrix()
703
-
704
- # Process the azimut.
705
- R_az = Rotation.from_euler(seq="Y", angles=azimut, degrees=degrees).as_matrix()
706
- W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_az
707
-
708
- # Process elevation.
709
- R_el = Rotation.from_euler(
710
- seq="X", angles=elevation, degrees=degrees
711
- ).as_matrix()
712
- W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_el
713
-
714
- # Process distance.
715
- tf_distance = np.eye(4)
716
- tf_distance[2, 3] = distance
717
- W_H_C = W_H_C @ tf_distance
718
-
719
- # Extract the position and the quaternion.
720
- p = W_H_C[0:3, 3]
721
- Q = Rotation.from_matrix(W_H_C[0:3, 0:3]).as_quat(scalar_first=True)
722
-
723
- return MujocoCamera.build(
724
- name=camera_name,
725
- mode="fixed",
726
- fovy=f"{fovy if degrees else np.rad2deg(fovy)}",
727
- pos=" ".join(p.astype(str).tolist()),
728
- quat=" ".join(Q.astype(str).tolist()),
729
- **kwargs,
730
- )
731
-
732
- def asdict(self) -> dict[str, str]:
733
-
734
- return {k: v for k, v in dataclasses.asdict(self).items() if v is not None}
jaxsim/mujoco/utils.py CHANGED
@@ -1,7 +1,14 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from collections.abc import Sequence
5
+
1
6
  import mujoco as mj
2
7
  import numpy as np
8
+ import numpy.typing as npt
9
+ from scipy.spatial.transform import Rotation
3
10
 
4
- from . import MujocoModelHelper
11
+ from .model import MujocoModelHelper
5
12
 
6
13
 
7
14
  def mujoco_data_from_jaxsim(
@@ -99,3 +106,118 @@ def mujoco_data_from_jaxsim(
99
106
  mj.mj_forward(mujoco_model, model_helper.data)
100
107
 
101
108
  return model_helper.data
109
+
110
+
111
+ @dataclasses.dataclass
112
+ class MujocoCamera:
113
+ """
114
+ Helper class storing parameters of a Mujoco camera.
115
+
116
+ Refer to the official documentation for more details:
117
+ https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-camera
118
+ """
119
+
120
+ mode: str = "fixed"
121
+
122
+ target: str | None = None
123
+ fovy: str = "45"
124
+ pos: str = "0 0 0"
125
+
126
+ quat: str | None = None
127
+ axisangle: str | None = None
128
+ xyaxes: str | None = None
129
+ zaxis: str | None = None
130
+ euler: str | None = None
131
+
132
+ name: str | None = None
133
+
134
+ @classmethod
135
+ def build(cls, **kwargs) -> MujocoCamera:
136
+
137
+ if not all(isinstance(value, str) for value in kwargs.values()):
138
+ raise ValueError(f"Values must be strings: {kwargs}")
139
+
140
+ return cls(**kwargs)
141
+
142
+ @staticmethod
143
+ def build_from_target_view(
144
+ camera_name: str,
145
+ lookat: Sequence[float | int] | npt.NDArray = (0, 0, 0),
146
+ distance: float | int | npt.NDArray = 3,
147
+ azimut: float | int | npt.NDArray = 90,
148
+ elevation: float | int | npt.NDArray = -45,
149
+ fovy: float | int | npt.NDArray = 45,
150
+ degrees: bool = True,
151
+ **kwargs,
152
+ ) -> MujocoCamera:
153
+ """
154
+ Create a custom camera that looks at a target point.
155
+
156
+ Note:
157
+ The choice of the parameters is easier if we imagine to consider a target
158
+ frame `T` whose origin is located over the lookat point and having the same
159
+ orientation of the world frame `W`. We also introduce a camera frame `C`
160
+ whose origin is located over the lower-left corner of the image, and having
161
+ the x-axis pointing right and the y-axis pointing up in image coordinates.
162
+ The camera renders what it sees in the -z direction of frame `C`.
163
+
164
+ Args:
165
+ camera_name: The name of the camera.
166
+ lookat: The target point to look at (origin of `T`).
167
+ distance:
168
+ The distance from the target point (displacement between the origins
169
+ of `T` and `C`).
170
+ azimut:
171
+ The rotation around z of the camera. With an angle of 0, the camera
172
+ would loot at the target point towards the positive x-axis of `T`.
173
+ elevation:
174
+ The rotation around the x-axis of the camera frame `C`. Note that if
175
+ you want to lift the view angle, the elevation is negative.
176
+ fovy: The field of view of the camera.
177
+ degrees: Whether the angles are in degrees or radians.
178
+ **kwargs: Additional camera parameters.
179
+
180
+ Returns:
181
+ The custom camera.
182
+ """
183
+
184
+ # Start from a frame whose origin is located over the lookat point.
185
+ # We initialize a -90 degrees rotation around the z-axis because due to
186
+ # the default camera coordinate system (x pointing right, y pointing up).
187
+ W_H_C = np.eye(4)
188
+ W_H_C[0:3, 3] = np.array(lookat)
189
+ W_H_C[0:3, 0:3] = Rotation.from_euler(
190
+ seq="ZX", angles=[-90, 90], degrees=True
191
+ ).as_matrix()
192
+
193
+ # Process the azimut.
194
+ R_az = Rotation.from_euler(seq="Y", angles=azimut, degrees=degrees).as_matrix()
195
+ W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_az
196
+
197
+ # Process elevation.
198
+ R_el = Rotation.from_euler(
199
+ seq="X", angles=elevation, degrees=degrees
200
+ ).as_matrix()
201
+ W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_el
202
+
203
+ # Process distance.
204
+ tf_distance = np.eye(4)
205
+ tf_distance[2, 3] = distance
206
+ W_H_C = W_H_C @ tf_distance
207
+
208
+ # Extract the position and the quaternion.
209
+ p = W_H_C[0:3, 3]
210
+ Q = Rotation.from_matrix(W_H_C[0:3, 0:3]).as_quat(scalar_first=True)
211
+
212
+ return MujocoCamera.build(
213
+ name=camera_name,
214
+ mode="fixed",
215
+ fovy=f"{fovy if degrees else np.rad2deg(fovy)}",
216
+ pos=" ".join(p.astype(str).tolist()),
217
+ quat=" ".join(Q.astype(str).tolist()),
218
+ **kwargs,
219
+ )
220
+
221
+ def asdict(self) -> dict[str, str]:
222
+
223
+ return {k: v for k, v in dataclasses.asdict(self).items() if v is not None}
@@ -100,32 +100,7 @@ class JointDescription(JaxsimDataclass):
100
100
  if not isinstance(other, JointDescription):
101
101
  return False
102
102
 
103
- if not (
104
- self.name == other.name
105
- and self.jtype == other.jtype
106
- and self.child == other.child
107
- and self.parent == other.parent
108
- and self.index == other.index
109
- and all(
110
- np.allclose(getattr(self, attr), getattr(other, attr))
111
- for attr in [
112
- "axis",
113
- "pose",
114
- "friction_static",
115
- "friction_viscous",
116
- "position_limit_damper",
117
- "position_limit_spring",
118
- "position_limit",
119
- "initial_position",
120
- "motor_inertia",
121
- "motor_viscous_friction",
122
- "motor_gear_ratio",
123
- ]
124
- ),
125
- ):
126
- return False
127
-
128
- return True
103
+ return hash(self) == hash(other)
129
104
 
130
105
  def __hash__(self) -> int:
131
106
 
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import copy
4
4
  import dataclasses
5
5
  import functools
6
- from collections.abc import Callable, Iterable, Sequence
6
+ from collections.abc import Callable, Iterable, Iterator, Sequence
7
7
  from typing import Any
8
8
 
9
9
  import numpy as np
@@ -82,7 +82,7 @@ class KinematicGraph(Sequence[LinkDescription]):
82
82
  default_factory=list, hash=False, compare=False
83
83
  )
84
84
 
85
- root_pose: RootPose = dataclasses.field(default_factory=lambda: RootPose())
85
+ root_pose: RootPose = dataclasses.field(default_factory=RootPose)
86
86
 
87
87
  # Private attribute storing optional additional info.
88
88
  _extra_info: dict[str, Any] = dataclasses.field(
@@ -700,7 +700,7 @@ class KinematicGraph(Sequence[LinkDescription]):
700
700
  # Sequence protocol
701
701
  # =================
702
702
 
703
- def __iter__(self) -> Iterable[LinkDescription]:
703
+ def __iter__(self) -> Iterator[LinkDescription]:
704
704
  yield from KinematicGraph.breadth_first_search(root=self.root)
705
705
 
706
706
  def __reversed__(self) -> Iterable[LinkDescription]:
@@ -85,10 +85,7 @@ def extract_model_data(
85
85
 
86
86
  # Log type of base link.
87
87
  logging.debug(
88
- msg="Model '{}' is {}".format(
89
- sdf_model.name,
90
- "fixed-base" if sdf_model.is_fixed_base() else "floating-base",
91
- )
88
+ msg=f"Model '{sdf_model.name}' is {'fixed-base' if sdf_model.is_fixed_base() else 'floating-base'}"
92
89
  )
93
90
 
94
91
  # Log detected base link.
@@ -175,7 +172,7 @@ def extract_model_data(
175
172
  for j in sdf_model.joints()
176
173
  if j.type == "fixed"
177
174
  and j.parent == "world"
178
- and j.child in links_dict.keys()
175
+ and j.child in links_dict
179
176
  and j.pose.relative_to in {"__model__", "world", None}
180
177
  ]
181
178
 
@@ -287,7 +284,7 @@ def extract_model_data(
287
284
  for j in sdf_model.joints()
288
285
  if j.type in {"revolute", "continuous", "prismatic", "fixed"}
289
286
  and j.parent != "world"
290
- and j.child in links_dict.keys()
287
+ and j.child in links_dict
291
288
  ]
292
289
 
293
290
  # Create a dictionary to find the parent joint of the links.
@@ -179,7 +179,7 @@ def create_sphere_collision(
179
179
 
180
180
  r = collision.geometry.sphere.radius
181
181
  sphere_points = r * fibonacci_sphere(
182
- samples=int(os.getenv(key="JAXSIM_COLLISION_SPHERE_POINTS", default="250"))
182
+ samples=int(os.getenv(key="JAXSIM_COLLISION_SPHERE_POINTS", default="50"))
183
183
  )
184
184
 
185
185
  H = collision.pose.transform() if collision.pose is not None else np.eye(4)
jaxsim/rbda/jacobian.py CHANGED
@@ -205,7 +205,7 @@ def jacobian_full_doubly_left(
205
205
  # Convert adjoints to SE(3) transforms.
206
206
  # Returning them here prevents calling FK in case the output representation
207
207
  # of the Jacobian needs to be changed.
208
- B_H_L = jax.vmap(lambda B_X_L: Adjoint.to_transform(B_X_L))(B_X_i)
208
+ B_H_L = jax.vmap(Adjoint.to_transform)(B_X_i)
209
209
 
210
210
  # Adjust shape of doubly-left free-floating full Jacobian.
211
211
  B_J_full_WL_B = J.squeeze().astype(float)
@@ -322,7 +322,7 @@ def jacobian_derivative_full_doubly_left(
322
322
  # Convert adjoints to SE(3) transforms.
323
323
  # Returning them here prevents calling FK in case the output representation
324
324
  # of the Jacobian needs to be changed.
325
- B_H_L = jax.vmap(lambda B_X_L: Adjoint.to_transform(B_X_L))(B_X_i)
325
+ B_H_L = jax.vmap(Adjoint.to_transform)(B_X_i)
326
326
 
327
327
  # Adjust shape of doubly-left free-floating full Jacobian derivative.
328
328
  B_J̇_full_WL_B = J̇.squeeze().astype(float)
jaxsim/rbda/utils.py CHANGED
@@ -135,7 +135,7 @@ def process_inputs(
135
135
  # Check that the quaternion is unary since our RBDAs make this assumption in order
136
136
  # to prevent introducing additional normalizations that would affect AD.
137
137
  exceptions.raise_value_error_if(
138
- condition=jnp.logical_not(jnp.allclose(W_Q_B.dot(W_Q_B), 1.0)),
138
+ condition=~jnp.allclose(W_Q_B.dot(W_Q_B), 1.0),
139
139
  msg="A RBDA received a quaternion that is not normalized.",
140
140
  )
141
141
 
jaxsim/terrain/terrain.py CHANGED
@@ -8,6 +8,7 @@ import jax_dataclasses
8
8
  import numpy as np
9
9
 
10
10
  import jaxsim.typing as jtp
11
+ from jaxsim import exceptions
11
12
 
12
13
 
13
14
  class Terrain(abc.ABC):
@@ -108,7 +109,9 @@ class PlaneTerrain(FlatTerrain):
108
109
  _normal=tuple(normal.tolist()),
109
110
  )
110
111
 
111
- def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
112
+ def normal(
113
+ self, x: jtp.FloatLike | None = None, y: jtp.FloatLike | None = None
114
+ ) -> jtp.Vector:
112
115
  """
113
116
  Compute the normal vector of the terrain at a specific (x, y) location.
114
117
 
@@ -141,6 +144,11 @@ class PlaneTerrain(FlatTerrain):
141
144
  # Get the plane equation coefficients from the terrain normal.
142
145
  A, B, C = self._normal
143
146
 
147
+ exceptions.raise_value_error_if(
148
+ condition=jnp.allclose(C, 0.0),
149
+ msg="The z component of the normal cannot be zero.",
150
+ )
151
+
144
152
  # Compute the final coefficient D considering the terrain height.
145
153
  D = -C * self._height
146
154
 
jaxsim/utils/tracing.py CHANGED
@@ -8,15 +8,9 @@ import jax.interpreters.partial_eval
8
8
  def tracing(var: Any) -> bool | jax.Array:
9
9
  """Returns True if the variable is being traced by JAX, False otherwise."""
10
10
 
11
- return jax.numpy.array(
12
- [
13
- isinstance(var, t)
14
- for t in (
15
- jax._src.core.Tracer,
16
- jax.interpreters.partial_eval.DynamicJaxprTracer,
17
- )
18
- ]
19
- ).any()
11
+ return isinstance(
12
+ var, jax._src.core.Tracer | jax.interpreters.partial_eval.DynamicJaxprTracer
13
+ )
20
14
 
21
15
 
22
16
  def not_tracing(var: Any) -> bool | jax.Array:
jaxsim/utils/wrappers.py CHANGED
@@ -49,7 +49,7 @@ class CustomHashedObject(Generic[T]):
49
49
 
50
50
  obj: T
51
51
 
52
- hash_function: Callable[[T], int] = dataclasses.field(default=lambda obj: hash(obj))
52
+ hash_function: Callable[[T], int] = hash
53
53
 
54
54
  def get(self: CustomHashedObject[T]) -> T:
55
55
  return self.obj
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev327
3
+ Version: 0.4.3.dev350
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -1,5 +1,5 @@
1
1
  jaxsim/__init__.py,sha256=opgtbhhd1kDsHI4H1vOd3loMPDRi884yQ3tohfFGfNc,3382
2
- jaxsim/_version.py,sha256=crnjdWEkz6VH0LOon-LPCudsgKRsebWeeDMlTRY3AB0,428
2
+ jaxsim/_version.py,sha256=56tTuqXBlX9UQVJyJ_A9hRmvozgLzRGJ-9ZCppehae8,428
3
3
  jaxsim/exceptions.py,sha256=vSoScaRD4nvh6jltgK9Ry5pKnE0O5hb4_yI_pk_fvR8,2175
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
5
  jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
@@ -9,17 +9,17 @@ jaxsim/api/common.py,sha256=SNgxq42r6eF_-aPszvOjUYkGwXOzz4hKmhDwEUkscFQ,6650
9
9
  jaxsim/api/contact.py,sha256=D6RucrH9gnoUFLdmAEYwLGrimU0wLmuoDeOONu4ni74,25658
10
10
  jaxsim/api/data.py,sha256=ThRpoBlbdwf1N3xs8SWrY5d8RbfdYRwFcmkdIPgtee4,29004
11
11
  jaxsim/api/frame.py,sha256=yPSgNygHkvWlln4wShNt7vZm_fFobVEm7phsklNNyH8,12922
12
- jaxsim/api/joint.py,sha256=Vl9VJs66_es88zjBBXqPzKkA3oAktV-DNiTkXwxOSsI,7562
13
- jaxsim/api/kin_dyn_parameters.py,sha256=eFcRJbfati3YsUhiJc5ZgngE-eqzYhqtud9L3ntQ-Uw,29632
14
- jaxsim/api/link.py,sha256=au47jV7bNdH2itvZ4qngKZyLi0UTIKzqVjjKosnlMsU,12858
15
- jaxsim/api/model.py,sha256=H87gCwt3_J8NGF_L3CBlqSNxiUmJWPHzP2v-XdrUx3A,79862
12
+ jaxsim/api/joint.py,sha256=8rCIxRMeAidsaBbw7kkGp6z3-UmBPtqmYmV_arHDQJ8,7365
13
+ jaxsim/api/kin_dyn_parameters.py,sha256=Y9wnMshz83Zm4UEPOAOTINdtfkBZ86w853c8Yi2qaVs,29670
14
+ jaxsim/api/link.py,sha256=nHjffhNdi_xGkteMsqdb_hC9mdV9rNw7k3pl89Uhw_8,12798
15
+ jaxsim/api/model.py,sha256=A88AaBZpWvQ-L9blFyl1GHvTWI05rvVFKbSaHzD77_k,79563
16
16
  jaxsim/api/ode.py,sha256=_t18avoCJngQk6eMFTGpaeahbpchQP20qJnUOCPkz8s,15360
17
17
  jaxsim/api/ode_data.py,sha256=1SD-x-lYk_YSEnVpxTLd69uOKC0mFUj44ZqpSmEDOxw,20190
18
- jaxsim/api/references.py,sha256=crPzgUCDHkTKUuXQcj9ygFAI0pewIx-fTPEsU8fBiU4,20555
18
+ jaxsim/api/references.py,sha256=eIOk3MAOc9LJSKfI8M4WA8gGD-meo50vRfhXdea4sNI,20539
19
19
  jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
20
- jaxsim/integrators/common.py,sha256=8Y3PGVOydPBJrqYtH33rtng48T12ySWfXTk-ED_2wao,18297
20
+ jaxsim/integrators/common.py,sha256=ohISUnUWTaNHt2kweg1JyzwYGZgIH_wc-01qJWJsO80,18281
21
21
  jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
22
- jaxsim/integrators/variable_step.py,sha256=hGYKG3Sq3QITgzIePmCVCrrirwagqsKnB3aYifAcKR4,22848
22
+ jaxsim/integrators/variable_step.py,sha256=Tqz5ySSgyKak_k6cTXpmtqdPNaFlO7N6zj7jBIlChyM,22681
23
23
  jaxsim/math/__init__.py,sha256=8oPITEoGwgRcOeG8KxtqxPQ8b5uku1HNRMokpCoi9Tc,352
24
24
  jaxsim/math/adjoint.py,sha256=V7r5VrTCKPLEL5gavNSx9U7xSsrb11a5e4gWqJ2MuRo,4375
25
25
  jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
@@ -31,28 +31,28 @@ jaxsim/math/skew.py,sha256=oOGSSR8PUGROl6IJFlrmu6K3gPH-u16hUPfKIkcVv9o,1177
31
31
  jaxsim/math/transform.py,sha256=KXzQgOnCfAtbXCwxhplpJ3F0JT3oEyeLVby1_uRAryQ,2892
32
32
  jaxsim/mujoco/__init__.py,sha256=fZyRWre49pIhOrYdf6yJk_hOax8qWGe8OCmoq-dMVq8,201
33
33
  jaxsim/mujoco/__main__.py,sha256=GBmB7J-zj75ZnFyuAAmpSOpbxi_HhHhWJeot3ljGDJY,5291
34
- jaxsim/mujoco/loaders.py,sha256=CkFGydgOku5P_Pz7wdWlM2SCJRs71ePF-vsY9i90-I0,25350
34
+ jaxsim/mujoco/loaders.py,sha256=_CZekIqZNe8oFeH7zSv4gGZAZENRISwMd8dt640zjRI,20860
35
35
  jaxsim/mujoco/model.py,sha256=5_7rWk_WBkNKDHqeewIFj0t2ZGqJpE6RDXHSbRvw4e4,16493
36
- jaxsim/mujoco/utils.py,sha256=bGbLMSzcdqbinIwHHJHt8ZN1uup_6DLdB2dWqKiXwO4,3955
36
+ jaxsim/mujoco/utils.py,sha256=vZ8afASNOSxnxVW9p_1U1J_n-9nVhnBDqlV5k8c1GkM,8256
37
37
  jaxsim/mujoco/visualizer.py,sha256=nD6SNWmn-nxjjjIY9oPAHvL2j8q93DJDjZeepzke_DQ,6988
38
38
  jaxsim/parsers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
- jaxsim/parsers/kinematic_graph.py,sha256=wT2bgaCS8VQJTHy2H9sENkVPDOiMkRikxEF1t_WaahQ,34748
39
+ jaxsim/parsers/kinematic_graph.py,sha256=MJkJ7AW1TdLZmxibuiVrTfn6jHjh3OVhEF20DqwsCnM,34748
40
40
  jaxsim/parsers/descriptions/__init__.py,sha256=PbIlunVfb59pB5jSX97YVpMAANRZPRkJ0X-hS14rzv4,221
41
41
  jaxsim/parsers/descriptions/collision.py,sha256=BQeIG-TKi4SVny23w6riDrQ5itC6VRwEMBX6HgAXHxA,3973
42
- jaxsim/parsers/descriptions/joint.py,sha256=VSb6C0FBBKMqwrHBKfc-Bbn4rl_J0RzUxMQlhIEvOPM,5185
42
+ jaxsim/parsers/descriptions/joint.py,sha256=2KWLP4ILPMV8q1X0J7aS3GGFeZn4zXan0dqGOWc7XuQ,4365
43
43
  jaxsim/parsers/descriptions/link.py,sha256=Eh0W5qL7_Uw0GV-BkNKXhm9Q2dRTfIWCX5D-87zQkxA,3711
44
44
  jaxsim/parsers/descriptions/model.py,sha256=I2Vsbv8Josl4Le7b5rIvhqA2k9Bbv5JxMqwytayxds0,9833
45
45
  jaxsim/parsers/rod/__init__.py,sha256=G2vqlLajBLUc4gyzXwsEI2Wsi4TMOIF9bLDFeT6KrGU,92
46
- jaxsim/parsers/rod/parser.py,sha256=UAL6vkAab2y6PgIv3tAL8UodNbq-mmNBuBotxHJOSVU,14035
47
- jaxsim/parsers/rod/utils.py,sha256=5DsF3OeePZGidOJ5GiFSZx-51uIdnFvMW9EK6SgOW6Q,5698
46
+ jaxsim/parsers/rod/parser.py,sha256=EXcbtr_vMjAaUzQjfQlD1zLLYLAZXrNeFHaiZVlLwFI,13976
47
+ jaxsim/parsers/rod/utils.py,sha256=czQ2Y1_I9zGO0y2XDotHSqDorVH6zEcPhkuelApjs3k,5697
48
48
  jaxsim/rbda/__init__.py,sha256=kmy4G9aMkrqPNGdLSaSV3k15dpF52vBEUQXDFDuKIxU,337
49
49
  jaxsim/rbda/aba.py,sha256=w7ciyxB0IsmueatT0C7PcBQEl9dyiH9oqJgIi3xeTUE,8983
50
50
  jaxsim/rbda/collidable_points.py,sha256=0PFLzxWKtRg8-JtfNhGlSjBMv1J98tiLymOdvlvAak4,5325
51
51
  jaxsim/rbda/crba.py,sha256=bXkXESnVbv-lxhU1Y_i0rViEcQA4z2t2_jHwdVj5CBo,5049
52
52
  jaxsim/rbda/forward_kinematics.py,sha256=2GmEoWsrioVl_SAbKRKfhOLz57pY4aR81PKRdulqStA,3458
53
- jaxsim/rbda/jacobian.py,sha256=p0EV_8cLzLVV-93VKznT7VPuRj8W7h7rQWkPlWJXfCA,11023
53
+ jaxsim/rbda/jacobian.py,sha256=L6Vn4Kf9I6wj-MYcFY6o67mgIfLFaaW4i2wNQJ2PDL0,10981
54
54
  jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
55
- jaxsim/rbda/utils.py,sha256=eeT21Y4DiiyhrdF0lUE_VvRuwru5-rR7yOlOlWzCCWE,5381
55
+ jaxsim/rbda/utils.py,sha256=GLt7XIl1ROkx0_fnBCKUHYdB9_IBF3Yi4OnkHSX3gxA,5365
56
56
  jaxsim/rbda/contacts/__init__.py,sha256=L5MM-2pv76YPGzxExdz2EErgGBATuAjYnNHlq5QOySs,503
57
57
  jaxsim/rbda/contacts/common.py,sha256=ai49HeLQOsnckG0H2tUKW2KQ0Au_v9jRuNdnqie-YBk,11234
58
58
  jaxsim/rbda/contacts/relaxed_rigid.py,sha256=tbyskONuUhC6BZnZSpNUnlCjkI7LR6mCtmU_HimOAVE,20893
@@ -60,13 +60,13 @@ jaxsim/rbda/contacts/rigid.py,sha256=MSzkU6SFbW6CryNlyyxQ7K0-U-8k6VROGKv_DQrwqiw
60
60
  jaxsim/rbda/contacts/soft.py,sha256=t6bqBfGAtV1AWoevY82LAcXy2XW8w_uu7bNywcyxF0s,17001
61
61
  jaxsim/rbda/contacts/visco_elastic.py,sha256=vQkfMuqQ3Qu8nbDTPY4jWBZjV3U7qtoRK1Aya3O3oFA,41424
62
62
  jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
63
- jaxsim/terrain/terrain.py,sha256=K91HEzPqTSyNrc_j1KfAAEF_5oDeuk_-jnnZGrcMEcY,5015
63
+ jaxsim/terrain/terrain.py,sha256=_G1QS3zWycj089R8fTP5s2VjcZpEdJxREjXZJ-oXIvc,5248
64
64
  jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
65
65
  jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
66
- jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
67
- jaxsim/utils/wrappers.py,sha256=Fh82ZcaFi5fUnByyFLnmumaobsu1hJIvFdopUVzJ1ps,4052
68
- jaxsim-0.4.3.dev327.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
69
- jaxsim-0.4.3.dev327.dist-info/METADATA,sha256=hgujWPHg9YNK1R6nB4QvTd985MXx1jElYRRbwc_v2W0,17513
70
- jaxsim-0.4.3.dev327.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
71
- jaxsim-0.4.3.dev327.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
72
- jaxsim-0.4.3.dev327.dist-info/RECORD,,
66
+ jaxsim/utils/tracing.py,sha256=eEY28MZW0Lm_jJNt1NkFqZz0ek01tvhR46OXZYCo7tc,532
67
+ jaxsim/utils/wrappers.py,sha256=ZY7olSORzZRvSzkdeNLj8yjwUIAt9L0Douwl7wItjpk,4008
68
+ jaxsim-0.4.3.dev350.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
69
+ jaxsim-0.4.3.dev350.dist-info/METADATA,sha256=qyh1wWUq5dTCw9iznLWlS4DlKa6kfMnAMqZgYbldbCA,17513
70
+ jaxsim-0.4.3.dev350.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
71
+ jaxsim-0.4.3.dev350.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
72
+ jaxsim-0.4.3.dev350.dist-info/RECORD,,