jaxsim 0.2.1.dev123__py3-none-any.whl → 0.2.1.dev155__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/__init__.py CHANGED
@@ -8,8 +8,9 @@ def _jnp_options() -> None:
8
8
 
9
9
  import jax
10
10
 
11
- # Enable by default
12
- if not ("JAX_ENABLE_X64" in os.environ and os.environ["JAX_ENABLE_X64"] == "0"):
11
+ # Enable by default 64bit precision in JAX.
12
+ if os.environ.get("JAX_ENABLE_X64", "1") != "0":
13
+
13
14
  logging.info("Enabling JAX to use 64bit precision")
14
15
  jax.config.update("jax_enable_x64", True)
15
16
 
@@ -27,6 +28,7 @@ def _np_options() -> None:
27
28
 
28
29
 
29
30
  def _is_editable() -> bool:
31
+
30
32
  import importlib.util
31
33
  import pathlib
32
34
  import site
@@ -45,11 +47,40 @@ def _is_editable() -> bool:
45
47
  return jaxsim_package_dir not in site.getsitepackages()
46
48
 
47
49
 
48
- # Initialize the logging verbosity
49
- if _is_editable():
50
- logging.configure(level=logging.LoggingLevel.DEBUG)
51
- else:
52
- logging.configure(level=logging.LoggingLevel.WARNING)
50
+ def _get_default_logging_level(env_var: str) -> logging.LoggingLevel:
51
+ """
52
+ Get the default logging level.
53
+
54
+ Args:
55
+ env_var: The environment variable to check.
56
+
57
+ Returns:
58
+ The logging level to set.
59
+ """
60
+
61
+ import os
62
+
63
+ # Define the default logging level depending on the installation mode.
64
+ default_logging_level = (
65
+ logging.LoggingLevel.DEBUG
66
+ if _is_editable() # noqa: F821
67
+ else logging.LoggingLevel.WARNING
68
+ )
69
+
70
+ # Allow to override the default logging level with an environment variable.
71
+ try:
72
+ return logging.LoggingLevel[
73
+ os.environ.get(env_var, default_logging_level.name).upper()
74
+ ]
75
+
76
+ except KeyError as exc:
77
+ msg = f"Invalid logging level defined in {env_var}='{os.environ[env_var]}'"
78
+ raise RuntimeError(msg) from exc
79
+
80
+
81
+ # Configure the logger with the default logging level.
82
+ logging.configure(level=_get_default_logging_level(env_var="JAXSIM_LOGGING_LEVEL"))
83
+
53
84
 
54
85
  # Configure JAX
55
86
  _jnp_options()
@@ -59,6 +90,7 @@ _np_options()
59
90
 
60
91
  del _jnp_options
61
92
  del _np_options
93
+ del _get_default_logging_level
62
94
  del _is_editable
63
95
 
64
96
  from . import terrain # isort:skip
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.1.dev123'
16
- __version_tuple__ = version_tuple = (0, 2, 1, 'dev123')
15
+ __version__ = version = '0.2.1.dev155'
16
+ __version_tuple__ = version_tuple = (0, 2, 1, 'dev155')
jaxsim/api/contact.py CHANGED
@@ -365,20 +365,20 @@ def jacobian(
365
365
 
366
366
  W_H_C = transforms(model=model, data=data)
367
367
 
368
- def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
368
+ def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
369
369
  C_X_W = jaxsim.math.Adjoint.from_transform(
370
370
  transform=W_H_C, inverse=True
371
371
  )
372
372
  C_J_WC = C_X_W @ W_J_WC
373
373
  return C_J_WC
374
374
 
375
- O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC)
375
+ O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC)
376
376
 
377
377
  case VelRepr.Mixed:
378
378
 
379
379
  W_H_C = transforms(model=model, data=data)
380
380
 
381
- def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
381
+ def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
382
382
 
383
383
  W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
384
384
 
@@ -389,7 +389,7 @@ def jacobian(
389
389
  CW_J_WC = CW_X_W @ W_J_WC
390
390
  return CW_J_WC
391
391
 
392
- O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC)
392
+ O_J_WC = jax.vmap(mixed_jacobian)(W_H_C, W_J_WC)
393
393
 
394
394
  case _:
395
395
  raise ValueError(output_vel_repr)
@@ -6,6 +6,7 @@ import jax.lax
6
6
  import jax.numpy as jnp
7
7
  import jax_dataclasses
8
8
  import jaxlie
9
+ import numpy as np
9
10
  from jax_dataclasses import Static
10
11
 
11
12
  import jaxsim.typing as jtp
@@ -220,7 +221,9 @@ class KynDynParameters(JaxsimDataclass):
220
221
  (
221
222
  hash(self.number_of_links()),
222
223
  hash(self.number_of_joints()),
223
- hash(tuple(jnp.atleast_1d(self.parent_array).flatten().tolist())),
224
+ hash(tuple(np.atleast_1d(self.parent_array).flatten().tolist())),
225
+ hash(self._parent_array),
226
+ hash(self._support_body_array_bool),
224
227
  )
225
228
  )
226
229
 
jaxsim/api/link.py CHANGED
@@ -241,8 +241,8 @@ def jacobian(
241
241
  )
242
242
 
243
243
  # Compute the actual doubly-left free-floating jacobian of the link.
244
- κ = model.kin_dyn_parameters.support_body_array_bool[link_index]
245
- B_J_WL_B = jnp.hstack([jnp.ones(5), κ]) * B_J_full_WX_B
244
+ κb = model.kin_dyn_parameters.support_body_array_bool[link_index]
245
+ B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WX_B
246
246
 
247
247
  # Adjust the input representation such that `J_WL_I @ I_ν`.
248
248
  match data.velocity_representation:
jaxsim/api/model.py CHANGED
@@ -16,7 +16,8 @@ from jax_dataclasses import Static
16
16
  import jaxsim.api as js
17
17
  import jaxsim.parsers.descriptions
18
18
  import jaxsim.typing as jtp
19
- from jaxsim.utils import HashlessObject, JaxsimDataclass, Mutability
19
+ from jaxsim.math import Cross
20
+ from jaxsim.utils import JaxsimDataclass, Mutability
20
21
 
21
22
  from .common import VelRepr
22
23
 
@@ -32,6 +33,7 @@ class JaxSimModel(JaxsimDataclass):
32
33
  terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
33
34
  default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False
34
35
  )
36
+
35
37
  kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
36
38
  dataclasses.field(default=None, repr=False, compare=False, hash=False)
37
39
  )
@@ -40,13 +42,9 @@ class JaxSimModel(JaxsimDataclass):
40
42
  default=None, repr=False, compare=False, hash=False
41
43
  )
42
44
 
43
- _description: Static[
44
- HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None]
45
- ] = dataclasses.field(default=None, repr=False, compare=False, hash=False)
46
-
47
- @property
48
- def description(self) -> jaxsim.parsers.descriptions.ModelDescription:
49
- return self._description.get()
45
+ description: Static[jaxsim.parsers.descriptions.ModelDescription | None] = (
46
+ dataclasses.field(default=None, repr=False, compare=False, hash=False)
47
+ )
50
48
 
51
49
  def __eq__(self, other: JaxSimModel) -> bool:
52
50
 
@@ -60,6 +58,7 @@ class JaxSimModel(JaxsimDataclass):
60
58
  return hash(
61
59
  (
62
60
  hash(self.model_name),
61
+ hash(self.description),
63
62
  hash(self.kin_dyn_parameters),
64
63
  )
65
64
  )
@@ -156,7 +155,7 @@ class JaxSimModel(JaxsimDataclass):
156
155
  # Build the model
