jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +64 -30
  24. jaxsim/math/cross.py +18 -9
  25. jaxsim/math/inertia.py +11 -9
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +59 -25
  28. jaxsim/math/rotation.py +30 -24
  29. jaxsim/math/skew.py +18 -7
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev5.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev5.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/top_level.txt +0 -0
jaxsim/math/quaternion.py CHANGED
@@ -1,21 +1,27 @@
1
1
  import jax.lax
2
2
  import jax.numpy as jnp
3
+ import jaxlie
3
4
 
4
5
  import jaxsim.typing as jtp
5
- from jaxsim.sixd import so3
6
+
7
+ from .utils import safe_norm
6
8
 
7
9
 
8
10
  class Quaternion:
11
+ """
12
+ A utility class for quaternion operations.
13
+ """
14
+
9
15
  @staticmethod
10
16
  def to_xyzw(wxyz: jtp.Vector) -> jtp.Vector:
11
17
  """
12
18
  Convert a quaternion from WXYZ to XYZW representation.
13
19
 
14
20
  Args:
15
- wxyz (jtp.Vector): Quaternion in WXYZ representation.
21
+ wxyz: Quaternion in WXYZ representation.
16
22
 
17
23
  Returns:
18
- jtp.Vector: Quaternion in XYZW representation.
24
+ Quaternion in XYZW representation.
19
25
  """
20
26
  return wxyz.squeeze()[jnp.array([1, 2, 3, 0])]
21
27
 
@@ -25,10 +31,10 @@ class Quaternion:
25
31
  Convert a quaternion from XYZW to WXYZ representation.
26
32
 
27
33
  Args:
28
- xyzw (jtp.Vector): Quaternion in XYZW representation.
34
+ xyzw: Quaternion in XYZW representation.
29
35
 
30
36
  Returns:
31
- jtp.Vector: Quaternion in WXYZ representation.
37
+ Quaternion in WXYZ representation.
32
38
  """
33
39
  return xyzw.squeeze()[jnp.array([3, 0, 1, 2])]
34
40
 
@@ -38,14 +44,12 @@ class Quaternion:
38
44
  Convert a quaternion to a direction cosine matrix (DCM).
39
45
 
40
46
  Args:
41
- quaternion (jtp.Vector): Quaternion in XYZW representation.
47
+ quaternion: Quaternion in XYZW representation.
42
48
 
43
49
  Returns:
44
- jtp.Matrix: Direction cosine matrix (DCM).
50
+ The Direction cosine matrix (DCM).
45
51
  """
46
- return so3.SO3.from_quaternion_xyzw(
47
- xyzw=Quaternion.to_xyzw(quaternion)
48
- ).as_matrix()
52
+ return jaxlie.SO3(wxyz=quaternion).as_matrix()
49
53
 
50
54
  @staticmethod
51
55
  def from_dcm(dcm: jtp.Matrix) -> jtp.Vector:
@@ -53,14 +57,12 @@ class Quaternion:
53
57
  Convert a direction cosine matrix (DCM) to a quaternion.
54
58
 
55
59
  Args:
56
- dcm (jtp.Matrix): Direction cosine matrix (DCM).
60
+ dcm: Direction cosine matrix (DCM).
57
61
 
58
62
  Returns:
59
- jtp.Vector: Quaternion in XYZW representation.
63
+ Quaternion in WXYZ representation.
60
64
  """
61
- return Quaternion.to_wxyz(
62
- xyzw=so3.SO3.from_matrix(matrix=dcm).as_quaternion_xyzw()
63
- )
65
+ return jaxlie.SO3.from_matrix(matrix=dcm).wxyz
64
66
 
65
67
  @staticmethod
66
68
  def derivative(
@@ -73,13 +75,13 @@ class Quaternion:
73
75
  Compute the derivative of a quaternion given angular velocity.
74
76
 
75
77
  Args:
76
- quaternion (jtp.Vector): Quaternion in XYZW representation.
77
- omega (jtp.Vector): Angular velocity vector.
78
+ quaternion: Quaternion in XYZW representation.
79
+ omega: Angular velocity vector.
78
80
  omega_in_body_fixed (bool): Whether the angular velocity is in the body-fixed frame.
79
81
  K (float): A scaling factor.
80
82
 
81
83
  Returns:
82
- jtp.Vector: The derivative of the quaternion.
84
+ The derivative of the quaternion.
83
85
  """
84
86
  ω = omega.squeeze()
85
87
  quaternion = quaternion.squeeze()
@@ -115,21 +117,53 @@ class Quaternion:
115
117
  operand=quaternion,
116
118
  )
117
119
 
