jaxsim 0.2.1.dev20__tar.gz → 0.2.1.dev38__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/PKG-INFO +3 -3
  2. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/environment.yml +1 -1
  3. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/setup.cfg +1 -1
  4. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/_version.py +2 -2
  5. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/api/model.py +38 -11
  6. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/math/joint_model.py +11 -62
  7. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/parsers/descriptions/joint.py +30 -47
  8. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/parsers/descriptions/model.py +14 -6
  9. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/parsers/kinematic_graph.py +245 -78
  10. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/parsers/rod/parser.py +2 -2
  11. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/parsers/rod/utils.py +17 -32
  12. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/utils/jaxsim_dataclass.py +4 -2
  13. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim.egg-info/PKG-INFO +3 -3
  14. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim.egg-info/requires.txt +2 -2
  15. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/tests/test_api_model.py +96 -17
  16. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/tests/utils_idyntree.py +80 -4
  17. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/.devcontainer/Dockerfile +0 -0
  18. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/.devcontainer/devcontainer.json +0 -0
  19. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/.gitattributes +0 -0
  20. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/.github/CODEOWNERS +0 -0
  21. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/.github/workflows/ci_cd.yml +0 -0
  22. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/.github/workflows/read_the_docs.yml +0 -0
  23. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/.github/workflows/style.yml +0 -0
  24. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/.gitignore +0 -0
  25. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/.pre-commit-config.yaml +0 -0
  26. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/.readthedocs.yaml +0 -0
  27. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/CONTRIBUTING.md +0 -0
  28. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/LICENSE +0 -0
  29. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/README.md +0 -0
  30. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/docs/Makefile +0 -0
  31. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/docs/conf.py +0 -0
  32. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/docs/guide/install.rst +0 -0
  33. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/docs/index.rst +0 -0
  34. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/docs/make.bat +0 -0
  35. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/docs/modules/api.rst +0 -0
  36. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/docs/modules/index.rst +0 -0
  37. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/docs/modules/integrators.rst +0 -0
  38. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/docs/modules/math.rst +0 -0
  39. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/docs/modules/mujoco.rst +0 -0
  40. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/docs/modules/parsers.rst +0 -0
  41. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/docs/modules/rbda.rst +0 -0
  42. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/docs/modules/typing.rst +0 -0
  43. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/docs/modules/utils.rst +0 -0
  44. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/examples/.gitattributes +0 -0
  45. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/examples/.gitignore +0 -0
  46. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/examples/PD_controller.ipynb +0 -0
  47. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/examples/Parallel_computing.ipynb +0 -0
  48. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/examples/README.md +0 -0
  49. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/examples/assets/cartpole.urdf +0 -0
  50. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/pixi.lock +0 -0
  51. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/pyproject.toml +0 -0
  52. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/setup.py +0 -0
  53. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/__init__.py +0 -0
  54. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/api/__init__.py +0 -0
  55. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/api/com.py +0 -0
  56. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/api/common.py +0 -0
  57. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/api/contact.py +0 -0
  58. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/api/data.py +0 -0
  59. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/api/joint.py +0 -0
  60. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
  61. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/api/link.py +0 -0
  62. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/api/ode.py +0 -0
  63. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/api/ode_data.py +0 -0
  64. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/api/references.py +0 -0
  65. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/integrators/__init__.py +0 -0
  66. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/integrators/common.py +0 -0
  67. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/integrators/fixed_step.py +0 -0
  68. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/integrators/variable_step.py +0 -0
  69. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/logging.py +0 -0
  70. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/math/__init__.py +0 -0
  71. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/math/adjoint.py +0 -0
  72. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/math/cross.py +0 -0
  73. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/math/inertia.py +0 -0
  74. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/math/quaternion.py +0 -0
  75. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/math/rotation.py +0 -0
  76. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/math/skew.py +0 -0
  77. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/math/transform.py +0 -0
  78. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/mujoco/__init__.py +0 -0
  79. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/mujoco/__main__.py +0 -0
  80. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/mujoco/loaders.py +0 -0
  81. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/mujoco/model.py +0 -0
  82. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/mujoco/visualizer.py +0 -0
  83. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/parsers/__init__.py +0 -0
  84. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  85. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  86. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/parsers/descriptions/link.py +0 -0
  87. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/parsers/rod/__init__.py +0 -0
  88. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/rbda/__init__.py +0 -0
  89. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/rbda/aba.py +0 -0
  90. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/rbda/collidable_points.py +0 -0
  91. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/rbda/crba.py +0 -0
  92. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  93. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/rbda/jacobian.py +0 -0
  94. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/rbda/rnea.py +0 -0
  95. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/rbda/soft_contacts.py +0 -0
  96. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/rbda/utils.py +0 -0
  97. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/terrain/__init__.py +0 -0
  98. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/terrain/terrain.py +0 -0
  99. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/typing.py +0 -0
  100. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/utils/__init__.py +0 -0
  101. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/utils/hashless.py +0 -0
  102. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim/utils/tracing.py +0 -0
  103. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  104. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  105. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim.egg-info/not-zip-safe +0 -0
  106. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/src/jaxsim.egg-info/top_level.txt +0 -0
  107. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/tests/__init__.py +0 -0
  108. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/tests/conftest.py +0 -0
  109. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/tests/test_api_com.py +0 -0
  110. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/tests/test_api_data.py +0 -0
  111. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/tests/test_api_joint.py +0 -0
  112. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/tests/test_api_link.py +0 -0
  113. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/tests/test_automatic_differentiation.py +0 -0
  114. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/tests/test_pytree.py +0 -0
  115. {jaxsim-0.2.1.dev20 → jaxsim-0.2.1.dev38}/tests/test_simulations.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.2.1.dev20