157
156
  model = JaxSimModel(
158
157
  model_name=model_name,
159
- _description=HashlessObject(obj=model_description),
158
+ description=model_description,
160
159
  kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
161
160
  model_description=model_description
162
161
  ),
@@ -301,7 +300,7 @@ def reduce(
301
300
  locked_joint_positions:
302
301
  A dictionary containing the positions of the joints to be considered
303
302
  in the reduction process. The removed joints in the reduced model
304
- will have their position locked to their value in this dictionary.
303
+ will have their position locked to their value of this dictionary.
305
304
  If a joint is not part of the dictionary, its position is set to zero.
306
305
  """
307
306
 
@@ -314,10 +313,9 @@ def reduce(
314
313
  new_joints = set(model.joint_names()) - set(locked_joint_positions)
315
314
  raise ValueError(f"Passed joints not existing in the model: {new_joints}")
316
315
 
317
- # Copy the model description with a deep copy of the joints.
318
- intermediate_description = dataclasses.replace(
319
- model.description, joints=copy.deepcopy(model.description.joints)
320
- )
316
+ # Operate on a deep copy of the model description in order to prevent problems
317
+ # when mutable attributes are updated.
318
+ intermediate_description = copy.deepcopy(model.description)
321
319
 
322
320
  # Update the initial position of the joints.
323
321
  # This is necessary to compute the correct pose of the link pairs connected
@@ -685,8 +683,6 @@ def forward_dynamics_aba(
685
683
  another representation C_v̇_WB expressed in a generic frame C.
686
684
  """
687
685
 
688
- from jaxsim.math import Cross
689
-
690
686
  # In Mixed representation, we need to include a cross product in ℝ⁶.
691
687
  # In Inertial and Body representations, the cross product is always zero.
692
688
  C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
@@ -871,6 +867,126 @@ def free_floating_mass_matrix(
871
867
  raise ValueError(data.velocity_representation)
872
868
 
873
869
 
870
+ @jax.jit
871
+ def free_floating_coriolis_matrix(
872
+ model: JaxSimModel, data: js.data.JaxSimModelData
873
+ ) -> jtp.Matrix:
874
+ """
875
+ Compute the free-floating Coriolis matrix of the model.
876
+
877
+ Args:
878
+ model: The model to consider.
879
+ data: The data of the considered model.
880
+
881
+ Returns:
882
+ The free-floating Coriolis matrix of the model.
883
+
884
+ Note:
885
+ This function, contrarily to other quantities of the equations of motion,
886
+ does not exploit any iterative algorithm. Therefore, the computation of
887
+ the Coriolis matrix may be much slower than other quantities.
888
+ """
889
+
890
+ # We perform all the calculation in body-fixed.
891
+ # The Coriolis matrix computed in this representation is converted later
892
+ # to the active representation stored in data.
893
+ with data.switch_velocity_representation(VelRepr.Body):
894
+
895
+ B_ν = data.generalized_velocity()
896
+
897
+ # Doubly-left free-floating Jacobian.
898
+ L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data)
899
+
900
+ # Doubly-left free-floating Jacobian derivative.
901
+ L_J̇_WL_B = jax.vmap(
902
+ lambda link_index: js.link.jacobian_derivative(
903
+ model=model, data=data, link_index=link_index
904
+ )
905
+ )(js.link.names_to_idxs(model=model, link_names=model.link_names()))
906
+
907
+ L_M_L = link_spatial_inertia_matrices(model=model)
908
+
909
+ # Body-fixed link velocities.
910
+ # Note: we could have called link.velocity() instead of computing it ourselves,
911
+ # but since we need the link Jacobians later, we can save a double calculation.
912
+ L_v_WL = jax.vmap(lambda J: J @ B_ν)(L_J_WL_B)
913
+
914
+ # Compute the contribution of each link to the Coriolis matrix.
915
+ def compute_link_contribution(M, v, J, J̇) -> jtp.Array:
916
+
917
+ return J.T @ ((Cross.vx_star(v) @ M + M @ Cross.vx(v)) @ J + M @ J̇)
918
+
919
+ C_B_links = jax.vmap(compute_link_contribution)(
920
+ L_M_L,
921
+ L_v_WL,
922
+ L_J_WL_B,
923
+ L_J̇_WL_B,
924
+ )
925
+
926
+ # We need to adjust the Coriolis matrix for fixed-base models.
927
+ # In this case, the base link does not contribute to the matrix, and we need to zero
928
+ # the off-diagonal terms mapping joint quantities onto the base configuration.
929
+ if model.floating_base():
930
+ C_B = C_B_links.sum(axis=0)
931
+ else:
932
+ C_B = C_B_links[1:].sum(axis=0)
933
+ C_B = C_B.at[0:6, 6:].set(0.0)
934
+ C_B = C_B.at[6:, 0:6].set(0.0)
935
+
936
+ # Adjust the representation of the Coriolis matrix.
937
+ # Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6.
938
+ match data.velocity_representation:
939
+
940
+ case VelRepr.Body:
941
+ return C_B
942
+
943
+ case VelRepr.Inertial:
944
+
945
+ n = model.dofs()
946
+ W_H_B = data.base_transform()
947
+ B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True)
948
+ B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n))
949
+
950
+ with data.switch_velocity_representation(VelRepr.Inertial):
951
+ W_v_WB = data.base_velocity()
952
+ B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)
953
+
954
+ B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n)))
955
+
956
+ with data.switch_velocity_representation(VelRepr.Body):
957
+ M = free_floating_mass_matrix(model=model, data=data)
958
+
959
+ C = B_T_W.T @ (M @ B_Ṫ_W + C_B @ B_T_W)
960
+
961
+ return C
962
+
963
+ case VelRepr.Mixed:
964
+
965
+ n = model.dofs()
966
+ BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
967
+ B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)
968
+ B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n))
969
+
970
+ with data.switch_velocity_representation(VelRepr.Mixed):
971
+ BW_v_WB = data.base_velocity()
972
+ BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
973
+
974
+ BW_v_BW_B = BW_v_WB - BW_v_W_BW
975
+ B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B)
976
+
977
+ B_Ṫ_BW = jax.scipy.linalg.block_diag(B_Ẋ_BW, jnp.zeros(shape=(n, n)))
978
+
979
+ with data.switch_velocity_representation(VelRepr.Body):
980
+ M = free_floating_mass_matrix(model=model, data=data)
981
+
982
+ C = B_T_BW.T @ (M @ B_Ṫ_BW + C_B @ B_T_BW)
983
+
984
+ return C
985
+
986
+ case _:
987
+ raise ValueError(data.velocity_representation)
988
+
989
+
874
990
  @jax.jit
