jaxsim 0.4.3.dev68__py3-none-any.whl → 0.4.3.dev77__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +5 -0
- jaxsim/_version.py +2 -2
- jaxsim/api/contact.py +32 -1
- jaxsim/api/data.py +68 -20
- jaxsim/api/joint.py +62 -2
- jaxsim/api/model.py +37 -23
- jaxsim/api/ode.py +29 -25
- jaxsim/api/ode_data.py +11 -1
- jaxsim/integrators/common.py +1 -1
- jaxsim/math/inertia.py +1 -1
- jaxsim/mujoco/loaders.py +3 -3
- jaxsim/parsers/kinematic_graph.py +3 -3
- jaxsim/parsers/rod/parser.py +18 -14
- jaxsim/rbda/contacts/relaxed_rigid.py +409 -0
- jaxsim/rbda/contacts/rigid.py +21 -41
- jaxsim/terrain/terrain.py +41 -25
- jaxsim/typing.py +1 -1
- jaxsim/utils/jaxsim_dataclass.py +12 -9
- jaxsim/utils/wrappers.py +1 -1
- {jaxsim-0.4.3.dev68.dist-info → jaxsim-0.4.3.dev77.dist-info}/METADATA +2 -1
- {jaxsim-0.4.3.dev68.dist-info → jaxsim-0.4.3.dev77.dist-info}/RECORD +24 -23
- {jaxsim-0.4.3.dev68.dist-info → jaxsim-0.4.3.dev77.dist-info}/WHEEL +1 -1
- {jaxsim-0.4.3.dev68.dist-info → jaxsim-0.4.3.dev77.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev68.dist-info → jaxsim-0.4.3.dev77.dist-info}/top_level.txt +0 -0
jaxsim/__init__.py
CHANGED
@@ -20,6 +20,11 @@ def _jnp_options() -> None:
|
|
20
20
|
if jnp.empty(0, dtype=float).dtype != jnp.empty(0, dtype=np.float64).dtype:
|
21
21
|
logging.warning("Failed to enable 64bit precision in JAX")
|
22
22
|
|
23
|
+
else:
|
24
|
+
logging.warning(
|
25
|
+
"Using 32bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
|
26
|
+
)
|
27
|
+
|
23
28
|
|
24
29
|
def _np_options() -> None:
|
25
30
|
import numpy as np
|
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.4.3.
|
16
|
-
__version_tuple__ = version_tuple = (0, 4, 3, '
|
15
|
+
__version__ = version = '0.4.3.dev77'
|
16
|
+
__version_tuple__ = version_tuple = (0, 4, 3, 'dev77')
|
jaxsim/api/contact.py
CHANGED
@@ -117,6 +117,7 @@ def collidable_point_dynamics(
|
|
117
117
|
model: js.model.JaxSimModel,
|
118
118
|
data: js.data.JaxSimModelData,
|
119
119
|
link_forces: jtp.MatrixLike | None = None,
|
120
|
+
joint_force_references: jtp.VectorLike | None = None,
|
120
121
|
) -> tuple[jtp.Matrix, dict[str, jtp.Array]]:
|
121
122
|
r"""
|
122
123
|
Compute the 6D force applied to each collidable point.
|
@@ -127,11 +128,14 @@ def collidable_point_dynamics(
|
|
127
128
|
link_forces:
|
128
129
|
The 6D external forces to apply to the links expressed in the same
|
129
130
|
representation of data.
|
131
|
+
joint_force_references:
|
132
|
+
The joint force references to apply to the joints.
|
130
133
|
|
131
134
|
Returns:
|
132
135
|
The 6D force applied to each collidable point and additional data based on the contact model configured:
|
133
136
|
- Soft: the material deformation rate.
|
134
|
-
- Rigid:
|
137
|
+
- Rigid: no additional data.
|
138
|
+
- QuasiRigid: no additional data.
|
135
139
|
|
136
140
|
Note:
|
137
141
|
The material deformation rate is always returned in the mixed frame
|
@@ -144,6 +148,10 @@ def collidable_point_dynamics(
|
|
144
148
|
W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
|
145
149
|
|
146
150
|
# Import privately the contacts classes.
|
151
|
+
from jaxsim.rbda.contacts.relaxed_rigid import (
|
152
|
+
RelaxedRigidContacts,
|
153
|
+
RelaxedRigidContactsState,
|
154
|
+
)
|
147
155
|
from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
|
148
156
|
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
|
149
157
|
|
@@ -186,6 +194,29 @@ def collidable_point_dynamics(
|
|
186
194
|
model=model,
|
187
195
|
data=data,
|
188
196
|
link_forces=link_forces,
|
197
|
+
joint_force_references=joint_force_references,
|
198
|
+
)
|
199
|
+
|
200
|
+
aux_data = dict()
|
201
|
+
|
202
|
+
case RelaxedRigidContacts():
|
203
|
+
assert isinstance(model.contact_model, RelaxedRigidContacts)
|
204
|
+
assert isinstance(data.state.contact, RelaxedRigidContactsState)
|
205
|
+
|
206
|
+
# Build the contact model.
|
207
|
+
relaxed_rigid_contacts = RelaxedRigidContacts(
|
208
|
+
parameters=data.contacts_params, terrain=model.terrain
|
209
|
+
)
|
210
|
+
|
211
|
+
# Compute the 6D force expressed in the inertial frame and applied to each
|
212
|
+
# collidable point.
|
213
|
+
W_f_Ci, _ = relaxed_rigid_contacts.compute_contact_forces(
|
214
|
+
position=W_p_Ci,
|
215
|
+
velocity=W_ṗ_Ci,
|
216
|
+
model=model,
|
217
|
+
data=data,
|
218
|
+
link_forces=link_forces,
|
219
|
+
joint_force_references=joint_force_references,
|
189
220
|
)
|
190
221
|
|
191
222
|
aux_data = dict()
|
jaxsim/api/data.py
CHANGED
@@ -6,10 +6,11 @@ from collections.abc import Sequence
|
|
6
6
|
|
7
7
|
import jax
|
8
8
|
import jax.numpy as jnp
|
9
|
+
import jax.scipy.spatial.transform
|
9
10
|
import jax_dataclasses
|
10
|
-
import jaxlie
|
11
11
|
|
12
12
|
import jaxsim.api as js
|
13
|
+
import jaxsim.math
|
13
14
|
import jaxsim.rbda
|
14
15
|
import jaxsim.typing as jtp
|
15
16
|
from jaxsim.rbda.contacts.soft import SoftContacts
|
@@ -39,7 +40,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
39
40
|
contacts_params: jaxsim.rbda.ContactsParams = dataclasses.field(repr=False)
|
40
41
|
|
41
42
|
time_ns: jtp.Int = dataclasses.field(
|
42
|
-
default_factory=lambda: jnp.array(
|
43
|
+
default_factory=lambda: jnp.array(
|
44
|
+
0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
|
45
|
+
),
|
43
46
|
)
|
44
47
|
|
45
48
|
def __hash__(self) -> int:
|
@@ -172,9 +175,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
172
175
|
)
|
173
176
|
|
174
177
|
time_ns = (
|
175
|
-
jnp.array(
|
178
|
+
jnp.array(
|
179
|
+
time * 1e9,
|
180
|
+
dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32,
|
181
|
+
)
|
176
182
|
if time is not None
|
177
|
-
else jnp.array(
|
183
|
+
else jnp.array(
|
184
|
+
0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
|
185
|
+
)
|
178
186
|
)
|
179
187
|
|
180
188
|
if isinstance(model.contact_model, SoftContacts):
|
@@ -188,10 +196,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
188
196
|
else:
|
189
197
|
contacts_params = model.contact_model.parameters
|
190
198
|
|
191
|
-
W_H_B =
|
192
|
-
translation=base_position,
|
193
|
-
|
194
|
-
).as_matrix()
|
199
|
+
W_H_B = jaxsim.math.Transform.from_quaternion_and_translation(
|
200
|
+
translation=base_position, quaternion=base_quaternion
|
201
|
+
)
|
195
202
|
|
196
203
|
v_WB = JaxSimModelData.other_representation_to_inertial(
|
197
204
|
array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
|
@@ -377,7 +384,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
377
384
|
on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
|
378
385
|
)
|
379
386
|
|
380
|
-
return (W_Q_B if not dcm else
|
387
|
+
return (W_Q_B if not dcm else jaxsim.math.Quaternion.to_dcm(W_Q_B)).astype(
|
388
|
+
float
|
389
|
+
)
|
381
390
|
|
382
391
|
@jax.jit
|
383
392
|
def base_transform(self) -> jtp.Matrix:
|
@@ -586,16 +595,18 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
586
595
|
The updated `JaxSimModelData` object.
|
587
596
|
"""
|
588
597
|
|
589
|
-
|
598
|
+
W_Q_B = jnp.array(base_quaternion, dtype=float)
|
599
|
+
|
600
|
+
W_Q_B = jax.lax.select(
|
601
|
+
pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
|
602
|
+
on_true=W_Q_B,
|
603
|
+
on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
|
604
|
+
)
|
590
605
|
|
591
606
|
return self.replace(
|
592
607
|
validate=True,
|
593
608
|
state=self.state.replace(
|
594
|
-
physics_model=self.state.physics_model.replace(
|
595
|
-
base_quaternion=jnp.atleast_1d(base_quaternion.squeeze()).astype(
|
596
|
-
float
|
597
|
-
)
|
598
|
-
)
|
609
|
+
physics_model=self.state.physics_model.replace(base_quaternion=W_Q_B)
|
599
610
|
),
|
600
611
|
)
|
601
612
|
|
@@ -728,6 +739,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
728
739
|
)
|
729
740
|
|
730
741
|
|
742
|
+
@functools.partial(jax.jit, static_argnames=["velocity_representation", "base_rpy_seq"])
|
731
743
|
def random_model_data(
|
732
744
|
model: js.model.JaxSimModel,
|
733
745
|
*,
|
@@ -737,6 +749,18 @@ def random_model_data(
|
|
737
749
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
738
750
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
739
751
|
] = ((-1, -1, 0.5), 1.0),
|
752
|
+
base_rpy_bounds: tuple[
|
753
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
754
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
755
|
+
] = (-jnp.pi, jnp.pi),
|
756
|
+
base_rpy_seq: str = "XYZ",
|
757
|
+
joint_pos_bounds: (
|
758
|
+
tuple[
|
759
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
760
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
761
|
+
]
|
762
|
+
| None
|
763
|
+
) = None,
|
740
764
|
base_vel_lin_bounds: tuple[
|
741
765
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
742
766
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
@@ -762,6 +786,12 @@ def random_model_data(
|
|
762
786
|
key: The random key.
|
763
787
|
velocity_representation: The velocity representation to use.
|
764
788
|
base_pos_bounds: The bounds for the base position.
|
789
|
+
base_rpy_bounds:
|
790
|
+
The bounds for the euler angles used to build the base orientation.
|
791
|
+
base_rpy_seq:
|
792
|
+
The sequence of axes for rotation (using `Rotation` from scipy).
|
793
|
+
joint_pos_bounds:
|
794
|
+
The bounds for the joint positions (reading the joint limits if None).
|
765
795
|
base_vel_lin_bounds: The bounds for the base linear velocity.
|
766
796
|
base_vel_ang_bounds: The bounds for the base angular velocity.
|
767
797
|
joint_vel_bounds: The bounds for the joint velocities.
|
@@ -776,6 +806,8 @@ def random_model_data(
|
|
776
806
|
|
777
807
|
p_min = jnp.array(base_pos_bounds[0], dtype=float)
|
778
808
|
p_max = jnp.array(base_pos_bounds[1], dtype=float)
|
809
|
+
rpy_min = jnp.array(base_rpy_bounds[0], dtype=float)
|
810
|
+
rpy_max = jnp.array(base_rpy_bounds[1], dtype=float)
|
779
811
|
v_min = jnp.array(base_vel_lin_bounds[0], dtype=float)
|
780
812
|
v_max = jnp.array(base_vel_lin_bounds[1], dtype=float)
|
781
813
|
ω_min = jnp.array(base_vel_ang_bounds[0], dtype=float)
|
@@ -801,13 +833,29 @@ def random_model_data(
|
|
801
833
|
key=k1, shape=(3,), minval=p_min, maxval=p_max
|
802
834
|
)
|
803
835
|
|
804
|
-
physics_model_state.base_quaternion =
|
805
|
-
|
806
|
-
|
836
|
+
physics_model_state.base_quaternion = jaxsim.math.Quaternion.to_wxyz(
|
837
|
+
xyzw=jax.scipy.spatial.transform.Rotation.from_euler(
|
838
|
+
seq=base_rpy_seq,
|
839
|
+
angles=jax.random.uniform(
|
840
|
+
key=k2, shape=(3,), minval=rpy_min, maxval=rpy_max
|
841
|
+
),
|
842
|
+
).as_quat()
|
843
|
+
)
|
807
844
|
|
808
845
|
if model.number_of_joints() > 0:
|
809
|
-
|
810
|
-
|
846
|
+
|
847
|
+
s_min, s_max = (
|
848
|
+
jnp.array(joint_pos_bounds, dtype=float)
|
849
|
+
if joint_pos_bounds is not None
|
850
|
+
else (None, None)
|
851
|
+
)
|
852
|
+
|
853
|
+
physics_model_state.joint_positions = (
|
854
|
+
js.joint.random_joint_positions(model=model, key=k3)
|
855
|
+
if (s_min is None or s_max is None)
|
856
|
+
else jax.random.uniform(
|
857
|
+
key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max
|
858
|
+
)
|
811
859
|
)
|
812
860
|
|
813
861
|
physics_model_state.joint_velocities = jax.random.uniform(
|
jaxsim/api/joint.py
CHANGED
@@ -180,17 +180,77 @@ def random_joint_positions(
|
|
180
180
|
|
181
181
|
Args:
|
182
182
|
model: The model to consider.
|
183
|
-
joint_names: The names of the joints.
|
184
|
-
key: The random key.
|
183
|
+
joint_names: The names of the considered joints (all if None).
|
184
|
+
key: The random key (initialized from seed 0 if None).
|
185
|
+
|
186
|
+
Note:
|
187
|
+
If the joint range or revolute joints is larger than 2π, their joint positions
|
188
|
+
will be sampled from an interval of size 2π.
|
185
189
|
|
186
190
|
Returns:
|
187
191
|
The random joint positions.
|
188
192
|
"""
|
189
193
|
|
194
|
+
# Consider the key corresponding to a zero seed if it was not passed.
|
190
195
|
key = key if key is not None else jax.random.PRNGKey(seed=0)
|
191
196
|
|
197
|
+
# Get the joint limits parsed from the model description.
|
192
198
|
s_min, s_max = position_limits(model=model, joint_names=joint_names)
|
193
199
|
|
200
|
+
# Get the joint indices.
|
201
|
+
# Note that it will trigger an exception if the given `joint_names` are not valid.
|
202
|
+
joint_names = joint_names if joint_names is not None else model.joint_names()
|
203
|
+
joint_indices = names_to_idxs(model=model, joint_names=joint_names)
|
204
|
+
|
205
|
+
from jaxsim.parsers.descriptions.joint import JointType
|
206
|
+
|
207
|
+
# Filter for revolute joints.
|
208
|
+
is_revolute = jnp.where(
|
209
|
+
jnp.array(model.kin_dyn_parameters.joint_model.joint_types[1:])[joint_indices]
|
210
|
+
== JointType.Revolute,
|
211
|
+
True,
|
212
|
+
False,
|
213
|
+
)
|
214
|
+
|
215
|
+
# Shorthand for π.
|
216
|
+
π = jnp.pi
|
217
|
+
|
218
|
+
# Filter for revolute with full range (or continuous).
|
219
|
+
is_revolute_full_range = jnp.logical_and(is_revolute, s_max - s_min >= 2 * π)
|
220
|
+
|
221
|
+
# Clip the lower limit to -π if the joint range is larger than [-π, π].
|
222
|
+
s_min = jnp.where(
|
223
|
+
jnp.logical_and(
|
224
|
+
is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
|
225
|
+
),
|
226
|
+
-π,
|
227
|
+
s_min,
|
228
|
+
)
|
229
|
+
|
230
|
+
# Clip the upper limit to +π if the joint range is larger than [-π, π].
|
231
|
+
s_max = jnp.where(
|
232
|
+
jnp.logical_and(
|
233
|
+
is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
|
234
|
+
),
|
235
|
+
π,
|
236
|
+
s_max,
|
237
|
+
)
|
238
|
+
|
239
|
+
# Shift the lower limit if the upper limit is smaller than +π.
|
240
|
+
s_min = jnp.where(
|
241
|
+
jnp.logical_and(is_revolute_full_range, s_max < π),
|
242
|
+
s_max - 2 * π,
|
243
|
+
s_min,
|
244
|
+
)
|
245
|
+
|
246
|
+
# Shift the upper limit if the lower limit is larger than -π.
|
247
|
+
s_max = jnp.where(
|
248
|
+
jnp.logical_and(is_revolute_full_range, s_min > -π),
|
249
|
+
s_min + 2 * π,
|
250
|
+
s_max,
|
251
|
+
)
|
252
|
+
|
253
|
+
# Sample the joint positions.
|
194
254
|
s_random = jax.random.uniform(
|
195
255
|
minval=s_min,
|
196
256
|
maxval=s_max,
|
jaxsim/api/model.py
CHANGED
@@ -1747,14 +1747,18 @@ def link_contact_forces(
|
|
1747
1747
|
data: The data of the considered model.
|
1748
1748
|
|
1749
1749
|
Returns:
|
1750
|
-
A (nL, 6) array containing the stacked 6D contact forces of the links,
|
1750
|
+
A `(nL, 6)` array containing the stacked 6D contact forces of the links,
|
1751
1751
|
expressed in the frame corresponding to the active representation.
|
1752
1752
|
"""
|
1753
1753
|
|
1754
|
+
# Note: the following code should be kept in sync with the function
|
1755
|
+
# `jaxsim.api.ode.system_velocity_dynamics`. We cannot merge them since
|
1756
|
+
# there we need to get also aux_data.
|
1757
|
+
|
1754
1758
|
# Compute the 6D forces applied to each collidable point expressed in the
|
1755
1759
|
# inertial frame.
|
1756
1760
|
with data.switch_velocity_representation(VelRepr.Inertial):
|
1757
|
-
|
1761
|
+
W_f_C = js.contact.collidable_point_forces(model=model, data=data)
|
1758
1762
|
|
1759
1763
|
# Construct the vector defining the parent link index of each collidable point.
|
1760
1764
|
# We use this vector to sum the 6D forces of all collidable points rigidly
|
@@ -1763,29 +1767,28 @@ def link_contact_forces(
|
|
1763
1767
|
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
1764
1768
|
)
|
1765
1769
|
|
1770
|
+
# Create the mask that associate each collidable point to their parent link.
|
1771
|
+
# We use this mask to sum the collidable points to the right link.
|
1772
|
+
mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
|
1773
|
+
model.number_of_links()
|
1774
|
+
)
|
1775
|
+
|
1766
1776
|
# Sum the forces of all collidable points rigidly attached to a body.
|
1767
|
-
# Since the contact forces
|
1777
|
+
# Since the contact forces W_f_C are expressed in the world frame,
|
1768
1778
|
# we don't need any coordinate transformation.
|
1769
|
-
|
1770
|
-
lambda nc: (
|
1771
|
-
jnp.vstack(
|
1772
|
-
jnp.equal(parent_link_index_of_collidable_points, nc).astype(int)
|
1773
|
-
)
|
1774
|
-
* W_f_Ci
|
1775
|
-
).sum(axis=0)
|
1776
|
-
)(jnp.arange(model.number_of_links()))
|
1777
|
-
|
1778
|
-
# Convert the 6D forces to the active representation.
|
1779
|
-
f_Li = jax.vmap(
|
1780
|
-
lambda W_f_L: data.inertial_to_other_representation(
|
1781
|
-
array=W_f_L,
|
1782
|
-
other_representation=data.velocity_representation,
|
1783
|
-
transform=data.base_transform(),
|
1784
|
-
is_force=True,
|
1785
|
-
)
|
1786
|
-
)(W_f_Li)
|
1779
|
+
W_f_L = mask.T @ W_f_C
|
1787
1780
|
|
1788
|
-
|
1781
|
+
# Create a references object to store the link forces.
|
1782
|
+
references = js.references.JaxSimModelReferences.build(
|
1783
|
+
model=model, link_forces=W_f_L, velocity_representation=VelRepr.Inertial
|
1784
|
+
)
|
1785
|
+
|
1786
|
+
# Use the references object to convert the link forces to the velocity
|
1787
|
+
# representation of data.
|
1788
|
+
with references.switch_velocity_representation(data.velocity_representation):
|
1789
|
+
f_L = references.link_forces(model=model, data=data)
|
1790
|
+
|
1791
|
+
return f_L
|
1789
1792
|
|
1790
1793
|
|
1791
1794
|
# ======
|
@@ -1931,11 +1934,22 @@ def step(
|
|
1931
1934
|
),
|
1932
1935
|
)
|
1933
1936
|
|
1937
|
+
tf_ns = t0_ns + jnp.array(dt * 1e9, dtype=t0_ns.dtype)
|
1938
|
+
tf_ns = jnp.where(tf_ns >= t0_ns, tf_ns, jnp.array(0, dtype=t0_ns.dtype))
|
1939
|
+
|
1940
|
+
jax.lax.cond(
|
1941
|
+
pred=tf_ns < t0_ns,
|
1942
|
+
true_fun=lambda: jax.debug.print(
|
1943
|
+
"The simulation time overflowed, resetting simulation time to 0."
|
1944
|
+
),
|
1945
|
+
false_fun=lambda: None,
|
1946
|
+
)
|
1947
|
+
|
1934
1948
|
data_tf = (
|
1935
1949
|
# Store the new state of the model and the new time.
|
1936
1950
|
data.replace(
|
1937
1951
|
state=state_tf,
|
1938
|
-
time_ns=
|
1952
|
+
time_ns=tf_ns,
|
1939
1953
|
)
|
1940
1954
|
)
|
1941
1955
|
|
jaxsim/api/ode.py
CHANGED
@@ -95,7 +95,7 @@ def system_velocity_dynamics(
|
|
95
95
|
Args:
|
96
96
|
model: The model to consider.
|
97
97
|
data: The data of the considered model.
|
98
|
-
joint_forces: The joint
|
98
|
+
joint_forces: The joint force references to apply.
|
99
99
|
link_forces:
|
100
100
|
The 6D forces to apply to the links expressed in the frame corresponding to
|
101
101
|
the velocity representation of `data`.
|
@@ -120,6 +120,7 @@ def system_velocity_dynamics(
|
|
120
120
|
references = js.references.JaxSimModelReferences.build(
|
121
121
|
model=model,
|
122
122
|
link_forces=O_f_L,
|
123
|
+
joint_force_references=joint_forces,
|
123
124
|
data=data,
|
124
125
|
velocity_representation=data.velocity_representation,
|
125
126
|
)
|
@@ -132,9 +133,16 @@ def system_velocity_dynamics(
|
|
132
133
|
# with the terrain.
|
133
134
|
W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float)
|
134
135
|
|
136
|
+
# Initialize a dictionary of auxiliary data.
|
137
|
+
# This dictionary is used to store additional data computed by the contact model.
|
135
138
|
aux_data = {}
|
139
|
+
|
136
140
|
if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
|
137
141
|
|
142
|
+
# Note: the following code should be kept in sync with the function
|
143
|
+
# `jaxsim.api.model.link_contact_forces`. We cannot merge them since
|
144
|
+
# here we need to get also aux_data.
|
145
|
+
|
138
146
|
# Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
|
139
147
|
# along with contact-specific auxiliary states.
|
140
148
|
with data.switch_velocity_representation(VelRepr.Inertial):
|
@@ -142,6 +150,7 @@ def system_velocity_dynamics(
|
|
142
150
|
model=model,
|
143
151
|
data=data,
|
144
152
|
link_forces=references.link_forces(model=model, data=data),
|
153
|
+
joint_force_references=references.joint_force_references(model=model),
|
145
154
|
)
|
146
155
|
|
147
156
|
# Construct the vector defining the parent link index of each collidable point.
|
@@ -175,17 +184,15 @@ def system_velocity_dynamics(
|
|
175
184
|
forces=W_f_Li_terrain,
|
176
185
|
additive=True,
|
177
186
|
)
|
178
|
-
|
179
|
-
|
187
|
+
|
188
|
+
# Get the link forces in inertial representation
|
180
189
|
f_L_total = references.link_forces(model=model, data=data)
|
181
190
|
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
|
186
|
-
)
|
191
|
+
v̇_WB, s̈ = system_acceleration(
|
192
|
+
model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
|
193
|
+
)
|
187
194
|
|
188
|
-
return
|
195
|
+
return v̇_WB, s̈, aux_data
|
189
196
|
|
190
197
|
|
191
198
|
def system_acceleration(
|
@@ -196,7 +203,7 @@ def system_acceleration(
|
|
196
203
|
link_forces: jtp.MatrixLike | None = None,
|
197
204
|
) -> tuple[jtp.Vector, jtp.Vector]:
|
198
205
|
"""
|
199
|
-
Compute the system acceleration in
|
206
|
+
Compute the system acceleration in the active representation.
|
200
207
|
|
201
208
|
Args:
|
202
209
|
model: The model to consider.
|
@@ -206,7 +213,7 @@ def system_acceleration(
|
|
206
213
|
The 6D forces to apply to the links expressed in the same representation of data.
|
207
214
|
|
208
215
|
Returns:
|
209
|
-
A tuple containing the base 6D acceleration in
|
216
|
+
A tuple containing the base 6D acceleration in in the active representation
|
210
217
|
and the joint accelerations.
|
211
218
|
"""
|
212
219
|
|
@@ -272,18 +279,15 @@ def system_acceleration(
|
|
272
279
|
)
|
273
280
|
|
274
281
|
# - Joint accelerations: s̈ ∈ ℝⁿ
|
275
|
-
# - Base
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
link_forces=references.link_forces(),
|
285
|
-
)
|
286
|
-
return W_v̇_WB, s̈
|
282
|
+
# - Base acceleration: v̇_WB ∈ ℝ⁶
|
283
|
+
v̇_WB, s̈ = js.model.forward_dynamics_aba(
|
284
|
+
model=model,
|
285
|
+
data=data,
|
286
|
+
joint_forces=references.joint_force_references(model=model),
|
287
|
+
link_forces=references.link_forces(model=model, data=data),
|
288
|
+
)
|
289
|
+
|
290
|
+
return v̇_WB, s̈
|
287
291
|
|
288
292
|
|
289
293
|
@jax.jit
|
@@ -353,7 +357,7 @@ def system_dynamics(
|
|
353
357
|
corresponding derivative, and the dictionary of auxiliary data returned
|
354
358
|
by the system dynamics evaluation.
|
355
359
|
"""
|
356
|
-
|
360
|
+
from jaxsim.rbda.contacts.relaxed_rigid import RelaxedRigidContacts
|
357
361
|
from jaxsim.rbda.contacts.rigid import RigidContacts
|
358
362
|
from jaxsim.rbda.contacts.soft import SoftContacts
|
359
363
|
|
@@ -371,7 +375,7 @@ def system_dynamics(
|
|
371
375
|
case SoftContacts():
|
372
376
|
ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"]
|
373
377
|
|
374
|
-
case RigidContacts():
|
378
|
+
case RigidContacts() | RelaxedRigidContacts():
|
375
379
|
pass
|
376
380
|
|
377
381
|
case _:
|
jaxsim/api/ode_data.py
CHANGED
@@ -6,6 +6,10 @@ import jax_dataclasses
|
|
6
6
|
import jaxsim.api as js
|
7
7
|
import jaxsim.typing as jtp
|
8
8
|
from jaxsim.rbda import ContactsState
|
9
|
+
from jaxsim.rbda.contacts.relaxed_rigid import (
|
10
|
+
RelaxedRigidContacts,
|
11
|
+
RelaxedRigidContactsState,
|
12
|
+
)
|
9
13
|
from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
|
10
14
|
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
|
11
15
|
from jaxsim.utils import JaxsimDataclass
|
@@ -173,6 +177,10 @@ class ODEState(JaxsimDataclass):
|
|
173
177
|
)
|
174
178
|
case RigidContacts():
|
175
179
|
contact = RigidContactsState.build()
|
180
|
+
|
181
|
+
case RelaxedRigidContacts():
|
182
|
+
contact = RelaxedRigidContactsState.build()
|
183
|
+
|
176
184
|
case _:
|
177
185
|
raise ValueError("Unable to determine contact state class prefix.")
|
178
186
|
|
@@ -216,7 +224,9 @@ class ODEState(JaxsimDataclass):
|
|
216
224
|
|
217
225
|
# Get the contact model from the `JaxSimModel`.
|
218
226
|
match contact:
|
219
|
-
case
|
227
|
+
case (
|
228
|
+
SoftContactsState() | RigidContactsState() | RelaxedRigidContactsState()
|
229
|
+
):
|
220
230
|
pass
|
221
231
|
case None:
|
222
232
|
contact = SoftContactsState.zero(model=model)
|
jaxsim/integrators/common.py
CHANGED
@@ -497,7 +497,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
497
497
|
b: jtp.Matrix,
|
498
498
|
c: jtp.Vector,
|
499
499
|
index_of_solution: jtp.IntLike = 0,
|
500
|
-
) -> [bool, int | None]:
|
500
|
+
) -> tuple[bool, int | None]:
|
501
501
|
"""
|
502
502
|
Check if the Butcher tableau supports the FSAL (first-same-as-last) property.
|
503
503
|
|
jaxsim/math/inertia.py
CHANGED
@@ -45,7 +45,7 @@ class Inertia:
|
|
45
45
|
M (jtp.Matrix): The 6x6 inertia matrix.
|
46
46
|
|
47
47
|
Returns:
|
48
|
-
|
48
|
+
tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
|
49
49
|
|
50
50
|
Raises:
|
51
51
|
ValueError: If the input matrix M has an unexpected shape.
|
jaxsim/mujoco/loaders.py
CHANGED
@@ -211,7 +211,7 @@ class RodModelToMjcf:
|
|
211
211
|
joints_dict = {j.name: j for j in rod_model.joints()}
|
212
212
|
|
213
213
|
# Convert all the joints not considered to fixed joints.
|
214
|
-
for joint_name in
|
214
|
+
for joint_name in {j.name for j in rod_model.joints()} - considered_joints:
|
215
215
|
joints_dict[joint_name].type = "fixed"
|
216
216
|
|
217
217
|
# Convert the ROD model to URDF.
|
@@ -289,10 +289,10 @@ class RodModelToMjcf:
|
|
289
289
|
mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets)
|
290
290
|
|
291
291
|
# Get the joint names.
|
292
|
-
mj_joint_names =
|
292
|
+
mj_joint_names = {
|
293
293
|
mj.mj_id2name(mj_model, mj.mjtObj.mjOBJ_JOINT, idx)
|
294
294
|
for idx in range(mj_model.njnt)
|
295
|
-
|
295
|
+
}
|
296
296
|
|
297
297
|
# Check that the Mujoco model only has the considered joints.
|
298
298
|
if mj_joint_names != considered_joints:
|
@@ -394,7 +394,7 @@ class KinematicGraph(Sequence[LinkDescription]):
|
|
394
394
|
return copy.deepcopy(self)
|
395
395
|
|
396
396
|
# Check if all considered joints are part of the full kinematic graph
|
397
|
-
if len(set(considered_joints) -
|
397
|
+
if len(set(considered_joints) - {j.name for j in full_graph.joints}) != 0:
|
398
398
|
extra_j = set(considered_joints) - {j.name for j in full_graph.joints}
|
399
399
|
msg = f"Not all joints to consider are part of the graph ({{{extra_j}}})"
|
400
400
|
raise ValueError(msg)
|
@@ -536,8 +536,8 @@ class KinematicGraph(Sequence[LinkDescription]):
|
|
536
536
|
root_link_name=full_graph.root.name,
|
537
537
|
)
|
538
538
|
|
539
|
-
assert
|
540
|
-
|
539
|
+
assert {f.name for f in self.frames}.isdisjoint(
|
540
|
+
{f.name for f in unconnected_frames + reduced_frames}
|
541
541
|
)
|
542
542
|
|
543
543
|
for link in unconnected_links:
|