3
+ Version: 0.2.1.dev38
4
4
  Home-page: https://github.com/ami-iit/jaxsim
5
5
  Author: Diego Ferigo
6
6
  Author-email: diego.ferigo@iit.it
@@ -44,7 +44,7 @@ Requires-Dist: black[jupyter]~=24.0; extra == "style"
44
44
  Requires-Dist: isort; extra == "style"
45
45
  Requires-Dist: pre-commit; extra == "style"
46
46
  Provides-Extra: testing
47
- Requires-Dist: idyntree; extra == "testing"
47
+ Requires-Dist: idyntree>=12.2.1; extra == "testing"
48
48
  Requires-Dist: pytest>=6.0; extra == "testing"
49
49
  Requires-Dist: pytest-icdiff; extra == "testing"
50
50
  Requires-Dist: robot-descriptions; extra == "testing"
@@ -56,7 +56,7 @@ Provides-Extra: all
56
56
  Requires-Dist: black[jupyter]~=24.0; extra == "all"
57
57
  Requires-Dist: isort; extra == "all"
58
58
  Requires-Dist: pre-commit; extra == "all"
59
- Requires-Dist: idyntree; extra == "all"
59
+ Requires-Dist: idyntree>=12.2.1; extra == "all"
60
60
  Requires-Dist: pytest>=6.0; extra == "all"
61
61
  Requires-Dist: pytest-icdiff; extra == "all"
62
62
  Requires-Dist: robot-descriptions; extra == "all"
@@ -22,7 +22,7 @@ dependencies:
22
22
  - isort
23
23
  - pre-commit
24
24
  # [testing]
25
- - idyntree
25
+ - idyntree >= 12.2.1
26
26
  - pytest
27
27
  - pytest-icdiff
28
28
  - robot_descriptions
@@ -69,7 +69,7 @@ style =
69
69
  isort
70
70
  pre-commit
71
71
  testing =
72
- idyntree
72
+ idyntree >= 12.2.1
73
73
  pytest >=6.0
74
74
  pytest-icdiff
75
75
  robot-descriptions
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.1.dev20'
16
- __version_tuple__ = version_tuple = (0, 2, 1, 'dev20')
15
+ __version__ = version = '0.2.1.dev38'
16
+ __version_tuple__ = version_tuple = (0, 2, 1, 'dev38')
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import copy
3
4
  import dataclasses
4
5
  import functools
5
6
  import pathlib
6
- from typing import Any
7
+ from typing import Any, Sequence
7
8
 
8
9
  import jax
9
10
  import jax.numpy as jnp
@@ -55,7 +56,7 @@ class JaxSimModel(JaxsimDataclass):
55
56
  *,
56
57
  terrain: jaxsim.terrain.Terrain | None = None,
57
58
  is_urdf: bool | None = None,
58
- considered_joints: list[str] | None = None,
59
+ considered_joints: Sequence[str] | None = None,
59
60
  ) -> JaxSimModel:
60
61
  """
61
62
  Build a Model object from a model description.
@@ -257,24 +258,50 @@ class JaxSimModel(JaxsimDataclass):
257
258
  # =====================
258
259
 
259
260
 
260
- def reduce(model: JaxSimModel, considered_joints: tuple[str, ...]) -> JaxSimModel:
261
+ def reduce(
262
+ model: JaxSimModel,
263
+ considered_joints: tuple[str, ...],
264
+ locked_joint_positions: dict[str, jtp.Float] | None = None,
265
+ ) -> JaxSimModel:
261
266
  """
