jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +5 -6
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -0
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +216 -0
- jaxsim/api/contact.py +271 -0
- jaxsim/api/data.py +821 -0
- jaxsim/api/joint.py +189 -0
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +361 -0
- jaxsim/api/model.py +1633 -0
- jaxsim/api/ode.py +295 -0
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +421 -0
- jaxsim/integrators/__init__.py +2 -0
- jaxsim/integrators/common.py +594 -0
- jaxsim/integrators/fixed_step.py +102 -0
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +92 -0
- jaxsim/mujoco/__init__.py +3 -0
- jaxsim/mujoco/__main__.py +192 -0
- jaxsim/mujoco/loaders.py +615 -0
- jaxsim/mujoco/model.py +414 -0
- jaxsim/mujoco/visualizer.py +176 -0
- jaxsim/parsers/descriptions/collision.py +14 -0
- jaxsim/parsers/descriptions/link.py +13 -2
- jaxsim/parsers/kinematic_graph.py +8 -3
- jaxsim/parsers/rod/parser.py +54 -38
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/{physics/algos → terrain}/terrain.py +4 -6
- jaxsim/typing.py +30 -30
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -31
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
- jaxsim-0.2.0.dist-info/METADATA +237 -0
- jaxsim-0.2.0.dist-info/RECORD +64 -0
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1695
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -101
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -256
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -454
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -358
- jaxsim/physics/model/physics_model_state.py +0 -174
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -452
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -53
- jaxsim/simulation/ode_integration.py +0 -125
- jaxsim/simulation/simulator.py +0 -544
- jaxsim/simulation/simulator_callbacks.py +0 -53
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -532
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.1.dev401.dist-info/METADATA +0 -167
- jaxsim-0.1.dev401.dist-info/RECORD +0 -64
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/api/references.py
ADDED
@@ -0,0 +1,421 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import functools
|
4
|
+
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import jax_dataclasses
|
8
|
+
|
9
|
+
import jaxsim.api as js
|
10
|
+
import jaxsim.typing as jtp
|
11
|
+
from jaxsim.utils.tracing import not_tracing
|
12
|
+
|
13
|
+
from .common import VelRepr
|
14
|
+
from .ode_data import ODEInput
|
15
|
+
|
16
|
+
try:
|
17
|
+
from typing import Self
|
18
|
+
except ImportError:
|
19
|
+
from typing_extensions import Self
|
20
|
+
|
21
|
+
|
22
|
+
@jax_dataclasses.pytree_dataclass
|
23
|
+
class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
24
|
+
"""
|
25
|
+
Class containing the references for a `JaxSimModel` object.
|
26
|
+
"""
|
27
|
+
|
28
|
+
input: ODEInput
|
29
|
+
|
30
|
+
@staticmethod
|
31
|
+
def zero(
|
32
|
+
model: js.model.JaxSimModel,
|
33
|
+
velocity_representation: VelRepr = VelRepr.Inertial,
|
34
|
+
) -> JaxSimModelReferences:
|
35
|
+
"""
|
36
|
+
Create a `JaxSimModelReferences` object with zero references.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
model: The model for which to create the zero references.
|
40
|
+
velocity_representation: The velocity representation to use.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
A `JaxSimModelReferences` object with zero state.
|
44
|
+
"""
|
45
|
+
|
46
|
+
return JaxSimModelReferences.build(
|
47
|
+
model=model, velocity_representation=velocity_representation
|
48
|
+
)
|
49
|
+
|
50
|
+
@staticmethod
|
51
|
+
def build(
|
52
|
+
model: js.model.JaxSimModel,
|
53
|
+
joint_force_references: jtp.Vector | None = None,
|
54
|
+
link_forces: jtp.Matrix | None = None,
|
55
|
+
data: js.data.JaxSimModelData | None = None,
|
56
|
+
velocity_representation: VelRepr | None = None,
|
57
|
+
) -> JaxSimModelReferences:
|
58
|
+
"""
|
59
|
+
Create a `JaxSimModelReferences` object with the given references.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
model: The model for which to create the state.
|
63
|
+
joint_force_references: The joint force references.
|
64
|
+
link_forces: The link 6D forces in the desired representation.
|
65
|
+
data:
|
66
|
+
The data of the model, only needed if the velocity representation is
|
67
|
+
not inertial-fixed.
|
68
|
+
velocity_representation: The velocity representation to use.
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
A `JaxSimModelReferences` object with the given references.
|
72
|
+
"""
|
73
|
+
|
74
|
+
# Create or adjust joint force references.
|
75
|
+
joint_force_references = jnp.atleast_1d(
|
76
|
+
joint_force_references.squeeze()
|
77
|
+
if joint_force_references is not None
|
78
|
+
else jnp.zeros(model.dofs())
|
79
|
+
).astype(float)
|
80
|
+
|
81
|
+
# Create or adjust link forces.
|
82
|
+
f_L = jnp.atleast_2d(
|
83
|
+
link_forces.squeeze()
|
84
|
+
if link_forces is not None
|
85
|
+
else jnp.zeros((model.number_of_links(), 6))
|
86
|
+
).astype(float)
|
87
|
+
|
88
|
+
# Select the velocity representation.
|
89
|
+
velocity_representation = (
|
90
|
+
velocity_representation
|
91
|
+
if velocity_representation is not None
|
92
|
+
else (
|
93
|
+
data.velocity_representation if data is not None else VelRepr.Inertial
|
94
|
+
)
|
95
|
+
)
|
96
|
+
|
97
|
+
# Create a zero references object.
|
98
|
+
references = JaxSimModelReferences(
|
99
|
+
input=ODEInput.zero(model=model),
|
100
|
+
velocity_representation=velocity_representation,
|
101
|
+
)
|
102
|
+
|
103
|
+
# Store the joint force references.
|
104
|
+
references = references.set_joint_force_references(
|
105
|
+
forces=joint_force_references,
|
106
|
+
model=model,
|
107
|
+
joint_names=model.joint_names(),
|
108
|
+
)
|
109
|
+
|
110
|
+
# Apply the link forces.
|
111
|
+
references = references.apply_link_forces(
|
112
|
+
forces=f_L,
|
113
|
+
model=model,
|
114
|
+
data=data,
|
115
|
+
link_names=model.link_names(),
|
116
|
+
additive=False,
|
117
|
+
)
|
118
|
+
|
119
|
+
return references
|
120
|
+
|
121
|
+
def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
|
122
|
+
"""
|
123
|
+
Check if the current references are valid for the given model.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
model: The model to check against.
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
`True` if the current references are valid for the given model,
|
130
|
+
`False` otherwise.
|
131
|
+
"""
|
132
|
+
|
133
|
+
valid = True
|
134
|
+
|
135
|
+
if model is not None:
|
136
|
+
valid = valid and self.input.valid(model=model)
|
137
|
+
|
138
|
+
return valid
|
139
|
+
|
140
|
+
# ==================
|
141
|
+
# Extract quantities
|
142
|
+
# ==================
|
143
|
+
|
144
|
+
@functools.partial(jax.jit, static_argnames=["link_names"])
|
145
|
+
def link_forces(
|
146
|
+
self,
|
147
|
+
model: js.model.JaxSimModel | None = None,
|
148
|
+
data: js.data.JaxSimModelData | None = None,
|
149
|
+
link_names: tuple[str, ...] | None = None,
|
150
|
+
) -> jtp.Matrix:
|
151
|
+
"""
|
152
|
+
Return the link forces expressed in the frame of the active representation.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
model: The model to consider.
|
156
|
+
data: The data of the considered model.
|
157
|
+
link_names: The names of the links corresponding to the forces.
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
If no model and no link names are provided, the link forces as a
|
161
|
+
`(n_links,6)` matrix corresponding to the default link serialization
|
162
|
+
of the original model used to build the actuation object.
|
163
|
+
If a model is provided and no link names are provided, the link forces
|
164
|
+
as a `(n_links,6)` matrix corresponding to the serialization of the
|
165
|
+
provided model.
|
166
|
+
If both a model and link names are provided, the link forces as a
|
167
|
+
`(len(link_names),6)` matrix corresponding to the serialization of
|
168
|
+
the passed link names vector.
|
169
|
+
|
170
|
+
Note:
|
171
|
+
The returned link forces are those passed as user inputs when integrating
|
172
|
+
the dynamics of the model. They are summed with other forces related
|
173
|
+
e.g. to the contact model and other kinematic constraints.
|
174
|
+
"""
|
175
|
+
|
176
|
+
W_f_L = self.input.physics_model.f_ext
|
177
|
+
|
178
|
+
# Return all link forces in inertial-fixed representation using the implicit
|
179
|
+
# serialization.
|
180
|
+
if model is None:
|
181
|
+
if self.velocity_representation is not VelRepr.Inertial:
|
182
|
+
msg = "Missing model to use a representation different from {}"
|
183
|
+
raise ValueError(msg.format(VelRepr.Inertial.name))
|
184
|
+
|
185
|
+
if link_names is not None:
|
186
|
+
raise ValueError("Link names cannot be provided without a model")
|
187
|
+
|
188
|
+
return self.input.physics_model.f_ext
|
189
|
+
|
190
|
+
# If we have the model, we can extract the link names, if not provided.
|
191
|
+
link_names = link_names if link_names is not None else model.link_names()
|
192
|
+
link_idxs = js.link.names_to_idxs(link_names=link_names, model=model)
|
193
|
+
|
194
|
+
# In inertial-fixed representation, we already have the link forces.
|
195
|
+
if self.velocity_representation is VelRepr.Inertial:
|
196
|
+
return W_f_L[link_idxs, :]
|
197
|
+
|
198
|
+
if data is None:
|
199
|
+
msg = "Missing model data to use a representation different from {}"
|
200
|
+
raise ValueError(msg.format(VelRepr.Inertial.name))
|
201
|
+
|
202
|
+
if not_tracing(self.input.physics_model.f_ext) and not data.valid(model=model):
|
203
|
+
raise ValueError("The provided data is not valid for the model")
|
204
|
+
|
205
|
+
# Helper function to convert a single 6D force to the active representation.
|
206
|
+
def convert(f_L: jtp.Vector) -> jtp.Vector:
|
207
|
+
return JaxSimModelReferences.inertial_to_other_representation(
|
208
|
+
array=f_L,
|
209
|
+
other_representation=self.velocity_representation,
|
210
|
+
transform=data.base_transform(),
|
211
|
+
is_force=True,
|
212
|
+
)
|
213
|
+
|
214
|
+
# Convert to the desired representation.
|
215
|
+
f_L = jax.vmap(convert)(W_f_L[link_idxs, :])
|
216
|
+
|
217
|
+
return f_L
|
218
|
+
|
219
|
+
def joint_force_references(
|
220
|
+
self,
|
221
|
+
model: js.model.JaxSimModel | None = None,
|
222
|
+
joint_names: tuple[str, ...] | None = None,
|
223
|
+
) -> jtp.Vector:
|
224
|
+
"""
|
225
|
+
Return the joint force references.
|
226
|
+
|
227
|
+
Args:
|
228
|
+
model: The model to consider.
|
229
|
+
joint_names: The names of the joints corresponding to the forces.
|
230
|
+
|
231
|
+
Returns:
|
232
|
+
If no model and no joint names are provided, the joint forces as a
|
233
|
+
`(DoFs,)` vector corresponding to the default joint serialization
|
234
|
+
of the original model used to build the actuation object.
|
235
|
+
If a model is provided and no joint names are provided, the joint forces
|
236
|
+
as a `(DoFs,)` vector corresponding to the serialization of the
|
237
|
+
provided model.
|
238
|
+
If both a model and joint names are provided, the joint forces as a
|
239
|
+
`(len(joint_names),)` vector corresponding to the serialization of
|
240
|
+
the passed joint names vector.
|
241
|
+
|
242
|
+
Note:
|
243
|
+
The returned joint forces are those passed as user inputs when integrating
|
244
|
+
the dynamics of the model. They are summed with other joint forces related
|
245
|
+
e.g. to the enforcement of other kinematic constraints. Keep also in mind
|
246
|
+
that the presence of joint friction and other similar effects can make the
|
247
|
+
actual joint forces different from the references.
|
248
|
+
"""
|
249
|
+
|
250
|
+
if model is None:
|
251
|
+
if joint_names is not None:
|
252
|
+
raise ValueError("Joint names cannot be provided without a model")
|
253
|
+
|
254
|
+
return self.input.physics_model.tau
|
255
|
+
|
256
|
+
if not_tracing(self.input.physics_model.tau) and not self.valid(model=model):
|
257
|
+
msg = "The actuation object is not compatible with the provided model"
|
258
|
+
raise ValueError(msg)
|
259
|
+
|
260
|
+
joint_names = joint_names if joint_names is not None else model.joint_names()
|
261
|
+
joint_idxs = js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
262
|
+
|
263
|
+
return jnp.atleast_1d(
|
264
|
+
self.input.physics_model.tau[joint_idxs].squeeze()
|
265
|
+
).astype(float)
|
266
|
+
|
267
|
+
# ================
|
268
|
+
# Store quantities
|
269
|
+
# ================
|
270
|
+
|
271
|
+
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
272
|
+
def set_joint_force_references(
|
273
|
+
self,
|
274
|
+
forces: jtp.VectorLike,
|
275
|
+
model: js.model.JaxSimModel | None = None,
|
276
|
+
joint_names: tuple[str, ...] | None = None,
|
277
|
+
) -> Self:
|
278
|
+
"""
|
279
|
+
Set the joint force references.
|
280
|
+
|
281
|
+
Args:
|
282
|
+
forces: The joint force references.
|
283
|
+
model:
|
284
|
+
The model to consider, only needed if a joint serialization different
|
285
|
+
from the implicit one is used.
|
286
|
+
joint_names: The names of the joints corresponding to the forces.
|
287
|
+
|
288
|
+
Returns:
|
289
|
+
A new `JaxSimModelReferences` object with the given joint force references.
|
290
|
+
"""
|
291
|
+
|
292
|
+
forces = jnp.array(forces)
|
293
|
+
|
294
|
+
def replace(forces: jtp.VectorLike) -> JaxSimModelReferences:
|
295
|
+
return self.replace(
|
296
|
+
validate=True,
|
297
|
+
input=self.input.replace(
|
298
|
+
physics_model=self.input.physics_model.replace(
|
299
|
+
tau=jnp.atleast_1d(forces.squeeze()).astype(float)
|
300
|
+
)
|
301
|
+
),
|
302
|
+
)
|
303
|
+
|
304
|
+
if model is None:
|
305
|
+
return replace(forces=forces)
|
306
|
+
|
307
|
+
if not_tracing(forces) and not self.valid(model=model):
|
308
|
+
msg = "The references object is not compatible with the provided model"
|
309
|
+
raise ValueError(msg)
|
310
|
+
|
311
|
+
joint_names = joint_names if joint_names is not None else model.joint_names()
|
312
|
+
joint_idxs = js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
313
|
+
|
314
|
+
return replace(forces=self.input.physics_model.tau.at[joint_idxs].set(forces))
|
315
|
+
|
316
|
+
@functools.partial(jax.jit, static_argnames=["link_names", "additive"])
|
317
|
+
def apply_link_forces(
|
318
|
+
self,
|
319
|
+
forces: jtp.MatrixLike,
|
320
|
+
model: js.model.JaxSimModel | None = None,
|
321
|
+
data: js.data.JaxSimModelData | None = None,
|
322
|
+
link_names: tuple[str, ...] | None = None,
|
323
|
+
additive: bool = False,
|
324
|
+
) -> Self:
|
325
|
+
"""
|
326
|
+
Apply the link forces.
|
327
|
+
|
328
|
+
Args:
|
329
|
+
forces: The link 6D forces in the active representation.
|
330
|
+
model:
|
331
|
+
The model to consider, only needed if a link serialization different
|
332
|
+
from the implicit one is used.
|
333
|
+
data:
|
334
|
+
The data of the considered model, only needed if the velocity
|
335
|
+
representation is not inertial-fixed.
|
336
|
+
link_names: The names of the links corresponding to the forces.
|
337
|
+
additive:
|
338
|
+
Whether to add the forces to the existing ones instead of replacing them.
|
339
|
+
|
340
|
+
Returns:
|
341
|
+
A new `JaxSimModelReferences` object with the given link forces.
|
342
|
+
|
343
|
+
Note:
|
344
|
+
The link forces must be expressed in the active representation.
|
345
|
+
Then, we always convert and store forces in inertial-fixed representation.
|
346
|
+
"""
|
347
|
+
|
348
|
+
f_L = jnp.array(forces)
|
349
|
+
|
350
|
+
# Helper function to replace the link forces.
|
351
|
+
def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:
|
352
|
+
return self.replace(
|
353
|
+
validate=True,
|
354
|
+
input=self.input.replace(
|
355
|
+
physics_model=self.input.physics_model.replace(
|
356
|
+
f_ext=jnp.atleast_2d(forces.squeeze()).astype(float)
|
357
|
+
)
|
358
|
+
),
|
359
|
+
)
|
360
|
+
|
361
|
+
# In this case, we allow only to set the inertial 6D forces to all links
|
362
|
+
# using the implicit link serialization.
|
363
|
+
if model is None:
|
364
|
+
if self.velocity_representation is not VelRepr.Inertial:
|
365
|
+
msg = "Missing model to use a representation different from {}"
|
366
|
+
raise ValueError(msg.format(VelRepr.Inertial.name))
|
367
|
+
|
368
|
+
if link_names is not None:
|
369
|
+
raise ValueError("Link names cannot be provided without a model")
|
370
|
+
|
371
|
+
W_f_L = f_L
|
372
|
+
|
373
|
+
W_f0_L = (
|
374
|
+
jnp.zeros_like(W_f_L)
|
375
|
+
if not additive
|
376
|
+
else self.input.physics_model.f_ext
|
377
|
+
)
|
378
|
+
|
379
|
+
return replace(forces=W_f0_L + W_f_L)
|
380
|
+
|
381
|
+
# If we have the model, we can extract the link names if not provided.
|
382
|
+
link_names = link_names if link_names is not None else model.link_names()
|
383
|
+
link_idxs = js.link.names_to_idxs(link_names=link_names, model=model)
|
384
|
+
|
385
|
+
# Compute the bias depending on whether we either set or add the link forces.
|
386
|
+
W_f0_L = (
|
387
|
+
jnp.zeros_like(f_L)
|
388
|
+
if not additive
|
389
|
+
else self.input.physics_model.f_ext[link_idxs, :]
|
390
|
+
)
|
391
|
+
|
392
|
+
# If inertial-fixed representation, we can directly store the link forces.
|
393
|
+
if self.velocity_representation is VelRepr.Inertial:
|
394
|
+
W_f_L = f_L
|
395
|
+
return replace(
|
396
|
+
forces=self.input.physics_model.f_ext.at[link_idxs, :].set(
|
397
|
+
W_f0_L + W_f_L
|
398
|
+
)
|
399
|
+
)
|
400
|
+
|
401
|
+
if data is None:
|
402
|
+
msg = "Missing model data to use a representation different from {}"
|
403
|
+
raise ValueError(msg.format(VelRepr.Inertial.name))
|
404
|
+
|
405
|
+
if not_tracing(forces) and not data.valid(model=model):
|
406
|
+
raise ValueError("The provided data is not valid for the model")
|
407
|
+
|
408
|
+
# Helper function to convert a single 6D force to the inertial representation.
|
409
|
+
def convert(f_L: jtp.Vector) -> jtp.Vector:
|
410
|
+
return JaxSimModelReferences.other_representation_to_inertial(
|
411
|
+
array=f_L,
|
412
|
+
other_representation=self.velocity_representation,
|
413
|
+
transform=data.base_transform(),
|
414
|
+
is_force=True,
|
415
|
+
)
|
416
|
+
|
417
|
+
W_f_L = jax.vmap(convert)(f_L)
|
418
|
+
|
419
|
+
return replace(
|
420
|
+
forces=self.input.physics_model.f_ext.at[link_idxs, :].set(W_f0_L + W_f_L)
|
421
|
+
)
|