jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1.dev401.dist-info/METADATA +0 -167
  88. jaxsim-0.1.dev401.dist-info/RECORD +0 -64
  89. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,196 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import numpy as np
6
-
7
- import jaxsim.typing as jtp
8
- from jaxsim.math.adjoint import Adjoint
9
- from jaxsim.math.cross import Cross
10
- from jaxsim.physics.model.physics_model import PhysicsModel
11
-
12
- from . import utils
13
-
14
-
15
- def rnea(
16
- model: PhysicsModel,
17
- xfb: jtp.Vector,
18
- q: jtp.Vector,
19
- qd: jtp.Vector,
20
- qdd: jtp.Vector,
21
- a0fb: jtp.Vector = jnp.zeros(6),
22
- f_ext: jtp.Matrix | None = None,
23
- ) -> Tuple[jtp.Vector, jtp.Vector]:
24
- """
25
- Recursive Newton-Euler Algorithm (RNEA) algorithm for inverse dynamics.
26
- """
27
-
28
- xfb, q, qd, qdd, _, f_ext = utils.process_inputs(
29
- physics_model=model, xfb=xfb, q=q, qd=qd, qdd=qdd, f_ext=f_ext
30
- )
31
-
32
- a0fb = a0fb.squeeze()
33
- gravity = model.gravity.squeeze()
34
-
35
- if a0fb.shape[0] != 6:
36
- raise ValueError(a0fb.shape)
37
-
38
- M = model.spatial_inertias
39
- pre_X_λi = model.tree_transforms
40
- i_X_pre = model.joint_transforms(q=q)
41
- S = model.motion_subspaces(q=q)
42
- i_X_λi = jnp.zeros_like(pre_X_λi)
43
-
44
- Γ = jnp.array([*model._joint_motor_gear_ratio.values()])
45
- IM = jnp.array([*model._joint_motor_inertia.values()])
46
- K_v = jnp.array([*model._joint_motor_viscous_friction.values()])
47
- K̅ᵥ = jnp.diag(Γ.T * jnp.diag(K_v) * Γ)
48
- m_S = jnp.concatenate([S[:1], S[1:] * Γ[:, None, None]], axis=0)
49
-
50
- i_X_0 = jnp.zeros_like(pre_X_λi)
51
- i_X_0 = i_X_0.at[0].set(jnp.eye(6))
52
-
53
- # Parent array mapping: i -> λ(i).
54
- # Exception: λ(0) must not be used, it's initialized to -1.
55
- λ = model.parent_array()
56
-
57
- v = jnp.array([jnp.zeros([6, 1])] * model.NB)
58
- a = jnp.array([jnp.zeros([6, 1])] * model.NB)
59
- f = jnp.array([jnp.zeros([6, 1])] * model.NB)
60
-
61
- v_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
62
- a_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
63
- f_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
64
-
65
- # 6D transform of base velocity
66
- B_X_W = Adjoint.from_quaternion_and_translation(
67
- quaternion=xfb[0:4],
68
- translation=xfb[4:7],
69
- inverse=True,
70
- normalize_quaternion=True,
71
- )
72
- i_X_λi = i_X_λi.at[0].set(B_X_W)
73
-
74
- a_0 = -B_X_W @ jnp.vstack(gravity)
75
- a = a.at[0].set(a_0)
76
-
77
- if model.is_floating_base:
78
- W_v_WB = jnp.vstack(jnp.hstack([xfb[10:13], xfb[7:10]]))
79
-
80
- v_0 = B_X_W @ W_v_WB
81
- v = v.at[0].set(v_0)
82
-
83
- a_0 = B_X_W @ (jnp.vstack(a0fb) - jnp.vstack(gravity))
84
- a = a.at[0].set(a_0)
85
-
86
- f_0 = (
87
- M[0] @ a[0]
88
- + Cross.vx_star(v[0]) @ M[0] @ v[0]
89
- - Adjoint.inverse(B_X_W).T @ jnp.vstack(f_ext[0])
90
- )
91
- f = f.at[0].set(f_0)
92
-
93
- ForwardPassCarry = Tuple[
94
- jtp.MatrixJax,
95
- jtp.MatrixJax,
96
- jtp.MatrixJax,
97
- jtp.MatrixJax,
98
- jtp.MatrixJax,
99
- jtp.MatrixJax,
100
- jtp.MatrixJax,
101
- jtp.MatrixJax,
102
- ]
103
- forward_pass_carry = (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m)
104
-
105
- def forward_pass(
106
- carry: ForwardPassCarry, i: jtp.Int
107
- ) -> Tuple[ForwardPassCarry, None]:
108
- ii = i - 1
109
- i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m = carry
110
-
111
- vJ = S[i] * qd[ii]
112
- vJ_m = m_S[i] * qd[ii]
113
-
114
- i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
115
- i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
116
-
117
- v_i = i_X_λi[i] @ v[λ[i]] + vJ
118
- v = v.at[i].set(v_i)
119
-
120
- v_i_m = i_X_λi[i] @ v_m[λ[i]] + vJ_m
121
- v_m = v_m.at[i].set(v_i_m)
122
-
123
- a_i = i_X_λi[i] @ a[λ[i]] + S[i] * qdd[ii] + Cross.vx(v[i]) @ vJ
124
- a = a.at[i].set(a_i)
125
-
126
- a_i_m = i_X_λi[i] @ a_m[λ[i]] + m_S[i] * qdd[ii] + Cross.vx(v_m[i]) @ vJ_m
127
- a_m = a_m.at[i].set(a_i_m)
128
-
129
- i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
130
- i_X_0 = i_X_0.at[i].set(i_X_0_i)
131
- i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
132
-
133
- f_i = (
134
- M[i] @ a[i]
135
- + Cross.vx_star(v[i]) @ M[i] @ v[i]
136
- - i_Xf_W @ jnp.vstack(f_ext[i])
137
- )
138
- f = f.at[i].set(f_i)
139
-
140
- f_i_m = IM[i] * a_m[i] + Cross.vx_star(v_m[i]) * IM[i] @ v_m[i]
141
- f_m = f_m.at[i].set(f_i_m)
142
-
143
- return (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), None
144
-
145
- (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), _ = jax.lax.scan(
146
- f=forward_pass,
147
- init=forward_pass_carry,
148
- xs=np.arange(start=1, stop=model.NB),
149
- )
150
-
151
- tau = jnp.zeros_like(q)
152
-
153
- BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax]
154
- backward_pass_carry = (tau, f, f_m)
155
-
156
- def backward_pass(
157
- carry: BackwardPassCarry, i: jtp.Int
158
- ) -> Tuple[BackwardPassCarry, None]:
159
- ii = i - 1
160
- tau, f, f_m = carry
161
-
162
- value = S[i].T @ f[i] + m_S[i].T @ f_m[i] # + K̅ᵥ[i] * qd[ii]
163
- tau = tau.at[ii].set(value.squeeze())
164
-
165
- def update_f(ffm: Tuple[jtp.MatrixJax, jtp.MatrixJax]) -> jtp.MatrixJax:
166
- f, f_m = ffm
167
- f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
168
- f = f.at[λ[i]].set(f_λi)
169
-
170
- f_m_λi = f_m[λ[i]] + i_X_λi[i].T @ f_m[i]
171
- f_m = f_m.at[λ[i]].set(f_m_λi)
172
- return f, f_m
173
-
174
- f, f_m = jax.lax.cond(
175
- pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(),
176
- true_fun=update_f,
177
- false_fun=lambda f: f,
178
- operand=(f, f_m),
179
- )
180
-
181
- return (tau, f, f_m), None
182
-
183
- (tau, f, f_m), _ = jax.lax.scan(
184
- f=backward_pass,
185
- init=backward_pass_carry,
186
- xs=np.flip(np.arange(start=1, stop=model.NB)),
187
- )
188
-
189
- # Handle 1 DoF models
190
- tau = jnp.atleast_1d(tau.squeeze())
191
- tau = jnp.vstack(tau) if tau.size > 0 else jnp.empty(shape=(0, 1))
192
-
193
- # Express the base 6D force in the world frame
194
- W_f0 = B_X_W.T @ jnp.vstack(f[0])
195
-
196
- return W_f0, tau
@@ -1,454 +0,0 @@
1
- import dataclasses
2
- from typing import Tuple
3
-
4
- import jax
5
- import jax.flatten_util
6
- import jax.numpy as jnp
7
- import jax_dataclasses
8
- import numpy as np
9
-
10
- import jaxsim.physics.model.physics_model
11
- import jaxsim.typing as jtp
12
- from jaxsim.math.adjoint import Adjoint
13
- from jaxsim.math.skew import Skew
14
- from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
15
- from jaxsim.physics.model.physics_model import PhysicsModel
16
-
17
- from . import utils
18
-
19
-
20
- @jax_dataclasses.pytree_dataclass
21
- class SoftContactsState:
22
- """
23
- State of the soft contacts model.
24
-
25
- Attributes:
26
- tangential_deformation (jtp.Matrix): The tangential deformation of the material at each collidable point.
27
- """
28
-
29
- tangential_deformation: jtp.Matrix
30
-
31
- @staticmethod
32
- def zero(
33
- physics_model: jaxsim.physics.model.physics_model.PhysicsModel,
34
- ) -> "SoftContactsState":
35
- """
36
- Modify the SoftContactsState instance imposing zero tangential deformation.
37
-
38
- Args:
39
- physics_model (jaxsim.physics.model.physics_model.PhysicsModel): The physics model.
40
-
41
- Returns:
42
- SoftContactsState: A SoftContactsState instance with zero tangential deformation.
43
- """
44
-
45
- return SoftContactsState(
46
- tangential_deformation=jnp.zeros(shape=(3, physics_model.gc.body.size))
47
- )
48
-
49
- def valid(
50
- self, physics_model: jaxsim.physics.model.physics_model.PhysicsModel
51
- ) -> bool:
52
- """
53
- Check if the soft contacts state has valid shape.
54
-
55
- Args:
56
- physics_model (jaxsim.physics.model.physics_model.PhysicsModel): The physics model.
57
-
58
- Returns:
59
- bool: True if the state has a valid shape, otherwise False.
60
- """
61
-
62
- from jaxsim.simulation.utils import check_valid_shape
63
-
64
- return check_valid_shape(
65
- what="tangential_deformation",
66
- shape=self.tangential_deformation.shape,
67
- expected_shape=(3, physics_model.gc.body.size),
68
- valid=True,
69
- )
70
-
71
- def replace(self, validate: bool = True, **kwargs) -> "SoftContactsState":
72
- """
73
- Replace attributes of the soft contacts state.
74
-
75
- Args:
76
- validate (bool, optional): Whether to validate the state after replacement. Defaults to True.
77
-
78
- Returns:
79
- SoftContactsState: A new SoftContactsState instance with replaced attributes.
80
- """
81
-
82
- with jax_dataclasses.copy_and_mutate(self, validate=validate) as updated_state:
83
- _ = [updated_state.__setattr__(k, v) for k, v in kwargs.items()]
84
-
85
- return updated_state
86
-
87
-
88
- def collidable_points_pos_vel(
89
- model: PhysicsModel,
90
- q: jtp.Vector,
91
- qd: jtp.Vector,
92
- xfb: jtp.Vector | None = None,
93
- ) -> Tuple[jtp.Matrix, jtp.Matrix]:
94
- """
95
- Compute the position and linear velocity of collidable points in the world frame.
96
-
97
- Args:
98
- model (PhysicsModel): The physics model.
99
- q (jtp.Vector): The joint positions.
100
- qd (jtp.Vector): The joint velocities.
101
- xfb (jtp.Vector, optional): The floating base state. Defaults to None.
102
-
103
- Returns:
104
- Tuple[jtp.Matrix, jtp.Matrix]: A tuple containing the position and velocity of collidable points.
105
- """
106
-
107
- # Make sure that shape and size are correct
108
- xfb, q, qd, _, _, _ = utils.process_inputs(physics_model=model, xfb=xfb, q=q, qd=qd)
109
-
110
- # Initialize buffers of link transforms (W_X_i) and 6D inertial velocities (W_v_Wi)
111
- W_X_i = jnp.zeros(shape=[model.NB, 6, 6])
112
- W_v_Wi = jnp.zeros(shape=[model.NB, 6, 1])
113
-
114
- # 6D transform of base velocity
115
- W_X_0 = Adjoint.from_quaternion_and_translation(
116
- quaternion=xfb[0:4], translation=xfb[4:7], normalize_quaternion=True
117
- )
118
- W_X_i = W_X_i.at[0].set(W_X_0)
119
-
120
- # Store the 6D inertial velocity W_v_W0 of the base link
121
- W_v_W0 = jnp.vstack(jnp.hstack([xfb[10:13], xfb[7:10]]))
122
- W_v_Wi = W_v_Wi.at[0].set(W_v_W0)
123
-
124
- # Compute useful resources from the model
125
- S = model.motion_subspaces(q=q)
126
-
127
- # Get the 6D transform between the parent link λi and the joint's predecessor frame
128
- pre_X_λi = model.tree_transforms
129
-
130
- # Compute the 6D transform of the joints (from predecessor to successor)
131
- i_X_pre = model.joint_transforms(q=q)
132
-
133
- # Parent array mapping: i -> λ(i).
134
- # Exception: λ(0) must not be used, it's initialized to -1.
135
- λ = model.parent_array()
136
-
137
- # ====================
138
- # Propagate kinematics
139
- # ====================
140
-
141
- PropagateTransformsCarry = Tuple[jtp.MatrixJax]
142
- propagate_transforms_carry: PropagateTransformsCarry = (W_X_i,)
143
-
144
- def propagate_transforms(
145
- carry: PropagateTransformsCarry, i: jtp.Int
146
- ) -> Tuple[PropagateTransformsCarry, None]:
147
- # Unpack the carry
148
- (W_X_i,) = carry
149
-
150
- # We need the inverse transforms (from parent to child direction)
151
- pre_Xi_i = Adjoint.inverse(i_X_pre[i])
152
- λi_Xi_pre = Adjoint.inverse(pre_X_λi[i])
153
-
154
- # Compute the parent to child 6D transform
155
- λi_X_i = λi_Xi_pre @ pre_Xi_i
156
-
157
- # Compute the world to child 6D transform
158
- W_Xi_i = W_X_i[λ[i]] @ λi_X_i
159
- W_X_i = W_X_i.at[i].set(W_Xi_i)
160
-
161
- # Pack and return the carry
162
- return (W_X_i,), None
163
-
164
- (W_X_i,), _ = jax.lax.scan(
165
- f=propagate_transforms,
166
- init=propagate_transforms_carry,
167
- xs=np.arange(start=1, stop=model.NB),
168
- )
169
-
170
- # ====================
171
- # Propagate velocities
172
- # ====================
173
-
174
- PropagateVelocitiesCarry = Tuple[jtp.MatrixJax]
175
- propagate_velocities_carry: PropagateVelocitiesCarry = (W_v_Wi,)
176
-
177
- def propagate_velocities(
178
- carry: PropagateVelocitiesCarry, j_vel_and_j_idx: jtp.VectorJax
179
- ) -> Tuple[PropagateVelocitiesCarry, None]:
180
- # Unpack the scanned data
181
- qd_ii = j_vel_and_j_idx[0]
182
- ii = jnp.array(j_vel_and_j_idx[1], dtype=int)
183
-
184
- # Given a joint whose velocity is qd[ii], the index of its parent link is ii + 1
185
- i = ii + 1
186
-
187
- # Unpack the carry
188
- (W_v_Wi,) = carry
189
-
190
- # Propagate the 6D velocity
191
- W_vi_Wi = W_v_Wi[λ[i]] + W_X_i[i] @ (S[i] * qd_ii)
192
- W_v_Wi = W_v_Wi.at[i].set(W_vi_Wi)
193
-
194
- # Pack and return the carry
195
- return (W_v_Wi,), None
196
-
197
- (W_v_Wi,), _ = jax.lax.scan(
198
- f=propagate_velocities,
199
- init=propagate_velocities_carry,
200
- xs=jnp.vstack([qd, jnp.arange(start=0, stop=qd.size)]).T,
201
- )
202
-
203
- # ==================================================
204
- # Compute position and velocity of collidable points
205
- # ==================================================
206
-
207
- def process_point_kinematics(
208
- Li_p_C: jtp.VectorJax, parent_body: jtp.Int
209
- ) -> Tuple[jtp.VectorJax, jtp.VectorJax]:
210
- # Compute the position of the collidable point
211
- W_p_Ci = (
212
- Adjoint.to_transform(adjoint=W_X_i[parent_body]) @ jnp.hstack([Li_p_C, 1])
213
- )[0:3]
214
-
215
- # Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}
216
- CW_vl_WCi = (
217
- jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()])
218
- @ W_v_Wi[parent_body].squeeze()
219
- )
220
-
221
- return W_p_Ci, CW_vl_WCi
222
-
223
- # Process all the collidable points in parallel
224
- W_p_Ci, CW_v_WC = jax.vmap(process_point_kinematics)(
225
- model.gc.point.T, model.gc.body
226
- )
227
-
228
- return W_p_Ci.transpose(), CW_v_WC.transpose()
229
-
230
-
231
- @jax_dataclasses.pytree_dataclass
232
- class SoftContactsParams:
233
- """Parameters of the soft contacts model."""
234
-
235
- K: float = dataclasses.field(default_factory=lambda: jnp.array(1e6, dtype=float))
236
- D: float = dataclasses.field(default_factory=lambda: jnp.array(2000, dtype=float))
237
- mu: float = dataclasses.field(default_factory=lambda: jnp.array(0.5, dtype=float))
238
-
239
- @staticmethod
240
- def build(
241
- K: float = 1e6, D: float = 2_000, mu: float = 0.5
242
- ) -> "SoftContactsParams":
243
- """
244
- Create a SoftContactsParams instance with specified parameters.
245
-
246
- Args:
247
- K (float, optional): The stiffness parameter. Defaults to 1e6.
248
- D (float, optional): The damping parameter. Defaults to 2000.
249
- mu (float, optional): The friction coefficient. Defaults to 0.5.
250
-
251
- Returns:
252
- SoftContactsParams: A SoftContactsParams instance with the specified parameters.
253
- """
254
-
255
- return SoftContactsParams(
256
- K=jnp.array(K, dtype=float),
257
- D=jnp.array(D, dtype=float),
258
- mu=jnp.array(mu, dtype=float),
259
- )
260
-
261
-
262
- @jax_dataclasses.pytree_dataclass
263
- class SoftContacts:
264
- """Soft contacts model."""
265
-
266
- parameters: SoftContactsParams = dataclasses.field(
267
- default_factory=SoftContactsParams
268
- )
269
-
270
- terrain: Terrain = dataclasses.field(default_factory=FlatTerrain)
271
-
272
- def contact_model(
273
- self,
274
- position: jtp.Vector,
275
- velocity: jtp.Vector,
276
- tangential_deformation: jtp.Vector,
277
- ) -> Tuple[jtp.Vector, jtp.Vector]:
278
- """
279
- Compute the contact forces and material deformation rate.
280
-
281
- Args:
282
- position (jtp.Vector): The position of the collidable point.
283
- velocity (jtp.Vector): The linear velocity of the collidable point.
284
- tangential_deformation (jtp.Vector): The tangential deformation.
285
-
286
- Returns:
287
- Tuple[jtp.Vector, jtp.Vector]: A tuple containing the contact force and material deformation rate.
288
- """
289
-
290
- # Short name of parameters
291
- K = self.parameters.K
292
- D = self.parameters.D
293
- μ = self.parameters.mu
294
-
295
- # Material 3D tangential deformation and its derivative
296
- m = tangential_deformation.squeeze()
297
- ṁ = jnp.zeros_like(m)
298
-
299
- # Note: all the small hardcoded tolerances in this method have been introduced
300
- # to allow jax differentiating through this algorithm. They should not affect
301
- # the accuracy of the simulation, although they might make it less readable.
302
-
303
- # ========================
304
- # Normal force computation
305
- # ========================
306
-
307
- # Unpack the position of the collidable point
308
- px, py, pz = W_p_C = position.squeeze()
309
- vx, vy, vz = W_ṗ_C = velocity.squeeze()
310
-
311
- # Compute the terrain normal and the contact depth
312
- n̂ = self.terrain.normal(x=px, y=py).squeeze()
313
- h = jnp.array([0, 0, self.terrain.height(x=px, y=py) - pz])
314
-
315
- # Compute the penetration depth normal to the terrain
316
- δ = jnp.maximum(0.0, jnp.dot(h, n̂))
317
-
318
- # Compute the penetration normal velocity
319
- δ̇ = -jnp.dot(W_ṗ_C, n̂)
320
-
321
- # Non-linear spring-damper model.
322
- # This is the force magnitude along the direction normal to the terrain.
323
- force_normal_mag = jax.lax.select(
324
- pred=δ >= 1e-9,
325
- on_true=jnp.sqrt(δ + 1e-12) * (K * δ + D * δ̇),
326
- on_false=jnp.array(0.0),
327
- )
328
-
329
- # Prevent negative normal forces that might occur when δ̇ is largely negative
330
- force_normal_mag = jnp.maximum(0.0, force_normal_mag)
331
-
332
- # Compute the 3D linear force in C[W] frame
333
- force_normal = force_normal_mag * n̂
334
-
335
- # ====================================
336
- # No friction and no tangential forces
337
- # ====================================
338
-
339
- # Compute the adjoint C[W]->W for transforming 6D forces from mixed to inertial.
340
- # Note: this is equal to the 6D velocities transform: CW_X_W.transpose().
341
- W_Xf_CW = jnp.vstack(
342
- [
343
- jnp.block([jnp.eye(3), jnp.zeros(shape=(3, 3))]),
344
- jnp.block([Skew.wedge(W_p_C), jnp.eye(3)]),
345
- ]
346
- )
347
-
348
- def with_no_friction():
349
- # Compute 6D mixed force in C[W]
350
- CW_f_lin = force_normal
351
- CW_f = jnp.hstack([force_normal, jnp.zeros_like(CW_f_lin)])
352
-
353
- # Compute lin-ang 6D forces (inertial representation)
354
- W_f = W_Xf_CW @ CW_f
355
-
356
- return W_f, ṁ
357
-
358
- # =========================
359
- # Compute tangential forces
360
- # =========================
361
-
362
- def with_friction():
363
- # Initialize the tangential deformation rate ṁ.
364
- # For inactive contacts with m≠0, this is the dynamics of the material
365
- # relaxation converging exponentially to steady state.
366
- ṁ = (-K / D) * m
367
-
368
- # Check if the collidable point is below ground.
369
- # Note: when δ=0, we consider the point still not it contact such that
370
- # we prevent divisions by 0 in the computations below.
371
- active_contact = pz < self.terrain.height(x=px, y=py)
372
-
373
- def above_terrain():
374
- return jnp.zeros(6), ṁ
375
-
376
- def below_terrain():
377
- # Decompose the velocity in normal and tangential components
378
- v_normal = jnp.dot(W_ṗ_C, n̂) * n̂
379
- v_tangential = W_ṗ_C - v_normal
380
-
381
- # Compute the tangential force. If inside the friction cone, the contact
382
- f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential)
383
-
384
- def sticking_contact():
385
- # Sum the normal and tangential forces, and create the 6D force
386
- CW_f_stick = force_normal + f_tangential
387
- CW_f = jnp.hstack([CW_f_stick, jnp.zeros(3)])
388
-
389
- # In this case the 3D material deformation is the tangential velocity
390
- ṁ = v_tangential
391
-
392
- # Return the 6D force in the contact frame and
393
- # the deformation derivative
394
- return CW_f, ṁ
395
-
396
- def slipping_contact():
397
- # Clip the tangential force if too small, allowing jax to
398
- # differentiate through the norm computation
399
- f_tangential_no_nan = jax.lax.select(
400
- pred=f_tangential.dot(f_tangential) >= 1e-9**2,
401
- on_true=f_tangential,
402
- on_false=jnp.array([1e-12, 0, 0]),
403
- )
404
-
405
- # Project the force to the friction cone boundary
406
- f_tangential_projected = (μ * force_normal_mag) * (
407
- f_tangential / jnp.linalg.norm(f_tangential_no_nan)
408
- )
409
-
410
- # Sum the normal and tangential forces, and create the 6D force
411
- CW_f_slip = force_normal + f_tangential_projected
412
- CW_f = jnp.hstack([CW_f_slip, jnp.zeros(3)])
413
-
414
- # Correct the material deformation derivative for slipping contacts.
415
- # Basically we compute ṁ such that we get `f_tangential` on the cone
416
- # given the current (m, δ).
417
- ε = 1e-9
418
- δε = jnp.maximum(δ, ε)
419
- α = -K * jnp.sqrt(δε)
420
- β = -D * jnp.sqrt(δε)
421
- ṁ = (f_tangential_projected - α * m) / β
422
-
423
- # Return the 6D force in the contact frame and
424
- # the deformation derivative
425
- return CW_f, ṁ
426
-
427
- CW_f, ṁ = jax.lax.cond(
428
- pred=f_tangential.dot(f_tangential) > (μ * force_normal_mag) ** 2,
429
- true_fun=lambda _: slipping_contact(),
430
- false_fun=lambda _: sticking_contact(),
431
- operand=None,
432
- )
433
-
434
- # Express the 6D force in the world frame
435
- W_f = W_Xf_CW @ CW_f
436
-
437
- # Return the 6D force in the world frame and the deformation derivative
438
- return W_f, ṁ
439
-
440
- # (W_f, ṁ)
441
- return jax.lax.cond(
442
- pred=active_contact,
443
- true_fun=lambda _: below_terrain(),
444
- false_fun=lambda _: above_terrain(),
445
- operand=None,
446
- )
447
-
448
- # (W_f, ṁ)
449
- return jax.lax.cond(
450
- pred=(μ == 0.0),
451
- true_fun=lambda _: with_no_friction(),
452
- false_fun=lambda _: with_friction(),
453
- operand=None,
454
- )