jaxsim 0.2.1.dev40__tar.gz → 0.2.1.dev47__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/PKG-INFO +1 -1
  2. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/_version.py +2 -2
  3. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/api/joint.py +3 -12
  4. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/api/kin_dyn_parameters.py +9 -10
  5. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/math/joint_model.py +38 -44
  6. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/descriptions/__init__.py +1 -1
  7. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/descriptions/joint.py +10 -35
  8. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/kinematic_graph.py +6 -4
  9. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/rod/parser.py +1 -1
  10. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/rod/utils.py +4 -8
  11. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim.egg-info/PKG-INFO +1 -1
  12. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/.devcontainer/Dockerfile +0 -0
  13. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/.devcontainer/devcontainer.json +0 -0
  14. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/.gitattributes +0 -0
  15. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/.github/CODEOWNERS +0 -0
  16. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/.github/workflows/ci_cd.yml +0 -0
  17. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/.github/workflows/read_the_docs.yml +0 -0
  18. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/.github/workflows/style.yml +0 -0
  19. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/.gitignore +0 -0
  20. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/.pre-commit-config.yaml +0 -0
  21. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/.readthedocs.yaml +0 -0
  22. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/CONTRIBUTING.md +0 -0
  23. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/LICENSE +0 -0
  24. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/README.md +0 -0
  25. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/docs/Makefile +0 -0
  26. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/docs/conf.py +0 -0
  27. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/docs/guide/install.rst +0 -0
  28. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/docs/index.rst +0 -0
  29. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/docs/make.bat +0 -0
  30. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/docs/modules/api.rst +0 -0
  31. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/docs/modules/index.rst +0 -0
  32. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/docs/modules/integrators.rst +0 -0
  33. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/docs/modules/math.rst +0 -0
  34. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/docs/modules/mujoco.rst +0 -0
  35. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/docs/modules/parsers.rst +0 -0
  36. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/docs/modules/rbda.rst +0 -0
  37. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/docs/modules/typing.rst +0 -0
  38. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/docs/modules/utils.rst +0 -0
  39. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/environment.yml +0 -0
  40. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/examples/.gitattributes +0 -0
  41. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/examples/.gitignore +0 -0
  42. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/examples/PD_controller.ipynb +0 -0
  43. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/examples/Parallel_computing.ipynb +0 -0
  44. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/examples/README.md +0 -0
  45. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/examples/assets/cartpole.urdf +0 -0
  46. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/pixi.lock +0 -0
  47. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/pyproject.toml +0 -0
  48. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/setup.cfg +0 -0
  49. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/setup.py +0 -0
  50. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/__init__.py +0 -0
  51. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/api/__init__.py +0 -0
  52. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/api/com.py +0 -0
  53. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/api/common.py +0 -0
  54. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/api/contact.py +0 -0
  55. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/api/data.py +0 -0
  56. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/api/link.py +0 -0
  57. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/api/model.py +0 -0
  58. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/api/ode.py +0 -0
  59. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/api/ode_data.py +0 -0
  60. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/api/references.py +0 -0
  61. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/integrators/__init__.py +0 -0
  62. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/integrators/common.py +0 -0
  63. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/integrators/fixed_step.py +0 -0
  64. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/integrators/variable_step.py +0 -0
  65. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/logging.py +0 -0
  66. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/math/__init__.py +0 -0
  67. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/math/adjoint.py +0 -0
  68. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/math/cross.py +0 -0
  69. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/math/inertia.py +0 -0
  70. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/math/quaternion.py +0 -0
  71. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/math/rotation.py +0 -0
  72. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/math/skew.py +0 -0
  73. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/math/transform.py +0 -0
  74. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/mujoco/__init__.py +0 -0
  75. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/mujoco/__main__.py +0 -0
  76. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/mujoco/loaders.py +0 -0
  77. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/mujoco/model.py +0 -0
  78. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/mujoco/visualizer.py +0 -0
  79. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/__init__.py +0 -0
  80. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  81. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/descriptions/link.py +0 -0
  82. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/descriptions/model.py +0 -0
  83. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/parsers/rod/__init__.py +0 -0
  84. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/__init__.py +0 -0
  85. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/aba.py +0 -0
  86. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/collidable_points.py +0 -0
  87. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/crba.py +0 -0
  88. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  89. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/jacobian.py +0 -0
  90. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/rnea.py +0 -0
  91. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/soft_contacts.py +0 -0
  92. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/rbda/utils.py +0 -0
  93. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/terrain/__init__.py +0 -0
  94. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/terrain/terrain.py +0 -0
  95. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/typing.py +0 -0
  96. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/utils/__init__.py +0 -0
  97. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/utils/hashless.py +0 -0
  98. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  99. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim/utils/tracing.py +0 -0
  100. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  101. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  102. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim.egg-info/not-zip-safe +0 -0
  103. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim.egg-info/requires.txt +0 -0
  104. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/src/jaxsim.egg-info/top_level.txt +0 -0
  105. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/tests/__init__.py +0 -0
  106. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/tests/conftest.py +0 -0
  107. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/tests/test_api_com.py +0 -0
  108. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/tests/test_api_data.py +0 -0
  109. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/tests/test_api_joint.py +0 -0
  110. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/tests/test_api_link.py +0 -0
  111. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/tests/test_api_model.py +0 -0
  112. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/tests/test_automatic_differentiation.py +0 -0
  113. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/tests/test_pytree.py +0 -0
  114. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/tests/test_simulations.py +0 -0
  115. {jaxsim-0.2.1.dev40 → jaxsim-0.2.1.dev47}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.2.1.dev40
