jaxsim 0.2.1.dev80__tar.gz → 0.2.1.dev98__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/PKG-INFO +1 -1
  2. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/pyproject.toml +0 -3
  3. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/_version.py +2 -2
  4. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/api/data.py +19 -1
  5. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/api/frame.py +4 -4
  6. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/api/kin_dyn_parameters.py +21 -18
  7. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/api/model.py +26 -7
  8. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/api/ode_data.py +31 -0
  9. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/math/joint_model.py +25 -18
  10. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/parsers/descriptions/joint.py +3 -1
  11. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/rbda/soft_contacts.py +17 -0
  12. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/utils/__init__.py +1 -1
  13. jaxsim-0.2.1.dev98/src/jaxsim/utils/wrappers.py +78 -0
  14. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim.egg-info/PKG-INFO +1 -1
  15. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim.egg-info/SOURCES.txt +1 -1
  16. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/tests/conftest.py +19 -4
  17. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/tests/test_api_com.py +1 -1
  18. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/tests/test_api_data.py +3 -3
  19. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/tests/test_api_frame.py +10 -12
  20. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/tests/test_api_link.py +4 -4
  21. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/tests/test_api_model.py +8 -9
  22. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/tests/test_automatic_differentiation.py +8 -8
  23. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/tests/test_contact.py +1 -1
  24. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/tests/test_pytree.py +10 -24
  25. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/tests/utils_idyntree.py +5 -8
  26. jaxsim-0.2.1.dev80/src/jaxsim/utils/hashless.py +0 -18
  27. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/.devcontainer/Dockerfile +0 -0
  28. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/.devcontainer/devcontainer.json +0 -0
  29. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/.gitattributes +0 -0
  30. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/.github/CODEOWNERS +0 -0
  31. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/.github/workflows/ci_cd.yml +0 -0
  32. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/.github/workflows/read_the_docs.yml +0 -0
  33. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/.github/workflows/style.yml +0 -0
  34. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/.gitignore +0 -0
  35. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/.pre-commit-config.yaml +0 -0
  36. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/.readthedocs.yaml +0 -0
  37. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/CONTRIBUTING.md +0 -0
  38. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/LICENSE +0 -0
  39. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/README.md +0 -0
  40. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/docs/Makefile +0 -0
  41. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/docs/conf.py +0 -0
  42. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/docs/guide/install.rst +0 -0
  43. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/docs/index.rst +0 -0
  44. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/docs/make.bat +0 -0
  45. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/docs/modules/api.rst +0 -0
  46. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/docs/modules/index.rst +0 -0
  47. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/docs/modules/integrators.rst +0 -0
  48. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/docs/modules/math.rst +0 -0
  49. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/docs/modules/mujoco.rst +0 -0
  50. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/docs/modules/parsers.rst +0 -0
  51. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/docs/modules/rbda.rst +0 -0
  52. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/docs/modules/typing.rst +0 -0
  53. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/docs/modules/utils.rst +0 -0
  54. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/environment.yml +0 -0
  55. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/examples/.gitattributes +0 -0
  56. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/examples/.gitignore +0 -0
  57. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/examples/PD_controller.ipynb +0 -0
  58. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/examples/Parallel_computing.ipynb +0 -0
  59. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/examples/README.md +0 -0
  60. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/examples/assets/cartpole.urdf +0 -0
  61. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/pixi.lock +0 -0
  62. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/setup.cfg +0 -0
  63. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/setup.py +0 -0
  64. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/__init__.py +0 -0
  65. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/api/__init__.py +0 -0
  66. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/api/com.py +0 -0
  67. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/api/common.py +0 -0
  68. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/api/contact.py +0 -0
  69. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/api/joint.py +0 -0
  70. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/api/link.py +0 -0
  71. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/api/ode.py +0 -0
  72. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/api/references.py +0 -0
  73. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/integrators/__init__.py +0 -0
  74. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/integrators/common.py +0 -0
  75. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/integrators/fixed_step.py +0 -0
  76. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/integrators/variable_step.py +0 -0
  77. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/logging.py +0 -0
  78. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/math/__init__.py +0 -0
  79. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/math/adjoint.py +0 -0
  80. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/math/cross.py +0 -0
  81. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/math/inertia.py +0 -0
  82. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/math/quaternion.py +0 -0
  83. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/math/rotation.py +0 -0
  84. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/math/skew.py +0 -0
  85. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/math/transform.py +0 -0
  86. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/mujoco/__init__.py +0 -0
  87. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/mujoco/__main__.py +0 -0
  88. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/mujoco/loaders.py +0 -0
  89. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/mujoco/model.py +0 -0
  90. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/mujoco/visualizer.py +0 -0
  91. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/parsers/__init__.py +0 -0
  92. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  93. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  94. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/parsers/descriptions/link.py +0 -0
  95. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/parsers/descriptions/model.py +0 -0
  96. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  97. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/parsers/rod/__init__.py +0 -0
  98. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/parsers/rod/parser.py +0 -0
  99. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/parsers/rod/utils.py +0 -0
  100. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/rbda/__init__.py +0 -0
  101. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/rbda/aba.py +0 -0
  102. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/rbda/collidable_points.py +0 -0
  103. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/rbda/crba.py +0 -0
  104. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  105. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/rbda/jacobian.py +0 -0
  106. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/rbda/rnea.py +0 -0
  107. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/rbda/utils.py +0 -0
  108. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/terrain/__init__.py +0 -0
  109. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/terrain/terrain.py +0 -0
  110. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/typing.py +0 -0
  111. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  112. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim/utils/tracing.py +0 -0
  113. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  114. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim.egg-info/not-zip-safe +0 -0
  115. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim.egg-info/requires.txt +0 -0
  116. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/src/jaxsim.egg-info/top_level.txt +0 -0
  117. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/tests/__init__.py +0 -0
  118. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/tests/test_api_joint.py +0 -0
  119. {jaxsim-0.2.1.dev80 → jaxsim-0.2.1.dev98}/tests/test_simulations.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.2.1.dev80
