jaxsim 0.4.3.dev312__py3-none-any.whl → 0.4.3.dev350__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/contact.py +65 -28
- jaxsim/api/joint.py +8 -9
- jaxsim/api/kin_dyn_parameters.py +9 -4
- jaxsim/api/link.py +3 -4
- jaxsim/api/model.py +21 -22
- jaxsim/api/references.py +1 -1
- jaxsim/integrators/common.py +2 -2
- jaxsim/integrators/variable_step.py +6 -12
- jaxsim/mujoco/loaders.py +9 -138
- jaxsim/mujoco/utils.py +123 -1
- jaxsim/parsers/descriptions/joint.py +1 -26
- jaxsim/parsers/kinematic_graph.py +3 -3
- jaxsim/parsers/rod/parser.py +3 -6
- jaxsim/parsers/rod/utils.py +1 -1
- jaxsim/rbda/collidable_points.py +18 -5
- jaxsim/rbda/contacts/common.py +11 -9
- jaxsim/rbda/contacts/relaxed_rigid.py +14 -5
- jaxsim/rbda/contacts/rigid.py +9 -6
- jaxsim/rbda/contacts/soft.py +17 -4
- jaxsim/rbda/jacobian.py +2 -2
- jaxsim/rbda/utils.py +1 -1
- jaxsim/terrain/terrain.py +9 -1
- jaxsim/utils/tracing.py +3 -9
- jaxsim/utils/wrappers.py +1 -1
- {jaxsim-0.4.3.dev312.dist-info → jaxsim-0.4.3.dev350.dist-info}/METADATA +1 -1
- {jaxsim-0.4.3.dev312.dist-info → jaxsim-0.4.3.dev350.dist-info}/RECORD +30 -30
- {jaxsim-0.4.3.dev312.dist-info → jaxsim-0.4.3.dev350.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev312.dist-info → jaxsim-0.4.3.dev350.dist-info}/WHEEL +0 -0
- {jaxsim-0.4.3.dev312.dist-info → jaxsim-0.4.3.dev350.dist-info}/top_level.txt +0 -0
jaxsim/mujoco/utils.py
CHANGED
@@ -1,7 +1,14 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from collections.abc import Sequence
|
5
|
+
|
1
6
|
import mujoco as mj
|
2
7
|
import numpy as np
|
8
|
+
import numpy.typing as npt
|
9
|
+
from scipy.spatial.transform import Rotation
|
3
10
|
|
4
|
-
from . import MujocoModelHelper
|
11
|
+
from .model import MujocoModelHelper
|
5
12
|
|
6
13
|
|
7
14
|
def mujoco_data_from_jaxsim(
|
@@ -99,3 +106,118 @@ def mujoco_data_from_jaxsim(
|
|
99
106
|
mj.mj_forward(mujoco_model, model_helper.data)
|
100
107
|
|
101
108
|
return model_helper.data
|
109
|
+
|
110
|
+
|
111
|
+
@dataclasses.dataclass
|
112
|
+
class MujocoCamera:
|
113
|
+
"""
|
114
|
+
Helper class storing parameters of a Mujoco camera.
|
115
|
+
|
116
|
+
Refer to the official documentation for more details:
|
117
|
+
https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-camera
|
118
|
+
"""
|
119
|
+
|
120
|
+
mode: str = "fixed"
|
121
|
+
|
122
|
+
target: str | None = None
|
123
|
+
fovy: str = "45"
|
124
|
+
pos: str = "0 0 0"
|
125
|
+
|
126
|
+
quat: str | None = None
|
127
|
+
axisangle: str | None = None
|
128
|
+
xyaxes: str | None = None
|
129
|
+
zaxis: str | None = None
|
130
|
+
euler: str | None = None
|
131
|
+
|
132
|
+
name: str | None = None
|
133
|
+
|
134
|
+
@classmethod
|
135
|
+
def build(cls, **kwargs) -> MujocoCamera:
|
136
|
+
|
137
|
+
if not all(isinstance(value, str) for value in kwargs.values()):
|
138
|
+
raise ValueError(f"Values must be strings: {kwargs}")
|
139
|
+
|
140
|
+
return cls(**kwargs)
|
141
|
+
|
142
|
+
@staticmethod
|
143
|
+
def build_from_target_view(
|
144
|
+
camera_name: str,
|
145
|
+
lookat: Sequence[float | int] | npt.NDArray = (0, 0, 0),
|
146
|
+
distance: float | int | npt.NDArray = 3,
|
147
|
+
azimut: float | int | npt.NDArray = 90,
|
148
|
+
elevation: float | int | npt.NDArray = -45,
|
149
|
+
fovy: float | int | npt.NDArray = 45,
|
150
|
+
degrees: bool = True,
|
151
|
+
**kwargs,
|
152
|
+
) -> MujocoCamera:
|
153
|
+
"""
|
154
|
+
Create a custom camera that looks at a target point.
|
155
|
+
|
156
|
+
Note:
|
157
|
+
The choice of the parameters is easier if we imagine to consider a target
|
158
|
+
frame `T` whose origin is located over the lookat point and having the same
|
159
|
+
orientation of the world frame `W`. We also introduce a camera frame `C`
|
160
|
+
whose origin is located over the lower-left corner of the image, and having
|
161
|
+
the x-axis pointing right and the y-axis pointing up in image coordinates.
|
162
|
+
The camera renders what it sees in the -z direction of frame `C`.
|
163
|
+
|
164
|
+
Args:
|
165
|
+
camera_name: The name of the camera.
|
166
|
+
lookat: The target point to look at (origin of `T`).
|
167
|
+
distance:
|
168
|
+
The distance from the target point (displacement between the origins
|
169
|
+
of `T` and `C`).
|
170
|
+
azimut:
|
171
|
+
The rotation around z of the camera. With an angle of 0, the camera
|
172
|
+
would loot at the target point towards the positive x-axis of `T`.
|
173
|
+
elevation:
|
174
|
+
The rotation around the x-axis of the camera frame `C`. Note that if
|
175
|
+
you want to lift the view angle, the elevation is negative.
|
176
|
+
fovy: The field of view of the camera.
|
177
|
+
degrees: Whether the angles are in degrees or radians.
|
178
|
+
**kwargs: Additional camera parameters.
|
179
|
+
|
180
|
+
Returns:
|
181
|
+
The custom camera.
|
182
|
+
"""
|
183
|
+
|
184
|
+
# Start from a frame whose origin is located over the lookat point.
|
185
|
+
# We initialize a -90 degrees rotation around the z-axis because due to
|
186
|
+
# the default camera coordinate system (x pointing right, y pointing up).
|
187
|
+
W_H_C = np.eye(4)
|
188
|
+
W_H_C[0:3, 3] = np.array(lookat)
|
189
|
+
W_H_C[0:3, 0:3] = Rotation.from_euler(
|
190
|
+
seq="ZX", angles=[-90, 90], degrees=True
|
191
|
+
).as_matrix()
|
192
|
+
|
193
|
+
# Process the azimut.
|
194
|
+
R_az = Rotation.from_euler(seq="Y", angles=azimut, degrees=degrees).as_matrix()
|
195
|
+
W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_az
|
196
|
+
|
197
|
+
# Process elevation.
|
198
|
+
R_el = Rotation.from_euler(
|
199
|
+
seq="X", angles=elevation, degrees=degrees
|
200
|
+
).as_matrix()
|
201
|
+
W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_el
|
202
|
+
|
203
|
+
# Process distance.
|
204
|
+
tf_distance = np.eye(4)
|
205
|
+
tf_distance[2, 3] = distance
|
206
|
+
W_H_C = W_H_C @ tf_distance
|
207
|
+
|
208
|
+
# Extract the position and the quaternion.
|
209
|
+
p = W_H_C[0:3, 3]
|
210
|
+
Q = Rotation.from_matrix(W_H_C[0:3, 0:3]).as_quat(scalar_first=True)
|
211
|
+
|
212
|
+
return MujocoCamera.build(
|
213
|
+
name=camera_name,
|
214
|
+
mode="fixed",
|
215
|
+
fovy=f"{fovy if degrees else np.rad2deg(fovy)}",
|
216
|
+
pos=" ".join(p.astype(str).tolist()),
|
217
|
+
quat=" ".join(Q.astype(str).tolist()),
|
218
|
+
**kwargs,
|
219
|
+
)
|
220
|
+
|
221
|
+
def asdict(self) -> dict[str, str]:
|
222
|
+
|
223
|
+
return {k: v for k, v in dataclasses.asdict(self).items() if v is not None}
|
@@ -100,32 +100,7 @@ class JointDescription(JaxsimDataclass):
|
|
100
100
|
if not isinstance(other, JointDescription):
|
101
101
|
return False
|
102
102
|
|
103
|
-
|
104
|
-
self.name == other.name
|
105
|
-
and self.jtype == other.jtype
|
106
|
-
and self.child == other.child
|
107
|
-
and self.parent == other.parent
|
108
|
-
and self.index == other.index
|
109
|
-
and all(
|
110
|
-
np.allclose(getattr(self, attr), getattr(other, attr))
|
111
|
-
for attr in [
|
112
|
-
"axis",
|
113
|
-
"pose",
|
114
|
-
"friction_static",
|
115
|
-
"friction_viscous",
|
116
|
-
"position_limit_damper",
|
117
|
-
"position_limit_spring",
|
118
|
-
"position_limit",
|
119
|
-
"initial_position",
|
120
|
-
"motor_inertia",
|
121
|
-
"motor_viscous_friction",
|
122
|
-
"motor_gear_ratio",
|
123
|
-
]
|
124
|
-
),
|
125
|
-
):
|
126
|
-
return False
|
127
|
-
|
128
|
-
return True
|
103
|
+
return hash(self) == hash(other)
|
129
104
|
|
130
105
|
def __hash__(self) -> int:
|
131
106
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
import copy
|
4
4
|
import dataclasses
|
5
5
|
import functools
|
6
|
-
from collections.abc import Callable, Iterable, Sequence
|
6
|
+
from collections.abc import Callable, Iterable, Iterator, Sequence
|
7
7
|
from typing import Any
|
8
8
|
|
9
9
|
import numpy as np
|
@@ -82,7 +82,7 @@ class KinematicGraph(Sequence[LinkDescription]):
|
|
82
82
|
default_factory=list, hash=False, compare=False
|
83
83
|
)
|
84
84
|
|
85
|
-
root_pose: RootPose = dataclasses.field(default_factory=
|
85
|
+
root_pose: RootPose = dataclasses.field(default_factory=RootPose)
|
86
86
|
|
87
87
|
# Private attribute storing optional additional info.
|
88
88
|
_extra_info: dict[str, Any] = dataclasses.field(
|
@@ -700,7 +700,7 @@ class KinematicGraph(Sequence[LinkDescription]):
|
|
700
700
|
# Sequence protocol
|
701
701
|
# =================
|
702
702
|
|
703
|
-
def __iter__(self) ->
|
703
|
+
def __iter__(self) -> Iterator[LinkDescription]:
|
704
704
|
yield from KinematicGraph.breadth_first_search(root=self.root)
|
705
705
|
|
706
706
|
def __reversed__(self) -> Iterable[LinkDescription]:
|
jaxsim/parsers/rod/parser.py
CHANGED
@@ -85,10 +85,7 @@ def extract_model_data(
|
|
85
85
|
|
86
86
|
# Log type of base link.
|
87
87
|
logging.debug(
|
88
|
-
msg="Model '{}' is {}"
|
89
|
-
sdf_model.name,
|
90
|
-
"fixed-base" if sdf_model.is_fixed_base() else "floating-base",
|
91
|
-
)
|
88
|
+
msg=f"Model '{sdf_model.name}' is {'fixed-base' if sdf_model.is_fixed_base() else 'floating-base'}"
|
92
89
|
)
|
93
90
|
|
94
91
|
# Log detected base link.
|
@@ -175,7 +172,7 @@ def extract_model_data(
|
|
175
172
|
for j in sdf_model.joints()
|
176
173
|
if j.type == "fixed"
|
177
174
|
and j.parent == "world"
|
178
|
-
and j.child in links_dict
|
175
|
+
and j.child in links_dict
|
179
176
|
and j.pose.relative_to in {"__model__", "world", None}
|
180
177
|
]
|
181
178
|
|
@@ -287,7 +284,7 @@ def extract_model_data(
|
|
287
284
|
for j in sdf_model.joints()
|
288
285
|
if j.type in {"revolute", "continuous", "prismatic", "fixed"}
|
289
286
|
and j.parent != "world"
|
290
|
-
and j.child in links_dict
|
287
|
+
and j.child in links_dict
|
291
288
|
]
|
292
289
|
|
293
290
|
# Create a dictionary to find the parent joint of the links.
|
jaxsim/parsers/rod/utils.py
CHANGED
@@ -179,7 +179,7 @@ def create_sphere_collision(
|
|
179
179
|
|
180
180
|
r = collision.geometry.sphere.radius
|
181
181
|
sphere_points = r * fibonacci_sphere(
|
182
|
-
samples=int(os.getenv(key="JAXSIM_COLLISION_SPHERE_POINTS", default="
|
182
|
+
samples=int(os.getenv(key="JAXSIM_COLLISION_SPHERE_POINTS", default="50"))
|
183
183
|
)
|
184
184
|
|
185
185
|
H = collision.pose.transform() if collision.pose is not None else np.eye(4)
|
jaxsim/rbda/collidable_points.py
CHANGED
@@ -21,7 +21,7 @@ def collidable_points_pos_vel(
|
|
21
21
|
) -> tuple[jtp.Matrix, jtp.Matrix]:
|
22
22
|
"""
|
23
23
|
|
24
|
-
Compute the position and linear velocity of collidable points in the world frame.
|
24
|
+
Compute the position and linear velocity of the enabled collidable points in the world frame.
|
25
25
|
|
26
26
|
Args:
|
27
27
|
model: The model to consider.
|
@@ -35,10 +35,23 @@ def collidable_points_pos_vel(
|
|
35
35
|
joint_velocities: The velocities of the joints.
|
36
36
|
|
37
37
|
Returns:
|
38
|
-
A tuple containing the position and linear velocity of collidable points.
|
38
|
+
A tuple containing the position and linear velocity of the enabled collidable points.
|
39
39
|
"""
|
40
40
|
|
41
|
-
|
41
|
+
# Get the indices of the enabled collidable points.
|
42
|
+
indices_of_enabled_collidable_points = (
|
43
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
44
|
+
)
|
45
|
+
|
46
|
+
parent_link_idx_of_enabled_collidable_points = jnp.array(
|
47
|
+
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
48
|
+
)[indices_of_enabled_collidable_points]
|
49
|
+
|
50
|
+
L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
|
51
|
+
indices_of_enabled_collidable_points
|
52
|
+
]
|
53
|
+
|
54
|
+
if len(indices_of_enabled_collidable_points) == 0:
|
42
55
|
return jnp.array(0).astype(float), jnp.empty(0).astype(float)
|
43
56
|
|
44
57
|
W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs(
|
@@ -136,8 +149,8 @@ def collidable_points_pos_vel(
|
|
136
149
|
|
137
150
|
# Process all the collidable points in parallel.
|
138
151
|
W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)(
|
139
|
-
|
140
|
-
|
152
|
+
L_p_Ci,
|
153
|
+
parent_link_idx_of_enabled_collidable_points,
|
141
154
|
)
|
142
155
|
|
143
156
|
return W_p_Ci, CW_vl_WC
|
jaxsim/rbda/contacts/common.py
CHANGED
@@ -216,11 +216,21 @@ class ContactModel(JaxsimDataclass):
|
|
216
216
|
the velocity representation of data.
|
217
217
|
"""
|
218
218
|
|
219
|
+
# Get the object storing the contact parameters of the model.
|
220
|
+
contact_parameters = model.kin_dyn_parameters.contact_parameters
|
221
|
+
|
222
|
+
# Extract the indices corresponding to the enabled collidable points.
|
223
|
+
indices_of_enabled_collidable_points = (
|
224
|
+
contact_parameters.indices_of_enabled_collidable_points
|
225
|
+
)
|
226
|
+
|
219
227
|
# Convert the contact forces to a JAX array.
|
220
228
|
f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())
|
221
229
|
|
222
230
|
# Get the pose of the enabled collidable points.
|
223
|
-
W_H_C = js.contact.transforms(model=model, data=data)
|
231
|
+
W_H_C = js.contact.transforms(model=model, data=data)[
|
232
|
+
indices_of_enabled_collidable_points
|
233
|
+
]
|
224
234
|
|
225
235
|
# Convert the contact forces to inertial-fixed representation.
|
226
236
|
W_f_C = jax.vmap(
|
@@ -234,14 +244,6 @@ class ContactModel(JaxsimDataclass):
|
|
234
244
|
)
|
235
245
|
)(f_C, W_H_C)
|
236
246
|
|
237
|
-
# Get the object storing the contact parameters of the model.
|
238
|
-
contact_parameters = model.kin_dyn_parameters.contact_parameters
|
239
|
-
|
240
|
-
# Extract the indices corresponding to the enabled collidable points.
|
241
|
-
indices_of_enabled_collidable_points = (
|
242
|
-
contact_parameters.indices_of_enabled_collidable_points
|
243
|
-
)
|
244
|
-
|
245
247
|
# Construct the vector defining the parent link index of each collidable point.
|
246
248
|
# We use this vector to sum the 6D forces of all collidable points rigidly
|
247
249
|
# attached to the same link.
|
@@ -357,13 +357,15 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
357
357
|
|
358
358
|
Jl_WC = jnp.vstack(
|
359
359
|
jax.vmap(lambda J, height: J * (height < 0))(
|
360
|
-
js.contact.jacobian(model=model, data=data)[:, :3, :],
|
360
|
+
js.contact.jacobian(model=model, data=data)[:, :3, :],
|
361
|
+
δ,
|
361
362
|
)
|
362
363
|
)
|
363
364
|
|
364
365
|
J̇_WC = jnp.vstack(
|
365
366
|
jax.vmap(lambda J̇, height: J̇ * (height < 0))(
|
366
|
-
js.contact.jacobian_derivative(model=model, data=data)[:, :3],
|
367
|
+
js.contact.jacobian_derivative(model=model, data=data)[:, :3],
|
368
|
+
δ,
|
367
369
|
),
|
368
370
|
)
|
369
371
|
|
@@ -530,6 +532,15 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
530
532
|
)
|
531
533
|
)
|
532
534
|
|
535
|
+
# Get the indices of the enabled collidable points.
|
536
|
+
indices_of_enabled_collidable_points = (
|
537
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
538
|
+
)
|
539
|
+
|
540
|
+
parent_link_idx_of_enabled_collidable_points = jnp.array(
|
541
|
+
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
542
|
+
)[indices_of_enabled_collidable_points]
|
543
|
+
|
533
544
|
# Compute the 6D inertia matrices of all links.
|
534
545
|
M_L = js.model.link_spatial_inertia_matrices(model=model)
|
535
546
|
|
@@ -595,9 +606,7 @@ class RelaxedRigidContacts(common.ContactModel):
|
|
595
606
|
f=jnp.concatenate,
|
596
607
|
tree=(
|
597
608
|
*jax.vmap(compute_row)(
|
598
|
-
link_idx=
|
599
|
-
model.kin_dyn_parameters.contact_parameters.body
|
600
|
-
),
|
609
|
+
link_idx=parent_link_idx_of_enabled_collidable_points,
|
601
610
|
penetration=penetration,
|
602
611
|
velocity=velocity,
|
603
612
|
),
|
jaxsim/rbda/contacts/rigid.py
CHANGED
@@ -285,6 +285,13 @@ class RigidContacts(ContactModel):
|
|
285
285
|
# Import qpax privately just in this method.
|
286
286
|
import qpax
|
287
287
|
|
288
|
+
# Get the indices of the enabled collidable points.
|
289
|
+
indices_of_enabled_collidable_points = (
|
290
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
291
|
+
)
|
292
|
+
|
293
|
+
n_collidable_points = len(indices_of_enabled_collidable_points)
|
294
|
+
|
288
295
|
link_forces = jnp.atleast_2d(
|
289
296
|
jnp.array(link_forces, dtype=float).squeeze()
|
290
297
|
if link_forces is not None
|
@@ -299,7 +306,6 @@ class RigidContacts(ContactModel):
|
|
299
306
|
|
300
307
|
# Compute kin-dyn quantities used in the contact model.
|
301
308
|
with data.switch_velocity_representation(VelRepr.Mixed):
|
302
|
-
|
303
309
|
BW_ν = data.generalized_velocity()
|
304
310
|
|
305
311
|
M = js.model.free_floating_mass_matrix(model=model, data=data)
|
@@ -310,14 +316,11 @@ class RigidContacts(ContactModel):
|
|
310
316
|
W_H_C = js.contact.transforms(model=model, data=data)
|
311
317
|
|
312
318
|
# Compute the position and linear velocities (mixed representation) of
|
313
|
-
# all collidable points belonging to the robot.
|
319
|
+
# all enabled collidable points belonging to the robot.
|
314
320
|
position, velocity = js.contact.collidable_point_kinematics(
|
315
321
|
model=model, data=data
|
316
322
|
)
|
317
323
|
|
318
|
-
# Get the number of collidable points.
|
319
|
-
n_collidable_points = len(model.kin_dyn_parameters.contact_parameters.body)
|
320
|
-
|
321
324
|
# Compute the penetration depth and velocity of the collidable points.
|
322
325
|
# Note that this function considers the penetration in the normal direction.
|
323
326
|
δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))(
|
@@ -460,7 +463,7 @@ class RigidContacts(ContactModel):
|
|
460
463
|
return G
|
461
464
|
|
462
465
|
@staticmethod
|
463
|
-
def _compute_ineq_bounds(n_collidable_points:
|
466
|
+
def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector:
|
464
467
|
|
465
468
|
n_constraints = 6 * n_collidable_points
|
466
469
|
return jnp.zeros(shape=(n_constraints,))
|
jaxsim/rbda/contacts/soft.py
CHANGED
@@ -445,16 +445,27 @@ class SoftContacts(common.ContactModel):
|
|
445
445
|
# contact parameters are not compatible.
|
446
446
|
model, data = self.initialize_model_and_data(model=model, data=data)
|
447
447
|
|
448
|
+
# Get the indices of the enabled collidable points.
|
449
|
+
indices_of_enabled_collidable_points = (
|
450
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
451
|
+
)
|
452
|
+
|
448
453
|
# Compute the position and linear velocities (mixed representation) of
|
449
|
-
# all collidable points belonging to the robot
|
454
|
+
# all the collidable points belonging to the robot and extract the ones
|
455
|
+
# for the enabled collidable points.
|
450
456
|
W_p_C, W_ṗ_C = js.contact.collidable_point_kinematics(model=model, data=data)
|
451
457
|
|
452
458
|
# Extract the material deformation corresponding to the collidable points.
|
453
459
|
m = data.state.extended["tangential_deformation"]
|
454
460
|
|
455
|
-
|
461
|
+
m_enabled = m[indices_of_enabled_collidable_points]
|
462
|
+
|
463
|
+
# Initialize the tangential deformation rate array for every collidable point.
|
464
|
+
ṁ = jnp.zeros_like(m)
|
465
|
+
|
466
|
+
# Compute the contact forces only for the enabled collidable points.
|
456
467
|
# Since we treat them as independent, we can vmap the computation.
|
457
|
-
W_f, ṁ = jax.vmap(
|
468
|
+
W_f, ṁ_enabled = jax.vmap(
|
458
469
|
lambda p, v, m: SoftContacts.compute_contact_force(
|
459
470
|
position=p,
|
460
471
|
velocity=v,
|
@@ -462,6 +473,8 @@ class SoftContacts(common.ContactModel):
|
|
462
473
|
parameters=data.contacts_params,
|
463
474
|
terrain=model.terrain,
|
464
475
|
)
|
465
|
-
)(W_p_C, W_ṗ_C,
|
476
|
+
)(W_p_C, W_ṗ_C, m_enabled)
|
477
|
+
|
478
|
+
ṁ = ṁ.at[indices_of_enabled_collidable_points].set(ṁ_enabled)
|
466
479
|
|
467
480
|
return W_f, dict(m_dot=ṁ)
|
jaxsim/rbda/jacobian.py
CHANGED
@@ -205,7 +205,7 @@ def jacobian_full_doubly_left(
|
|
205
205
|
# Convert adjoints to SE(3) transforms.
|
206
206
|
# Returning them here prevents calling FK in case the output representation
|
207
207
|
# of the Jacobian needs to be changed.
|
208
|
-
B_H_L = jax.vmap(
|
208
|
+
B_H_L = jax.vmap(Adjoint.to_transform)(B_X_i)
|
209
209
|
|
210
210
|
# Adjust shape of doubly-left free-floating full Jacobian.
|
211
211
|
B_J_full_WL_B = J.squeeze().astype(float)
|
@@ -322,7 +322,7 @@ def jacobian_derivative_full_doubly_left(
|
|
322
322
|
# Convert adjoints to SE(3) transforms.
|
323
323
|
# Returning them here prevents calling FK in case the output representation
|
324
324
|
# of the Jacobian needs to be changed.
|
325
|
-
B_H_L = jax.vmap(
|
325
|
+
B_H_L = jax.vmap(Adjoint.to_transform)(B_X_i)
|
326
326
|
|
327
327
|
# Adjust shape of doubly-left free-floating full Jacobian derivative.
|
328
328
|
B_J̇_full_WL_B = J̇.squeeze().astype(float)
|
jaxsim/rbda/utils.py
CHANGED
@@ -135,7 +135,7 @@ def process_inputs(
|
|
135
135
|
# Check that the quaternion is unary since our RBDAs make this assumption in order
|
136
136
|
# to prevent introducing additional normalizations that would affect AD.
|
137
137
|
exceptions.raise_value_error_if(
|
138
|
-
condition
|
138
|
+
condition=~jnp.allclose(W_Q_B.dot(W_Q_B), 1.0),
|
139
139
|
msg="A RBDA received a quaternion that is not normalized.",
|
140
140
|
)
|
141
141
|
|
jaxsim/terrain/terrain.py
CHANGED
@@ -8,6 +8,7 @@ import jax_dataclasses
|
|
8
8
|
import numpy as np
|
9
9
|
|
10
10
|
import jaxsim.typing as jtp
|
11
|
+
from jaxsim import exceptions
|
11
12
|
|
12
13
|
|
13
14
|
class Terrain(abc.ABC):
|
@@ -108,7 +109,9 @@ class PlaneTerrain(FlatTerrain):
|
|
108
109
|
_normal=tuple(normal.tolist()),
|
109
110
|
)
|
110
111
|
|
111
|
-
def normal(
|
112
|
+
def normal(
|
113
|
+
self, x: jtp.FloatLike | None = None, y: jtp.FloatLike | None = None
|
114
|
+
) -> jtp.Vector:
|
112
115
|
"""
|
113
116
|
Compute the normal vector of the terrain at a specific (x, y) location.
|
114
117
|
|
@@ -141,6 +144,11 @@ class PlaneTerrain(FlatTerrain):
|
|
141
144
|
# Get the plane equation coefficients from the terrain normal.
|
142
145
|
A, B, C = self._normal
|
143
146
|
|
147
|
+
exceptions.raise_value_error_if(
|
148
|
+
condition=jnp.allclose(C, 0.0),
|
149
|
+
msg="The z component of the normal cannot be zero.",
|
150
|
+
)
|
151
|
+
|
144
152
|
# Compute the final coefficient D considering the terrain height.
|
145
153
|
D = -C * self._height
|
146
154
|
|
jaxsim/utils/tracing.py
CHANGED
@@ -8,15 +8,9 @@ import jax.interpreters.partial_eval
|
|
8
8
|
def tracing(var: Any) -> bool | jax.Array:
|
9
9
|
"""Returns True if the variable is being traced by JAX, False otherwise."""
|
10
10
|
|
11
|
-
return
|
12
|
-
|
13
|
-
|
14
|
-
for t in (
|
15
|
-
jax._src.core.Tracer,
|
16
|
-
jax.interpreters.partial_eval.DynamicJaxprTracer,
|
17
|
-
)
|
18
|
-
]
|
19
|
-
).any()
|
11
|
+
return isinstance(
|
12
|
+
var, jax._src.core.Tracer | jax.interpreters.partial_eval.DynamicJaxprTracer
|
13
|
+
)
|
20
14
|
|
21
15
|
|
22
16
|
def not_tracing(var: Any) -> bool | jax.Array:
|
jaxsim/utils/wrappers.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.4.3.
|
3
|
+
Version: 0.4.3.dev350
|
4
4
|
Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
|
5
5
|
Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
|
6
6
|
Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
|
@@ -1,25 +1,25 @@
|
|
1
1
|
jaxsim/__init__.py,sha256=opgtbhhd1kDsHI4H1vOd3loMPDRi884yQ3tohfFGfNc,3382
|
2
|
-
jaxsim/_version.py,sha256=
|
2
|
+
jaxsim/_version.py,sha256=56tTuqXBlX9UQVJyJ_A9hRmvozgLzRGJ-9ZCppehae8,428
|
3
3
|
jaxsim/exceptions.py,sha256=vSoScaRD4nvh6jltgK9Ry5pKnE0O5hb4_yI_pk_fvR8,2175
|
4
4
|
jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
|
5
5
|
jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
|
6
6
|
jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
|
7
7
|
jaxsim/api/com.py,sha256=m-p3EJDhpnMTlXKplfbZE_aH9NqX_VyLlAE3vUhc6l4,13642
|
8
8
|
jaxsim/api/common.py,sha256=SNgxq42r6eF_-aPszvOjUYkGwXOzz4hKmhDwEUkscFQ,6650
|
9
|
-
jaxsim/api/contact.py,sha256=
|
9
|
+
jaxsim/api/contact.py,sha256=D6RucrH9gnoUFLdmAEYwLGrimU0wLmuoDeOONu4ni74,25658
|
10
10
|
jaxsim/api/data.py,sha256=ThRpoBlbdwf1N3xs8SWrY5d8RbfdYRwFcmkdIPgtee4,29004
|
11
11
|
jaxsim/api/frame.py,sha256=yPSgNygHkvWlln4wShNt7vZm_fFobVEm7phsklNNyH8,12922
|
12
|
-
jaxsim/api/joint.py,sha256=
|
13
|
-
jaxsim/api/kin_dyn_parameters.py,sha256=
|
14
|
-
jaxsim/api/link.py,sha256=
|
15
|
-
jaxsim/api/model.py,sha256=
|
12
|
+
jaxsim/api/joint.py,sha256=8rCIxRMeAidsaBbw7kkGp6z3-UmBPtqmYmV_arHDQJ8,7365
|
13
|
+
jaxsim/api/kin_dyn_parameters.py,sha256=Y9wnMshz83Zm4UEPOAOTINdtfkBZ86w853c8Yi2qaVs,29670
|
14
|
+
jaxsim/api/link.py,sha256=nHjffhNdi_xGkteMsqdb_hC9mdV9rNw7k3pl89Uhw_8,12798
|
15
|
+
jaxsim/api/model.py,sha256=A88AaBZpWvQ-L9blFyl1GHvTWI05rvVFKbSaHzD77_k,79563
|
16
16
|
jaxsim/api/ode.py,sha256=_t18avoCJngQk6eMFTGpaeahbpchQP20qJnUOCPkz8s,15360
|
17
17
|
jaxsim/api/ode_data.py,sha256=1SD-x-lYk_YSEnVpxTLd69uOKC0mFUj44ZqpSmEDOxw,20190
|
18
|
-
jaxsim/api/references.py,sha256=
|
18
|
+
jaxsim/api/references.py,sha256=eIOk3MAOc9LJSKfI8M4WA8gGD-meo50vRfhXdea4sNI,20539
|
19
19
|
jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
|
20
|
-
jaxsim/integrators/common.py,sha256=
|
20
|
+
jaxsim/integrators/common.py,sha256=ohISUnUWTaNHt2kweg1JyzwYGZgIH_wc-01qJWJsO80,18281
|
21
21
|
jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
|
22
|
-
jaxsim/integrators/variable_step.py,sha256=
|
22
|
+
jaxsim/integrators/variable_step.py,sha256=Tqz5ySSgyKak_k6cTXpmtqdPNaFlO7N6zj7jBIlChyM,22681
|
23
23
|
jaxsim/math/__init__.py,sha256=8oPITEoGwgRcOeG8KxtqxPQ8b5uku1HNRMokpCoi9Tc,352
|
24
24
|
jaxsim/math/adjoint.py,sha256=V7r5VrTCKPLEL5gavNSx9U7xSsrb11a5e4gWqJ2MuRo,4375
|
25
25
|
jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
|
@@ -31,42 +31,42 @@ jaxsim/math/skew.py,sha256=oOGSSR8PUGROl6IJFlrmu6K3gPH-u16hUPfKIkcVv9o,1177
|
|
31
31
|
jaxsim/math/transform.py,sha256=KXzQgOnCfAtbXCwxhplpJ3F0JT3oEyeLVby1_uRAryQ,2892
|
32
32
|
jaxsim/mujoco/__init__.py,sha256=fZyRWre49pIhOrYdf6yJk_hOax8qWGe8OCmoq-dMVq8,201
|
33
33
|
jaxsim/mujoco/__main__.py,sha256=GBmB7J-zj75ZnFyuAAmpSOpbxi_HhHhWJeot3ljGDJY,5291
|
34
|
-
jaxsim/mujoco/loaders.py,sha256=
|
34
|
+
jaxsim/mujoco/loaders.py,sha256=_CZekIqZNe8oFeH7zSv4gGZAZENRISwMd8dt640zjRI,20860
|
35
35
|
jaxsim/mujoco/model.py,sha256=5_7rWk_WBkNKDHqeewIFj0t2ZGqJpE6RDXHSbRvw4e4,16493
|
36
|
-
jaxsim/mujoco/utils.py,sha256=
|
36
|
+
jaxsim/mujoco/utils.py,sha256=vZ8afASNOSxnxVW9p_1U1J_n-9nVhnBDqlV5k8c1GkM,8256
|
37
37
|
jaxsim/mujoco/visualizer.py,sha256=nD6SNWmn-nxjjjIY9oPAHvL2j8q93DJDjZeepzke_DQ,6988
|
38
38
|
jaxsim/parsers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
39
|
-
jaxsim/parsers/kinematic_graph.py,sha256=
|
39
|
+
jaxsim/parsers/kinematic_graph.py,sha256=MJkJ7AW1TdLZmxibuiVrTfn6jHjh3OVhEF20DqwsCnM,34748
|
40
40
|
jaxsim/parsers/descriptions/__init__.py,sha256=PbIlunVfb59pB5jSX97YVpMAANRZPRkJ0X-hS14rzv4,221
|
41
41
|
jaxsim/parsers/descriptions/collision.py,sha256=BQeIG-TKi4SVny23w6riDrQ5itC6VRwEMBX6HgAXHxA,3973
|
42
|
-
jaxsim/parsers/descriptions/joint.py,sha256=
|
42
|
+
jaxsim/parsers/descriptions/joint.py,sha256=2KWLP4ILPMV8q1X0J7aS3GGFeZn4zXan0dqGOWc7XuQ,4365
|
43
43
|
jaxsim/parsers/descriptions/link.py,sha256=Eh0W5qL7_Uw0GV-BkNKXhm9Q2dRTfIWCX5D-87zQkxA,3711
|
44
44
|
jaxsim/parsers/descriptions/model.py,sha256=I2Vsbv8Josl4Le7b5rIvhqA2k9Bbv5JxMqwytayxds0,9833
|
45
45
|
jaxsim/parsers/rod/__init__.py,sha256=G2vqlLajBLUc4gyzXwsEI2Wsi4TMOIF9bLDFeT6KrGU,92
|
46
|
-
jaxsim/parsers/rod/parser.py,sha256=
|
47
|
-
jaxsim/parsers/rod/utils.py,sha256=
|
46
|
+
jaxsim/parsers/rod/parser.py,sha256=EXcbtr_vMjAaUzQjfQlD1zLLYLAZXrNeFHaiZVlLwFI,13976
|
47
|
+
jaxsim/parsers/rod/utils.py,sha256=czQ2Y1_I9zGO0y2XDotHSqDorVH6zEcPhkuelApjs3k,5697
|
48
48
|
jaxsim/rbda/__init__.py,sha256=kmy4G9aMkrqPNGdLSaSV3k15dpF52vBEUQXDFDuKIxU,337
|
49
49
|
jaxsim/rbda/aba.py,sha256=w7ciyxB0IsmueatT0C7PcBQEl9dyiH9oqJgIi3xeTUE,8983
|
50
|
-
jaxsim/rbda/collidable_points.py,sha256=
|
50
|
+
jaxsim/rbda/collidable_points.py,sha256=0PFLzxWKtRg8-JtfNhGlSjBMv1J98tiLymOdvlvAak4,5325
|
51
51
|
jaxsim/rbda/crba.py,sha256=bXkXESnVbv-lxhU1Y_i0rViEcQA4z2t2_jHwdVj5CBo,5049
|
52
52
|
jaxsim/rbda/forward_kinematics.py,sha256=2GmEoWsrioVl_SAbKRKfhOLz57pY4aR81PKRdulqStA,3458
|
53
|
-
jaxsim/rbda/jacobian.py,sha256=
|
53
|
+
jaxsim/rbda/jacobian.py,sha256=L6Vn4Kf9I6wj-MYcFY6o67mgIfLFaaW4i2wNQJ2PDL0,10981
|
54
54
|
jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
|
55
|
-
jaxsim/rbda/utils.py,sha256=
|
55
|
+
jaxsim/rbda/utils.py,sha256=GLt7XIl1ROkx0_fnBCKUHYdB9_IBF3Yi4OnkHSX3gxA,5365
|
56
56
|
jaxsim/rbda/contacts/__init__.py,sha256=L5MM-2pv76YPGzxExdz2EErgGBATuAjYnNHlq5QOySs,503
|
57
|
-
jaxsim/rbda/contacts/common.py,sha256=
|
58
|
-
jaxsim/rbda/contacts/relaxed_rigid.py,sha256=
|
59
|
-
jaxsim/rbda/contacts/rigid.py,sha256=
|
60
|
-
jaxsim/rbda/contacts/soft.py,sha256=
|
57
|
+
jaxsim/rbda/contacts/common.py,sha256=ai49HeLQOsnckG0H2tUKW2KQ0Au_v9jRuNdnqie-YBk,11234
|
58
|
+
jaxsim/rbda/contacts/relaxed_rigid.py,sha256=tbyskONuUhC6BZnZSpNUnlCjkI7LR6mCtmU_HimOAVE,20893
|
59
|
+
jaxsim/rbda/contacts/rigid.py,sha256=MSzkU6SFbW6CryNlyyxQ7K0-U-8k6VROGKv_DQrwqiw,17156
|
60
|
+
jaxsim/rbda/contacts/soft.py,sha256=t6bqBfGAtV1AWoevY82LAcXy2XW8w_uu7bNywcyxF0s,17001
|
61
61
|
jaxsim/rbda/contacts/visco_elastic.py,sha256=vQkfMuqQ3Qu8nbDTPY4jWBZjV3U7qtoRK1Aya3O3oFA,41424
|
62
62
|
jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
|
63
|
-
jaxsim/terrain/terrain.py,sha256=
|
63
|
+
jaxsim/terrain/terrain.py,sha256=_G1QS3zWycj089R8fTP5s2VjcZpEdJxREjXZJ-oXIvc,5248
|
64
64
|
jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
|
65
65
|
jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
|
66
|
-
jaxsim/utils/tracing.py,sha256=
|
67
|
-
jaxsim/utils/wrappers.py,sha256=
|
68
|
-
jaxsim-0.4.3.
|
69
|
-
jaxsim-0.4.3.
|
70
|
-
jaxsim-0.4.3.
|
71
|
-
jaxsim-0.4.3.
|
72
|
-
jaxsim-0.4.3.
|
66
|
+
jaxsim/utils/tracing.py,sha256=eEY28MZW0Lm_jJNt1NkFqZz0ek01tvhR46OXZYCo7tc,532
|
67
|
+
jaxsim/utils/wrappers.py,sha256=ZY7olSORzZRvSzkdeNLj8yjwUIAt9L0Douwl7wItjpk,4008
|
68
|
+
jaxsim-0.4.3.dev350.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
|
69
|
+
jaxsim-0.4.3.dev350.dist-info/METADATA,sha256=qyh1wWUq5dTCw9iznLWlS4DlKa6kfMnAMqZgYbldbCA,17513
|
70
|
+
jaxsim-0.4.3.dev350.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
|
71
|
+
jaxsim-0.4.3.dev350.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
72
|
+
jaxsim-0.4.3.dev350.dist-info/RECORD,,
|
File without changes
|
File without changes
|