jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1.dev401.dist-info/METADATA +0 -167
  88. jaxsim-0.1.dev401.dist-info/RECORD +0 -64
  89. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/__init__.py CHANGED
@@ -6,12 +6,12 @@ from ._version import __version__
6
6
  def _jnp_options() -> None:
7
7
  import os
8
8
 
9
- from jax.config import config
9
+ import jax
10
10
 
11
11
  # Enable by default
12
12
  if not ("JAX_ENABLE_X64" in os.environ and os.environ["JAX_ENABLE_X64"] == "0"):
13
13
  logging.info("Enabling JAX to use 64bit precision")
14
- config.update("jax_enable_x64", True)
14
+ jax.config.update("jax_enable_x64", True)
15
15
 
16
16
  import jax.numpy as jnp
17
17
  import numpy as np
@@ -61,7 +61,6 @@ del _jnp_options
61
61
  del _np_options
62
62
  del _is_editable
63
63
 
64
- from . import high_level, logging, math, simulation, sixd
65
- from .high_level.common import VelRepr
66
- from .simulation.ode_integration import IntegratorType
67
- from .simulation.simulator import JaxSim
64
+ from . import terrain # isort:skip
65
+ from . import api, integrators, logging, math, rbda
66
+ from .api.common import VelRepr
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.1.dev401'
16
- __version_tuple__ = version_tuple = (0, 1, 'dev401')
15
+ __version__ = version = '0.2.0'
16
+ __version_tuple__ = version_tuple = (0, 2, 0)
jaxsim/api/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from . import common # isort:skip
2
+ from . import model, data # isort:skip
3
+ from . import com, contact, joint, kin_dyn_parameters, link, ode, ode_data, references
jaxsim/api/com.py ADDED
@@ -0,0 +1,240 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import jaxlie
4
+
5
+ import jaxsim.api as js
6
+ import jaxsim.math
7
+ import jaxsim.typing as jtp
8
+
9
+ from .common import VelRepr
10
+
11
+
12
+ @jax.jit
13
+ def com_position(
14
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
15
+ ) -> jtp.Vector:
16
+ """
17
+ Compute the position of the center of mass of the model.
18
+
19
+ Args:
20
+ model: The model to consider.
21
+ data: The data of the considered model.
22
+
23
+ Returns:
24
+ The position of the center of mass of the model w.r.t. the world frame.
25
+ """
26
+
27
+ m = js.model.total_mass(model=model)
28
+
29
+ W_H_L = js.model.forward_kinematics(model=model, data=data)
30
+ W_H_B = data.base_transform()
31
+ B_H_W = jaxlie.SE3.from_matrix(W_H_B).inverse().as_matrix()
32
+
33
+ def B_p̃_LCoM(i) -> jtp.Vector:
34
+ m = js.link.mass(model=model, link_index=i)
35
+ L_p_LCoM = js.link.com_position(
36
+ model=model, data=data, link_index=i, in_link_frame=True
37
+ )
38
+ return m * B_H_W @ W_H_L[i] @ jnp.hstack([L_p_LCoM, 1])
39
+
40
+ com_links = jax.vmap(B_p̃_LCoM)(jnp.arange(model.number_of_links()))
41
+
42
+ B_p̃_CoM = (1 / m) * com_links.sum(axis=0)
43
+ B_p̃_CoM = B_p̃_CoM.at[3].set(1)
44
+
45
+ return (W_H_B @ B_p̃_CoM)[0:3].astype(float)
46
+
47
+
48
+ @jax.jit
49
+ def com_linear_velocity(
50
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
51
+ ) -> jtp.Vector:
52
+ r"""
53
+ Compute the linear velocity of the center of mass of the model.
54
+
55
+ Args:
56
+ model: The model to consider.
57
+ data: The data of the considered model.
58
+
59
+ Returns:
60
+ The linear velocity of the center of mass of the model in the
61
+ active representation.
62
+
63
+ Note:
64
+ The linear velocity of the center of mass is expressed in the mixed frame
65
+ :math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
66
+ active velocity representation is either inertial-fixed or mixed,
67
+ and :math:`[C] = [B]` if the active velocity representation is body-fixed.
68
+ """
69
+
70
+ # Extract the linear component of the 6D average centroidal velocity.
71
+ # This is expressed in G[B] in body-fixed representation, and in G[W] in
72
+ # inertial-fixed or mixed representation.
73
+ G_vl_WG = average_centroidal_velocity(model=model, data=data)[0:3]
74
+
75
+ return G_vl_WG
76
+
77
+
78
+ @jax.jit
79
+ def centroidal_momentum(
80
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
81
+ ) -> jtp.Vector:
82
+ r"""
83
+ Compute the centroidal momentum of the model.
84
+
85
+ Args:
86
+ model: The model to consider.
87
+ data: The data of the considered model.
88
+
89
+ Returns:
90
+ The centroidal momentum of the model.
91
+
92
+ Note:
93
+ The centroidal momentum is expressed in the mixed frame
94
+ :math:`({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`C = W` if the
95
+ active velocity representation is either inertial-fixed or mixed,
96
+ and :math:`C = B` if the active velocity representation is body-fixed.
97
+ """
98
+
99
+ ν = data.generalized_velocity()
100
+ G_J = centroidal_momentum_jacobian(model=model, data=data)
101
+
102
+ return G_J @ ν
103
+
104
+
105
+ @jax.jit
106
+ def centroidal_momentum_jacobian(
107
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
108
+ ) -> jtp.Matrix:
109
+ r"""
110
+ Compute the Jacobian of the centroidal momentum of the model.
111
+
112
+ Args:
113
+ model: The model to consider.
114
+ data: The data of the considered model.
115
+
116
+ Returns:
117
+ The Jacobian of the centroidal momentum of the model.
118
+
119
+ Note:
120
+ The frame corresponding to the output representation of this Jacobian is either
121
+ :math:`G[W]`, if the active velocity representation is inertial-fixed or mixed,
122
+ or :math:`G[B]`, if the active velocity representation is body-fixed.
123
+
124
+ Note:
125
+ This Jacobian is also known in the literature as Centroidal Momentum Matrix.
126
+ """
127
+
128
+ # Compute the Jacobian of the total momentum with body-fixed output representation.
129
+ # We convert the output representation either to G[W] or G[B] below.
130
+ B_Jh = js.model.total_momentum_jacobian(
131
+ model=model, data=data, output_vel_repr=VelRepr.Body
132
+ )
133
+
134
+ W_H_B = data.base_transform()
135
+ B_H_W = jaxsim.math.Transform.inverse(W_H_B)
136
+
137
+ W_p_CoM = com_position(model=model, data=data)
138
+
139
+ match data.velocity_representation:
140
+ case VelRepr.Inertial | VelRepr.Mixed:
141
+ W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
142
+ case VelRepr.Body:
143
+ W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM)
144
+ case _:
145
+ raise ValueError(data.velocity_representation)
146
+
147
+ # Compute the transform for 6D forces.
148
+ G_Xf_B = jaxsim.math.Adjoint.from_transform(transform=B_H_W @ W_H_G).T
149
+
150
+ return G_Xf_B @ B_Jh
151
+
152
+
153
+ @jax.jit
154
+ def locked_centroidal_spatial_inertia(
155
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
156
+ ):
157
+ """
158
+ Compute the locked centroidal spatial inertia of the model.
159
+
160
+ Args:
161
+ model: The model to consider.
162
+ data: The data of the considered model.
163
+
164
+ Returns:
165
+ The locked centroidal spatial inertia of the model.
166
+ """
167
+
168
+ with data.switch_velocity_representation(VelRepr.Body):
169
+ B_Mbb_B = js.model.locked_spatial_inertia(model=model, data=data)
170
+
171
+ W_H_B = data.base_transform()
172
+ W_p_CoM = com_position(model=model, data=data)
173
+
174
+ match data.velocity_representation:
175
+ case VelRepr.Inertial | VelRepr.Mixed:
176
+ W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
177
+ case VelRepr.Body:
178
+ W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM)
179
+ case _:
180
+ raise ValueError(data.velocity_representation)
181
+
182
+ B_H_G = jaxlie.SE3.from_matrix(jaxsim.math.Transform.inverse(W_H_B) @ W_H_G)
183
+
184
+ B_Xv_G = B_H_G.adjoint()
185
+ G_Xf_B = B_Xv_G.transpose()
186
+
187
+ return G_Xf_B @ B_Mbb_B @ B_Xv_G
188
+
189
+
190
+ @jax.jit
191
+ def average_centroidal_velocity(
192
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
193
+ ) -> jtp.Vector:
194
+ r"""
195
+ Compute the average centroidal velocity of the model.
196
+
197
+ Args:
198
+ model: The model to consider.
199
+ data: The data of the considered model.
200
+
201
+ Returns:
202
+ The average centroidal velocity of the model.
203
+
204
+ Note:
205
+ The average velocity is expressed in the mixed frame
206
+ :math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
207
+ active velocity representation is either inertial-fixed or mixed,
208
+ and :math:`[C] = [B]` if the active velocity representation is body-fixed.
209
+ """
210
+
211
+ ν = data.generalized_velocity()
212
+ G_J = average_centroidal_velocity_jacobian(model=model, data=data)
213
+
214
+ return G_J @ ν
215
+
216
+
217
+ @jax.jit
218
+ def average_centroidal_velocity_jacobian(
219
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
220
+ ) -> jtp.Matrix:
221
+ r"""
222
+ Compute the Jacobian of the average centroidal velocity of the model.
223
+
224
+ Args:
225
+ model: The model to consider.
226
+ data: The data of the considered model.
227
+
228
+ Returns:
229
+ The Jacobian of the average centroidal velocity of the model.
230
+
231
+ Note:
232
+ The frame corresponding to the output representation of this Jacobian is either
233
+ :math:`G[W]`, if the active velocity representation is inertial-fixed or mixed,
234
+ or :math:`G[B]`, if the active velocity representation is body-fixed.
235
+ """
236
+
237
+ G_J = centroidal_momentum_jacobian(model=model, data=data)
238
+ G_Mbb = locked_centroidal_spatial_inertia(model=model, data=data)
239
+
240
+ return jnp.linalg.inv(G_Mbb) @ G_J
jaxsim/api/common.py ADDED
@@ -0,0 +1,216 @@
1
+ import abc
2
+ import contextlib
3
+ import dataclasses
4
+ import enum
5
+ import functools
6
+ from typing import ContextManager
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import jax_dataclasses
11
+ import jaxlie
12
+ from jax_dataclasses import Static
13
+
14
+ import jaxsim.typing as jtp
15
+ from jaxsim.utils import JaxsimDataclass, Mutability
16
+
17
+ try:
18
+ from typing import Self
19
+ except ImportError:
20
+ from typing_extensions import Self
21
+
22
+
23
+ @enum.unique
24
+ class VelRepr(enum.IntEnum):
25
+ """
26
+ Enumeration of all supported 6D velocity representations.
27
+ """
28
+
29
+ Body = enum.auto()
30
+ Mixed = enum.auto()
31
+ Inertial = enum.auto()
32
+
33
+
34
+ @jax_dataclasses.pytree_dataclass
35
+ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
36
+ """
37
+ Base class for model data structures with velocity representation.
38
+ """
39
+
40
+ velocity_representation: Static[VelRepr] = dataclasses.field(
41
+ default=VelRepr.Inertial, kw_only=True
42
+ )
43
+
44
+ @contextlib.contextmanager
45
+ def switch_velocity_representation(
46
+ self, velocity_representation: VelRepr
47
+ ) -> ContextManager[Self]:
48
+ """
49
+ Context manager to temporarily switch the velocity representation.
50
+
51
+ Args:
52
+ velocity_representation: The new velocity representation.
53
+
54
+ Yields:
55
+ The same object with the new velocity representation.
56
+ """
57
+
58
+ original_representation = self.velocity_representation
59
+
60
+ try:
61
+
62
+ # First, we replace the velocity representation
63
+ with self.mutable_context(
64
+ mutability=Mutability.MUTABLE_NO_VALIDATION,
65
+ restore_after_exception=True,
66
+ ):
67
+ self.velocity_representation = velocity_representation
68
+
69
+ # Then, we yield the data with changed representation.
70
+ # We run this in a mutable context with restoration so that any exception
71
+ # occurring, we restore the original object in case it was modified.
72
+ with self.mutable_context(
73
+ mutability=self.mutability(), restore_after_exception=True
74
+ ):
75
+ yield self
76
+
77
+ finally:
78
+ with self.mutable_context(
79
+ mutability=Mutability.MUTABLE_NO_VALIDATION,
80
+ restore_after_exception=True,
81
+ ):
82
+ self.velocity_representation = original_representation
83
+
84
+ @staticmethod
85
+ @functools.partial(jax.jit, static_argnames=["other_representation", "is_force"])
86
+ def inertial_to_other_representation(
87
+ array: jtp.Array,
88
+ other_representation: VelRepr,
89
+ transform: jtp.Matrix,
90
+ is_force: bool = False,
91
+ ) -> jtp.Array:
92
+ r"""
93
+ Convert a 6D quantity from inertial-fixed to another representation.
94
+
95
+ Args:
96
+ array: The 6D quantity to convert.
97
+ other_representation: The representation to convert to.
98
+ transform:
99
+ The `math:W \mathbf{H}_O` transform, where `math:O` is the
100
+ reference frame of the other representation.
101
+ is_force: Whether the quantity is a 6D force or a 6D velocity.
102
+
103
+ Returns:
104
+ The 6D quantity in the other representation.
105
+ """
106
+
107
+ W_array = array.squeeze()
108
+ W_H_O = transform.squeeze()
109
+
110
+ if W_array.size != 6:
111
+ raise ValueError(W_array.size, 6)
112
+
113
+ if W_H_O.shape != (4, 4):
114
+ raise ValueError(W_H_O.shape, (4, 4))
115
+
116
+ match other_representation:
117
+
118
+ case VelRepr.Inertial:
119
+ return W_array
120
+
121
+ case VelRepr.Body:
122
+
123
+ if not is_force:
124
+ O_Xv_W = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint()
125
+ O_array = O_Xv_W @ W_array
126
+
127
+ else:
128
+ O_Xf_W = jaxlie.SE3.from_matrix(W_H_O).adjoint().T
129
+ O_array = O_Xf_W @ W_array
130
+
131
+ return O_array
132
+
133
+ case VelRepr.Mixed:
134
+ W_p_O = W_H_O[0:3, 3]
135
+ W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
136
+
137
+ if not is_force:
138
+ OW_Xv_W = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint()
139
+ OW_array = OW_Xv_W @ W_array
140
+
141
+ else:
142
+ OW_Xf_W = jaxlie.SE3.from_matrix(W_H_OW).adjoint().transpose()
143
+ OW_array = OW_Xf_W @ W_array
144
+
145
+ return OW_array
146
+
147
+ case _:
148
+ raise ValueError(other_representation)
149
+
150
+ @staticmethod
151
+ @functools.partial(jax.jit, static_argnames=["other_representation", "is_force"])
152
+ def other_representation_to_inertial(
153
+ array: jtp.Array,
154
+ other_representation: VelRepr,
155
+ transform: jtp.Matrix,
156
+ is_force: bool = False,
157
+ ) -> jtp.Array:
158
+ r"""
159
+ Convert a 6D quantity from another representation to inertial-fixed.
160
+
161
+ Args:
162
+ array: The 6D quantity to convert.
163
+ other_representation: The representation to convert from.
164
+ transform:
165
+ The `math:W \mathbf{H}_O` transform, where `math:O` is the
166
+ reference frame of the other representation.
167
+ is_force: Whether the quantity is a 6D force or a 6D velocity.
168
+
169
+ Returns:
170
+ The 6D quantity in the inertial-fixed representation.
171
+ """
172
+
173
+ W_array = array.squeeze()
174
+ W_H_O = transform.squeeze()
175
+
176
+ if W_array.size != 6:
177
+ raise ValueError(W_array.size, 6)
178
+
179
+ if W_H_O.shape != (4, 4):
180
+ raise ValueError(W_H_O.shape, (4, 4))
181
+
182
+ match other_representation:
183
+ case VelRepr.Inertial:
184
+ W_array = array
185
+ return W_array
186
+
187
+ case VelRepr.Body:
188
+ O_array = array
189
+
190
+ if not is_force:
191
+ W_Xv_O: jtp.Array = jaxlie.SE3.from_matrix(W_H_O).adjoint()
192
+ W_array = W_Xv_O @ O_array
193
+
194
+ else:
195
+ W_Xf_O = jaxlie.SE3.from_matrix(W_H_O).inverse().adjoint().T
196
+ W_array = W_Xf_O @ O_array
197
+
198
+ return W_array
199
+
200
+ case VelRepr.Mixed:
201
+ BW_array = array
202
+ W_p_O = W_H_O[0:3, 3]
203
+ W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
204
+
205
+ if not is_force:
206
+ W_Xv_BW: jtp.Array = jaxlie.SE3.from_matrix(W_H_OW).adjoint()
207
+ W_array = W_Xv_BW @ BW_array
208
+
209
+ else:
210
+ W_Xf_BW = jaxlie.SE3.from_matrix(W_H_OW).inverse().adjoint().T
211
+ W_array = W_Xf_BW @ BW_array
212
+
213
+ return W_array
214
+
215
+ case _:
216
+ raise ValueError(other_representation)