3
+ Version: 0.2.1.dev98
4
4
  Home-page: https://github.com/ami-iit/jaxsim
5
5
  Author: Diego Ferigo
6
6
  Author-email: diego.ferigo@iit.it
@@ -24,13 +24,10 @@ profile = "black"
24
24
  [tool.pytest.ini_options]
25
25
  addopts = "-rsxX -v --strict-markers"
26
26
  minversion = "6.0"
27
- preview = true
28
27
  testpaths = [
29
28
  "tests",
30
29
  ]
31
30
 
32
- target-version = "py311"
33
-
34
31
  [tool.ruff]
35
32
  exclude = [
36
33
  ".git",
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.1.dev80'
16
- __version_tuple__ = version_tuple = (0, 2, 1, 'dev80')
15
+ __version__ = version = '0.2.1.dev98'
16
+ __version_tuple__ = version_tuple = (0, 2, 1, 'dev98')
@@ -30,7 +30,7 @@ except ImportError:
30
30
  @jax_dataclasses.pytree_dataclass
31
31
  class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
32
32
  """
33
- Class containing the state of a `JaxSimModel` object.
33
+ Class containing the data of a `JaxSimModel` object.
34
34
  """
35
35
 
36
36
  state: ODEState
@@ -43,6 +43,24 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
43
43
  default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
44
44
  )
45
45
 
46
+ def __hash__(self) -> int:
47
+
48
+ return hash(
49
+ (
50
+ hash(self.state),
51
+ hash(tuple(self.gravity.flatten().tolist())),
52
+ hash(self.soft_contacts_params),
53
+ hash(jnp.atleast_1d(self.time_ns).flatten().tolist()),
54
+ )
55
+ )
56
+
57
+ def __eq__(self, other: JaxSimModelData) -> bool:
58
+
59
+ if not isinstance(other, JaxSimModelData):
60
+ return False
61
+
62
+ return hash(self) == hash(other)
63
+
46
64
  def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
47
65
  """
48
66
  Check if the current state is valid for the given model.
@@ -30,7 +30,7 @@ def idx_of_parent_link(model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike) -
30
30
  """
31
31
 
32
32
  # Get the intermediate representation parsed from the model description.
33
- ir = model.description.get()
33
+ ir = model.description
34
34
 
35
35
  # Extract the indices of the frame and the link it is attached to.
36
36
  F = ir.frames[frame_idx - model.number_of_links()]
@@ -51,7 +51,7 @@ def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> int:
51
51
  The index of the frame.
52
52
  """
53
53
 
54
- frame_names = np.array([frame.name for frame in model.description.get().frames])
54
+ frame_names = np.array([frame.name for frame in model.description.frames])
55
55
 
56
56
  if frame_name in frame_names:
57
57
  idx_in_list = np.argwhere(frame_names == frame_name)
@@ -72,7 +72,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str
72
72
  The name of the frame.
73
73
  """
