jaxsim 0.6.1.dev13__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. jaxsim/__init__.py +1 -1
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/actuation_model.py +96 -0
  5. jaxsim/api/com.py +8 -8
  6. jaxsim/api/contact.py +15 -255
  7. jaxsim/api/contact_model.py +101 -0
  8. jaxsim/api/data.py +258 -556
  9. jaxsim/api/frame.py +7 -7
  10. jaxsim/api/integrators.py +76 -0
  11. jaxsim/api/kin_dyn_parameters.py +41 -58
  12. jaxsim/api/link.py +7 -7
  13. jaxsim/api/model.py +190 -453
  14. jaxsim/api/ode.py +34 -338
  15. jaxsim/api/references.py +2 -2
  16. jaxsim/exceptions.py +2 -2
  17. jaxsim/math/__init__.py +4 -3
  18. jaxsim/math/joint_model.py +17 -107
  19. jaxsim/mujoco/model.py +1 -1
  20. jaxsim/mujoco/utils.py +2 -2
  21. jaxsim/parsers/kinematic_graph.py +1 -3
  22. jaxsim/rbda/aba.py +7 -4
  23. jaxsim/rbda/collidable_points.py +7 -98
  24. jaxsim/rbda/contacts/__init__.py +2 -10
  25. jaxsim/rbda/contacts/common.py +0 -138
  26. jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
  27. jaxsim/rbda/crba.py +5 -2
  28. jaxsim/rbda/forward_kinematics.py +37 -12
  29. jaxsim/rbda/jacobian.py +15 -6
  30. jaxsim/rbda/rnea.py +7 -4
  31. jaxsim/rbda/utils.py +3 -3
  32. jaxsim/utils/jaxsim_dataclass.py +5 -1
  33. {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
  34. jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
  35. jaxsim/api/ode_data.py +0 -401
  36. jaxsim/integrators/__init__.py +0 -2
  37. jaxsim/integrators/common.py +0 -592
  38. jaxsim/integrators/fixed_step.py +0 -153
  39. jaxsim/integrators/variable_step.py +0 -706
  40. jaxsim/rbda/contacts/rigid.py +0 -462
  41. jaxsim/rbda/contacts/soft.py +0 -480
  42. jaxsim/rbda/contacts/visco_elastic.py +0 -1066
  43. jaxsim-0.6.1.dev13.dist-info/RECORD +0 -74
  44. {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
  45. {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
  46. {jaxsim-0.6.1.dev13.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
@@ -1,480 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import dataclasses
4
- import functools
5
-
6
- import jax
7
- import jax.numpy as jnp
8
- import jax_dataclasses
9
-
10
- import jaxsim.api as js
11
- import jaxsim.math
12
- import jaxsim.typing as jtp
13
- from jaxsim import logging
14
- from jaxsim.math import StandardGravity
15
- from jaxsim.terrain import Terrain
16
-
17
- from . import common
18
-
19
- try:
20
- from typing import Self
21
- except ImportError:
22
- from typing_extensions import Self
23
-
24
-
25
- @jax_dataclasses.pytree_dataclass
26
- class SoftContactsParams(common.ContactsParams):
27
- """Parameters of the soft contacts model."""
28
-
29
- K: jtp.Float = dataclasses.field(
30
- default_factory=lambda: jnp.array(1e6, dtype=float)
31
- )
32
-
33
- D: jtp.Float = dataclasses.field(
34
- default_factory=lambda: jnp.array(2000, dtype=float)
35
- )
36
-
37
- mu: jtp.Float = dataclasses.field(
38
- default_factory=lambda: jnp.array(0.5, dtype=float)
39
- )
40
-
41
- p: jtp.Float = dataclasses.field(
42
- default_factory=lambda: jnp.array(0.5, dtype=float)
43
- )
44
-
45
- q: jtp.Float = dataclasses.field(
46
- default_factory=lambda: jnp.array(0.5, dtype=float)
47
- )
48
-
49
- def __hash__(self) -> int:
50
-
51
- from jaxsim.utils.wrappers import HashedNumpyArray
52
-
53
- return hash(
54
- (
55
- HashedNumpyArray.hash_of_array(self.K),
56
- HashedNumpyArray.hash_of_array(self.D),
57
- HashedNumpyArray.hash_of_array(self.mu),
58
- HashedNumpyArray.hash_of_array(self.p),
59
- HashedNumpyArray.hash_of_array(self.q),
60
- )
61
- )
62
-
63
- def __eq__(self, other: SoftContactsParams) -> bool:
64
-
65
- if not isinstance(other, SoftContactsParams):
66
- return NotImplemented
67
-
68
- return hash(self) == hash(other)
69
-
70
- @classmethod
71
- def build(
72
- cls: type[Self],
73
- *,
74
- K: jtp.FloatLike = 1e6,
75
- D: jtp.FloatLike = 2_000,
76
- mu: jtp.FloatLike = 0.5,
77
- p: jtp.FloatLike = 0.5,
78
- q: jtp.FloatLike = 0.5,
79
- ) -> Self:
80
- """
81
- Create a SoftContactsParams instance with specified parameters.
82
-
83
- Args:
84
- K: The stiffness parameter.
85
- D: The damping parameter of the soft contacts model.
86
- mu: The static friction coefficient.
87
- p:
88
- The exponent p corresponding to the damping-related non-linearity
89
- of the Hunt/Crossley model.
90
- q:
91
- The exponent q corresponding to the spring-related non-linearity
92
- of the Hunt/Crossley model
93
-
94
- Returns:
95
- A SoftContactsParams instance with the specified parameters.
96
- """
97
-
98
- return SoftContactsParams(
99
- K=jnp.array(K, dtype=float),
100
- D=jnp.array(D, dtype=float),
101
- mu=jnp.array(mu, dtype=float),
102
- p=jnp.array(p, dtype=float),
103
- q=jnp.array(q, dtype=float),
104
- )
105
-
106
- @classmethod
107
- def build_default_from_jaxsim_model(
108
- cls: type[Self],
109
- model: js.model.JaxSimModel,
110
- *,
111
- standard_gravity: jtp.FloatLike = StandardGravity,
112
- static_friction_coefficient: jtp.FloatLike = 0.5,
113
- max_penetration: jtp.FloatLike = 0.001,
114
- number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
115
- damping_ratio: jtp.FloatLike = 1.0,
116
- p: jtp.FloatLike = 0.5,
117
- q: jtp.FloatLike = 0.5,
118
- ) -> SoftContactsParams:
119
- """
120
- Create a SoftContactsParams instance with good default parameters.
121
-
122
- Args:
123
- model: The target model.
124
- standard_gravity: The standard gravity constant.
125
- static_friction_coefficient:
126
- The static friction coefficient between the model and the terrain.
127
- max_penetration: The maximum penetration depth.
128
- number_of_active_collidable_points_steady_state:
129
- The number of contacts supporting the weight of the model
130
- in steady state.
131
- damping_ratio: The ratio controlling the damping behavior.
132
- p:
133
- The exponent p corresponding to the damping-related non-linearity
134
- of the Hunt/Crossley model.
135
- q:
136
- The exponent q corresponding to the spring-related non-linearity
137
- of the Hunt/Crossley model
138
-
139
- Returns:
140
- A `SoftContactsParams` instance with the specified parameters.
141
-
142
- Note:
143
- The `damping_ratio` parameter allows to operate on the following conditions:
144
- - ξ > 1.0: over-damped
145
- - ξ = 1.0: critically damped
146
- - ξ < 1.0: under-damped
147
- """
148
-
149
- # Use symbols for input parameters.
150
- ξ = damping_ratio
151
- δ_max = max_penetration
152
- μc = static_friction_coefficient
153
-
154
- # Compute the total mass of the model.
155
- m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum()
156
-
157
- # Rename the standard gravity.
158
- g = standard_gravity
159
-
160
- # Compute the average support force on each collidable point.
161
- f_average = m * g / number_of_active_collidable_points_steady_state
162
-
163
- # Compute the stiffness to get the desired steady-state penetration.
164
- # Note that this is dependent on the non-linear exponent used in
165
- # the damping term of the Hunt/Crossley model.
166
- K = f_average / jnp.power(δ_max, 1 + p)
167
-
168
- # Compute the damping using the damping ratio.
169
- critical_damping = 2 * jnp.sqrt(K * m)
170
- D = ξ * critical_damping
171
-
172
- return SoftContactsParams.build(K=K, D=D, mu=μc, p=p, q=q)
173
-
174
- def valid(self) -> jtp.BoolLike:
175
- """
176
- Check if the parameters are valid.
177
-
178
- Returns:
179
- `True` if the parameters are valid, `False` otherwise.
180
- """
181
-
182
- return jnp.hstack(
183
- [
184
- self.K >= 0.0,
185
- self.D >= 0.0,
186
- self.mu >= 0.0,
187
- self.p >= 0.0,
188
- self.q >= 0.0,
189
- ]
190
- ).all()
191
-
192
-
193
- @jax_dataclasses.pytree_dataclass
194
- class SoftContacts(common.ContactModel):
195
- """Soft contacts model."""
196
-
197
- @classmethod
198
- def build(
199
- cls: type[Self],
200
- model: js.model.JaxSimModel | None = None,
201
- **kwargs,
202
- ) -> Self:
203
- """
204
- Create a `SoftContacts` instance with specified parameters.
205
-
206
- Args:
207
- model:
208
- The robot model considered by the contact model.
209
- If passed, it is used to estimate good default parameters.
210
- **kwargs: Additional parameters to pass to the contact model.
211
-
212
- Returns:
213
- The `SoftContacts` instance.
214
- """
215
-
216
- if len(kwargs) != 0:
217
- logging.debug(msg=f"Ignoring extra arguments: {kwargs}")
218
-
219
- return cls(**kwargs)
220
-
221
- @classmethod
222
- def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:
223
- """
224
- Build zero state variables of the contact model.
225
- """
226
-
227
- # Initialize the material deformation to zero.
228
- tangential_deformation = jnp.zeros(
229
- shape=(len(model.kin_dyn_parameters.contact_parameters.body), 3),
230
- dtype=float,
231
- )
232
-
233
- return {"tangential_deformation": tangential_deformation}
234
-
235
- @staticmethod
236
- @functools.partial(jax.jit, static_argnames=("terrain",))
237
- def hunt_crossley_contact_model(
238
- position: jtp.VectorLike,
239
- velocity: jtp.VectorLike,
240
- tangential_deformation: jtp.VectorLike,
241
- terrain: Terrain,
242
- K: jtp.FloatLike,
243
- D: jtp.FloatLike,
244
- mu: jtp.FloatLike,
245
- p: jtp.FloatLike = 0.5,
246
- q: jtp.FloatLike = 0.5,
247
- ) -> tuple[jtp.Vector, jtp.Vector]:
248
- """
249
- Compute the contact force using the Hunt/Crossley model.
250
-
251
- Args:
252
- position: The position of the collidable point.
253
- velocity: The velocity of the collidable point.
254
- tangential_deformation: The material deformation of the collidable point.
255
- terrain: The terrain model.
256
- K: The stiffness parameter.
257
- D: The damping parameter of the soft contacts model.
258
- mu: The static friction coefficient.
259
- p:
260
- The exponent p corresponding to the damping-related non-linearity
261
- of the Hunt/Crossley model.
262
- q:
263
- The exponent q corresponding to the spring-related non-linearity
264
- of the Hunt/Crossley model
265
-
266
- Returns:
267
- A tuple containing the computed contact force and the derivative of the
268
- material deformation.
269
- """
270
-
271
- # Convert the input vectors to arrays.
272
- W_p_C = jnp.array(position, dtype=float).squeeze()
273
- W_ṗ_C = jnp.array(velocity, dtype=float).squeeze()
274
- m = jnp.array(tangential_deformation, dtype=float).squeeze()
275
-
276
- # Use symbol for the static friction.
277
- μ = mu
278
-
279
- # Compute the penetration depth, its rate, and the considered terrain normal.
280
- δ, δ̇, n̂ = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain)
281
-
282
- # There are few operations like computing the norm of a vector with zero length
283
- # or computing the square root of zero that are problematic in an AD context.
284
- # To avoid these issues, we introduce a small tolerance ε to their arguments
285
- # and make sure that we do not check them against zero directly.
286
- ε = jnp.finfo(float).eps
287
-
288
- # Compute the powers of the penetration depth.
289
- # Inject ε to address AD issues in differentiating the square root when
290
- # p and q are fractional.
291
- δp = jnp.power(δ + ε, p)
292
- δq = jnp.power(δ + ε, q)
293
-
294
- # ========================
295
- # Compute the normal force
296
- # ========================
297
-
298
- # Non-linear spring-damper model (Hunt/Crossley model).
299
- # This is the force magnitude along the direction normal to the terrain.
300
- force_normal_mag = (K * δp) * δ + (D * δq) * δ̇
301
-
302
- # Depending on the magnitude of δ̇, the normal force could be negative.
303
- force_normal_mag = jnp.maximum(0.0, force_normal_mag)
304
-
305
- # Compute the 3D linear force in C[W] frame.
306
- f_normal = force_normal_mag * n̂
307
-
308
- # ============================
309
- # Compute the tangential force
310
- # ============================
311
-
312
- # Extract the tangential component of the velocity.
313
- v_tangential = W_ṗ_C - jnp.dot(W_ṗ_C, n̂) * n̂
314
-
315
- # Extract the normal and tangential components of the material deformation.
316
- m_normal = jnp.dot(m, n̂) * n̂
317
- m_tangential = m - jnp.dot(m, n̂) * n̂
318
-
319
- # Compute the tangential force in the sticking case.
320
- # Using the tangential component of the material deformation should not be
321
- # necessary if the sticking-slipping transition occurs in a terrain area
322
- # with a locally constant normal. However, this assumption is not true in
323
- # general, especially for highly uneven terrains.
324
- f_tangential = -((K * δp) * m_tangential + (D * δq) * v_tangential)
325
-
326
- # Detect the contact type (sticking or slipping).
327
- # Note that if there is no contact, sticking is set to True, and this detail
328
- # is exploited in the computation of the `contact_status` variable.
329
- sticking = jnp.logical_or(
330
- δ <= 0, f_tangential.dot(f_tangential) <= (μ * force_normal_mag) ** 2
331
- )
332
-
333
- # Compute the direction of the tangential force.
334
- # To prevent dividing by zero, we use a switch statement.
335
- norm = jaxsim.math.safe_norm(f_tangential)
336
- f_tangential_direction = f_tangential / (
337
- norm + jnp.finfo(float).eps * (norm == 0)
338
- )
339
-
340
- # Project the tangential force to the friction cone if slipping.
341
- f_tangential = jnp.where(
342
- sticking,
343
- f_tangential,
344
- jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction,
345
- )
346
-
347
- # Set the tangential force to zero if there is no contact.
348
- f_tangential = jnp.where(δ <= 0, jnp.zeros(3), f_tangential)
349
-
350
- # =====================================
351
- # Compute the material deformation rate
352
- # =====================================
353
-
354
- # Compute the derivative of the material deformation.
355
- # Note that we included an additional relaxation of `m_normal` in the
356
- # sticking case, so that the normal deformation that could have accumulated
357
- # from a previous slipping phase can relax to zero.
358
- ṁ_no_contact = -(K / D) * m
359
- ṁ_sticking = v_tangential - (K / D) * m_normal
360
- ṁ_slipping = -(f_tangential + (K * δp) * m_tangential) / (D * δq)
361
-
362
- # Compute the contact status:
363
- # 0: slipping
364
- # 1: sticking
365
- # 2: no contact
366
- contact_status = sticking.astype(int)
367
- contact_status += (δ <= 0).astype(int)
368
-
369
- # Select the right material deformation rate depending on the contact status.
370
- ṁ = jax.lax.select_n(contact_status, ṁ_slipping, ṁ_sticking, ṁ_no_contact)
371
-
372
- # ==========================================
373
- # Compute and return the final contact force
374
- # ==========================================
375
-
376
- # Sum the normal and tangential forces.
377
- CW_fl = f_normal + f_tangential
378
-
379
- return CW_fl, ṁ
380
-
381
- @staticmethod
382
- @functools.partial(jax.jit, static_argnames=("terrain",))
383
- def compute_contact_force(
384
- position: jtp.VectorLike,
385
- velocity: jtp.VectorLike,
386
- tangential_deformation: jtp.VectorLike,
387
- parameters: SoftContactsParams,
388
- terrain: Terrain,
389
- ) -> tuple[jtp.Vector, jtp.Vector]:
390
- """
391
- Compute the contact force.
392
-
393
- Args:
394
- position: The position of the collidable point.
395
- velocity: The velocity of the collidable point.
396
- tangential_deformation: The material deformation of the collidable point.
397
- parameters: The parameters of the soft contacts model.
398
- terrain: The terrain model.
399
-
400
- Returns:
401
- A tuple containing the computed contact force and the derivative of the
402
- material deformation.
403
- """
404
-
405
- CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model(
406
- position=position,
407
- velocity=velocity,
408
- tangential_deformation=tangential_deformation,
409
- terrain=terrain,
410
- K=parameters.K,
411
- D=parameters.D,
412
- mu=parameters.mu,
413
- p=parameters.p,
414
- q=parameters.q,
415
- )
416
-
417
- # Pack a mixed 6D force.
418
- CW_f = jnp.hstack([CW_fl, jnp.zeros(3)])
419
-
420
- # Compute the 6D force transform from the mixed to the inertial-fixed frame.
421
- W_Xf_CW = jaxsim.math.Adjoint.from_quaternion_and_translation(
422
- translation=jnp.array(position), inverse=True
423
- ).T
424
-
425
- # Compute the 6D force in the inertial-fixed frame.
426
- W_f = W_Xf_CW @ CW_f
427
-
428
- return W_f, ṁ
429
-
430
- @staticmethod
431
- @jax.jit
432
- def compute_contact_forces(
433
- model: js.model.JaxSimModel,
434
- data: js.data.JaxSimModelData,
435
- ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
436
- """
437
- Compute the contact forces.
438
-
439
- Args:
440
- model: The model to consider.
441
- data: The data of the considered model.
442
-
443
- Returns:
444
- A tuple containing as first element the computed contact forces, and as
445
- second element a dictionary with derivative of the material deformation.
446
- """
447
-
448
- # Get the indices of the enabled collidable points.
449
- indices_of_enabled_collidable_points = (
450
- model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
451
- )
452
-
453
- # Compute the position and linear velocities (mixed representation) of
454
- # all the collidable points belonging to the robot and extract the ones
455
- # for the enabled collidable points.
456
- W_p_C, W_ṗ_C = js.contact.collidable_point_kinematics(model=model, data=data)
457
-
458
- # Extract the material deformation corresponding to the collidable points.
459
- m = data.state.extended["tangential_deformation"]
460
-
461
- m_enabled = m[indices_of_enabled_collidable_points]
462
-
463
- # Initialize the tangential deformation rate array for every collidable point.
464
- ṁ = jnp.zeros_like(m)
465
-
466
- # Compute the contact forces only for the enabled collidable points.
467
- # Since we treat them as independent, we can vmap the computation.
468
- W_f, ṁ_enabled = jax.vmap(
469
- lambda p, v, m: SoftContacts.compute_contact_force(
470
- position=p,
471
- velocity=v,
472
- tangential_deformation=m,
473
- parameters=data.contacts_params,
474
- terrain=model.terrain,
475
- )
476
- )(W_p_C, W_ṗ_C, m_enabled)
477
-
478
- ṁ = ṁ.at[indices_of_enabled_collidable_points].set(ṁ_enabled)
479
-
480
- return W_f, dict(m_dot=ṁ)