jaxsim 0.1rc0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1rc0.dist-info/METADATA +0 -167
  88. jaxsim-0.1rc0.dist-info/RECORD +0 -64
  89. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,421 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import jax_dataclasses
8
+
9
+ import jaxsim.api as js
10
+ import jaxsim.typing as jtp
11
+ from jaxsim.utils.tracing import not_tracing
12
+
13
+ from .common import VelRepr
14
+ from .ode_data import ODEInput
15
+
16
+ try:
17
+ from typing import Self
18
+ except ImportError:
19
+ from typing_extensions import Self
20
+
21
+
22
+ @jax_dataclasses.pytree_dataclass
23
+ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
24
+ """
25
+ Class containing the references for a `JaxSimModel` object.
26
+ """
27
+
28
+ input: ODEInput
29
+
30
+ @staticmethod
31
+ def zero(
32
+ model: js.model.JaxSimModel,
33
+ velocity_representation: VelRepr = VelRepr.Inertial,
34
+ ) -> JaxSimModelReferences:
35
+ """
36
+ Create a `JaxSimModelReferences` object with zero references.
37
+
38
+ Args:
39
+ model: The model for which to create the zero references.
40
+ velocity_representation: The velocity representation to use.
41
+
42
+ Returns:
43
+ A `JaxSimModelReferences` object with zero state.
44
+ """
45
+
46
+ return JaxSimModelReferences.build(
47
+ model=model, velocity_representation=velocity_representation
48
+ )
49
+
50
+ @staticmethod
51
+ def build(
52
+ model: js.model.JaxSimModel,
53
+ joint_force_references: jtp.Vector | None = None,
54
+ link_forces: jtp.Matrix | None = None,
55
+ data: js.data.JaxSimModelData | None = None,
56
+ velocity_representation: VelRepr | None = None,
57
+ ) -> JaxSimModelReferences:
58
+ """
59
+ Create a `JaxSimModelReferences` object with the given references.
60
+
61
+ Args:
62
+ model: The model for which to create the state.
63
+ joint_force_references: The joint force references.
64
+ link_forces: The link 6D forces in the desired representation.
65
+ data:
66
+ The data of the model, only needed if the velocity representation is
67
+ not inertial-fixed.
68
+ velocity_representation: The velocity representation to use.
69
+
70
+ Returns:
71
+ A `JaxSimModelReferences` object with the given references.
72
+ """
73
+
74
+ # Create or adjust joint force references.
75
+ joint_force_references = jnp.atleast_1d(
76
+ joint_force_references.squeeze()
77
+ if joint_force_references is not None
78
+ else jnp.zeros(model.dofs())
79
+ ).astype(float)
80
+
81
+ # Create or adjust link forces.
82
+ f_L = jnp.atleast_2d(
83
+ link_forces.squeeze()
84
+ if link_forces is not None
85
+ else jnp.zeros((model.number_of_links(), 6))
86
+ ).astype(float)
87
+
88
+ # Select the velocity representation.
89
+ velocity_representation = (
90
+ velocity_representation
91
+ if velocity_representation is not None
92
+ else (
93
+ data.velocity_representation if data is not None else VelRepr.Inertial
94
+ )
95
+ )
96
+
97
+ # Create a zero references object.
98
+ references = JaxSimModelReferences(
99
+ input=ODEInput.zero(model=model),
100
+ velocity_representation=velocity_representation,
101
+ )
102
+
103
+ # Store the joint force references.
104
+ references = references.set_joint_force_references(
105
+ forces=joint_force_references,
106
+ model=model,
107
+ joint_names=model.joint_names(),
108
+ )
109
+
110
+ # Apply the link forces.
111
+ references = references.apply_link_forces(
112
+ forces=f_L,
113
+ model=model,
114
+ data=data,
115
+ link_names=model.link_names(),
116
+ additive=False,
117
+ )
118
+
119
+ return references
120
+
121
+ def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
122
+ """
123
+ Check if the current references are valid for the given model.
124
+
125
+ Args:
126
+ model: The model to check against.
127
+
128
+ Returns:
129
+ `True` if the current references are valid for the given model,
130
+ `False` otherwise.
131
+ """
132
+
133
+ valid = True
134
+
135
+ if model is not None:
136
+ valid = valid and self.input.valid(model=model)
137
+
138
+ return valid
139
+
140
+ # ==================
141
+ # Extract quantities
142
+ # ==================
143
+
144
+ @functools.partial(jax.jit, static_argnames=["link_names"])
145
+ def link_forces(
146
+ self,
147
+ model: js.model.JaxSimModel | None = None,
148
+ data: js.data.JaxSimModelData | None = None,
149
+ link_names: tuple[str, ...] | None = None,
150
+ ) -> jtp.Matrix:
151
+ """
152
+ Return the link forces expressed in the frame of the active representation.
153
+
154
+ Args:
155
+ model: The model to consider.
156
+ data: The data of the considered model.
157
+ link_names: The names of the links corresponding to the forces.
158
+
159
+ Returns:
160
+ If no model and no link names are provided, the link forces as a
161
+ `(n_links,6)` matrix corresponding to the default link serialization
162
+ of the original model used to build the actuation object.
163
+ If a model is provided and no link names are provided, the link forces
164
+ as a `(n_links,6)` matrix corresponding to the serialization of the
165
+ provided model.
166
+ If both a model and link names are provided, the link forces as a
167
+ `(len(link_names),6)` matrix corresponding to the serialization of
168
+ the passed link names vector.
169
+
170
+ Note:
171
+ The returned link forces are those passed as user inputs when integrating
172
+ the dynamics of the model. They are summed with other forces related
173
+ e.g. to the contact model and other kinematic constraints.
174
+ """
175
+
176
+ W_f_L = self.input.physics_model.f_ext
177
+
178
+ # Return all link forces in inertial-fixed representation using the implicit
179
+ # serialization.
180
+ if model is None:
181
+ if self.velocity_representation is not VelRepr.Inertial:
182
+ msg = "Missing model to use a representation different from {}"
183
+ raise ValueError(msg.format(VelRepr.Inertial.name))
184
+
185
+ if link_names is not None:
186
+ raise ValueError("Link names cannot be provided without a model")
187
+
188
+ return self.input.physics_model.f_ext
189
+
190
+ # If we have the model, we can extract the link names, if not provided.
191
+ link_names = link_names if link_names is not None else model.link_names()
192
+ link_idxs = js.link.names_to_idxs(link_names=link_names, model=model)
193
+
194
+ # In inertial-fixed representation, we already have the link forces.
195
+ if self.velocity_representation is VelRepr.Inertial:
196
+ return W_f_L[link_idxs, :]
197
+
198
+ if data is None:
199
+ msg = "Missing model data to use a representation different from {}"
200
+ raise ValueError(msg.format(VelRepr.Inertial.name))
201
+
202
+ if not_tracing(self.input.physics_model.f_ext) and not data.valid(model=model):
203
+ raise ValueError("The provided data is not valid for the model")
204
+
205
+ # Helper function to convert a single 6D force to the active representation.
206
+ def convert(f_L: jtp.Vector) -> jtp.Vector:
207
+ return JaxSimModelReferences.inertial_to_other_representation(
208
+ array=f_L,
209
+ other_representation=self.velocity_representation,
210
+ transform=data.base_transform(),
211
+ is_force=True,
212
+ )
213
+
214
+ # Convert to the desired representation.
215
+ f_L = jax.vmap(convert)(W_f_L[link_idxs, :])
216
+
217
+ return f_L
218
+
219
+ def joint_force_references(
220
+ self,
221
+ model: js.model.JaxSimModel | None = None,
222
+ joint_names: tuple[str, ...] | None = None,
223
+ ) -> jtp.Vector:
224
+ """
225
+ Return the joint force references.
226
+
227
+ Args:
228
+ model: The model to consider.
229
+ joint_names: The names of the joints corresponding to the forces.
230
+
231
+ Returns:
232
+ If no model and no joint names are provided, the joint forces as a
233
+ `(DoFs,)` vector corresponding to the default joint serialization
234
+ of the original model used to build the actuation object.
235
+ If a model is provided and no joint names are provided, the joint forces
236
+ as a `(DoFs,)` vector corresponding to the serialization of the
237
+ provided model.
238
+ If both a model and joint names are provided, the joint forces as a
239
+ `(len(joint_names),)` vector corresponding to the serialization of
240
+ the passed joint names vector.
241
+
242
+ Note:
243
+ The returned joint forces are those passed as user inputs when integrating
244
+ the dynamics of the model. They are summed with other joint forces related
245
+ e.g. to the enforcement of other kinematic constraints. Keep also in mind
246
+ that the presence of joint friction and other similar effects can make the
247
+ actual joint forces different from the references.
248
+ """
249
+
250
+ if model is None:
251
+ if joint_names is not None:
252
+ raise ValueError("Joint names cannot be provided without a model")
253
+
254
+ return self.input.physics_model.tau
255
+
256
+ if not_tracing(self.input.physics_model.tau) and not self.valid(model=model):
257
+ msg = "The actuation object is not compatible with the provided model"
258
+ raise ValueError(msg)
259
+
260
+ joint_names = joint_names if joint_names is not None else model.joint_names()
261
+ joint_idxs = js.joint.names_to_idxs(joint_names=joint_names, model=model)
262
+
263
+ return jnp.atleast_1d(
264
+ self.input.physics_model.tau[joint_idxs].squeeze()
265
+ ).astype(float)
266
+
267
+ # ================
268
+ # Store quantities
269
+ # ================
270
+
271
+ @functools.partial(jax.jit, static_argnames=["joint_names"])
272
+ def set_joint_force_references(
273
+ self,
274
+ forces: jtp.VectorLike,
275
+ model: js.model.JaxSimModel | None = None,
276
+ joint_names: tuple[str, ...] | None = None,
277
+ ) -> Self:
278
+ """
279
+ Set the joint force references.
280
+
281
+ Args:
282
+ forces: The joint force references.
283
+ model:
284
+ The model to consider, only needed if a joint serialization different
285
+ from the implicit one is used.
286
+ joint_names: The names of the joints corresponding to the forces.
287
+
288
+ Returns:
289
+ A new `JaxSimModelReferences` object with the given joint force references.
290
+ """
291
+
292
+ forces = jnp.array(forces)
293
+
294
+ def replace(forces: jtp.VectorLike) -> JaxSimModelReferences:
295
+ return self.replace(
296
+ validate=True,
297
+ input=self.input.replace(
298
+ physics_model=self.input.physics_model.replace(
299
+ tau=jnp.atleast_1d(forces.squeeze()).astype(float)
300
+ )
301
+ ),
302
+ )
303
+
304
+ if model is None:
305
+ return replace(forces=forces)
306
+
307
+ if not_tracing(forces) and not self.valid(model=model):
308
+ msg = "The references object is not compatible with the provided model"
309
+ raise ValueError(msg)
310
+
311
+ joint_names = joint_names if joint_names is not None else model.joint_names()
312
+ joint_idxs = js.joint.names_to_idxs(joint_names=joint_names, model=model)
313
+
314
+ return replace(forces=self.input.physics_model.tau.at[joint_idxs].set(forces))
315
+
316
+ @functools.partial(jax.jit, static_argnames=["link_names", "additive"])
317
+ def apply_link_forces(
318
+ self,
319
+ forces: jtp.MatrixLike,
320
+ model: js.model.JaxSimModel | None = None,
321
+ data: js.data.JaxSimModelData | None = None,
322
+ link_names: tuple[str, ...] | None = None,
323
+ additive: bool = False,
324
+ ) -> Self:
325
+ """
326
+ Apply the link forces.
327
+
328
+ Args:
329
+ forces: The link 6D forces in the active representation.
330
+ model:
331
+ The model to consider, only needed if a link serialization different
332
+ from the implicit one is used.
333
+ data:
334
+ The data of the considered model, only needed if the velocity
335
+ representation is not inertial-fixed.
336
+ link_names: The names of the links corresponding to the forces.
337
+ additive:
338
+ Whether to add the forces to the existing ones instead of replacing them.
339
+
340
+ Returns:
341
+ A new `JaxSimModelReferences` object with the given link forces.
342
+
343
+ Note:
344
+ The link forces must be expressed in the active representation.
345
+ Then, we always convert and store forces in inertial-fixed representation.
346
+ """
347
+
348
+ f_L = jnp.array(forces)
349
+
350
+ # Helper function to replace the link forces.
351
+ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:
352
+ return self.replace(
353
+ validate=True,
354
+ input=self.input.replace(
355
+ physics_model=self.input.physics_model.replace(
356
+ f_ext=jnp.atleast_2d(forces.squeeze()).astype(float)
357
+ )
358
+ ),
359
+ )
360
+
361
+ # In this case, we allow only to set the inertial 6D forces to all links
362
+ # using the implicit link serialization.
363
+ if model is None:
364
+ if self.velocity_representation is not VelRepr.Inertial:
365
+ msg = "Missing model to use a representation different from {}"
366
+ raise ValueError(msg.format(VelRepr.Inertial.name))
367
+
368
+ if link_names is not None:
369
+ raise ValueError("Link names cannot be provided without a model")
370
+
371
+ W_f_L = f_L
372
+
373
+ W_f0_L = (
374
+ jnp.zeros_like(W_f_L)
375
+ if not additive
376
+ else self.input.physics_model.f_ext
377
+ )
378
+
379
+ return replace(forces=W_f0_L + W_f_L)
380
+
381
+ # If we have the model, we can extract the link names if not provided.
382
+ link_names = link_names if link_names is not None else model.link_names()
383
+ link_idxs = js.link.names_to_idxs(link_names=link_names, model=model)
384
+
385
+ # Compute the bias depending on whether we either set or add the link forces.
386
+ W_f0_L = (
387
+ jnp.zeros_like(f_L)
388
+ if not additive
389
+ else self.input.physics_model.f_ext[link_idxs, :]
390
+ )
391
+
392
+ # If inertial-fixed representation, we can directly store the link forces.
393
+ if self.velocity_representation is VelRepr.Inertial:
394
+ W_f_L = f_L
395
+ return replace(
396
+ forces=self.input.physics_model.f_ext.at[link_idxs, :].set(
397
+ W_f0_L + W_f_L
398
+ )
399
+ )
400
+
401
+ if data is None:
402
+ msg = "Missing model data to use a representation different from {}"
403
+ raise ValueError(msg.format(VelRepr.Inertial.name))
404
+
405
+ if not_tracing(forces) and not data.valid(model=model):
406
+ raise ValueError("The provided data is not valid for the model")
407
+
408
+ # Helper function to convert a single 6D force to the inertial representation.
409
+ def convert(f_L: jtp.Vector) -> jtp.Vector:
410
+ return JaxSimModelReferences.other_representation_to_inertial(
411
+ array=f_L,
412
+ other_representation=self.velocity_representation,
413
+ transform=data.base_transform(),
414
+ is_force=True,
415
+ )
416
+
417
+ W_f_L = jax.vmap(convert)(f_L)
418
+
419
+ return replace(
420
+ forces=self.input.physics_model.f_ext.at[link_idxs, :].set(W_f0_L + W_f_L)
421
+ )
@@ -0,0 +1,2 @@
1
+ from . import fixed_step, variable_step
2
+ from .common import Integrator, SystemDynamics, Time, TimeStep