74
74
 
75
- return model.description.get().frames[frame_index - model.number_of_links()].name
75
+ return model.description.frames[frame_index - model.number_of_links()].name
76
76
 
77
77
 
78
78
  @functools.partial(jax.jit, static_argnames=["frame_names"])
@@ -144,7 +144,7 @@ def transform(
144
144
  W_H_L = js.link.transform(model=model, data=data, link_index=L)
145
145
 
146
146
  # Get the static frame pose wrt the parent link.
147
- frame = model.description.get().frames[frame_index - model.number_of_links()]
147
+ frame = model.description.frames[frame_index - model.number_of_links()]
148
148
  L_H_F = frame.pose
149
149
 
150
150
  # Combine the transforms computing the frame pose.
@@ -11,7 +11,7 @@ from jax_dataclasses import Static
11
11
  import jaxsim.typing as jtp
12
12
  from jaxsim.math import Inertia, JointModel, supported_joint_motion
13
13
  from jaxsim.parsers.descriptions import JointDescription, ModelDescription
14
- from jaxsim.utils import JaxsimDataclass
14
+ from jaxsim.utils import HashedNumpyArray, JaxsimDataclass
15
15
 
16
16
 
17
17
  @jax_dataclasses.pytree_dataclass
@@ -32,8 +32,8 @@ class KynDynParameters(JaxsimDataclass):
32
32
 
33
33
  # Static
34
34
  link_names: Static[tuple[str]]
35
- parent_array: Static[jtp.Vector]
36
- support_body_array_bool: Static[jtp.Matrix]
35
+ _parent_array: Static[HashedNumpyArray]
36
+ _support_body_array_bool: Static[HashedNumpyArray]
37
37
 
38
38
  # Links
39
39
  link_parameters: LinkParameters
@@ -45,6 +45,14 @@ class KynDynParameters(JaxsimDataclass):
45
45
  joint_model: JointModel
46
46
  joint_parameters: JointParameters | None
47
47
 
48
+ @property
49
+ def parent_array(self) -> jtp.Vector:
50
+ return self._parent_array.get()
51
+
52
+ @property
53
+ def support_body_array_bool(self) -> jtp.Matrix:
54
+ return self._support_body_array_bool.get()
55
+
48
56
  @staticmethod
49
57
  def build(model_description: ModelDescription) -> KynDynParameters:
50
58
  """
@@ -191,8 +199,8 @@ class KynDynParameters(JaxsimDataclass):
191
199
 
192
200
  return KynDynParameters(
193
201
  link_names=tuple(l.name for l in ordered_links),
194
- parent_array=parent_array,
195
- support_body_array_bool=support_body_array_bool,
202
+ _parent_array=HashedNumpyArray(array=parent_array),
203
+ _support_body_array_bool=HashedNumpyArray(array=support_body_array_bool),
196
204
  link_parameters=link_parameters,
197
205
  joint_model=joint_model,
198
206
  joint_parameters=joint_parameters,
@@ -204,23 +212,18 @@ class KynDynParameters(JaxsimDataclass):
204
212
  if not isinstance(other, KynDynParameters):
205
213
  return False
206
214
 
207
- equal = True
208
- equal = equal and self.number_of_links() == other.number_of_links()
209
- equal = equal and self.number_of_joints() == other.number_of_joints()
210
- equal = equal and jnp.allclose(self.parent_array, other.parent_array)
211
-
212
- return equal
215
+ return hash(self) == hash(other)
213
216
 
214
217
  def __hash__(self) -> int:
215
218
 
216
- h = (
217
- hash(self.number_of_links()),
218
- hash(self.number_of_joints()),
219
- hash(tuple(self.parent_array.tolist())),
219
+ return hash(
220
+ (
221
+ hash(self.number_of_links()),
222
+ hash(self.number_of_joints()),
223
+ hash(tuple(jnp.atleast_1d(self.parent_array).flatten().tolist())),
224
+ )
220
225
  )
221
226
 
222
- return hash(h)
223
-
224
227
  # =============================
225
228
  # Helpers to extract parameters
226
229
  # =============================
@@ -388,7 +391,7 @@ class KynDynParameters(JaxsimDataclass):
388
391
  pre_H_suc_J, S_J = jax.vmap(supported_joint_motion)(
389
392
  jnp.array(self.joint_model.joint_types[1:]).astype(int),
390
393
  jnp.array(joint_positions),
391
- jnp.array(self.joint_model.joint_axis),
394
+ jnp.array([j.axis for j in self.joint_model.joint_axis]),
392
395
  )
393
396
 
394
397
  # Extract the transforms and motion subspaces of the joints.
@@ -32,18 +32,37 @@ class JaxSimModel(JaxsimDataclass):
32
32
  terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
33
33
  default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False
34
34
  )
35
+ kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
36
+ dataclasses.field(default=None, repr=False, compare=False, hash=False)
37
+ )
35
38
 
36
39
  built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
37
40
  default=None, repr=False, compare=False, hash=False
38
41
  )
39
42
 
40
- description: Static[
43
+ _description: Static[
41
44
  HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None]
42
45
  ] = dataclasses.field(default=None, repr=False, compare=False, hash=False)
43
46
 
44
- kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
45
- dataclasses.field(default=None, repr=False, compare=False, hash=False)
46
- )
47
+ @property
48
+ def description(self) -> jaxsim.parsers.descriptions.ModelDescription:
49
+ return self._description.get()
50
+
51
+ def __eq__(self, other: JaxSimModel) -> bool:
52
+
53
+ if not isinstance(other, JaxSimModel):
54
+ return False
55
+
56
+ return hash(self) == hash(other)
57
+
58
+ def __hash__(self) -> int:
59
+
60
+ return hash(
61
+ (
62
+ hash(self.model_name),
63
+ hash(self.kin_dyn_parameters),
64
+ )
65
+ )
47
66
 
48
67
  # ========================
49
68
  # Initialization and state
@@ -137,7 +156,7 @@ class JaxSimModel(JaxsimDataclass):
137
156
  # Build the model
138
157
  model = JaxSimModel(
139
158
  model_name=model_name,
140
- description=HashlessObject(obj=model_description),
159
+ _description=HashlessObject(obj=model_description),
141
160
  kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
142
161
  model_description=model_description
143
162
  ),
@@ -260,7 +279,7 @@ class JaxSimModel(JaxsimDataclass):
260
279
  The names of the links in the model.
261
280
  """
262
281
 
263
- return tuple([frame.name for frame in self.description.get().frames])
282
+ return tuple(frame.name for frame in self.description.frames)
264
283
 
265
284
 
266
285
  # =====================
@@ -297,7 +316,7 @@ def reduce(
297
316
 
298
317
  # Copy the model description with a deep copy of the joints.
299
318
  intermediate_description = dataclasses.replace(
300
- model.description.get(), joints=copy.deepcopy(model.description.get().joints)
319
+ model.description, joints=copy.deepcopy(model.description.joints)
301
320
  )
302
321
 
303
322
  # Update the initial position of the joints.
@@ -281,6 +281,24 @@ class PhysicsModelState(JaxsimDataclass):
281
281
  default_factory=lambda: jnp.zeros(3)
282
282
  )
283
283
 
284
+ def __hash__(self) -> int:
285
+
286
+ return hash(
287
+ (
288
+ hash(tuple(jnp.atleast_1d(self.joint_positions.flatten().tolist()))),
289
+ hash(tuple(jnp.atleast_1d(self.joint_velocities.flatten().tolist()))),
290
+ hash(tuple(self.base_position.flatten().tolist())),
291
+ hash(tuple(self.base_quaternion.flatten().tolist())),
292
+ )
293
+ )
294
+
295
+ def __eq__(self, other: PhysicsModelState) -> bool:
296
+
297
+ if not isinstance(other, PhysicsModelState):
298
+ return False
299
+
300
+ return hash(self) == hash(other)
301
+
284
302
  @staticmethod
285
303
  def build_from_jaxsim_model(
286
304
  model: js.model.JaxSimModel | None = None,
@@ -593,6 +611,19 @@ class SoftContactsState(JaxsimDataclass):
593
611
 
594
612
  tangential_deformation: jtp.Matrix
595
613
 
614
+ def __hash__(self) -> int:
615
+
616
+ return hash(
617
+ tuple(jnp.atleast_1d(self.tangential_deformation.flatten()).tolist())
618
+ )
619
+
620
+ def __eq__(self, other: SoftContactsState) -> bool:
621
+
622
+ if not isinstance(other, SoftContactsState):
623
+ return False
624
+
625
+ return hash(self) == hash(other)
626
+
596
627
  @staticmethod
597
628
  def build_from_jaxsim_model(
598
629
  model: js.model.JaxSimModel | None = None,
@@ -39,7 +39,7 @@ class JointModel:
39
39
 
40
40
  joint_dofs: Static[tuple[int, ...]]
41
41
  joint_names: Static[tuple[str, ...]]
42
- joint_types: Static[tuple[JointType, ...]]
42
+ joint_types: Static[tuple[int, ...]]
43
43
  joint_axis: Static[tuple[JointGenericAxis, ...]]
44
44
 
45
45
  @staticmethod
@@ -109,7 +109,7 @@ class JointModel:
109
109
  joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]),
110
110
  joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
111
111
  joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]),