3
+ Version: 0.2.1.dev47
4
4
  Home-page: https://github.com/ami-iit/jaxsim
5
5
  Author: Diego Ferigo
6
6
  Author-email: diego.ferigo@iit.it
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.1.dev40'
16
- __version_tuple__ = version_tuple = (0, 2, 1, 'dev40')
15
+ __version__ = version = '0.2.1.dev47'
16
+ __version_tuple__ = version_tuple = (0, 2, 1, 'dev47')
@@ -3,7 +3,6 @@ from typing import Sequence
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
6
- import numpy as np
7
6
 
8
7
  import jaxsim.api as js
9
8
  import jaxsim.typing as jtp
@@ -30,17 +29,9 @@ def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
30
29
  # Note: the index of the joint for RBDAs starts from 1, but
31
30
  # the index for accessing the right element starts from 0.
32
31
  # Therefore, there is a -1.
33
- return (
34
- jnp.array(
35
- np.argwhere(
36
- np.array(model.kin_dyn_parameters.joint_model.joint_names)
37
- == joint_name
38
- )
39
- - 1
40
- )
41
- .squeeze()
42
- .astype(int)
43
- )
32
+ return jnp.array(
33
+ model.kin_dyn_parameters.joint_model.joint_names.index(joint_name) - 1
34
+ ).squeeze()
44
35
  return jnp.array(-1).astype(int)
45
36
 
46
37
 
@@ -382,21 +382,20 @@ class KynDynParameters(JaxsimDataclass):
382
382
  )
383
383
 
384
384
  # Compute the transforms and motion subspaces of the joints.
385
- # TODO: understand how to use joint_indices to access joint_types, right now
386
- # it fails when used within a JIT context.
387
- pre_H_suc_and_S = [
388
- supported_joint_motion(
389
- joint_type=self.joint_model.joint_types[i + 1],
390
- joint_position=jnp.array(s),
385
+ if self.number_of_joints() == 0:
386
+ pre_H_suc_J, S_J = jnp.empty((0, 4, 4)), jnp.empty((0, 6, 1))
387
+ else:
388
+ pre_H_suc_J, S_J = jax.vmap(supported_joint_motion)(
389
+ jnp.array(self.joint_model.joint_types[1:]).astype(int),
390
+ jnp.array(joint_positions),
391
+ jnp.array(self.joint_model.joint_axis),
391
392
  )
392
- for i, s in enumerate(jnp.array(joint_positions).astype(float))
393
- ]
394
393
 
395
394
  # Extract the transforms and motion subspaces of the joints.
396
395
  # We stack the base transform W_H_B at index 0, and a dummy motion subspace
397
396
  # for either the fixed or free-floating joint connecting the world to the base.
398
- pre_H_suc = jnp.stack([W_H_B] + [H for H, _ in pre_H_suc_and_S])
399
- S = jnp.stack([jnp.vstack(jnp.zeros(6))] + [S for _, S in pre_H_suc_and_S])
397
+ pre_H_suc = jnp.vstack([W_H_B[jnp.newaxis, ...], pre_H_suc_J])
398
+ S = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J])
400
399
 
