jaxsim 0.4.3.dev68__py3-none-any.whl → 0.4.3.dev77__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -223,7 +223,7 @@ def extract_model_data(
223
223
  child=links_dict[j.child],
224
224
  jtype=utils.joint_to_joint_type(joint=j),
225
225
  axis=(
226
- np.array(j.axis.xyz.xyz)
226
+ np.array(j.axis.xyz.xyz, dtype=float)
227
227
  if j.axis is not None
228
228
  and j.axis.xyz is not None
229
229
  and j.axis.xyz.xyz is not None
@@ -232,39 +232,43 @@ def extract_model_data(
232
232
  pose=j.pose.transform() if j.pose is not None else np.eye(4),
233
233
  initial_position=0.0,
234
234
  position_limit=(
235
- (
236
- float(j.axis.limit.lower)
237
- if j.axis is not None and j.axis.limit is not None
238
- else np.finfo(float).min
235
+ float(
236
+ j.axis.limit.lower
237
+ if j.axis is not None
238
+ and j.axis.limit is not None
239
+ and j.axis.limit.lower is not None
240
+ else jnp.finfo(float).min
239
241
  ),
240
- (
241
- float(j.axis.limit.upper)
242
- if j.axis is not None and j.axis.limit is not None
243
- else np.finfo(float).max
242
+ float(
243
+ j.axis.limit.upper
244
+ if j.axis is not None
245
+ and j.axis.limit is not None
246
+ and j.axis.limit.upper is not None
247
+ else jnp.finfo(float).max
244
248
  ),
245
249
  ),
246
- friction_static=(
250
+ friction_static=float(
247
251
  j.axis.dynamics.friction
248
252
  if j.axis is not None
249
253
  and j.axis.dynamics is not None
250
254
  and j.axis.dynamics.friction is not None
251
255
  else 0.0
252
256
  ),
253
- friction_viscous=(
257
+ friction_viscous=float(
254
258
  j.axis.dynamics.damping
255
259
  if j.axis is not None
256
260
  and j.axis.dynamics is not None
257
261
  and j.axis.dynamics.damping is not None
258
262
  else 0.0
259
263
  ),
260
- position_limit_damper=(
264
+ position_limit_damper=float(
261
265
  j.axis.limit.dissipation
262
266
  if j.axis is not None
263
267
  and j.axis.limit is not None
264
268
  and j.axis.limit.dissipation is not None
265
269
  else 0.0
266
270
  ),
267
- position_limit_spring=(
271
+ position_limit_spring=float(
268
272
  j.axis.limit.stiffness
269
273
  if j.axis is not None
270
274
  and j.axis.limit is not None
@@ -273,7 +277,7 @@ def extract_model_data(
273
277
  ),
274
278
  )
275
279
  for j in sdf_model.joints()
276
- if j.type in {"revolute", "prismatic", "fixed"}
280
+ if j.type in {"revolute", "continuous", "prismatic", "fixed"}
277
281
  and j.parent != "world"
278
282
  and j.child in links_dict.keys()
279
283
  ]
@@ -0,0 +1,409 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from typing import Any
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import jax_dataclasses
9
+ import jaxopt
10
+
11
+ import jaxsim.api as js
12
+ import jaxsim.typing as jtp
13
+ from jaxsim.api.common import VelRepr
14
+ from jaxsim.math import Adjoint
15
+ from jaxsim.terrain.terrain import FlatTerrain, Terrain
16
+
17
+ from .common import ContactModel, ContactsParams, ContactsState
18
+
19
+
20
+ @jax_dataclasses.pytree_dataclass
21
+ class RelaxedRigidContactsParams(ContactsParams):
22
+ """Parameters of the relaxed rigid contacts model."""
23
+
24
+ # Time constant
25
+ time_constant: jtp.Float = dataclasses.field(
26
+ default_factory=lambda: jnp.array(0.01, dtype=float)
27
+ )
28
+
29
+ # Adimensional damping coefficient
30
+ damping_coefficient: jtp.Float = dataclasses.field(
31
+ default_factory=lambda: jnp.array(1.0, dtype=float)
32
+ )
33
+
34
+ # Minimum impedance
35
+ d_min: jtp.Float = dataclasses.field(
36
+ default_factory=lambda: jnp.array(0.9, dtype=float)
37
+ )
38
+
39
+ # Maximum impedance
40
+ d_max: jtp.Float = dataclasses.field(
41
+ default_factory=lambda: jnp.array(0.95, dtype=float)
42
+ )
43
+
44
+ # Width
45
+ width: jtp.Float = dataclasses.field(
46
+ default_factory=lambda: jnp.array(0.0001, dtype=float)
47
+ )
48
+
49
+ # Midpoint
50
+ midpoint: jtp.Float = dataclasses.field(
51
+ default_factory=lambda: jnp.array(0.1, dtype=float)
52
+ )
53
+
54
+ # Power exponent
55
+ power: jtp.Float = dataclasses.field(
56
+ default_factory=lambda: jnp.array(1.0, dtype=float)
57
+ )
58
+
59
+ # Stiffness
60
+ stiffness: jtp.Float = dataclasses.field(
61
+ default_factory=lambda: jnp.array(0.0, dtype=float)
62
+ )
63
+
64
+ # Damping
65
+ damping: jtp.Float = dataclasses.field(
66
+ default_factory=lambda: jnp.array(0.0, dtype=float)
67
+ )
68
+
69
+ # Friction coefficient
70
+ mu: jtp.Float = dataclasses.field(
71
+ default_factory=lambda: jnp.array(0.5, dtype=float)
72
+ )
73
+
74
+ # Maximum number of iterations
75
+ max_iterations: jtp.Int = dataclasses.field(
76
+ default_factory=lambda: jnp.array(50, dtype=int)
77
+ )
78
+
79
+ # Solver tolerance
80
+ tolerance: jtp.Float = dataclasses.field(
81
+ default_factory=lambda: jnp.array(1e-6, dtype=float)
82
+ )
83
+
84
+ def __hash__(self) -> int:
85
+ from jaxsim.utils.wrappers import HashedNumpyArray
86
+
87
+ return hash(
88
+ (
89
+ HashedNumpyArray(self.time_constant),
90
+ HashedNumpyArray(self.damping_coefficient),
91
+ HashedNumpyArray(self.d_min),
92
+ HashedNumpyArray(self.d_max),
93
+ HashedNumpyArray(self.width),
94
+ HashedNumpyArray(self.midpoint),
95
+ HashedNumpyArray(self.power),
96
+ HashedNumpyArray(self.stiffness),
97
+ HashedNumpyArray(self.damping),
98
+ HashedNumpyArray(self.mu),
99
+ HashedNumpyArray(self.max_iterations),
100
+ HashedNumpyArray(self.tolerance),
101
+ )
102
+ )
103
+
104
+ def __eq__(self, other: RelaxedRigidContactsParams) -> bool:
105
+ return hash(self) == hash(other)
106
+
107
+ @classmethod
108
+ def build(
109
+ cls,
110
+ time_constant: jtp.FloatLike | None = None,
111
+ damping_coefficient: jtp.FloatLike | None = None,
112
+ d_min: jtp.FloatLike | None = None,
113
+ d_max: jtp.FloatLike | None = None,
114
+ width: jtp.FloatLike | None = None,
115
+ midpoint: jtp.FloatLike | None = None,
116
+ power: jtp.FloatLike | None = None,
117
+ stiffness: jtp.FloatLike | None = None,
118
+ damping: jtp.FloatLike | None = None,
119
+ mu: jtp.FloatLike | None = None,
120
+ max_iterations: jtp.IntLike | None = None,
121
+ tolerance: jtp.FloatLike | None = None,
122
+ ) -> RelaxedRigidContactsParams:
123
+ """Create a `RelaxedRigidContactsParams` instance"""
124
+
125
+ return cls(
126
+ **{
127
+ field: jnp.array(locals().get(field, default), dtype=default.dtype)
128
+ for field, default in map(
129
+ lambda f: (f, cls.__dataclass_fields__[f].default),
130
+ filter(lambda f: f != "__mutability__", cls.__dataclass_fields__),
131
+ )
132
+ }
133
+ )
134
+
135
+ def valid(self) -> bool:
136
+ return bool(
137
+ jnp.all(self.time_constant >= 0.0)
138
+ and jnp.all(self.damping_coefficient > 0.0)
139
+ and jnp.all(self.d_min >= 0.0)
140
+ and jnp.all(self.d_max <= 1.0)
141
+ and jnp.all(self.d_min <= self.d_max)
142
+ and jnp.all(self.width >= 0.0)
143
+ and jnp.all(self.midpoint >= 0.0)
144
+ and jnp.all(self.power >= 0.0)
145
+ and jnp.all(self.mu >= 0.0)
146
+ and jnp.all(self.max_iterations > 0)
147
+ and jnp.all(self.tolerance > 0.0)
148
+ )
149
+
150
+
151
+ @jax_dataclasses.pytree_dataclass
152
+ class RelaxedRigidContactsState(ContactsState):
153
+ """Class storing the state of the relaxed rigid contacts model."""
154
+
155
+ def __eq__(self, other: RelaxedRigidContactsState) -> bool:
156
+ return hash(self) == hash(other)
157
+
158
+ @staticmethod
159
+ def build() -> RelaxedRigidContactsState:
160
+ """Create a `RelaxedRigidContactsState` instance"""
161
+
162
+ return RelaxedRigidContactsState()
163
+
164
+ @staticmethod
165
+ def zero(model: js.model.JaxSimModel) -> RelaxedRigidContactsState:
166
+ """Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`."""
167
+ return RelaxedRigidContactsState.build()
168
+
169
+ def valid(self, model: js.model.JaxSimModel) -> bool:
170
+ return True
171
+
172
+
173
+ @jax_dataclasses.pytree_dataclass
174
+ class RelaxedRigidContacts(ContactModel):
175
+ """Relaxed rigid contacts model."""
176
+
177
+ parameters: RelaxedRigidContactsParams = dataclasses.field(
178
+ default_factory=RelaxedRigidContactsParams
179
+ )
180
+
181
+ terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
182
+ default_factory=FlatTerrain
183
+ )
184
+
185
+ def compute_contact_forces(
186
+ self,
187
+ position: jtp.Vector,
188
+ velocity: jtp.Vector,
189
+ model: js.model.JaxSimModel,
190
+ data: js.data.JaxSimModelData,
191
+ link_forces: jtp.MatrixLike | None = None,
192
+ joint_force_references: jtp.VectorLike | None = None,
193
+ ) -> tuple[jtp.Vector, tuple[Any, ...]]:
194
+ """
195
+ Compute the contact forces.
196
+
197
+ Args:
198
+ position: The position of the collidable point.
199
+ velocity: The linear velocity of the collidable point.
200
+ model: The `JaxSimModel` instance.
201
+ data: The `JaxSimModelData` instance.
202
+ link_forces:
203
+ Optional `(n_links, 6)` matrix of external forces acting on the links,
204
+ expressed in the same representation of data.
205
+ joint_force_references:
206
+ Optional `(n_joints,)` vector of joint forces.
207
+
208
+ Returns:
209
+ A tuple containing the contact forces.
210
+ """
211
+
212
+ link_forces = (
213
+ link_forces
214
+ if link_forces is not None
215
+ else jnp.zeros((model.number_of_links(), 6))
216
+ )
217
+
218
+ joint_force_references = (
219
+ joint_force_references
220
+ if joint_force_references is not None
221
+ else jnp.zeros(model.number_of_joints())
222
+ )
223
+
224
+ references = js.references.JaxSimModelReferences.build(
225
+ model=model,
226
+ data=data,
227
+ velocity_representation=data.velocity_representation,
228
+ link_forces=link_forces,
229
+ joint_force_references=joint_force_references,
230
+ )
231
+
232
+ def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
233
+ x, y, z = jax.tree_map(jnp.squeeze, (x, y, z))
234
+
235
+ n̂ = self.terrain.normal(x=x, y=y).squeeze()
236
+ h = jnp.array([0, 0, z - model.terrain.height(x=x, y=y)])
237
+
238
+ return jnp.dot(h, n̂)
239
+
240
+ # Compute the activation state of the collidable points
241
+ δ = jax.vmap(_detect_contact)(*position.T)
242
+
243
+ with (
244
+ references.switch_velocity_representation(VelRepr.Mixed),
245
+ data.switch_velocity_representation(VelRepr.Mixed),
246
+ ):
247
+ M = js.model.free_floating_mass_matrix(model=model, data=data)
248
+ Jl_WC = jnp.vstack(
249
+ jax.vmap(lambda J, height: J * (height < 0))(
250
+ js.contact.jacobian(model=model, data=data)[:, :3, :], δ
251
+ )
252
+ )
253
+ W_H_C = js.contact.transforms(model=model, data=data)
254
+ BW_ν̇_free = jnp.hstack(
255
+ js.ode.system_acceleration(
256
+ model=model,
257
+ data=data,
258
+ link_forces=references.link_forces(model=model, data=data),
259
+ )
260
+ )
261
+ BW_ν = data.generalized_velocity()
262
+ J̇_WC = jnp.vstack(
263
+ jax.vmap(lambda J̇, height: J̇ * (height < 0))(
264
+ js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
265
+ ),
266
+ )
267
+
268
+ a_ref, R, K, D = self._regularizers(
269
+ model=model,
270
+ penetration=δ,
271
+ velocity=velocity,
272
+ parameters=self.parameters,
273
+ )
274
+
275
+ G = Jl_WC @ jnp.linalg.lstsq(M, Jl_WC.T)[0]
276
+ CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇_WC @ BW_ν
277
+
278
+ # Calculate quantities for the linear optimization problem.
279
+ A = G + R
280
+ b = CW_al_free_WC - a_ref
281
+
282
+ objective = lambda x: jnp.sum(jnp.square(A @ x + b))
283
+
284
+ # Compute the 3D linear force in C[W] frame
285
+ opt = jaxopt.LBFGS(
286
+ fun=objective,
287
+ maxiter=self.parameters.max_iterations,
288
+ tol=self.parameters.tolerance,
289
+ maxls=30,
290
+ history_size=10,
291
+ max_stepsize=100.0,
292
+ )
293
+
294
+ init_params = (
295
+ K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ)
296
+ + D[:, jnp.newaxis] * velocity
297
+ ).flatten()
298
+
299
+ CW_f_Ci = opt.run(init_params=init_params).params.reshape(-1, 3)
300
+
301
+ def mixed_to_inertial(W_H_C: jax.Array, CW_fl: jax.Array) -> jax.Array:
302
+ W_Xf_CW = Adjoint.from_transform(
303
+ W_H_C.at[0:3, 0:3].set(jnp.eye(3)),
304
+ inverse=True,
305
+ ).T
306
+ return W_Xf_CW @ jnp.hstack([CW_fl, jnp.zeros(3)])
307
+
308
+ W_f_C = jax.vmap(mixed_to_inertial)(W_H_C, CW_f_Ci)
309
+
310
+ return W_f_C, (None,)
311
+
312
+ @staticmethod
313
+ def _regularizers(
314
+ model: js.model.JaxSimModel,
315
+ penetration: jtp.Array,
316
+ velocity: jtp.Array,
317
+ parameters: RelaxedRigidContactsParams,
318
+ ) -> tuple:
319
+ """
320
+ Compute the contact jacobian and the reference acceleration.
321
+
322
+ Args:
323
+ model: The jaxsim model.
324
+ penetration: The penetration of the collidable points.
325
+ velocity: The velocity of the collidable points.
326
+ parameters: The parameters of the relaxed rigid contacts model.
327
+
328
+ Returns:
329
+ A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping.
330
+ """
331
+
332
+ Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ, *_ = jax_dataclasses.astuple(
333
+ parameters
334
+ )
335
+
336
+ def _imp_aref(
337
+ penetration: jtp.Array,
338
+ velocity: jtp.Array,
339
+ ) -> tuple[jtp.Array, jtp.Array]:
340
+ """
341
+ Calculates impedance and offset acceleration in constraint frame.
342
+
343
+ Args:
344
+ penetration: penetration in constraint frame
345
+ velocity: velocity in constraint frame
346
+
347
+ Returns:
348
+ a_ref: offset acceleration in constraint frame
349
+ R: regularization matrix
350
+ K: computed stiffness
351
+ D: computed damping
352
+ """
353
+ position = jnp.zeros(shape=(3,)).at[2].set(penetration)
354
+
355
+ imp_x = jnp.abs(position) / width
356
+ imp_a = (1.0 / jnp.power(mid, p - 1)) * jnp.power(imp_x, p)
357
+
358
+ imp_b = 1 - (1.0 / jnp.power(1 - mid, p - 1)) * jnp.power(1 - imp_x, p)
359
+
360
+ imp_y = jnp.where(imp_x < mid, imp_a, imp_b)
361
+
362
+ imp = jnp.clip(ξ_min + imp_y * (ξ_max - ξ_min), ξ_min, ξ_max)
363
+ imp = jnp.atleast_1d(jnp.where(imp_x > 1.0, ξ_max, imp))
364
+
365
+ # When passing negative values, K and D represent a spring and damper, respectively.
366
+ K_f = jnp.where(K < 0, -K / ξ_max**2, 1 / (ξ_max * Ω * ζ) ** 2)
367
+ D_f = jnp.where(D < 0, -D / ξ_max, 2 / (ξ_max * Ω))
368
+
369
+ a_ref = -jnp.atleast_1d(D_f * velocity + K_f * imp * position)
370
+
371
+ return imp, a_ref, jnp.atleast_1d(K_f), jnp.atleast_1d(D_f)
372
+
373
+ def _compute_row(
374
+ *,
375
+ link_idx: jtp.Float,
376
+ penetration: jtp.Array,
377
+ velocity: jtp.Array,
378
+ ) -> tuple[jtp.Array, jtp.Array]:
379
+
380
+ # Compute the reference acceleration.
381
+ ξ, a_ref, K, D = _imp_aref(
382
+ penetration=penetration,
383
+ velocity=velocity,
384
+ )
385
+
386
+ # Compute the regularization terms.
387
+ R = (
388
+ (2 * μ**2 * (1 - ξ) / (ξ + 1e-12))
389
+ * (1 + μ**2)
390
+ @ jnp.linalg.inv(M_L[link_idx, :3, :3])
391
+ )
392
+
393
+ return jax.tree.map(lambda x: x * (penetration < 0), (a_ref, R, K, D))
394
+
395
+ M_L = js.model.link_spatial_inertia_matrices(model=model)
396
+
397
+ a_ref, R, K, D = jax.tree.map(
398
+ jnp.concatenate,
399
+ (
400
+ *jax.vmap(_compute_row)(
401
+ link_idx=jnp.array(
402
+ model.kin_dyn_parameters.contact_parameters.body
403
+ ),
404
+ penetration=penetration,
405
+ velocity=velocity,
406
+ ),
407
+ ),
408
+ )
409
+ return a_ref, jnp.diag(R), K, D
@@ -9,7 +9,6 @@ import jax_dataclasses
9
9
 
10
10
  import jaxsim.api as js
11
11
  import jaxsim.typing as jtp
12
- from jaxsim import math
13
12
  from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
14
13
  from jaxsim.terrain import FlatTerrain, Terrain
15
14
 
@@ -214,6 +213,7 @@ class RigidContacts(ContactModel):
214
213
  model: js.model.JaxSimModel,
215
214
  data: js.data.JaxSimModelData,
216
215
  link_forces: jtp.MatrixLike | None = None,
216
+ joint_force_references: jtp.VectorLike | None = None,
217
217
  regularization_term: jtp.FloatLike = 1e-6,
218
218
  ) -> tuple[jtp.Vector, tuple[Any, ...]]:
219
219
  """
@@ -227,6 +227,8 @@ class RigidContacts(ContactModel):
227
227
  link_forces:
228
228
  Optional `(n_links, 6)` matrix of external forces acting on the links,
229
229
  expressed in the same representation of data.
230
+ joint_force_references:
231
+ Optional `(n_joints,)` vector of joint forces.
230
232
  regularization_term:
231
233
  The regularization term to add to the diagonal of the Delassus
232
234
  matrix for better numerical conditioning.
@@ -244,6 +246,12 @@ class RigidContacts(ContactModel):
244
246
  else jnp.zeros((model.number_of_links(), 6))
245
247
  )
246
248
 
249
+ joint_force_references = (
250
+ joint_force_references
251
+ if joint_force_references is not None
252
+ else jnp.zeros((model.number_of_joints(),))
253
+ )
254
+
247
255
  # Compute kin-dyn quantities used in the contact model
248
256
  with data.switch_velocity_representation(VelRepr.Mixed):
249
257
  M = js.model.free_floating_mass_matrix(model=model, data=data)
@@ -270,11 +278,20 @@ class RigidContacts(ContactModel):
270
278
  data=data,
271
279
  velocity_representation=data.velocity_representation,
272
280
  link_forces=link_forces,
281
+ joint_force_references=joint_force_references,
273
282
  )
274
283
 
275
- with references.switch_velocity_representation(VelRepr.Mixed):
276
- BW_ν̇_free = RigidContacts._compute_mixed_nu_dot_free(
277
- model, data, references=references
284
+ with (
285
+ references.switch_velocity_representation(VelRepr.Mixed),
286
+ data.switch_velocity_representation(VelRepr.Mixed),
287
+ ):
288
+ BW_ν̇_free = jnp.hstack(
289
+ js.ode.system_acceleration(
290
+ model=model,
291
+ data=data,
292
+ joint_forces=references.joint_force_references(model=model),
293
+ link_forces=references.link_forces(model=model, data=data),
294
+ )
278
295
  )
279
296
 
280
297
  free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
@@ -380,43 +397,6 @@ class RigidContacts(ContactModel):
380
397
  n_constraints = 6 * n_collidable_points
381
398
  return jnp.zeros(shape=(n_constraints,))
382
399
 
383
- @staticmethod
384
- def _compute_mixed_nu_dot_free(
385
- model: js.model.JaxSimModel,
386
- data: js.data.JaxSimModelData,
387
- references: js.references.JaxSimModelReferences | None = None,
388
- ) -> jtp.Array:
389
- references = (
390
- references
391
- if references is not None
392
- else js.references.JaxSimModelReferences.zero(model=model, data=data)
393
- )
394
-
395
- with (
396
- data.switch_velocity_representation(VelRepr.Mixed),
397
- references.switch_velocity_representation(VelRepr.Mixed),
398
- ):
399
- BW_v_WB = data.base_velocity()
400
- W_ṗ_B, W_ω_WB = jnp.split(BW_v_WB, 2)
401
- W_v̇_WB, s̈ = js.ode.system_acceleration(
402
- model=model,
403
- data=data,
404
- joint_forces=references.joint_force_references(model=model),
405
- link_forces=references.link_forces(model=model, data=data),
406
- )
407
-
408
- # Convert the inertial-fixed base acceleration to a mixed base acceleration.
409
- W_H_B = data.base_transform()
410
- W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
411
- BW_X_W = math.Adjoint.from_transform(W_H_BW, inverse=True)
412
- term1 = BW_X_W @ W_v̇_WB
413
- term2 = jnp.zeros(6).at[0:3].set(jnp.cross(W_ṗ_B, W_ω_WB))
414
- BW_v̇_WB = term1 - term2
415
-
416
- BW_ν̇ = jnp.hstack([BW_v̇_WB, s̈])
417
-
418
- return BW_ν̇
419
-
420
400
  @staticmethod
421
401
  def _linear_acceleration_of_collidable_points(
422
402
  model: js.model.JaxSimModel,