jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +64 -30
- jaxsim/math/cross.py +18 -9
- jaxsim/math/inertia.py +11 -9
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +59 -25
- jaxsim/math/rotation.py +30 -24
- jaxsim/math/skew.py +18 -7
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev5.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev5.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/top_level.txt +0 -0
jaxsim/math/quaternion.py
CHANGED
@@ -1,21 +1,27 @@
|
|
1
1
|
import jax.lax
|
2
2
|
import jax.numpy as jnp
|
3
|
+
import jaxlie
|
3
4
|
|
4
5
|
import jaxsim.typing as jtp
|
5
|
-
|
6
|
+
|
7
|
+
from .utils import safe_norm
|
6
8
|
|
7
9
|
|
8
10
|
class Quaternion:
|
11
|
+
"""
|
12
|
+
A utility class for quaternion operations.
|
13
|
+
"""
|
14
|
+
|
9
15
|
@staticmethod
|
10
16
|
def to_xyzw(wxyz: jtp.Vector) -> jtp.Vector:
|
11
17
|
"""
|
12
18
|
Convert a quaternion from WXYZ to XYZW representation.
|
13
19
|
|
14
20
|
Args:
|
15
|
-
wxyz
|
21
|
+
wxyz: Quaternion in WXYZ representation.
|
16
22
|
|
17
23
|
Returns:
|
18
|
-
|
24
|
+
Quaternion in XYZW representation.
|
19
25
|
"""
|
20
26
|
return wxyz.squeeze()[jnp.array([1, 2, 3, 0])]
|
21
27
|
|
@@ -25,10 +31,10 @@ class Quaternion:
|
|
25
31
|
Convert a quaternion from XYZW to WXYZ representation.
|
26
32
|
|
27
33
|
Args:
|
28
|
-
xyzw
|
34
|
+
xyzw: Quaternion in XYZW representation.
|
29
35
|
|
30
36
|
Returns:
|
31
|
-
|
37
|
+
Quaternion in WXYZ representation.
|
32
38
|
"""
|
33
39
|
return xyzw.squeeze()[jnp.array([3, 0, 1, 2])]
|
34
40
|
|
@@ -38,14 +44,12 @@ class Quaternion:
|
|
38
44
|
Convert a quaternion to a direction cosine matrix (DCM).
|
39
45
|
|
40
46
|
Args:
|
41
|
-
quaternion
|
47
|
+
quaternion: Quaternion in XYZW representation.
|
42
48
|
|
43
49
|
Returns:
|
44
|
-
|
50
|
+
The Direction cosine matrix (DCM).
|
45
51
|
"""
|
46
|
-
return
|
47
|
-
xyzw=Quaternion.to_xyzw(quaternion)
|
48
|
-
).as_matrix()
|
52
|
+
return jaxlie.SO3(wxyz=quaternion).as_matrix()
|
49
53
|
|
50
54
|
@staticmethod
|
51
55
|
def from_dcm(dcm: jtp.Matrix) -> jtp.Vector:
|
@@ -53,14 +57,12 @@ class Quaternion:
|
|
53
57
|
Convert a direction cosine matrix (DCM) to a quaternion.
|
54
58
|
|
55
59
|
Args:
|
56
|
-
dcm
|
60
|
+
dcm: Direction cosine matrix (DCM).
|
57
61
|
|
58
62
|
Returns:
|
59
|
-
|
63
|
+
Quaternion in WXYZ representation.
|
60
64
|
"""
|
61
|
-
return
|
62
|
-
xyzw=so3.SO3.from_matrix(matrix=dcm).as_quaternion_xyzw()
|
63
|
-
)
|
65
|
+
return jaxlie.SO3.from_matrix(matrix=dcm).wxyz
|
64
66
|
|
65
67
|
@staticmethod
|
66
68
|
def derivative(
|
@@ -73,13 +75,13 @@ class Quaternion:
|
|
73
75
|
Compute the derivative of a quaternion given angular velocity.
|
74
76
|
|
75
77
|
Args:
|
76
|
-
quaternion
|
77
|
-
omega
|
78
|
+
quaternion: Quaternion in XYZW representation.
|
79
|
+
omega: Angular velocity vector.
|
78
80
|
omega_in_body_fixed (bool): Whether the angular velocity is in the body-fixed frame.
|
79
81
|
K (float): A scaling factor.
|
80
82
|
|
81
83
|
Returns:
|
82
|
-
|
84
|
+
The derivative of the quaternion.
|
83
85
|
"""
|
84
86
|
ω = omega.squeeze()
|
85
87
|
quaternion = quaternion.squeeze()
|
@@ -115,21 +117,53 @@ class Quaternion:
|
|
115
117
|
operand=quaternion,
|
116
118
|
)
|
117
119
|
|
118
|
-
norm_ω =
|
119
|
-
pred=ω.dot(ω) < (1e-6) ** 2,
|
120
|
-
true_fun=lambda _: 1e-6,
|
121
|
-
false_fun=lambda _: jnp.linalg.norm(ω),
|
122
|
-
operand=None,
|
123
|
-
)
|
120
|
+
norm_ω = safe_norm(ω)
|
124
121
|
|
125
122
|
qd = 0.5 * (
|
126
123
|
Q
|
127
124
|
@ jnp.hstack(
|
128
125
|
[
|
129
|
-
K * norm_ω * (1 -
|
126
|
+
K * norm_ω * (1 - safe_norm(quaternion)),
|
130
127
|
ω,
|
131
128
|
]
|
132
129
|
)
|
133
130
|
)
|
134
131
|
|
135
132
|
return jnp.vstack(qd)
|
133
|
+
|
134
|
+
@staticmethod
|
135
|
+
def integration(
|
136
|
+
quaternion: jtp.VectorLike,
|
137
|
+
dt: jtp.FloatLike,
|
138
|
+
omega: jtp.VectorLike,
|
139
|
+
omega_in_body_fixed: jtp.BoolLike = False,
|
140
|
+
) -> jtp.Vector:
|
141
|
+
"""
|
142
|
+
Integrate a quaternion in SO(3) given an angular velocity.
|
143
|
+
|
144
|
+
Args:
|
145
|
+
quaternion: The quaternion to integrate.
|
146
|
+
dt: The time step.
|
147
|
+
omega: The angular velocity vector.
|
148
|
+
omega_in_body_fixed:
|
149
|
+
Whether the angular velocity is in body-fixed representation
|
150
|
+
as opposed to the default inertial-fixed representation.
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
The integrated quaternion.
|
154
|
+
"""
|
155
|
+
|
156
|
+
ω_AB = jnp.array(omega).squeeze().astype(float)
|
157
|
+
A_Q_B = jnp.array(quaternion).squeeze().astype(float)
|
158
|
+
|
159
|
+
# Build the initial SO(3) quaternion.
|
160
|
+
W_Q_B_t0 = jaxlie.SO3(wxyz=A_Q_B)
|
161
|
+
|
162
|
+
# Integrate the quaternion on the manifold.
|
163
|
+
W_Q_B_tf = jax.lax.select(
|
164
|
+
pred=omega_in_body_fixed,
|
165
|
+
on_true=(W_Q_B_t0 @ jaxlie.SO3.exp(tangent=dt * ω_AB)).wxyz,
|
166
|
+
on_false=(jaxlie.SO3.exp(tangent=dt * ω_AB) @ W_Q_B_t0).wxyz,
|
167
|
+
)
|
168
|
+
|
169
|
+
return W_Q_B_tf
|
jaxsim/math/rotation.py
CHANGED
@@ -1,27 +1,30 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
|
-
import jax
|
4
1
|
import jax.numpy as jnp
|
2
|
+
import jaxlie
|
5
3
|
|
6
4
|
import jaxsim.typing as jtp
|
7
|
-
from jaxsim.sixd import so3
|
8
5
|
|
9
6
|
from .skew import Skew
|
7
|
+
from .utils import safe_norm
|
10
8
|
|
11
9
|
|
12
10
|
class Rotation:
|
11
|
+
"""
|
12
|
+
A utility class for rotation matrix operations.
|
13
|
+
"""
|
14
|
+
|
13
15
|
@staticmethod
|
14
16
|
def x(theta: jtp.Float) -> jtp.Matrix:
|
15
17
|
"""
|
16
18
|
Generate a 3D rotation matrix around the X-axis.
|
17
19
|
|
18
20
|
Args:
|
19
|
-
theta
|
21
|
+
theta: Rotation angle in radians.
|
20
22
|
|
21
23
|
Returns:
|
22
|
-
|
24
|
+
The 3D rotation matrix.
|
23
25
|
"""
|
24
|
-
|
26
|
+
|
27
|
+
return jaxlie.SO3.from_x_radians(theta=theta).as_matrix()
|
25
28
|
|
26
29
|
@staticmethod
|
27
30
|
def y(theta: jtp.Float) -> jtp.Matrix:
|
@@ -29,12 +32,13 @@ class Rotation:
|
|
29
32
|
Generate a 3D rotation matrix around the Y-axis.
|
30
33
|
|
31
34
|
Args:
|
32
|
-
theta
|
35
|
+
theta: Rotation angle in radians.
|
33
36
|
|
34
37
|
Returns:
|
35
|
-
|
38
|
+
The 3D rotation matrix.
|
36
39
|
"""
|
37
|
-
|
40
|
+
|
41
|
+
return jaxlie.SO3.from_y_radians(theta=theta).as_matrix()
|
38
42
|
|
39
43
|
@staticmethod
|
40
44
|
def z(theta: jtp.Float) -> jtp.Matrix:
|
@@ -42,12 +46,13 @@ class Rotation:
|
|
42
46
|
Generate a 3D rotation matrix around the Z-axis.
|
43
47
|
|
44
48
|
Args:
|
45
|
-
theta
|
49
|
+
theta: Rotation angle in radians.
|
46
50
|
|
47
51
|
Returns:
|
48
|
-
|
52
|
+
The 3D rotation matrix.
|
49
53
|
"""
|
50
|
-
|
54
|
+
|
55
|
+
return jaxlie.SO3.from_z_radians(theta=theta).as_matrix()
|
51
56
|
|
52
57
|
@staticmethod
|
53
58
|
def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix:
|
@@ -55,17 +60,18 @@ class Rotation:
|
|
55
60
|
Generate a 3D rotation matrix from an axis-angle representation.
|
56
61
|
|
57
62
|
Args:
|
58
|
-
vector
|
63
|
+
vector: Axis-angle representation or the rotation as a 3D vector.
|
59
64
|
|
60
65
|
Returns:
|
61
|
-
|
62
|
-
|
66
|
+
The SO(3) rotation matrix.
|
63
67
|
"""
|
68
|
+
|
64
69
|
vector = vector.squeeze()
|
65
|
-
theta = jnp.linalg.norm(vector)
|
66
70
|
|
67
|
-
def theta_is_not_zero(
|
68
|
-
|
71
|
+
def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix:
|
72
|
+
|
73
|
+
v = axis
|
74
|
+
theta = safe_norm(v)
|
69
75
|
|
70
76
|
s = jnp.sin(theta)
|
71
77
|
c = jnp.cos(theta)
|
@@ -79,9 +85,9 @@ class Rotation:
|
|
79
85
|
|
80
86
|
return R.transpose()
|
81
87
|
|
82
|
-
return
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
88
|
+
return jnp.where(
|
89
|
+
jnp.allclose(vector, 0.0),
|
90
|
+
# Return an identity rotation matrix when the input vector is zero.
|
91
|
+
jnp.eye(3),
|
92
|
+
theta_is_not_zero(axis=vector),
|
87
93
|
)
|
jaxsim/math/skew.py
CHANGED
@@ -14,15 +14,26 @@ class Skew:
|
|
14
14
|
Compute the skew-symmetric matrix (wedge operator) of a 3D vector.
|
15
15
|
|
16
16
|
Args:
|
17
|
-
vector
|
17
|
+
vector: A 3D vector.
|
18
18
|
|
19
19
|
Returns:
|
20
|
-
|
20
|
+
The skew-symmetric matrix corresponding to the input vector.
|
21
21
|
|
22
22
|
"""
|
23
|
-
|
24
|
-
|
25
|
-
|
23
|
+
|
24
|
+
vector = vector.reshape(-1, 3)
|
25
|
+
|
26
|
+
x, y, z = jnp.split(vector, 3, axis=-1)
|
27
|
+
|
28
|
+
skew = jnp.stack(
|
29
|
+
[
|
30
|
+
jnp.concatenate([jnp.zeros_like(x), -z, y], axis=-1),
|
31
|
+
jnp.concatenate([z, jnp.zeros_like(x), -x], axis=-1),
|
32
|
+
jnp.concatenate([-y, x, jnp.zeros_like(x)], axis=-1),
|
33
|
+
],
|
34
|
+
axis=-2,
|
35
|
+
).squeeze()
|
36
|
+
|
26
37
|
return skew
|
27
38
|
|
28
39
|
@staticmethod
|
@@ -31,10 +42,10 @@ class Skew:
|
|
31
42
|
Extract the 3D vector from a skew-symmetric matrix (vee operator).
|
32
43
|
|
33
44
|
Args:
|
34
|
-
matrix
|
45
|
+
matrix: A 3x3 skew-symmetric matrix.
|
35
46
|
|
36
47
|
Returns:
|
37
|
-
|
48
|
+
The 3D vector extracted from the input matrix.
|
38
49
|
|
39
50
|
"""
|
40
51
|
vector = 0.5 * jnp.vstack(
|
jaxsim/math/transform.py
ADDED
@@ -0,0 +1,102 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
import jaxlie
|
3
|
+
|
4
|
+
import jaxsim.typing as jtp
|
5
|
+
|
6
|
+
|
7
|
+
class Transform:
|
8
|
+
"""
|
9
|
+
A utility class for transformation matrix operations.
|
10
|
+
"""
|
11
|
+
|
12
|
+
@staticmethod
|
13
|
+
def from_quaternion_and_translation(
|
14
|
+
quaternion: jtp.VectorLike | None = None,
|
15
|
+
translation: jtp.VectorLike | None = None,
|
16
|
+
inverse: jtp.BoolLike = False,
|
17
|
+
normalize_quaternion: jtp.BoolLike = False,
|
18
|
+
) -> jtp.Matrix:
|
19
|
+
"""
|
20
|
+
Create a transformation matrix from a quaternion and a translation.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
quaternion: A 4D vector representing a SO(3) orientation.
|
24
|
+
translation: A 3D vector representing a translation.
|
25
|
+
inverse: Whether to compute the inverse transformation.
|
26
|
+
normalize_quaternion:
|
27
|
+
Whether to normalize the quaternion before creating the transformation.
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
The 4x4 transformation matrix representing the SE(3) transformation.
|
31
|
+
"""
|
32
|
+
|
33
|
+
quaternion = quaternion if quaternion is not None else jnp.array([1.0, 0, 0, 0])
|
34
|
+
translation = translation if translation is not None else jnp.zeros(3)
|
35
|
+
|
36
|
+
W_Q_B = jnp.array(quaternion).astype(float)
|
37
|
+
W_p_B = jnp.array(translation).astype(float)
|
38
|
+
|
39
|
+
assert W_p_B.size == 3
|
40
|
+
assert W_Q_B.size == 4
|
41
|
+
|
42
|
+
A_R_B = jaxlie.SO3(wxyz=W_Q_B)
|
43
|
+
A_R_B = A_R_B if not normalize_quaternion else A_R_B.normalize()
|
44
|
+
|
45
|
+
A_H_B = jaxlie.SE3.from_rotation_and_translation(
|
46
|
+
rotation=A_R_B, translation=W_p_B
|
47
|
+
)
|
48
|
+
|
49
|
+
return A_H_B.as_matrix() if not inverse else A_H_B.inverse().as_matrix()
|
50
|
+
|
51
|
+
@staticmethod
|
52
|
+
def from_rotation_and_translation(
|
53
|
+
rotation: jtp.MatrixLike | None = None,
|
54
|
+
translation: jtp.VectorLike | None = None,
|
55
|
+
inverse: jtp.BoolLike = False,
|
56
|
+
) -> jtp.Matrix:
|
57
|
+
"""
|
58
|
+
Create a transformation matrix from a rotation matrix and a translation vector.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
rotation: A 3x3 rotation matrix representing a SO(3) orientation.
|
62
|
+
translation: A 3D vector representing a translation.
|
63
|
+
inverse: Whether to compute the inverse transformation.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
The 4x4 transformation matrix representing the SE(3) transformation.
|
67
|
+
"""
|
68
|
+
rotation = rotation if rotation is not None else jnp.eye(3)
|
69
|
+
translation = translation if translation is not None else jnp.zeros(3)
|
70
|
+
|
71
|
+
A_R_B = jnp.array(rotation).astype(float)
|
72
|
+
W_p_B = jnp.array(translation).astype(float)
|
73
|
+
|
74
|
+
assert W_p_B.size == 3
|
75
|
+
assert A_R_B.shape == (3, 3)
|
76
|
+
|
77
|
+
A_H_B = jaxlie.SE3.from_rotation_and_translation(
|
78
|
+
rotation=jaxlie.SO3.from_matrix(A_R_B), translation=W_p_B
|
79
|
+
)
|
80
|
+
|
81
|
+
return A_H_B.as_matrix() if not inverse else A_H_B.inverse().as_matrix()
|
82
|
+
|
83
|
+
@staticmethod
|
84
|
+
def inverse(transform: jtp.MatrixLike) -> jtp.Matrix:
|
85
|
+
"""
|
86
|
+
Compute the inverse transformation matrix.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
transform: A 4x4 transformation matrix.
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
The 4x4 inverse transformation matrix.
|
93
|
+
"""
|
94
|
+
|
95
|
+
A_H_B = jnp.reshape(transform, (-1, 4, 4))
|
96
|
+
|
97
|
+
return (
|
98
|
+
jaxlie.SE3.from_matrix(matrix=A_H_B)
|
99
|
+
.inverse()
|
100
|
+
.as_matrix()
|
101
|
+
.reshape(transform.shape[:-2] + (4, 4))
|
102
|
+
)
|
jaxsim/math/utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
|
3
|
+
import jaxsim.typing as jtp
|
4
|
+
|
5
|
+
|
6
|
+
def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array:
|
7
|
+
"""
|
8
|
+
Compute an array norm handling NaNs and making sure that
|
9
|
+
it is safe to get the gradient.
|
10
|
+
|
11
|
+
Args:
|
12
|
+
array: The array for which to compute the norm.
|
13
|
+
axis: The axis for which to compute the norm.
|
14
|
+
|
15
|
+
Returns:
|
16
|
+
The norm of the array with handling for zero arrays to avoid NaNs.
|
17
|
+
"""
|
18
|
+
|
19
|
+
# Check if the entire array is composed of zeros.
|
20
|
+
is_zero = jnp.allclose(array, 0.0)
|
21
|
+
|
22
|
+
# Replace zeros with an array of ones temporarily to avoid division by zero.
|
23
|
+
# This ensures the computation of norm does not produce NaNs or Infs.
|
24
|
+
array = jnp.where(is_zero, jnp.ones_like(array), array)
|
25
|
+
|
26
|
+
# Compute the norm of the array along the specified axis.
|
27
|
+
norm = jnp.linalg.norm(array, axis=axis)
|
28
|
+
|
29
|
+
# Use `jnp.where` to set the norm to 0.0 where the input array was all zeros.
|
30
|
+
# This usage supports potential batch processing for future scalability.
|
31
|
+
return jnp.where(is_zero, 0.0, norm)
|
jaxsim/mujoco/__init__.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
-
from .loaders import RodModelToMjcf, SdfToMjcf, UrdfToMjcf
|
1
|
+
from .loaders import ModelToMjcf, RodModelToMjcf, SdfToMjcf, UrdfToMjcf
|
2
2
|
from .model import MujocoModelHelper
|
3
|
+
from .utils import mujoco_data_from_jaxsim
|
3
4
|
from .visualizer import MujocoVideoRecorder, MujocoVisualizer
|