118
- norm_ω = jax.lax.cond(
119
- pred=ω.dot(ω) < (1e-6) ** 2,
120
- true_fun=lambda _: 1e-6,
121
- false_fun=lambda _: jnp.linalg.norm(ω),
122
- operand=None,
123
- )
120
+ norm_ω = safe_norm(ω)
124
121
 
125
122
  qd = 0.5 * (
126
123
  Q
127
124
  @ jnp.hstack(
128
125
  [
129
- K * norm_ω * (1 - jnp.linalg.norm(quaternion)),
126
+ K * norm_ω * (1 - safe_norm(quaternion)),
130
127
  ω,
131
128
  ]
132
129
  )
133
130
  )
134
131
 
135
132
  return jnp.vstack(qd)
133
+
134
+ @staticmethod
135
+ def integration(
136
+ quaternion: jtp.VectorLike,
137
+ dt: jtp.FloatLike,
138
+ omega: jtp.VectorLike,
139
+ omega_in_body_fixed: jtp.BoolLike = False,
140
+ ) -> jtp.Vector:
141
+ """
142
+ Integrate a quaternion in SO(3) given an angular velocity.
143
+
144
+ Args:
145
+ quaternion: The quaternion to integrate.
146
+ dt: The time step.
147
+ omega: The angular velocity vector.
148
+ omega_in_body_fixed:
149
+ Whether the angular velocity is in body-fixed representation
150
+ as opposed to the default inertial-fixed representation.
151
+
152
+ Returns:
153
+ The integrated quaternion.
154
+ """
155
+
156
+ ω_AB = jnp.array(omega).squeeze().astype(float)
157
+ A_Q_B = jnp.array(quaternion).squeeze().astype(float)
158
+
159
+ # Build the initial SO(3) quaternion.
160
+ W_Q_B_t0 = jaxlie.SO3(wxyz=A_Q_B)
161
+
162
+ # Integrate the quaternion on the manifold.
163
+ W_Q_B_tf = jax.lax.select(
164
+ pred=omega_in_body_fixed,
165
+ on_true=(W_Q_B_t0 @ jaxlie.SO3.exp(tangent=dt * ω_AB)).wxyz,
166
+ on_false=(jaxlie.SO3.exp(tangent=dt * ω_AB) @ W_Q_B_t0).wxyz,
167
+ )
168
+
169
+ return W_Q_B_tf
jaxsim/math/rotation.py CHANGED
@@ -1,27 +1,30 @@
1
- from typing import Tuple
2
-
3
- import jax
4
1
  import jax.numpy as jnp
2
+ import jaxlie
5
3
 
6
4
  import jaxsim.typing as jtp
7
- from jaxsim.sixd import so3
8
5
 
9
6
  from .skew import Skew
7
+ from .utils import safe_norm
10
8
 
11
9
 
12
10
  class Rotation:
11
+ """
12
+ A utility class for rotation matrix operations.
13
+ """
14
+
13
15
  @staticmethod
14
16
  def x(theta: jtp.Float) -> jtp.Matrix:
15
17
  """
16
18
  Generate a 3D rotation matrix around the X-axis.
17
19
 
18
20
  Args:
19
- theta (jtp.Float): Rotation angle in radians.
21
+ theta: Rotation angle in radians.
20
22
 
21
23
  Returns:
22
- jtp.Matrix: 3D rotation matrix.
24
+ The 3D rotation matrix.
23
25
  """
24
- return so3.SO3.from_x_radians(theta=theta).as_matrix()
26
+
27
+ return jaxlie.SO3.from_x_radians(theta=theta).as_matrix()
25
28
 
26
29
  @staticmethod
27
30
  def y(theta: jtp.Float) -> jtp.Matrix:
@@ -29,12 +32,13 @@ class Rotation:
29
32
  Generate a 3D rotation matrix around the Y-axis.
30
33
 
31
34
  Args:
32
- theta (jtp.Float): Rotation angle in radians.
35
+ theta: Rotation angle in radians.
33
36
 
34
37
  Returns:
35
- jtp.Matrix: 3D rotation matrix.
38
+ The 3D rotation matrix.
36
39
  """
37
- return so3.SO3.from_y_radians(theta=theta).as_matrix()
40
+
41
+ return jaxlie.SO3.from_y_radians(theta=theta).as_matrix()
38
42
 
39
43
  @staticmethod
40
44
  def z(theta: jtp.Float) -> jtp.Matrix:
@@ -42,12 +46,13 @@ class Rotation:
42
46
  Generate a 3D rotation matrix around the Z-axis.
43
47
 
44
48
  Args:
45
- theta (jtp.Float): Rotation angle in radians.
49
+ theta: Rotation angle in radians.
46
50
 
47
51
  Returns:
48
- jtp.Matrix: 3D rotation matrix.
52
+ The 3D rotation matrix.
49
53
  """
50
- return so3.SO3.from_z_radians(theta=theta).as_matrix()
54
+
55
+ return jaxlie.SO3.from_z_radians(theta=theta).as_matrix()
51
56
 