262
267
  Reduce the model by lumping together the links connected by removed joints.
263
268
 
264
269
  Args:
265
270
  model: The model to reduce.
266
271
  considered_joints: The sequence of joints to consider.
267
-
268
- Note:
269
- If considered_joints contains joints not existing in the model, the method
270
- will raise an exception. If considered_joints is empty, the method will
271
- return a copy of the input model.
272
+ locked_joint_positions:
273
+ A dictionary containing the positions of the joints to be considered
274
+ in the reduction process. The removed joints in the reduced model
275
+ will have their position locked to their value in this dictionary.
276
+ If a joint is not part of the dictionary, its position is set to zero.
272
277
  """
273
278
 
279
+ locked_joint_positions = (
280
+ locked_joint_positions if locked_joint_positions is not None else {}
281
+ )
282
+
283
+ # If locked joints are passed, make sure that they are valid.
284
+ if not set(locked_joint_positions).issubset(model.joint_names()):
285
+ new_joints = set(model.joint_names()) - set(locked_joint_positions)
286
+ raise ValueError(f"Passed joints not existing in the model: {new_joints}")
287
+
288
+ # Copy the model description with a deep copy of the joints.
289
+ intermediate_description = dataclasses.replace(
290
+ model.description.get(), joints=copy.deepcopy(model.description.get().joints)
291
+ )
292
+
293
+ # Update the initial position of the joints.
294
+ # This is necessary to compute the correct pose of the link pairs connected
295
+ # to removed joints.
296
+ for joint_name in set(model.joint_names()) - set(considered_joints):
297
+ j = intermediate_description.joints_dict[joint_name]
298
+ with j.mutable_context():
299
+ j.initial_position = float(locked_joint_positions.get(joint_name, 0.0))
300
+
274
301
  # Reduce the model description.
275
- # If considered_joints contains joints not existing in the model, the method
276
- # will raise an exception.
277
- reduced_intermediate_description = model.description.obj.reduce(
302
+ # If `considered_joints` contains joints not existing in the model,
303
+ # the method will raise an exception.
304
+ reduced_intermediate_description = intermediate_description.reduce(
278
305
  considered_joints=list(considered_joints)
279
306
  )
280
307
 
@@ -15,6 +15,7 @@ from jaxsim.parsers.descriptions import (
15
15
  JointType,
16
16
  ModelDescription,
17
17
  )
18
+ from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms
18
19
 
19
20
  from .rotation import Rotation
20
21
 
@@ -87,21 +88,19 @@ class JointModel:
87
88
  # w.r.t. the implicit __model__ SDF frame is not the identity).
88
89
  suc_H_i = suc_H_i.at[0].set(ordered_links[0].pose)
89
90
 
91
+ # Create the object to compute forward kinematics.
92
+ fk = KinematicGraphTransforms(graph=description)
93
+
90
94
  # Compute the parent-to-predecessor and successor-to-child transforms for
91
95
  # each joint belonging to the model.
92
96
  # Note that the joint indices starts from i=1 given our joint model,
93
97
  # therefore the entries at index 0 are not updated.
94
98
  for joint in ordered_joints:
95
99
  λ_H_pre = λ_H_pre.at[joint.index].set(
96
- description.relative_transform(
97
- relative_to=joint.parent.name,
98
- name=joint.name,
99
- )
100
+ fk.relative_transform(relative_to=joint.parent.name, name=joint.name)
100
101
  )
101
102
  suc_H_i = suc_H_i.at[joint.index].set(
102
- description.relative_transform(
103
- relative_to=joint.name, name=joint.child.name
104
- )
103
+ fk.relative_transform(relative_to=joint.name, name=joint.child.name)
105
104
  )
106
105
 
107
106
  # Define the DoFs of the base link.
@@ -243,16 +242,16 @@ def supported_joint_motion(
243
242
  """
244
243
 
245
244
  if isinstance(joint_type, JointType):
246
- code = joint_type
245
+ type_enum = joint_type
247
246
  elif isinstance(joint_type, JointDescriptor):
248
- code = joint_type.code
247
+ type_enum = joint_type.joint_type
249
248
  else:
250
249
  raise ValueError(joint_type)
251
250
 
252
251
  # Prepare the joint position
253
252
  s = jnp.array(joint_position).astype(float)
254
253
 
255
- match code:
254
+ match type_enum:
256
255
 
257
256
  case JointType.R:
258
257
  joint_type: JointGenericAxis