401
400
  # Extract the successor-to-child fixed transforms.
402
401
  # Note that here we include also the index 0 since suc_H_child[0] stores the
@@ -1,7 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import functools
4
-
5
3
  import jax
6
4
  import jax.numpy as jnp
7
5
  import jax_dataclasses
@@ -9,12 +7,7 @@ import jaxlie
9
7
  from jax_dataclasses import Static
10
8
 
11
9
  import jaxsim.typing as jtp
12
- from jaxsim.parsers.descriptions import (
13
- JointDescriptor,
14
- JointGenericAxis,
15
- JointType,
16
- ModelDescription,
17
- )
10
+ from jaxsim.parsers.descriptions import JointGenericAxis, JointType, ModelDescription
18
11
  from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms
19
12
 
20
13
  from .rotation import Rotation
@@ -46,7 +39,8 @@ class JointModel:
46
39
 
47
40
  joint_dofs: Static[tuple[int, ...]]
48
41
  joint_names: Static[tuple[str, ...]]
49
- joint_types: Static[tuple[JointType | JointDescriptor, ...]]
42
+ joint_types: Static[tuple[JointType, ...]]
43
+ joint_axis: Static[tuple[JointGenericAxis, ...]]
50
44
 
51
45
  @staticmethod
52
46
  def build(description: ModelDescription) -> JointModel:
@@ -114,7 +108,8 @@ class JointModel:
114
108
  # Static attributes
115
109
  joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]),
116
110
  joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
117
- joint_types=tuple([JointType.F] + [j.jtype for j in ordered_joints]),
111
+ joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]),
112
+ joint_axis=tuple([j.axis for j in ordered_joints]),
118
113
  )
119
114
 
