jaxsim 0.2.dev188__py3-none-any.whl → 0.2.dev364__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. jaxsim/__init__.py +3 -4
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +13 -2
  6. jaxsim/api/contact.py +120 -43
  7. jaxsim/api/data.py +112 -71
  8. jaxsim/api/joint.py +77 -36
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +150 -75
  11. jaxsim/api/model.py +542 -269
  12. jaxsim/api/ode.py +88 -72
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +12 -11
  15. jaxsim/integrators/__init__.py +2 -2
  16. jaxsim/integrators/common.py +110 -24
  17. jaxsim/integrators/fixed_step.py +11 -67
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +93 -0
  25. jaxsim/parsers/descriptions/collision.py +14 -0
  26. jaxsim/parsers/descriptions/link.py +13 -2
  27. jaxsim/parsers/kinematic_graph.py +5 -0
  28. jaxsim/parsers/rod/utils.py +7 -8
  29. jaxsim/rbda/__init__.py +7 -0
  30. jaxsim/rbda/aba.py +295 -0
  31. jaxsim/rbda/collidable_points.py +142 -0
  32. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  33. jaxsim/rbda/forward_kinematics.py +113 -0
  34. jaxsim/rbda/jacobian.py +201 -0
  35. jaxsim/rbda/rnea.py +237 -0
  36. jaxsim/rbda/soft_contacts.py +296 -0
  37. jaxsim/rbda/utils.py +152 -0
  38. jaxsim/terrain/__init__.py +2 -0
  39. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  40. jaxsim/utils/__init__.py +1 -4
  41. jaxsim/utils/hashless.py +18 -0
  42. jaxsim/utils/jaxsim_dataclass.py +281 -30
  43. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
  44. jaxsim-0.2.dev364.dist-info/RECORD +64 -0
  45. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
  46. jaxsim/high_level/__init__.py +0 -2
  47. jaxsim/high_level/common.py +0 -11
  48. jaxsim/high_level/joint.py +0 -148
  49. jaxsim/high_level/link.py +0 -259
  50. jaxsim/high_level/model.py +0 -1686
  51. jaxsim/math/conv.py +0 -114
  52. jaxsim/math/joint.py +0 -102
  53. jaxsim/math/plucker.py +0 -100
  54. jaxsim/physics/__init__.py +0 -12
  55. jaxsim/physics/algos/__init__.py +0 -0
  56. jaxsim/physics/algos/aba.py +0 -254
  57. jaxsim/physics/algos/aba_motors.py +0 -284
  58. jaxsim/physics/algos/forward_kinematics.py +0 -79
  59. jaxsim/physics/algos/jacobian.py +0 -98
  60. jaxsim/physics/algos/rnea.py +0 -180
  61. jaxsim/physics/algos/rnea_motors.py +0 -196
  62. jaxsim/physics/algos/soft_contacts.py +0 -523
  63. jaxsim/physics/algos/utils.py +0 -69
  64. jaxsim/physics/model/__init__.py +0 -0
  65. jaxsim/physics/model/ground_contact.py +0 -55
  66. jaxsim/physics/model/physics_model.py +0 -388
  67. jaxsim/physics/model/physics_model_state.py +0 -283
  68. jaxsim/simulation/__init__.py +0 -4
  69. jaxsim/simulation/integrators.py +0 -393
  70. jaxsim/simulation/ode.py +0 -290
  71. jaxsim/simulation/ode_data.py +0 -96
  72. jaxsim/simulation/ode_integration.py +0 -62
  73. jaxsim/simulation/simulator.py +0 -543
  74. jaxsim/simulation/simulator_callbacks.py +0 -79
  75. jaxsim/simulation/utils.py +0 -15
  76. jaxsim/sixd/__init__.py +0 -2
  77. jaxsim/utils/oop.py +0 -536
  78. jaxsim/utils/vmappable.py +0 -117
  79. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  80. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
  81. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,296 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import jax_dataclasses
8
+
9
+ import jaxsim.api as js
10
+ import jaxsim.typing as jtp
11
+ from jaxsim.math import Skew, StandardGravity
12
+ from jaxsim.terrain import FlatTerrain, Terrain
13
+ from jaxsim.utils import JaxsimDataclass
14
+
15
+
16
+ @jax_dataclasses.pytree_dataclass
17
+ class SoftContactsParams(JaxsimDataclass):
18
+ """Parameters of the soft contacts model."""
19
+
20
+ K: jtp.Float = dataclasses.field(
21
+ default_factory=lambda: jnp.array(1e6, dtype=float)
22
+ )
23
+
24
+ D: jtp.Float = dataclasses.field(
25
+ default_factory=lambda: jnp.array(2000, dtype=float)
26
+ )
27
+
28
+ mu: jtp.Float = dataclasses.field(
29
+ default_factory=lambda: jnp.array(0.5, dtype=float)
30
+ )
31
+
32
+ @staticmethod
33
+ def build(
34
+ K: jtp.FloatLike = 1e6, D: jtp.FloatLike = 2_000, mu: jtp.FloatLike = 0.5
35
+ ) -> SoftContactsParams:
36
+ """
37
+ Create a SoftContactsParams instance with specified parameters.
38
+
39
+ Args:
40
+ K: The stiffness parameter.
41
+ D: The damping parameter of the soft contacts model.
42
+ mu: The static friction coefficient.
43
+
44
+ Returns:
45
+ A SoftContactsParams instance with the specified parameters.
46
+ """
47
+
48
+ return SoftContactsParams(
49
+ K=jnp.array(K, dtype=float),
50
+ D=jnp.array(D, dtype=float),
51
+ mu=jnp.array(mu, dtype=float),
52
+ )
53
+
54
+ @staticmethod
55
+ def build_default_from_jaxsim_model(
56
+ model: js.model.JaxSimModel,
57
+ *,
58
+ standard_gravity: jtp.FloatLike = StandardGravity,
59
+ static_friction_coefficient: jtp.FloatLike = 0.5,
60
+ max_penetration: jtp.FloatLike = 0.001,
61
+ number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
62
+ damping_ratio: jtp.FloatLike = 1.0,
63
+ ) -> SoftContactsParams:
64
+ """
65
+ Create a SoftContactsParams instance with good default parameters.
66
+
67
+ Args:
68
+ model: The target model.
69
+ standard_gravity: The standard gravity constant.
70
+ static_friction_coefficient:
71
+ The static friction coefficient between the model and the terrain.
72
+ max_penetration: The maximum penetration depth.
73
+ number_of_active_collidable_points_steady_state:
74
+ The number of contacts supporting the weight of the model
75
+ in steady state.
76
+ damping_ratio: The ratio controlling the damping behavior.
77
+
78
+ Returns:
79
+ A `SoftContactsParams` instance with the specified parameters.
80
+
81
+ Note:
82
+ The `damping_ratio` parameter allows to operate on the following conditions:
83
+ - ξ > 1.0: over-damped
84
+ - ξ = 1.0: critically damped
85
+ - ξ < 1.0: under-damped
86
+ """
87
+
88
+ # Use symbols for input parameters
89
+ ξ = damping_ratio
90
+ δ_max = max_penetration
91
+ μc = static_friction_coefficient
92
+
93
+ # Compute the total mass of the model
94
+ m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum()
95
+
96
+ # Rename the standard gravity
97
+ g = standard_gravity
98
+
99
+ # Compute the average support force on each collidable point
100
+ f_average = m * g / number_of_active_collidable_points_steady_state
101
+
102
+ # Compute the stiffness to get the desired steady-state penetration
103
+ K = f_average / jnp.power(δ_max, 3 / 2)
104
+
105
+ # Compute the damping using the damping ratio
106
+ critical_damping = 2 * jnp.sqrt(K * m)
107
+ D = ξ * critical_damping
108
+
109
+ return SoftContactsParams.build(K=K, D=D, mu=μc)
110
+
111
+
112
+ @jax_dataclasses.pytree_dataclass
113
+ class SoftContacts:
114
+ """Soft contacts model."""
115
+
116
+ parameters: SoftContactsParams = dataclasses.field(
117
+ default_factory=SoftContactsParams
118
+ )
119
+
120
+ terrain: Terrain = dataclasses.field(default_factory=FlatTerrain)
121
+
122
+ def contact_model(
123
+ self,
124
+ position: jtp.Vector,
125
+ velocity: jtp.Vector,
126
+ tangential_deformation: jtp.Vector,
127
+ ) -> tuple[jtp.Vector, jtp.Vector]:
128
+ """
129
+ Compute the contact forces and material deformation rate.
130
+
131
+ Args:
132
+ position: The position of the collidable point.
133
+ velocity: The linear velocity of the collidable point.
134
+ tangential_deformation: The tangential deformation.
135
+
136
+ Returns:
137
+ A tuple containing the contact force and material deformation rate.
138
+ """
139
+
140
+ # Short name of parameters
141
+ K = self.parameters.K
142
+ D = self.parameters.D
143
+ μ = self.parameters.mu
144
+
145
+ # Material 3D tangential deformation and its derivative
146
+ m = tangential_deformation.squeeze()
147
+ ṁ = jnp.zeros_like(m)
148
+
149
+ # Note: all the small hardcoded tolerances in this method have been introduced
150
+ # to allow jax differentiating through this algorithm. They should not affect
151
+ # the accuracy of the simulation, although they might make it less readable.
152
+
153
+ # ========================
154
+ # Normal force computation
155
+ # ========================
156
+
157
+ # Unpack the position of the collidable point
158
+ px, py, pz = W_p_C = position.squeeze()
159
+ vx, vy, vz = W_ṗ_C = velocity.squeeze()
160
+
161
+ # Compute the terrain normal and the contact depth
162
+ n̂ = self.terrain.normal(x=px, y=py).squeeze()
163
+ h = jnp.array([0, 0, self.terrain.height(x=px, y=py) - pz])
164
+
165
+ # Compute the penetration depth normal to the terrain
166
+ δ = jnp.maximum(0.0, jnp.dot(h, n̂))
167
+
168
+ # Compute the penetration normal velocity
169
+ δ̇ = -jnp.dot(W_ṗ_C, n̂)
170
+
171
+ # Non-linear spring-damper model.
172
+ # This is the force magnitude along the direction normal to the terrain.
173
+ force_normal_mag = jax.lax.select(
174
+ pred=δ >= 1e-9,
175
+ on_true=jnp.sqrt(δ + 1e-12) * (K * δ + D * δ̇),
176
+ on_false=jnp.array(0.0),
177
+ )
178
+
179
+ # Prevent negative normal forces that might occur when δ̇ is largely negative
180
+ force_normal_mag = jnp.maximum(0.0, force_normal_mag)
181
+
182
+ # Compute the 3D linear force in C[W] frame
183
+ force_normal = force_normal_mag * n̂
184
+
185
+ # ====================================
186
+ # No friction and no tangential forces
187
+ # ====================================
188
+
189
+ # Compute the adjoint C[W]->W for transforming 6D forces from mixed to inertial.
190
+ # Note: this is equal to the 6D velocities transform: CW_X_W.transpose().
191
+ W_Xf_CW = jnp.vstack(
192
+ [
193
+ jnp.block([jnp.eye(3), jnp.zeros(shape=(3, 3))]),
194
+ jnp.block([Skew.wedge(W_p_C), jnp.eye(3)]),
195
+ ]
196
+ )
197
+
198
+ def with_no_friction():
199
+ # Compute 6D mixed force in C[W]
200
+ CW_f_lin = force_normal
201
+ CW_f = jnp.hstack([force_normal, jnp.zeros_like(CW_f_lin)])
202
+
203
+ # Compute lin-ang 6D forces (inertial representation)
204
+ W_f = W_Xf_CW @ CW_f
205
+
206
+ return W_f, ṁ
207
+
208
+ # =========================
209
+ # Compute tangential forces
210
+ # =========================
211
+
212
+ def with_friction():
213
+ # Initialize the tangential deformation rate ṁ.
214
+ # For inactive contacts with m≠0, this is the dynamics of the material
215
+ # relaxation converging exponentially to steady state.
216
+ ṁ = (-K / D) * m
217
+
218
+ # Check if the collidable point is below ground.
219
+ # Note: when δ=0, we consider the point still not it contact such that
220
+ # we prevent divisions by 0 in the computations below.
221
+ active_contact = pz < self.terrain.height(x=px, y=py)
222
+
223
+ def above_terrain():
224
+ return jnp.zeros(6), ṁ
225
+
226
+ def below_terrain():
227
+ # Decompose the velocity in normal and tangential components
228
+ v_normal = jnp.dot(W_ṗ_C, n̂) * n̂
229
+ v_tangential = W_ṗ_C - v_normal
230
+
231
+ # Compute the tangential force. If inside the friction cone, the contact
232
+ f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential)
233
+
234
+ def sticking_contact():
235
+ # Sum the normal and tangential forces, and create the 6D force
236
+ CW_f_stick = force_normal + f_tangential
237
+ CW_f = jnp.hstack([CW_f_stick, jnp.zeros(3)])
238
+
239
+ # In this case the 3D material deformation is the tangential velocity
240
+ ṁ = v_tangential
241
+
242
+ # Return the 6D force in the contact frame and
243
+ # the deformation derivative
244
+ return CW_f, ṁ
245
+
246
+ def slipping_contact():
247
+ # Project the force to the friction cone boundary
248
+ f_tangential_projected = (μ * force_normal_mag) * (
249
+ f_tangential / jnp.maximum(jnp.linalg.norm(f_tangential), 1e-9)
250
+ )
251
+
252
+ # Sum the normal and tangential forces, and create the 6D force
253
+ CW_f_slip = force_normal + f_tangential_projected
254
+ CW_f = jnp.hstack([CW_f_slip, jnp.zeros(3)])
255
+
256
+ # Correct the material deformation derivative for slipping contacts.
257
+ # Basically we compute ṁ such that we get `f_tangential` on the cone
258
+ # given the current (m, δ).
259
+ ε = 1e-9
260
+ δε = jnp.maximum(δ, ε)
261
+ α = -K * jnp.sqrt(δε)
262
+ β = -D * jnp.sqrt(δε)
263
+ ṁ = (f_tangential_projected - α * m) / β
264
+
265
+ # Return the 6D force in the contact frame and
266
+ # the deformation derivative
267
+ return CW_f, ṁ
268
+
269
+ CW_f, ṁ = jax.lax.cond(
270
+ pred=f_tangential.dot(f_tangential) > (μ * force_normal_mag) ** 2,
271
+ true_fun=lambda _: slipping_contact(),
272
+ false_fun=lambda _: sticking_contact(),
273
+ operand=None,
274
+ )
275
+
276
+ # Express the 6D force in the world frame
277
+ W_f = W_Xf_CW @ CW_f
278
+
279
+ # Return the 6D force in the world frame and the deformation derivative
280
+ return W_f, ṁ
281
+
282
+ # (W_f, ṁ)
283
+ return jax.lax.cond(
284
+ pred=active_contact,
285
+ true_fun=lambda _: below_terrain(),
286
+ false_fun=lambda _: above_terrain(),
287
+ operand=None,
288
+ )
289
+
290
+ # (W_f, ṁ)
291
+ return jax.lax.cond(
292
+ pred=(μ == 0.0),
293
+ true_fun=lambda _: with_no_friction(),
294
+ false_fun=lambda _: with_friction(),
295
+ operand=None,
296
+ )
jaxsim/rbda/utils.py ADDED
@@ -0,0 +1,152 @@
1
+ import jax.numpy as jnp
2
+
3
+ import jaxsim.api as js
4
+ import jaxsim.typing as jtp
5
+ from jaxsim.math import StandardGravity
6
+
7
+
8
+ def process_inputs(
9
+ model: js.model.JaxSimModel,
10
+ *,
11
+ base_position: jtp.VectorLike | None = None,
12
+ base_quaternion: jtp.VectorLike | None = None,
13
+ joint_positions: jtp.VectorLike | None = None,
14
+ base_linear_velocity: jtp.VectorLike | None = None,
15
+ base_angular_velocity: jtp.VectorLike | None = None,
16
+ joint_velocities: jtp.VectorLike | None = None,
17
+ base_linear_acceleration: jtp.VectorLike | None = None,
18
+ base_angular_acceleration: jtp.VectorLike | None = None,
19
+ joint_accelerations: jtp.VectorLike | None = None,
20
+ joint_forces: jtp.VectorLike | None = None,
21
+ link_forces: jtp.MatrixLike | None = None,
22
+ standard_gravity: jtp.VectorLike | None = None,
23
+ ) -> tuple[
24
+ jtp.Vector,
25
+ jtp.Vector,
26
+ jtp.Vector,
27
+ jtp.Vector,
28
+ jtp.Vector,
29
+ jtp.Vector,
30
+ jtp.Vector,
31
+ jtp.Vector,
32
+ jtp.Matrix,
33
+ jtp.Vector,
34
+ ]:
35
+ """
36
+ Adjust the inputs to rigid-body dynamics algorithms.
37
+
38
+ Args:
39
+ model: The model to consider.
40
+ base_position: The position of the base link.
41
+ base_quaternion: The quaternion of the base link.
42
+ joint_positions: The positions of the joints.
43
+ base_linear_velocity: The linear velocity of the base link.
44
+ base_angular_velocity: The angular velocity of the base link.
45
+ joint_velocities: The velocities of the joints.
46
+ base_linear_acceleration: The linear acceleration of the base link.
47
+ base_angular_acceleration: The angular acceleration of the base link.
48
+ joint_accelerations: The accelerations of the joints.
49
+ joint_forces: The forces applied to the joints.
50
+ link_forces: The forces applied to the links.
51
+ standard_gravity: The standard gravity constant.
52
+
53
+ Returns:
54
+ The adjusted inputs.
55
+ """
56
+
57
+ dofs = model.dofs()
58
+ nl = model.number_of_links()
59
+
60
+ # Floating-base position.
61
+ W_p_B = base_position
62
+ W_Q_B = base_quaternion
63
+ s = joint_positions
64
+
65
+ # Floating-base velocity in inertial-fixed representation.
66
+ W_vl_WB = base_linear_velocity
67
+ W_ω_WB = base_angular_velocity
68
+ ṡ = joint_velocities
69
+
70
+ # Floating-base acceleration in inertial-fixed representation.
71
+ W_v̇l_WB = base_linear_acceleration
72
+ W_ω̇_WB = base_angular_acceleration
73
+ s̈ = joint_accelerations
74
+
75
+ # System dynamics inputs.
76
+ f = link_forces
77
+ τ = joint_forces
78
+
79
+ # Fill missing data and adjust dimensions.
80
+ s = jnp.atleast_1d(s.squeeze()) if s is not None else jnp.zeros(dofs)
81
+ ṡ = jnp.atleast_1d(ṡ.squeeze()) if ṡ is not None else jnp.zeros(dofs)
82
+ s̈ = jnp.atleast_1d(s̈.squeeze()) if s̈ is not None else jnp.zeros(dofs)
83
+ τ = jnp.atleast_1d(τ.squeeze()) if τ is not None else jnp.zeros(dofs)
84
+ W_vl_WB = jnp.atleast_1d(W_vl_WB.squeeze()) if W_vl_WB is not None else jnp.zeros(3)
85
+ W_v̇l_WB = jnp.atleast_1d(W_v̇l_WB.squeeze()) if W_v̇l_WB is not None else jnp.zeros(3)
86
+ W_p_B = jnp.atleast_1d(W_p_B.squeeze()) if W_p_B is not None else jnp.zeros(3)
87
+ W_ω_WB = jnp.atleast_1d(W_ω_WB.squeeze()) if W_ω_WB is not None else jnp.zeros(3)
88
+ W_ω̇_WB = jnp.atleast_1d(W_ω̇_WB.squeeze()) if W_ω̇_WB is not None else jnp.zeros(3)
89
+ f = jnp.atleast_2d(f.squeeze()) if f is not None else jnp.zeros(shape=(nl, 6))
90
+ W_Q_B = (
91
+ jnp.atleast_1d(W_Q_B.squeeze())
92
+ if W_Q_B is not None
93
+ else jnp.array([1.0, 0, 0, 0])
94
+ )
95
+ standard_gravity = (
96
+ jnp.array(standard_gravity).squeeze()
97
+ if standard_gravity is not None
98
+ else StandardGravity
99
+ )
100
+
101
+ if s.shape != (dofs,):
102
+ raise ValueError(s.shape, dofs)
103
+
104
+ if ṡ.shape != (dofs,):
105
+ raise ValueError(ṡ.shape, dofs)
106
+
107
+ if s̈.shape != (dofs,):
108
+ raise ValueError(s̈.shape, dofs)
109
+
110
+ if τ.shape != (dofs,):
111
+ raise ValueError(τ.shape, dofs)
112
+
113
+ if W_p_B.shape != (3,):
114
+ raise ValueError(W_p_B.shape, (3,))
115
+
116
+ if W_vl_WB.shape != (3,):
117
+ raise ValueError(W_vl_WB.shape, (3,))
118
+
119
+ if W_ω_WB.shape != (3,):
120
+ raise ValueError(W_ω_WB.shape, (3,))
121
+
122
+ if W_v̇l_WB.shape != (3,):
123
+ raise ValueError(W_v̇l_WB.shape, (3,))
124
+
125
+ if W_ω̇_WB.shape != (3,):
126
+ raise ValueError(W_ω̇_WB.shape, (3,))
127
+
128
+ if f.shape != (nl, 6):
129
+ raise ValueError(f.shape, (nl, 6))
130
+
131
+ if W_Q_B.shape != (4,):
132
+ raise ValueError(W_Q_B.shape, (4,))
133
+
134
+ # Pack the 6D base velocity and acceleration.
135
+ W_v_WB = jnp.hstack([W_vl_WB, W_ω_WB])
136
+ W_v̇_WB = jnp.hstack([W_v̇l_WB, W_ω̇_WB])
137
+
138
+ # Create the 6D gravity acceleration.
139
+ W_g = jnp.zeros(6).at[2].set(-standard_gravity)
140
+
141
+ return (
142
+ W_p_B.astype(float),
143
+ W_Q_B.astype(float),
144
+ s.astype(float),
145
+ W_v_WB.astype(float),
146
+ ṡ.astype(float),
147
+ W_v̇_WB.astype(float),
148
+ s̈.astype(float),
149
+ τ.astype(float),
150
+ f.astype(float),
151
+ W_g.astype(float),
152
+ )
@@ -0,0 +1,2 @@
1
+ from . import terrain
2
+ from .terrain import FlatTerrain, Terrain
@@ -46,23 +46,21 @@ class FlatTerrain(Terrain):
46
46
 
47
47
  @jax_dataclasses.pytree_dataclass
48
48
  class PlaneTerrain(Terrain):
49
- plane_normal: jtp.Vector = jax_dataclasses.field(
50
- default_factory=lambda: jnp.array([0, 0, 1.0])
51
- )
49
+ plane_normal: list = jax_dataclasses.field(default_factory=lambda: [0, 0, 1.0])
52
50
 
53
51
  @staticmethod
54
- def build(plane_normal: jtp.Vector) -> "PlaneTerrain":
52
+ def build(plane_normal: list) -> "PlaneTerrain":
55
53
  """
56
54
  Create a PlaneTerrain instance with a specified plane normal vector.
57
55
 
58
56
  Args:
59
- plane_normal (jtp.Vector): The normal vector of the terrain plane.
57
+ plane_normal (list): The normal vector of the terrain plane.
60
58
 
61
59
  Returns:
62
60
  PlaneTerrain: A PlaneTerrain instance.
63
61
  """
64
62
 
65
- return PlaneTerrain(plane_normal=jnp.array(plane_normal, dtype=float))
63
+ return PlaneTerrain(plane_normal=plane_normal)
66
64
 
67
65
  def height(self, x: float, y: float) -> float:
68
66
  """
jaxsim/utils/__init__.py CHANGED
@@ -1,8 +1,5 @@
1
1
  from jax_dataclasses._copy_and_mutate import _Mutability as Mutability
2
2
 
3
+ from .hashless import HashlessObject
3
4
  from .jaxsim_dataclass import JaxsimDataclass
4
5
  from .tracing import not_tracing, tracing
5
- from .vmappable import Vmappable
6
-
7
- # Leave this below the others to prevent circular imports
8
- from .oop import jax_tf # isort: skip
@@ -0,0 +1,18 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from typing import Generic, TypeVar
5
+
6
+ T = TypeVar("T")
7
+
8
+
9
+ @dataclasses.dataclass
10
+ class HashlessObject(Generic[T]):
11
+
12
+ obj: T
13
+
14
+ def get(self: HashlessObject[T]) -> T:
15
+ return self.obj
16
+
17
+ def __hash__(self) -> int:
18
+ return 0