875
991
  def inverse_dynamics(
876
992
  model: JaxSimModel,
@@ -931,8 +1047,6 @@ def inverse_dynamics(
931
1047
  expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
932
1048
  """
933
1049
 
934
- from jaxsim.math import Cross
935
-
936
1050
  W_X_C = jaxlie.SE3.from_matrix(W_H_C).adjoint()
937
1051
  C_X_W = jaxlie.SE3.from_matrix(W_H_C).inverse().adjoint()
938
1052
  C_v_WC = C_X_W @ W_v_WC
@@ -1364,12 +1478,7 @@ def link_bias_accelerations(
1364
1478
  # ================================================
1365
1479
 
1366
1480
  # Compute the base transform.
1367
- W_H_B = jaxlie.SE3.from_rotation_and_translation(
1368
- rotation=jaxlie.SO3.from_quaternion_xyzw(
1369
- xyzw=jaxsim.math.Quaternion.to_xyzw(wxyz=data.base_orientation())
1370
- ),
1371
- translation=data.base_position(),
1372
- ).as_matrix()
1481
+ W_H_B = data.base_transform()
1373
1482
 
1374
1483
  def other_representation_to_inertial(
1375
1484
  C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector
@@ -1410,9 +1519,12 @@ def link_bias_accelerations(
1410
1519
  W_H_C = W_H_BW
1411
1520
  with data.switch_velocity_representation(VelRepr.Mixed):
1412
1521
  W_ṗ_B = data.base_velocity()[0:3]
1413
- W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
1522
+ BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
1523
+ W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW)
1524
+ W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW
1414
1525
  with data.switch_velocity_representation(VelRepr.Mixed):
1415
1526
  C_v_WB = BW_v_WB = data.base_velocity()
1527
+
1416
1528
  case _:
1417
1529
  raise ValueError(data.velocity_representation)
1418
1530
 
jaxsim/api/ode.py CHANGED
@@ -223,7 +223,9 @@ def system_velocity_dynamics(
223
223
 
224
224
  @jax.jit
225
225
  def system_position_dynamics(
226
- model: js.model.JaxSimModel, data: js.data.JaxSimModelData
226
+ model: js.model.JaxSimModel,
227
+ data: js.data.JaxSimModelData,
228
+ baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
227
229
  ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
228
230
  """
229
231
  Compute the dynamics of the system position.
@@ -231,6 +233,8 @@ def system_position_dynamics(
231
233
  Args:
232
234
  model: The model to consider.
233
235
  data: The data of the considered model.
236
+ baumgarte_quaternion_regularization:
237
+ The Baumgarte regularization coefficient for adjusting the quaternion norm.
234
238
 
235
239
  Returns:
236
240
  A tuple containing the derivative of the base position, the derivative of the
@@ -250,6 +254,7 @@ def system_position_dynamics(
250
254
  quaternion=W_Q_B,
251
255
  omega=W_ω_WB,
252
256
  omega_in_body_fixed=False,
257
+ K=baumgarte_quaternion_regularization,
253
258
  ).squeeze()
254
259
 
255
260
  return W_ṗ_B, W_Q̇_B, ṡ
@@ -262,6 +267,7 @@ def system_dynamics(
262
267
  *,
263
268
  joint_forces: jtp.Vector | None = None,
264
269
  link_forces: jtp.Vector | None = None,
270
+ baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
265
271
  ) -> tuple[ODEState, dict[str, Any]]:
266
272
  """
267
273
  Compute the dynamics of the system.
@@ -271,6 +277,9 @@ def system_dynamics(
271
277
  data: The data of the considered model.
272
278
  joint_forces: The joint forces to apply.
273
279
  link_forces: The 6D forces to apply to the links.
280
+ baumgarte_quaternion_regularization:
281
+ The Baumgarte regularization coefficient used to adjust the norm of the
282
+ quaternion (only used in integrators not operating on the SO(3) manifold).
274
283
 
275
284
  Returns:
276
285
  A tuple with an `ODEState` object storing in each of its attributes the
@@ -287,7 +296,11 @@ def system_dynamics(
287
296
  )
288
297
 
289
298
  # Extract the velocities.
290
- W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(model=model, data=data)
299
+ W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
300
+ model=model,
301
+ data=data,
302
+ baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
303
+ )
291
304
 
292
305
  # Create an ODEState object populated with the derivative of each leaf.
293
306
  # Our integrators, operating on generic pytrees, will be able to handle it
@@ -34,8 +34,8 @@ class JointModel:
34
34
  already in a vectorized form. In other words, it cannot be created using vmap.
35
35
  """
36
36
 
37
- λ_H_pre: jax.Array
38
- suc_H_i: jax.Array
37
+ λ_H_pre: jtp.Array
38
+ suc_H_i: jtp.Array
39
39
 
40
40
  joint_dofs: Static[tuple[int, ...]]
41
41
  joint_names: Static[tuple[str, ...]]
jaxsim/mujoco/loaders.py CHANGED
@@ -188,10 +188,9 @@ class RodModelToMjcf:
188
188
  )
189
189
 
190
190
  # If considered joints are passed, make sure that they are all part of the model.
191
- if considered_joints - set([j.name for j in rod_model.joints()]):
192
- extra_joints = set(considered_joints) - set(
193
- [j.name for j in rod_model.joints()]
194
- )
191
+ if considered_joints - {j.name for j in rod_model.joints()}:
192
+ extra_joints = set(considered_joints) - {j.name for j in rod_model.joints()}
193
+
195
194
  msg = f"Couldn't find the following joints in the model: '{extra_joints}'"
196
195
  raise ValueError(msg)
197
196
 
@@ -352,7 +351,7 @@ class RodModelToMjcf:
352
351
  # Set alpha=0 to the color of all collision elements
353
352
  for geometry_element in mujoco_element.findall(".//geom[@rgba]"):
354
353
  if geometry_element.attrib.get("name") in collision_names:
355
- r, g, b, a = geometry_element.attrib["rgba"].split(" ")
354
+ r, g, b, _ = geometry_element.attrib["rgba"].split(" ")
356
355
  geometry_element.set("rgba", f"{r} {g} {b} 0")
357
356
 
358
357
  # -----------------------
jaxsim/mujoco/model.py CHANGED
@@ -73,7 +73,7 @@ class MujocoModelHelper:
73
73
  new_hfield = generate_hfield(heightmap, (nrow, ncol))
74
74
  model.hfield_data = new_hfield
75
75
 
76
- return MujocoModelHelper(model=model, data=mj.MjData(model))
76
+ return MujocoModelHelper(model=model, data=data)
77
77
 
78
78
  def time(self) -> float:
79
79
  """Return the simulation time."""
@@ -173,4 +173,4 @@ class MujocoVisualizer:
173
173
  try:
174
174
  yield handle
175
175
  finally:
176
- handle.close() if close_on_exit else None
176
+ _ = handle.close() if close_on_exit else None
@@ -1,11 +1,13 @@
1
+ from __future__ import annotations
2
+
1
3
  import abc
2
4
  import dataclasses
3
- from typing import List
4
5
 
5
6
  import jax.numpy as jnp
6
7
  import numpy as np
7
8
  import numpy.typing as npt
8
9
 
10
+ import jaxsim.typing as jtp
9
11
  from jaxsim import logging
10
12
 
11
13
  from .link import LinkDescription
@@ -17,9 +19,9 @@ class CollidablePoint:
17
19
  Represents a collidable point associated with a parent link.
18
20
 
19
21
  Attributes:
20
- parent_link (LinkDescription): The parent link to which the collidable point is attached.
21
- position (npt.NDArray): The position of the collidable point relative to the parent link.
22
- enabled (bool): A flag indicating whether the collidable point is enabled for collision detection.
22
+ parent_link: The parent link to which the collidable point is attached.
23
+ position: The position of the collidable point relative to the parent link.
24
+ enabled: A flag indicating whether the collidable point is enabled for collision detection.
23
25
 
24
26
  """
25
27
 
@@ -29,7 +31,7 @@ class CollidablePoint:
29
31
 
30
32
  def change_link(
31
33
  self, new_link: LinkDescription, new_H_old: npt.NDArray
32
- ) -> "CollidablePoint":
34
+ ) -> CollidablePoint:
33
35
  """
34
36
  Move the collidable point to a new parent link.
35
37
 
@@ -39,8 +41,8 @@ class CollidablePoint:
39
41
 
40
42
  Returns:
41
43
  CollidablePoint: A new collidable point associated with the new parent link.
42
-
43
44
  """
45
+
44
46
  msg = f"Moving collidable point: {self.parent_link.name} -> {new_link.name}"
45
47
  logging.debug(msg=msg)
46
48
 
@@ -50,15 +52,24 @@ class CollidablePoint:
50
52
  enabled=self.enabled,
51
53
  )
52
54
 
53
- def __eq__(self, other):
54
- retval = (
55
- self.parent_link == other.parent_link
56
- and (self.position == other.position).all()
57
- and self.enabled == other.enabled
55
+ def __hash__(self) -> int:
56
+
57
+ return hash(
58
+ (
59
+ hash(self.parent_link),
60
+ hash(tuple(self.position.tolist())),
61
+ hash(self.enabled),
62
+ )
58
63
  )
59
- return retval
60
64
 
61
- def __str__(self):
65
+ def __eq__(self, other: CollidablePoint) -> bool:
66
+
67
+ if not isinstance(other, CollidablePoint):
68
+ return False
69
+
70
+ return hash(self) == hash(other)
71
+
72
+ def __str__(self) -> str:
62
73
  return (
63
74
  f"{self.__class__.__name__}("
64
75
  + f"parent_link={self.parent_link.name}"
@@ -74,11 +85,11 @@ class CollisionShape(abc.ABC):
74
85
  Abstract base class for representing collision shapes.
75
86
 
76
87
  Attributes:
77
- collidable_points (List[CollidablePoint]): A list of collidable points associated with the collision shape.
88
+ collidable_points: A list of collidable points associated with the collision shape.
78
89
 
79
90
  """
80
91
 
81
- collidable_points: List[CollidablePoint]
92
+ collidable_points: tuple[CollidablePoint]
82
93
 
83
94
  def __str__(self):
84
95
  return (
@@ -95,14 +106,26 @@ class BoxCollision(CollisionShape):
95
106
  Represents a box-shaped collision shape.
96
107
 
97
108
  Attributes:
98
- center (npt.NDArray): The center of the box in the local frame of the collision shape.
109
+ center: The center of the box in the local frame of the collision shape.
99
110
 
100
111
  """
101
112
 
102
- center: npt.NDArray
113
+ center: jtp.VectorLike
103
114
 
104
- def __eq__(self, other):
105
- return (self.center == other.center).all() and super().__eq__(other)
115
+ def __hash__(self) -> int:
116
+ return hash(
117
+ (
118
+ hash(super()),
119
+ hash(tuple(self.center.tolist())),
120
+ )
121
+ )
122
+
123
+ def __eq__(self, other: BoxCollision) -> bool:
124
+
125
+ if not isinstance(other, BoxCollision):
126
+ return False
127
+
128
+ return hash(self) == hash(other)
106
129
 
107
130
 
108
131
  @dataclasses.dataclass
@@ -111,11 +134,23 @@ class SphereCollision(CollisionShape):
111
134
  Represents a spherical collision shape.
112
135
 
113
136
  Attributes:
114
- center (npt.NDArray): The center of the sphere in the local frame of the collision shape.
137
+ center: The center of the sphere in the local frame of the collision shape.
115
138
 
116
139
  """
117
140
 
118
- center: npt.NDArray
141
+ center: jtp.VectorLike
142
+
143
+ def __hash__(self) -> int:
144
+ return hash(
145
+ (
146
+ hash(super()),
147
+ hash(tuple(self.center.tolist())),
148
+ )
149
+ )
150
+
151
+ def __eq__(self, other: BoxCollision) -> bool:
152
+
153
+ if not isinstance(other, BoxCollision):
154
+ return False
119
155
 
120
- def __eq__(self, other):
121
- return (self.center == other.center).all() and super().__eq__(other)
156
+ return hash(self) == hash(other)
@@ -1,11 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
- from typing import ClassVar, Tuple, Union
4
+ from typing import ClassVar
5
5
 
6
6
  import jax_dataclasses
7
7
  import numpy as np
8
- import numpy.typing as npt
9
8
 
10
9
  import jaxsim.typing as jtp
11
10
  from jaxsim.utils import JaxsimDataclass, Mutability
@@ -15,6 +14,7 @@ from .link import LinkDescription
15
14
 
16
15
  @dataclasses.dataclass(frozen=True)
17
16
  class JointType:
17
+
18
18
  Fixed: ClassVar[int] = 0
19
19
  Revolute: ClassVar[int] = 1
20
20
  Prismatic: ClassVar[int] = 2
@@ -64,29 +64,31 @@ class JointDescription(JaxsimDataclass):
64
64
  """
65
65
 
66
66
  name: jax_dataclasses.Static[str]
67
- axis: npt.NDArray
68
- pose: npt.NDArray
69
- jtype: jax_dataclasses.Static[JointType]
67
+ axis: jtp.Vector
68
+ pose: jtp.Matrix
69
+ jtype: jax_dataclasses.Static[jtp.IntLike]
70
70
  child: LinkDescription = dataclasses.dataclass(repr=False)
71
71
  parent: LinkDescription = dataclasses.dataclass(repr=False)
72
72
 
73
- index: int | None = None
73
+ index: jtp.IntLike | None = None
74
+
75
+ friction_static: jtp.FloatLike = 0.0
76
+ friction_viscous: jtp.FloatLike = 0.0
74
77
 
75
- friction_static: float = 0.0
76
- friction_viscous: float = 0.0
78
+ position_limit_damper: jtp.FloatLike = 0.0
79
+ position_limit_spring: jtp.FloatLike = 0.0
77
80
 
78
- position_limit_damper: float = 0.0
79
- position_limit_spring: float = 0.0
81
+ position_limit: tuple[jtp.FloatLike, jtp.FloatLike] = (0.0, 0.0)
82
+ initial_position: jtp.FloatLike | jtp.VectorLike = 0.0
80
83
 
81
- position_limit: Tuple[float, float] = (0.0, 0.0)
82
- initial_position: Union[float, npt.NDArray] = 0.0
84
+ motor_inertia: jtp.FloatLike = 0.0
85
+ motor_viscous_friction: jtp.FloatLike = 0.0
86
+ motor_gear_ratio: jtp.FloatLike = 1.0
83
87
 
84
- motor_inertia: float = 0.0
85
- motor_viscous_friction: float = 0.0
86
- motor_gear_ratio: float = 1.0
88
+ def __post_init__(self) -> None:
87
89
 
88
- def __post_init__(self):
89
90
  if self.axis is not None:
91
+
90
92
  with self.mutable_context(
91
93
  mutability=Mutability.MUTABLE, restore_after_exception=False
92
94
  ):
@@ -94,4 +96,24 @@ class JointDescription(JaxsimDataclass):
94
96
  self.axis = self.axis / norm_of_axis
95
97
 
96
98
  def __hash__(self) -> int:
97
- return hash(self.__repr__())
99
+
100
+ return hash(
101
+ (
102
+ hash(self.name),
103
+ hash(tuple(self.axis.tolist())),
104
+ hash(tuple(self.pose.flatten().tolist())),
105
+ hash(int(self.jtype)),
106
+ hash(self.child),
107
+ hash(self.parent),
108
+ hash(int(self.index)) if self.index is not None else 0,
109
+ hash(float(self.friction_static)),
110
+ hash(float(self.friction_viscous)),
111
+ hash(float(self.position_limit_damper)),
112
+ hash(float(self.position_limit_spring)),
113
+ hash((float(el) for el in self.position_limit)),
114
+ hash(tuple(np.atleast_1d(self.initial_position).tolist())),
115
+ hash(float(self.motor_inertia)),
116
+ hash(float(self.motor_viscous_friction)),
117
+ hash(float(self.motor_gear_ratio)),
118
+ ),
119
+ )
@@ -5,6 +5,7 @@ import dataclasses
5
5
  import jax.numpy as jnp
6
6
  import jax_dataclasses
7
7
  import jaxlie
8
+ import numpy as np
8
9
  from jax_dataclasses import Static
9
10
 
10
11
  import jaxsim.typing as jtp
@@ -23,7 +24,7 @@ class LinkDescription(JaxsimDataclass):
23
24
  index: An optional index for the link (it gets automatically assigned).
24
25
  parent: The parent link of this link.
25
26
  pose: The pose transformation matrix of the link.
26
- children: List of child links.
27
+ children: The children links.
27
28
  """
28
29
 
29
30
  name: Static[str]
@@ -33,7 +34,7 @@ class LinkDescription(JaxsimDataclass):
33
34
  parent: LinkDescription = dataclasses.field(default=None, repr=False)
34
35
  pose: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.eye(4), repr=False)
35
36
 
36
- children: Static[list[LinkDescription]] = dataclasses.field(
37
+ children: Static[tuple[LinkDescription]] = dataclasses.field(
37
38
  default_factory=list, repr=False
38
39
  )
39
40
 
@@ -43,10 +44,12 @@ class LinkDescription(JaxsimDataclass):
43
44
  (
44
45
  hash(self.name),
45
46
  hash(float(self.mass)),
46
- hash(tuple(self.inertia.flatten().tolist())),
47
- hash(int(self.index)),
48
- hash(self.parent),
49
- hash(tuple(hash(c) for c in self.children)),
47
+ hash(tuple(np.atleast_1d(self.inertia).flatten().tolist())),
48
+ hash(int(self.index)) if self.index is not None else 0,
49
+ hash(tuple(np.atleast_1d(self.pose).flatten().tolist())),
50
+ hash(tuple(self.children)),
51
+ # Here only using the name to prevent circular recursion:
52
+ hash(self.parent.name) if self.parent is not None else 0,
50
53
  )
51
54
  )
52
55
 
@@ -27,7 +27,7 @@ class ModelDescription(KinematicGraph):
27
27
 
28
28
  fixed_base: bool = True
29
29
 
30
- collision_shapes: list[CollisionShape] = dataclasses.field(
30
+ collision_shapes: tuple[CollisionShape, ...] = dataclasses.field(
31
31
  default_factory=list, repr=False, hash=False
32
32
  )
33
33
 
@@ -37,7 +37,7 @@ class ModelDescription(KinematicGraph):
37
37
  links: list[LinkDescription],
38
38
  joints: list[JointDescription],
39
39
  frames: list[LinkDescription] | None = None,
40
- collisions: list[CollisionShape] = (),
40
+ collisions: tuple[CollisionShape, ...] = (),
41
41
  fixed_base: bool = False,
42
42
  base_link_name: str | None = None,
43
43
  considered_joints: Sequence[str] | None = None,
@@ -87,7 +87,7 @@ class ModelDescription(KinematicGraph):
87
87
  for collision_shape in collisions:
88
88
 
89
89
  # Get all the collidable points of the shape
90
- coll_points = list(collision_shape.collidable_points)
90
+ coll_points = tuple(collision_shape.collidable_points)
91
91
 
92
92
  # Assume they have an unique parent link
93
93
  if not len(set({cp.parent_link.name for cp in coll_points})) == 1:
@@ -111,7 +111,7 @@ class ModelDescription(KinematicGraph):
111
111
  continue
112
112
 
113
113
  # Create a new collision shape
114
- new_collision_shape = CollisionShape(collidable_points=[])
114
+ new_collision_shape = CollisionShape(collidable_points=())
115
115
  final_collisions.append(new_collision_shape)
116
116
 
117
117
  # If the frame was found, update the collidable points' pose and add them
@@ -133,19 +133,19 @@ class ModelDescription(KinematicGraph):
133
133
  ),
134
134
  )
135
135
 
136
- # Store the updated collision
137
- new_collision_shape.collidable_points.append(moved_cp)
136
+ # Store the updated collision.
137
+ new_collision_shape.collidable_points += (moved_cp,)
138
138
 
139
139
  # Build the model
140
140
  model = ModelDescription(
141
141
  name=name,
142
142
  root_pose=kinematic_graph.root_pose,
143
143
  fixed_base=fixed_base,
144
- collision_shapes=final_collisions,
144
+ collision_shapes=tuple(final_collisions),
145
145
  root=kinematic_graph.root,
146
146
  joints=kinematic_graph.joints,
147
147
  frames=kinematic_graph.frames,
148
- _joints_removed=kinematic_graph._joints_removed,
148
+ _joints_removed=kinematic_graph.joints_removed,
149
149
  )
150
150
 
151
151
  # Check that the root link of kinematic graph is the desired base link.
@@ -174,7 +174,7 @@ class ModelDescription(KinematicGraph):
174
174
  links=list(self.links_dict.values()),
175
175
  joints=self.joints,
176
176
  frames=self.frames,
177
- collisions=self.collision_shapes,
177
+ collisions=tuple(self.collision_shapes),
178
178
  fixed_base=self.fixed_base,
179
179
  base_link_name=list(iter(self))[0].name,
180
180
  model_pose=self.root_pose,
@@ -182,8 +182,8 @@ class ModelDescription(KinematicGraph):
182
182
  )
183
183
 
184
184
  # Include the unconnected/removed joints from the original model.
185
- for joint in self._joints_removed:
186
- reduced_model_description._joints_removed.append(joint)
185
+ for joint in self.joints_removed:
186
+ reduced_model_description.joints_removed.append(joint)
187
187
 
188
188
  return reduced_model_description
189
189
 
@@ -243,3 +243,23 @@ class ModelDescription(KinematicGraph):
243
243
 
244
244
  # Return enabled collidable points
245
245
  return [cp for cp in all_collidable_points if cp.enabled]
246
+
247
+ def __eq__(self, other: ModelDescription) -> bool:
248
+
249
+ if not isinstance(other, ModelDescription):
250
+ return False
251
+
252
+ return hash(self) == hash(other)
253
+
254
+ def __hash__(self) -> int:
255
+
256
+ return hash(
257
+ (
258
+ hash(self.name),
259
+ hash(self.fixed_base),
260
+ hash(self.root),
261
+ hash(tuple(self.joints)),
262
+ hash(tuple(self.frames)),
263
+ hash(self.root_pose),
264
+ )
265
+ )
@@ -31,14 +31,21 @@ class RootPose(NamedTuple):
31
31
  root_position: npt.NDArray = np.zeros(3)
32
32
  root_quaternion: npt.NDArray = np.array([1.0, 0, 0, 0])