@@ -276,58 +275,8 @@ def supported_joint_motion(
276
275
  S = jnp.vstack(jnp.hstack([joint_type.axis.squeeze(), jnp.zeros(3)]))
277
276
 
278
277
  case JointType.F:
279
- raise ValueError("Fixed joints shouldn't be here")
280
-
281
- case JointType.Rx:
282
-
283
- pre_H_suc = jaxlie.SE3.from_rotation(
284
- rotation=jaxlie.SO3.from_x_radians(theta=s)
285
- )
286
-
287
- S = jnp.vstack([0, 0, 0, 1.0, 0, 0])
288
-
289
- case JointType.Ry:
290
-
291
- pre_H_suc = jaxlie.SE3.from_rotation(
292
- rotation=jaxlie.SO3.from_y_radians(theta=s)
293
- )
294
-
295
- S = jnp.vstack([0, 0, 0, 0, 1.0, 0])
296
-
297
- case JointType.Rz:
298
-
299
- pre_H_suc = jaxlie.SE3.from_rotation(
300
- rotation=jaxlie.SO3.from_z_radians(theta=s)
301
- )
302
-
303
- S = jnp.vstack([0, 0, 0, 0, 0, 1.0])
304
-
305
- case JointType.Px:
306
-
307
- pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
308
- rotation=jaxlie.SO3.identity(),
309
- translation=jnp.array([s, 0.0, 0.0]),
310
- )
311
-
312
- S = jnp.vstack([1.0, 0, 0, 0, 0, 0])
313
-
314
- case JointType.Py:
315
-
316
- pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
317
- rotation=jaxlie.SO3.identity(),
318
- translation=jnp.array([0.0, s, 0.0]),
319
- )
320
-
321
- S = jnp.vstack([0, 1.0, 0, 0, 0, 0])
322
-
323
- case JointType.Pz:
324
-
325
- pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
326
- rotation=jaxlie.SO3.identity(),
327
- translation=jnp.array([0.0, 0.0, s]),
328
- )
329
-
330
- S = jnp.vstack([0, 0, 1.0, 0, 0, 0])
278
+ pre_H_suc = jaxlie.SE3.identity()
279
+ S = jnp.zeros(shape=(6, 1))
331
280
 
332
281
  case _:
333
282
  raise ValueError(joint_type)
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import dataclasses
2
4
  import enum
3
5
  from typing import Tuple, Union
@@ -6,79 +8,60 @@ import jax_dataclasses
6
8
  import numpy as np
7
9
  import numpy.typing as npt
8
10
 
11
+ import jaxsim.typing as jtp
9
12
  from jaxsim.utils import JaxsimDataclass, Mutability
10
13
 
11
14
  from .link import LinkDescription
12
15
 
13
16
 
17
+ @enum.unique
14
18
  class JointType(enum.IntEnum):
15
19
  """
16
- Enumeration of joint types for robot joints.
17
-
18
- Args:
19
- F: Fixed joint (no movement).
20
- R: Revolute joint (rotation).
21
- P: Prismatic joint (translation).
22
- Rx: Revolute joint with rotation about the X-axis.
23
- Ry: Revolute joint with rotation about the Y-axis.
24
- Rz: Revolute joint with rotation about the Z-axis.
25
- Px: Prismatic joint with translation along the X-axis.
26
- Py: Prismatic joint with translation along the Y-axis.
27
- Pz: Prismatic joint with translation along the Z-axis.
20
+ Type of supported joints.
28
21
  """
29
22
 
30
- F = enum.auto() # Fixed
31
- R = enum.auto() # Revolute
32
- P = enum.auto() # Prismatic
23
+ @staticmethod
24
+ def _generate_next_value_(name, start, count, last_values):
25
+ # Start auto Enum value from 0 instead of 1
26
+ return count
27
+
28
+ #: Fixed joint.
29
+ F = enum.auto()
33
30
 
34
- # Revolute joints, single axis
35
- Rx = enum.auto()
36
- Ry = enum.auto()
37
- Rz = enum.auto()
31
+ #: Revolute joint (1 DoF around axis).
32
+ R = enum.auto()
38
33
 
39
- # Prismatic joints, single axis
40
- Px = enum.auto()
41
- Py = enum.auto()
42
- Pz = enum.auto()
34
+ #: Prismatic joint (1 DoF along axis).
35
+ P = enum.auto()
43
36
 
44
37
 
45
- @dataclasses.dataclass
38
+ @jax_dataclasses.pytree_dataclass
46
39
  class JointDescriptor:
47
40
  """
48
- Description of a joint type with a specific code.
49
-
50
- Args:
51
- code (JointType): The code representing the joint type.
52
-
41
+ Base class for joint types requiring to store additional metadata.
53
42
  """
54
43
 
