jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,523 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import dataclasses
4
- from typing import Tuple
5
-
6
- import jax
7
- import jax.flatten_util
8
- import jax.numpy as jnp
9
- import jax_dataclasses
10
- import numpy as np
11
-
12
- import jaxsim.physics.model.physics_model
13
- import jaxsim.typing as jtp
14
- from jaxsim.math.adjoint import Adjoint
15
- from jaxsim.math.skew import Skew
16
- from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
17
- from jaxsim.physics.model.physics_model import PhysicsModel
18
- from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass
19
-
20
- from . import utils
21
-
22
-
23
- @jax_dataclasses.pytree_dataclass
24
- class SoftContactsState(JaxsimDataclass):
25
- """
26
- State of the soft contacts model.
27
-
28
- Attributes:
29
- tangential_deformation:
30
- The tangential deformation of the material at each collidable point.
31
- """
32
-
33
- tangential_deformation: jtp.Matrix
34
-
35
- @staticmethod
36
- def build(
37
- tangential_deformation: jtp.Matrix | None = None,
38
- number_of_collidable_points: int | None = None,
39
- ) -> SoftContactsState:
40
- """"""
41
-
42
- tangential_deformation = (
43
- tangential_deformation
44
- if tangential_deformation is not None
45
- else jnp.zeros(shape=(3, number_of_collidable_points))
46
- )
47
-
48
- return SoftContactsState(
49
- tangential_deformation=jnp.array(tangential_deformation, dtype=float)
50
- )
51
-
52
- @staticmethod
53
- def build_from_physics_model(
54
- tangential_deformation: jtp.Matrix | None = None,
55
- physics_model: jaxsim.physics.model.physics_model.PhysicsModel | None = None,
56
- ) -> SoftContactsState:
57
- """"""
58
-
59
- return SoftContactsState.build(
60
- tangential_deformation=tangential_deformation,
61
- number_of_collidable_points=len(physics_model.gc.body),
62
- )
63
-
64
- @staticmethod
65
- def zero(
66
- physics_model: jaxsim.physics.model.physics_model.PhysicsModel,
67
- ) -> SoftContactsState:
68
- """
69
- Modify the SoftContactsState instance imposing zero tangential deformation.
70
-
71
- Args:
72
- physics_model: The physics model.
73
-
74
- Returns:
75
- A SoftContactsState instance with zero tangential deformation.
76
- """
77
-
78
- return SoftContactsState.build_from_physics_model(physics_model=physics_model)
79
-
80
- def valid(
81
- self, physics_model: jaxsim.physics.model.physics_model.PhysicsModel
82
- ) -> bool:
83
- """
84
- Check if the soft contacts state has valid shape.
85
-
86
- Args:
87
- physics_model: The physics model.
88
-
89
- Returns:
90
- True if the state has a valid shape, otherwise False.
91
- """
92
-
93
- from jaxsim.simulation.utils import check_valid_shape
94
-
95
- return check_valid_shape(
96
- what="tangential_deformation",
97
- shape=self.tangential_deformation.shape,
98
- expected_shape=(3, len(physics_model.gc.body)),
99
- valid=True,
100
- )
101
-
102
-
103
- def collidable_points_pos_vel(
104
- model: PhysicsModel,
105
- q: jtp.Vector,
106
- qd: jtp.Vector,
107
- xfb: jtp.Vector | None = None,
108
- ) -> Tuple[jtp.Matrix, jtp.Matrix]:
109
- """
110
- Compute the position and linear velocity of collidable points in the world frame.
111
-
112
- Args:
113
- model (PhysicsModel): The physics model.
114
- q (jtp.Vector): The joint positions.
115
- qd (jtp.Vector): The joint velocities.
116
- xfb (jtp.Vector, optional): The floating base state. Defaults to None.
117
-
118
- Returns:
119
- Tuple[jtp.Matrix, jtp.Matrix]: A tuple containing the position and velocity of collidable points.
120
- """
121
-
122
- # Make sure that shape and size are correct
123
- xfb, q, qd, _, _, _ = utils.process_inputs(physics_model=model, xfb=xfb, q=q, qd=qd)
124
-
125
- # Initialize buffers of link transforms (W_X_i) and 6D inertial velocities (W_v_Wi)
126
- W_X_i = jnp.zeros(shape=[model.NB, 6, 6])
127
- W_v_Wi = jnp.zeros(shape=[model.NB, 6, 1])
128
-
129
- # 6D transform of base velocity
130
- W_X_0 = Adjoint.from_quaternion_and_translation(
131
- quaternion=xfb[0:4], translation=xfb[4:7], normalize_quaternion=True
132
- )
133
- W_X_i = W_X_i.at[0].set(W_X_0)
134
-
135
- # Store the 6D inertial velocity W_v_W0 of the base link
136
- W_v_W0 = jnp.vstack(jnp.hstack([xfb[10:13], xfb[7:10]]))
137
- W_v_Wi = W_v_Wi.at[0].set(W_v_W0)
138
-
139
- # Compute useful resources from the model
140
- S = model.motion_subspaces(q=q)
141
-
142
- # Get the 6D transform between the parent link λi and the joint's predecessor frame
143
- pre_X_λi = model.tree_transforms
144
-
145
- # Compute the 6D transform of the joints (from predecessor to successor)
146
- i_X_pre = model.joint_transforms(q=q)
147
-
148
- # Parent array mapping: i -> λ(i).
149
- # Exception: λ(0) must not be used, it's initialized to -1.
150
- λ = model.parent_array()
151
-
152
- # ====================
153
- # Propagate kinematics
154
- # ====================
155
-
156
- PropagateTransformsCarry = Tuple[jtp.MatrixJax]
157
- propagate_transforms_carry: PropagateTransformsCarry = (W_X_i,)
158
-
159
- def propagate_transforms(
160
- carry: PropagateTransformsCarry, i: jtp.Int
161
- ) -> Tuple[PropagateTransformsCarry, None]:
162
- # Unpack the carry
163
- (W_X_i,) = carry
164
-
165
- # We need the inverse transforms (from parent to child direction)
166
- pre_Xi_i = Adjoint.inverse(i_X_pre[i])
167
- λi_Xi_pre = Adjoint.inverse(pre_X_λi[i])
168
-
169
- # Compute the parent to child 6D transform
170
- λi_X_i = λi_Xi_pre @ pre_Xi_i
171
-
172
- # Compute the world to child 6D transform
173
- W_Xi_i = W_X_i[λ[i]] @ λi_X_i
174
- W_X_i = W_X_i.at[i].set(W_Xi_i)
175
-
176
- # Pack and return the carry
177
- return (W_X_i,), None
178
-
179
- (W_X_i,), _ = jax.lax.scan(
180
- f=propagate_transforms,
181
- init=propagate_transforms_carry,
182
- xs=np.arange(start=1, stop=model.NB),
183
- )
184
-
185
- # ====================
186
- # Propagate velocities
187
- # ====================
188
-
189
- PropagateVelocitiesCarry = Tuple[jtp.MatrixJax]
190
- propagate_velocities_carry: PropagateVelocitiesCarry = (W_v_Wi,)
191
-
192
- def propagate_velocities(
193
- carry: PropagateVelocitiesCarry, j_vel_and_j_idx: jtp.VectorJax
194
- ) -> Tuple[PropagateVelocitiesCarry, None]:
195
- # Unpack the scanned data
196
- qd_ii = j_vel_and_j_idx[0]
197
- ii = jnp.array(j_vel_and_j_idx[1], dtype=int)
198
-
199
- # Given a joint whose velocity is qd[ii], the index of its parent link is ii + 1
200
- i = ii + 1
201
-
202
- # Unpack the carry
203
- (W_v_Wi,) = carry
204
-
205
- # Propagate the 6D velocity
206
- W_vi_Wi = W_v_Wi[λ[i]] + W_X_i[i] @ (S[i] * qd_ii)
207
- W_v_Wi = W_v_Wi.at[i].set(W_vi_Wi)
208
-
209
- # Pack and return the carry
210
- return (W_v_Wi,), None
211
-
212
- (W_v_Wi,), _ = jax.lax.scan(
213
- f=propagate_velocities,
214
- init=propagate_velocities_carry,
215
- xs=jnp.vstack([qd, jnp.arange(start=0, stop=qd.size)]).T,
216
- )
217
-
218
- # ==================================================
219
- # Compute position and velocity of collidable points
220
- # ==================================================
221
-
222
- def process_point_kinematics(
223
- Li_p_C: jtp.VectorJax, parent_body: jtp.Int
224
- ) -> Tuple[jtp.VectorJax, jtp.VectorJax]:
225
- # Compute the position of the collidable point
226
- W_p_Ci = (
227
- Adjoint.to_transform(adjoint=W_X_i[parent_body]) @ jnp.hstack([Li_p_C, 1])
228
- )[0:3]
229
-
230
- # Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}
231
- CW_vl_WCi = (
232
- jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()])
233
- @ W_v_Wi[parent_body].squeeze()
234
- )
235
-
236
- return W_p_Ci, CW_vl_WCi
237
-
238
- # Process all the collidable points in parallel
239
- W_p_Ci, CW_v_WC = jax.vmap(process_point_kinematics)(
240
- model.gc.point.T, np.array(model.gc.body, dtype=int)
241
- )
242
-
243
- return W_p_Ci.transpose(), CW_v_WC.transpose()
244
-
245
-
246
- @jax_dataclasses.pytree_dataclass
247
- class SoftContactsParams:
248
- """Parameters of the soft contacts model."""
249
-
250
- K: float = dataclasses.field(default_factory=lambda: jnp.array(1e6, dtype=float))
251
- D: float = dataclasses.field(default_factory=lambda: jnp.array(2000, dtype=float))
252
- mu: float = dataclasses.field(default_factory=lambda: jnp.array(0.5, dtype=float))
253
-
254
- @staticmethod
255
- def build(
256
- K: jtp.FloatLike = 1e6, D: jtp.FloatLike = 2_000, mu: jtp.FloatLike = 0.5
257
- ) -> SoftContactsParams:
258
- """
259
- Create a SoftContactsParams instance with specified parameters.
260
-
261
- Args:
262
- K (float, optional): The stiffness parameter. Defaults to 1e6.
263
- D (float, optional): The damping parameter. Defaults to 2000.
264
- mu (float, optional): The friction coefficient. Defaults to 0.5.
265
-
266
- Returns:
267
- SoftContactsParams: A SoftContactsParams instance with the specified parameters.
268
- """
269
-
270
- return SoftContactsParams(
271
- K=jnp.array(K, dtype=float),
272
- D=jnp.array(D, dtype=float),
273
- mu=jnp.array(mu, dtype=float),
274
- )
275
-
276
- @staticmethod
277
- def build_default_from_physics_model(
278
- physics_model: PhysicsModel,
279
- static_friction_coefficient: jtp.FloatLike = 0.5,
280
- max_penetration: jtp.FloatLike = 0.001,
281
- number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
282
- damping_ratio: jtp.FloatLike = 1.0,
283
- ) -> SoftContactsParams:
284
- """
285
- Create a SoftContactsParams instance with good default parameters.
286
-
287
- Args:
288
- physics_model: The target physics model.
289
- static_friction_coefficient: The static friction coefficient.
290
- max_penetration: The maximum penetration depth.
291
- number_of_active_collidable_points_steady_state: The number of contacts
292
- supporting the weight of the model in steady state.
293
- damping_ratio: The ratio controlling the damping behavior.
294
-
295
- Returns:
296
- A SoftContactsParams instance with the specified parameters.
297
-
298
- Note:
299
- The `damping_ratio` parameter allows to operate on the following conditions:
300
- - ξ > 1.0: over-damped
301
- - ξ = 1.0: critically damped
302
- - ξ < 1.0: under-damped
303
- """
304
-
305
- # Use symbols for input parameters
306
- ξ = damping_ratio
307
- δ_max = max_penetration
308
- μc = static_friction_coefficient
309
-
310
- # Compute the total mass of the model
311
- m = jnp.array(
312
- [l.mass for l in physics_model.description.links_dict.values()]
313
- ).sum()
314
-
315
- # Extract gravity
316
- g = -physics_model.gravity[0:3][-1]
317
-
318
- # Compute the average support force on each collidable point
319
- f_average = m * g / number_of_active_collidable_points_steady_state
320
-
321
- # Compute the stiffness to get the desired steady-state penetration
322
- K = f_average / jnp.power(δ_max, 3 / 2)
323
-
324
- # Compute the damping using the damping ratio
325
- critical_damping = 2 * jnp.sqrt(K * m)
326
- D = ξ * critical_damping
327
-
328
- return SoftContactsParams.build(K=K, D=D, mu=μc)
329
-
330
-
331
- @jax_dataclasses.pytree_dataclass
332
- class SoftContacts:
333
- """Soft contacts model."""
334
-
335
- parameters: SoftContactsParams = dataclasses.field(
336
- default_factory=SoftContactsParams
337
- )
338
-
339
- terrain: Terrain = dataclasses.field(default_factory=FlatTerrain)
340
-
341
- def contact_model(
342
- self,
343
- position: jtp.Vector,
344
- velocity: jtp.Vector,
345
- tangential_deformation: jtp.Vector,
346
- ) -> Tuple[jtp.Vector, jtp.Vector]:
347
- """
348
- Compute the contact forces and material deformation rate.
349
-
350
- Args:
351
- position (jtp.Vector): The position of the collidable point.
352
- velocity (jtp.Vector): The linear velocity of the collidable point.
353
- tangential_deformation (jtp.Vector): The tangential deformation.
354
-
355
- Returns:
356
- Tuple[jtp.Vector, jtp.Vector]: A tuple containing the contact force and material deformation rate.
357
- """
358
-
359
- # Short name of parameters
360
- K = self.parameters.K
361
- D = self.parameters.D
362
- μ = self.parameters.mu
363
-
364
- # Material 3D tangential deformation and its derivative
365
- m = tangential_deformation.squeeze()
366
- ṁ = jnp.zeros_like(m)
367
-
368
- # Note: all the small hardcoded tolerances in this method have been introduced
369
- # to allow jax differentiating through this algorithm. They should not affect
370
- # the accuracy of the simulation, although they might make it less readable.
371
-
372
- # ========================
373
- # Normal force computation
374
- # ========================
375
-
376
- # Unpack the position of the collidable point
377
- px, py, pz = W_p_C = position.squeeze()
378
- vx, vy, vz = W_ṗ_C = velocity.squeeze()
379
-
380
- # Compute the terrain normal and the contact depth
381
- n̂ = self.terrain.normal(x=px, y=py).squeeze()
382
- h = jnp.array([0, 0, self.terrain.height(x=px, y=py) - pz])
383
-
384
- # Compute the penetration depth normal to the terrain
385
- δ = jnp.maximum(0.0, jnp.dot(h, n̂))
386
-
387
- # Compute the penetration normal velocity
388
- δ̇ = -jnp.dot(W_ṗ_C, n̂)
389
-
390
- # Non-linear spring-damper model.
391
- # This is the force magnitude along the direction normal to the terrain.
392
- force_normal_mag = jax.lax.select(
393
- pred=δ >= 1e-9,
394
- on_true=jnp.sqrt(δ + 1e-12) * (K * δ + D * δ̇),
395
- on_false=jnp.array(0.0),
396
- )
397
-
398
- # Prevent negative normal forces that might occur when δ̇ is largely negative
399
- force_normal_mag = jnp.maximum(0.0, force_normal_mag)
400
-
401
- # Compute the 3D linear force in C[W] frame
402
- force_normal = force_normal_mag * n̂
403
-
404
- # ====================================
405
- # No friction and no tangential forces
406
- # ====================================
407
-
408
- # Compute the adjoint C[W]->W for transforming 6D forces from mixed to inertial.
409
- # Note: this is equal to the 6D velocities transform: CW_X_W.transpose().
410
- W_Xf_CW = jnp.vstack(
411
- [
412
- jnp.block([jnp.eye(3), jnp.zeros(shape=(3, 3))]),
413
- jnp.block([Skew.wedge(W_p_C), jnp.eye(3)]),
414
- ]
415
- )
416
-
417
- def with_no_friction():
418
- # Compute 6D mixed force in C[W]
419
- CW_f_lin = force_normal
420
- CW_f = jnp.hstack([force_normal, jnp.zeros_like(CW_f_lin)])
421
-
422
- # Compute lin-ang 6D forces (inertial representation)
423
- W_f = W_Xf_CW @ CW_f
424
-
425
- return W_f, ṁ
426
-
427
- # =========================
428
- # Compute tangential forces
429
- # =========================
430
-
431
- def with_friction():
432
- # Initialize the tangential deformation rate ṁ.
433
- # For inactive contacts with m≠0, this is the dynamics of the material
434
- # relaxation converging exponentially to steady state.
435
- ṁ = (-K / D) * m
436
-
437
- # Check if the collidable point is below ground.
438
- # Note: when δ=0, we consider the point still not it contact such that
439
- # we prevent divisions by 0 in the computations below.
440
- active_contact = pz < self.terrain.height(x=px, y=py)
441
-
442
- def above_terrain():
443
- return jnp.zeros(6), ṁ
444
-
445
- def below_terrain():
446
- # Decompose the velocity in normal and tangential components
447
- v_normal = jnp.dot(W_ṗ_C, n̂) * n̂
448
- v_tangential = W_ṗ_C - v_normal
449
-
450
- # Compute the tangential force. If inside the friction cone, the contact
451
- f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential)
452
-
453
- def sticking_contact():
454
- # Sum the normal and tangential forces, and create the 6D force
455
- CW_f_stick = force_normal + f_tangential
456
- CW_f = jnp.hstack([CW_f_stick, jnp.zeros(3)])
457
-
458
- # In this case the 3D material deformation is the tangential velocity
459
- ṁ = v_tangential
460
-
461
- # Return the 6D force in the contact frame and
462
- # the deformation derivative
463
- return CW_f, ṁ
464
-
465
- def slipping_contact():
466
- # Clip the tangential force if too small, allowing jax to
467
- # differentiate through the norm computation
468
- f_tangential_no_nan = jax.lax.select(
469
- pred=f_tangential.dot(f_tangential) >= 1e-9**2,
470
- on_true=f_tangential,
471
- on_false=jnp.array([1e-12, 0, 0]),
472
- )
473
-
474
- # Project the force to the friction cone boundary
475
- f_tangential_projected = (μ * force_normal_mag) * (
476
- f_tangential / jnp.linalg.norm(f_tangential_no_nan)
477
- )
478
-
479
- # Sum the normal and tangential forces, and create the 6D force
480
- CW_f_slip = force_normal + f_tangential_projected
481
- CW_f = jnp.hstack([CW_f_slip, jnp.zeros(3)])
482
-
483
- # Correct the material deformation derivative for slipping contacts.
484
- # Basically we compute ṁ such that we get `f_tangential` on the cone
485
- # given the current (m, δ).
486
- ε = 1e-9
487
- δε = jnp.maximum(δ, ε)
488
- α = -K * jnp.sqrt(δε)
489
- β = -D * jnp.sqrt(δε)
490
- ṁ = (f_tangential_projected - α * m) / β
491
-
492
- # Return the 6D force in the contact frame and
493
- # the deformation derivative
494
- return CW_f, ṁ
495
-
496
- CW_f, ṁ = jax.lax.cond(
497
- pred=f_tangential.dot(f_tangential) > (μ * force_normal_mag) ** 2,
498
- true_fun=lambda _: slipping_contact(),
499
- false_fun=lambda _: sticking_contact(),
500
- operand=None,
501
- )
502
-
503
- # Express the 6D force in the world frame
504
- W_f = W_Xf_CW @ CW_f
505
-
506
- # Return the 6D force in the world frame and the deformation derivative
507
- return W_f, ṁ
508
-
509
- # (W_f, ṁ)
510
- return jax.lax.cond(
511
- pred=active_contact,
512
- true_fun=lambda _: below_terrain(),
513
- false_fun=lambda _: above_terrain(),
514
- operand=None,
515
- )
516
-
517
- # (W_f, ṁ)
518
- return jax.lax.cond(
519
- pred=(μ == 0.0),
520
- true_fun=lambda _: with_no_friction(),
521
- false_fun=lambda _: with_friction(),
522
- operand=None,
523
- )
@@ -1,78 +0,0 @@
1
- import abc
2
-
3
- import jax.numpy as jnp
4
- import jax_dataclasses
5
-
6
- import jaxsim.typing as jtp
7
-
8
-
9
- class Terrain(abc.ABC):
10
- delta = 0.010
11
-
12
- @abc.abstractmethod
13
- def height(self, x: float, y: float) -> float:
14
- pass
15
-
16
- def normal(self, x: float, y: float) -> jtp.Vector:
17
- """
18
- Compute the normal vector of the terrain at a specific (x, y) location.
19
-
20
- Args:
21
- x (float): The x-coordinate of the location.
22
- y (float): The y-coordinate of the location.
23
-
24
- Returns:
25
- jtp.Vector: The normal vector of the terrain surface at the specified location.
26
- """
27
-
28
- # https://stackoverflow.com/a/5282364
29
- h_xp = self.height(x=x + self.delta, y=y)
30
- h_xm = self.height(x=x - self.delta, y=y)
31
- h_yp = self.height(x=x, y=y + self.delta)
32
- h_ym = self.height(x=x, y=y - self.delta)
33
-
34
- n = jnp.array(
35
- [(h_xm - h_xp) / (2 * self.delta), (h_ym - h_yp) / (2 * self.delta), 1.0]
36
- )
37
-
38
- return n / jnp.linalg.norm(n)
39
-
40
-
41
- @jax_dataclasses.pytree_dataclass
42
- class FlatTerrain(Terrain):
43
- def height(self, x: float, y: float) -> float:
44
- return 0.0
45
-
46
-
47
- @jax_dataclasses.pytree_dataclass
48
- class PlaneTerrain(Terrain):
49
- plane_normal: list = jax_dataclasses.field(default_factory=lambda: [0, 0, 1.0])
50
-
51
- @staticmethod
52
- def build(plane_normal: list) -> "PlaneTerrain":
53
- """
54
- Create a PlaneTerrain instance with a specified plane normal vector.
55
-
56
- Args:
57
- plane_normal (list): The normal vector of the terrain plane.
58
-
59
- Returns:
60
- PlaneTerrain: A PlaneTerrain instance.
61
- """
62
-
63
- return PlaneTerrain(plane_normal=plane_normal)
64
-
65
- def height(self, x: float, y: float) -> float:
66
- """
67
- Compute the height of the terrain at a specific (x, y) location on a plane.
68
-
69
- Args:
70
- x (float): The x-coordinate of the location.
71
- y (float): The y-coordinate of the location.
72
-
73
- Returns:
74
- float: The height of the terrain at the specified location on the plane.
75
- """
76
-
77
- a, b, c = self.plane_normal
78
- return -(a * x + b * y) / c
@@ -1,69 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax.numpy as jnp
4
-
5
- import jaxsim.typing as jtp
6
- from jaxsim.physics.model.physics_model import PhysicsModel
7
-
8
-
9
- def process_inputs(
10
- physics_model: PhysicsModel,
11
- xfb: jtp.Vector | None = None,
12
- q: jtp.Vector | None = None,
13
- qd: jtp.Vector | None = None,
14
- qdd: jtp.Vector | None = None,
15
- tau: jtp.Vector | None = None,
16
- f_ext: jtp.Matrix | None = None,
17
- ) -> Tuple[jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector, jtp.Matrix]:
18
- """
19
- Adjust the inputs to the physics model.
20
-
21
- Args:
22
- physics_model: The physics model.
23
- xfb: The variables of the base link.
24
- q: The generalized coordinates.
25
- qd: The generalized velocities.
26
- qdd: The generalized accelerations.
27
- tau: The generalized forces.
28
- f_ext: The link external forces.
29
-
30
- Returns:
31
- The adjusted inputs.
32
- """
33
-
34
- # Remove extra dimensions
35
- q = q.squeeze() if q is not None else jnp.zeros(physics_model.dofs())
36
- qd = qd.squeeze() if qd is not None else jnp.zeros(physics_model.dofs())
37
- qdd = qdd.squeeze() if qdd is not None else jnp.zeros(physics_model.dofs())
38
- tau = tau.squeeze() if tau is not None else jnp.zeros(physics_model.dofs())
39
- xfb = xfb.squeeze() if xfb is not None else jnp.zeros(13).at[0].set(1)
40
- f_ext = (
41
- f_ext.squeeze()
42
- if f_ext is not None
43
- else jnp.zeros(shape=(physics_model.NB, 6)).squeeze()
44
- )
45
-
46
- # Fix case with just 1 DoF
47
- q = jnp.atleast_1d(q)
48
- qd = jnp.atleast_1d(qd)
49
- qdd = jnp.atleast_1d(qdd)
50
- tau = jnp.atleast_1d(tau)
51
-
52
- # Fix case with just 1 body
53
- f_ext = jnp.atleast_2d(f_ext)
54
-
55
- # Validate dimensions
56
- dofs = physics_model.dofs()
57
-
58
- if xfb is not None and xfb.shape[0] != 13:
59
- raise ValueError(xfb.shape)
60
- if q is not None and q.shape[0] != dofs:
61
- raise ValueError(q.shape, dofs)
62
- if qd is not None and qd.shape[0] != dofs:
63
- raise ValueError(qd.shape, dofs)
64
- if tau is not None and tau.shape[0] != dofs:
65
- raise ValueError(tau.shape, dofs)
66
- if f_ext is not None and f_ext.shape != (physics_model.NB, 6):
67
- raise ValueError(f_ext.shape, (physics_model.NB, 6))
68
-
69
- return xfb, q, qd, qdd, tau, f_ext
File without changes