jaxsim 0.6.1.dev13__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. jaxsim/__init__.py +1 -1
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/actuation_model.py +96 -0
  5. jaxsim/api/com.py +8 -8
  6. jaxsim/api/contact.py +15 -255
  7. jaxsim/api/contact_model.py +101 -0
  8. jaxsim/api/data.py +258 -556
  9. jaxsim/api/frame.py +7 -7
  10. jaxsim/api/integrators.py +76 -0
  11. jaxsim/api/kin_dyn_parameters.py +41 -58
  12. jaxsim/api/link.py +7 -7
  13. jaxsim/api/model.py +190 -453
  14. jaxsim/api/ode.py +34 -338
  15. jaxsim/api/references.py +2 -2
  16. jaxsim/exceptions.py +2 -2
  17. jaxsim/math/__init__.py +4 -3
  18. jaxsim/math/joint_model.py +17 -107
  19. jaxsim/mujoco/model.py +1 -1
  20. jaxsim/mujoco/utils.py +2 -2
  21. jaxsim/parsers/kinematic_graph.py +1 -3
  22. jaxsim/rbda/aba.py +7 -4
  23. jaxsim/rbda/collidable_points.py +7 -98
  24. jaxsim/rbda/contacts/__init__.py +2 -10
  25. jaxsim/rbda/contacts/common.py +0 -138
  26. jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
  27. jaxsim/rbda/crba.py +5 -2
  28. jaxsim/rbda/forward_kinematics.py +37 -12
  29. jaxsim/rbda/jacobian.py +15 -6
  30. jaxsim/rbda/rnea.py +7 -4
  31. jaxsim/rbda/utils.py +3 -3
  32. jaxsim/utils/jaxsim_dataclass.py +5 -1
  33. {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
  34. jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
  35. jaxsim/api/ode_data.py +0 -401
  36. jaxsim/integrators/__init__.py +0 -2
  37. jaxsim/integrators/common.py +0 -592
  38. jaxsim/integrators/fixed_step.py +0 -153
  39. jaxsim/integrators/variable_step.py +0 -706
  40. jaxsim/rbda/contacts/rigid.py +0 -462
  41. jaxsim/rbda/contacts/soft.py +0 -480
  42. jaxsim/rbda/contacts/visco_elastic.py +0 -1066
  43. jaxsim-0.6.1.dev13.dist-info/RECORD +0 -74
  44. {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
  45. {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
  46. {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
jaxsim/api/ode_data.py DELETED
@@ -1,401 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import dataclasses
4
-
5
- import jax
6
- import jax.numpy as jnp
7
- import jax_dataclasses
8
-
9
- import jaxsim.api as js
10
- import jaxsim.typing as jtp
11
- from jaxsim.utils import JaxsimDataclass
12
-
13
- # ===================================================================
14
- # Define the state of the ODE system defining the integrated dynamics
15
- # ===================================================================
16
-
17
- # Note: the ODE system is the combination of the floating-base dynamics and the
18
- # soft-contacts dynamics.
19
-
20
-
21
- @jax_dataclasses.pytree_dataclass
22
- class ODEState(JaxsimDataclass):
23
- """
24
- The state of the ODE system.
25
-
26
- Attributes:
27
- physics_model: The state of the physics model.
28
- extended:
29
- Additional state variables extending the state vector corresponding to
30
- equations of motion. These extended variables are passed to the integrator.
31
- """
32
-
33
- physics_model: PhysicsModelState
34
-
35
- extended: dict[str, jtp.PyTree] = dataclasses.field(default_factory=dict)
36
-
37
- @staticmethod
38
- def build_from_jaxsim_model(
39
- model: js.model.JaxSimModel,
40
- joint_positions: jtp.Vector | None = None,
41
- joint_velocities: jtp.Vector | None = None,
42
- base_position: jtp.Vector | None = None,
43
- base_quaternion: jtp.Vector | None = None,
44
- base_linear_velocity: jtp.Vector | None = None,
45
- base_angular_velocity: jtp.Vector | None = None,
46
- **kwargs,
47
- ) -> ODEState:
48
- """
49
- Build an `ODEState` from a `JaxSimModel`.
50
-
51
- Args:
52
- model: The `JaxSimModel` associated with the ODE state.
53
- joint_positions: The vector of joint positions.
54
- joint_velocities: The vector of joint velocities.
55
- base_position: The 3D position of the base link.
56
- base_quaternion: The quaternion defining the orientation of the base link.
57
- base_linear_velocity:
58
- The linear velocity of the base link in inertial-fixed representation.
59
- base_angular_velocity:
60
- The angular velocity of the base link in inertial-fixed representation.
61
- kwargs:
62
- Additional arguments corresponding variables extending the default
63
- state vector of the physics model.
64
-
65
- Note:
66
- Kwargs can be used to supply any additional state variables that are passed
67
- to the integrator. This is useful to extend the default system dynamics,
68
- for example if the contact model requires additional state variables or to
69
- simulate additional dynamics like actuators or muscoloskeletal models.
70
-
71
- Returns:
72
- The `ODEState` built from the `JaxSimModel`.
73
-
74
- Note:
75
- If any of the state components are not provided, they are built from the
76
- `JaxSimModel` and initialized to zero.
77
- """
78
-
79
- # Initialize the extended state with the optional contact state.
80
- extended_state = model.contact_model.zero_state_variables(model=model)
81
-
82
- # Override the default extended state with optional kwargs.
83
- extended_state |= kwargs
84
-
85
- return ODEState.build(
86
- model=model,
87
- physics_model_state=PhysicsModelState.build_from_jaxsim_model(
88
- model=model,
89
- joint_positions=joint_positions,
90
- joint_velocities=joint_velocities,
91
- base_position=base_position,
92
- base_quaternion=base_quaternion,
93
- base_linear_velocity=base_linear_velocity,
94
- base_angular_velocity=base_angular_velocity,
95
- ),
96
- extended_state=extended_state,
97
- )
98
-
99
- @staticmethod
100
- def build(
101
- physics_model_state: PhysicsModelState | None = None,
102
- extended_state: dict[str, jtp.PyTree] | None = None,
103
- model: js.model.JaxSimModel | None = None,
104
- ) -> ODEState:
105
- """
106
- Build an `ODEState` from a `PhysicsModelState` and a `ContactsState`.
107
-
108
- Args:
109
- physics_model_state: The state of the physics model.
110
- extended_state: Additional state variables extending the state vector.
111
- model: The `JaxSimModel` associated with the ODE state.
112
-
113
- Returns:
114
- A `ODEState` instance.
115
- """
116
-
117
- # Build a zero state for the physics model if not provided.
118
- physics_model_state = (
119
- physics_model_state
120
- if physics_model_state is not None
121
- else PhysicsModelState.zero(model=model)
122
- )
123
-
124
- return ODEState(
125
- physics_model=physics_model_state,
126
- extended=extended_state,
127
- )
128
-
129
- @staticmethod
130
- def zero(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> ODEState:
131
- """
132
- Build a zero `ODEState` corresponding to a `JaxSimModel`.
133
-
134
- Args:
135
- model: The model to consider.
136
- data: The data of the considered model.
137
-
138
- Returns:
139
- A zero `ODEState` instance.
140
- """
141
-
142
- ode_state = ODEState.build(
143
- model=model,
144
- extended_state=jax.tree.map(
145
- lambda x: jnp.zeros_like(x), data.state.extended
146
- ),
147
- )
148
-
149
- return ode_state
150
-
151
- def valid(self, model: js.model.JaxSimModel) -> bool:
152
- """
153
- Check if the `ODEState` is valid for a given `JaxSimModel`.
154
-
155
- Args:
156
- model: The model to validate this `ODEState` against.
157
-
158
- Returns:
159
- `True` if the ODE state is valid for the given model, `False` otherwise.
160
- """
161
-
162
- # TODO: should we validate the extended state?
163
- return self.physics_model.valid(model=model)
164
-
165
-
166
- # ==================================================
167
- # Define the input and state of floating-base robots
168
- # ==================================================
169
-
170
-
171
- @jax_dataclasses.pytree_dataclass
172
- class PhysicsModelState(JaxsimDataclass):
173
- """
174
- Class storing the state of the physics model dynamics.
175
-
176
- Attributes:
177
- joint_positions: The vector of joint positions.
178
- joint_velocities: The vector of joint velocities.
179
- base_position: The 3D position of the base link.
180
- base_quaternion: The quaternion defining the orientation of the base link.
181
- base_linear_velocity:
182
- The linear velocity of the base link in inertial-fixed representation.
183
- base_angular_velocity:
184
- The angular velocity of the base link in inertial-fixed representation.
185
-
186
- """
187
-
188
- # Joint state
189
- joint_positions: jtp.Vector
190
- joint_velocities: jtp.Vector
191
-
192
- # Base state
193
- base_position: jtp.Vector = jax_dataclasses.field(
194
- default_factory=lambda: jnp.zeros(3)
195
- )
196
- base_quaternion: jtp.Vector = jax_dataclasses.field(
197
- default_factory=lambda: jnp.array([1.0, 0, 0, 0])
198
- )
199
- base_linear_velocity: jtp.Vector = jax_dataclasses.field(
200
- default_factory=lambda: jnp.zeros(3)
201
- )
202
- base_angular_velocity: jtp.Vector = jax_dataclasses.field(
203
- default_factory=lambda: jnp.zeros(3)
204
- )
205
-
206
- def __hash__(self) -> int:
207
-
208
- from jaxsim.utils.wrappers import HashedNumpyArray
209
-
210
- return hash(
211
- (
212
- HashedNumpyArray.hash_of_array(self.joint_positions),
213
- HashedNumpyArray.hash_of_array(self.joint_velocities),
214
- HashedNumpyArray.hash_of_array(self.base_position),
215
- HashedNumpyArray.hash_of_array(self.base_quaternion),
216
- HashedNumpyArray.hash_of_array(self.base_linear_velocity),
217
- HashedNumpyArray.hash_of_array(self.base_angular_velocity),
218
- )
219
- )
220
-
221
- def __eq__(self, other: PhysicsModelState) -> bool:
222
-
223
- if not isinstance(other, PhysicsModelState):
224
- return False
225
-
226
- return hash(self) == hash(other)
227
-
228
- @staticmethod
229
- def build_from_jaxsim_model(
230
- model: js.model.JaxSimModel | None = None,
231
- joint_positions: jtp.Vector | None = None,
232
- joint_velocities: jtp.Vector | None = None,
233
- base_position: jtp.Vector | None = None,
234
- base_quaternion: jtp.Vector | None = None,
235
- base_linear_velocity: jtp.Vector | None = None,
236
- base_angular_velocity: jtp.Vector | None = None,
237
- ) -> PhysicsModelState:
238
- """
239
- Build a `PhysicsModelState` from a `JaxSimModel`.
240
-
241
- Args:
242
- model: The `JaxSimModel` associated with the state.
243
- joint_positions: The vector of joint positions.
244
- joint_velocities: The vector of joint velocities.
245
- base_position: The 3D position of the base link.
246
- base_quaternion: The quaternion defining the orientation of the base link.
247
- base_linear_velocity:
248
- The linear velocity of the base link in inertial-fixed representation.
249
- base_angular_velocity:
250
- The angular velocity of the base link in inertial-fixed representation.
251
-
252
- Note:
253
- If any of the state components are not provided, they are built from the
254
- `JaxSimModel` and initialized to zero.
255
-
256
- Returns:
257
- A `PhysicsModelState` instance.
258
- """
259
-
260
- return PhysicsModelState.build(
261
- joint_positions=joint_positions,
262
- joint_velocities=joint_velocities,
263
- base_position=base_position,
264
- base_quaternion=base_quaternion,
265
- base_linear_velocity=base_linear_velocity,
266
- base_angular_velocity=base_angular_velocity,
267
- number_of_dofs=model.dofs(),
268
- )
269
-
270
- @staticmethod
271
- def build(
272
- joint_positions: jtp.Vector | None = None,
273
- joint_velocities: jtp.Vector | None = None,
274
- base_position: jtp.Vector | None = None,
275
- base_quaternion: jtp.Vector | None = None,
276
- base_linear_velocity: jtp.Vector | None = None,
277
- base_angular_velocity: jtp.Vector | None = None,
278
- number_of_dofs: jtp.Int | None = None,
279
- ) -> PhysicsModelState:
280
- """
281
- Build a `PhysicsModelState`.
282
-
283
- Args:
284
- joint_positions: The vector of joint positions.
285
- joint_velocities: The vector of joint velocities.
286
- base_position: The 3D position of the base link.
287
- base_quaternion: The quaternion defining the orientation of the base link.
288
- base_linear_velocity:
289
- The linear velocity of the base link in inertial-fixed representation.
290
- base_angular_velocity:
291
- The angular velocity of the base link in inertial-fixed representation.
292
- number_of_dofs:
293
- The number of degrees of freedom of the physics model.
294
-
295
- Returns:
296
- A `PhysicsModelState` instance.
297
- """
298
-
299
- joint_positions = (
300
- joint_positions
301
- if joint_positions is not None
302
- else jnp.zeros(number_of_dofs)
303
- )
304
-
305
- joint_velocities = (
306
- joint_velocities
307
- if joint_velocities is not None
308
- else jnp.zeros(number_of_dofs)
309
- )
310
-
311
- base_position = base_position if base_position is not None else jnp.zeros(3)
312
-
313
- base_quaternion = (
314
- base_quaternion
315
- if base_quaternion is not None
316
- else jnp.array([1.0, 0, 0, 0])
317
- )
318
-
319
- base_linear_velocity = (
320
- base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3)
321
- )
322
-
323
- base_angular_velocity = (
324
- base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
325
- )
326
-
327
- physics_model_state = PhysicsModelState(
328
- joint_positions=jnp.array(joint_positions, dtype=float),
329
- joint_velocities=jnp.array(joint_velocities, dtype=float),
330
- base_position=jnp.array(base_position, dtype=float),
331
- base_quaternion=jnp.array(base_quaternion, dtype=float),
332
- base_linear_velocity=jnp.array(base_linear_velocity, dtype=float),
333
- base_angular_velocity=jnp.array(base_angular_velocity, dtype=float),
334
- )
335
-
336
- # TODO (diegoferigo): assert state.valid(physics_model)
337
- return physics_model_state
338
-
339
- @staticmethod
340
- def zero(model: js.model.JaxSimModel) -> PhysicsModelState:
341
- """
342
- Build a `PhysicsModelState` with all components initialized to zero.
343
-
344
- Args:
345
- model: The `JaxSimModel` associated with the state.
346
-
347
- Returns:
348
- A `PhysicsModelState` instance.
349
- """
350
-
351
- return PhysicsModelState.build_from_jaxsim_model(model=model)
352
-
353
- def valid(self, model: js.model.JaxSimModel) -> bool:
354
- """
355
- Check if the `PhysicsModelState` is valid for a given `JaxSimModel`.
356
-
357
- Args:
358
- model: The `JaxSimModel` to validate the `PhysicsModelState` against.
359
-
360
- Returns:
361
- `True` if the `PhysicsModelState` is valid for the given model,
362
- `False` otherwise.
363
- """
364
-
365
- shape = self.joint_positions.shape
366
- expected_shape = (model.dofs(),)
367
-
368
- if shape != expected_shape:
369
- return False
370
-
371
- shape = self.joint_velocities.shape
372
- expected_shape = (model.dofs(),)
373
-
374
- if shape != expected_shape:
375
- return False
376
-
377
- shape = self.base_position.shape
378
- expected_shape = (3,)
379
-
380
- if shape != expected_shape:
381
- return False
382
-
383
- shape = self.base_quaternion.shape
384
- expected_shape = (4,)
385
-
386
- if shape != expected_shape:
387
- return False
388
-
389
- shape = self.base_linear_velocity.shape
390
- expected_shape = (3,)
391
-
392
- if shape != expected_shape:
393
- return False
394
-
395
- shape = self.base_angular_velocity.shape
396
- expected_shape = (3,)
397
-
398
- if shape != expected_shape:
399
- return False
400
-
401
- return True
@@ -1,2 +0,0 @@
1
- from . import fixed_step, variable_step
2
- from .common import Integrator, SystemDynamics, Time, TimeStep