jaxsim 0.1rc0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1rc0.dist-info/METADATA +0 -167
  88. jaxsim-0.1rc0.dist-info/RECORD +0 -64
  89. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/high_level/link.py DELETED
@@ -1,259 +0,0 @@
1
- import dataclasses
2
- import functools
3
- from typing import Any
4
-
5
- import jax.lax
6
- import jax.numpy as jnp
7
- import jax_dataclasses
8
- import numpy as np
9
- from jax_dataclasses import Static
10
-
11
- import jaxsim.parsers
12
- import jaxsim.typing as jtp
13
- from jaxsim import sixd
14
- from jaxsim.physics.algos.jacobian import jacobian
15
- from jaxsim.utils import Vmappable, oop
16
-
17
- from .common import VelRepr
18
-
19
-
20
- @jax_dataclasses.pytree_dataclass
21
- class Link(Vmappable):
22
- """
23
- High-level class to operate in r/o on a single link of a simulated model.
24
- """
25
-
26
- link_description: Static[jaxsim.parsers.descriptions.LinkDescription]
27
-
28
- _parent_model: Any = dataclasses.field(
29
- default=None, repr=False, compare=False, hash=False
30
- )
31
-
32
- @property
33
- def parent_model(self) -> "jaxsim.high_level.model.Model":
34
- """"""
35
-
36
- return self._parent_model
37
-
38
- @functools.partial(oop.jax_tf.method_ro, jit=False)
39
- def valid(self) -> jtp.Bool:
40
- """"""
41
-
42
- return jnp.array(self.parent_model is not None, dtype=bool)
43
-
44
- # ==========
45
- # Properties
46
- # ==========
47
-
48
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
49
- def name(self) -> str:
50
- """"""
51
-
52
- return self.link_description.name
53
-
54
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
55
- def index(self) -> jtp.Int:
56
- """"""
57
-
58
- return jnp.array(self.link_description.index, dtype=int)
59
-
60
- # ========
61
- # Dynamics
62
- # ========
63
-
64
- @functools.partial(oop.jax_tf.method_ro, jit=False)
65
- def mass(self) -> jtp.Float:
66
- """"""
67
-
68
- return jnp.array(self.link_description.mass, dtype=float)
69
-
70
- @functools.partial(oop.jax_tf.method_ro, jit=False)
71
- def spatial_inertia(self) -> jtp.Matrix:
72
- """"""
73
-
74
- return jnp.array(self.link_description.inertia, dtype=float)
75
-
76
- @functools.partial(oop.jax_tf.method_ro, vmap_in_axes=(0, None))
77
- def com_position(self, in_link_frame: bool = True) -> jtp.Vector:
78
- """"""
79
-
80
- from jaxsim.math.inertia import Inertia
81
-
82
- _, L_p_CoM, _ = Inertia.to_params(M=self.spatial_inertia())
83
-
84
- def com_in_link_frame():
85
- return L_p_CoM.squeeze()
86
-
87
- def com_in_inertial_frame():
88
- W_H_L = self.transform()
89
- W_p̃_CoM = W_H_L @ jnp.hstack([L_p_CoM.squeeze(), 1])
90
-
91
- return W_p̃_CoM[0:3].squeeze()
92
-
93
- return jax.lax.select(
94
- pred=in_link_frame,
95
- on_true=com_in_link_frame(),
96
- on_false=com_in_inertial_frame(),
97
- )
98
-
99
- # ==========
100
- # Kinematics
101
- # ==========
102
-
103
- @functools.partial(oop.jax_tf.method_ro)
104
- def position(self) -> jtp.Vector:
105
- """"""
106
-
107
- return self.transform()[0:3, 3]
108
-
109
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["dcm"])
110
- def orientation(self, dcm: bool = False) -> jtp.Vector:
111
- """"""
112
-
113
- R = self.transform()[0:3, 0:3]
114
-
115
- to_wxyz = np.array([3, 0, 1, 2])
116
- return R if dcm else sixd.so3.SO3.from_matrix(R).as_quaternion_xyzw()[to_wxyz]
117
-
118
- @functools.partial(oop.jax_tf.method_ro)
119
- def transform(self) -> jtp.Matrix:
120
- """"""
121
-
122
- return self.parent_model.forward_kinematics()[self.index()]
123
-
124
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"])
125
- def velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector:
126
- """"""
127
-
128
- v_WL = (
129
- self.jacobian(output_vel_repr=vel_repr)
130
- @ self.parent_model.generalized_velocity()
131
- )
132
-
133
- return v_WL
134
-
135
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"])
136
- def linear_velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector:
137
- """"""
138
-
139
- return self.velocity(vel_repr=vel_repr)[0:3]
140
-
141
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"])
142
- def angular_velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector:
143
- """"""
144
-
145
- return self.velocity(vel_repr=vel_repr)[3:6]
146
-
147
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["output_vel_repr"])
148
- def jacobian(self, output_vel_repr: VelRepr | None = None) -> jtp.Matrix:
149
- """"""
150
-
151
- if output_vel_repr is None:
152
- output_vel_repr = self.parent_model.velocity_representation
153
-
154
- # Compute the doubly left-trivialized free-floating jacobian
155
- L_J_WL_B = jacobian(
156
- model=self.parent_model.physics_model,
157
- body_index=self.index(),
158
- q=self.parent_model.data.model_state.joint_positions,
159
- )
160
-
161
- if self.parent_model.velocity_representation is VelRepr.Body:
162
- L_J_WL_target = L_J_WL_B
163
-
164
- elif self.parent_model.velocity_representation is VelRepr.Inertial:
165
- dofs = self.parent_model.dofs()
166
- W_H_B = self.parent_model.base_transform()
167
-
168
- B_X_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint()
169
- zero_6n = jnp.zeros(shape=(6, dofs))
170
-
171
- B_T_W = jnp.vstack(
172
- [
173
- jnp.block([B_X_W, zero_6n]),
174
- jnp.block([zero_6n.T, jnp.eye(dofs)]),
175
- ]
176
- )
177
-
178
- L_J_WL_target = L_J_WL_B @ B_T_W
179
-
180
- elif self.parent_model.velocity_representation is VelRepr.Mixed:
181
- dofs = self.parent_model.dofs()
182
- W_H_B = self.parent_model.base_transform()
183
- BW_H_B = jnp.array(W_H_B).at[0:3, 3].set(jnp.zeros(3))
184
-
185
- B_X_BW = sixd.se3.SE3.from_matrix(BW_H_B).inverse().adjoint()
186
- zero_6n = jnp.zeros(shape=(6, dofs))
187
-
188
- B_T_BW = jnp.vstack(
189
- [
190
- jnp.block([B_X_BW, zero_6n]),
191
- jnp.block([zero_6n.T, jnp.eye(dofs)]),
192
- ]
193
- )
194
-
195
- L_J_WL_target = L_J_WL_B @ B_T_BW
196
-
197
- else:
198
- raise ValueError(self.parent_model.velocity_representation)
199
-
200
- if output_vel_repr is VelRepr.Body:
201
- return L_J_WL_target
202
-
203
- elif output_vel_repr is VelRepr.Inertial:
204
- W_H_L = self.transform()
205
- W_X_L = sixd.se3.SE3.from_matrix(W_H_L).adjoint()
206
- return W_X_L @ L_J_WL_target
207
-
208
- elif output_vel_repr is VelRepr.Mixed:
209
- W_H_L = self.transform()
210
- LW_H_L = jnp.array(W_H_L).at[0:3, 3].set(jnp.zeros(3))
211
- LW_X_L = sixd.se3.SE3.from_matrix(LW_H_L).adjoint()
212
- return LW_X_L @ L_J_WL_target
213
-
214
- else:
215
- raise ValueError(output_vel_repr)
216
-
217
- @functools.partial(oop.jax_tf.method_ro)
218
- def external_force(self) -> jtp.Vector:
219
- """
220
- Return the active external force acting on the link.
221
-
222
- This external force is a user input and is not computed by the physics engine.
223
- During the simulation, this external force is summed to other terms like those
224
- related to enforce contact constraints.
225
-
226
- Returns:
227
- The active external 6D force acting on the link in the active representation.
228
- """
229
-
230
- # Get the external force stored in the inertial representation
231
- W_f_ext = self.parent_model.data.model_input.f_ext[self.index()]
232
-
233
- # Express it in the active representation
234
- if self.parent_model.velocity_representation is VelRepr.Inertial:
235
- f_ext = W_f_ext
236
-
237
- elif self.parent_model.velocity_representation is VelRepr.Body:
238
- W_H_L = self.transform()
239
- W_X_L = sixd.se3.SE3.from_matrix(W_H_L).adjoint()
240
-
241
- f_ext = L_f_ext = W_X_L.transpose() @ W_f_ext
242
-
243
- elif self.parent_model.velocity_representation is VelRepr.Mixed:
244
- W_p_L = self.transform()[0:3, 3]
245
- W_H_LW = jnp.eye(4).at[0:3, 3].set(W_p_L)
246
- W_X_LW = sixd.se3.SE3.from_matrix(W_H_LW).adjoint()
247
-
248
- f_ext = LW_f_ext = W_X_LW.transpose() @ W_f_ext
249
-
250
- else:
251
- raise ValueError(self.parent_model.velocity_representation)
252
-
253
- return f_ext
254
-
255
- @functools.partial(oop.jax_tf.method_ro)
256
- def in_contact(self) -> jtp.Bool:
257
- """"""
258
-
259
- return self.parent_model.in_contact()[self.index()]