52
57
  @staticmethod
53
58
  def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix:
@@ -55,17 +60,18 @@ class Rotation:
55
60
  Generate a 3D rotation matrix from an axis-angle representation.
56
61
 
57
62
  Args:
58
- vector (jtp.Vector): Axis-angle representation as a 3D vector.
63
+ vector: Axis-angle representation or the rotation as a 3D vector.
59
64
 
60
65
  Returns:
61
- jtp.Matrix: 3D rotation matrix.
62
-
66
+ The SO(3) rotation matrix.
63
67
  """
68
+
64
69
  vector = vector.squeeze()
65
- theta = jnp.linalg.norm(vector)
66
70
 
67
- def theta_is_not_zero(theta_and_v: Tuple[jtp.Float, jtp.Vector]) -> jtp.Matrix:
68
- theta, v = theta_and_v
71
+ def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix:
72
+
73
+ v = axis
74
+ theta = safe_norm(v)
69
75
 
70
76
  s = jnp.sin(theta)
71
77
  c = jnp.cos(theta)
@@ -79,9 +85,9 @@ class Rotation:
79
85
 
80
86
  return R.transpose()
81
87
 
82
- return jax.lax.cond(
83
- pred=(theta == 0.0),
84
- true_fun=lambda operand: jnp.eye(3),
85
- false_fun=theta_is_not_zero,
86
- operand=(theta, vector),
88
+ return jnp.where(
89
+ jnp.allclose(vector, 0.0),
90
+ # Return an identity rotation matrix when the input vector is zero.
91
+ jnp.eye(3),
92
+ theta_is_not_zero(axis=vector),
87
93
  )
jaxsim/math/skew.py CHANGED
@@ -14,15 +14,26 @@ class Skew:
14
14
  Compute the skew-symmetric matrix (wedge operator) of a 3D vector.
15
15
 
16
16
  Args:
17
- vector (jtp.Vector): A 3D vector.
17
+ vector: A 3D vector.
18
18
 
19
19
  Returns:
20
- jtp.Matrix: The skew-symmetric matrix corresponding to the input vector.
20
+ The skew-symmetric matrix corresponding to the input vector.
21
21
 
22
22
  """
23
- vector = vector.squeeze()
24
- x, y, z = vector
25
- skew = jnp.array([[0, -z, y], [z, 0, -x], [-y, x, 0]])
23
+
24
+ vector = vector.reshape(-1, 3)
25
+
26
+ x, y, z = jnp.split(vector, 3, axis=-1)
27
+
28
+ skew = jnp.stack(
29
+ [
30
+ jnp.concatenate([jnp.zeros_like(x), -z, y], axis=-1),
31
+ jnp.concatenate([z, jnp.zeros_like(x), -x], axis=-1),
32
+ jnp.concatenate([-y, x, jnp.zeros_like(x)], axis=-1),
33
+ ],
34
+ axis=-2,
35
+ ).squeeze()
36
+
26
37
  return skew
27
38
 
28
39
  @staticmethod
@@ -31,10 +42,10 @@ class Skew:
31
42
  Extract the 3D vector from a skew-symmetric matrix (vee operator).
32
43
 
33
44
  Args:
34
- matrix (jtp.Matrix): A 3x3 skew-symmetric matrix.
45
+ matrix: A 3x3 skew-symmetric matrix.
35
46
 
36
47
  Returns:
37
- jtp.Vector: The 3D vector extracted from the input matrix.
48
+ The 3D vector extracted from the input matrix.
38
49
 
