jaxsim 0.2.dev188__py3-none-any.whl → 0.2.dev364__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. jaxsim/__init__.py +3 -4
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +13 -2
  6. jaxsim/api/contact.py +120 -43
  7. jaxsim/api/data.py +112 -71
  8. jaxsim/api/joint.py +77 -36
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +150 -75
  11. jaxsim/api/model.py +542 -269
  12. jaxsim/api/ode.py +88 -72
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +12 -11
  15. jaxsim/integrators/__init__.py +2 -2
  16. jaxsim/integrators/common.py +110 -24
  17. jaxsim/integrators/fixed_step.py +11 -67
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +93 -0
  25. jaxsim/parsers/descriptions/collision.py +14 -0
  26. jaxsim/parsers/descriptions/link.py +13 -2
  27. jaxsim/parsers/kinematic_graph.py +5 -0
  28. jaxsim/parsers/rod/utils.py +7 -8
  29. jaxsim/rbda/__init__.py +7 -0
  30. jaxsim/rbda/aba.py +295 -0
  31. jaxsim/rbda/collidable_points.py +142 -0
  32. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  33. jaxsim/rbda/forward_kinematics.py +113 -0
  34. jaxsim/rbda/jacobian.py +201 -0
  35. jaxsim/rbda/rnea.py +237 -0
  36. jaxsim/rbda/soft_contacts.py +296 -0
  37. jaxsim/rbda/utils.py +152 -0
  38. jaxsim/terrain/__init__.py +2 -0
  39. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  40. jaxsim/utils/__init__.py +1 -4
  41. jaxsim/utils/hashless.py +18 -0
  42. jaxsim/utils/jaxsim_dataclass.py +281 -30
  43. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
  44. jaxsim-0.2.dev364.dist-info/RECORD +64 -0
  45. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
  46. jaxsim/high_level/__init__.py +0 -2
  47. jaxsim/high_level/common.py +0 -11
  48. jaxsim/high_level/joint.py +0 -148
  49. jaxsim/high_level/link.py +0 -259
  50. jaxsim/high_level/model.py +0 -1686
  51. jaxsim/math/conv.py +0 -114
  52. jaxsim/math/joint.py +0 -102
  53. jaxsim/math/plucker.py +0 -100
  54. jaxsim/physics/__init__.py +0 -12
  55. jaxsim/physics/algos/__init__.py +0 -0
  56. jaxsim/physics/algos/aba.py +0 -254
  57. jaxsim/physics/algos/aba_motors.py +0 -284
  58. jaxsim/physics/algos/forward_kinematics.py +0 -79
  59. jaxsim/physics/algos/jacobian.py +0 -98
  60. jaxsim/physics/algos/rnea.py +0 -180
  61. jaxsim/physics/algos/rnea_motors.py +0 -196
  62. jaxsim/physics/algos/soft_contacts.py +0 -523
  63. jaxsim/physics/algos/utils.py +0 -69
  64. jaxsim/physics/model/__init__.py +0 -0
  65. jaxsim/physics/model/ground_contact.py +0 -55
  66. jaxsim/physics/model/physics_model.py +0 -388
  67. jaxsim/physics/model/physics_model_state.py +0 -283
  68. jaxsim/simulation/__init__.py +0 -4
  69. jaxsim/simulation/integrators.py +0 -393
  70. jaxsim/simulation/ode.py +0 -290
  71. jaxsim/simulation/ode_data.py +0 -96
  72. jaxsim/simulation/ode_integration.py +0 -62
  73. jaxsim/simulation/simulator.py +0 -543
  74. jaxsim/simulation/simulator_callbacks.py +0 -79
  75. jaxsim/simulation/utils.py +0 -15
  76. jaxsim/sixd/__init__.py +0 -2
  77. jaxsim/utils/oop.py +0 -536
  78. jaxsim/utils/vmappable.py +0 -117
  79. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  80. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
  81. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