112
- joint_axis=tuple([j.axis for j in ordered_joints]),
112
+ joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints),
113
113
  )
114
114
 
115
115
  def parent_H_child(
@@ -201,7 +201,7 @@ class JointModel:
201
201
  pre_H_suc, S = supported_joint_motion(
202
202
  self.joint_types[joint_index],
203
203
  joint_position,
204
- self.joint_axis[joint_index],
204
+ self.joint_axis[joint_index].axis,
205
205
  )
206
206
 
207
207
  return pre_H_suc, S
@@ -224,9 +224,9 @@ class JointModel:
224
224
 
225
225
  @jax.jit
226
226
  def supported_joint_motion(
227
- joint_type: JointType,
227
+ joint_type: jtp.IntLike,
228
228
  joint_position: jtp.VectorLike,
229
- joint_axis: JointGenericAxis,
229
+ joint_axis: jtp.VectorLike | None = None,
230
230
  /,
231
231
  ) -> tuple[jtp.Matrix, jtp.Array]:
232
232
  """
@@ -234,8 +234,8 @@ def supported_joint_motion(
234
234
 
235
235
  Args:
236
236
  joint_type: The type of the joint.
237
- joint_axis: The axis of rotation or translation of the joint.
238
237
  joint_position: The position of the joint.
238
+ joint_axis: The optional 3D axis of rotation or translation of the joint.
239
239
 
240
240
  Returns:
241
241
  A tuple containing the homogeneous transformation and the motion subspace.
@@ -244,26 +244,33 @@ def supported_joint_motion(
244
244
  # Prepare the joint position
245
245
  s = jnp.array(joint_position).astype(float)
246
246
 
247
- def compute_F():
247
+ def compute_F() -> tuple[jtp.Matrix, jtp.Array]:
248
248
  return jaxlie.SE3.identity(), jnp.zeros(shape=(6, 1))
249
249
 
250
- def compute_R():
250
+ def compute_R() -> tuple[jtp.Matrix, jtp.Array]:
251
+
252
+ # Get the additional argument specifying the joint axis.
253
+ # This is a metadata required by only some joint types.
254
+ axis = jnp.array(joint_axis).astype(float).squeeze()
255
+
251
256
  pre_H_suc = jaxlie.SE3.from_rotation(
252
- rotation=jaxlie.SO3.from_matrix(
253
- Rotation.from_axis_angle(vector=s * joint_axis)
254
- )
257
+ rotation=jaxlie.SO3.from_matrix(Rotation.from_axis_angle(vector=s * axis))
255
258
  )
256
259
 
257
- S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_axis.squeeze()]))
260
+ S = jnp.vstack(jnp.hstack([jnp.zeros(3), axis]))
261
+
258
262
  return pre_H_suc, S
259
263
 
260
- def compute_P():
261
- pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
262
- rotation=jaxlie.SO3.identity(),
263
- translation=jnp.array(s * joint_axis),
264
- )
264
+ def compute_P() -> tuple[jtp.Matrix, jtp.Array]:
265
+
266
+ # Get the additional argument specifying the joint axis.
267
+ # This is a metadata required by only some joint types.
268
+ axis = jnp.array(joint_axis).astype(float).squeeze()
269
+
270
+ pre_H_suc = jaxlie.SE3.from_translation(translation=jnp.array(s * axis))
271
+
272
+ S = jnp.vstack(jnp.hstack([axis, jnp.zeros(3)]))
265
273
 
266
- S = jnp.vstack(jnp.hstack([joint_axis.squeeze(), jnp.zeros(3)]))
267
274
  return pre_H_suc, S
268
275
 
269
276
  pre_H_suc, S = jax.lax.switch(
@@ -30,9 +30,11 @@ class JointGenericAxis:
30
30
  axis: jtp.Vector
31
31
 
32
32
  def __hash__(self) -> int:
33
- return hash((tuple(np.array(self.axis).tolist())))
33
+
34
+ return hash(tuple(self.axis.tolist()))
34
35
 
35
36
  def __eq__(self, other: JointGenericAxis) -> bool:
37
+
36
38
  if not isinstance(other, JointGenericAxis):
37
39
  return False
38
40
 
@@ -29,6 +29,23 @@ class SoftContactsParams(JaxsimDataclass):
29
29
  default_factory=lambda: jnp.array(0.5, dtype=float)
30
30
  )
31
31
 
32
+ def __hash__(self) -> int:
33
+
34
+ return hash(
35
+ (
36
+ hash(tuple(jnp.atleast_1d(self.K).flatten().tolist())),
37
+ hash(tuple(jnp.atleast_1d(self.D).flatten().tolist())),
38
+ hash(tuple(jnp.atleast_1d(self.mu).flatten().tolist())),
39
+ )
40
+ )
41
+
42
+ def __eq__(self, other: SoftContactsParams) -> bool:
43
+
44
+ if not isinstance(other, SoftContactsParams):
45
+ return NotImplemented
46
+
47
+ return hash(self) == hash(other)
48
+
32
49
  @staticmethod
33
50
  def build(
34
51
  K: jtp.FloatLike = 1e6, D: jtp.FloatLike = 2_000, mu: jtp.FloatLike = 0.5
@@ -1,5 +1,5 @@
1
1
  from jax_dataclasses._copy_and_mutate import _Mutability as Mutability
2
2
 
3
- from .hashless import HashlessObject
4
3
  from .jaxsim_dataclass import JaxsimDataclass
5
4
  from .tracing import not_tracing, tracing
5
+ from .wrappers import HashedNumpyArray, HashlessObject
@@ -0,0 +1,78 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from typing import Generic, TypeVar
5
+
6
+ import jax
7
+ import jax_dataclasses
8
+ import numpy as np
9
+ import numpy.typing as npt
10
+
11
+ T = TypeVar("T")
12
+
13
+
14
+ @dataclasses.dataclass
15
+ class HashlessObject(Generic[T]):
16
+ """
17
+ A class that wraps an object and makes it hashless.
18
+
19
+ This is useful for creating particular JAX pytrees.
20
+ For example, to create a pytree with a static leaf that is ignored
21
+ by JAX when it compares two instances to trigger a JIT recompilation.
22
+ """
23
+
24
+ obj: T
25
+
26
+ def get(self: HashlessObject[T]) -> T:
27
+ return self.obj
28
+
29
+ def __hash__(self) -> int:
30
+
31
+ return 0
32
+
33
+ def __eq__(self, other: HashlessObject[T]) -> bool:
34
+
35
+ if not isinstance(other, HashlessObject) and isinstance(
36
+ other.get(), type(self.get())
37
+ ):
38
+ return False
39
+
40
+ return hash(self) == hash(other)
41
+
42
+
43
+ @jax_dataclasses.pytree_dataclass
44
+ class HashedNumpyArray:
45
+ """
46
+ A class that wraps a numpy array and makes it hashable.
47
+
48
+ This is useful for creating particular JAX pytrees.
49
+ For example, to create a pytree with a plain NumPy or JAX NumPy array as static leaf.
50
+
51
+ Note:
52
+ Calculating with the wrapper class the hash of a very large array can be
53
+ very expensive. If the array is large and only the equality operator is needed,
54
+ set `large_array=True` to use a faster comparison method.
55
+ """
56
+
57
+ array: jax.Array | npt.NDArray
58
+
59
+ large_array: jax_dataclasses.Static[bool] = dataclasses.field(
60
+ default=False, repr=False, compare=False, hash=False
61
+ )
62
+
63
+ def get(self) -> jax.Array | npt.NDArray:
64
+ return self.array
65
+
66
+ def __hash__(self) -> int:
67
+
68
+ return hash(tuple(np.atleast_1d(self.array).flatten().tolist()))
69
+
70
+ def __eq__(self, other: HashedNumpyArray) -> bool:
71
+
72
+ if not isinstance(other, HashedNumpyArray):
73
+ return False
74
+
75
+ if self.large_array:
76
+ return np.array_equal(self.array, other.array)
77
+
78
+ return hash(self) == hash(other)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.2.1.dev80
3
+ Version: 0.2.1.dev98
4
4
  Home-page: https://github.com/ami-iit/jaxsim
5
5
  Author: Diego Ferigo
6
6
  Author-email: diego.ferigo@iit.it
@@ -99,9 +99,9 @@ src/jaxsim/rbda/utils.py
99
99
  src/jaxsim/terrain/__init__.py
100
100
  src/jaxsim/terrain/terrain.py
101
101
  src/jaxsim/utils/__init__.py
102
- src/jaxsim/utils/hashless.py
103
102
  src/jaxsim/utils/jaxsim_dataclass.py
104
103
  src/jaxsim/utils/tracing.py
104
+ src/jaxsim/utils/wrappers.py
105
105
  tests/__init__.py
106
106
  tests/conftest.py
107
107
  tests/test_api_com.py
@@ -177,18 +177,19 @@ def jaxsim_model_sphere() -> js.model.JaxSimModel:
177
177
 
178
178
 
179
179
  @pytest.fixture(scope="session")
180
- def jaxsim_model_ergocub() -> js.model.JaxSimModel:
180
+ def ergocub_model_description_path() -> pathlib.Path:
181
181
  """
182
- Fixture providing the JaxSim model of the ErgoCub robot.
182
+ Fixture providing the path to the URDF model description of the ErgoCub robot.
183
183
 
184
184
  Returns:
185
- The JaxSim model of the ErgoCub robot.
185
+ The path to the URDF model description of the ErgoCub robot.
186
186
  """
187
187
 
188
188
  try:
189
189
  os.environ["ROBOT_DESCRIPTION_COMMIT"] = "v0.7.1"
190
190
 
191
191
  import robot_descriptions.ergocub_description
192
+
192
193
  finally:
193
194
  _ = os.environ.pop("ROBOT_DESCRIPTION_COMMIT", None)
194
195
 
@@ -198,7 +199,21 @@ def jaxsim_model_ergocub() -> js.model.JaxSimModel:
198
199
  )
199
200
  )
200
201
 
201
- return build_jaxsim_model(model_description=model_urdf_path)
202
+ return model_urdf_path
203
+
204
+
205
+ @pytest.fixture(scope="session")
206
+ def jaxsim_model_ergocub(
207
+ ergocub_model_description_path: pathlib.Path,
208
+ ) -> js.model.JaxSimModel:
209
+ """
210
+ Fixture providing the JaxSim model of the ErgoCub robot.
211
+
212
+ Returns:
213
+ The JaxSim model of the ErgoCub robot.
214
+ """
215
+
216
+ return build_jaxsim_model(model_description=ergocub_model_description_path)
202
217
 
203
218
 
204
219
  @pytest.fixture(scope="session")
@@ -15,7 +15,7 @@ def test_com_properties(
15
15
 
16
16
  model = jaxsim_models_types
17
17
 
18
- key, subkey = jax.random.split(prng_key, num=2)
18
+ _, subkey = jax.random.split(prng_key, num=2)
19
19
  data = js.data.random_model_data(
20
20
  model=model, key=subkey, velocity_representation=velocity_representation
21
21
  )
@@ -27,7 +27,7 @@ def test_data_joint_indexing(
27
27
 
28
28
  model = jaxsim_models_types
29
29
 
30
- key, subkey = jax.random.split(prng_key, num=2)
30
+ _, subkey = jax.random.split(prng_key, num=2)
31
31
  data = js.data.random_model_data(
32
32
  model=model, key=subkey, velocity_representation=velocity_representation
33
33
  )
@@ -56,7 +56,7 @@ def test_data_switch_velocity_representation(
56
56
 
57
57
  model = jaxsim_models_types
58
58
 
59
- key, subkey = jax.random.split(prng_key, num=2)
59
+ _, subkey = jax.random.split(prng_key, num=2)
60
60
  data = js.data.random_model_data(
61
61
  model=model, key=subkey, velocity_representation=VelRepr.Inertial
62
62
  )
@@ -98,7 +98,7 @@ def test_data_change_velocity_representation(
98
98
 
99
99
  model = jaxsim_models_types
100
100
 
101
- key, subkey = jax.random.split(prng_key, num=2)
101
+ _, subkey = jax.random.split(prng_key, num=2)
102
102
  data = js.data.random_model_data(
103
103
  model=model, key=subkey, velocity_representation=VelRepr.Inertial
104
104
  )