120
115
  def parent_H_child(
@@ -204,8 +199,9 @@ class JointModel:
204
199
  """
205
200
 
206
201
  pre_H_suc, S = supported_joint_motion(
207
- joint_type=self.joint_types[joint_index],
208
- joint_position=joint_position,
202
+ self.joint_types[joint_index],
203
+ joint_position,
204
+ self.joint_axis[joint_index],
209
205
  )
210
206
 
211
207
  return pre_H_suc, S
@@ -226,59 +222,57 @@ class JointModel:
226
222
  return self.suc_H_i[joint_index]
227
223
 
228
224
 
229
- @functools.partial(jax.jit, static_argnames=["joint_type"])
225
+ @jax.jit
230
226
  def supported_joint_motion(
231
- joint_type: JointType | JointDescriptor, joint_position: jtp.VectorLike
227
+ joint_type: JointType,
228
+ joint_position: jtp.VectorLike,
229
+ joint_axis: JointGenericAxis,
230
+ /,
232
231
  ) -> tuple[jtp.Matrix, jtp.Array]:
233
232
  """
234
233
  Compute the homogeneous transformation and motion subspace of a joint.
235
234
 
236
235
  Args:
237
236
  joint_type: The type of the joint.
237
+ joint_axis: The axis of rotation or translation of the joint.
238
238
  joint_position: The position of the joint.
239
239
 
240
240
  Returns:
241
241
  A tuple containing the homogeneous transformation and the motion subspace.
242
242
  """
243
243
 
244
- if isinstance(joint_type, JointType):
245
- type_enum = joint_type
246
- elif isinstance(joint_type, JointDescriptor):
247
- type_enum = joint_type.joint_type
248
- else:
249
- raise ValueError(joint_type)
250
-
251
244
  # Prepare the joint position
252
245
  s = jnp.array(joint_position).astype(float)
253
246
 
254
- match type_enum:
255
-
256
- case JointType.R:
257
- joint_type: JointGenericAxis
247
+ def compute_F():
248
+ return jaxlie.SE3.identity(), jnp.zeros(shape=(6, 1))
258
249
 
259
- pre_H_suc = jaxlie.SE3.from_rotation(
260
- rotation=jaxlie.SO3.from_matrix(
261
- Rotation.from_axis_angle(vector=s * joint_type.axis)
262
- )
250
+ def compute_R():
251
+ pre_H_suc = jaxlie.SE3.from_rotation(
252
+ rotation=jaxlie.SO3.from_matrix(
253
+ Rotation.from_axis_angle(vector=s * joint_axis)
263
254
  )
255
+ )
264
256
 
265
- S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_type.axis.squeeze()]))
266
-
267
- case JointType.P:
268
- joint_type: JointGenericAxis
269
-
270
- pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
271
- rotation=jaxlie.SO3.identity(),
272
- translation=jnp.array(s * joint_type.axis),
273
- )
257
+ S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_axis.squeeze()]))
258
+ return pre_H_suc, S
274
259
 
275
- S = jnp.vstack(jnp.hstack([joint_type.axis.squeeze(), jnp.zeros(3)]))
260
+ def compute_P():
261
+ pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
262
+ rotation=jaxlie.SO3.identity(),
263
+ translation=jnp.array(s * joint_axis),
264
+ )
276
265
 
277
- case JointType.F:
278
- pre_H_suc = jaxlie.SE3.identity()
279
- S = jnp.zeros(shape=(6, 1))
266
+ S = jnp.vstack(jnp.hstack([joint_axis.squeeze(), jnp.zeros(3)]))
267
+ return pre_H_suc, S
280
268
 
281
- case _:
282
- raise ValueError(joint_type)
269
+ pre_H_suc, S = jax.lax.switch(
270
+ index=joint_type,
271
+ branches=(
272
+ compute_F, # JointType.Fixed
273
+ compute_R, # JointType.Revolute
274
+ compute_P, # JointType.Prismatic
275
+ ),
276
+ )
283
277
 
284
278
  return pre_H_suc.as_matrix(), S
@@ -1,4 +1,4 @@
1
1
  from .collision import BoxCollision, CollidablePoint, CollisionShape, SphereCollision
2
- from .joint import JointDescription, JointDescriptor, JointGenericAxis, JointType
2
+ from .joint import JointDescription, JointGenericAxis, JointType
3
3
  from .link import LinkDescription
4
4
  from .model import ModelDescription
@@ -1,8 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
- import enum
5
- from typing import Tuple, Union
4
+ from typing import ClassVar, Tuple, Union
6
5
 
7
6
  import jax_dataclasses
8
7
  import numpy as np
@@ -14,39 +13,15 @@ from jaxsim.utils import JaxsimDataclass, Mutability
14
13
  from .link import LinkDescription
15
14
 
16
15
 
17
- @enum.unique
18
- class JointType(enum.IntEnum):
19
- """
20
- Type of supported joints.
21
- """
22
-
23
- @staticmethod
24
- def _generate_next_value_(name, start, count, last_values):
25
- # Start auto Enum value from 0 instead of 1
26
- return count
27
-
28
- #: Fixed joint.
29
- F = enum.auto()
30
-
31
- #: Revolute joint (1 DoF around axis).
32
- R = enum.auto()
33
-
34
- #: Prismatic joint (1 DoF along axis).
35
- P = enum.auto()
36
-
37
-
38
- @jax_dataclasses.pytree_dataclass
39
- class JointDescriptor:
40
- """
41
- Base class for joint types requiring to store additional metadata.
42
- """
43
-
44
- #: The joint type.
45
- joint_type: JointType
16
+ @dataclasses.dataclass(frozen=True)
17
+ class JointType:
18
+ Fixed: ClassVar[int] = 0
19
+ Revolute: ClassVar[int] = 1
20
+ Prismatic: ClassVar[int] = 2
46
21
 
47
22
 
48
23
  @jax_dataclasses.pytree_dataclass
49
- class JointGenericAxis(JointDescriptor):
24
+ class JointGenericAxis:
50
25
  """
51
26
  A joint requiring the specification of a 3D axis.
52
27
  """
@@ -55,7 +30,7 @@ class JointGenericAxis(JointDescriptor):
55
30
  axis: jtp.Vector
56
31
 
57
32
  def __hash__(self) -> int:
58
- return hash((self.joint_type, tuple(np.array(self.axis).tolist())))
33
+ return hash((tuple(np.array(self.axis).tolist())))
59
34
 
60
35
  def __eq__(self, other: JointGenericAxis) -> bool:
61
36
  if not isinstance(other, JointGenericAxis):
@@ -73,7 +48,7 @@ class JointDescription(JaxsimDataclass):
73
48
  name (str): The name of the joint.
74
49
  axis (npt.NDArray): The axis of rotation or translation for the joint.
75
50
  pose (npt.NDArray): The pose transformation matrix of the joint.
76
- jtype (Union[JointType, JointDescriptor]): The type of the joint.
51
+ jtype (JointType): The type of the joint.
77
52
  child (LinkDescription): The child link attached to the joint.
78
53
  parent (LinkDescription): The parent link attached to the joint.
79
54
  index (Optional[int]): An optional index for the joint.
@@ -89,7 +64,7 @@ class JointDescription(JaxsimDataclass):
89
64
  name: jax_dataclasses.Static[str]
90
65
  axis: npt.NDArray
91
66
  pose: npt.NDArray
92
- jtype: jax_dataclasses.Static[Union[JointType, JointDescriptor]]
67
+ jtype: jax_dataclasses.Static[JointType]
93
68
  child: LinkDescription = dataclasses.dataclass(repr=False)
94
69
  parent: LinkDescription = dataclasses.dataclass(repr=False)
95
70
 
@@ -689,6 +689,7 @@ class KinematicGraphTransforms:
689
689
  # Compute the joint transform from the predecessor to the successor frame.
690
690
  pre_H_J = self.pre_H_suc(
691
691
  joint_type=joint.jtype,
692
+ joint_axis=joint.axis,
692
693
  joint_position=self._initial_joint_positions[joint.name],
693
694
  )
694
695
 
@@ -762,14 +763,15 @@ class KinematicGraphTransforms:
762
763
 
763
764
  @staticmethod
764
765
  def pre_H_suc(
765
- joint_type: descriptions.JointType | descriptions.JointDescriptor,
766
+ joint_type: descriptions.JointType,
767
+ joint_axis: descriptions.JointGenericAxis,
766
768
  joint_position: float | None = None,
767
769
  ) -> npt.NDArray:
768
770
 
769
771
  import jaxsim.math
770
772
 
771
773
  return np.array(
772
- jaxsim.math.supported_joint_motion(
773
- joint_type=joint_type, joint_position=joint_position
774
- )[0]
774
+ jaxsim.math.supported_joint_motion(joint_type, joint_position, joint_axis)[
775
+ 0
776
+ ]
775
777
  )
@@ -352,7 +352,7 @@ def build_model_description(
352
352
  considered_joints=[
353
353
  j.name
354
354
  for j in sdf_data.joint_descriptions
355
- if j.jtype is not descriptions.JointType.F
355
+ if j.jtype is not descriptions.JointType.Fixed
356
356
  ],
357
357
  )
358
358
 
@@ -61,7 +61,7 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
61
61
 
62
62
  def joint_to_joint_type(
63
63
  joint: rod.Joint,
64
- ) -> descriptions.JointType | descriptions.JointDescriptor:
64
+ ) -> descriptions.JointType:
65
65
  """
66
66
  Extract the joint type from an SDF joint.
67
67
 
@@ -76,7 +76,7 @@ def joint_to_joint_type(
76
76
  joint_type = joint.type
77
77
 
78
78
  if joint_type == "fixed":
79
- return descriptions.JointType.F
79
+ return descriptions.JointType.Fixed
80
80
 
81
81
  if not (axis.xyz is not None and axis.xyz.xyz is not None):
82
82
  raise ValueError("Failed to read axis xyz data")
@@ -86,14 +86,10 @@ def joint_to_joint_type(
86
86
  axis_xyz = axis_xyz / np.linalg.norm(axis_xyz)
87
87
 
88
88
  if joint_type in {"revolute", "continuous"}:
89
- return descriptions.JointGenericAxis(
90
- joint_type=descriptions.JointType.R, axis=axis_xyz
91
- )
89
+ return descriptions.JointType.Revolute
92
90
 
93
91
  if joint_type == "prismatic":
94
- return descriptions.JointGenericAxis(
95
- joint_type=descriptions.JointType.P, axis=axis_xyz
96
- )
92
+ return descriptions.JointType.Prismatic
97
93
 
98
94
  raise ValueError("Joint not supported", axis_xyz, joint_type)
99
95
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.2.1.dev40
3
+ Version: 0.2.1.dev47
4
4
  Home-page: https://github.com/ami-iit/jaxsim
5
5
  Author: Diego Ferigo
6
6
  Author-email: diego.ferigo@iit.it
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes