jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +64 -30
  24. jaxsim/math/cross.py +18 -9
  25. jaxsim/math/inertia.py +11 -9
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +59 -25
  28. jaxsim/math/rotation.py +30 -24
  29. jaxsim/math/skew.py +18 -7
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev5.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev5.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/top_level.txt +0 -0
jaxsim/rbda/rnea.py ADDED
@@ -0,0 +1,235 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import jaxlie
4
+
5
+ import jaxsim.api as js
6
+ import jaxsim.typing as jtp
7
+ from jaxsim.math import Adjoint, Cross, StandardGravity
8
+
9
+ from . import utils
10
+
11
+
12
+ def rnea(
13
+ model: js.model.JaxSimModel,
14
+ *,
15
+ base_position: jtp.Vector,
16
+ base_quaternion: jtp.Vector,
17
+ joint_positions: jtp.Vector,
18
+ base_linear_velocity: jtp.Vector,
19
+ base_angular_velocity: jtp.Vector,
20
+ joint_velocities: jtp.Vector,
21
+ base_linear_acceleration: jtp.Vector | None = None,
22
+ base_angular_acceleration: jtp.Vector | None = None,
23
+ joint_accelerations: jtp.Vector | None = None,
24
+ link_forces: jtp.Matrix | None = None,
25
+ standard_gravity: jtp.FloatLike = StandardGravity,
26
+ ) -> tuple[jtp.Vector, jtp.Vector]:
27
+ """
28
+ Compute inverse dynamics using the Recursive Newton-Euler Algorithm (RNEA).
29
+
30
+ Args:
31
+ model: The model to consider.
32
+ base_position: The position of the base link.
33
+ base_quaternion: The quaternion of the base link.
34
+ joint_positions: The positions of the joints.
35
+ base_linear_velocity:
36
+ The linear velocity of the base link in inertial-fixed representation.
37
+ base_angular_velocity:
38
+ The angular velocity of the base link in inertial-fixed representation.
39
+ joint_velocities: The velocities of the joints.
40
+ base_linear_acceleration:
41
+ The linear acceleration of the base link in inertial-fixed representation.
42
+ base_angular_acceleration:
43
+ The angular acceleration of the base link in inertial-fixed representation.
44
+ joint_accelerations: The accelerations of the joints.
45
+ link_forces:
46
+ The forces applied to the links expressed in the world frame.
47
+ standard_gravity: The standard gravity constant.
48
+
49
+ Returns:
50
+ A tuple containing the 6D force applied to the base link expressed in the
51
+ world frame and the joint forces that, when applied respectively to the base
52
+ link and joints, produce the given base and joint accelerations.
53
+ """
54
+
55
+ W_p_B, W_Q_B, s, W_v_WB, ṡ, W_v̇_WB, s̈, _, W_f, W_g = utils.process_inputs(
56
+ model=model,
57
+ base_position=base_position,
58
+ base_quaternion=base_quaternion,
59
+ joint_positions=joint_positions,
60
+ base_linear_velocity=base_linear_velocity,
61
+ base_angular_velocity=base_angular_velocity,
62
+ joint_velocities=joint_velocities,
63
+ base_linear_acceleration=base_linear_acceleration,
64
+ base_angular_acceleration=base_angular_acceleration,
65
+ joint_accelerations=joint_accelerations,
66
+ link_forces=link_forces,
67
+ standard_gravity=standard_gravity,
68
+ )
69
+
70
+ W_g = jnp.atleast_2d(W_g).T
71
+ W_v_WB = jnp.atleast_2d(W_v_WB).T
72
+ W_v̇_WB = jnp.atleast_2d(W_v̇_WB).T
73
+
74
+ # Get the 6D spatial inertia matrices of all links.
75
+ M = js.model.link_spatial_inertia_matrices(model=model)
76
+
77
+ # Get the parent array λ(i).
78
+ # Note: λ(0) must not be used, it's initialized to -1.
79
+ λ = model.kin_dyn_parameters.parent_array
80
+
81
+ # Compute the base transform.
82
+ W_H_B = jaxlie.SE3.from_rotation_and_translation(
83
+ rotation=jaxlie.SO3(wxyz=W_Q_B),
84
+ translation=W_p_B,
85
+ )
86
+
87
+ # Compute 6D transforms of the base velocity.
88
+ W_X_B = W_H_B.adjoint()
89
+ B_X_W = W_H_B.inverse().adjoint()
90
+
91
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
92
+ # These transforms define the relative kinematics of the entire model, including
93
+ # the base transform for both floating-base and fixed-base models.
94
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
95
+ joint_positions=s, base_transform=W_H_B.as_matrix()
96
+ )
97
+
98
+ # Allocate buffers.
99
+ v = jnp.zeros(shape=(model.number_of_links(), 6, 1))
100
+ a = jnp.zeros(shape=(model.number_of_links(), 6, 1))
101
+ f = jnp.zeros(shape=(model.number_of_links(), 6, 1))
102
+
103
+ # Allocate the buffer of transforms link -> base.
104
+ i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
105
+ i_X_0 = i_X_0.at[0].set(jnp.eye(6))
106
+
107
+ # Initialize the acceleration of the base link.
108
+ a_0 = -B_X_W @ W_g
109
+ a = a.at[0].set(a_0)
110
+
111
+ if model.floating_base():
112
+
113
+ # Base velocity v₀ in body-fixed representation.
114
+ v_0 = B_X_W @ W_v_WB
115
+ v = v.at[0].set(v_0)
116
+
117
+ # Base acceleration a₀ in body-fixed representation w/o gravity.
118
+ a_0 = B_X_W @ (W_v̇_WB - W_g)
119
+ a = a.at[0].set(a_0)
120
+
121
+ # Force applied to the base link that produce the base acceleration w/o gravity.
122
+ f_0 = (
123
+ M[0] @ a[0]
124
+ + Cross.vx_star(v[0]) @ M[0] @ v[0]
125
+ - W_X_B.T @ jnp.vstack(W_f[0])
126
+ )
127
+ f = f.at[0].set(f_0)
128
+
129
+ # ======
130
+ # Pass 1
131
+ # ======
132
+
133
+ ForwardPassCarry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
134
+ forward_pass_carry: ForwardPassCarry = (v, a, i_X_0, f)
135
+
136
+ def forward_pass(
137
+ carry: ForwardPassCarry, i: jtp.Int
138
+ ) -> tuple[ForwardPassCarry, None]:
139
+
140
+ ii = i - 1
141
+ v, a, i_X_0, f = carry
142
+
143
+ # Project the joint velocity into its motion subspace.
144
+ vJ = S[i] * ṡ[ii]
145
+
146
+ # Propagate the link velocity.
147
+ v_i = i_X_λi[i] @ v[λ[i]] + vJ
148
+ v = v.at[i].set(v_i)
149
+
150
+ # Propagate the link acceleration.
151
+ a_i = i_X_λi[i] @ a[λ[i]] + S[i] * s̈[ii] + Cross.vx(v[i]) @ vJ
152
+ a = a.at[i].set(a_i)
153
+
154
+ # Compute the link-to-base transform.
155
+ i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
156
+ i_X_0 = i_X_0.at[i].set(i_X_0_i)
157
+
158
+ # Compute link-to-world transform for the 6D force.
159
+ i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
160
+
161
+ # Compute the force acting on the link.
162
+ f_i = (
163
+ M[i] @ a[i]
164
+ + Cross.vx_star(v[i]) @ M[i] @ v[i]
165
+ - i_Xf_W @ jnp.vstack(W_f[i])
166
+ )
167
+ f = f.at[i].set(f_i)
168
+
169
+ return (v, a, i_X_0, f), None
170
+
171
+ (v, a, i_X_0, f), _ = (
172
+ jax.lax.scan(
173
+ f=forward_pass,
174
+ init=forward_pass_carry,
175
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
176
+ )
177
+ if model.number_of_links() > 1
178
+ else [(v, a, i_X_0, f), None]
179
+ )
180
+
181
+ # ======
182
+ # Pass 2
183
+ # ======
184
+
185
+ τ = jnp.zeros_like(s)
186
+
187
+ BackwardPassCarry = tuple[jtp.Vector, jtp.Matrix]
188
+ backward_pass_carry: BackwardPassCarry = (τ, f)
189
+
190
+ def backward_pass(
191
+ carry: BackwardPassCarry, i: jtp.Int
192
+ ) -> tuple[BackwardPassCarry, None]:
193
+
194
+ ii = i - 1
195
+ τ, f = carry
196
+
197
+ # Project the 6D force to the DoF of the joint.
198
+ τ_i = S[i].T @ f[i]
199
+ τ = τ.at[ii].set(τ_i.squeeze())
200
+
201
+ # Propagate the force to the parent link.
202
+ def update_f(f: jtp.Matrix) -> jtp.Matrix:
203
+
204
+ f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
205
+ f = f.at[λ[i]].set(f_λi)
206
+
207
+ return f
208
+
209
+ f = jax.lax.cond(
210
+ pred=jnp.logical_or(λ[i] != 0, model.floating_base()),
211
+ true_fun=update_f,
212
+ false_fun=lambda f: f,
213
+ operand=f,
214
+ )
215
+
216
+ return (τ, f), None
217
+
218
+ (τ, f), _ = (
219
+ jax.lax.scan(
220
+ f=backward_pass,
221
+ init=backward_pass_carry,
222
+ xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
223
+ )
224
+ if model.number_of_links() > 1
225
+ else [(τ, f), None]
226
+ )
227
+
228
+ # ==============
229
+ # Adjust outputs
230
+ # ==============
231
+
232
+ # Express the base 6D force in the world frame.
233
+ W_f0 = B_X_W.T @ f[0]
234
+
235
+ return W_f0.squeeze(), jnp.atleast_1d(τ.squeeze())
jaxsim/rbda/utils.py ADDED
@@ -0,0 +1,160 @@
1
+ import jax.numpy as jnp
2
+
3
+ import jaxsim.api as js
4
+ import jaxsim.typing as jtp
5
+ from jaxsim import exceptions
6
+ from jaxsim.math import StandardGravity
7
+
8
+
9
+ def process_inputs(
10
+ model: js.model.JaxSimModel,
11
+ *,
12
+ base_position: jtp.VectorLike | None = None,
13
+ base_quaternion: jtp.VectorLike | None = None,
14
+ joint_positions: jtp.VectorLike | None = None,
15
+ base_linear_velocity: jtp.VectorLike | None = None,
16
+ base_angular_velocity: jtp.VectorLike | None = None,
17
+ joint_velocities: jtp.VectorLike | None = None,
18
+ base_linear_acceleration: jtp.VectorLike | None = None,
19
+ base_angular_acceleration: jtp.VectorLike | None = None,
20
+ joint_accelerations: jtp.VectorLike | None = None,
21
+ joint_forces: jtp.VectorLike | None = None,
22
+ link_forces: jtp.MatrixLike | None = None,
23
+ standard_gravity: jtp.ScalarLike | None = None,
24
+ ) -> tuple[
25
+ jtp.Vector,
26
+ jtp.Vector,
27
+ jtp.Vector,
28
+ jtp.Vector,
29
+ jtp.Vector,
30
+ jtp.Vector,
31
+ jtp.Vector,
32
+ jtp.Vector,
33
+ jtp.Matrix,
34
+ jtp.Vector,
35
+ ]:
36
+ """
37
+ Adjust the inputs to rigid-body dynamics algorithms.
38
+
39
+ Args:
40
+ model: The model to consider.
41
+ base_position: The position of the base link.
42
+ base_quaternion: The quaternion of the base link.
43
+ joint_positions: The positions of the joints.
44
+ base_linear_velocity: The linear velocity of the base link.
45
+ base_angular_velocity: The angular velocity of the base link.
46
+ joint_velocities: The velocities of the joints.
47
+ base_linear_acceleration: The linear acceleration of the base link.
48
+ base_angular_acceleration: The angular acceleration of the base link.
49
+ joint_accelerations: The accelerations of the joints.
50
+ joint_forces: The forces applied to the joints.
51
+ link_forces: The forces applied to the links.
52
+ standard_gravity: The standard gravity constant.
53
+
54
+ Returns:
55
+ The adjusted inputs.
56
+ """
57
+
58
+ dofs = model.dofs()
59
+ nl = model.number_of_links()
60
+
61
+ # Floating-base position.
62
+ W_p_B = base_position
63
+ W_Q_B = base_quaternion
64
+ s = joint_positions
65
+
66
+ # Floating-base velocity in inertial-fixed representation.
67
+ W_vl_WB = base_linear_velocity
68
+ W_ω_WB = base_angular_velocity
69
+ ṡ = joint_velocities
70
+
71
+ # Floating-base acceleration in inertial-fixed representation.
72
+ W_v̇l_WB = base_linear_acceleration
73
+ W_ω̇_WB = base_angular_acceleration
74
+ s̈ = joint_accelerations
75
+
76
+ # System dynamics inputs.
77
+ f = link_forces
78
+ τ = joint_forces
79
+
80
+ # Fill missing data and adjust dimensions.
81
+ s = jnp.atleast_1d(s.squeeze()) if s is not None else jnp.zeros(dofs)
82
+ ṡ = jnp.atleast_1d(ṡ.squeeze()) if ṡ is not None else jnp.zeros(dofs)
83
+ s̈ = jnp.atleast_1d(s̈.squeeze()) if s̈ is not None else jnp.zeros(dofs)
84
+ τ = jnp.atleast_1d(τ.squeeze()) if τ is not None else jnp.zeros(dofs)
85
+ W_vl_WB = jnp.atleast_1d(W_vl_WB.squeeze()) if W_vl_WB is not None else jnp.zeros(3)
86
+ W_v̇l_WB = jnp.atleast_1d(W_v̇l_WB.squeeze()) if W_v̇l_WB is not None else jnp.zeros(3)
87
+ W_p_B = jnp.atleast_1d(W_p_B.squeeze()) if W_p_B is not None else jnp.zeros(3)
88
+ W_ω_WB = jnp.atleast_1d(W_ω_WB.squeeze()) if W_ω_WB is not None else jnp.zeros(3)
89
+ W_ω̇_WB = jnp.atleast_1d(W_ω̇_WB.squeeze()) if W_ω̇_WB is not None else jnp.zeros(3)
90
+ f = jnp.atleast_2d(f.squeeze()) if f is not None else jnp.zeros(shape=(nl, 6))
91
+ W_Q_B = (
92
+ jnp.atleast_1d(W_Q_B.squeeze())
93
+ if W_Q_B is not None
94
+ else jnp.array([1.0, 0, 0, 0])
95
+ )
96
+ standard_gravity = (
97
+ jnp.array(standard_gravity).squeeze()
98
+ if standard_gravity is not None
99
+ else StandardGravity
100
+ )
101
+
102
+ if s.shape != (dofs,):
103
+ raise ValueError(s.shape, dofs)
104
+
105
+ if ṡ.shape != (dofs,):
106
+ raise ValueError(ṡ.shape, dofs)
107
+
108
+ if s̈.shape != (dofs,):
109
+ raise ValueError(s̈.shape, dofs)
110
+
111
+ if τ.shape != (dofs,):
112
+ raise ValueError(τ.shape, dofs)
113
+
114
+ if W_p_B.shape != (3,):
115
+ raise ValueError(W_p_B.shape, (3,))
116
+
117
+ if W_vl_WB.shape != (3,):
118
+ raise ValueError(W_vl_WB.shape, (3,))
119
+
120
+ if W_ω_WB.shape != (3,):
121
+ raise ValueError(W_ω_WB.shape, (3,))
122
+
123
+ if W_v̇l_WB.shape != (3,):
124
+ raise ValueError(W_v̇l_WB.shape, (3,))
125
+
126
+ if W_ω̇_WB.shape != (3,):
127
+ raise ValueError(W_ω̇_WB.shape, (3,))
128
+
129
+ if f.shape != (nl, 6):
130
+ raise ValueError(f.shape, (nl, 6))
131
+
132
+ if W_Q_B.shape != (4,):
133
+ raise ValueError(W_Q_B.shape, (4,))
134
+
135
+ # Check that the quaternion is unary since our RBDAs make this assumption in order
136
+ # to prevent introducing additional normalizations that would affect AD.
137
+ exceptions.raise_value_error_if(
138
+ condition=~jnp.allclose(W_Q_B.dot(W_Q_B), 1.0),
139
+ msg="A RBDA received a quaternion that is not normalized.",
140
+ )
141
+
142
+ # Pack the 6D base velocity and acceleration.
143
+ W_v_WB = jnp.hstack([W_vl_WB, W_ω_WB])
144
+ W_v̇_WB = jnp.hstack([W_v̇l_WB, W_ω̇_WB])
145
+
146
+ # Create the 6D gravity acceleration.
147
+ W_g = jnp.zeros(6).at[2].set(-standard_gravity)
148
+
149
+ return (
150
+ W_p_B.astype(float),
151
+ W_Q_B.astype(float),
152
+ s.astype(float),
153
+ W_v_WB.astype(float),
154
+ ṡ.astype(float),
155
+ W_v̇_WB.astype(float),
156
+ s̈.astype(float),
157
+ τ.astype(float),
158
+ f.astype(float),
159
+ W_g.astype(float),
160
+ )
@@ -0,0 +1,2 @@
1
+ from . import terrain
2
+ from .terrain import FlatTerrain, PlaneTerrain, Terrain
@@ -0,0 +1,238 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ import dataclasses
5
+
6
+ import jax.numpy as jnp
7
+ import jax_dataclasses
8
+ import numpy as np
9
+
10
+ import jaxsim.math
11
+ import jaxsim.typing as jtp
12
+ from jaxsim import exceptions
13
+
14
+
15
+ class Terrain(abc.ABC):
16
+ """
17
+ Base class for terrain models.
18
+
19
+ Attributes:
20
+ delta: The delta value used for numerical differentiation.
21
+ """
22
+
23
+ delta = 0.010
24
+
25
+ @abc.abstractmethod
26
+ def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
27
+ """
28
+ Compute the height of the terrain at a specific (x, y) location.
29
+
30
+ Args:
31
+ x: The x-coordinate of the location.
32
+ y: The y-coordinate of the location.
33
+
34
+ Returns:
35
+ The height of the terrain at the specified location.
36
+ """
37
+
38
+ pass
39
+
40
+ def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
41
+ """
42
+ Compute the normal vector of the terrain at a specific (x, y) location.
43
+
44
+ Args:
45
+ x: The x-coordinate of the location.
46
+ y: The y-coordinate of the location.
47
+
48
+ Returns:
49
+ The normal vector of the terrain surface at the specified location.
50
+ """
51
+
52
+ # https://stackoverflow.com/a/5282364
53
+ h_xp = self.height(x=x + self.delta, y=y)
54
+ h_xm = self.height(x=x - self.delta, y=y)
55
+ h_yp = self.height(x=x, y=y + self.delta)
56
+ h_ym = self.height(x=x, y=y - self.delta)
57
+
58
+ n = jnp.array(
59
+ [(h_xm - h_xp) / (2 * self.delta), (h_ym - h_yp) / (2 * self.delta), 1.0]
60
+ )
61
+
62
+ return n / jaxsim.math.safe_norm(n)
63
+
64
+
65
+ @jax_dataclasses.pytree_dataclass
66
+ class FlatTerrain(Terrain):
67
+ """
68
+ Represents a terrain model with a flat surface and a constant height.
69
+ """
70
+
71
+ _height: float = dataclasses.field(default=0.0, kw_only=True)
72
+
73
+ @staticmethod
74
+ def build(height: jtp.FloatLike = 0.0) -> FlatTerrain:
75
+ """
76
+ Create a FlatTerrain instance with a specified height.
77
+
78
+ Args:
79
+ height: The height of the flat terrain.
80
+
81
+ Returns:
82
+ FlatTerrain: A FlatTerrain instance.
83
+ """
84
+
85
+ return FlatTerrain(_height=float(height))
86
+
87
+ def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
88
+ """
89
+ Compute the height of the terrain at a specific (x, y) location.
90
+
91
+ Args:
92
+ x: The x-coordinate of the location.
93
+ y: The y-coordinate of the location.
94
+
95
+ Returns:
96
+ The height of the terrain at the specified location.
97
+ """
98
+
99
+ return jnp.array(self._height, dtype=float)
100
+
101
+ def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
102
+ """
103
+ Compute the normal vector of the terrain at a specific (x, y) location.
104
+
105
+ Args:
106
+ x: The x-coordinate of the location.
107
+ y: The y-coordinate of the location.
108
+
109
+ Returns:
110
+ The normal vector of the terrain surface at the specified location.
111
+ """
112
+
113
+ return jnp.array([0.0, 0.0, 1.0], dtype=float)
114
+
115
+ def __hash__(self) -> int:
116
+
117
+ return hash(self._height)
118
+
119
+ def __eq__(self, other: FlatTerrain) -> bool:
120
+
121
+ if not isinstance(other, FlatTerrain):
122
+ return False
123
+
124
+ return self._height == other._height
125
+
126
+
127
+ @jax_dataclasses.pytree_dataclass
128
+ class PlaneTerrain(FlatTerrain):
129
+ """
130
+ Represents a terrain model with a flat surface defined by a normal vector.
131
+ """
132
+
133
+ _normal: tuple[float, float, float] = jax_dataclasses.field(
134
+ default=(0.0, 0.0, 1.0), kw_only=True
135
+ )
136
+
137
+ @staticmethod
138
+ def build(height: jtp.FloatLike = 0.0, *, normal: jtp.VectorLike) -> PlaneTerrain:
139
+ """
140
+ Create a PlaneTerrain instance with a specified plane normal vector.
141
+
142
+ Args:
143
+ normal: The normal vector of the terrain plane.
144
+ height: The height of the plane over the origin.
145
+
146
+ Returns:
147
+ PlaneTerrain: A PlaneTerrain instance.
148
+ """
149
+
150
+ normal = jnp.array(normal, dtype=float)
151
+ height = jnp.array(height, dtype=float)
152
+
153
+ if normal.shape != (3,):
154
+ msg = "Expected a 3D vector for the plane normal, got '{}'."
155
+ raise ValueError(msg.format(normal.shape))
156
+
157
+ # Make sure that the plane normal is a unit vector.
158
+ normal = normal / jnp.linalg.norm(normal)
159
+
160
+ return PlaneTerrain(
161
+ _height=height.item(),
162
+ _normal=tuple(normal.tolist()),
163
+ )
164
+
165
+ def normal(
166
+ self, x: jtp.FloatLike | None = None, y: jtp.FloatLike | None = None
167
+ ) -> jtp.Vector:
168
+ """
169
+ Compute the normal vector of the terrain at a specific (x, y) location.
170
+
171
+ Args:
172
+ x: The x-coordinate of the location.
173
+ y: The y-coordinate of the location.
174
+
175
+ Returns:
176
+ The normal vector of the terrain surface at the specified location.
177
+ """
178
+
179
+ return jnp.array(self._normal, dtype=float)
180
+
181
+ def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
182
+ """
183
+ Compute the height of the terrain at a specific (x, y) location on a plane.
184
+
185
+ Args:
186
+ x: The x-coordinate of the location.
187
+ y: The y-coordinate of the location.
188
+
189
+ Returns:
190
+ The height of the terrain at the specified location on the plane.
191
+ """
192
+
193
+ # Equation of the plane: A x + B y + C z + D = 0
194
+ # Normal vector coordinates: (A, B, C)
195
+ # The height over the origin: -D/C
196
+
197
+ # Get the plane equation coefficients from the terrain normal.
198
+ A, B, C = self._normal
199
+
200
+ exceptions.raise_value_error_if(
201
+ condition=jnp.allclose(C, 0.0),
202
+ msg="The z component of the normal cannot be zero.",
203
+ )
204
+
205
+ # Compute the final coefficient D considering the terrain height.
206
+ D = -C * self._height
207
+
208
+ # Invert the plane equation to get the height at the given (x, y) coordinates.
209
+ return jnp.array(-(A * x + B * y + D) / C).astype(float)
210
+
211
+ def __hash__(self) -> int:
212
+
213
+ from jaxsim.utils.wrappers import HashedNumpyArray
214
+
215
+ return hash(
216
+ (
217
+ hash(self._height),
218
+ HashedNumpyArray.hash_of_array(
219
+ array=np.array(self._normal, dtype=float)
220
+ ),
221
+ )
222
+ )
223
+
224
+ def __eq__(self, other: PlaneTerrain) -> bool:
225
+
226
+ if not isinstance(other, PlaneTerrain):
227
+ return False
228
+
229
+ if not (
230
+ np.allclose(self._height, other._height)
231
+ and np.allclose(
232
+ np.array(self._normal, dtype=float),
233
+ np.array(other._normal, dtype=float),
234
+ )
235
+ ):
236
+ return False
237
+
238
+ return True
jaxsim/typing.py CHANGED
@@ -1,4 +1,5 @@
1
- from typing import Any, Hashable
1
+ from collections.abc import Hashable
2
+ from typing import Any, TypeVar
2
3
 
