jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +64 -30
  24. jaxsim/math/cross.py +18 -9
  25. jaxsim/math/inertia.py +11 -9
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +59 -25
  28. jaxsim/math/rotation.py +30 -24
  29. jaxsim/math/skew.py +18 -7
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev5.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev5.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/top_level.txt +0 -0
@@ -1,148 +0,0 @@
1
- import dataclasses
2
- import functools
3
- from typing import Any
4
-
5
- import jax.numpy as jnp
6
- import jax_dataclasses
7
- from jax_dataclasses import Static
8
-
9
- import jaxsim.parsers
10
- import jaxsim.typing as jtp
11
- from jaxsim.utils import Vmappable, not_tracing, oop
12
-
13
-
14
- @jax_dataclasses.pytree_dataclass
15
- class Joint(Vmappable):
16
- """
17
- High-level class to operate in r/o on a single joint of a simulated model.
18
- """
19
-
20
- joint_description: Static[jaxsim.parsers.descriptions.JointDescription]
21
-
22
- _parent_model: Any = dataclasses.field(
23
- default=None, repr=False, compare=False, hash=False
24
- )
25
-
26
- @property
27
- def parent_model(self) -> "jaxsim.high_level.model.Model":
28
- """"""
29
-
30
- return self._parent_model
31
-
32
- @functools.partial(oop.jax_tf.method_ro, jit=False)
33
- def valid(self) -> jtp.Bool:
34
- """"""
35
-
36
- return jnp.array(self.parent_model is not None, dtype=bool)
37
-
38
- @functools.partial(oop.jax_tf.method_ro, jit=False)
39
- def index(self) -> jtp.Int:
40
- """"""
41
-
42
- return jnp.array(self.joint_description.index, dtype=int)
43
-
44
- @functools.partial(oop.jax_tf.method_ro)
45
- def dofs(self) -> jtp.Int:
46
- """"""
47
-
48
- return jnp.array(1, dtype=int)
49
-
50
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
51
- def name(self) -> str:
52
- """"""
53
-
54
- return self.joint_description.name
55
-
56
- @functools.partial(oop.jax_tf.method_ro)
57
- def position(self, dof: int | None = None) -> jtp.Float:
58
- """"""
59
-
60
- dof = dof if dof is not None else 0
61
-
62
- return jnp.array(
63
- self.parent_model.joint_positions(joint_names=(self.name(),))[dof],
64
- dtype=float,
65
- )
66
-
67
- @functools.partial(oop.jax_tf.method_ro)
68
- def velocity(self, dof: int | None = None) -> jtp.Float:
69
- """"""
70
-
71
- dof = dof if dof is not None else 0
72
-
73
- return jnp.array(
74
- self.parent_model.joint_velocities(joint_names=(self.name(),))[dof],
75
- dtype=float,
76
- )
77
-
78
- @functools.partial(oop.jax_tf.method_ro)
79
- def force_target(self, dof: int | None = None) -> jtp.Float:
80
- """"""
81
-
82
- dof = dof if dof is not None else 0
83
-
84
- return jnp.array(
85
- self.parent_model.joint_generalized_forces_targets(
86
- joint_names=(self.name(),)
87
- )[dof],
88
- dtype=float,
89
- )
90
-
91
- @functools.partial(oop.jax_tf.method_ro)
92
- def position_limit(self, dof: int | None = None) -> tuple[jtp.Float, jtp.Float]:
93
- """"""
94
-
95
- dof = dof if dof is not None else 0
96
-
97
- if not_tracing(dof) and dof != 0:
98
- msg = "Only joints with 1 DoF are currently supported"
99
- raise ValueError(msg)
100
-
101
- low, high = self.joint_description.position_limit
102
-
103
- return jnp.array(low, dtype=float), jnp.array(high, dtype=float)
104
-
105
- # =============
106
- # Motor methods
107
- # =============
108
- @functools.partial(oop.jax_tf.method_ro)
109
- def motor_inertia(self) -> jtp.Vector:
110
- """"""
111
-
112
- return jnp.array(self.joint_description.motor_inertia, dtype=float)
113
-
114
- @functools.partial(oop.jax_tf.method_ro)
115
- def motor_gear_ratio(self) -> jtp.Vector:
116
- """"""
117
-
118
- return jnp.array(self.joint_description.motor_gear_ratio, dtype=float)
119
-
120
- @functools.partial(oop.jax_tf.method_ro)
121
- def motor_viscous_friction(self) -> jtp.Vector:
122
- """"""
123
-
124
- return jnp.array(self.joint_description.motor_viscous_friction, dtype=float)
125
-
126
- # =================
127
- # Multi-DoF methods
128
- # =================
129
-
130
- @functools.partial(oop.jax_tf.method_ro)
131
- def joint_position(self) -> jtp.Vector:
132
- """"""
133
-
134
- return self.parent_model.joint_positions(joint_names=(self.name(),))
135
-
136
- @functools.partial(oop.jax_tf.method_ro)
137
- def joint_velocity(self) -> jtp.Vector:
138
- """"""
139
-
140
- return self.parent_model.joint_velocities(joint_names=(self.name(),))
141
-
142
- @functools.partial(oop.jax_tf.method_ro)
143
- def joint_force_target(self) -> jtp.Vector:
144
- """"""
145
-
146
- return self.parent_model.joint_generalized_forces_targets(
147
- joint_names=(self.name(),)
148
- )
jaxsim/high_level/link.py DELETED
@@ -1,259 +0,0 @@
1
- import dataclasses
2
- import functools
3
- from typing import Any
4
-
5
- import jax.lax
6
- import jax.numpy as jnp
7
- import jax_dataclasses
8
- import numpy as np
9
- from jax_dataclasses import Static
10
-
11
- import jaxsim.parsers
12
- import jaxsim.typing as jtp
13
- from jaxsim import sixd
14
- from jaxsim.physics.algos.jacobian import jacobian
15
- from jaxsim.utils import Vmappable, oop
16
-
17
- from .common import VelRepr
18
-
19
-
20
- @jax_dataclasses.pytree_dataclass
21
- class Link(Vmappable):
22
- """
23
- High-level class to operate in r/o on a single link of a simulated model.
24
- """
25
-
26
- link_description: Static[jaxsim.parsers.descriptions.LinkDescription]
27
-
28
- _parent_model: Any = dataclasses.field(
29
- default=None, repr=False, compare=False, hash=False
30
- )
31
-
32
- @property
33
- def parent_model(self) -> "jaxsim.high_level.model.Model":
34
- """"""
35
-
36
- return self._parent_model
37
-
38
- @functools.partial(oop.jax_tf.method_ro, jit=False)
39
- def valid(self) -> jtp.Bool:
40
- """"""
41
-
42
- return jnp.array(self.parent_model is not None, dtype=bool)
43
-
44
- # ==========
45
- # Properties
46
- # ==========
47
-
48
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
49
- def name(self) -> str:
50
- """"""
51
-
52
- return self.link_description.name
53
-
54
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
55
- def index(self) -> jtp.Int:
56
- """"""
57
-
58
- return jnp.array(self.link_description.index, dtype=int)
59
-
60
- # ========
61
- # Dynamics
62
- # ========
63
-
64
- @functools.partial(oop.jax_tf.method_ro, jit=False)
65
- def mass(self) -> jtp.Float:
66
- """"""
67
-
68
- return jnp.array(self.link_description.mass, dtype=float)
69
-
70
- @functools.partial(oop.jax_tf.method_ro, jit=False)
71
- def spatial_inertia(self) -> jtp.Matrix:
72
- """"""
73
-
74
- return jnp.array(self.link_description.inertia, dtype=float)
75
-
76
- @functools.partial(oop.jax_tf.method_ro, vmap_in_axes=(0, None))
77
- def com_position(self, in_link_frame: bool = True) -> jtp.Vector:
78
- """"""
79
-
80
- from jaxsim.math.inertia import Inertia
81
-
82
- _, L_p_CoM, _ = Inertia.to_params(M=self.spatial_inertia())
83
-
84
- def com_in_link_frame():
85
- return L_p_CoM.squeeze()
86
-
87
- def com_in_inertial_frame():
88
- W_H_L = self.transform()
89
- W_p̃_CoM = W_H_L @ jnp.hstack([L_p_CoM.squeeze(), 1])
90
-
91
- return W_p̃_CoM[0:3].squeeze()
92
-
93
- return jax.lax.select(
94
- pred=in_link_frame,
95
- on_true=com_in_link_frame(),
96
- on_false=com_in_inertial_frame(),
97
- )
98
-
99
- # ==========
100
- # Kinematics
101
- # ==========
102
-
103
- @functools.partial(oop.jax_tf.method_ro)
104
- def position(self) -> jtp.Vector:
105
- """"""
106
-
107
- return self.transform()[0:3, 3]
108
-
109
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["dcm"])
110
- def orientation(self, dcm: bool = False) -> jtp.Vector:
111
- """"""
112
-
113
- R = self.transform()[0:3, 0:3]
114
-
115
- to_wxyz = np.array([3, 0, 1, 2])
116
- return R if dcm else sixd.so3.SO3.from_matrix(R).as_quaternion_xyzw()[to_wxyz]
117
-
118
- @functools.partial(oop.jax_tf.method_ro)
119
- def transform(self) -> jtp.Matrix:
120
- """"""
121
-
122
- return self.parent_model.forward_kinematics()[self.index()]
123
-
124
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"])
125
- def velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector:
126
- """"""
127
-
128
- v_WL = (
129
- self.jacobian(output_vel_repr=vel_repr)
130
- @ self.parent_model.generalized_velocity()
131
- )
132
-
133
- return v_WL
134
-
135
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"])
136
- def linear_velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector:
137
- """"""
138
-
139
- return self.velocity(vel_repr=vel_repr)[0:3]
140
-
141
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"])
142
- def angular_velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector:
143
- """"""
144
-
145
- return self.velocity(vel_repr=vel_repr)[3:6]
146
-
147
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["output_vel_repr"])
148
- def jacobian(self, output_vel_repr: VelRepr | None = None) -> jtp.Matrix:
149
- """"""
150
-
151
- if output_vel_repr is None:
152
- output_vel_repr = self.parent_model.velocity_representation
153
-
154
- # Compute the doubly left-trivialized free-floating jacobian
155
- L_J_WL_B = jacobian(
156
- model=self.parent_model.physics_model,
157
- body_index=self.index(),
158
- q=self.parent_model.data.model_state.joint_positions,
159
- )
160
-
161
- if self.parent_model.velocity_representation is VelRepr.Body:
162
- L_J_WL_target = L_J_WL_B
163
-
164
- elif self.parent_model.velocity_representation is VelRepr.Inertial:
165
- dofs = self.parent_model.dofs()
166
- W_H_B = self.parent_model.base_transform()
167
-
168
- B_X_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint()
169
- zero_6n = jnp.zeros(shape=(6, dofs))
170
-
171
- B_T_W = jnp.vstack(
172
- [
173
- jnp.block([B_X_W, zero_6n]),
174
- jnp.block([zero_6n.T, jnp.eye(dofs)]),
175
- ]
176
- )
177
-
178
- L_J_WL_target = L_J_WL_B @ B_T_W
179
-
180
- elif self.parent_model.velocity_representation is VelRepr.Mixed:
181
- dofs = self.parent_model.dofs()
182
- W_H_B = self.parent_model.base_transform()
183
- BW_H_B = jnp.array(W_H_B).at[0:3, 3].set(jnp.zeros(3))
184
-
185
- B_X_BW = sixd.se3.SE3.from_matrix(BW_H_B).inverse().adjoint()
186
- zero_6n = jnp.zeros(shape=(6, dofs))
187
-
188
- B_T_BW = jnp.vstack(
189
- [
190
- jnp.block([B_X_BW, zero_6n]),
191
- jnp.block([zero_6n.T, jnp.eye(dofs)]),
192
- ]
193
- )
194
-
195
- L_J_WL_target = L_J_WL_B @ B_T_BW
196
-
197
- else:
198
- raise ValueError(self.parent_model.velocity_representation)
199
-
200
- if output_vel_repr is VelRepr.Body:
201
- return L_J_WL_target
202
-
203
- elif output_vel_repr is VelRepr.Inertial:
204
- W_H_L = self.transform()
205
- W_X_L = sixd.se3.SE3.from_matrix(W_H_L).adjoint()
206
- return W_X_L @ L_J_WL_target
207
-
208
- elif output_vel_repr is VelRepr.Mixed:
209
- W_H_L = self.transform()
210
- LW_H_L = jnp.array(W_H_L).at[0:3, 3].set(jnp.zeros(3))
211
- LW_X_L = sixd.se3.SE3.from_matrix(LW_H_L).adjoint()
212
- return LW_X_L @ L_J_WL_target
213
-
214
- else:
215
- raise ValueError(output_vel_repr)
216
-
217
- @functools.partial(oop.jax_tf.method_ro)
218
- def external_force(self) -> jtp.Vector:
219
- """
220
- Return the active external force acting on the link.
221
-
222
- This external force is a user input and is not computed by the physics engine.
223
- During the simulation, this external force is summed to other terms like those
224
- related to enforce contact constraints.
225
-
226
- Returns:
227
- The active external 6D force acting on the link in the active representation.
228
- """
229
-
230
- # Get the external force stored in the inertial representation
231
- W_f_ext = self.parent_model.data.model_input.f_ext[self.index()]
232
-
233
- # Express it in the active representation
234
- if self.parent_model.velocity_representation is VelRepr.Inertial:
235
- f_ext = W_f_ext
236
-
237
- elif self.parent_model.velocity_representation is VelRepr.Body:
238
- W_H_L = self.transform()
239
- W_X_L = sixd.se3.SE3.from_matrix(W_H_L).adjoint()
240
-
241
- f_ext = L_f_ext = W_X_L.transpose() @ W_f_ext
242
-
243
- elif self.parent_model.velocity_representation is VelRepr.Mixed:
244
- W_p_L = self.transform()[0:3, 3]
245
- W_H_LW = jnp.eye(4).at[0:3, 3].set(W_p_L)
246
- W_X_LW = sixd.se3.SE3.from_matrix(W_H_LW).adjoint()
247
-
248
- f_ext = LW_f_ext = W_X_LW.transpose() @ W_f_ext
249
-
250
- else:
251
- raise ValueError(self.parent_model.velocity_representation)
252
-
253
- return f_ext
254
-
255
- @functools.partial(oop.jax_tf.method_ro)
256
- def in_contact(self) -> jtp.Bool:
257
- """"""
258
-
259
- return self.parent_model.in_contact()[self.index()]