39
50
  """
40
51
  vector = 0.5 * jnp.vstack(
@@ -0,0 +1,102 @@
1
+ import jax.numpy as jnp
2
+ import jaxlie
3
+
4
+ import jaxsim.typing as jtp
5
+
6
+
7
+ class Transform:
8
+ """
9
+ A utility class for transformation matrix operations.
10
+ """
11
+
12
+ @staticmethod
13
+ def from_quaternion_and_translation(
14
+ quaternion: jtp.VectorLike | None = None,
15
+ translation: jtp.VectorLike | None = None,
16
+ inverse: jtp.BoolLike = False,
17
+ normalize_quaternion: jtp.BoolLike = False,
18
+ ) -> jtp.Matrix:
19
+ """
20
+ Create a transformation matrix from a quaternion and a translation.
21
+
22
+ Args:
23
+ quaternion: A 4D vector representing a SO(3) orientation.
24
+ translation: A 3D vector representing a translation.
25
+ inverse: Whether to compute the inverse transformation.
26
+ normalize_quaternion:
27
+ Whether to normalize the quaternion before creating the transformation.
28
+
29
+ Returns:
30
+ The 4x4 transformation matrix representing the SE(3) transformation.
31
+ """
32
+
33
+ quaternion = quaternion if quaternion is not None else jnp.array([1.0, 0, 0, 0])
34
+ translation = translation if translation is not None else jnp.zeros(3)
35
+
36
+ W_Q_B = jnp.array(quaternion).astype(float)
37
+ W_p_B = jnp.array(translation).astype(float)
38
+
39
+ assert W_p_B.size == 3
40
+ assert W_Q_B.size == 4
41
+
42
+ A_R_B = jaxlie.SO3(wxyz=W_Q_B)
43
+ A_R_B = A_R_B if not normalize_quaternion else A_R_B.normalize()
44
+
45
+ A_H_B = jaxlie.SE3.from_rotation_and_translation(
46
+ rotation=A_R_B, translation=W_p_B
47
+ )
48
+
49
+ return A_H_B.as_matrix() if not inverse else A_H_B.inverse().as_matrix()
50
+
51
+ @staticmethod
52
+ def from_rotation_and_translation(
53
+ rotation: jtp.MatrixLike | None = None,
54
+ translation: jtp.VectorLike | None = None,
55
+ inverse: jtp.BoolLike = False,
56
+ ) -> jtp.Matrix:
57
+ """
58
+ Create a transformation matrix from a rotation matrix and a translation vector.
59
+
60
+ Args:
61
+ rotation: A 3x3 rotation matrix representing a SO(3) orientation.
62
+ translation: A 3D vector representing a translation.
63
+ inverse: Whether to compute the inverse transformation.
64
+
65
+ Returns:
66
+ The 4x4 transformation matrix representing the SE(3) transformation.
67
+ """
68
+ rotation = rotation if rotation is not None else jnp.eye(3)
69
+ translation = translation if translation is not None else jnp.zeros(3)
70
+
71
+ A_R_B = jnp.array(rotation).astype(float)
72
+ W_p_B = jnp.array(translation).astype(float)
73
+
74
+ assert W_p_B.size == 3
75
+ assert A_R_B.shape == (3, 3)
76
+
77
+ A_H_B = jaxlie.SE3.from_rotation_and_translation(
78
+ rotation=jaxlie.SO3.from_matrix(A_R_B), translation=W_p_B
79
+ )
80
+
81
+ return A_H_B.as_matrix() if not inverse else A_H_B.inverse().as_matrix()
82
+
83
+ @staticmethod
84
+ def inverse(transform: jtp.MatrixLike) -> jtp.Matrix:
85
+ """
86
+ Compute the inverse transformation matrix.
87
+
88
+ Args:
89
+ transform: A 4x4 transformation matrix.
90
+
91
+ Returns:
92
+ The 4x4 inverse transformation matrix.
93
+ """
94
+
95
+ A_H_B = jnp.reshape(transform, (-1, 4, 4))
96
+
97
+ return (
98
+ jaxlie.SE3.from_matrix(matrix=A_H_B)
99
+ .inverse()
100
+ .as_matrix()
101
+ .reshape(transform.shape[:-2] + (4, 4))
102
+ )
jaxsim/math/utils.py ADDED
@@ -0,0 +1,31 @@
1
+ import jax.numpy as jnp
2
+
3
+ import jaxsim.typing as jtp
4
+
5
+
6
+ def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array:
7
+ """
8
+ Compute an array norm handling NaNs and making sure that
9
+ it is safe to get the gradient.
10
+
11
+ Args:
12
+ array: The array for which to compute the norm.
13
+ axis: The axis for which to compute the norm.
14
+
15
+ Returns:
16
+ The norm of the array with handling for zero arrays to avoid NaNs.
17
+ """
18
+
19
+ # Check if the entire array is composed of zeros.
20
+ is_zero = jnp.allclose(array, 0.0)
21
+
22
+ # Replace zeros with an array of ones temporarily to avoid division by zero.
23
+ # This ensures the computation of norm does not produce NaNs or Infs.
24
+ array = jnp.where(is_zero, jnp.ones_like(array), array)
25
+
26
+ # Compute the norm of the array along the specified axis.
27
+ norm = jnp.linalg.norm(array, axis=axis)
28
+
29
+ # Use `jnp.where` to set the norm to 0.0 where the input array was all zeros.
30
+ # This usage supports potential batch processing for future scalability.
31
+ return jnp.where(is_zero, 0.0, norm)
jaxsim/mujoco/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
- from .loaders import RodModelToMjcf, SdfToMjcf, UrdfToMjcf
1
+ from .loaders import ModelToMjcf, RodModelToMjcf, SdfToMjcf, UrdfToMjcf
2
2
  from .model import MujocoModelHelper
3
+ from .utils import mujoco_data_from_jaxsim
3
4
  from .visualizer import MujocoVideoRecorder, MujocoVisualizer