jaxsim 0.4.3.dev68__py3-none-any.whl → 0.4.3.dev70__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -223,7 +223,7 @@ def extract_model_data(
223
223
  child=links_dict[j.child],
224
224
  jtype=utils.joint_to_joint_type(joint=j),
225
225
  axis=(
226
- np.array(j.axis.xyz.xyz)
226
+ np.array(j.axis.xyz.xyz, dtype=float)
227
227
  if j.axis is not None
228
228
  and j.axis.xyz is not None
229
229
  and j.axis.xyz.xyz is not None
@@ -232,39 +232,43 @@ def extract_model_data(
232
232
  pose=j.pose.transform() if j.pose is not None else np.eye(4),
233
233
  initial_position=0.0,
234
234
  position_limit=(
235
- (
236
- float(j.axis.limit.lower)
237
- if j.axis is not None and j.axis.limit is not None
238
- else np.finfo(float).min
235
+ float(
236
+ j.axis.limit.lower
237
+ if j.axis is not None
238
+ and j.axis.limit is not None
239
+ and j.axis.limit.lower is not None
240
+ else jnp.finfo(float).min
239
241
  ),
240
- (
241
- float(j.axis.limit.upper)
242
- if j.axis is not None and j.axis.limit is not None
243
- else np.finfo(float).max
242
+ float(
243
+ j.axis.limit.upper
244
+ if j.axis is not None
245
+ and j.axis.limit is not None
246
+ and j.axis.limit.upper is not None
247
+ else jnp.finfo(float).max
244
248
  ),
245
249
  ),
246
- friction_static=(
250
+ friction_static=float(
247
251
  j.axis.dynamics.friction
248
252
  if j.axis is not None
249
253
  and j.axis.dynamics is not None
250
254
  and j.axis.dynamics.friction is not None
251
255
  else 0.0
252
256
  ),
253
- friction_viscous=(
257
+ friction_viscous=float(
254
258
  j.axis.dynamics.damping
255
259
  if j.axis is not None
256
260
  and j.axis.dynamics is not None
257
261
  and j.axis.dynamics.damping is not None
258
262
  else 0.0
259
263
  ),
260
- position_limit_damper=(
264
+ position_limit_damper=float(
261
265
  j.axis.limit.dissipation
262
266
  if j.axis is not None
263
267
  and j.axis.limit is not None
264
268
  and j.axis.limit.dissipation is not None
265
269
  else 0.0
266
270
  ),
267
- position_limit_spring=(
271
+ position_limit_spring=float(
268
272
  j.axis.limit.stiffness
269
273
  if j.axis is not None
270
274
  and j.axis.limit is not None
@@ -273,7 +277,7 @@ def extract_model_data(
273
277
  ),
274
278
  )
275
279
  for j in sdf_model.joints()
276
- if j.type in {"revolute", "prismatic", "fixed"}
280
+ if j.type in {"revolute", "continuous", "prismatic", "fixed"}
277
281
  and j.parent != "world"
278
282
  and j.child in links_dict.keys()
279
283
  ]
@@ -0,0 +1,384 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from typing import Any
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import jax_dataclasses
9
+ import jaxopt
10
+
11
+ import jaxsim.api as js
12
+ import jaxsim.typing as jtp
13
+ from jaxsim.api.common import VelRepr
14
+ from jaxsim.math import Adjoint
15
+ from jaxsim.terrain.terrain import FlatTerrain, Terrain
16
+
17
+ from .common import ContactModel, ContactsParams, ContactsState
18
+
19
+
20
+ @jax_dataclasses.pytree_dataclass
21
+ class RelaxedRigidContactsParams(ContactsParams):
22
+ """Parameters of the relaxed rigid contacts model."""
23
+
24
+ # Time constant
25
+ time_constant: jtp.Float = dataclasses.field(
26
+ default_factory=lambda: jnp.array(0.01, dtype=float)
27
+ )
28
+
29
+ # Adimensional damping coefficient
30
+ damping_coefficient: jtp.Float = dataclasses.field(
31
+ default_factory=lambda: jnp.array(1.0, dtype=float)
32
+ )
33
+
34
+ # Minimum impedance
35
+ d_min: jtp.Float = dataclasses.field(
36
+ default_factory=lambda: jnp.array(0.9, dtype=float)
37
+ )
38
+
39
+ # Maximum impedance
40
+ d_max: jtp.Float = dataclasses.field(
41
+ default_factory=lambda: jnp.array(0.95, dtype=float)
42
+ )
43
+
44
+ # Width
45
+ width: jtp.Float = dataclasses.field(
46
+ default_factory=lambda: jnp.array(0.0001, dtype=float)
47
+ )
48
+
49
+ # Midpoint
50
+ midpoint: jtp.Float = dataclasses.field(
51
+ default_factory=lambda: jnp.array(0.1, dtype=float)
52
+ )
53
+
54
+ # Power exponent
55
+ power: jtp.Float = dataclasses.field(
56
+ default_factory=lambda: jnp.array(1.0, dtype=float)
57
+ )
58
+
59
+ # Stiffness
60
+ stiffness: jtp.Float = dataclasses.field(
61
+ default_factory=lambda: jnp.array(0.0, dtype=float)
62
+ )
63
+
64
+ # Damping
65
+ damping: jtp.Float = dataclasses.field(
66
+ default_factory=lambda: jnp.array(0.0, dtype=float)
67
+ )
68
+
69
+ # Friction coefficient
70
+ mu: jtp.Float = dataclasses.field(
71
+ default_factory=lambda: jnp.array(0.5, dtype=float)
72
+ )
73
+
74
+ # Maximum number of iterations
75
+ max_iterations: jtp.Int = dataclasses.field(
76
+ default_factory=lambda: jnp.array(50, dtype=int)
77
+ )
78
+
79
+ # Solver tolerance
80
+ tolerance: jtp.Float = dataclasses.field(
81
+ default_factory=lambda: jnp.array(1e-6, dtype=float)
82
+ )
83
+
84
+ def __hash__(self) -> int:
85
+ from jaxsim.utils.wrappers import HashedNumpyArray
86
+
87
+ return hash(
88
+ (
89
+ HashedNumpyArray(self.time_constant),
90
+ HashedNumpyArray(self.damping_coefficient),
91
+ HashedNumpyArray(self.d_min),
92
+ HashedNumpyArray(self.d_max),
93
+ HashedNumpyArray(self.width),
94
+ HashedNumpyArray(self.midpoint),
95
+ HashedNumpyArray(self.power),
96
+ HashedNumpyArray(self.stiffness),
97
+ HashedNumpyArray(self.damping),
98
+ HashedNumpyArray(self.mu),
99
+ HashedNumpyArray(self.max_iterations),
100
+ HashedNumpyArray(self.tolerance),
101
+ )
102
+ )
103
+
104
+ def __eq__(self, other: RelaxedRigidContactsParams) -> bool:
105
+ return hash(self) == hash(other)
106
+
107
+ @classmethod
108
+ def build(
109
+ cls,
110
+ time_constant: jtp.FloatLike | None = None,
111
+ damping_coefficient: jtp.FloatLike | None = None,
112
+ d_min: jtp.FloatLike | None = None,
113
+ d_max: jtp.FloatLike | None = None,
114
+ width: jtp.FloatLike | None = None,
115
+ midpoint: jtp.FloatLike | None = None,
116
+ power: jtp.FloatLike | None = None,
117
+ stiffness: jtp.FloatLike | None = None,
118
+ damping: jtp.FloatLike | None = None,
119
+ mu: jtp.FloatLike | None = None,
120
+ max_iterations: jtp.IntLike | None = None,
121
+ tolerance: jtp.FloatLike | None = None,
122
+ ) -> RelaxedRigidContactsParams:
123
+ """Create a `RelaxedRigidContactsParams` instance"""
124
+
125
+ return cls(
126
+ **{
127
+ field: jnp.array(locals().get(field, default), dtype=default.dtype)
128
+ for field, default in map(
129
+ lambda f: (f, cls.__dataclass_fields__[f].default),
130
+ filter(lambda f: f != "__mutability__", cls.__dataclass_fields__),
131
+ )
132
+ }
133
+ )
134
+
135
+ def valid(self) -> bool:
136
+ return bool(
137
+ jnp.all(self.time_constant >= 0.0)
138
+ and jnp.all(self.damping_coefficient > 0.0)
139
+ and jnp.all(self.d_min >= 0.0)
140
+ and jnp.all(self.d_max <= 1.0)
141
+ and jnp.all(self.d_min <= self.d_max)
142
+ and jnp.all(self.width >= 0.0)
143
+ and jnp.all(self.midpoint >= 0.0)
144
+ and jnp.all(self.power >= 0.0)
145
+ and jnp.all(self.mu >= 0.0)
146
+ and jnp.all(self.max_iterations > 0)
147
+ and jnp.all(self.tolerance > 0.0)
148
+ )
149
+
150
+
151
+ @jax_dataclasses.pytree_dataclass
152
+ class RelaxedRigidContactsState(ContactsState):
153
+ """Class storing the state of the relaxed rigid contacts model."""
154
+
155
+ def __eq__(self, other: RelaxedRigidContactsState) -> bool:
156
+ return hash(self) == hash(other)
157
+
158
+ @staticmethod
159
+ def build() -> RelaxedRigidContactsState:
160
+ """Create a `RelaxedRigidContactsState` instance"""
161
+
162
+ return RelaxedRigidContactsState()
163
+
164
+ @staticmethod
165
+ def zero(model: js.model.JaxSimModel) -> RelaxedRigidContactsState:
166
+ """Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`."""
167
+ return RelaxedRigidContactsState.build()
168
+
169
+ def valid(self, model: js.model.JaxSimModel) -> bool:
170
+ return True
171
+
172
+
173
+ @jax_dataclasses.pytree_dataclass
174
+ class RelaxedRigidContacts(ContactModel):
175
+ """Relaxed rigid contacts model."""
176
+
177
+ parameters: RelaxedRigidContactsParams = dataclasses.field(
178
+ default_factory=RelaxedRigidContactsParams
179
+ )
180
+
181
+ terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
182
+ default_factory=FlatTerrain
183
+ )
184
+
185
+ def compute_contact_forces(
186
+ self,
187
+ position: jtp.Vector,
188
+ velocity: jtp.Vector,
189
+ model: js.model.JaxSimModel,
190
+ data: js.data.JaxSimModelData,
191
+ link_forces: jtp.MatrixLike | None = None,
192
+ ) -> tuple[jtp.Vector, tuple[Any, ...]]:
193
+
194
+ link_forces = (
195
+ link_forces
196
+ if link_forces is not None
197
+ else jnp.zeros((model.number_of_links(), 6))
198
+ )
199
+
200
+ references = js.references.JaxSimModelReferences.build(
201
+ model=model,
202
+ data=data,
203
+ velocity_representation=data.velocity_representation,
204
+ link_forces=link_forces,
205
+ )
206
+
207
+ def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
208
+ x, y, z = jax.tree_map(jnp.squeeze, (x, y, z))
209
+
210
+ n̂ = self.terrain.normal(x=x, y=y).squeeze()
211
+ h = jnp.array([0, 0, z - model.terrain.height(x=x, y=y)])
212
+
213
+ return jnp.dot(h, n̂)
214
+
215
+ # Compute the activation state of the collidable points
216
+ δ = jax.vmap(_detect_contact)(*position.T)
217
+
218
+ with (
219
+ references.switch_velocity_representation(VelRepr.Mixed),
220
+ data.switch_velocity_representation(VelRepr.Mixed),
221
+ ):
222
+ M = js.model.free_floating_mass_matrix(model=model, data=data)
223
+ Jl_WC = jnp.vstack(
224
+ jax.vmap(lambda J, height: J * (height < 0))(
225
+ js.contact.jacobian(model=model, data=data)[:, :3, :], δ
226
+ )
227
+ )
228
+ W_H_C = js.contact.transforms(model=model, data=data)
229
+ BW_ν̇_free = jnp.hstack(
230
+ js.ode.system_acceleration(
231
+ model=model,
232
+ data=data,
233
+ link_forces=references.link_forces(model=model, data=data),
234
+ )
235
+ )
236
+ BW_ν = data.generalized_velocity()
237
+ J̇_WC = jnp.vstack(
238
+ jax.vmap(lambda J̇, height: J̇ * (height < 0))(
239
+ js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
240
+ ),
241
+ )
242
+
243
+ a_ref, R, K, D = self._regularizers(
244
+ model=model,
245
+ penetration=δ,
246
+ velocity=velocity,
247
+ parameters=self.parameters,
248
+ )
249
+
250
+ G = Jl_WC @ jnp.linalg.lstsq(M, Jl_WC.T)[0]
251
+ CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇_WC @ BW_ν
252
+
253
+ # Calculate quantities for the linear optimization problem.
254
+ A = G + R
255
+ b = CW_al_free_WC - a_ref
256
+
257
+ objective = lambda x: jnp.sum(jnp.square(A @ x + b))
258
+
259
+ # Compute the 3D linear force in C[W] frame
260
+ opt = jaxopt.LBFGS(
261
+ fun=objective,
262
+ maxiter=self.parameters.max_iterations,
263
+ tol=self.parameters.tolerance,
264
+ maxls=30,
265
+ history_size=10,
266
+ max_stepsize=100.0,
267
+ )
268
+
269
+ init_params = (
270
+ K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ)
271
+ + D[:, jnp.newaxis] * velocity
272
+ ).flatten()
273
+
274
+ CW_f_Ci = opt.run(init_params=init_params).params.reshape(-1, 3)
275
+
276
+ def mixed_to_inertial(W_H_C: jax.Array, CW_fl: jax.Array) -> jax.Array:
277
+ W_Xf_CW = Adjoint.from_transform(
278
+ W_H_C.at[0:3, 0:3].set(jnp.eye(3)),
279
+ inverse=True,
280
+ ).T
281
+ return W_Xf_CW @ jnp.hstack([CW_fl, jnp.zeros(3)])
282
+
283
+ W_f_C = jax.vmap(mixed_to_inertial)(W_H_C, CW_f_Ci)
284
+
285
+ return W_f_C, (None,)
286
+
287
+ @staticmethod
288
+ def _regularizers(
289
+ model: js.model.JaxSimModel,
290
+ penetration: jtp.Array,
291
+ velocity: jtp.Array,
292
+ parameters: RelaxedRigidContactsParams,
293
+ ) -> tuple:
294
+ """
295
+ Compute the contact jacobian and the reference acceleration.
296
+
297
+ Args:
298
+ model: The jaxsim model.
299
+ penetration: The penetration of the collidable points.
300
+ velocity: The velocity of the collidable points.
301
+ parameters: The parameters of the relaxed rigid contacts model.
302
+
303
+ Returns:
304
+ A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping.
305
+ """
306
+
307
+ Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ, *_ = jax_dataclasses.astuple(
308
+ parameters
309
+ )
310
+
311
+ def _imp_aref(
312
+ penetration: jtp.Array,
313
+ velocity: jtp.Array,
314
+ ) -> tuple[jtp.Array, jtp.Array]:
315
+ """
316
+ Calculates impedance and offset acceleration in constraint frame.
317
+
318
+ Args:
319
+ penetration: penetration in constraint frame
320
+ velocity: velocity in constraint frame
321
+
322
+ Returns:
323
+ a_ref: offset acceleration in constraint frame
324
+ R: regularization matrix
325
+ K: computed stiffness
326
+ D: computed damping
327
+ """
328
+ position = jnp.zeros(shape=(3,)).at[2].set(penetration)
329
+
330
+ imp_x = jnp.abs(position) / width
331
+ imp_a = (1.0 / jnp.power(mid, p - 1)) * jnp.power(imp_x, p)
332
+
333
+ imp_b = 1 - (1.0 / jnp.power(1 - mid, p - 1)) * jnp.power(1 - imp_x, p)
334
+
335
+ imp_y = jnp.where(imp_x < mid, imp_a, imp_b)
336
+
337
+ imp = jnp.clip(ξ_min + imp_y * (ξ_max - ξ_min), ξ_min, ξ_max)
338
+ imp = jnp.atleast_1d(jnp.where(imp_x > 1.0, ξ_max, imp))
339
+
340
+ # When passing negative values, K and D represent a spring and damper, respectively.
341
+ K_f = jnp.where(K < 0, -K / ξ_max**2, 1 / (ξ_max * Ω * ζ) ** 2)
342
+ D_f = jnp.where(D < 0, -D / ξ_max, 2 / (ξ_max * Ω))
343
+
344
+ a_ref = -jnp.atleast_1d(D_f * velocity + K_f * imp * position)
345
+
346
+ return imp, a_ref, jnp.atleast_1d(K_f), jnp.atleast_1d(D_f)
347
+
348
+ def _compute_row(
349
+ *,
350
+ link_idx: jtp.Float,
351
+ penetration: jtp.Array,
352
+ velocity: jtp.Array,
353
+ ) -> tuple[jtp.Array, jtp.Array]:
354
+
355
+ # Compute the reference acceleration.
356
+ ξ, a_ref, K, D = _imp_aref(
357
+ penetration=penetration,
358
+ velocity=velocity,
359
+ )
360
+
361
+ # Compute the regularization terms.
362
+ R = (
363
+ (2 * μ**2 * (1 - ξ) / (ξ + 1e-12))
364
+ * (1 + μ**2)
365
+ @ jnp.linalg.inv(M_L[link_idx, :3, :3])
366
+ )
367
+
368
+ return jax.tree.map(lambda x: x * (penetration < 0), (a_ref, R, K, D))
369
+
370
+ M_L = js.model.link_spatial_inertia_matrices(model=model)
371
+
372
+ a_ref, R, K, D = jax.tree.map(
373
+ jnp.concatenate,
374
+ (
375
+ *jax.vmap(_compute_row)(
376
+ link_idx=jnp.array(
377
+ model.kin_dyn_parameters.contact_parameters.body
378
+ ),
379
+ penetration=penetration,
380
+ velocity=velocity,
381
+ ),
382
+ ),
383
+ )
384
+ return a_ref, jnp.diag(R), K, D
@@ -9,7 +9,6 @@ import jax_dataclasses
9
9
 
10
10
  import jaxsim.api as js
11
11
  import jaxsim.typing as jtp
12
- from jaxsim import math
13
12
  from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
14
13
  from jaxsim.terrain import FlatTerrain, Terrain
15
14
 
@@ -272,9 +271,17 @@ class RigidContacts(ContactModel):
272
271
  link_forces=link_forces,
273
272
  )
274
273
 
275
- with references.switch_velocity_representation(VelRepr.Mixed):
276
- BW_ν̇_free = RigidContacts._compute_mixed_nu_dot_free(
277
- model, data, references=references
274
+ with (
275
+ references.switch_velocity_representation(VelRepr.Mixed),
276
+ data.switch_velocity_representation(VelRepr.Mixed),
277
+ ):
278
+ BW_ν̇_free = jnp.hstack(
279
+ js.ode.system_acceleration(
280
+ model=model,
281
+ data=data,
282
+ joint_forces=references.joint_force_references(model=model),
283
+ link_forces=references.link_forces(model=model, data=data),
284
+ )
278
285
  )
279
286
 
280
287
  free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
@@ -380,43 +387,6 @@ class RigidContacts(ContactModel):
380
387
  n_constraints = 6 * n_collidable_points
381
388
  return jnp.zeros(shape=(n_constraints,))
382
389
 
383
- @staticmethod
384
- def _compute_mixed_nu_dot_free(
385
- model: js.model.JaxSimModel,
386
- data: js.data.JaxSimModelData,
387
- references: js.references.JaxSimModelReferences | None = None,
388
- ) -> jtp.Array:
389
- references = (
390
- references
391
- if references is not None
392
- else js.references.JaxSimModelReferences.zero(model=model, data=data)
393
- )
394
-
395
- with (
396
- data.switch_velocity_representation(VelRepr.Mixed),
397
- references.switch_velocity_representation(VelRepr.Mixed),
398
- ):
399
- BW_v_WB = data.base_velocity()
400
- W_ṗ_B, W_ω_WB = jnp.split(BW_v_WB, 2)
401
- W_v̇_WB, s̈ = js.ode.system_acceleration(
402
- model=model,
403
- data=data,
404
- joint_forces=references.joint_force_references(model=model),
405
- link_forces=references.link_forces(model=model, data=data),
406
- )
407
-
408
- # Convert the inertial-fixed base acceleration to a mixed base acceleration.
409
- W_H_B = data.base_transform()
410
- W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
411
- BW_X_W = math.Adjoint.from_transform(W_H_BW, inverse=True)
412
- term1 = BW_X_W @ W_v̇_WB
413
- term2 = jnp.zeros(6).at[0:3].set(jnp.cross(W_ṗ_B, W_ω_WB))
414
- BW_v̇_WB = term1 - term2
415
-
416
- BW_ν̇ = jnp.hstack([BW_v̇_WB, s̈])
417
-
418
- return BW_ν̇
419
-
420
390
  @staticmethod
421
391
  def _linear_acceleration_of_collidable_points(
422
392
  model: js.model.JaxSimModel,
jaxsim/terrain/terrain.py CHANGED
@@ -46,66 +46,82 @@ class Terrain(abc.ABC):
46
46
  @jax_dataclasses.pytree_dataclass
47
47
  class FlatTerrain(Terrain):
48
48
 
49
- z: float = dataclasses.field(default=0.0, kw_only=True)
49
+ _height: float = dataclasses.field(default=0.0, kw_only=True)
50
50
 
51
51
  @staticmethod
52
52
  def build(height: jtp.FloatLike) -> FlatTerrain:
53
53
 
54
- return FlatTerrain(z=float(height))
54
+ return FlatTerrain(_height=float(height))
55
55
 
56
56
  def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
57
57
 
58
- return jnp.array(self.z, dtype=float)
58
+ return jnp.array(self._height, dtype=float)
59
+
60
+ def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
61
+
62
+ return jnp.array([0.0, 0.0, 1.0], dtype=float)
59
63
 
60
64
  def __hash__(self) -> int:
61
65
 
62
- return hash(self.z)
66
+ return hash(self._height)
63
67
 
64
68
  def __eq__(self, other: FlatTerrain) -> bool:
65
69
 
66
70
  if not isinstance(other, FlatTerrain):
67
71
  return False
68
72
 
69
- return self.z == other.z
73
+ return self._height == other._height
70
74
 
71
75
 
72
76
  @jax_dataclasses.pytree_dataclass
73
77
  class PlaneTerrain(FlatTerrain):
74
78
 
75
- plane_normal: tuple[float, float, float] = jax_dataclasses.field(
79
+ _normal: tuple[float, float, float] = jax_dataclasses.field(
76
80
  default=(0.0, 0.0, 1.0), kw_only=True
77
81
  )
78
82
 
79
83
  @staticmethod
80
- def build(
81
- plane_normal: jtp.VectorLike, plane_height_over_origin: jtp.FloatLike = 0.0
82
- ) -> PlaneTerrain:
84
+ def build(height: jtp.FloatLike = 0.0, *, normal: jtp.VectorLike) -> PlaneTerrain:
83
85
  """
84
86
  Create a PlaneTerrain instance with a specified plane normal vector.
85
87
 
86
88
  Args:
87
- plane_normal: The normal vector of the terrain plane.
88
- plane_height_over_origin: The height of the plane over the origin.
89
+ normal: The normal vector of the terrain plane.
90
+ height: The height of the plane over the origin.
89
91
 
90
92
  Returns:
91
93
  PlaneTerrain: A PlaneTerrain instance.
92
94
  """
93
95
 
94
- plane_normal = jnp.array(plane_normal, dtype=float)
95
- plane_height_over_origin = jnp.array(plane_height_over_origin, dtype=float)
96
+ normal = jnp.array(normal, dtype=float)
97
+ height = jnp.array(height, dtype=float)
96
98
 
97
- if plane_normal.shape != (3,):
99
+ if normal.shape != (3,):
98
100
  msg = "Expected a 3D vector for the plane normal, got '{}'."
99
- raise ValueError(msg.format(plane_normal.shape))
101
+ raise ValueError(msg.format(normal.shape))
100
102
 
101
103
  # Make sure that the plane normal is a unit vector.
102
- plane_normal = plane_normal / jnp.linalg.norm(plane_normal)
104
+ normal = normal / jnp.linalg.norm(normal)
103
105
 
104
106
  return PlaneTerrain(
105
- z=float(plane_height_over_origin),
106
- plane_normal=tuple(plane_normal.tolist()),
107
+ _height=height.item(),
108
+ _normal=tuple(normal.tolist()),
107
109
  )
108
110
 
111
+ def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
112
+ """
113
+ Compute the normal vector of the terrain at a specific (x, y) location.
114
+
115
+ Args:
116
+ x: The x-coordinate of the location.
117
+ y: The y-coordinate of the location.
118
+
119
+ Returns:
120
+ The normal vector of the terrain surface at the specified location.
121
+ """
122
+
123
+ return jnp.array(self._normal, dtype=float)
124
+
109
125
  def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
110
126
  """
111
127
  Compute the height of the terrain at a specific (x, y) location on a plane.
@@ -123,10 +139,10 @@ class PlaneTerrain(FlatTerrain):
123
139
  # The height over the origin: -D/C
124
140
 
125
141
  # Get the plane equation coefficients from the terrain normal.
126
- A, B, C = self.plane_normal
142
+ A, B, C = self._normal
127
143
 
128
144
  # Compute the final coefficient D considering the terrain height.
129
- D = -C * self.z
145
+ D = -C * self._height
130
146
 
131
147
  # Invert the plane equation to get the height at the given (x, y) coordinates.
132
148
  return jnp.array(-(A * x + B * y + D) / C).astype(float)
@@ -137,9 +153,9 @@ class PlaneTerrain(FlatTerrain):
137
153
 
138
154
  return hash(
139
155
  (
140
- hash(self.z),
156
+ hash(self._height),
141
157
  HashedNumpyArray.hash_of_array(
142
- array=jnp.array(self.plane_normal, dtype=float)
158
+ array=jnp.array(self._normal, dtype=float)
143
159
  ),
144
160
  )
145
161
  )
@@ -150,10 +166,10 @@ class PlaneTerrain(FlatTerrain):
150
166
  return False
151
167
 
152
168
  if not (
153
- np.allclose(self.z, other.z)
169
+ np.allclose(self._height, other._height)
154
170
  and np.allclose(
155
- np.array(self.plane_normal, dtype=float),
156
- np.array(other.plane_normal, dtype=float),
171
+ np.array(self._normal, dtype=float),
172
+ np.array(other._normal, dtype=float),
157
173
  )
158
174
  ):
159
175
  return False
jaxsim/typing.py CHANGED
@@ -16,7 +16,7 @@ Int = Scalar
16
16
  Bool = Scalar
17
17
  Float = Scalar
18
18
 
19
- PyTree = (
19
+ PyTree: object = (
20
20
  dict[Hashable, TypeVar("PyTree")]
21
21
  | list[TypeVar("PyTree")]
22
22
  | tuple[TypeVar("PyTree")]