3
4
  import jax
4
5
 
@@ -6,34 +7,33 @@ import jax
6
7
  # JAX types
7
8
  # =========
8
9
 
9
- ScalarJax = jax.Array
10
- IntJax = ScalarJax
11
- BoolJax = ScalarJax
12
- FloatJax = ScalarJax
13
-
14
- ArrayJax = jax.Array
15
- VectorJax = ArrayJax
16
- MatrixJax = ArrayJax
10
+ Array = jax.Array
11
+ Scalar = Array
12
+ Vector = Array
13
+ Matrix = Array
17
14
 
18
- PyTree = (
19
- dict[Hashable, "PyTree"] | list["PyTree"] | tuple["PyTree"] | None | jax.Array | Any
15
+ Int = Scalar
16
+ Bool = Scalar
17
+ Float = Scalar
18
+
19
+ PyTree: object = (
20
+ dict[Hashable, TypeVar("PyTree")]
21
+ | list[TypeVar("PyTree")]
22
+ | tuple[TypeVar("PyTree")]
23
+ | jax.Array
24
+ | Any
25
+ | None
20
26
  )
21
27
 
22
28
  # =======================
23
29
  # Mixed JAX / NumPy types
24
30
  # =======================
25
31
 
26
- Array = jax.typing.ArrayLike
27
- Vector = Array
28
- Matrix = Array
29
-
30
- Int = int | IntJax
31
- Bool = bool | ArrayJax
32
- Float = float | FloatJax
32
+ ArrayLike = jax.typing.ArrayLike | tuple
33
+ ScalarLike = int | float | Scalar | ArrayLike
34
+ VectorLike = Vector | ArrayLike | tuple
35
+ MatrixLike = Matrix | ArrayLike
33
36
 
34
- ArrayLike = Array
35
- VectorLike = Vector
36
- MatrixLike = Matrix
37
- IntLike = Int
38
- BoolLike = Bool
39
- FloatLike = Float
37
+ IntLike = int | Int | jax.typing.ArrayLike
38
+ BoolLike = bool | Bool | jax.typing.ArrayLike
39
+ FloatLike = float | Float | jax.typing.ArrayLike
jaxsim/utils/__init__.py CHANGED
@@ -2,7 +2,4 @@ from jax_dataclasses._copy_and_mutate import _Mutability as Mutability
2
2
 
3
3
  from .jaxsim_dataclass import JaxsimDataclass
4
4
  from .tracing import not_tracing, tracing
5
- from .vmappable import Vmappable
6
-
7
- # Leave this below the others to prevent circular imports
8
- from .oop import jax_tf # isort: skip
5
+ from .wrappers import HashedNumpyArray, HashlessObject