jaxsim/high_level/link.py DELETED
@@ -1,259 +0,0 @@
1
- import dataclasses
2
- import functools
3
- from typing import Any
4
-
5
- import jax.lax
6
- import jax.numpy as jnp
7
- import jax_dataclasses
8
- import numpy as np
9
- from jax_dataclasses import Static
10
-
11
- import jaxsim.parsers
12
- import jaxsim.typing as jtp
13
- from jaxsim import sixd
14
- from jaxsim.physics.algos.jacobian import jacobian
15
- from jaxsim.utils import Vmappable, oop
16
-
17
- from .common import VelRepr
18
-
19
-
20
- @jax_dataclasses.pytree_dataclass
21
- class Link(Vmappable):
22
- """
23
- High-level class to operate in r/o on a single link of a simulated model.
24
- """
25
-
26
- link_description: Static[jaxsim.parsers.descriptions.LinkDescription]
27
-
28
- _parent_model: Any = dataclasses.field(
29
- default=None, repr=False, compare=False, hash=False
30
- )
31
-
32
- @property
33
- def parent_model(self) -> "jaxsim.high_level.model.Model":
34
- """"""
35
-
36
- return self._parent_model
37
-
38
- @functools.partial(oop.jax_tf.method_ro, jit=False)
39
- def valid(self) -> jtp.Bool:
40
- """"""
41
-
42
- return jnp.array(self.parent_model is not None, dtype=bool)
43
-
44
- # ==========
45
- # Properties
46
- # ==========
47
-
48
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
49
- def name(self) -> str:
50
- """"""
51
-
52
- return self.link_description.name
53
-
54
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
55
- def index(self) -> jtp.Int:
56
- """"""
57
-
58
- return jnp.array(self.link_description.index, dtype=int)
59
-
60
- # ========
61
- # Dynamics
62
- # ========
63
-
64
- @functools.partial(oop.jax_tf.method_ro, jit=False)
65
- def mass(self) -> jtp.Float:
66
- """"""
67
-
68
- return jnp.array(self.link_description.mass, dtype=float)
69
-
70
- @functools.partial(oop.jax_tf.method_ro, jit=False)
71
- def spatial_inertia(self) -> jtp.Matrix:
72
- """"""
73
-
74
- return jnp.array(self.link_description.inertia, dtype=float)
75
-
76
- @functools.partial(oop.jax_tf.method_ro, vmap_in_axes=(0, None))
77
- def com_position(self, in_link_frame: bool = True) -> jtp.Vector:
78
- """"""
79
-
80
- from jaxsim.math.inertia import Inertia
81
-
82
- _, L_p_CoM, _ = Inertia.to_params(M=self.spatial_inertia())
83
-
84
- def com_in_link_frame():
85
- return L_p_CoM.squeeze()
86
-
87
- def com_in_inertial_frame():
88
- W_H_L = self.transform()
89
- W_p̃_CoM = W_H_L @ jnp.hstack([L_p_CoM.squeeze(), 1])
90
-
91
- return W_p̃_CoM[0:3].squeeze()
92
-
93
- return jax.lax.select(
94
- pred=in_link_frame,
95
- on_true=com_in_link_frame(),
96
- on_false=com_in_inertial_frame(),
97
- )
98
-
99
- # ==========
100
- # Kinematics
101
- # ==========
102
-
103
- @functools.partial(oop.jax_tf.method_ro)
104
- def position(self) -> jtp.Vector:
105
- """"""
106
-
107
- return self.transform()[0:3, 3]
108
-
109
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["dcm"])
110
- def orientation(self, dcm: bool = False) -> jtp.Vector:
111
- """"""
112
-
113
- R = self.transform()[0:3, 0:3]
114
-
115
- to_wxyz = np.array([3, 0, 1, 2])
116
- return R if dcm else sixd.so3.SO3.from_matrix(R).as_quaternion_xyzw()[to_wxyz]
117
-
118
- @functools.partial(oop.jax_tf.method_ro)
119
- def transform(self) -> jtp.Matrix:
120
- """"""
121
-
122
- return self.parent_model.forward_kinematics()[self.index()]
123
-
124
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"])
125
- def velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector:
126
- """"""
127
-
128
- v_WL = (
129
- self.jacobian(output_vel_repr=vel_repr)
130
- @ self.parent_model.generalized_velocity()
131
- )
132
-
133
- return v_WL
134
-
135
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"])
136
- def linear_velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector:
137
- """"""
138
-
139
- return self.velocity(vel_repr=vel_repr)[0:3]
140
-
141
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"])
142
- def angular_velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector:
143
- """"""
144
-
145
- return self.velocity(vel_repr=vel_repr)[3:6]
146
-
147
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["output_vel_repr"])
148
- def jacobian(self, output_vel_repr: VelRepr | None = None) -> jtp.Matrix:
149
- """"""
150
-
151
- if output_vel_repr is None:
152
- output_vel_repr = self.parent_model.velocity_representation
153
-
154
- # Compute the doubly left-trivialized free-floating jacobian
155
- L_J_WL_B = jacobian(
156
- model=self.parent_model.physics_model,
157
- body_index=self.index(),
158
- q=self.parent_model.data.model_state.joint_positions,
159
- )
160
-
161
- if self.parent_model.velocity_representation is VelRepr.Body:
162
- L_J_WL_target = L_J_WL_B
163
-
164
- elif self.parent_model.velocity_representation is VelRepr.Inertial:
165
- dofs = self.parent_model.dofs()
166
- W_H_B = self.parent_model.base_transform()
167
-
168
- B_X_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint()
169
- zero_6n = jnp.zeros(shape=(6, dofs))
170
-
171
- B_T_W = jnp.vstack(
172
- [
173
- jnp.block([B_X_W, zero_6n]),
174
- jnp.block([zero_6n.T, jnp.eye(dofs)]),
175
- ]
176
- )
177
-
178
- L_J_WL_target = L_J_WL_B @ B_T_W
179
-
180
- elif self.parent_model.velocity_representation is VelRepr.Mixed:
181
- dofs = self.parent_model.dofs()
182
- W_H_B = self.parent_model.base_transform()
183
- BW_H_B = jnp.array(W_H_B).at[0:3, 3].set(jnp.zeros(3))
184
-
185
- B_X_BW = sixd.se3.SE3.from_matrix(BW_H_B).inverse().adjoint()
186
- zero_6n = jnp.zeros(shape=(6, dofs))
187
-
188
- B_T_BW = jnp.vstack(
189
- [
190
- jnp.block([B_X_BW, zero_6n]),
191
- jnp.block([zero_6n.T, jnp.eye(dofs)]),
192
- ]
193
- )
194
-
195
- L_J_WL_target = L_J_WL_B @ B_T_BW
196
-
197
- else:
198
- raise ValueError(self.parent_model.velocity_representation)
199
-
200
- if output_vel_repr is VelRepr.Body:
201
- return L_J_WL_target
202
-
203
- elif output_vel_repr is VelRepr.Inertial:
204
- W_H_L = self.transform()
205
- W_X_L = sixd.se3.SE3.from_matrix(W_H_L).adjoint()
206
- return W_X_L @ L_J_WL_target
207
-
208
- elif output_vel_repr is VelRepr.Mixed:
209
- W_H_L = self.transform()
210
- LW_H_L = jnp.array(W_H_L).at[0:3, 3].set(jnp.zeros(3))
211
- LW_X_L = sixd.se3.SE3.from_matrix(LW_H_L).adjoint()
212
- return LW_X_L @ L_J_WL_target
213
-
214
- else:
215
- raise ValueError(output_vel_repr)
216
-
217
- @functools.partial(oop.jax_tf.method_ro)
218
- def external_force(self) -> jtp.Vector:
219
- """
220
- Return the active external force acting on the link.
221
-
222
- This external force is a user input and is not computed by the physics engine.
223
- During the simulation, this external force is summed to other terms like those
224
- related to enforce contact constraints.
225
-
226
- Returns:
227
- The active external 6D force acting on the link in the active representation.
228
- """
229
-
230
- # Get the external force stored in the inertial representation
231
- W_f_ext = self.parent_model.data.model_input.f_ext[self.index()]
232
-
233
- # Express it in the active representation
234
- if self.parent_model.velocity_representation is VelRepr.Inertial:
235
- f_ext = W_f_ext
236
-
237
- elif self.parent_model.velocity_representation is VelRepr.Body:
238
- W_H_L = self.transform()
239
- W_X_L = sixd.se3.SE3.from_matrix(W_H_L).adjoint()
240
-
241
- f_ext = L_f_ext = W_X_L.transpose() @ W_f_ext
242
-
243
- elif self.parent_model.velocity_representation is VelRepr.Mixed:
244
- W_p_L = self.transform()[0:3, 3]
245
- W_H_LW = jnp.eye(4).at[0:3, 3].set(W_p_L)
246
- W_X_LW = sixd.se3.SE3.from_matrix(W_H_LW).adjoint()
247
-
248
- f_ext = LW_f_ext = W_X_LW.transpose() @ W_f_ext
249
-
250
- else:
251
- raise ValueError(self.parent_model.velocity_representation)
252
-
253
- return f_ext
254
-
255
- @functools.partial(oop.jax_tf.method_ro)
256
- def in_contact(self) -> jtp.Bool:
257
- """"""
258
-
259
- return self.parent_model.in_contact()[self.index()]