jaxsim 0.4.3.dev31__py3-none-any.whl → 0.4.3.dev64__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/contact.py +27 -1
- jaxsim/api/data.py +30 -8
- jaxsim/api/joint.py +62 -2
- jaxsim/api/model.py +1 -1
- jaxsim/api/ode.py +19 -24
- jaxsim/api/ode_data.py +11 -1
- jaxsim/integrators/common.py +1 -1
- jaxsim/math/inertia.py +1 -1
- jaxsim/mujoco/loaders.py +3 -3
- jaxsim/parsers/kinematic_graph.py +3 -3
- jaxsim/rbda/contacts/relaxed_rigid.py +384 -0
- jaxsim/rbda/contacts/rigid.py +11 -41
- jaxsim/terrain/terrain.py +41 -25
- jaxsim/typing.py +1 -1
- jaxsim/utils/jaxsim_dataclass.py +12 -9
- jaxsim/utils/wrappers.py +1 -1
- {jaxsim-0.4.3.dev31.dist-info → jaxsim-0.4.3.dev64.dist-info}/METADATA +2 -1
- {jaxsim-0.4.3.dev31.dist-info → jaxsim-0.4.3.dev64.dist-info}/RECORD +22 -21
- {jaxsim-0.4.3.dev31.dist-info → jaxsim-0.4.3.dev64.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev31.dist-info → jaxsim-0.4.3.dev64.dist-info}/WHEEL +0 -0
- {jaxsim-0.4.3.dev31.dist-info → jaxsim-0.4.3.dev64.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.4.3.
|
16
|
-
__version_tuple__ = version_tuple = (0, 4, 3, '
|
15
|
+
__version__ = version = '0.4.3.dev64'
|
16
|
+
__version_tuple__ = version_tuple = (0, 4, 3, 'dev64')
|
jaxsim/api/contact.py
CHANGED
@@ -131,7 +131,8 @@ def collidable_point_dynamics(
|
|
131
131
|
Returns:
|
132
132
|
The 6D force applied to each collidable point and additional data based on the contact model configured:
|
133
133
|
- Soft: the material deformation rate.
|
134
|
-
- Rigid:
|
134
|
+
- Rigid: no additional data.
|
135
|
+
- QuasiRigid: no additional data.
|
135
136
|
|
136
137
|
Note:
|
137
138
|
The material deformation rate is always returned in the mixed frame
|
@@ -144,6 +145,10 @@ def collidable_point_dynamics(
|
|
144
145
|
W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
|
145
146
|
|
146
147
|
# Import privately the contacts classes.
|
148
|
+
from jaxsim.rbda.contacts.relaxed_rigid import (
|
149
|
+
RelaxedRigidContacts,
|
150
|
+
RelaxedRigidContactsState,
|
151
|
+
)
|
147
152
|
from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
|
148
153
|
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
|
149
154
|
|
@@ -190,6 +195,27 @@ def collidable_point_dynamics(
|
|
190
195
|
|
191
196
|
aux_data = dict()
|
192
197
|
|
198
|
+
case RelaxedRigidContacts():
|
199
|
+
assert isinstance(model.contact_model, RelaxedRigidContacts)
|
200
|
+
assert isinstance(data.state.contact, RelaxedRigidContactsState)
|
201
|
+
|
202
|
+
# Build the contact model.
|
203
|
+
relaxed_rigid_contacts = RelaxedRigidContacts(
|
204
|
+
parameters=data.contacts_params, terrain=model.terrain
|
205
|
+
)
|
206
|
+
|
207
|
+
# Compute the 6D force expressed in the inertial frame and applied to each
|
208
|
+
# collidable point.
|
209
|
+
W_f_Ci, _ = relaxed_rigid_contacts.compute_contact_forces(
|
210
|
+
position=W_p_Ci,
|
211
|
+
velocity=W_ṗ_Ci,
|
212
|
+
model=model,
|
213
|
+
data=data,
|
214
|
+
link_forces=link_forces,
|
215
|
+
)
|
216
|
+
|
217
|
+
aux_data = dict()
|
218
|
+
|
193
219
|
case _:
|
194
220
|
raise ValueError(f"Invalid contact model {model.contact_model}")
|
195
221
|
|
jaxsim/api/data.py
CHANGED
@@ -593,16 +593,18 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
593
593
|
The updated `JaxSimModelData` object.
|
594
594
|
"""
|
595
595
|
|
596
|
-
|
596
|
+
W_Q_B = jnp.array(base_quaternion, dtype=float)
|
597
|
+
|
598
|
+
W_Q_B = jax.lax.select(
|
599
|
+
pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
|
600
|
+
on_true=W_Q_B,
|
601
|
+
on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
|
602
|
+
)
|
597
603
|
|
598
604
|
return self.replace(
|
599
605
|
validate=True,
|
600
606
|
state=self.state.replace(
|
601
|
-
physics_model=self.state.physics_model.replace(
|
602
|
-
base_quaternion=jnp.atleast_1d(base_quaternion.squeeze()).astype(
|
603
|
-
float
|
604
|
-
)
|
605
|
-
)
|
607
|
+
physics_model=self.state.physics_model.replace(base_quaternion=W_Q_B)
|
606
608
|
),
|
607
609
|
)
|
608
610
|
|
@@ -744,6 +746,13 @@ def random_model_data(
|
|
744
746
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
745
747
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
746
748
|
] = ((-1, -1, 0.5), 1.0),
|
749
|
+
joint_pos_bounds: (
|
750
|
+
tuple[
|
751
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
752
|
+
jtp.FloatLike | Sequence[jtp.FloatLike],
|
753
|
+
]
|
754
|
+
| None
|
755
|
+
) = None,
|
747
756
|
base_vel_lin_bounds: tuple[
|
748
757
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
749
758
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
@@ -769,6 +778,8 @@ def random_model_data(
|
|
769
778
|
key: The random key.
|
770
779
|
velocity_representation: The velocity representation to use.
|
771
780
|
base_pos_bounds: The bounds for the base position.
|
781
|
+
joint_pos_bounds:
|
782
|
+
The bounds for the joint positions (reading the joint limits if None).
|
772
783
|
base_vel_lin_bounds: The bounds for the base linear velocity.
|
773
784
|
base_vel_ang_bounds: The bounds for the base angular velocity.
|
774
785
|
joint_vel_bounds: The bounds for the joint velocities.
|
@@ -813,8 +824,19 @@ def random_model_data(
|
|
813
824
|
).wxyz
|
814
825
|
|
815
826
|
if model.number_of_joints() > 0:
|
816
|
-
|
817
|
-
|
827
|
+
|
828
|
+
s_min, s_max = (
|
829
|
+
jnp.array(joint_pos_bounds, dtype=float)
|
830
|
+
if joint_pos_bounds is not None
|
831
|
+
else (None, None)
|
832
|
+
)
|
833
|
+
|
834
|
+
physics_model_state.joint_positions = (
|
835
|
+
js.joint.random_joint_positions(model=model, key=k3)
|
836
|
+
if (s_min is None or s_max is None)
|
837
|
+
else jax.random.uniform(
|
838
|
+
key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max
|
839
|
+
)
|
818
840
|
)
|
819
841
|
|
820
842
|
physics_model_state.joint_velocities = jax.random.uniform(
|
jaxsim/api/joint.py
CHANGED
@@ -180,17 +180,77 @@ def random_joint_positions(
|
|
180
180
|
|
181
181
|
Args:
|
182
182
|
model: The model to consider.
|
183
|
-
joint_names: The names of the joints.
|
184
|
-
key: The random key.
|
183
|
+
joint_names: The names of the considered joints (all if None).
|
184
|
+
key: The random key (initialized from seed 0 if None).
|
185
|
+
|
186
|
+
Note:
|
187
|
+
If the joint range or revolute joints is larger than 2π, their joint positions
|
188
|
+
will be sampled from an interval of size 2π.
|
185
189
|
|
186
190
|
Returns:
|
187
191
|
The random joint positions.
|
188
192
|
"""
|
189
193
|
|
194
|
+
# Consider the key corresponding to a zero seed if it was not passed.
|
190
195
|
key = key if key is not None else jax.random.PRNGKey(seed=0)
|
191
196
|
|
197
|
+
# Get the joint limits parsed from the model description.
|
192
198
|
s_min, s_max = position_limits(model=model, joint_names=joint_names)
|
193
199
|
|
200
|
+
# Get the joint indices.
|
201
|
+
# Note that it will trigger an exception if the given `joint_names` are not valid.
|
202
|
+
joint_names = joint_names if joint_names is not None else model.joint_names()
|
203
|
+
joint_indices = names_to_idxs(model=model, joint_names=joint_names)
|
204
|
+
|
205
|
+
from jaxsim.parsers.descriptions.joint import JointType
|
206
|
+
|
207
|
+
# Filter for revolute joints.
|
208
|
+
is_revolute = jnp.where(
|
209
|
+
jnp.array(model.kin_dyn_parameters.joint_model.joint_types[1:])[joint_indices]
|
210
|
+
== JointType.Revolute,
|
211
|
+
True,
|
212
|
+
False,
|
213
|
+
)
|
214
|
+
|
215
|
+
# Shorthand for π.
|
216
|
+
π = jnp.pi
|
217
|
+
|
218
|
+
# Filter for revolute with full range (or continuous).
|
219
|
+
is_revolute_full_range = jnp.logical_and(is_revolute, s_max - s_min >= 2 * π)
|
220
|
+
|
221
|
+
# Clip the lower limit to -π if the joint range is larger than [-π, π].
|
222
|
+
s_min = jnp.where(
|
223
|
+
jnp.logical_and(
|
224
|
+
is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
|
225
|
+
),
|
226
|
+
-π,
|
227
|
+
s_min,
|
228
|
+
)
|
229
|
+
|
230
|
+
# Clip the upper limit to +π if the joint range is larger than [-π, π].
|
231
|
+
s_max = jnp.where(
|
232
|
+
jnp.logical_and(
|
233
|
+
is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
|
234
|
+
),
|
235
|
+
π,
|
236
|
+
s_max,
|
237
|
+
)
|
238
|
+
|
239
|
+
# Shift the lower limit if the upper limit is smaller than +π.
|
240
|
+
s_min = jnp.where(
|
241
|
+
jnp.logical_and(is_revolute_full_range, s_max < π),
|
242
|
+
s_max - 2 * π,
|
243
|
+
s_min,
|
244
|
+
)
|
245
|
+
|
246
|
+
# Shift the upper limit if the lower limit is larger than -π.
|
247
|
+
s_max = jnp.where(
|
248
|
+
jnp.logical_and(is_revolute_full_range, s_min > -π),
|
249
|
+
s_min + 2 * π,
|
250
|
+
s_max,
|
251
|
+
)
|
252
|
+
|
253
|
+
# Sample the joint positions.
|
194
254
|
s_random = jax.random.uniform(
|
195
255
|
minval=s_min,
|
196
256
|
maxval=s_max,
|
jaxsim/api/model.py
CHANGED
@@ -1935,7 +1935,7 @@ def step(
|
|
1935
1935
|
tf_ns = jnp.where(tf_ns >= t0_ns, tf_ns, jnp.array(0, dtype=t0_ns.dtype))
|
1936
1936
|
|
1937
1937
|
jax.lax.cond(
|
1938
|
-
pred=tf_ns
|
1938
|
+
pred=tf_ns < t0_ns,
|
1939
1939
|
true_fun=lambda: jax.debug.print(
|
1940
1940
|
"The simulation time overflowed, resetting simulation time to 0."
|
1941
1941
|
),
|
jaxsim/api/ode.py
CHANGED
@@ -175,17 +175,15 @@ def system_velocity_dynamics(
|
|
175
175
|
forces=W_f_Li_terrain,
|
176
176
|
additive=True,
|
177
177
|
)
|
178
|
-
|
179
|
-
|
178
|
+
|
179
|
+
# Get the link forces in inertial representation
|
180
180
|
f_L_total = references.link_forces(model=model, data=data)
|
181
181
|
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
|
186
|
-
)
|
182
|
+
v̇_WB, s̈ = system_acceleration(
|
183
|
+
model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
|
184
|
+
)
|
187
185
|
|
188
|
-
return
|
186
|
+
return v̇_WB, s̈, aux_data
|
189
187
|
|
190
188
|
|
191
189
|
def system_acceleration(
|
@@ -196,7 +194,7 @@ def system_acceleration(
|
|
196
194
|
link_forces: jtp.MatrixLike | None = None,
|
197
195
|
) -> tuple[jtp.Vector, jtp.Vector]:
|
198
196
|
"""
|
199
|
-
Compute the system acceleration in
|
197
|
+
Compute the system acceleration in the active representation.
|
200
198
|
|
201
199
|
Args:
|
202
200
|
model: The model to consider.
|
@@ -206,7 +204,7 @@ def system_acceleration(
|
|
206
204
|
The 6D forces to apply to the links expressed in the same representation of data.
|
207
205
|
|
208
206
|
Returns:
|
209
|
-
A tuple containing the base 6D acceleration in
|
207
|
+
A tuple containing the base 6D acceleration in in the active representation
|
210
208
|
and the joint accelerations.
|
211
209
|
"""
|
212
210
|
|
@@ -272,18 +270,15 @@ def system_acceleration(
|
|
272
270
|
)
|
273
271
|
|
274
272
|
# - Joint accelerations: s̈ ∈ ℝⁿ
|
275
|
-
# - Base
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
link_forces=references.link_forces(),
|
285
|
-
)
|
286
|
-
return W_v̇_WB, s̈
|
273
|
+
# - Base acceleration: v̇_WB ∈ ℝ⁶
|
274
|
+
v̇_WB, s̈ = js.model.forward_dynamics_aba(
|
275
|
+
model=model,
|
276
|
+
data=data,
|
277
|
+
joint_forces=references.joint_force_references(model=model),
|
278
|
+
link_forces=references.link_forces(model=model, data=data),
|
279
|
+
)
|
280
|
+
|
281
|
+
return v̇_WB, s̈
|
287
282
|
|
288
283
|
|
289
284
|
@jax.jit
|
@@ -353,7 +348,7 @@ def system_dynamics(
|
|
353
348
|
corresponding derivative, and the dictionary of auxiliary data returned
|
354
349
|
by the system dynamics evaluation.
|
355
350
|
"""
|
356
|
-
|
351
|
+
from jaxsim.rbda.contacts.relaxed_rigid import RelaxedRigidContacts
|
357
352
|
from jaxsim.rbda.contacts.rigid import RigidContacts
|
358
353
|
from jaxsim.rbda.contacts.soft import SoftContacts
|
359
354
|
|
@@ -371,7 +366,7 @@ def system_dynamics(
|
|
371
366
|
case SoftContacts():
|
372
367
|
ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"]
|
373
368
|
|
374
|
-
case RigidContacts():
|
369
|
+
case RigidContacts() | RelaxedRigidContacts():
|
375
370
|
pass
|
376
371
|
|
377
372
|
case _:
|
jaxsim/api/ode_data.py
CHANGED
@@ -6,6 +6,10 @@ import jax_dataclasses
|
|
6
6
|
import jaxsim.api as js
|
7
7
|
import jaxsim.typing as jtp
|
8
8
|
from jaxsim.rbda import ContactsState
|
9
|
+
from jaxsim.rbda.contacts.relaxed_rigid import (
|
10
|
+
RelaxedRigidContacts,
|
11
|
+
RelaxedRigidContactsState,
|
12
|
+
)
|
9
13
|
from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
|
10
14
|
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
|
11
15
|
from jaxsim.utils import JaxsimDataclass
|
@@ -173,6 +177,10 @@ class ODEState(JaxsimDataclass):
|
|
173
177
|
)
|
174
178
|
case RigidContacts():
|
175
179
|
contact = RigidContactsState.build()
|
180
|
+
|
181
|
+
case RelaxedRigidContacts():
|
182
|
+
contact = RelaxedRigidContactsState.build()
|
183
|
+
|
176
184
|
case _:
|
177
185
|
raise ValueError("Unable to determine contact state class prefix.")
|
178
186
|
|
@@ -216,7 +224,9 @@ class ODEState(JaxsimDataclass):
|
|
216
224
|
|
217
225
|
# Get the contact model from the `JaxSimModel`.
|
218
226
|
match contact:
|
219
|
-
case
|
227
|
+
case (
|
228
|
+
SoftContactsState() | RigidContactsState() | RelaxedRigidContactsState()
|
229
|
+
):
|
220
230
|
pass
|
221
231
|
case None:
|
222
232
|
contact = SoftContactsState.zero(model=model)
|
jaxsim/integrators/common.py
CHANGED
@@ -497,7 +497,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
497
497
|
b: jtp.Matrix,
|
498
498
|
c: jtp.Vector,
|
499
499
|
index_of_solution: jtp.IntLike = 0,
|
500
|
-
) -> [bool, int | None]:
|
500
|
+
) -> tuple[bool, int | None]:
|
501
501
|
"""
|
502
502
|
Check if the Butcher tableau supports the FSAL (first-same-as-last) property.
|
503
503
|
|
jaxsim/math/inertia.py
CHANGED
@@ -45,7 +45,7 @@ class Inertia:
|
|
45
45
|
M (jtp.Matrix): The 6x6 inertia matrix.
|
46
46
|
|
47
47
|
Returns:
|
48
|
-
|
48
|
+
tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
|
49
49
|
|
50
50
|
Raises:
|
51
51
|
ValueError: If the input matrix M has an unexpected shape.
|
jaxsim/mujoco/loaders.py
CHANGED
@@ -211,7 +211,7 @@ class RodModelToMjcf:
|
|
211
211
|
joints_dict = {j.name: j for j in rod_model.joints()}
|
212
212
|
|
213
213
|
# Convert all the joints not considered to fixed joints.
|
214
|
-
for joint_name in
|
214
|
+
for joint_name in {j.name for j in rod_model.joints()} - considered_joints:
|
215
215
|
joints_dict[joint_name].type = "fixed"
|
216
216
|
|
217
217
|
# Convert the ROD model to URDF.
|
@@ -289,10 +289,10 @@ class RodModelToMjcf:
|
|
289
289
|
mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets)
|
290
290
|
|
291
291
|
# Get the joint names.
|
292
|
-
mj_joint_names =
|
292
|
+
mj_joint_names = {
|
293
293
|
mj.mj_id2name(mj_model, mj.mjtObj.mjOBJ_JOINT, idx)
|
294
294
|
for idx in range(mj_model.njnt)
|
295
|
-
|
295
|
+
}
|
296
296
|
|
297
297
|
# Check that the Mujoco model only has the considered joints.
|
298
298
|
if mj_joint_names != considered_joints:
|
@@ -394,7 +394,7 @@ class KinematicGraph(Sequence[LinkDescription]):
|
|
394
394
|
return copy.deepcopy(self)
|
395
395
|
|
396
396
|
# Check if all considered joints are part of the full kinematic graph
|
397
|
-
if len(set(considered_joints) -
|
397
|
+
if len(set(considered_joints) - {j.name for j in full_graph.joints}) != 0:
|
398
398
|
extra_j = set(considered_joints) - {j.name for j in full_graph.joints}
|
399
399
|
msg = f"Not all joints to consider are part of the graph ({{{extra_j}}})"
|
400
400
|
raise ValueError(msg)
|
@@ -536,8 +536,8 @@ class KinematicGraph(Sequence[LinkDescription]):
|
|
536
536
|
root_link_name=full_graph.root.name,
|
537
537
|
)
|
538
538
|
|
539
|
-
assert
|
540
|
-
|
539
|
+
assert {f.name for f in self.frames}.isdisjoint(
|
540
|
+
{f.name for f in unconnected_frames + reduced_frames}
|
541
541
|
)
|
542
542
|
|
543
543
|
for link in unconnected_links:
|
@@ -0,0 +1,384 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
import jax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import jax_dataclasses
|
9
|
+
import jaxopt
|
10
|
+
|
11
|
+
import jaxsim.api as js
|
12
|
+
import jaxsim.typing as jtp
|
13
|
+
from jaxsim.api.common import VelRepr
|
14
|
+
from jaxsim.math import Adjoint
|
15
|
+
from jaxsim.terrain.terrain import FlatTerrain, Terrain
|
16
|
+
|
17
|
+
from .common import ContactModel, ContactsParams, ContactsState
|
18
|
+
|
19
|
+
|
20
|
+
@jax_dataclasses.pytree_dataclass
|
21
|
+
class RelaxedRigidContactsParams(ContactsParams):
|
22
|
+
"""Parameters of the relaxed rigid contacts model."""
|
23
|
+
|
24
|
+
# Time constant
|
25
|
+
time_constant: jtp.Float = dataclasses.field(
|
26
|
+
default_factory=lambda: jnp.array(0.01, dtype=float)
|
27
|
+
)
|
28
|
+
|
29
|
+
# Adimensional damping coefficient
|
30
|
+
damping_coefficient: jtp.Float = dataclasses.field(
|
31
|
+
default_factory=lambda: jnp.array(1.0, dtype=float)
|
32
|
+
)
|
33
|
+
|
34
|
+
# Minimum impedance
|
35
|
+
d_min: jtp.Float = dataclasses.field(
|
36
|
+
default_factory=lambda: jnp.array(0.9, dtype=float)
|
37
|
+
)
|
38
|
+
|
39
|
+
# Maximum impedance
|
40
|
+
d_max: jtp.Float = dataclasses.field(
|
41
|
+
default_factory=lambda: jnp.array(0.95, dtype=float)
|
42
|
+
)
|
43
|
+
|
44
|
+
# Width
|
45
|
+
width: jtp.Float = dataclasses.field(
|
46
|
+
default_factory=lambda: jnp.array(0.0001, dtype=float)
|
47
|
+
)
|
48
|
+
|
49
|
+
# Midpoint
|
50
|
+
midpoint: jtp.Float = dataclasses.field(
|
51
|
+
default_factory=lambda: jnp.array(0.1, dtype=float)
|
52
|
+
)
|
53
|
+
|
54
|
+
# Power exponent
|
55
|
+
power: jtp.Float = dataclasses.field(
|
56
|
+
default_factory=lambda: jnp.array(1.0, dtype=float)
|
57
|
+
)
|
58
|
+
|
59
|
+
# Stiffness
|
60
|
+
stiffness: jtp.Float = dataclasses.field(
|
61
|
+
default_factory=lambda: jnp.array(0.0, dtype=float)
|
62
|
+
)
|
63
|
+
|
64
|
+
# Damping
|
65
|
+
damping: jtp.Float = dataclasses.field(
|
66
|
+
default_factory=lambda: jnp.array(0.0, dtype=float)
|
67
|
+
)
|
68
|
+
|
69
|
+
# Friction coefficient
|
70
|
+
mu: jtp.Float = dataclasses.field(
|
71
|
+
default_factory=lambda: jnp.array(0.5, dtype=float)
|
72
|
+
)
|
73
|
+
|
74
|
+
# Maximum number of iterations
|
75
|
+
max_iterations: jtp.Int = dataclasses.field(
|
76
|
+
default_factory=lambda: jnp.array(50, dtype=int)
|
77
|
+
)
|
78
|
+
|
79
|
+
# Solver tolerance
|
80
|
+
tolerance: jtp.Float = dataclasses.field(
|
81
|
+
default_factory=lambda: jnp.array(1e-6, dtype=float)
|
82
|
+
)
|
83
|
+
|
84
|
+
def __hash__(self) -> int:
|
85
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
86
|
+
|
87
|
+
return hash(
|
88
|
+
(
|
89
|
+
HashedNumpyArray(self.time_constant),
|
90
|
+
HashedNumpyArray(self.damping_coefficient),
|
91
|
+
HashedNumpyArray(self.d_min),
|
92
|
+
HashedNumpyArray(self.d_max),
|
93
|
+
HashedNumpyArray(self.width),
|
94
|
+
HashedNumpyArray(self.midpoint),
|
95
|
+
HashedNumpyArray(self.power),
|
96
|
+
HashedNumpyArray(self.stiffness),
|
97
|
+
HashedNumpyArray(self.damping),
|
98
|
+
HashedNumpyArray(self.mu),
|
99
|
+
HashedNumpyArray(self.max_iterations),
|
100
|
+
HashedNumpyArray(self.tolerance),
|
101
|
+
)
|
102
|
+
)
|
103
|
+
|
104
|
+
def __eq__(self, other: RelaxedRigidContactsParams) -> bool:
|
105
|
+
return hash(self) == hash(other)
|
106
|
+
|
107
|
+
@classmethod
|
108
|
+
def build(
|
109
|
+
cls,
|
110
|
+
time_constant: jtp.FloatLike | None = None,
|
111
|
+
damping_coefficient: jtp.FloatLike | None = None,
|
112
|
+
d_min: jtp.FloatLike | None = None,
|
113
|
+
d_max: jtp.FloatLike | None = None,
|
114
|
+
width: jtp.FloatLike | None = None,
|
115
|
+
midpoint: jtp.FloatLike | None = None,
|
116
|
+
power: jtp.FloatLike | None = None,
|
117
|
+
stiffness: jtp.FloatLike | None = None,
|
118
|
+
damping: jtp.FloatLike | None = None,
|
119
|
+
mu: jtp.FloatLike | None = None,
|
120
|
+
max_iterations: jtp.IntLike | None = None,
|
121
|
+
tolerance: jtp.FloatLike | None = None,
|
122
|
+
) -> RelaxedRigidContactsParams:
|
123
|
+
"""Create a `RelaxedRigidContactsParams` instance"""
|
124
|
+
|
125
|
+
return cls(
|
126
|
+
**{
|
127
|
+
field: jnp.array(locals().get(field, default), dtype=default.dtype)
|
128
|
+
for field, default in map(
|
129
|
+
lambda f: (f, cls.__dataclass_fields__[f].default),
|
130
|
+
filter(lambda f: f != "__mutability__", cls.__dataclass_fields__),
|
131
|
+
)
|
132
|
+
}
|
133
|
+
)
|
134
|
+
|
135
|
+
def valid(self) -> bool:
|
136
|
+
return bool(
|
137
|
+
jnp.all(self.time_constant >= 0.0)
|
138
|
+
and jnp.all(self.damping_coefficient > 0.0)
|
139
|
+
and jnp.all(self.d_min >= 0.0)
|
140
|
+
and jnp.all(self.d_max <= 1.0)
|
141
|
+
and jnp.all(self.d_min <= self.d_max)
|
142
|
+
and jnp.all(self.width >= 0.0)
|
143
|
+
and jnp.all(self.midpoint >= 0.0)
|
144
|
+
and jnp.all(self.power >= 0.0)
|
145
|
+
and jnp.all(self.mu >= 0.0)
|
146
|
+
and jnp.all(self.max_iterations > 0)
|
147
|
+
and jnp.all(self.tolerance > 0.0)
|
148
|
+
)
|
149
|
+
|
150
|
+
|
151
|
+
@jax_dataclasses.pytree_dataclass
|
152
|
+
class RelaxedRigidContactsState(ContactsState):
|
153
|
+
"""Class storing the state of the relaxed rigid contacts model."""
|
154
|
+
|
155
|
+
def __eq__(self, other: RelaxedRigidContactsState) -> bool:
|
156
|
+
return hash(self) == hash(other)
|
157
|
+
|
158
|
+
@staticmethod
|
159
|
+
def build() -> RelaxedRigidContactsState:
|
160
|
+
"""Create a `RelaxedRigidContactsState` instance"""
|
161
|
+
|
162
|
+
return RelaxedRigidContactsState()
|
163
|
+
|
164
|
+
@staticmethod
|
165
|
+
def zero(model: js.model.JaxSimModel) -> RelaxedRigidContactsState:
|
166
|
+
"""Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`."""
|
167
|
+
return RelaxedRigidContactsState.build()
|
168
|
+
|
169
|
+
def valid(self, model: js.model.JaxSimModel) -> bool:
|
170
|
+
return True
|
171
|
+
|
172
|
+
|
173
|
+
@jax_dataclasses.pytree_dataclass
|
174
|
+
class RelaxedRigidContacts(ContactModel):
|
175
|
+
"""Relaxed rigid contacts model."""
|
176
|
+
|
177
|
+
parameters: RelaxedRigidContactsParams = dataclasses.field(
|
178
|
+
default_factory=RelaxedRigidContactsParams
|
179
|
+
)
|
180
|
+
|
181
|
+
terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
|
182
|
+
default_factory=FlatTerrain
|
183
|
+
)
|
184
|
+
|
185
|
+
def compute_contact_forces(
|
186
|
+
self,
|
187
|
+
position: jtp.Vector,
|
188
|
+
velocity: jtp.Vector,
|
189
|
+
model: js.model.JaxSimModel,
|
190
|
+
data: js.data.JaxSimModelData,
|
191
|
+
link_forces: jtp.MatrixLike | None = None,
|
192
|
+
) -> tuple[jtp.Vector, tuple[Any, ...]]:
|
193
|
+
|
194
|
+
link_forces = (
|
195
|
+
link_forces
|
196
|
+
if link_forces is not None
|
197
|
+
else jnp.zeros((model.number_of_links(), 6))
|
198
|
+
)
|
199
|
+
|
200
|
+
references = js.references.JaxSimModelReferences.build(
|
201
|
+
model=model,
|
202
|
+
data=data,
|
203
|
+
velocity_representation=data.velocity_representation,
|
204
|
+
link_forces=link_forces,
|
205
|
+
)
|
206
|
+
|
207
|
+
def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
|
208
|
+
x, y, z = jax.tree_map(jnp.squeeze, (x, y, z))
|
209
|
+
|
210
|
+
n̂ = self.terrain.normal(x=x, y=y).squeeze()
|
211
|
+
h = jnp.array([0, 0, z - model.terrain.height(x=x, y=y)])
|
212
|
+
|
213
|
+
return jnp.dot(h, n̂)
|
214
|
+
|
215
|
+
# Compute the activation state of the collidable points
|
216
|
+
δ = jax.vmap(_detect_contact)(*position.T)
|
217
|
+
|
218
|
+
with (
|
219
|
+
references.switch_velocity_representation(VelRepr.Mixed),
|
220
|
+
data.switch_velocity_representation(VelRepr.Mixed),
|
221
|
+
):
|
222
|
+
M = js.model.free_floating_mass_matrix(model=model, data=data)
|
223
|
+
Jl_WC = jnp.vstack(
|
224
|
+
jax.vmap(lambda J, height: J * (height < 0))(
|
225
|
+
js.contact.jacobian(model=model, data=data)[:, :3, :], δ
|
226
|
+
)
|
227
|
+
)
|
228
|
+
W_H_C = js.contact.transforms(model=model, data=data)
|
229
|
+
BW_ν̇_free = jnp.hstack(
|
230
|
+
js.ode.system_acceleration(
|
231
|
+
model=model,
|
232
|
+
data=data,
|
233
|
+
link_forces=references.link_forces(model=model, data=data),
|
234
|
+
)
|
235
|
+
)
|
236
|
+
BW_ν = data.generalized_velocity()
|
237
|
+
J̇_WC = jnp.vstack(
|
238
|
+
jax.vmap(lambda J̇, height: J̇ * (height < 0))(
|
239
|
+
js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
|
240
|
+
),
|
241
|
+
)
|
242
|
+
|
243
|
+
a_ref, R, K, D = self._regularizers(
|
244
|
+
model=model,
|
245
|
+
penetration=δ,
|
246
|
+
velocity=velocity,
|
247
|
+
parameters=self.parameters,
|
248
|
+
)
|
249
|
+
|
250
|
+
G = Jl_WC @ jnp.linalg.lstsq(M, Jl_WC.T)[0]
|
251
|
+
CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇_WC @ BW_ν
|
252
|
+
|
253
|
+
# Calculate quantities for the linear optimization problem.
|
254
|
+
A = G + R
|
255
|
+
b = CW_al_free_WC - a_ref
|
256
|
+
|
257
|
+
objective = lambda x: jnp.sum(jnp.square(A @ x + b))
|
258
|
+
|
259
|
+
# Compute the 3D linear force in C[W] frame
|
260
|
+
opt = jaxopt.LBFGS(
|
261
|
+
fun=objective,
|
262
|
+
maxiter=self.parameters.max_iterations,
|
263
|
+
tol=self.parameters.tolerance,
|
264
|
+
maxls=30,
|
265
|
+
history_size=10,
|
266
|
+
max_stepsize=100.0,
|
267
|
+
)
|
268
|
+
|
269
|
+
init_params = (
|
270
|
+
K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ)
|
271
|
+
+ D[:, jnp.newaxis] * velocity
|
272
|
+
).flatten()
|
273
|
+
|
274
|
+
CW_f_Ci = opt.run(init_params=init_params).params.reshape(-1, 3)
|
275
|
+
|
276
|
+
def mixed_to_inertial(W_H_C: jax.Array, CW_fl: jax.Array) -> jax.Array:
|
277
|
+
W_Xf_CW = Adjoint.from_transform(
|
278
|
+
W_H_C.at[0:3, 0:3].set(jnp.eye(3)),
|
279
|
+
inverse=True,
|
280
|
+
).T
|
281
|
+
return W_Xf_CW @ jnp.hstack([CW_fl, jnp.zeros(3)])
|
282
|
+
|
283
|
+
W_f_C = jax.vmap(mixed_to_inertial)(W_H_C, CW_f_Ci)
|
284
|
+
|
285
|
+
return W_f_C, (None,)
|
286
|
+
|
287
|
+
@staticmethod
|
288
|
+
def _regularizers(
|
289
|
+
model: js.model.JaxSimModel,
|
290
|
+
penetration: jtp.Array,
|
291
|
+
velocity: jtp.Array,
|
292
|
+
parameters: RelaxedRigidContactsParams,
|
293
|
+
) -> tuple:
|
294
|
+
"""
|
295
|
+
Compute the contact jacobian and the reference acceleration.
|
296
|
+
|
297
|
+
Args:
|
298
|
+
model: The jaxsim model.
|
299
|
+
penetration: The penetration of the collidable points.
|
300
|
+
velocity: The velocity of the collidable points.
|
301
|
+
parameters: The parameters of the relaxed rigid contacts model.
|
302
|
+
|
303
|
+
Returns:
|
304
|
+
A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping.
|
305
|
+
"""
|
306
|
+
|
307
|
+
Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ, *_ = jax_dataclasses.astuple(
|
308
|
+
parameters
|
309
|
+
)
|
310
|
+
|
311
|
+
def _imp_aref(
|
312
|
+
penetration: jtp.Array,
|
313
|
+
velocity: jtp.Array,
|
314
|
+
) -> tuple[jtp.Array, jtp.Array]:
|
315
|
+
"""
|
316
|
+
Calculates impedance and offset acceleration in constraint frame.
|
317
|
+
|
318
|
+
Args:
|
319
|
+
penetration: penetration in constraint frame
|
320
|
+
velocity: velocity in constraint frame
|
321
|
+
|
322
|
+
Returns:
|
323
|
+
a_ref: offset acceleration in constraint frame
|
324
|
+
R: regularization matrix
|
325
|
+
K: computed stiffness
|
326
|
+
D: computed damping
|
327
|
+
"""
|
328
|
+
position = jnp.zeros(shape=(3,)).at[2].set(penetration)
|
329
|
+
|
330
|
+
imp_x = jnp.abs(position) / width
|
331
|
+
imp_a = (1.0 / jnp.power(mid, p - 1)) * jnp.power(imp_x, p)
|
332
|
+
|
333
|
+
imp_b = 1 - (1.0 / jnp.power(1 - mid, p - 1)) * jnp.power(1 - imp_x, p)
|
334
|
+
|
335
|
+
imp_y = jnp.where(imp_x < mid, imp_a, imp_b)
|
336
|
+
|
337
|
+
imp = jnp.clip(ξ_min + imp_y * (ξ_max - ξ_min), ξ_min, ξ_max)
|
338
|
+
imp = jnp.atleast_1d(jnp.where(imp_x > 1.0, ξ_max, imp))
|
339
|
+
|
340
|
+
# When passing negative values, K and D represent a spring and damper, respectively.
|
341
|
+
K_f = jnp.where(K < 0, -K / ξ_max**2, 1 / (ξ_max * Ω * ζ) ** 2)
|
342
|
+
D_f = jnp.where(D < 0, -D / ξ_max, 2 / (ξ_max * Ω))
|
343
|
+
|
344
|
+
a_ref = -jnp.atleast_1d(D_f * velocity + K_f * imp * position)
|
345
|
+
|
346
|
+
return imp, a_ref, jnp.atleast_1d(K_f), jnp.atleast_1d(D_f)
|
347
|
+
|
348
|
+
def _compute_row(
|
349
|
+
*,
|
350
|
+
link_idx: jtp.Float,
|
351
|
+
penetration: jtp.Array,
|
352
|
+
velocity: jtp.Array,
|
353
|
+
) -> tuple[jtp.Array, jtp.Array]:
|
354
|
+
|
355
|
+
# Compute the reference acceleration.
|
356
|
+
ξ, a_ref, K, D = _imp_aref(
|
357
|
+
penetration=penetration,
|
358
|
+
velocity=velocity,
|
359
|
+
)
|
360
|
+
|
361
|
+
# Compute the regularization terms.
|
362
|
+
R = (
|
363
|
+
(2 * μ**2 * (1 - ξ) / (ξ + 1e-12))
|
364
|
+
* (1 + μ**2)
|
365
|
+
@ jnp.linalg.inv(M_L[link_idx, :3, :3])
|
366
|
+
)
|
367
|
+
|
368
|
+
return jax.tree.map(lambda x: x * (penetration < 0), (a_ref, R, K, D))
|
369
|
+
|
370
|
+
M_L = js.model.link_spatial_inertia_matrices(model=model)
|
371
|
+
|
372
|
+
a_ref, R, K, D = jax.tree.map(
|
373
|
+
jnp.concatenate,
|
374
|
+
(
|
375
|
+
*jax.vmap(_compute_row)(
|
376
|
+
link_idx=jnp.array(
|
377
|
+
model.kin_dyn_parameters.contact_parameters.body
|
378
|
+
),
|
379
|
+
penetration=penetration,
|
380
|
+
velocity=velocity,
|
381
|
+
),
|
382
|
+
),
|
383
|
+
)
|
384
|
+
return a_ref, jnp.diag(R), K, D
|
jaxsim/rbda/contacts/rigid.py
CHANGED
@@ -9,7 +9,6 @@ import jax_dataclasses
|
|
9
9
|
|
10
10
|
import jaxsim.api as js
|
11
11
|
import jaxsim.typing as jtp
|
12
|
-
from jaxsim import math
|
13
12
|
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
|
14
13
|
from jaxsim.terrain import FlatTerrain, Terrain
|
15
14
|
|
@@ -272,9 +271,17 @@ class RigidContacts(ContactModel):
|
|
272
271
|
link_forces=link_forces,
|
273
272
|
)
|
274
273
|
|
275
|
-
with
|
276
|
-
|
277
|
-
|
274
|
+
with (
|
275
|
+
references.switch_velocity_representation(VelRepr.Mixed),
|
276
|
+
data.switch_velocity_representation(VelRepr.Mixed),
|
277
|
+
):
|
278
|
+
BW_ν̇_free = jnp.hstack(
|
279
|
+
js.ode.system_acceleration(
|
280
|
+
model=model,
|
281
|
+
data=data,
|
282
|
+
joint_forces=references.joint_force_references(model=model),
|
283
|
+
link_forces=references.link_forces(model=model, data=data),
|
284
|
+
)
|
278
285
|
)
|
279
286
|
|
280
287
|
free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
|
@@ -380,43 +387,6 @@ class RigidContacts(ContactModel):
|
|
380
387
|
n_constraints = 6 * n_collidable_points
|
381
388
|
return jnp.zeros(shape=(n_constraints,))
|
382
389
|
|
383
|
-
@staticmethod
|
384
|
-
def _compute_mixed_nu_dot_free(
|
385
|
-
model: js.model.JaxSimModel,
|
386
|
-
data: js.data.JaxSimModelData,
|
387
|
-
references: js.references.JaxSimModelReferences | None = None,
|
388
|
-
) -> jtp.Array:
|
389
|
-
references = (
|
390
|
-
references
|
391
|
-
if references is not None
|
392
|
-
else js.references.JaxSimModelReferences.zero(model=model, data=data)
|
393
|
-
)
|
394
|
-
|
395
|
-
with (
|
396
|
-
data.switch_velocity_representation(VelRepr.Mixed),
|
397
|
-
references.switch_velocity_representation(VelRepr.Mixed),
|
398
|
-
):
|
399
|
-
BW_v_WB = data.base_velocity()
|
400
|
-
W_ṗ_B, W_ω_WB = jnp.split(BW_v_WB, 2)
|
401
|
-
W_v̇_WB, s̈ = js.ode.system_acceleration(
|
402
|
-
model=model,
|
403
|
-
data=data,
|
404
|
-
joint_forces=references.joint_force_references(model=model),
|
405
|
-
link_forces=references.link_forces(model=model, data=data),
|
406
|
-
)
|
407
|
-
|
408
|
-
# Convert the inertial-fixed base acceleration to a mixed base acceleration.
|
409
|
-
W_H_B = data.base_transform()
|
410
|
-
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
411
|
-
BW_X_W = math.Adjoint.from_transform(W_H_BW, inverse=True)
|
412
|
-
term1 = BW_X_W @ W_v̇_WB
|
413
|
-
term2 = jnp.zeros(6).at[0:3].set(jnp.cross(W_ṗ_B, W_ω_WB))
|
414
|
-
BW_v̇_WB = term1 - term2
|
415
|
-
|
416
|
-
BW_ν̇ = jnp.hstack([BW_v̇_WB, s̈])
|
417
|
-
|
418
|
-
return BW_ν̇
|
419
|
-
|
420
390
|
@staticmethod
|
421
391
|
def _linear_acceleration_of_collidable_points(
|
422
392
|
model: js.model.JaxSimModel,
|
jaxsim/terrain/terrain.py
CHANGED
@@ -46,66 +46,82 @@ class Terrain(abc.ABC):
|
|
46
46
|
@jax_dataclasses.pytree_dataclass
|
47
47
|
class FlatTerrain(Terrain):
|
48
48
|
|
49
|
-
|
49
|
+
_height: float = dataclasses.field(default=0.0, kw_only=True)
|
50
50
|
|
51
51
|
@staticmethod
|
52
52
|
def build(height: jtp.FloatLike) -> FlatTerrain:
|
53
53
|
|
54
|
-
return FlatTerrain(
|
54
|
+
return FlatTerrain(_height=float(height))
|
55
55
|
|
56
56
|
def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
|
57
57
|
|
58
|
-
return jnp.array(self.
|
58
|
+
return jnp.array(self._height, dtype=float)
|
59
|
+
|
60
|
+
def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
|
61
|
+
|
62
|
+
return jnp.array([0.0, 0.0, 1.0], dtype=float)
|
59
63
|
|
60
64
|
def __hash__(self) -> int:
|
61
65
|
|
62
|
-
return hash(self.
|
66
|
+
return hash(self._height)
|
63
67
|
|
64
68
|
def __eq__(self, other: FlatTerrain) -> bool:
|
65
69
|
|
66
70
|
if not isinstance(other, FlatTerrain):
|
67
71
|
return False
|
68
72
|
|
69
|
-
return self.
|
73
|
+
return self._height == other._height
|
70
74
|
|
71
75
|
|
72
76
|
@jax_dataclasses.pytree_dataclass
|
73
77
|
class PlaneTerrain(FlatTerrain):
|
74
78
|
|
75
|
-
|
79
|
+
_normal: tuple[float, float, float] = jax_dataclasses.field(
|
76
80
|
default=(0.0, 0.0, 1.0), kw_only=True
|
77
81
|
)
|
78
82
|
|
79
83
|
@staticmethod
|
80
|
-
def build(
|
81
|
-
plane_normal: jtp.VectorLike, plane_height_over_origin: jtp.FloatLike = 0.0
|
82
|
-
) -> PlaneTerrain:
|
84
|
+
def build(height: jtp.FloatLike = 0.0, *, normal: jtp.VectorLike) -> PlaneTerrain:
|
83
85
|
"""
|
84
86
|
Create a PlaneTerrain instance with a specified plane normal vector.
|
85
87
|
|
86
88
|
Args:
|
87
|
-
|
88
|
-
|
89
|
+
normal: The normal vector of the terrain plane.
|
90
|
+
height: The height of the plane over the origin.
|
89
91
|
|
90
92
|
Returns:
|
91
93
|
PlaneTerrain: A PlaneTerrain instance.
|
92
94
|
"""
|
93
95
|
|
94
|
-
|
95
|
-
|
96
|
+
normal = jnp.array(normal, dtype=float)
|
97
|
+
height = jnp.array(height, dtype=float)
|
96
98
|
|
97
|
-
if
|
99
|
+
if normal.shape != (3,):
|
98
100
|
msg = "Expected a 3D vector for the plane normal, got '{}'."
|
99
|
-
raise ValueError(msg.format(
|
101
|
+
raise ValueError(msg.format(normal.shape))
|
100
102
|
|
101
103
|
# Make sure that the plane normal is a unit vector.
|
102
|
-
|
104
|
+
normal = normal / jnp.linalg.norm(normal)
|
103
105
|
|
104
106
|
return PlaneTerrain(
|
105
|
-
|
106
|
-
|
107
|
+
_height=height.item(),
|
108
|
+
_normal=tuple(normal.tolist()),
|
107
109
|
)
|
108
110
|
|
111
|
+
def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
|
112
|
+
"""
|
113
|
+
Compute the normal vector of the terrain at a specific (x, y) location.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
x: The x-coordinate of the location.
|
117
|
+
y: The y-coordinate of the location.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
The normal vector of the terrain surface at the specified location.
|
121
|
+
"""
|
122
|
+
|
123
|
+
return jnp.array(self._normal, dtype=float)
|
124
|
+
|
109
125
|
def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
|
110
126
|
"""
|
111
127
|
Compute the height of the terrain at a specific (x, y) location on a plane.
|
@@ -123,10 +139,10 @@ class PlaneTerrain(FlatTerrain):
|
|
123
139
|
# The height over the origin: -D/C
|
124
140
|
|
125
141
|
# Get the plane equation coefficients from the terrain normal.
|
126
|
-
A, B, C = self.
|
142
|
+
A, B, C = self._normal
|
127
143
|
|
128
144
|
# Compute the final coefficient D considering the terrain height.
|
129
|
-
D = -C * self.
|
145
|
+
D = -C * self._height
|
130
146
|
|
131
147
|
# Invert the plane equation to get the height at the given (x, y) coordinates.
|
132
148
|
return jnp.array(-(A * x + B * y + D) / C).astype(float)
|
@@ -137,9 +153,9 @@ class PlaneTerrain(FlatTerrain):
|
|
137
153
|
|
138
154
|
return hash(
|
139
155
|
(
|
140
|
-
hash(self.
|
156
|
+
hash(self._height),
|
141
157
|
HashedNumpyArray.hash_of_array(
|
142
|
-
array=jnp.array(self.
|
158
|
+
array=jnp.array(self._normal, dtype=float)
|
143
159
|
),
|
144
160
|
)
|
145
161
|
)
|
@@ -150,10 +166,10 @@ class PlaneTerrain(FlatTerrain):
|
|
150
166
|
return False
|
151
167
|
|
152
168
|
if not (
|
153
|
-
np.allclose(self.
|
169
|
+
np.allclose(self._height, other._height)
|
154
170
|
and np.allclose(
|
155
|
-
np.array(self.
|
156
|
-
np.array(other.
|
171
|
+
np.array(self._normal, dtype=float),
|
172
|
+
np.array(other._normal, dtype=float),
|
157
173
|
)
|
158
174
|
):
|
159
175
|
return False
|
jaxsim/typing.py
CHANGED
jaxsim/utils/jaxsim_dataclass.py
CHANGED
@@ -135,9 +135,10 @@ class JaxsimDataclass(abc.ABC):
|
|
135
135
|
"""
|
136
136
|
|
137
137
|
return tuple(
|
138
|
-
|
139
|
-
|
140
|
-
|
138
|
+
map(
|
139
|
+
lambda leaf: getattr(leaf, "shape", None),
|
140
|
+
jax.tree_util.tree_leaves(tree),
|
141
|
+
)
|
141
142
|
)
|
142
143
|
|
143
144
|
@staticmethod
|
@@ -154,9 +155,10 @@ class JaxsimDataclass(abc.ABC):
|
|
154
155
|
"""
|
155
156
|
|
156
157
|
return tuple(
|
157
|
-
|
158
|
-
|
159
|
-
|
158
|
+
map(
|
159
|
+
lambda leaf: getattr(leaf, "dtype", None),
|
160
|
+
jax.tree_util.tree_leaves(tree),
|
161
|
+
)
|
160
162
|
)
|
161
163
|
|
162
164
|
@staticmethod
|
@@ -172,9 +174,10 @@ class JaxsimDataclass(abc.ABC):
|
|
172
174
|
"""
|
173
175
|
|
174
176
|
return tuple(
|
175
|
-
|
176
|
-
|
177
|
-
|
177
|
+
map(
|
178
|
+
lambda leaf: getattr(leaf, "weak_type", None),
|
179
|
+
jax.tree_util.tree_leaves(tree),
|
180
|
+
)
|
178
181
|
)
|
179
182
|
|
180
183
|
@staticmethod
|
jaxsim/utils/wrappers.py
CHANGED
@@ -110,7 +110,7 @@ class HashedNumpyArray:
|
|
110
110
|
return np.allclose(
|
111
111
|
self.array,
|
112
112
|
other.array,
|
113
|
-
**(
|
113
|
+
**(dict(atol=self.precision) if self.precision is not None else {}),
|
114
114
|
)
|
115
115
|
|
116
116
|
return hash(self) == hash(other)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.4.3.
|
3
|
+
Version: 0.4.3.dev64
|
4
4
|
Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
|
5
5
|
Author-email: Diego Ferigo <dgferigo@gmail.com>
|
6
6
|
Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
|
@@ -61,6 +61,7 @@ Description-Content-Type: text/markdown
|
|
61
61
|
License-File: LICENSE
|
62
62
|
Requires-Dist: coloredlogs
|
63
63
|
Requires-Dist: jax>=0.4.13
|
64
|
+
Requires-Dist: jaxopt>=0.8.0
|
64
65
|
Requires-Dist: jaxlib>=0.4.13
|
65
66
|
Requires-Dist: jaxlie>=1.3.0
|
66
67
|
Requires-Dist: jax-dataclasses>=1.4.0
|
@@ -1,29 +1,29 @@
|
|
1
1
|
jaxsim/__init__.py,sha256=bSbpggIz5aG6QuGZLa0V2EfHjAOeucMxi-vIYxzLmN8,2788
|
2
|
-
jaxsim/_version.py,sha256=
|
2
|
+
jaxsim/_version.py,sha256=lLNskxtfHW1HqvnLRuhux3LlK89fMiZFUWknSYopw7k,426
|
3
3
|
jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
|
4
4
|
jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
|
5
|
-
jaxsim/typing.py,sha256=
|
5
|
+
jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
|
6
6
|
jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
|
7
7
|
jaxsim/api/com.py,sha256=m-p3EJDhpnMTlXKplfbZE_aH9NqX_VyLlAE3vUhc6l4,13642
|
8
8
|
jaxsim/api/common.py,sha256=SNgxq42r6eF_-aPszvOjUYkGwXOzz4hKmhDwEUkscFQ,6650
|
9
|
-
jaxsim/api/contact.py,sha256=
|
10
|
-
jaxsim/api/data.py,sha256=
|
9
|
+
jaxsim/api/contact.py,sha256=C_PgMjWYYiqpA7Oz3IxHeFgrp855-xG6AQr6Ze98CtI,21863
|
10
|
+
jaxsim/api/data.py,sha256=mFUw2mj8AIXduW6HnkGN7eooZHfJhwnWbtYZfLF6gk4,28206
|
11
11
|
jaxsim/api/frame.py,sha256=KS8A5wRfjxhe9NgcVo2QA516iP5zky7UVnWxG7nTa7c,12911
|
12
|
-
jaxsim/api/joint.py,sha256=
|
12
|
+
jaxsim/api/joint.py,sha256=lksT1Doxz2jknHyhb4ls20z6f6dofpZSzBJtVacZXAE,7129
|
13
13
|
jaxsim/api/kin_dyn_parameters.py,sha256=CcfSg5Mc8qb1mZeMQ4AK_ffZIsK5yOl7tu397pFhcDA,29369
|
14
14
|
jaxsim/api/link.py,sha256=qPRtc8qqMRjZxUCZYXJMygbB6huDXBfIT1b1b8Durkw,18631
|
15
|
-
jaxsim/api/model.py,sha256=
|
16
|
-
jaxsim/api/ode.py,sha256=
|
17
|
-
jaxsim/api/ode_data.py,sha256=
|
15
|
+
jaxsim/api/model.py,sha256=K0q8-j-04f6B3MEXsctDGtWiuWlN3HbDrsS7zoPYStk,65871
|
16
|
+
jaxsim/api/ode.py,sha256=VuOLvCFoyGLmhNf2vFP5BI9BAPz78V_RW5tJ4hrizsw,13041
|
17
|
+
jaxsim/api/ode_data.py,sha256=7RSoBhfCJdP6P9InQbDwdBVpClPMMuetewI-6AWm-_0,20276
|
18
18
|
jaxsim/api/references.py,sha256=XOVKuQXRmjPoP-T5JWGSbqIGX5DzOkeGafqRpj0ZQEM,20771
|
19
19
|
jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
|
20
|
-
jaxsim/integrators/common.py,sha256=
|
20
|
+
jaxsim/integrators/common.py,sha256=XIrJVJDO0ldaZ93WgoGNlFoRvazsRJTpO3DrK9kIXqM,20437
|
21
21
|
jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
|
22
22
|
jaxsim/integrators/variable_step.py,sha256=5StkFh9oQba34zlkIoXG2fUN78gbxkHePWbrpQ-QZOI,21274
|
23
23
|
jaxsim/math/__init__.py,sha256=8oPITEoGwgRcOeG8KxtqxPQ8b5uku1HNRMokpCoi9Tc,352
|
24
24
|
jaxsim/math/adjoint.py,sha256=o1FCipkGwPtMbN2gFNIyUV8ADF3TX5fxElpTEXK0bIs,4377
|
25
25
|
jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
|
26
|
-
jaxsim/math/inertia.py,sha256=
|
26
|
+
jaxsim/math/inertia.py,sha256=01hz6wMFreN2jBA0rVoBS1YMVh77KvwuzXSOpI3pxNk,1614
|
27
27
|
jaxsim/math/joint_model.py,sha256=EzAveaG5B6ZnCFNUzN30KEQUVesd83lfWXJarYR-kUw,9989
|
28
28
|
jaxsim/math/quaternion.py,sha256=_WA7W3iv7px83sWO1V1n0-J78hqAlO4SL1-jofE-UZ4,4754
|
29
29
|
jaxsim/math/rotation.py,sha256=k-nwT79zmWrys3NNAB-lGWxat7Kqm_6JnFRoimJ8rBg,2156
|
@@ -31,11 +31,11 @@ jaxsim/math/skew.py,sha256=oOGSSR8PUGROl6IJFlrmu6K3gPH-u16hUPfKIkcVv9o,1177
|
|
31
31
|
jaxsim/math/transform.py,sha256=KXzQgOnCfAtbXCwxhplpJ3F0JT3oEyeLVby1_uRAryQ,2892
|
32
32
|
jaxsim/mujoco/__init__.py,sha256=Zo5GAlN1DYKvX8s1hu1j6HntKIbBMLB9Puv9ouaNAZ8,158
|
33
33
|
jaxsim/mujoco/__main__.py,sha256=GBmB7J-zj75ZnFyuAAmpSOpbxi_HhHhWJeot3ljGDJY,5291
|
34
|
-
jaxsim/mujoco/loaders.py,sha256=
|
34
|
+
jaxsim/mujoco/loaders.py,sha256=_8Af_5Yo0-lWHE-46BBMcrqSJnDNxr3peyc519DExtA,25322
|
35
35
|
jaxsim/mujoco/model.py,sha256=AQksXemXWACJ3yvefV2G5HLwwBU9ISoJrOD1wlxdY5w,16386
|
36
36
|
jaxsim/mujoco/visualizer.py,sha256=T1vU-w4NKSmgEkZ0FqVcGmIvYrYO0len2UBSsU4MOZ0,6978
|
37
37
|
jaxsim/parsers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
38
|
-
jaxsim/parsers/kinematic_graph.py,sha256=
|
38
|
+
jaxsim/parsers/kinematic_graph.py,sha256=wT2bgaCS8VQJTHy2H9sENkVPDOiMkRikxEF1t_WaahQ,34748
|
39
39
|
jaxsim/parsers/descriptions/__init__.py,sha256=PbIlunVfb59pB5jSX97YVpMAANRZPRkJ0X-hS14rzv4,221
|
40
40
|
jaxsim/parsers/descriptions/collision.py,sha256=BQeIG-TKi4SVny23w6riDrQ5itC6VRwEMBX6HgAXHxA,3973
|
41
41
|
jaxsim/parsers/descriptions/joint.py,sha256=VSb6C0FBBKMqwrHBKfc-Bbn4rl_J0RzUxMQlhIEvOPM,5185
|
@@ -54,16 +54,17 @@ jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
|
|
54
54
|
jaxsim/rbda/utils.py,sha256=eeT21Y4DiiyhrdF0lUE_VvRuwru5-rR7yOlOlWzCCWE,5381
|
55
55
|
jaxsim/rbda/contacts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
56
56
|
jaxsim/rbda/contacts/common.py,sha256=VwAs742futAmLnDgbaOuLzNDBFiKDfYItdEZ4UcFgzE,2467
|
57
|
-
jaxsim/rbda/contacts/
|
57
|
+
jaxsim/rbda/contacts/relaxed_rigid.py,sha256=9YkPLbK6Kk0wPkuj47r7NBqY2tARyJsiCbrvDlOWHSI,12700
|
58
|
+
jaxsim/rbda/contacts/rigid.py,sha256=fbZk7sC6YOnTs_tzQRfsyBpHyT22XF-wB-EvOSZmhos,14746
|
58
59
|
jaxsim/rbda/contacts/soft.py,sha256=_wvb5iZDjGcVg6rNQelN4LZN7qSC2NIp0HdKvZmlGfk,15647
|
59
60
|
jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
|
60
|
-
jaxsim/terrain/terrain.py,sha256=
|
61
|
+
jaxsim/terrain/terrain.py,sha256=xUQg47yGxIOcTkLPbnO3sruEGBhoCd16j1evTGlmNjI,5010
|
61
62
|
jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
|
62
|
-
jaxsim/utils/jaxsim_dataclass.py,sha256=
|
63
|
+
jaxsim/utils/jaxsim_dataclass.py,sha256=FSiUvdnq4Y1T9Jaa_mw4ZBQJe8H7deLr3Kupxtlh4iI,11322
|
63
64
|
jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
|
64
|
-
jaxsim/utils/wrappers.py,sha256=
|
65
|
-
jaxsim-0.4.3.
|
66
|
-
jaxsim-0.4.3.
|
67
|
-
jaxsim-0.4.3.
|
68
|
-
jaxsim-0.4.3.
|
69
|
-
jaxsim-0.4.3.
|
65
|
+
jaxsim/utils/wrappers.py,sha256=Fh82ZcaFi5fUnByyFLnmumaobsu1hJIvFdopUVzJ1ps,4052
|
66
|
+
jaxsim-0.4.3.dev64.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
|
67
|
+
jaxsim-0.4.3.dev64.dist-info/METADATA,sha256=0-JS1eJjFMSaMzwqbCSpWYU2GcrZkxT1LBDo7lhWICo,17276
|
68
|
+
jaxsim-0.4.3.dev64.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
69
|
+
jaxsim-0.4.3.dev64.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
70
|
+
jaxsim-0.4.3.dev64.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|