jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/api/data.py
CHANGED
@@ -2,28 +2,23 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import dataclasses
|
4
4
|
import functools
|
5
|
-
from
|
5
|
+
from collections.abc import Sequence
|
6
6
|
|
7
7
|
import jax
|
8
8
|
import jax.numpy as jnp
|
9
|
+
import jax.scipy.spatial.transform
|
9
10
|
import jax_dataclasses
|
10
|
-
|
11
|
-
import
|
12
|
-
|
13
|
-
import jaxsim.
|
14
|
-
import jaxsim.physics.algos.aba
|
15
|
-
import jaxsim.physics.algos.crba
|
16
|
-
import jaxsim.physics.algos.forward_kinematics
|
17
|
-
import jaxsim.physics.algos.rnea
|
18
|
-
import jaxsim.physics.model.physics_model
|
19
|
-
import jaxsim.physics.model.physics_model_state
|
11
|
+
|
12
|
+
import jaxsim.api as js
|
13
|
+
import jaxsim.math
|
14
|
+
import jaxsim.rbda
|
20
15
|
import jaxsim.typing as jtp
|
21
|
-
from jaxsim.high_level.common import VelRepr
|
22
|
-
from jaxsim.physics.algos import soft_contacts
|
23
|
-
from jaxsim.simulation.ode_data import ODEState
|
24
16
|
from jaxsim.utils import Mutability
|
17
|
+
from jaxsim.utils.tracing import not_tracing
|
25
18
|
|
26
19
|
from . import common
|
20
|
+
from .common import VelRepr
|
21
|
+
from .ode_data import ODEState
|
27
22
|
|
28
23
|
try:
|
29
24
|
from typing import Self
|
@@ -34,21 +29,35 @@ except ImportError:
|
|
34
29
|
@jax_dataclasses.pytree_dataclass
|
35
30
|
class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
36
31
|
"""
|
37
|
-
Class containing the
|
32
|
+
Class containing the data of a `JaxSimModel` object.
|
38
33
|
"""
|
39
34
|
|
40
35
|
state: ODEState
|
41
36
|
|
42
|
-
gravity: jtp.
|
37
|
+
gravity: jtp.Vector
|
43
38
|
|
44
|
-
|
45
|
-
|
46
|
-
)
|
47
|
-
|
48
|
-
|
49
|
-
|
39
|
+
contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False)
|
40
|
+
|
41
|
+
def __hash__(self) -> int:
|
42
|
+
|
43
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
44
|
+
|
45
|
+
return hash(
|
46
|
+
(
|
47
|
+
hash(self.state),
|
48
|
+
HashedNumpyArray.hash_of_array(self.gravity),
|
49
|
+
hash(self.contacts_params),
|
50
|
+
)
|
51
|
+
)
|
52
|
+
|
53
|
+
def __eq__(self, other: JaxSimModelData) -> bool:
|
50
54
|
|
51
|
-
|
55
|
+
if not isinstance(other, JaxSimModelData):
|
56
|
+
return False
|
57
|
+
|
58
|
+
return hash(self) == hash(other)
|
59
|
+
|
60
|
+
def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
|
52
61
|
"""
|
53
62
|
Check if the current state is valid for the given model.
|
54
63
|
|
@@ -60,15 +69,16 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
60
69
|
"""
|
61
70
|
|
62
71
|
valid = True
|
72
|
+
valid = valid and self.standard_gravity() > 0
|
63
73
|
|
64
74
|
if model is not None:
|
65
|
-
valid = valid and self.state.valid(
|
75
|
+
valid = valid and self.state.valid(model=model)
|
66
76
|
|
67
77
|
return valid
|
68
78
|
|
69
79
|
@staticmethod
|
70
80
|
def zero(
|
71
|
-
model:
|
81
|
+
model: js.model.JaxSimModel,
|
72
82
|
velocity_representation: VelRepr = VelRepr.Inertial,
|
73
83
|
) -> JaxSimModelData:
|
74
84
|
"""
|
@@ -88,18 +98,17 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
88
98
|
|
89
99
|
@staticmethod
|
90
100
|
def build(
|
91
|
-
model:
|
92
|
-
base_position: jtp.
|
93
|
-
base_quaternion: jtp.
|
94
|
-
joint_positions: jtp.
|
95
|
-
base_linear_velocity: jtp.
|
96
|
-
base_angular_velocity: jtp.
|
97
|
-
joint_velocities: jtp.
|
98
|
-
|
99
|
-
|
100
|
-
soft_contacts_params: soft_contacts.SoftContactsParams | None = None,
|
101
|
+
model: js.model.JaxSimModel,
|
102
|
+
base_position: jtp.VectorLike | None = None,
|
103
|
+
base_quaternion: jtp.VectorLike | None = None,
|
104
|
+
joint_positions: jtp.VectorLike | None = None,
|
105
|
+
base_linear_velocity: jtp.VectorLike | None = None,
|
106
|
+
base_angular_velocity: jtp.VectorLike | None = None,
|
107
|
+
joint_velocities: jtp.VectorLike | None = None,
|
108
|
+
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
|
109
|
+
contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
|
101
110
|
velocity_representation: VelRepr = VelRepr.Inertial,
|
102
|
-
|
111
|
+
extended_ode_state: dict[str, jtp.PyTree] | None = None,
|
103
112
|
) -> JaxSimModelData:
|
104
113
|
"""
|
105
114
|
Create a `JaxSimModelData` object with the given state.
|
@@ -114,97 +123,119 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
114
123
|
base_angular_velocity:
|
115
124
|
The base angular velocity in the selected representation.
|
116
125
|
joint_velocities: The joint velocities.
|
117
|
-
|
118
|
-
|
119
|
-
soft_contacts_params: The parameters of the soft contacts.
|
126
|
+
standard_gravity: The standard gravity constant.
|
127
|
+
contacts_params: The parameters of the soft contacts.
|
120
128
|
velocity_representation: The velocity representation to use.
|
121
|
-
|
129
|
+
extended_ode_state:
|
130
|
+
Additional user-defined state variables that are not part of the
|
131
|
+
standard `ODEState` object. Useful to extend the system dynamics
|
132
|
+
considered by default in JaxSim.
|
122
133
|
|
123
134
|
Returns:
|
124
|
-
A `JaxSimModelData`
|
135
|
+
A `JaxSimModelData` initialized with the given state.
|
125
136
|
"""
|
126
137
|
|
127
138
|
base_position = jnp.array(
|
128
|
-
base_position if base_position is not None else jnp.zeros(3)
|
139
|
+
base_position if base_position is not None else jnp.zeros(3),
|
140
|
+
dtype=float,
|
129
141
|
).squeeze()
|
130
142
|
|
131
143
|
base_quaternion = jnp.array(
|
132
|
-
|
133
|
-
|
134
|
-
|
144
|
+
(
|
145
|
+
base_quaternion
|
146
|
+
if base_quaternion is not None
|
147
|
+
else jnp.array([1.0, 0, 0, 0])
|
148
|
+
),
|
149
|
+
dtype=float,
|
135
150
|
).squeeze()
|
136
151
|
|
137
152
|
base_linear_velocity = jnp.array(
|
138
|
-
base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3)
|
153
|
+
base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3),
|
154
|
+
dtype=float,
|
139
155
|
).squeeze()
|
140
156
|
|
141
157
|
base_angular_velocity = jnp.array(
|
142
|
-
|
158
|
+
(
|
159
|
+
base_angular_velocity
|
160
|
+
if base_angular_velocity is not None
|
161
|
+
else jnp.zeros(3)
|
162
|
+
),
|
163
|
+
dtype=float,
|
143
164
|
).squeeze()
|
144
165
|
|
145
|
-
gravity = jnp.
|
146
|
-
gravity if gravity is not None else model.physics_model.gravity[0:3]
|
147
|
-
).squeeze()
|
166
|
+
gravity = jnp.zeros(3).at[2].set(-standard_gravity)
|
148
167
|
|
149
168
|
joint_positions = jnp.atleast_1d(
|
150
|
-
|
151
|
-
|
152
|
-
|
169
|
+
jnp.array(
|
170
|
+
(
|
171
|
+
joint_positions
|
172
|
+
if joint_positions is not None
|
173
|
+
else jnp.zeros(model.dofs())
|
174
|
+
),
|
175
|
+
dtype=float,
|
176
|
+
).squeeze()
|
153
177
|
)
|
154
178
|
|
155
179
|
joint_velocities = jnp.atleast_1d(
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
else jnp.array(0, dtype=jnp.uint64)
|
180
|
+
jnp.array(
|
181
|
+
(
|
182
|
+
joint_velocities
|
183
|
+
if joint_velocities is not None
|
184
|
+
else jnp.zeros(model.dofs())
|
185
|
+
),
|
186
|
+
dtype=float,
|
187
|
+
).squeeze()
|
165
188
|
)
|
166
189
|
|
167
|
-
|
168
|
-
|
169
|
-
if soft_contacts_params is not None
|
170
|
-
else jaxsim.api.contact.estimate_good_soft_contacts_parameters(model=model)
|
190
|
+
W_H_B = jaxsim.math.Transform.from_quaternion_and_translation(
|
191
|
+
translation=base_position, quaternion=base_quaternion
|
171
192
|
)
|
172
193
|
|
173
|
-
W_H_B = jaxlie.SE3.from_rotation_and_translation(
|
174
|
-
translation=base_position,
|
175
|
-
rotation=jaxlie.SO3.from_quaternion_xyzw(
|
176
|
-
base_quaternion[jnp.array([1, 2, 3, 0])]
|
177
|
-
),
|
178
|
-
).as_matrix()
|
179
|
-
|
180
194
|
v_WB = JaxSimModelData.other_representation_to_inertial(
|
181
195
|
array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
|
182
196
|
other_representation=velocity_representation,
|
183
197
|
transform=W_H_B,
|
184
198
|
is_force=False,
|
199
|
+
).astype(float)
|
200
|
+
|
201
|
+
ode_state = ODEState.build_from_jaxsim_model(
|
202
|
+
model=model,
|
203
|
+
base_position=base_position,
|
204
|
+
base_quaternion=base_quaternion,
|
205
|
+
joint_positions=joint_positions,
|
206
|
+
base_linear_velocity=v_WB[0:3],
|
207
|
+
base_angular_velocity=v_WB[3:6],
|
208
|
+
joint_velocities=joint_velocities,
|
209
|
+
# Unpack all the additional ODE states. If the contact model requires an
|
210
|
+
# additional state that is not explicitly passed to this builder, ODEState
|
211
|
+
# automatically populates that state with zeroed variables.
|
212
|
+
# This is not true for any other custom state that the user might want to
|
213
|
+
# pass to the integrator.
|
214
|
+
**(extended_ode_state if extended_ode_state else {}),
|
185
215
|
)
|
186
216
|
|
187
|
-
|
188
|
-
physics_model=model.physics_model,
|
189
|
-
physics_model_state=jaxsim.physics.model.physics_model_state.PhysicsModelState(
|
190
|
-
base_position=base_position.astype(float),
|
191
|
-
base_quaternion=base_quaternion.astype(float),
|
192
|
-
joint_positions=joint_positions.astype(float),
|
193
|
-
base_linear_velocity=v_WB[0:3].astype(float),
|
194
|
-
base_angular_velocity=v_WB[3:6].astype(float),
|
195
|
-
joint_velocities=joint_velocities.astype(float),
|
196
|
-
),
|
197
|
-
soft_contacts_state=soft_contacts_state,
|
198
|
-
)
|
199
|
-
|
200
|
-
if not ode_state.valid(physics_model=model.physics_model):
|
217
|
+
if not ode_state.valid(model=model):
|
201
218
|
raise ValueError(ode_state)
|
202
219
|
|
220
|
+
if contacts_params is None:
|
221
|
+
|
222
|
+
if isinstance(
|
223
|
+
model.contact_model,
|
224
|
+
jaxsim.rbda.contacts.SoftContacts
|
225
|
+
| jaxsim.rbda.contacts.ViscoElasticContacts,
|
226
|
+
):
|
227
|
+
|
228
|
+
contacts_params = js.contact.estimate_good_contact_parameters(
|
229
|
+
model=model, standard_gravity=standard_gravity
|
230
|
+
)
|
231
|
+
|
232
|
+
else:
|
233
|
+
contacts_params = model.contact_model._parameters_class()
|
234
|
+
|
203
235
|
return JaxSimModelData(
|
204
|
-
time_ns=time_ns,
|
205
236
|
state=ode_state,
|
206
|
-
gravity=gravity
|
207
|
-
|
237
|
+
gravity=gravity,
|
238
|
+
contacts_params=contacts_params,
|
208
239
|
velocity_representation=velocity_representation,
|
209
240
|
)
|
210
241
|
|
@@ -212,20 +243,21 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
212
243
|
# Extract quantities
|
213
244
|
# ==================
|
214
245
|
|
215
|
-
def
|
246
|
+
def standard_gravity(self) -> jtp.Float:
|
216
247
|
"""
|
217
|
-
Get the
|
248
|
+
Get the standard gravity constant.
|
218
249
|
|
219
250
|
Returns:
|
220
|
-
The
|
251
|
+
The standard gravity constant.
|
221
252
|
"""
|
222
253
|
|
223
|
-
return self.
|
254
|
+
return -self.gravity[2]
|
224
255
|
|
256
|
+
@js.common.named_scope
|
225
257
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
226
258
|
def joint_positions(
|
227
259
|
self,
|
228
|
-
model:
|
260
|
+
model: js.model.JaxSimModel | None = None,
|
229
261
|
joint_names: tuple[str, ...] | None = None,
|
230
262
|
) -> jtp.Vector:
|
231
263
|
"""
|
@@ -250,22 +282,30 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
250
282
|
"""
|
251
283
|
|
252
284
|
if model is None:
|
285
|
+
if joint_names is not None:
|
286
|
+
raise ValueError("Joint names cannot be provided without a model")
|
287
|
+
|
253
288
|
return self.state.physics_model.joint_positions
|
254
289
|
|
255
|
-
if not self.valid(
|
290
|
+
if not_tracing(self.state.physics_model.joint_positions) and not self.valid(
|
291
|
+
model=model
|
292
|
+
):
|
256
293
|
msg = "The data object is not compatible with the provided model"
|
257
294
|
raise ValueError(msg)
|
258
295
|
|
259
|
-
|
296
|
+
joint_idxs = (
|
297
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
298
|
+
if joint_names is not None
|
299
|
+
else jnp.arange(model.number_of_joints())
|
300
|
+
)
|
260
301
|
|
261
|
-
return self.state.physics_model.joint_positions[
|
262
|
-
jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
|
263
|
-
]
|
302
|
+
return self.state.physics_model.joint_positions[joint_idxs]
|
264
303
|
|
304
|
+
@js.common.named_scope
|
265
305
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
266
306
|
def joint_velocities(
|
267
307
|
self,
|
268
|
-
model:
|
308
|
+
model: js.model.JaxSimModel | None = None,
|
269
309
|
joint_names: tuple[str, ...] | None = None,
|
270
310
|
) -> jtp.Vector:
|
271
311
|
"""
|
@@ -290,18 +330,26 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
290
330
|
"""
|
291
331
|
|
292
332
|
if model is None:
|
333
|
+
if joint_names is not None:
|
334
|
+
raise ValueError("Joint names cannot be provided without a model")
|
335
|
+
|
293
336
|
return self.state.physics_model.joint_velocities
|
294
337
|
|
295
|
-
if not self.valid(
|
338
|
+
if not_tracing(self.state.physics_model.joint_velocities) and not self.valid(
|
339
|
+
model=model
|
340
|
+
):
|
296
341
|
msg = "The data object is not compatible with the provided model"
|
297
342
|
raise ValueError(msg)
|
298
343
|
|
299
|
-
|
344
|
+
joint_idxs = (
|
345
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
346
|
+
if joint_names is not None
|
347
|
+
else jnp.arange(model.number_of_joints())
|
348
|
+
)
|
300
349
|
|
301
|
-
return self.state.physics_model.joint_velocities[
|
302
|
-
jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
|
303
|
-
]
|
350
|
+
return self.state.physics_model.joint_velocities[joint_idxs]
|
304
351
|
|
352
|
+
@js.common.named_scope
|
305
353
|
@jax.jit
|
306
354
|
def base_position(self) -> jtp.Vector:
|
307
355
|
"""
|
@@ -313,6 +361,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
313
361
|
|
314
362
|
return self.state.physics_model.base_position.squeeze()
|
315
363
|
|
364
|
+
@js.common.named_scope
|
316
365
|
@functools.partial(jax.jit, static_argnames=["dcm"])
|
317
366
|
def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix:
|
318
367
|
"""
|
@@ -325,29 +374,24 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
325
374
|
The base orientation.
|
326
375
|
"""
|
327
376
|
|
377
|
+
# Extract the base quaternion.
|
378
|
+
W_Q_B = self.state.physics_model.base_quaternion.squeeze()
|
379
|
+
|
328
380
|
# Always normalize the quaternion to avoid numerical issues.
|
329
381
|
# If the active scheme does not integrate the quaternion on its manifold,
|
330
382
|
# we introduce a Baumgarte stabilization to let the quaternion converge to
|
331
383
|
# a unit quaternion. In this case, it is not guaranteed that the quaternion
|
332
384
|
# stored in the state is a unit quaternion.
|
333
|
-
|
334
|
-
|
335
|
-
/ jnp.linalg.norm(self.state.physics_model.base_quaternion)
|
336
|
-
)
|
385
|
+
norm = jaxsim.math.safe_norm(W_Q_B)
|
386
|
+
W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
|
337
387
|
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
return (
|
342
|
-
base_unit_quaternion
|
343
|
-
if not dcm
|
344
|
-
else jaxlie.SO3.from_quaternion_xyzw(
|
345
|
-
base_unit_quaternion[to_xyzw]
|
346
|
-
).as_matrix()
|
388
|
+
return (W_Q_B if not dcm else jaxsim.math.Quaternion.to_dcm(W_Q_B)).astype(
|
389
|
+
float
|
347
390
|
)
|
348
391
|
|
392
|
+
@js.common.named_scope
|
349
393
|
@jax.jit
|
350
|
-
def base_transform(self) -> jtp.
|
394
|
+
def base_transform(self) -> jtp.Matrix:
|
351
395
|
"""
|
352
396
|
Get the base transform.
|
353
397
|
|
@@ -365,6 +409,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
365
409
|
]
|
366
410
|
)
|
367
411
|
|
412
|
+
@js.common.named_scope
|
368
413
|
@jax.jit
|
369
414
|
def base_velocity(self) -> jtp.Vector:
|
370
415
|
"""
|
@@ -394,9 +439,10 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
394
439
|
.astype(float)
|
395
440
|
)
|
396
441
|
|
442
|
+
@js.common.named_scope
|
397
443
|
@jax.jit
|
398
444
|
def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:
|
399
|
-
"""
|
445
|
+
r"""
|
400
446
|
Get the generalized position
|
401
447
|
:math:`\mathbf{q} = ({}^W \mathbf{H}_B, \mathbf{s}) \in \text{SO}(3) \times \mathbb{R}^n`.
|
402
448
|
|
@@ -406,10 +452,12 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
406
452
|
|
407
453
|
return self.base_transform(), self.joint_positions()
|
408
454
|
|
455
|
+
@js.common.named_scope
|
409
456
|
@jax.jit
|
410
457
|
def generalized_velocity(self) -> jtp.Vector:
|
411
|
-
"""
|
412
|
-
Get the generalized velocity
|
458
|
+
r"""
|
459
|
+
Get the generalized velocity.
|
460
|
+
|
413
461
|
:math:`\boldsymbol{\nu} = (\boldsymbol{v}_{W,B};\, \boldsymbol{\omega}_{W,B};\, \mathbf{s}) \in \mathbb{R}^{6+n}`
|
414
462
|
|
415
463
|
Returns:
|
@@ -426,11 +474,12 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
426
474
|
# Store quantities
|
427
475
|
# ================
|
428
476
|
|
477
|
+
@js.common.named_scope
|
429
478
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
430
479
|
def reset_joint_positions(
|
431
480
|
self,
|
432
481
|
positions: jtp.VectorLike,
|
433
|
-
model:
|
482
|
+
model: js.model.JaxSimModel | None = None,
|
434
483
|
joint_names: tuple[str, ...] | None = None,
|
435
484
|
) -> Self:
|
436
485
|
"""
|
@@ -460,23 +509,26 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
460
509
|
if model is None:
|
461
510
|
return replace(s=positions)
|
462
511
|
|
463
|
-
if not self.valid(model=model):
|
512
|
+
if not_tracing(positions) and not self.valid(model=model):
|
464
513
|
msg = "The data object is not compatible with the provided model"
|
465
514
|
raise ValueError(msg)
|
466
515
|
|
467
|
-
|
516
|
+
joint_idxs = (
|
517
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
518
|
+
if joint_names is not None
|
519
|
+
else jnp.arange(model.number_of_joints())
|
520
|
+
)
|
468
521
|
|
469
522
|
return replace(
|
470
|
-
s=self.state.physics_model.joint_positions.at[
|
471
|
-
jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
|
472
|
-
].set(positions)
|
523
|
+
s=self.state.physics_model.joint_positions.at[joint_idxs].set(positions)
|
473
524
|
)
|
474
525
|
|
526
|
+
@js.common.named_scope
|
475
527
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
476
528
|
def reset_joint_velocities(
|
477
529
|
self,
|
478
530
|
velocities: jtp.VectorLike,
|
479
|
-
model:
|
531
|
+
model: js.model.JaxSimModel | None = None,
|
480
532
|
joint_names: tuple[str, ...] | None = None,
|
481
533
|
) -> Self:
|
482
534
|
"""
|
@@ -506,18 +558,21 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
506
558
|
if model is None:
|
507
559
|
return replace(ṡ=velocities)
|
508
560
|
|
509
|
-
if not self.valid(model=model):
|
561
|
+
if not_tracing(velocities) and not self.valid(model=model):
|
510
562
|
msg = "The data object is not compatible with the provided model"
|
511
563
|
raise ValueError(msg)
|
512
564
|
|
513
|
-
|
565
|
+
joint_idxs = (
|
566
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
567
|
+
if joint_names is not None
|
568
|
+
else jnp.arange(model.number_of_joints())
|
569
|
+
)
|
514
570
|
|
515
571
|
return replace(
|
516
|
-
ṡ=self.state.physics_model.joint_velocities.at[
|
517
|
-
jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
|
518
|
-
].set(velocities)
|
572
|
+
ṡ=self.state.physics_model.joint_velocities.at[joint_idxs].set(velocities)
|
519
573
|
)
|
520
574
|
|
575
|
+
@js.common.named_scope
|
521
576
|
@jax.jit
|
522
577
|
def reset_base_position(self, base_position: jtp.VectorLike) -> Self:
|
523
578
|
"""
|
@@ -541,6 +596,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
541
596
|
),
|
542
597
|
)
|
543
598
|
|
599
|
+
@js.common.named_scope
|
544
600
|
@jax.jit
|
545
601
|
def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:
|
546
602
|
"""
|
@@ -553,19 +609,19 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
553
609
|
The updated `JaxSimModelData` object.
|
554
610
|
"""
|
555
611
|
|
556
|
-
|
612
|
+
W_Q_B = jnp.array(base_quaternion, dtype=float)
|
613
|
+
|
614
|
+
norm = jaxsim.math.safe_norm(W_Q_B)
|
615
|
+
W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
|
557
616
|
|
558
617
|
return self.replace(
|
559
618
|
validate=True,
|
560
619
|
state=self.state.replace(
|
561
|
-
physics_model=self.state.physics_model.replace(
|
562
|
-
base_quaternion=jnp.atleast_1d(base_quaternion.squeeze()).astype(
|
563
|
-
float
|
564
|
-
)
|
565
|
-
)
|
620
|
+
physics_model=self.state.physics_model.replace(base_quaternion=W_Q_B)
|
566
621
|
),
|
567
622
|
)
|
568
623
|
|
624
|
+
@js.common.named_scope
|
569
625
|
@jax.jit
|
570
626
|
def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self:
|
571
627
|
"""
|
@@ -582,14 +638,13 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
582
638
|
|
583
639
|
W_p_B = base_pose[0:3, 3]
|
584
640
|
|
585
|
-
|
586
|
-
W_R_B: jaxlie.SO3 = jaxlie.SO3.from_matrix(base_pose[0:3, 0:3]) # noqa
|
587
|
-
W_Q_B = W_R_B.as_quaternion_xyzw()[to_wxyz]
|
641
|
+
W_Q_B = jaxsim.math.Quaternion.from_dcm(dcm=base_pose[0:3, 0:3])
|
588
642
|
|
589
643
|
return self.reset_base_position(base_position=W_p_B).reset_base_quaternion(
|
590
644
|
base_quaternion=W_Q_B
|
591
645
|
)
|
592
646
|
|
647
|
+
@js.common.named_scope
|
593
648
|
@functools.partial(jax.jit, static_argnames=["velocity_representation"])
|
594
649
|
def reset_base_linear_velocity(
|
595
650
|
self,
|
@@ -613,11 +668,15 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
613
668
|
|
614
669
|
return self.reset_base_velocity(
|
615
670
|
base_velocity=jnp.hstack(
|
616
|
-
[
|
671
|
+
[
|
672
|
+
linear_velocity.squeeze(),
|
673
|
+
self.base_velocity()[3:6],
|
674
|
+
]
|
617
675
|
),
|
618
676
|
velocity_representation=velocity_representation,
|
619
677
|
)
|
620
678
|
|
679
|
+
@js.common.named_scope
|
621
680
|
@functools.partial(jax.jit, static_argnames=["velocity_representation"])
|
622
681
|
def reset_base_angular_velocity(
|
623
682
|
self,
|
@@ -641,11 +700,15 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
641
700
|
|
642
701
|
return self.reset_base_velocity(
|
643
702
|
base_velocity=jnp.hstack(
|
644
|
-
[
|
703
|
+
[
|
704
|
+
self.base_velocity()[0:3],
|
705
|
+
angular_velocity.squeeze(),
|
706
|
+
]
|
645
707
|
),
|
646
708
|
velocity_representation=velocity_representation,
|
647
709
|
)
|
648
710
|
|
711
|
+
@js.common.named_scope
|
649
712
|
@functools.partial(jax.jit, static_argnames=["velocity_representation"])
|
650
713
|
def reset_base_velocity(
|
651
714
|
self,
|
@@ -691,8 +754,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
691
754
|
)
|
692
755
|
|
693
756
|
|
757
|
+
@functools.partial(jax.jit, static_argnames=["velocity_representation", "base_rpy_seq"])
|
694
758
|
def random_model_data(
|
695
|
-
model:
|
759
|
+
model: js.model.JaxSimModel,
|
696
760
|
*,
|
697
761
|
key: jax.Array | None = None,
|
698
762
|
velocity_representation: VelRepr | None = None,
|
@@ -700,6 +764,18 @@ def random_model_data(
|
|
700
764
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
701
765
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
702
766
|
] = ((-1, -1, 0.5), 1.0),
|
767
|
+
base_rpy_bounds: tuple[
|
768
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
769
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
770
|
+
] = (-jnp.pi, jnp.pi),
|
771
|
+
base_rpy_seq: str = "XYZ",
|
772
|
+
joint_pos_bounds: (
|
773
|
+
tuple[
|
774
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
775
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
776
|
+
]
|
777
|
+
| None
|
778
|
+
) = None,
|
703
779
|
base_vel_lin_bounds: tuple[
|
704
780
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
705
781
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
@@ -712,6 +788,11 @@ def random_model_data(
|
|
712
788
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
713
789
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
714
790
|
] = (-1.0, 1.0),
|
791
|
+
contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
|
792
|
+
standard_gravity_bounds: tuple[jtp.FloatLike, jtp.FloatLike] = (
|
793
|
+
jaxsim.math.StandardGravity,
|
794
|
+
jaxsim.math.StandardGravity,
|
795
|
+
),
|
715
796
|
) -> JaxSimModelData:
|
716
797
|
"""
|
717
798
|
Randomly generate a `JaxSimModelData` object.
|
@@ -721,19 +802,29 @@ def random_model_data(
|
|
721
802
|
key: The random key.
|
722
803
|
velocity_representation: The velocity representation to use.
|
723
804
|
base_pos_bounds: The bounds for the base position.
|
805
|
+
base_rpy_bounds:
|
806
|
+
The bounds for the euler angles used to build the base orientation.
|
807
|
+
base_rpy_seq:
|
808
|
+
The sequence of axes for rotation (using `Rotation` from scipy).
|
809
|
+
joint_pos_bounds:
|
810
|
+
The bounds for the joint positions (reading the joint limits if None).
|
724
811
|
base_vel_lin_bounds: The bounds for the base linear velocity.
|
725
812
|
base_vel_ang_bounds: The bounds for the base angular velocity.
|
726
813
|
joint_vel_bounds: The bounds for the joint velocities.
|
814
|
+
contacts_params: The parameters of the contact model.
|
815
|
+
standard_gravity_bounds: The bounds for the standard gravity.
|
727
816
|
|
728
817
|
Returns:
|
729
818
|
A `JaxSimModelData` object with random data.
|
730
819
|
"""
|
731
820
|
|
732
821
|
key = key if key is not None else jax.random.PRNGKey(seed=0)
|
733
|
-
k1, k2, k3, k4, k5, k6 = jax.random.split(key, num=
|
822
|
+
k1, k2, k3, k4, k5, k6, k7 = jax.random.split(key, num=7)
|
734
823
|
|
735
824
|
p_min = jnp.array(base_pos_bounds[0], dtype=float)
|
736
825
|
p_max = jnp.array(base_pos_bounds[1], dtype=float)
|
826
|
+
rpy_min = jnp.array(base_rpy_bounds[0], dtype=float)
|
827
|
+
rpy_max = jnp.array(base_rpy_bounds[1], dtype=float)
|
737
828
|
v_min = jnp.array(base_vel_lin_bounds[0], dtype=float)
|
738
829
|
v_max = jnp.array(base_vel_lin_bounds[1], dtype=float)
|
739
830
|
ω_min = jnp.array(base_vel_ang_bounds[0], dtype=float)
|
@@ -749,7 +840,9 @@ def random_model_data(
|
|
749
840
|
),
|
750
841
|
)
|
751
842
|
|
752
|
-
with random_data.mutable_context(
|
843
|
+
with random_data.mutable_context(
|
844
|
+
mutability=Mutability.MUTABLE, restore_after_exception=False
|
845
|
+
):
|
753
846
|
|
754
847
|
physics_model_state = random_data.state.physics_model
|
755
848
|
|
@@ -757,24 +850,76 @@ def random_model_data(
|
|
757
850
|
key=k1, shape=(3,), minval=p_min, maxval=p_max
|
758
851
|
)
|
759
852
|
|
760
|
-
physics_model_state.base_quaternion =
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
853
|
+
physics_model_state.base_quaternion = jaxsim.math.Quaternion.to_wxyz(
|
854
|
+
xyzw=jax.scipy.spatial.transform.Rotation.from_euler(
|
855
|
+
seq=base_rpy_seq,
|
856
|
+
angles=jax.random.uniform(
|
857
|
+
key=k2, shape=(3,), minval=rpy_min, maxval=rpy_max
|
858
|
+
),
|
859
|
+
).as_quat()
|
766
860
|
)
|
767
861
|
|
768
|
-
|
769
|
-
key=k4, shape=(3,), minval=v_min, maxval=v_max
|
770
|
-
)
|
862
|
+
if model.number_of_joints() > 0:
|
771
863
|
|
772
|
-
|
773
|
-
|
774
|
-
|
864
|
+
s_min, s_max = (
|
865
|
+
jnp.array(joint_pos_bounds, dtype=float)
|
866
|
+
if joint_pos_bounds is not None
|
867
|
+
else (None, None)
|
868
|
+
)
|
869
|
+
|
870
|
+
physics_model_state.joint_positions = (
|
871
|
+
js.joint.random_joint_positions(model=model, key=k3)
|
872
|
+
if (s_min is None or s_max is None)
|
873
|
+
else jax.random.uniform(
|
874
|
+
key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max
|
875
|
+
)
|
876
|
+
)
|
877
|
+
|
878
|
+
physics_model_state.joint_velocities = jax.random.uniform(
|
879
|
+
key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
|
880
|
+
)
|
881
|
+
|
882
|
+
if model.floating_base():
|
883
|
+
physics_model_state.base_linear_velocity = jax.random.uniform(
|
884
|
+
key=k5, shape=(3,), minval=v_min, maxval=v_max
|
885
|
+
)
|
886
|
+
|
887
|
+
physics_model_state.base_angular_velocity = jax.random.uniform(
|
888
|
+
key=k6, shape=(3,), minval=ω_min, maxval=ω_max
|
889
|
+
)
|
775
890
|
|
776
|
-
|
777
|
-
|
891
|
+
random_data.gravity = (
|
892
|
+
jnp.zeros(3, dtype=random_data.gravity.dtype)
|
893
|
+
.at[2]
|
894
|
+
.set(
|
895
|
+
-jax.random.uniform(
|
896
|
+
key=k7,
|
897
|
+
shape=(),
|
898
|
+
minval=standard_gravity_bounds[0],
|
899
|
+
maxval=standard_gravity_bounds[1],
|
900
|
+
)
|
901
|
+
)
|
778
902
|
)
|
779
903
|
|
904
|
+
if contacts_params is None:
|
905
|
+
|
906
|
+
if isinstance(
|
907
|
+
model.contact_model,
|
908
|
+
jaxsim.rbda.contacts.SoftContacts
|
909
|
+
| jaxsim.rbda.contacts.ViscoElasticContacts,
|
910
|
+
):
|
911
|
+
|
912
|
+
random_data = random_data.replace(
|
913
|
+
contacts_params=js.contact.estimate_good_contact_parameters(
|
914
|
+
model=model, standard_gravity=random_data.gravity
|
915
|
+
),
|
916
|
+
validate=False,
|
917
|
+
)
|
918
|
+
|
919
|
+
else:
|
920
|
+
random_data = random_data.replace(
|
921
|
+
contacts_params=model.contact_model._parameters_class(),
|
922
|
+
validate=False,
|
923
|
+
)
|
924
|
+
|
780
925
|
return random_data
|