jaxsim 0.6.2.dev182__py3-none-any.whl → 0.6.2.dev225__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +0 -1
- jaxsim/api/com.py +1 -3
- jaxsim/api/common.py +26 -38
- jaxsim/api/contact.py +140 -24
- jaxsim/api/data.py +96 -33
- jaxsim/api/integrators.py +18 -11
- jaxsim/api/model.py +25 -43
- jaxsim/api/ode.py +28 -6
- jaxsim/api/references.py +9 -16
- jaxsim/math/__init__.py +1 -1
- jaxsim/math/adjoint.py +2 -2
- jaxsim/math/transform.py +2 -2
- jaxsim/math/utils.py +3 -2
- jaxsim/mujoco/visualizer.py +1 -1
- jaxsim/parsers/kinematic_graph.py +1 -1
- jaxsim/rbda/__init__.py +1 -1
- jaxsim/rbda/contacts/__init__.py +6 -2
- jaxsim/rbda/contacts/common.py +114 -4
- jaxsim/rbda/contacts/relaxed_rigid.py +57 -177
- jaxsim/rbda/contacts/rigid.py +538 -0
- jaxsim/rbda/contacts/soft.py +448 -0
- jaxsim/rbda/forward_kinematics.py +0 -29
- jaxsim/rbda/utils.py +2 -2
- jaxsim/terrain/terrain.py +1 -1
- {jaxsim-0.6.2.dev182.dist-info → jaxsim-0.6.2.dev225.dist-info}/METADATA +3 -2
- {jaxsim-0.6.2.dev182.dist-info → jaxsim-0.6.2.dev225.dist-info}/RECORD +30 -29
- {jaxsim-0.6.2.dev182.dist-info → jaxsim-0.6.2.dev225.dist-info}/WHEEL +1 -1
- jaxsim/api/contact_model.py +0 -101
- {jaxsim-0.6.2.dev182.dist-info → jaxsim-0.6.2.dev225.dist-info/licenses}/LICENSE +0 -0
- {jaxsim-0.6.2.dev182.dist-info → jaxsim-0.6.2.dev225.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
@@ -17,5 +17,5 @@ __version__: str
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
18
18
|
version_tuple: VERSION_TUPLE
|
19
19
|
|
20
|
-
__version__ = version = '0.6.2.
|
21
|
-
__version_tuple__ = version_tuple = (0, 6, 2, '
|
20
|
+
__version__ = version = '0.6.2.dev225'
|
21
|
+
__version_tuple__ = version_tuple = (0, 6, 2, 'dev225')
|
jaxsim/api/__init__.py
CHANGED
jaxsim/api/com.py
CHANGED
@@ -301,9 +301,7 @@ def bias_acceleration(
|
|
301
301
|
C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841
|
302
302
|
C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
|
303
303
|
|
304
|
-
L_H_C = L_H_W = jax.vmap( # noqa: F841
|
305
|
-
lambda W_H_L: jaxsim.math.Transform.inverse(W_H_L)
|
306
|
-
)(W_H_L)
|
304
|
+
L_H_C = L_H_W = jax.vmap(jaxsim.math.Transform.inverse)(W_H_L) # noqa: F841
|
307
305
|
|
308
306
|
L_v_LC = L_v_LW = jax.vmap( # noqa: F841
|
309
307
|
lambda i: -js.link.velocity(
|
jaxsim/api/common.py
CHANGED
@@ -121,14 +121,8 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
|
|
121
121
|
The 6D quantity in the other representation.
|
122
122
|
"""
|
123
123
|
|
124
|
-
W_array = array
|
125
|
-
W_H_O = transform
|
126
|
-
|
127
|
-
if W_array.size != 6:
|
128
|
-
raise ValueError(W_array.size, 6)
|
129
|
-
|
130
|
-
if W_H_O.shape != (4, 4):
|
131
|
-
raise ValueError(W_H_O.shape, (4, 4))
|
124
|
+
W_array = array
|
125
|
+
W_H_O = transform
|
132
126
|
|
133
127
|
match other_representation:
|
134
128
|
|
@@ -139,25 +133,24 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
|
|
139
133
|
|
140
134
|
if not is_force:
|
141
135
|
O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True)
|
142
|
-
O_array = O_Xv_W
|
136
|
+
O_array = jnp.einsum("...ij,...j->...i", O_Xv_W, W_array)
|
143
137
|
|
144
138
|
else:
|
145
|
-
O_Xf_W = Adjoint.from_transform(transform=W_H_O).
|
146
|
-
O_array = O_Xf_W
|
139
|
+
O_Xf_W = Adjoint.from_transform(transform=W_H_O).swapaxes(-1, -2)
|
140
|
+
O_array = jnp.einsum("...ij,...j->...i", O_Xf_W, W_array)
|
147
141
|
|
148
142
|
return O_array
|
149
143
|
|
150
144
|
case VelRepr.Mixed:
|
151
|
-
|
152
|
-
W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
|
145
|
+
W_H_OW = W_H_O.at[..., 0:3, 0:3].set(jnp.eye(3))
|
153
146
|
|
154
147
|
if not is_force:
|
155
148
|
OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True)
|
156
|
-
OW_array = OW_Xv_W
|
149
|
+
OW_array = jnp.einsum("...ij,...j->...i", OW_Xv_W, W_array)
|
157
150
|
|
158
151
|
else:
|
159
|
-
OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).
|
160
|
-
OW_array = OW_Xf_W
|
152
|
+
OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).swapaxes(-1, -2)
|
153
|
+
OW_array = jnp.einsum("...ij,...j->...i", OW_Xf_W, W_array)
|
161
154
|
|
162
155
|
return OW_array
|
163
156
|
|
@@ -188,45 +181,40 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
|
|
188
181
|
The 6D quantity in the inertial-fixed representation.
|
189
182
|
"""
|
190
183
|
|
191
|
-
|
192
|
-
W_H_O = transform
|
193
|
-
|
194
|
-
if W_array.size != 6:
|
195
|
-
raise ValueError(W_array.size, 6)
|
196
|
-
|
197
|
-
if W_H_O.shape != (4, 4):
|
198
|
-
raise ValueError(W_H_O.shape, (4, 4))
|
184
|
+
O_array = array
|
185
|
+
W_H_O = transform
|
199
186
|
|
200
187
|
match other_representation:
|
201
188
|
case VelRepr.Inertial:
|
202
|
-
|
203
|
-
return W_array
|
189
|
+
return O_array
|
204
190
|
|
205
191
|
case VelRepr.Body:
|
206
|
-
O_array = array
|
207
192
|
|
208
193
|
if not is_force:
|
209
|
-
W_Xv_O
|
210
|
-
W_array = W_Xv_O
|
194
|
+
W_Xv_O = Adjoint.from_transform(W_H_O)
|
195
|
+
W_array = jnp.einsum("...ij,...j->...i", W_Xv_O, O_array)
|
211
196
|
|
212
197
|
else:
|
213
|
-
W_Xf_O = Adjoint.from_transform(
|
214
|
-
|
198
|
+
W_Xf_O = Adjoint.from_transform(
|
199
|
+
transform=W_H_O, inverse=True
|
200
|
+
).swapaxes(-1, -2)
|
201
|
+
W_array = jnp.einsum("...ij,...j->...i", W_Xf_O, O_array)
|
215
202
|
|
216
203
|
return W_array
|
217
204
|
|
218
205
|
case VelRepr.Mixed:
|
219
|
-
|
220
|
-
|
221
|
-
W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
|
206
|
+
|
207
|
+
W_H_OW = W_H_O.at[..., 0:3, 0:3].set(jnp.eye(3))
|
222
208
|
|
223
209
|
if not is_force:
|
224
|
-
W_Xv_BW
|
225
|
-
W_array = W_Xv_BW
|
210
|
+
W_Xv_BW = Adjoint.from_transform(W_H_OW)
|
211
|
+
W_array = jnp.einsum("...ij,...j->...i", W_Xv_BW, O_array)
|
226
212
|
|
227
213
|
else:
|
228
|
-
W_Xf_BW = Adjoint.from_transform(
|
229
|
-
|
214
|
+
W_Xf_BW = Adjoint.from_transform(
|
215
|
+
transform=W_H_OW, inverse=True
|
216
|
+
).swapaxes(-1, -2)
|
217
|
+
W_array = jnp.einsum("...ij,...j->...i", W_Xf_BW, O_array)
|
230
218
|
|
231
219
|
return W_array
|
232
220
|
|
jaxsim/api/contact.py
CHANGED
@@ -11,7 +11,7 @@ import jaxsim.terrain
|
|
11
11
|
import jaxsim.typing as jtp
|
12
12
|
from jaxsim import logging
|
13
13
|
from jaxsim.math import Adjoint, Cross, Transform
|
14
|
-
from jaxsim.rbda import
|
14
|
+
from jaxsim.rbda.contacts import SoftContacts
|
15
15
|
|
16
16
|
from .common import VelRepr
|
17
17
|
|
@@ -37,14 +37,11 @@ def collidable_point_kinematics(
|
|
37
37
|
the linear component of the mixed 6D frame velocity.
|
38
38
|
"""
|
39
39
|
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
link_transforms=data._link_transforms,
|
46
|
-
link_velocities=data._link_velocities,
|
47
|
-
)
|
40
|
+
W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
|
41
|
+
model=model,
|
42
|
+
link_transforms=data._link_transforms,
|
43
|
+
link_velocities=data._link_velocities,
|
44
|
+
)
|
48
45
|
|
49
46
|
return W_p_Ci, W_ṗ_Ci
|
50
47
|
|
@@ -164,18 +161,23 @@ def estimate_good_soft_contacts_parameters(
|
|
164
161
|
def estimate_good_contact_parameters(
|
165
162
|
model: js.model.JaxSimModel,
|
166
163
|
*,
|
164
|
+
standard_gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
|
167
165
|
static_friction_coefficient: jtp.FloatLike = 0.5,
|
168
|
-
|
166
|
+
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
|
167
|
+
damping_ratio: jtp.FloatLike = 1.0,
|
168
|
+
max_penetration: jtp.FloatLike | None = None,
|
169
169
|
) -> jaxsim.rbda.contacts.ContactParamsTypes:
|
170
170
|
"""
|
171
171
|
Estimate good contact parameters.
|
172
172
|
|
173
173
|
Args:
|
174
174
|
model: The model to consider.
|
175
|
+
standard_gravity: The standard gravity acceleration.
|
175
176
|
static_friction_coefficient: The static friction coefficient.
|
176
|
-
|
177
|
-
|
178
|
-
|
177
|
+
number_of_active_collidable_points_steady_state:
|
178
|
+
The number of active collidable points in steady state.
|
179
|
+
damping_ratio: The damping ratio.
|
180
|
+
max_penetration: The maximum penetration allowed.
|
179
181
|
|
180
182
|
Returns:
|
181
183
|
The estimated good contacts parameters.
|
@@ -190,20 +192,41 @@ def estimate_good_contact_parameters(
|
|
190
192
|
specific application.
|
191
193
|
"""
|
192
194
|
|
193
|
-
|
195
|
+
def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
|
196
|
+
"""
|
197
|
+
Displacement between the CoM and the lowest collidable point using zero
|
198
|
+
joint positions.
|
199
|
+
"""
|
200
|
+
|
201
|
+
zero_data = js.data.JaxSimModelData.build(
|
202
|
+
model=model,
|
203
|
+
)
|
204
|
+
|
205
|
+
W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
|
194
206
|
|
195
|
-
|
196
|
-
|
207
|
+
if model.floating_base():
|
208
|
+
W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
|
209
|
+
return 2 * (W_pz_CoM - W_pz_C.min())
|
197
210
|
|
198
|
-
|
199
|
-
mu=static_friction_coefficient,
|
200
|
-
**kwargs,
|
201
|
-
)
|
211
|
+
return 2 * W_pz_CoM
|
202
212
|
|
203
|
-
|
204
|
-
|
213
|
+
max_δ = (
|
214
|
+
max_penetration
|
215
|
+
if max_penetration is not None
|
216
|
+
# Consider as default a 0.5% of the model height.
|
217
|
+
else 0.005 * estimate_model_height(model=model)
|
218
|
+
)
|
205
219
|
|
206
|
-
|
220
|
+
nc = number_of_active_collidable_points_steady_state
|
221
|
+
|
222
|
+
return model.contact_model._parameters_class().build_default_from_jaxsim_model(
|
223
|
+
model=model,
|
224
|
+
standard_gravity=standard_gravity,
|
225
|
+
static_friction_coefficient=static_friction_coefficient,
|
226
|
+
max_penetration=max_δ,
|
227
|
+
number_of_active_collidable_points_steady_state=nc,
|
228
|
+
damping_ratio=damping_ratio,
|
229
|
+
)
|
207
230
|
|
208
231
|
|
209
232
|
@jax.jit
|
@@ -244,7 +267,7 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt
|
|
244
267
|
|
245
268
|
# Build the link-to-point transform from the displacement between the link frame L
|
246
269
|
# and the implicit contact frame C.
|
247
|
-
L_H_C = jax.vmap(
|
270
|
+
L_H_C = jax.vmap(jnp.eye(4).at[0:3, 3].set)(L_p_Ci)
|
248
271
|
|
249
272
|
# Compose the work-to-link and link-to-point transforms.
|
250
273
|
return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)
|
@@ -504,3 +527,96 @@ def jacobian_derivative(
|
|
504
527
|
)
|
505
528
|
|
506
529
|
return O_J̇_WC
|
530
|
+
|
531
|
+
|
532
|
+
@jax.jit
|
533
|
+
@js.common.named_scope
|
534
|
+
def link_contact_forces(
|
535
|
+
model: js.model.JaxSimModel,
|
536
|
+
data: js.data.JaxSimModelData,
|
537
|
+
*,
|
538
|
+
link_forces: jtp.MatrixLike | None = None,
|
539
|
+
joint_torques: jtp.VectorLike | None = None,
|
540
|
+
) -> tuple[jtp.Matrix, dict[str, jtp.Matrix]]:
|
541
|
+
"""
|
542
|
+
Compute the 6D contact forces of all links of the model in inertial representation.
|
543
|
+
|
544
|
+
Args:
|
545
|
+
model: The model to consider.
|
546
|
+
data: The data of the considered model.
|
547
|
+
link_forces:
|
548
|
+
The 6D external forces to apply to the links expressed in inertial representation
|
549
|
+
joint_torques:
|
550
|
+
The joint torques acting on the joints.
|
551
|
+
|
552
|
+
Returns:
|
553
|
+
A `(nL, 6)` array containing the stacked 6D contact forces of the links,
|
554
|
+
expressed in inertial representation.
|
555
|
+
"""
|
556
|
+
|
557
|
+
# Compute the contact forces for each collidable point with the active contact model.
|
558
|
+
W_f_C, aux_dict = model.contact_model.compute_contact_forces(
|
559
|
+
model=model,
|
560
|
+
data=data,
|
561
|
+
**(
|
562
|
+
dict(link_forces=link_forces, joint_force_references=joint_torques)
|
563
|
+
if not isinstance(model.contact_model, SoftContacts)
|
564
|
+
else {}
|
565
|
+
),
|
566
|
+
)
|
567
|
+
|
568
|
+
# Compute the 6D forces applied to the links equivalent to the forces applied
|
569
|
+
# to the frames associated to the collidable points.
|
570
|
+
W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C)
|
571
|
+
|
572
|
+
return W_f_L, aux_dict
|
573
|
+
|
574
|
+
|
575
|
+
@staticmethod
|
576
|
+
def link_forces_from_contact_forces(
|
577
|
+
model: js.model.JaxSimModel,
|
578
|
+
*,
|
579
|
+
contact_forces: jtp.MatrixLike,
|
580
|
+
) -> jtp.Matrix:
|
581
|
+
"""
|
582
|
+
Compute the link forces from the contact forces.
|
583
|
+
|
584
|
+
Args:
|
585
|
+
model: The robot model considered by the contact model.
|
586
|
+
contact_forces: The contact forces computed by the contact model.
|
587
|
+
|
588
|
+
Returns:
|
589
|
+
The 6D contact forces applied to the links and expressed in the frame of
|
590
|
+
the velocity representation of data.
|
591
|
+
"""
|
592
|
+
|
593
|
+
# Get the object storing the contact parameters of the model.
|
594
|
+
contact_parameters = model.kin_dyn_parameters.contact_parameters
|
595
|
+
|
596
|
+
# Extract the indices corresponding to the enabled collidable points.
|
597
|
+
indices_of_enabled_collidable_points = (
|
598
|
+
contact_parameters.indices_of_enabled_collidable_points
|
599
|
+
)
|
600
|
+
|
601
|
+
# Convert the contact forces to a JAX array.
|
602
|
+
W_f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())
|
603
|
+
|
604
|
+
# Construct the vector defining the parent link index of each collidable point.
|
605
|
+
# We use this vector to sum the 6D forces of all collidable points rigidly
|
606
|
+
# attached to the same link.
|
607
|
+
parent_link_index_of_collidable_points = jnp.array(
|
608
|
+
contact_parameters.body, dtype=int
|
609
|
+
)[indices_of_enabled_collidable_points]
|
610
|
+
|
611
|
+
# Create the mask that associate each collidable point to their parent link.
|
612
|
+
# We use this mask to sum the collidable points to the right link.
|
613
|
+
mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
|
614
|
+
model.number_of_links()
|
615
|
+
)
|
616
|
+
|
617
|
+
# Sum the forces of all collidable points rigidly attached to a body.
|
618
|
+
# Since the contact forces W_f_C are expressed in the world frame,
|
619
|
+
# we don't need any coordinate transformation.
|
620
|
+
W_f_L = mask.T @ W_f_C
|
621
|
+
|
622
|
+
return W_f_L
|
jaxsim/api/data.py
CHANGED
@@ -5,9 +5,9 @@ import functools
|
|
5
5
|
from collections.abc import Sequence
|
6
6
|
|
7
7
|
try:
|
8
|
-
from typing import override
|
8
|
+
from typing import Self, override
|
9
9
|
except ImportError:
|
10
|
-
from typing_extensions import override
|
10
|
+
from typing_extensions import override, Self
|
11
11
|
|
12
12
|
import jax
|
13
13
|
import jax.numpy as jnp
|
@@ -22,11 +22,6 @@ import jaxsim.typing as jtp
|
|
22
22
|
from . import common
|
23
23
|
from .common import VelRepr
|
24
24
|
|
25
|
-
try:
|
26
|
-
from typing import Self
|
27
|
-
except ImportError:
|
28
|
-
from typing_extensions import Self
|
29
|
-
|
30
25
|
|
31
26
|
@jax_dataclasses.pytree_dataclass
|
32
27
|
class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
@@ -64,6 +59,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
64
59
|
_link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
|
65
60
|
_link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None)
|
66
61
|
|
62
|
+
# Extended state for soft and rigid contact models.
|
63
|
+
contact_state: dict[str, jtp.Array] = dataclasses.field(default=None)
|
64
|
+
|
67
65
|
@staticmethod
|
68
66
|
def build(
|
69
67
|
model: js.model.JaxSimModel,
|
@@ -73,6 +71,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
73
71
|
base_linear_velocity: jtp.VectorLike | None = None,
|
74
72
|
base_angular_velocity: jtp.VectorLike | None = None,
|
75
73
|
joint_velocities: jtp.VectorLike | None = None,
|
74
|
+
contact_state: dict[str, jtp.Array] | None = None,
|
76
75
|
velocity_representation: VelRepr = VelRepr.Mixed,
|
77
76
|
) -> JaxSimModelData:
|
78
77
|
"""
|
@@ -89,6 +88,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
89
88
|
The base angular velocity in the selected representation.
|
90
89
|
joint_velocities: The joint velocities.
|
91
90
|
velocity_representation: The velocity representation to use. It defaults to mixed if not provided.
|
91
|
+
contact_state: The optional contact state.
|
92
92
|
|
93
93
|
Returns:
|
94
94
|
A `JaxSimModelData` initialized with the given state.
|
@@ -171,6 +171,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
171
171
|
)
|
172
172
|
)
|
173
173
|
|
174
|
+
contact_state = contact_state or {}
|
175
|
+
|
176
|
+
if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts):
|
177
|
+
contact_state.setdefault(
|
178
|
+
"tangential_deformation",
|
179
|
+
jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point),
|
180
|
+
)
|
181
|
+
|
174
182
|
model_data = JaxSimModelData(
|
175
183
|
velocity_representation=velocity_representation,
|
176
184
|
_base_quaternion=base_quaternion,
|
@@ -183,6 +191,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
183
191
|
_joint_transforms=joint_transforms,
|
184
192
|
_link_transforms=link_transforms,
|
185
193
|
_link_velocities=link_velocities_inertial,
|
194
|
+
contact_state=contact_state,
|
186
195
|
)
|
187
196
|
|
188
197
|
if not model_data.valid(model=model):
|
@@ -265,14 +274,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
265
274
|
"""
|
266
275
|
|
267
276
|
# Extract the base quaternion.
|
268
|
-
W_Q_B = self.base_quaternion
|
277
|
+
W_Q_B = self.base_quaternion
|
269
278
|
|
270
279
|
# Always normalize the quaternion to avoid numerical issues.
|
271
280
|
# If the active scheme does not integrate the quaternion on its manifold,
|
272
281
|
# we introduce a Baumgarte stabilization to let the quaternion converge to
|
273
282
|
# a unit quaternion. In this case, it is not guaranteed that the quaternion
|
274
283
|
# stored in the state is a unit quaternion.
|
275
|
-
norm = jaxsim.math.safe_norm(W_Q_B)
|
284
|
+
norm = jaxsim.math.safe_norm(W_Q_B, axis=-1, keepdims=True)
|
276
285
|
W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
|
277
286
|
return W_Q_B
|
278
287
|
|
@@ -285,11 +294,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
285
294
|
The base 6D velocity in the active representation.
|
286
295
|
"""
|
287
296
|
|
288
|
-
W_v_WB = jnp.
|
289
|
-
[
|
290
|
-
self._base_linear_velocity,
|
291
|
-
self._base_angular_velocity,
|
292
|
-
]
|
297
|
+
W_v_WB = jnp.concatenate(
|
298
|
+
[self._base_linear_velocity, self._base_angular_velocity], axis=-1
|
293
299
|
)
|
294
300
|
|
295
301
|
W_H_B = self._base_transform
|
@@ -350,11 +356,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
350
356
|
|
351
357
|
@js.common.named_scope
|
352
358
|
@jax.jit
|
353
|
-
def reset_base_quaternion(
|
359
|
+
def reset_base_quaternion(
|
360
|
+
self, model: js.model.JaxSimModel, base_quaternion: jtp.VectorLike
|
361
|
+
) -> Self:
|
354
362
|
"""
|
355
363
|
Reset the base quaternion.
|
356
364
|
|
357
365
|
Args:
|
366
|
+
model: The JaxSim model to use.
|
358
367
|
base_quaternion: The base orientation as a quaternion.
|
359
368
|
|
360
369
|
Returns:
|
@@ -363,18 +372,21 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
363
372
|
|
364
373
|
W_Q_B = jnp.array(base_quaternion, dtype=float)
|
365
374
|
|
366
|
-
norm = jaxsim.math.safe_norm(W_Q_B)
|
375
|
+
norm = jaxsim.math.safe_norm(W_Q_B, axis=-1)
|
367
376
|
W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
|
368
377
|
|
369
|
-
return self.replace(
|
378
|
+
return self.replace(model=model, base_quaternion=W_Q_B)
|
370
379
|
|
371
380
|
@js.common.named_scope
|
372
381
|
@jax.jit
|
373
|
-
def reset_base_pose(
|
382
|
+
def reset_base_pose(
|
383
|
+
self, model: js.model.JaxSimModel, base_pose: jtp.MatrixLike
|
384
|
+
) -> Self:
|
374
385
|
"""
|
375
386
|
Reset the base pose.
|
376
387
|
|
377
388
|
Args:
|
389
|
+
model: The JaxSim model to use.
|
378
390
|
base_pose: The base pose as an SE(3) matrix.
|
379
391
|
|
380
392
|
Returns:
|
@@ -385,6 +397,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
385
397
|
W_p_B = base_pose[0:3, 3]
|
386
398
|
W_Q_B = jaxsim.math.Quaternion.from_dcm(dcm=base_pose[0:3, 0:3])
|
387
399
|
return self.replace(
|
400
|
+
model=model,
|
388
401
|
base_position=W_p_B,
|
389
402
|
base_quaternion=W_Q_B,
|
390
403
|
)
|
@@ -399,11 +412,19 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
399
412
|
base_linear_velocity: jtp.Vector | None = None,
|
400
413
|
base_angular_velocity: jtp.Vector | None = None,
|
401
414
|
base_position: jtp.Vector | None = None,
|
415
|
+
*,
|
416
|
+
contact_state: dict[str, jtp.Array] | None = None,
|
402
417
|
validate: bool = False,
|
403
418
|
) -> Self:
|
404
419
|
"""
|
405
420
|
Replace the attributes of the `JaxSimModelData` object.
|
406
421
|
"""
|
422
|
+
|
423
|
+
# Extract the batch size.
|
424
|
+
batch_size = (
|
425
|
+
self._base_transform.shape[0] if self._base_transform.ndim > 2 else 1
|
426
|
+
)
|
427
|
+
|
407
428
|
if joint_positions is None:
|
408
429
|
joint_positions = self.joint_positions
|
409
430
|
if joint_velocities is None:
|
@@ -412,6 +433,22 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
412
433
|
base_quaternion = self.base_quaternion
|
413
434
|
if base_position is None:
|
414
435
|
base_position = self.base_position
|
436
|
+
if contact_state is None:
|
437
|
+
contact_state = self.contact_state
|
438
|
+
|
439
|
+
if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts):
|
440
|
+
contact_state.setdefault(
|
441
|
+
"tangential_deformation",
|
442
|
+
jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point),
|
443
|
+
)
|
444
|
+
|
445
|
+
# Normalize the quaternion to avoid numerical issues.
|
446
|
+
base_quaternion_norm = jaxsim.math.safe_norm(
|
447
|
+
base_quaternion, axis=-1, keepdims=True
|
448
|
+
)
|
449
|
+
base_quaternion = base_quaternion / jnp.where(
|
450
|
+
base_quaternion_norm == 0, 1.0, base_quaternion_norm
|
451
|
+
)
|
415
452
|
|
416
453
|
joint_positions = jnp.atleast_1d(joint_positions.squeeze()).astype(float)
|
417
454
|
joint_velocities = jnp.atleast_1d(joint_velocities.squeeze()).astype(float)
|
@@ -421,44 +458,70 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
421
458
|
base_transform = jaxsim.math.Transform.from_quaternion_and_translation(
|
422
459
|
translation=base_position, quaternion=base_quaternion
|
423
460
|
)
|
424
|
-
|
425
|
-
|
461
|
+
|
462
|
+
joint_transforms = jax.vmap(model.kin_dyn_parameters.joint_transforms)(
|
463
|
+
joint_positions=jnp.broadcast_to(
|
464
|
+
joint_positions, (batch_size, model.dofs())
|
465
|
+
),
|
466
|
+
base_transform=jnp.broadcast_to(base_transform, (batch_size, 4, 4)),
|
426
467
|
)
|
427
468
|
|
428
469
|
if base_linear_velocity is None and base_angular_velocity is None:
|
429
|
-
|
430
|
-
|
470
|
+
base_linear_velocity_inertial = self._base_linear_velocity
|
471
|
+
base_angular_velocity_inertial = self._base_angular_velocity
|
431
472
|
else:
|
432
473
|
if base_linear_velocity is None:
|
433
474
|
base_linear_velocity = self.base_velocity[:3]
|
434
475
|
if base_angular_velocity is None:
|
435
476
|
base_angular_velocity = self.base_velocity[3:]
|
477
|
+
|
436
478
|
base_linear_velocity = jnp.atleast_1d(base_linear_velocity.squeeze())
|
437
479
|
base_angular_velocity = jnp.atleast_1d(base_angular_velocity.squeeze())
|
480
|
+
|
438
481
|
W_v_WB = JaxSimModelData.other_representation_to_inertial(
|
439
482
|
array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
|
440
483
|
other_representation=self.velocity_representation,
|
441
484
|
transform=base_transform,
|
442
485
|
is_force=False,
|
443
486
|
).astype(float)
|
444
|
-
base_linear_velocity, base_angular_velocity = W_v_WB[:3], W_v_WB[3:]
|
445
487
|
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
488
|
+
base_linear_velocity_inertial, base_angular_velocity_inertial = (
|
489
|
+
W_v_WB[..., :3],
|
490
|
+
W_v_WB[..., 3:],
|
491
|
+
)
|
492
|
+
|
493
|
+
link_transforms, link_velocities = jax.vmap(
|
494
|
+
jaxsim.rbda.forward_kinematics_model, in_axes=(None,)
|
495
|
+
)(
|
496
|
+
model,
|
497
|
+
base_position=jnp.broadcast_to(base_position, (batch_size, 3)),
|
498
|
+
base_quaternion=jnp.broadcast_to(base_quaternion, (batch_size, 4)),
|
499
|
+
joint_positions=jnp.broadcast_to(
|
500
|
+
joint_positions, (batch_size, model.dofs())
|
501
|
+
),
|
502
|
+
joint_velocities=jnp.broadcast_to(
|
503
|
+
joint_velocities, (batch_size, model.dofs())
|
504
|
+
),
|
505
|
+
base_linear_velocity_inertial=jnp.broadcast_to(
|
506
|
+
base_linear_velocity_inertial, (batch_size, 3)
|
507
|
+
),
|
508
|
+
base_angular_velocity_inertial=jnp.broadcast_to(
|
509
|
+
base_angular_velocity_inertial, (batch_size, 3)
|
510
|
+
),
|
454
511
|
)
|
455
512
|
|
513
|
+
# Adjust the output shapes.
|
514
|
+
if batch_size == 1:
|
515
|
+
link_transforms = link_transforms.reshape(self._link_transforms.shape)
|
516
|
+
link_velocities = link_velocities.reshape(self._link_velocities.shape)
|
517
|
+
joint_transforms = joint_transforms.reshape(self._joint_transforms.shape)
|
518
|
+
|
456
519
|
return super().replace(
|
457
520
|
_joint_positions=joint_positions,
|
458
521
|
_joint_velocities=joint_velocities,
|
459
522
|
_base_quaternion=base_quaternion,
|
460
|
-
_base_linear_velocity=
|
461
|
-
_base_angular_velocity=
|
523
|
+
_base_linear_velocity=base_linear_velocity_inertial,
|
524
|
+
_base_angular_velocity=base_angular_velocity_inertial,
|
462
525
|
_base_position=base_position,
|
463
526
|
_base_transform=base_transform,
|
464
527
|
_joint_transforms=joint_transforms,
|