jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/api/common.py CHANGED
@@ -1,17 +1,18 @@
1
1
  import abc
2
2
  import contextlib
3
3
  import dataclasses
4
+ import enum
4
5
  import functools
5
- from typing import ContextManager
6
+ from collections.abc import Callable, Iterator
7
+ from typing import ParamSpec, TypeVar
6
8
 
7
9
  import jax
8
10
  import jax.numpy as jnp
9
11
  import jax_dataclasses
10
- import jaxlie
11
12
  from jax_dataclasses import Static
12
13
 
13
14
  import jaxsim.typing as jtp
14
- from jaxsim.high_level.common import VelRepr
15
+ from jaxsim.math import Adjoint
15
16
  from jaxsim.utils import JaxsimDataclass, Mutability
16
17
 
17
18
  try:
@@ -20,6 +21,32 @@ except ImportError:
20
21
  from typing_extensions import Self
21
22
 
22
23
 
24
+ _P = ParamSpec("_P")
25
+ _R = TypeVar("_R")
26
+
27
+
28
+ def named_scope(fn, name: str | None = None) -> Callable[_P, _R]:
29
+ """Apply a JAX named scope to a function for improved profiling and clarity."""
30
+
31
+ @functools.wraps(fn)
32
+ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
33
+ with jax.named_scope(name or fn.__name__):
34
+ return fn(*args, **kwargs)
35
+
36
+ return wrapper
37
+
38
+
39
+ @enum.unique
40
+ class VelRepr(enum.IntEnum):
41
+ """
42
+ Enumeration of all supported 6D velocity representations.
43
+ """
44
+
45
+ Body = enum.auto()
46
+ Mixed = enum.auto()
47
+ Inertial = enum.auto()
48
+
49
+
23
50
  @jax_dataclasses.pytree_dataclass
24
51
  class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
25
52
  """
@@ -33,7 +60,7 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
33
60
  @contextlib.contextmanager
34
61
  def switch_velocity_representation(
35
62
  self, velocity_representation: VelRepr
36
- ) -> ContextManager[Self]:
63
+ ) -> Iterator[Self]:
37
64
  """
38
65
  Context manager to temporarily switch the velocity representation.
39
66
 
@@ -48,7 +75,7 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
48
75
 
49
76
  try:
50
77
 
51
- # First, we replace the velocity representation
78
+ # First, we replace the velocity representation.
52
79
  with self.mutable_context(
53
80
  mutability=Mutability.MUTABLE_NO_VALIDATION,
54
81
  restore_after_exception=True,
@@ -59,7 +86,7 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
59
86
  # We run this in a mutable context with restoration so that any exception
60
87
  # occurring, we restore the original object in case it was modified.
61
88
  with self.mutable_context(
62
- mutability=self._mutability(), restore_after_exception=True
89
+ mutability=self.mutability(), restore_after_exception=True
63
90
  ):
64
91
  yield self
65
92
 
@@ -76,16 +103,17 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
76
103
  array: jtp.Array,
77
104
  other_representation: VelRepr,
78
105
  transform: jtp.Matrix,
79
- is_force: bool = False,
106
+ *,
107
+ is_force: bool,
80
108
  ) -> jtp.Array:
81
- """
109
+ r"""
82
110
  Convert a 6D quantity from inertial-fixed to another representation.
83
111
 
84
112
  Args:
85
113
  array: The 6D quantity to convert.
86
114
  other_representation: The representation to convert to.
87
115
  transform:
88
- The `math:W \mathbf{H}_O` transform, where `math:O` is the
116
+ The :math:`W \mathbf{H}_O` transform, where :math:`O` is the
89
117
  reference frame of the other representation.
90
118
  is_force: Whether the quantity is a 6D force or a 6D velocity.
91
119
 
@@ -110,11 +138,11 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
110
138
  case VelRepr.Body:
111
139
 
112
140
  if not is_force:
113
- O_Xv_W = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint()
141
+ O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True)
114
142
  O_array = O_Xv_W @ W_array
115
143
 
116
144
  else:
117
- O_Xf_W = jaxlie.SE3.from_matrix(W_H_O).adjoint().T
145
+ O_Xf_W = Adjoint.from_transform(transform=W_H_O).T
118
146
  O_array = O_Xf_W @ W_array
119
147
 
120
148
  return O_array
@@ -124,11 +152,11 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
124
152
  W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
125
153
 
126
154
  if not is_force:
127
- OW_Xv_W = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint()
155
+ OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True)
128
156
  OW_array = OW_Xv_W @ W_array
129
157
 
130
158
  else:
131
- OW_Xf_W = jaxlie.SE3.from_matrix(W_H_OW).adjoint().transpose()
159
+ OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T
132
160
  OW_array = OW_Xf_W @ W_array
133
161
 
134
162
  return OW_array
@@ -142,9 +170,10 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
142
170
  array: jtp.Array,
143
171
  other_representation: VelRepr,
144
172
  transform: jtp.Matrix,
145
- is_force: bool = False,
173
+ *,
174
+ is_force: bool,
146
175
  ) -> jtp.Array:
147
- """
176
+ r"""
148
177
  Convert a 6D quantity from another representation to inertial-fixed.
149
178
 
150
179
  Args:
@@ -177,11 +206,11 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
177
206
  O_array = array
178
207
 
179
208
  if not is_force:
180
- W_Xv_O: jtp.Array = jaxlie.SE3.from_matrix(W_H_O).adjoint()
209
+ W_Xv_O: jtp.Array = Adjoint.from_transform(W_H_O)
181
210
  W_array = W_Xv_O @ O_array
182
211
 
183
212
  else:
184
- W_Xf_O = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint().T
213
+ W_Xf_O = Adjoint.from_transform(transform=W_H_O, inverse=True).T
185
214
  W_array = W_Xf_O @ O_array
186
215
 
187
216
  return W_array
@@ -192,11 +221,11 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
192
221
  W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
193
222
 
194
223
  if not is_force:
195
- W_Xv_BW: jtp.Array = jaxlie.SE3.from_matrix(W_H_OW).adjoint()
224
+ W_Xv_BW: jtp.Array = Adjoint.from_transform(W_H_OW)
196
225
  W_array = W_Xv_BW @ BW_array
197
226
 
198
227
  else:
199
- W_Xf_BW = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint().T
228
+ W_Xf_BW = Adjoint.from_transform(transform=W_H_OW, inverse=True).T
200
229
  W_array = W_Xf_BW @ BW_array
201
230
 
202
231
  return W_array