jaxsim 0.4.3.dev327__tar.gz → 0.4.3.dev350__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/PKG-INFO +1 -1
  2. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/_version.py +2 -2
  3. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/api/joint.py +8 -9
  4. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/api/kin_dyn_parameters.py +6 -4
  5. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/api/link.py +3 -4
  6. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/api/model.py +10 -19
  7. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/api/references.py +1 -1
  8. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/integrators/common.py +2 -2
  9. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/integrators/variable_step.py +6 -12
  10. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/mujoco/loaders.py +9 -138
  11. jaxsim-0.4.3.dev350/src/jaxsim/mujoco/utils.py +223 -0
  12. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/descriptions/joint.py +1 -26
  13. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/kinematic_graph.py +3 -3
  14. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/rod/parser.py +3 -6
  15. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/rod/utils.py +1 -1
  16. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/jacobian.py +2 -2
  17. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/utils.py +1 -1
  18. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/terrain/terrain.py +9 -1
  19. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/utils/tracing.py +3 -9
  20. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/utils/wrappers.py +1 -1
  21. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim.egg-info/PKG-INFO +1 -1
  22. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/tests/test_api_data.py +5 -3
  23. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/tests/test_api_joint.py +1 -1
  24. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/tests/test_api_link.py +1 -1
  25. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/tests/test_api_model.py +8 -6
  26. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/tests/test_exceptions.py +10 -12
  27. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/tests/test_pytree.py +6 -7
  28. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/tests/test_simulations.py +4 -0
  29. jaxsim-0.4.3.dev327/src/jaxsim/mujoco/utils.py +0 -101
  30. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/.devcontainer/Dockerfile +0 -0
  31. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/.devcontainer/devcontainer.json +0 -0
  32. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/.gitattributes +0 -0
  33. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/.github/CODEOWNERS +0 -0
  34. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/.github/dependabot.yml +0 -0
  35. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/.github/workflows/ci_cd.yml +0 -0
  36. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/.github/workflows/pixi.yml +0 -0
  37. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/.github/workflows/read_the_docs.yml +0 -0
  38. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/.gitignore +0 -0
  39. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/.pre-commit-config.yaml +0 -0
  40. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/.readthedocs.yaml +0 -0
  41. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/CONTRIBUTING.md +0 -0
  42. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/LICENSE +0 -0
  43. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/README.md +0 -0
  44. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/docs/Makefile +0 -0
  45. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/docs/conf.py +0 -0
  46. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/docs/examples.rst +0 -0
  47. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/docs/guide/install.rst +0 -0
  48. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/docs/index.rst +0 -0
  49. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/docs/make.bat +0 -0
  50. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/docs/modules/api.rst +0 -0
  51. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/docs/modules/integrators.rst +0 -0
  52. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/docs/modules/math.rst +0 -0
  53. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/docs/modules/mujoco.rst +0 -0
  54. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/docs/modules/parsers.rst +0 -0
  55. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/docs/modules/rbda.rst +0 -0
  56. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/docs/modules/typing.rst +0 -0
  57. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/docs/modules/utils.rst +0 -0
  58. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/environment.yml +0 -0
  59. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/examples/.gitattributes +0 -0
  60. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/examples/.gitignore +0 -0
  61. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/examples/README.md +0 -0
  62. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/examples/assets/build_cartpole_urdf.py +0 -0
  63. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/examples/assets/cartpole.urdf +0 -0
  64. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/examples/jaxsim_as_multibody_dynamics_library.ipynb +0 -0
  65. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/examples/jaxsim_as_physics_engine.ipynb +0 -0
  66. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/examples/jaxsim_for_robot_controllers.ipynb +0 -0
  67. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/pixi.lock +0 -0
  68. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/pyproject.toml +0 -0
  69. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/setup.cfg +0 -0
  70. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/setup.py +0 -0
  71. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/__init__.py +0 -0
  72. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/api/__init__.py +0 -0
  73. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/api/com.py +0 -0
  74. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/api/common.py +0 -0
  75. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/api/contact.py +0 -0
  76. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/api/data.py +0 -0
  77. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/api/frame.py +0 -0
  78. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/api/ode.py +0 -0
  79. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/api/ode_data.py +0 -0
  80. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/exceptions.py +0 -0
  81. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/integrators/__init__.py +0 -0
  82. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/integrators/fixed_step.py +0 -0
  83. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/logging.py +0 -0
  84. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/math/__init__.py +0 -0
  85. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/math/adjoint.py +0 -0
  86. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/math/cross.py +0 -0
  87. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/math/inertia.py +0 -0
  88. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/math/joint_model.py +0 -0
  89. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/math/quaternion.py +0 -0
  90. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/math/rotation.py +0 -0
  91. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/math/skew.py +0 -0
  92. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/math/transform.py +0 -0
  93. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/mujoco/__init__.py +0 -0
  94. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/mujoco/__main__.py +0 -0
  95. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/mujoco/model.py +0 -0
  96. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/mujoco/visualizer.py +0 -0
  97. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/__init__.py +0 -0
  98. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  99. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  100. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/descriptions/link.py +0 -0
  101. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/descriptions/model.py +0 -0
  102. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/parsers/rod/__init__.py +0 -0
  103. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/__init__.py +0 -0
  104. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/aba.py +0 -0
  105. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/collidable_points.py +0 -0
  106. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/contacts/__init__.py +0 -0
  107. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/contacts/common.py +0 -0
  108. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/contacts/relaxed_rigid.py +0 -0
  109. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/contacts/rigid.py +0 -0
  110. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/contacts/soft.py +0 -0
  111. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/contacts/visco_elastic.py +0 -0
  112. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/crba.py +0 -0
  113. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  114. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/rbda/rnea.py +0 -0
  115. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/terrain/__init__.py +0 -0
  116. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/typing.py +0 -0
  117. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/utils/__init__.py +0 -0
  118. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  119. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  120. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  121. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim.egg-info/requires.txt +0 -0
  122. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/src/jaxsim.egg-info/top_level.txt +0 -0
  123. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/tests/__init__.py +0 -0
  124. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/tests/conftest.py +0 -0
  125. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/tests/test_api_com.py +0 -0
  126. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/tests/test_api_contact.py +0 -0
  127. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/tests/test_api_frame.py +0 -0
  128. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/tests/test_automatic_differentiation.py +0 -0
  129. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/tests/test_contact.py +0 -0
  130. {jaxsim-0.4.3.dev327 → jaxsim-0.4.3.dev350}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev327
3
+ Version: 0.4.3.dev350
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev327'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev327')
15
+ __version__ = version = '0.4.3.dev350'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev350')
@@ -53,9 +53,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str
53
53
  """
54
54
 
55
55
  exceptions.raise_value_error_if(
56
- condition=jnp.array(
57
- [joint_index < 0, joint_index >= model.number_of_joints()]
58
- ).any(),
56
+ condition=joint_index < 0,
59
57
  msg="Invalid joint index '{idx}'",
60
58
  idx=joint_index,
61
59
  )
@@ -123,10 +121,7 @@ def position_limit(
123
121
  """
124
122
 
125
123
  if model.number_of_joints() == 0:
126
- s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min
127
- s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max
128
-
129
- return jnp.atleast_1d(s_min).astype(float), jnp.atleast_1d(s_max).astype(float)
124
+ return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
130
125
 
131
126
  exceptions.raise_value_error_if(
132
127
  condition=jnp.array(
@@ -136,8 +131,12 @@ def position_limit(
136
131
  idx=joint_index,
137
132
  )
138
133
 
139
- s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_index]
140
- s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_index]
134
+ s_min = jnp.atleast_1d(
135
+ model.kin_dyn_parameters.joint_parameters.position_limits_min
136
+ )[joint_index]
137
+ s_max = jnp.atleast_1d(
138
+ model.kin_dyn_parameters.joint_parameters.position_limits_max
139
+ )[joint_index]
141
140
 
142
141
  return s_min.astype(float), s_max.astype(float)
143
142
 
@@ -438,7 +438,9 @@ class KynDynParameters(JaxsimDataclass):
438
438
  # Helpers to update parameters
439
439
  # ============================
440
440
 
441
- def set_link_mass(self, link_index: int, mass: jtp.FloatLike) -> KynDynParameters:
441
+ def set_link_mass(
442
+ self, link_index: jtp.IntLike, mass: jtp.FloatLike
443
+ ) -> KynDynParameters:
442
444
  """
443
445
  Set the mass of a link.
444
446
 
@@ -457,7 +459,7 @@ class KynDynParameters(JaxsimDataclass):
457
459
  return self.replace(link_parameters=link_parameters)
458
460
 
459
461
  def set_link_inertia(
460
- self, link_index: int, inertia: jtp.MatrixLike
462
+ self, link_index: jtp.IntLike, inertia: jtp.MatrixLike
461
463
  ) -> KynDynParameters:
462
464
  r"""
463
465
  Set the inertia tensor of a link.
@@ -593,10 +595,10 @@ class LinkParameters(JaxsimDataclass):
593
595
  """
594
596
 
595
597
  # Extract the link parameters from the 6D spatial inertia.
596
- m, L_p_CoM, I = Inertia.to_params(M=M)
598
+ m, L_p_CoM, I_CoM = Inertia.to_params(M=M)
597
599
 
598
600
  # Extract only the necessary elements of the inertia tensor.
599
- inertia_elements = I[jnp.triu_indices(3)]
601
+ inertia_elements = I_CoM[jnp.triu_indices(3)]
600
602
 
601
603
  return LinkParameters(
602
604
  index=jnp.array(index).squeeze().astype(int),
@@ -4,6 +4,7 @@ from collections.abc import Sequence
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
  import jax.scipy.linalg
7
+ import numpy as np
7
8
 
8
9
  import jaxsim.api as js
9
10
  import jaxsim.rbda
@@ -54,9 +55,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
54
55
  """
55
56
 
56
57
  exceptions.raise_value_error_if(
57
- condition=jnp.array(
58
- [link_index < 0, link_index >= model.number_of_links()]
59
- ).any(),
58
+ condition=link_index < 0,
60
59
  msg="Invalid link index '{idx}'",
61
60
  idx=link_index,
62
61
  )
@@ -98,7 +97,7 @@ def idxs_to_names(
98
97
  The names of the links.
99
98
  """
100
99
 
101
- return tuple(idx_to_name(model=model, link_index=idx) for idx in link_indices)
100
+ return tuple(np.array(model.kin_dyn_parameters.link_names)[list(link_indices)])
102
101
 
103
102
 
104
103
  # =========
@@ -304,7 +304,7 @@ class JaxSimModel(JaxsimDataclass):
304
304
 
305
305
  return self.model_name
306
306
 
307
- def number_of_links(self) -> jtp.Int:
307
+ def number_of_links(self) -> int:
308
308
  """
309
309
  Return the number of links in the model.
310
310
 
@@ -317,7 +317,7 @@ class JaxSimModel(JaxsimDataclass):
317
317
 
318
318
  return self.kin_dyn_parameters.number_of_links()
319
319
 
320
- def number_of_joints(self) -> jtp.Int:
320
+ def number_of_joints(self) -> int:
321
321
  """
322
322
  Return the number of joints in the model.
323
323
 
@@ -419,7 +419,7 @@ class JaxSimModel(JaxsimDataclass):
419
419
  def reduce(
420
420
  model: JaxSimModel,
421
421
  considered_joints: tuple[str, ...],
422
- locked_joint_positions: dict[str, jtp.Float] | None = None,
422
+ locked_joint_positions: dict[str, jtp.FloatLike] | None = None,
423
423
  ) -> JaxSimModel:
424
424
  """
425
425
  Reduce the model by lumping together the links connected by removed joints.
@@ -1038,12 +1038,7 @@ def forward_dynamics_aba(
1038
1038
  C_v̇_WB = to_active(
1039
1039
  W_v̇_WB=W_v̇_WB,
1040
1040
  W_H_C=W_H_C,
1041
- W_v_WB=jnp.hstack(
1042
- [
1043
- data.state.physics_model.base_linear_velocity,
1044
- data.state.physics_model.base_angular_velocity,
1045
- ]
1046
- ),
1041
+ W_v_WB=W_v_WB,
1047
1042
  W_v_WC=W_v_WC,
1048
1043
  )
1049
1044
 
@@ -2274,16 +2269,12 @@ def step(
2274
2269
  # Raise runtime error for not supported case in which Rigid contacts and
2275
2270
  # Baumgarte stabilization are enabled and used with ForwardEuler integrator.
2276
2271
  jaxsim.exceptions.raise_runtime_error_if(
2277
- condition=jnp.logical_and(
2278
- isinstance(
2279
- integrator,
2280
- jaxsim.integrators.fixed_step.ForwardEuler
2281
- | jaxsim.integrators.fixed_step.ForwardEulerSO3,
2282
- ),
2283
- jnp.array(
2284
- [data_tf.contacts_params.K, data_tf.contacts_params.D]
2285
- ).any(),
2286
- ),
2272
+ condition=isinstance(
2273
+ integrator,
2274
+ jaxsim.integrators.fixed_step.ForwardEuler
2275
+ | jaxsim.integrators.fixed_step.ForwardEulerSO3,
2276
+ )
2277
+ & ((data_tf.contacts_params.K > 0) | (data_tf.contacts_params.D > 0)),
2287
2278
  msg="Baumgarte stabilization is not supported with ForwardEuler integrators",
2288
2279
  )
2289
2280
 
@@ -503,7 +503,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
503
503
  ]
504
504
 
505
505
  exceptions.raise_value_error_if(
506
- condition=jnp.logical_not(data.valid(model=model)),
506
+ condition=~data.valid(model=model),
507
507
  msg="The provided data is not valid for the model",
508
508
  )
509
509
  W_H_Fi = jax.vmap(
@@ -319,7 +319,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
319
319
  f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
320
320
 
321
321
  # Initialize the carry of the for loop with the stacked kᵢ vectors.
322
- carry0 = jax.tree_map(
322
+ carry0 = jax.tree.map(
323
323
  lambda l: jnp.zeros((c.size, *l.shape), dtype=l.dtype), x0
324
324
  )
325
325
 
@@ -507,7 +507,7 @@ class ExplicitRungeKuttaSO3Mixin:
507
507
 
508
508
  # We assume that the initial quaternion is already unary.
509
509
  exceptions.raise_runtime_error_if(
510
- condition=jnp.logical_not(jnp.allclose(W_Q_B_t0.dot(W_Q_B_t0), 1.0)),
510
+ condition=~jnp.allclose(W_Q_B_t0.dot(W_Q_B_t0), 1.0),
511
511
  msg="The SO(3) integrator received a quaternion at t0 that is not unary.",
512
512
  )
513
513
 
@@ -152,7 +152,7 @@ def compute_pytree_scale(
152
152
  """
153
153
 
154
154
  # Consider a zero second pytree, if not given.
155
- x2 = jax.tree.map(lambda l: jnp.zeros_like(l), x1) if x2 is None else x2
155
+ x2 = jax.tree.map(jnp.zeros_like, x1) if x2 is None else x2
156
156
 
157
157
  # Compute the scaling factors of the initial state and its derivative.
158
158
  compute_scale = lambda l1, l2: atol + jnp.maximum(jnp.abs(l1), jnp.abs(l2)) * rtol
@@ -199,9 +199,7 @@ def local_error_estimation(
199
199
 
200
200
  # Consider a zero estimated final state, if not given.
201
201
  xf_estimate = (
202
- jax.tree.map(lambda l: jnp.zeros_like(l), xf)
203
- if xf_estimate is None
204
- else xf_estimate
202
+ jax.tree.map(jnp.zeros_like, xf) if xf_estimate is None else xf_estimate
205
203
  )
206
204
 
207
205
  # Estimate the error.
@@ -483,14 +481,10 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
483
481
  metadata_next,
484
482
  discarded_steps,
485
483
  ) = jax.lax.cond(
486
- pred=jnp.array(
487
- [
488
- discarded_steps >= self.max_step_rejections,
489
- local_error <= 1.0,
490
- Δt_next < self.dt_min,
491
- integrator_init,
492
- ]
493
- ).any(),
484
+ pred=discarded_steps
485
+ >= self.max_step_rejections | local_error
486
+ <= 1.0 | Δt_next
487
+ < self.dt_min | integrator_init,
494
488
  true_fun=accept_step,
495
489
  false_fun=reject_step,
496
490
  )
@@ -1,6 +1,3 @@
1
- from __future__ import annotations
2
-
3
- import dataclasses
4
1
  import pathlib
5
2
  import tempfile
6
3
  import warnings
@@ -9,10 +6,14 @@ from typing import Any
9
6
 
10
7
  import mujoco as mj
11
8
  import numpy as np
12
- import numpy.typing as npt
13
9
  import rod.urdf.exporter
14
10
  from lxml import etree as ET
15
- from scipy.spatial.transform import Rotation
11
+
12
+ from .utils import MujocoCamera
13
+
14
+ MujocoCameraType = (
15
+ MujocoCamera | Sequence[MujocoCamera] | dict[str, str] | Sequence[dict[str, str]]
16
+ )
16
17
 
17
18
 
18
19
  def load_rod_model(
@@ -167,12 +168,7 @@ class RodModelToMjcf:
167
168
  plane_normal: tuple[float, float, float] = (0, 0, 1),
168
169
  heightmap: bool | None = None,
169
170
  heightmap_samples_xy: tuple[int, int] = (101, 101),
170
- cameras: (
171
- MujocoCamera
172
- | Sequence[MujocoCamera]
173
- | dict[str, str]
174
- | Sequence[dict[str, str]]
175
- ) = (),
171
+ cameras: MujocoCameraType = (),
176
172
  ) -> tuple[str, dict[str, Any]]:
177
173
  """
178
174
  Converts a ROD model to a Mujoco MJCF string.
@@ -533,12 +529,7 @@ class UrdfToMjcf:
533
529
  model_name: str | None = None,
534
530
  plane_normal: tuple[float, float, float] = (0, 0, 1),
535
531
  heightmap: bool | None = None,
536
- cameras: (
537
- MujocoCamera
538
- | Sequence[MujocoCamera]
539
- | dict[str, str]
540
- | Sequence[dict[str, str]]
541
- ) = (),
532
+ cameras: MujocoCameraType = (),
542
533
  ) -> tuple[str, dict[str, Any]]:
543
534
  """
544
535
  Converts a URDF file to a Mujoco MJCF string.
@@ -580,12 +571,7 @@ class SdfToMjcf:
580
571
  model_name: str | None = None,
581
572
  plane_normal: tuple[float, float, float] = (0, 0, 1),
582
573
  heightmap: bool | None = None,
583
- cameras: (
584
- MujocoCamera
585
- | Sequence[MujocoCamera]
586
- | dict[str, str]
587
- | Sequence[dict[str, str]]
588
- ) = (),
574
+ cameras: MujocoCameraType = (),
589
575
  ) -> tuple[str, dict[str, Any]]:
590
576
  """
591
577
  Converts a SDF file to a Mujoco MJCF string.
@@ -617,118 +603,3 @@ class SdfToMjcf:
617
603
  heightmap=heightmap,
618
604
  cameras=cameras,
619
605
  )
620
-
621
-
622
- @dataclasses.dataclass
623
- class MujocoCamera:
624
- """
625
- Helper class storing parameters of a Mujoco camera.
626
-
627
- Refer to the official documentation for more details:
628
- https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-camera
629
- """
630
-
631
- mode: str = "fixed"
632
-
633
- target: str | None = None
634
- fovy: str = "45"
635
- pos: str = "0 0 0"
636
-
637
- quat: str | None = None
638
- axisangle: str | None = None
639
- xyaxes: str | None = None
640
- zaxis: str | None = None
641
- euler: str | None = None
642
-
643
- name: str | None = None
644
-
645
- @classmethod
646
- def build(cls, **kwargs) -> MujocoCamera:
647
-
648
- if not all(isinstance(value, str) for value in kwargs.values()):
649
- raise ValueError(f"Values must be strings: {kwargs}")
650
-
651
- return cls(**kwargs)
652
-
653
- @staticmethod
654
- def build_from_target_view(
655
- camera_name: str,
656
- lookat: Sequence[float | int] | npt.NDArray = (0, 0, 0),
657
- distance: float | int | npt.NDArray = 3,
658
- azimut: float | int | npt.NDArray = 90,
659
- elevation: float | int | npt.NDArray = -45,
660
- fovy: float | int | npt.NDArray = 45,
661
- degrees: bool = True,
662
- **kwargs,
663
- ) -> MujocoCamera:
664
- """
665
- Create a custom camera that looks at a target point.
666
-
667
- Note:
668
- The choice of the parameters is easier if we imagine to consider a target
669
- frame `T` whose origin is located over the lookat point and having the same
670
- orientation of the world frame `W`. We also introduce a camera frame `C`
671
- whose origin is located over the lower-left corner of the image, and having
672
- the x-axis pointing right and the y-axis pointing up in image coordinates.
673
- The camera renders what it sees in the -z direction of frame `C`.
674
-
675
- Args:
676
- camera_name: The name of the camera.
677
- lookat: The target point to look at (origin of `T`).
678
- distance:
679
- The distance from the target point (displacement between the origins
680
- of `T` and `C`).
681
- azimut:
682
- The rotation around z of the camera. With an angle of 0, the camera
683
- would loot at the target point towards the positive x-axis of `T`.
684
- elevation:
685
- The rotation around the x-axis of the camera frame `C`. Note that if
686
- you want to lift the view angle, the elevation is negative.
687
- fovy: The field of view of the camera.
688
- degrees: Whether the angles are in degrees or radians.
689
- **kwargs: Additional camera parameters.
690
-
691
- Returns:
692
- The custom camera.
693
- """
694
-
695
- # Start from a frame whose origin is located over the lookat point.
696
- # We initialize a -90 degrees rotation around the z-axis because due to
697
- # the default camera coordinate system (x pointing right, y pointing up).
698
- W_H_C = np.eye(4)
699
- W_H_C[0:3, 3] = np.array(lookat)
700
- W_H_C[0:3, 0:3] = Rotation.from_euler(
701
- seq="ZX", angles=[-90, 90], degrees=True
702
- ).as_matrix()
703
-
704
- # Process the azimut.
705
- R_az = Rotation.from_euler(seq="Y", angles=azimut, degrees=degrees).as_matrix()
706
- W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_az
707
-
708
- # Process elevation.
709
- R_el = Rotation.from_euler(
710
- seq="X", angles=elevation, degrees=degrees
711
- ).as_matrix()
712
- W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_el
713
-
714
- # Process distance.
715
- tf_distance = np.eye(4)
716
- tf_distance[2, 3] = distance
717
- W_H_C = W_H_C @ tf_distance
718
-
719
- # Extract the position and the quaternion.
720
- p = W_H_C[0:3, 3]
721
- Q = Rotation.from_matrix(W_H_C[0:3, 0:3]).as_quat(scalar_first=True)
722
-
723
- return MujocoCamera.build(
724
- name=camera_name,
725
- mode="fixed",
726
- fovy=f"{fovy if degrees else np.rad2deg(fovy)}",
727
- pos=" ".join(p.astype(str).tolist()),
728
- quat=" ".join(Q.astype(str).tolist()),
729
- **kwargs,
730
- )
731
-
732
- def asdict(self) -> dict[str, str]:
733
-
734
- return {k: v for k, v in dataclasses.asdict(self).items() if v is not None}
@@ -0,0 +1,223 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from collections.abc import Sequence
5
+
6
+ import mujoco as mj
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ from scipy.spatial.transform import Rotation
10
+
11
+ from .model import MujocoModelHelper
12
+
13
+
14
+ def mujoco_data_from_jaxsim(
15
+ mujoco_model: mj.MjModel,
16
+ jaxsim_model,
17
+ jaxsim_data,
18
+ mujoco_data: mj.MjData | None = None,
19
+ update_removed_joints: bool = True,
20
+ ) -> mj.MjData:
21
+ """
22
+ Create a Mujoco data object from a JaxSim model and data objects.
23
+
24
+ Args:
25
+ mujoco_model: The Mujoco model object corresponding to the JaxSim model.
26
+ jaxsim_model: The JaxSim model object from which the Mujoco model was created.
27
+ jaxsim_data: The JaxSim data object containing the state of the model.
28
+ mujoco_data: An optional Mujoco data object. If None, a new one will be created.
29
+ update_removed_joints:
30
+ If True, the positions of the joints that have been removed during the
31
+ model reduction process will be set to their initial values.
32
+
33
+ Returns:
34
+ The Mujoco data object containing the state of the JaxSim model.
35
+
36
+ Note:
37
+ This method is useful to initialize a Mujoco data object used for visualization
38
+ with the state of a JaxSim model. In particular, this function takes care of
39
+ initializing the positions of the joints that have been removed during the
40
+ model reduction process. After the initial creation of the Mujoco data object,
41
+ it's faster to update the state using an external MujocoModelHelper object.
42
+ """
43
+
44
+ # The package `jaxsim.mujoco` is supposed to be jax-independent.
45
+ # We import all the JaxSim resources privately.
46
+ import jaxsim.api as js
47
+
48
+ if not isinstance(jaxsim_model, js.model.JaxSimModel):
49
+ raise ValueError("The `jaxsim_model` argument must be a JaxSimModel object.")
50
+
51
+ if not isinstance(jaxsim_data, js.data.JaxSimModelData):
52
+ raise ValueError("The `jaxsim_data` argument must be a JaxSimModelData object.")
53
+
54
+ # Create the helper to operate on the Mujoco model and data.
55
+ model_helper = MujocoModelHelper(model=mujoco_model, data=mujoco_data)
56
+
57
+ # If the model is fixed-base, the Mujoco model won't have the joint corresponding
58
+ # to the floating base, and the helper would raise an exception.
59
+ if jaxsim_model.floating_base():
60
+
61
+ # Set the model position.
62
+ model_helper.set_base_position(position=np.array(jaxsim_data.base_position()))
63
+
64
+ # Set the model orientation.
65
+ model_helper.set_base_orientation(
66
+ orientation=np.array(jaxsim_data.base_orientation())
67
+ )
68
+
69
+ # Set the joint positions.
70
+ if jaxsim_model.dofs() > 0:
71
+
72
+ model_helper.set_joint_positions(
73
+ joint_names=list(jaxsim_model.joint_names()),
74
+ positions=np.array(
75
+ jaxsim_data.joint_positions(
76
+ model=jaxsim_model, joint_names=jaxsim_model.joint_names()
77
+ )
78
+ ),
79
+ )
80
+
81
+ # Updating these joints is not necessary after the first time.
82
+ # Users can disable this update after initialization.
83
+ if update_removed_joints:
84
+
85
+ # Create a dictionary with the joints that have been removed for various reasons
86
+ # (like link lumping due to model reduction).
87
+ joints_removed_dict = {
88
+ j.name: j
89
+ for j in jaxsim_model.description._joints_removed
90
+ if j.name not in set(jaxsim_model.joint_names())
91
+ }
92
+
93
+ # Set the positions of the removed joints.
94
+ _ = [
95
+ model_helper.set_joint_position(
96
+ position=joints_removed_dict[joint_name].initial_position,
97
+ joint_name=joint_name,
98
+ )
99
+ # Select all original joint that have been removed from the JaxSim model
100
+ # that are still present in the Mujoco model.
101
+ for joint_name in joints_removed_dict
102
+ if joint_name in model_helper.joint_names()
103
+ ]
104
+
105
+ # Return the mujoco data with updated kinematics.
106
+ mj.mj_forward(mujoco_model, model_helper.data)
107
+
108
+ return model_helper.data
109
+
110
+
111
+ @dataclasses.dataclass
112
+ class MujocoCamera:
113
+ """
114
+ Helper class storing parameters of a Mujoco camera.
115
+
116
+ Refer to the official documentation for more details:
117
+ https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-camera
118
+ """
119
+
120
+ mode: str = "fixed"
121
+
122
+ target: str | None = None
123
+ fovy: str = "45"
124
+ pos: str = "0 0 0"
125
+
126
+ quat: str | None = None
127
+ axisangle: str | None = None
128
+ xyaxes: str | None = None
129
+ zaxis: str | None = None
130
+ euler: str | None = None
131
+
132
+ name: str | None = None
133
+
134
+ @classmethod
135
+ def build(cls, **kwargs) -> MujocoCamera:
136
+
137
+ if not all(isinstance(value, str) for value in kwargs.values()):
138
+ raise ValueError(f"Values must be strings: {kwargs}")
139
+
140
+ return cls(**kwargs)
141
+
142
+ @staticmethod
143
+ def build_from_target_view(
144
+ camera_name: str,
145
+ lookat: Sequence[float | int] | npt.NDArray = (0, 0, 0),
146
+ distance: float | int | npt.NDArray = 3,
147
+ azimut: float | int | npt.NDArray = 90,
148
+ elevation: float | int | npt.NDArray = -45,
149
+ fovy: float | int | npt.NDArray = 45,
150
+ degrees: bool = True,
151
+ **kwargs,
152
+ ) -> MujocoCamera:
153
+ """
154
+ Create a custom camera that looks at a target point.
155
+
156
+ Note:
157
+ The choice of the parameters is easier if we imagine to consider a target
158
+ frame `T` whose origin is located over the lookat point and having the same
159
+ orientation of the world frame `W`. We also introduce a camera frame `C`
160
+ whose origin is located over the lower-left corner of the image, and having
161
+ the x-axis pointing right and the y-axis pointing up in image coordinates.
162
+ The camera renders what it sees in the -z direction of frame `C`.
163
+
164
+ Args:
165
+ camera_name: The name of the camera.
166
+ lookat: The target point to look at (origin of `T`).
167
+ distance:
168
+ The distance from the target point (displacement between the origins
169
+ of `T` and `C`).
170
+ azimut:
171
+ The rotation around z of the camera. With an angle of 0, the camera
172
+ would loot at the target point towards the positive x-axis of `T`.
173
+ elevation:
174
+ The rotation around the x-axis of the camera frame `C`. Note that if
175
+ you want to lift the view angle, the elevation is negative.
176
+ fovy: The field of view of the camera.
177
+ degrees: Whether the angles are in degrees or radians.
178
+ **kwargs: Additional camera parameters.
179
+
180
+ Returns:
181
+ The custom camera.
182
+ """
183
+
184
+ # Start from a frame whose origin is located over the lookat point.
185
+ # We initialize a -90 degrees rotation around the z-axis because due to
186
+ # the default camera coordinate system (x pointing right, y pointing up).
187
+ W_H_C = np.eye(4)
188
+ W_H_C[0:3, 3] = np.array(lookat)
189
+ W_H_C[0:3, 0:3] = Rotation.from_euler(
190
+ seq="ZX", angles=[-90, 90], degrees=True
191
+ ).as_matrix()
192
+
193
+ # Process the azimut.
194
+ R_az = Rotation.from_euler(seq="Y", angles=azimut, degrees=degrees).as_matrix()
195
+ W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_az
196
+
197
+ # Process elevation.
198
+ R_el = Rotation.from_euler(
199
+ seq="X", angles=elevation, degrees=degrees
200
+ ).as_matrix()
201
+ W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_el
202
+
203
+ # Process distance.
204
+ tf_distance = np.eye(4)
205
+ tf_distance[2, 3] = distance
206
+ W_H_C = W_H_C @ tf_distance
207
+
208
+ # Extract the position and the quaternion.
209
+ p = W_H_C[0:3, 3]
210
+ Q = Rotation.from_matrix(W_H_C[0:3, 0:3]).as_quat(scalar_first=True)
211
+
212
+ return MujocoCamera.build(
213
+ name=camera_name,
214
+ mode="fixed",
215
+ fovy=f"{fovy if degrees else np.rad2deg(fovy)}",
216
+ pos=" ".join(p.astype(str).tolist()),
217
+ quat=" ".join(Q.astype(str).tolist()),
218
+ **kwargs,
219
+ )
220
+
221
+ def asdict(self) -> dict[str, str]:
222
+
223
+ return {k: v for k, v in dataclasses.asdict(self).items() if v is not None}