33
33
 
34
+ def __hash__(self) -> int:
35
+
36
+ return hash(
37
+ (
38
+ hash(tuple(self.root_position.tolist())),
39
+ hash(tuple(self.root_quaternion.tolist())),
40
+ )
41
+ )
42
+
34
43
  def __eq__(self, other: RootPose) -> bool:
35
44
 
36
45
  if not isinstance(other, RootPose):
37
46
  return False
38
47
 
39
- return np.allclose(self.root_position, other.root_position) and np.allclose(
40
- self.root_quaternion, other.root_quaternion
41
- )
48
+ return hash(self) == hash(other)
42
49
 
43
50
 
44
51
  @dataclasses.dataclass(frozen=True)
@@ -54,22 +61,24 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
54
61
  """
55
62
 
56
63
  root: descriptions.LinkDescription
57
- frames: list[descriptions.LinkDescription] = dataclasses.field(default_factory=list)
64
+ frames: list[descriptions.LinkDescription] = dataclasses.field(
65
+ default_factory=list, hash=False, compare=False
66
+ )
58
67
  joints: list[descriptions.JointDescription] = dataclasses.field(
59
- default_factory=list
68
+ default_factory=list, hash=False, compare=False
60
69
  )
61
70
 
62
71
  root_pose: RootPose = dataclasses.field(default_factory=lambda: RootPose())
63
72
 
64
73
  # Private attribute storing optional additional info.
65
74
  _extra_info: dict[str, Any] = dataclasses.field(
66
- repr=False, compare=False, default_factory=dict
75
+ default_factory=dict, repr=False, hash=False, compare=False
67
76
  )
68
77
 
69
78
  # Private attribute storing the unconnected joints from the parsed model and
70
79
  # the joints removed after model reduction.
71
80
  _joints_removed: list[descriptions.JointDescription] = dataclasses.field(
72
- default_factory=list, repr=False, compare=False
81
+ default_factory=list, repr=False, hash=False, compare=False
73
82
  )
74
83
 
75
84
  @functools.cached_property
@@ -98,14 +107,17 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
98
107
  for index, link in enumerate(self):
99
108
  link.mutable(validate=False).index = index
100
109
 
101
- # Get the names of the links and frames.
110
+ # Get the names of the links, frames, and joints.
102
111
  link_names = [l.name for l in self]
103
112
  frame_names = [f.name for f in self.frames]
113
+ joint_names = [j.name for j in self.joints]
104
114
 
105
115
  # Make sure that they are unique.
106
116
  assert len(link_names) == len(set(link_names))
107
117
  assert len(frame_names) == len(set(frame_names))
118
+ assert len(joint_names) == len(set(joint_names))
108
119
  assert set(link_names).isdisjoint(set(frame_names))
120
+ assert set(link_names).isdisjoint(set(joint_names))
109
121
 
110
122
  # Order frames with their name.
111
123
  super().__setattr__("frames", sorted(self.frames, key=lambda f: f.name))
@@ -251,7 +263,7 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
251
263
 
252
264
  # Reset the connections of the root link.
253
265
  for link in links_dict.values():
254
- link.children = []
266
+ link.children = tuple()
255
267
 
256
268
  # Couple links and joints creating the kinematic graph.
257
269
  for joint in joints:
@@ -268,7 +280,8 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
268
280
 
269
281
  # Assign link's children and make sure they are unique.
270
282
  if child_link.name not in {l.name for l in parent_link.children}:
271
- parent_link.children.append(child_link)
283
+ with parent_link.mutable_context(Mutability.MUTABLE_NO_VALIDATION):
284
+ parent_link.children = parent_link.children + (child_link,)
272
285
 
273
286
  # Collect all the links of the kinematic graph.
274
287
  all_links_in_graph = list(
@@ -315,7 +328,7 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
315
328
  # Update the unconnected links by removing their children. The other properties
316
329
  # are left untouched, it's caller responsibility to post-process them if needed.
317
330
  for link in unconnected_links:
318
- link.children = []
331
+ link.children = tuple()
319
332
  msg = "Link '{}' won't be part of the kinematic graph because unconnected"
320
333
  logging.debug(msg=msg.format(link.name))
321
334
 
@@ -615,6 +628,17 @@ class KinematicGraph(Sequence[descriptions.LinkDescription]):
615
628
  horizontal=True,
616
629
  )
617
630
 
631
+ @property
632
+ def joints_removed(self) -> list[descriptions.JointDescription]:
633
+ """
634
+ Get the list of joints removed during the graph reduction.
635
+
636
+ Returns:
637
+ The list of removed joints.
638
+ """
639
+
640
+ return self._joints_removed
641
+
618
642
  @staticmethod
619
643
  def breadth_first_search(
620
644
  root: descriptions.LinkDescription,
@@ -785,6 +809,7 @@ class KinematicGraphTransforms:
785
809
 
786
810
  # Get the joint.
787
811
  joint = self.graph.joints_dict[name]
812
+ assert joint.name == name
788
813
 
789
814
  # Get the transform of the parent link.
790
815
  M_H_L = self.transform(name=joint.parent.name)
@@ -7,7 +7,7 @@ import numpy as np
7
7
  import rod
8
8
 
9
9
  from jaxsim import logging
10
- from jaxsim.math.quaternion import Quaternion
10
+ from jaxsim.math import Quaternion
11
11
  from jaxsim.parsers import descriptions, kinematic_graph
12
12
 
13
13
  from . import utils
@@ -6,7 +6,7 @@ import numpy.typing as npt
6
6
  import rod
7
7
 
8
8
  import jaxsim.typing as jtp
9
- from jaxsim.math.inertia import Inertia
9
+ from jaxsim.math import Inertia
10
10
  from jaxsim.parsers import descriptions
11
11
 
12
12
 
@@ -59,9 +59,7 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
59
59
  return M_L.astype(dtype=float)
60
60
 
61
61
 
62
- def joint_to_joint_type(
63
- joint: rod.Joint,
64
- ) -> descriptions.JointType:
62
+ def joint_to_joint_type(joint: rod.Joint) -> int:
65
63
  """