55
- code: JointType
44
+ #: The joint type.
45
+ joint_type: JointType
56
46
 
57
- def __hash__(self) -> int:
58
- return hash(self.__repr__())
59
47
 
60
-
61
- @dataclasses.dataclass
48
+ @jax_dataclasses.pytree_dataclass
62
49
  class JointGenericAxis(JointDescriptor):
63
50
  """
64
- Description of a joint type with a generic axis.
65
-
66
- Attributes:
67
- axis (npt.NDArray): The axis of rotation or translation for the joint.
68
-
51
+ A joint requiring the specification of a 3D axis.
69
52
  """
70
53
 
71
- axis: npt.NDArray
54
+ #: The axis of rotation or translation of the joint (must have norm 1).
55
+ axis: jtp.Vector
72
56
 
73
- def __post_init__(self):
74
- if np.allclose(self.axis, 0.0):
75
- raise ValueError(self.axis)
57
+ def __hash__(self) -> int:
58
+ return hash((self.joint_type, tuple(np.array(self.axis).tolist())))
76
59
 
77
- def __eq__(self, other):
78
- return super().__eq__(other) and np.allclose(self.axis, other.axis)
60
+ def __eq__(self, other: JointGenericAxis) -> bool:
61
+ if not isinstance(other, JointGenericAxis):
62
+ return False
79
63
 
80
- def __hash__(self) -> int:
81
- return hash(self.__repr__())
64
+ return hash(self) == hash(other)
82
65
 
83
66
 
84
67
  @jax_dataclasses.pytree_dataclass
@@ -4,7 +4,7 @@ from typing import List
4
4
 
5
5
  from jaxsim import logging
6
6
 
7
- from ..kinematic_graph import KinematicGraph, RootPose
7
+ from ..kinematic_graph import KinematicGraph, KinematicGraphTransforms, RootPose
8
8
  from .collision import CollidablePoint, CollisionShape
9
9
  from .joint import JointDescription
10
10
  from .link import LinkDescription
@@ -75,6 +75,9 @@ class ModelDescription(KinematicGraph):
75
75
  considered_joints=considered_joints
76
76
  )
77
77
 
78
+ # Create the object to compute forward kinematics.
79
+ fk = KinematicGraphTransforms(graph=kinematic_graph)
80
+
78
81
  # Store here the final model collisions
79
82
  final_collisions: List[CollisionShape] = []
80
83
 
@@ -121,7 +124,7 @@ class ModelDescription(KinematicGraph):
121
124
  # relative pose
122
125
  moved_cp = cp.change_link(
123
126
  new_link=real_parent_link_of_shape,
124
- new_H_old=kinematic_graph.relative_transform(
127
+ new_H_old=fk.relative_transform(
125
128
  relative_to=real_parent_link_of_shape.name,
126
129
  name=cp.parent_link.name,
127
130
  ),
@@ -139,7 +142,9 @@ class ModelDescription(KinematicGraph):
139
142
  root=kinematic_graph.root,
140
143
  joints=kinematic_graph.joints,
141
144
  frames=kinematic_graph.frames,
145
+ _joints_removed=kinematic_graph._joints_removed,
142
146
  )
147
+
143
148
  assert kinematic_graph.root.name == base_link_name, kinematic_graph.root.name
144
149
 
145
150
  return model
@@ -158,15 +163,12 @@ class ModelDescription(KinematicGraph):
158
163
  ValueError: If the specified joints are not part of the model.
159
164
  """
160
165
 
161
- msg = "The model reduction logic assumes that removed joints have zero angles"
162
- logging.info(msg=msg)
163
-
164
166
  if len(set(considered_joints) - set(self.joint_names())) != 0:
165
167
  extra_joints = set(considered_joints) - set(self.joint_names())
166
168
  msg = f"Found joints not part of the model: {extra_joints}"
167
169
  raise ValueError(msg)
168
170
 
169
- return ModelDescription.build_model_from(
171
+ reduced_model_description = ModelDescription.build_model_from(
170
172
  name=self.name,
171
173
  links=list(self.links_dict.values()),
172
174
  joints=self.joints,
@@ -177,6 +179,12 @@ class ModelDescription(KinematicGraph):
177
179
  considered_joints=considered_joints,
178
180
  )
179
181
 
182
+ # Include the unconnected/removed joints from the original model.
183
+ for joint in self._joints_removed:
184
+ reduced_model_description._joints_removed.append(joint)
185
+
186
+ return reduced_model_description
187
+
180
188
  def update_collision_shape_of_link(self, link_name: str, enabled: bool) -> None:
181
189
  """
182
190
  Enable or disable collision shapes associated with a link.