66
64
  Extract the joint type from an SDF joint.
67
65
 
@@ -69,7 +67,7 @@ def joint_to_joint_type(
69
67
  joint: The parsed SDF joint.
70
68
 
71
69
  Returns:
72
- The corresponding joint type description.
70
+ The integer corresponding to the joint type.
73
71
  """
74
72
 
75
73
  axis = joint.axis
@@ -138,7 +136,7 @@ def create_box_collision(
138
136
  collidable_points = [
139
137
  descriptions.CollidablePoint(
140
138
  parent_link=link_description,
141
- position=corner,
139
+ position=np.array(corner),
142
140
  enabled=True,
143
141
  )
144
142
  for corner in box_corners_wrt_link.T
@@ -197,7 +195,7 @@ def create_sphere_collision(
197
195
  collidable_points = [
198
196
  descriptions.CollidablePoint(
199
197
  parent_link=link_description,
200
- position=point,
198
+ position=np.array(point),
201
199
  enabled=True,
202
200
  )
203
201
  for point in sphere_points_wrt_link.T
jaxsim/rbda/crba.py CHANGED
@@ -111,7 +111,7 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
111
111
  # a while loop using a for loop with fixed number of iterations.
112
112
  def inner_fn(carry: CarryInnerFn, k: jtp.Int) -> tuple[CarryInnerFn, None]:
113
113
  def compute_inner(carry: CarryInnerFn) -> tuple[CarryInnerFn, None]:
114
- j, Fi, M = carry
114
+ j, _, _ = carry
115
115
  out = jax.lax.cond(
116
116
  pred=(λ[j] > 0),
117
117
  true_fun=while_loop_body,
@@ -120,7 +120,7 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
120
120
  )
121
121
  return out, None
122
122
 
123
- j, Fi, M = carry
123
+ j, _, _ = carry
124
124
  return jax.lax.cond(
125
125
  pred=(k == j),
126
126
  true_fun=compute_inner,
@@ -49,7 +49,7 @@ def forward_kinematics_model(
49
49
  # Compute the parent-to-child adjoints and the motion subspaces of the joints.
50
50
  # These transforms define the relative kinematics of the entire model, including
51
51
  # the base transform for both floating-base and fixed-base models.
52
- i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
52
+ i_X_λi, _ = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
53
53
  joint_positions=s, base_transform=W_H_B.as_matrix()
54
54
  )
55
55
 
jaxsim/typing.py CHANGED
@@ -1,4 +1,5 @@
1
- from typing import Any, Hashable
1
+ from collections.abc import Hashable
2
+ from typing import Any
2
3
 
3
4
  import jax
4
5
 
@@ -24,6 +25,7 @@ PyTree = (
24
25
  # =======================
25
26
 
26
27
  Array = jax.typing.ArrayLike
28
+ Scalar = Array
27
29
  Vector = Array
28
30
  Matrix = Array
29
31
 
@@ -31,6 +33,7 @@ Int = int | IntJax
31
33
  Bool = bool | ArrayJax
32
34
  Float = float | FloatJax
33
35
 
36
+ ScalarLike = Scalar | int | float
34
37
  ArrayLike = Array
35
38
  VectorLike = Vector
36
39
  MatrixLike = Matrix
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.2.1.dev123
3
+ Version: 0.2.1.dev155
4
4
  Home-page: https://github.com/ami-iit/jaxsim
5
5
  Author: Diego Ferigo
6
6
  Author-email: diego.ferigo@iit.it
@@ -68,7 +68,7 @@ Requires-Dist: mujoco >=3.0.0 ; extra == 'viz'
68
68
 
69
69
  JaxSim is a **differentiable physics engine** and **multibody dynamics library** designed for applications in control and robot learning, implemented with JAX.
70
70
 
71
- Its design facilitates research and accelerates prototyping in the intersection of robotics and artificial intelligence.
71
+ Its design facilitates research and accelerates prototyping in the intersection of robotics and artificial intelligence.
72
72
 
73
73
  ## Features
74
74
 
@@ -91,7 +91,7 @@ Its design facilitates research and accelerates prototyping in the intersection
91
91
 
92
92
  ### JaxSim as a multibody dynamics library
93
93
 
94
- - Provides rigid body dynamics algorithms (RBDAs) like RNEA, ABA, CRBA, and Jacobians.
94
+ - Provides rigid body dynamics algorithms (RBDAs) like RNEA, ABA, CRBA, and Jacobians.
95
95
  - Provides all the quantities included in the Euler-Poincarè formulation of the equations of motion.
96
96
  - Supports body-fixed, inertial-fixed, and mixed [velocity representations][notation].
97
97
  - Exposes all the necessary quantities to develop controllers in centroidal coordinates.
@@ -198,10 +198,10 @@ The main differences between MJX/Brax and JaxSim are as follows:
198
198
 
199
199
  - JaxSim supports out-of-the-box all SDF models with [Pose Frame Semantics][PFS].
200
200
  - JaxSim only supports collisions between points rigidly attached to bodies and a compliant ground surface.
201
- Our contact model requires careful tuning of its spring-damper parameters, but being an instantaneous
201
+ Our contact model requires careful tuning of its spring-damper parameters, but being an instantaneous
202
202
  function of the state $(\mathbf{q}, \boldsymbol{\nu})$, it doesn't require running any optimization algorithm
203
203
  when stepping the simulation forward.
204
- - JaxSim mitigates the stiffness of the contact-aware system dynamics by providing variable-step integrators.
204
+ - JaxSim mitigates the stiffness of the contact-aware system dynamics by providing variable-step integrators.
205
205
 
206
206
  [brax]: https://github.com/google/brax
207
207
  [mjx]: https://mujoco.readthedocs.io/en/3.0.0/mjx.html
@@ -1,18 +1,18 @@
1
- jaxsim/__init__.py,sha256=OcrfoYS1DGcmAGqu2AqlCTiUVxcpi-IsVwcr_16x74Q,1789
2
- jaxsim/_version.py,sha256=jAmAP2ZYBVoFw5W8EyAYOEpIKED5Jy8XK259JnSs0Fk,428
1
+ jaxsim/__init__.py,sha256=xzuTuZrgKdWLqqDzbvqzm2cJrEtAbepOeUqDu7ByVek,2621
2
+ jaxsim/_version.py,sha256=VTwovSkbhBoA3d5Ku6UfV56QTou0p3bVvIXRd63p0yQ,428
3
3
  jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
4
- jaxsim/typing.py,sha256=MeuOCQtLAr-sPkvB_sU8FtwGNRirz1auCwIgRC-QZl8,646
4
+ jaxsim/typing.py,sha256=cl7HHQCeP3mHmtF6EuQZcCjGvDmc_AryMWntP_lRBGg,722
5
5
  jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
6
6
  jaxsim/api/com.py,sha256=Yof6otFi-mLWAs1rqjmeNJTOWIH9gn7BdU5EIjiL6Ts,13481
7
7
  jaxsim/api/common.py,sha256=bqQ__pIQZbh-j8rkoHUkYHAgGiJnDzjHG-q4Ny0OOYQ,6646
8
- jaxsim/api/contact.py,sha256=Cvr-EfQtHP3nymtWdo-9WWU24Bkta-2Pp3nKsdjo6uc,12778
8
+ jaxsim/api/contact.py,sha256=79kcdq7C1_kWgxd1QWBabBhIPkwWEVLk-Fiz9kh-4so,12800
9
9
  jaxsim/api/data.py,sha256=xfKJz6Rw0YTk-EHCGiT8BFQrs_ggOz01lRi1Qh1mb28,27256
10
10
  jaxsim/api/frame.py,sha256=0YXOrGmx3cSQqa4_Ky-n6zyup3I3xvXNEgub-Bc5xUw,6222
11
11
  jaxsim/api/joint.py,sha256=-5DogPg4g4mmLckyVIVNjwv-Rxz0IWS7_md9nDlhPWA,4581
12
- jaxsim/api/kin_dyn_parameters.py,sha256=zMca7OmCsCWK_cavLTSZSeYh9Qu1-409cdsyWvWPAUQ,26090
13
- jaxsim/api/link.py,sha256=oW5-DShmmeCRk3JOJIwzo3HbWuNGmpm_wBJ4fkmrROM,16645
14
- jaxsim/api/model.py,sha256=1HlQ5FMzeJAk-cE1pmELgVjzMYUX9-iipw3N4WssAL4,55435
15
- jaxsim/api/ode.py,sha256=BfvV_14uu0szWecoDiV8rTu-dvSFLK7eyrO38ZqHB_w,10157
12
+ jaxsim/api/kin_dyn_parameters.py,sha256=b1e96I8hKU5fh4StLObdVcDpr_6ZglrgD3SRyrqTu18,26203
13
+ jaxsim/api/link.py,sha256=MdMWaMpM5Dj5JHK8uwHZ4zR4Fjq3R4asi2sGTxk1OAs,16647
14
+ jaxsim/api/model.py,sha256=sCx9CcP23A1I_ae4UqTq4Fpq5u0aDki72CqgnR1H50w,59465
15
+ jaxsim/api/ode.py,sha256=luTQJsIXUtCp_81dR42X7WrMvwrXtYbyJiqss29v7zA,10786
16
16
  jaxsim/api/ode_data.py,sha256=D6FzMkvY_qNuoFEImyp7sxAk-0pJOd3oZeSr9bBTcLk,23089
17
17
  jaxsim/api/references.py,sha256=UA6kSQVBoq-bXSo99EOELf-_MD5MTy2zS0GtG3wQ410,16618
18
18
  jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
@@ -23,31 +23,31 @@ jaxsim/math/__init__.py,sha256=inJ9nRFkqstuGa8OyFkfWVudo5U9Ug4WgDBuKva8AIA,337
23
23
  jaxsim/math/adjoint.py,sha256=DT21izjVW497GRrgNfx8tv0ZeWW5QncWMGMhI0acUNw,4425
24
24
  jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
25
25
  jaxsim/math/inertia.py,sha256=UAB7ym4gXFanejcs_ovZMpteHCc6poWYmt-mLmd5hhk,1640
26
- jaxsim/math/joint_model.py,sha256=hfkEQOglk265P9W0-nvXxI95k21x2deYEv0mud6XUKc,9989
26
+ jaxsim/math/joint_model.py,sha256=xJSocGOyLzLJIQo4j5rBfMCPD4ltUQ2jCZfN747i2Ck,9989
27
27
  jaxsim/math/quaternion.py,sha256=X9b8jHf0QemKUjIZSnXRJc3DdMr42CBhBy_mi9_X_AM,5068
28
28
  jaxsim/math/rotation.py,sha256=Z90daUjGpuNEVLfWB3SVtM9EtwAIaneVj9A9UpWXqhA,2182
29
29
  jaxsim/math/skew.py,sha256=oOGSSR8PUGROl6IJFlrmu6K3gPH-u16hUPfKIkcVv9o,1177
30
30
  jaxsim/math/transform.py,sha256=ZKzoXwKRowE17p3P0EgKrq8mocdMTGUKNyaKt6-bAX0,2941
31
31
  jaxsim/mujoco/__init__.py,sha256=Zo5GAlN1DYKvX8s1hu1j6HntKIbBMLB9Puv9ouaNAZ8,158
32
32
  jaxsim/mujoco/__main__.py,sha256=GBmB7J-zj75ZnFyuAAmpSOpbxi_HhHhWJeot3ljGDJY,5291
33
- jaxsim/mujoco/loaders.py,sha256=Tq3b2tq_BulGm9GCM00NQMzcD7B0yKvoyiP94aFbjx4,21023
34
- jaxsim/mujoco/model.py,sha256=5-4KTbEbU19zjrSuvUVdLo3noWxTvlCNsFIs3rQTNDY,13506
35
- jaxsim/mujoco/visualizer.py,sha256=YlteqcCbeB1B6saAHKBz1IJad3N5Rp163reZrzKLAys,5065
33
+ jaxsim/mujoco/loaders.py,sha256=7rjpeJ6_GuitlCty-ZkLhTILQ0GmsFzDMgve-7Gkkh4,20984
34
+ jaxsim/mujoco/model.py,sha256=1KVRjSLOTCuHt53apBPQTnFYJRknlVoKLQaxWsNK8qc,13494
35
+ jaxsim/mujoco/visualizer.py,sha256=PXgQzwetS9mRJYHBknDMLsQ9152FdrSvZuT9xE_dfIQ,5069
36
36
  jaxsim/parsers/__init__.py,sha256=sonYi-bBWAoB04kp1mxT4uIORxjb7SdZ0ukGPmVx98Y,44
37
- jaxsim/parsers/kinematic_graph.py,sha256=zFt7x7pPGJar36Azukdi1eI_sa1kMWD3B8kZqcHx6iw,33934
37
+ jaxsim/parsers/kinematic_graph.py,sha256=WdIxntWfxXf67x90oM5KHHXFrSITMwVahqWgcOjYFzc,34730
38
38
  jaxsim/parsers/descriptions/__init__.py,sha256=PbIlunVfb59pB5jSX97YVpMAANRZPRkJ0X-hS14rzv4,221
39
- jaxsim/parsers/descriptions/collision.py,sha256=HUWwuRgI9KznY29FFw1_zU3bGigDEezrcPOJSxSJGNU,3382
40
- jaxsim/parsers/descriptions/joint.py,sha256=lRnYMmjpASpz0Ueuqzwnj5Ze4yLRgPTx66H0_kbQnNI,3042
41
- jaxsim/parsers/descriptions/link.py,sha256=GC-6ZgRZuRVpcRo1sY6YaR8lkCHkR4DvHNs2Ydw_tn4,2887
42
- jaxsim/parsers/descriptions/model.py,sha256=uO5xOJtViihVPnSSsmfQJvCh45ANyi9KYAzLOhH0R8g,8993
39
+ jaxsim/parsers/descriptions/collision.py,sha256=BQeIG-TKi4SVny23w6riDrQ5itC6VRwEMBX6HgAXHxA,3973
40
+ jaxsim/parsers/descriptions/joint.py,sha256=z_nYSS0fdkcaerjUlPX0U1Vn1ArBT0u_XdKjqxG3HcY,3959
41
+ jaxsim/parsers/descriptions/link.py,sha256=QvEE7J6iMQLibpLqlcBV428UA7NMpFFXJwe35GYnjAY,3124
42
+ jaxsim/parsers/descriptions/model.py,sha256=V9nSyCK3mo7680WYMDEx1MTfdDTJzbCGPqAp3qA2XRE,9511
43
43
  jaxsim/parsers/rod/__init__.py,sha256=G2vqlLajBLUc4gyzXwsEI2Wsi4TMOIF9bLDFeT6KrGU,92
44
- jaxsim/parsers/rod/parser.py,sha256=Q13TOkmpU0SHpgSV8WRYWb290aPNNLsaz4eMlD4Mq5w,13525
45
- jaxsim/parsers/rod/utils.py,sha256=9oO4YsQRaR2v700IkNOXRnPpn5i4N8HFfgjPkMLK2mc,5732
44
+ jaxsim/parsers/rod/parser.py,sha256=4COuhkAYv4-GIpCqvkXEJWpDEQczEkBM3KwpqX48Rek,13514
45
+ jaxsim/parsers/rod/utils.py,sha256=KSjgy6WsmTrD5HZEA2x8hOBSRU4bUGOOHzxKkeFO5r8,5721
46
46
  jaxsim/rbda/__init__.py,sha256=MqEZwzu8SHPAlIFHmSXmCjehuOJGRX58OrBVAbBVMwg,374
47
47
  jaxsim/rbda/aba.py,sha256=0OoCzHhf1v-qqr1y5PIrD7_mPwAlid0fjXxUrIa5E_s,9118
48
48
  jaxsim/rbda/collidable_points.py,sha256=4ZNJbEj2nEi15jBLR-GNbdaqKgkN58FBgqd_TXupEgg,4948
49
- jaxsim/rbda/crba.py,sha256=GodskOZjtrSlbQAqxRv1un_706O7BaJK-U2qa18vJk8,4741
50
- jaxsim/rbda/forward_kinematics.py,sha256=OHugNU7C0UxYAW0o1rqH1ZgniSwurz6L1T1MJxfxq08,3418
49
+ jaxsim/rbda/crba.py,sha256=awsWEQXLE0UPEXIcZCVsAqBEPjyahMNzY9ux6nE1l-s,4739
50
+ jaxsim/rbda/forward_kinematics.py,sha256=94W7TUXvZjMb-99CyYR8pObuxIYYX9B_dtRZqsNcThs,3418
51
51
  jaxsim/rbda/jacobian.py,sha256=M79bGir-2w_iJ2GurYhOGgMfJnp7ZMOCW6AeeWKK8iM,10745
52
52
  jaxsim/rbda/rnea.py,sha256=DjwkvXQVUSUclM3Uy3UPZ2tao91R5dGd4o7TsS2qObI,7650
53
53
  jaxsim/rbda/soft_contacts.py,sha256=52zJOF31hFpqoaOednTvi8j_UxhRcdGNjzOPb2v2MPc,11257
@@ -58,8 +58,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
58
58
  jaxsim/utils/jaxsim_dataclass.py,sha256=h26timZ_XrBL_Q_oymv-DkQd-EcUiHn8QexAaZXBY9c,11396
59
59
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
60
60
  jaxsim/utils/wrappers.py,sha256=EJMcblYKUjxw9HJShVf81Ig3pHUJno6Dx6h-RnY--wM,2040
61
- jaxsim-0.2.1.dev123.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
62
- jaxsim-0.2.1.dev123.dist-info/METADATA,sha256=hOUnUjhHwGvdZVMR_-z5pKFsd681swunNdYkqgC338I,9745
63
- jaxsim-0.2.1.dev123.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
64
- jaxsim-0.2.1.dev123.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
65
- jaxsim-0.2.1.dev123.dist-info/RECORD,,
61
+ jaxsim-0.2.1.dev155.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
62
+ jaxsim-0.2.1.dev155.dist-info/METADATA,sha256=Jhyuk3qGnK7WzHYbuA4b5mCXofQw7WZAj6zqh01jKkU,9740
63
+ jaxsim-0.2.1.dev155.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
64
+ jaxsim-0.2.1.dev155.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
65
+ jaxsim-0.2.1.dev155.dist-info/RECORD,,