jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1.dev401.dist-info/METADATA +0 -167
  88. jaxsim-0.1.dev401.dist-info/RECORD +0 -64
  89. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/rbda/rnea.py ADDED
@@ -0,0 +1,237 @@
1
+ from typing import Tuple
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import jaxlie
6
+
7
+ import jaxsim.api as js
8
+ import jaxsim.typing as jtp
9
+ from jaxsim.math import Adjoint, Cross, Quaternion, StandardGravity
10
+
11
+ from . import utils
12
+
13
+
14
+ def rnea(
15
+ model: js.model.JaxSimModel,
16
+ *,
17
+ base_position: jtp.Vector,
18
+ base_quaternion: jtp.Vector,
19
+ joint_positions: jtp.Vector,
20
+ base_linear_velocity: jtp.Vector,
21
+ base_angular_velocity: jtp.Vector,
22
+ joint_velocities: jtp.Vector,
23
+ base_linear_acceleration: jtp.Vector | None = None,
24
+ base_angular_acceleration: jtp.Vector | None = None,
25
+ joint_accelerations: jtp.Vector | None = None,
26
+ link_forces: jtp.Matrix | None = None,
27
+ standard_gravity: jtp.FloatLike = StandardGravity,
28
+ ) -> Tuple[jtp.Vector, jtp.Vector]:
29
+ """
30
+ Compute inverse dynamics using the Recursive Newton-Euler Algorithm (RNEA).
31
+
32
+ Args:
33
+ model: The model to consider.
34
+ base_position: The position of the base link.
35
+ base_quaternion: The quaternion of the base link.
36
+ joint_positions: The positions of the joints.
37
+ base_linear_velocity:
38
+ The linear velocity of the base link in inertial-fixed representation.
39
+ base_angular_velocity:
40
+ The angular velocity of the base link in inertial-fixed representation.
41
+ joint_velocities: The velocities of the joints.
42
+ base_linear_acceleration:
43
+ The linear acceleration of the base link in inertial-fixed representation.
44
+ base_angular_acceleration:
45
+ The angular acceleration of the base link in inertial-fixed representation.
46
+ joint_accelerations: The accelerations of the joints.
47
+ link_forces:
48
+ The forces applied to the links expressed in the world frame.
49
+ standard_gravity: The standard gravity constant.
50
+
51
+ Returns:
52
+ A tuple containing the 6D force applied to the base link expressed in the
53
+ world frame and the joint forces that, when applied respectively to the base
54
+ link and joints, produce the given base and joint accelerations.
55
+ """
56
+
57
+ W_p_B, W_Q_B, s, W_v_WB, ṡ, W_v̇_WB, s̈, _, W_f, W_g = utils.process_inputs(
58
+ model=model,
59
+ base_position=base_position,
60
+ base_quaternion=base_quaternion,
61
+ joint_positions=joint_positions,
62
+ base_linear_velocity=base_linear_velocity,
63
+ base_angular_velocity=base_angular_velocity,
64
+ joint_velocities=joint_velocities,
65
+ base_linear_acceleration=base_linear_acceleration,
66
+ base_angular_acceleration=base_angular_acceleration,
67
+ joint_accelerations=joint_accelerations,
68
+ link_forces=link_forces,
69
+ standard_gravity=standard_gravity,
70
+ )
71
+
72
+ W_g = jnp.atleast_2d(W_g).T
73
+ W_v_WB = jnp.atleast_2d(W_v_WB).T
74
+ W_v̇_WB = jnp.atleast_2d(W_v̇_WB).T
75
+
76
+ # Get the 6D spatial inertia matrices of all links.
77
+ M = js.model.link_spatial_inertia_matrices(model=model)
78
+
79
+ # Get the parent array λ(i).
80
+ # Note: λ(0) must not be used, it's initialized to -1.
81
+ λ = model.kin_dyn_parameters.parent_array
82
+
83
+ # Compute the base transform.
84
+ W_H_B = jaxlie.SE3.from_rotation_and_translation(
85
+ rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
86
+ translation=W_p_B,
87
+ )
88
+
89
+ # Compute 6D transforms of the base velocity.
90
+ W_X_B = W_H_B.adjoint()
91
+ B_X_W = W_H_B.inverse().adjoint()
92
+
93
+ # Compute the parent-to-child adjoints and the motion subspaces of the joints.
94
+ # These transforms define the relative kinematics of the entire model, including
95
+ # the base transform for both floating-base and fixed-base models.
96
+ i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
97
+ joint_positions=s, base_transform=W_H_B.as_matrix()
98
+ )
99
+
100
+ # Allocate buffers.
101
+ v = jnp.zeros(shape=(model.number_of_links(), 6, 1))
102
+ a = jnp.zeros(shape=(model.number_of_links(), 6, 1))
103
+ f = jnp.zeros(shape=(model.number_of_links(), 6, 1))
104
+
105
+ # Allocate the buffer of transforms link -> base.
106
+ i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
107
+ i_X_0 = i_X_0.at[0].set(jnp.eye(6))
108
+
109
+ # Initialize the acceleration of the base link.
110
+ a_0 = -B_X_W @ W_g
111
+ a = a.at[0].set(a_0)
112
+
113
+ if model.floating_base():
114
+
115
+ # Base velocity v₀ in body-fixed representation.
116
+ v_0 = B_X_W @ W_v_WB
117
+ v = v.at[0].set(v_0)
118
+
119
+ # Base acceleration a₀ in body-fixed representation w/o gravity.
120
+ a_0 = B_X_W @ (W_v̇_WB - W_g)
121
+ a = a.at[0].set(a_0)
122
+
123
+ # Force applied to the base link that produce the base acceleration w/o gravity.
124
+ f_0 = (
125
+ M[0] @ a[0]
126
+ + Cross.vx_star(v[0]) @ M[0] @ v[0]
127
+ - W_X_B.T @ jnp.vstack(W_f[0])
128
+ )
129
+ f = f.at[0].set(f_0)
130
+
131
+ # ======
132
+ # Pass 1
133
+ # ======
134
+
135
+ ForwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax]
136
+ forward_pass_carry: ForwardPassCarry = (v, a, i_X_0, f)
137
+
138
+ def forward_pass(
139
+ carry: ForwardPassCarry, i: jtp.Int
140
+ ) -> Tuple[ForwardPassCarry, None]:
141
+
142
+ ii = i - 1
143
+ v, a, i_X_0, f = carry
144
+
145
+ # Project the joint velocity into its motion subspace.
146
+ vJ = S[i] * ṡ[ii]
147
+
148
+ # Propagate the link velocity.
149
+ v_i = i_X_λi[i] @ v[λ[i]] + vJ
150
+ v = v.at[i].set(v_i)
151
+
152
+ # Propagate the link acceleration.
153
+ a_i = i_X_λi[i] @ a[λ[i]] + S[i] * s̈[ii] + Cross.vx(v[i]) @ vJ
154
+ a = a.at[i].set(a_i)
155
+
156
+ # Compute the link-to-base transform.
157
+ i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
158
+ i_X_0 = i_X_0.at[i].set(i_X_0_i)
159
+
160
+ # Compute link-to-world transform for the 6D force.
161
+ i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
162
+
163
+ # Compute the force acting on the link.
164
+ f_i = (
165
+ M[i] @ a[i]
166
+ + Cross.vx_star(v[i]) @ M[i] @ v[i]
167
+ - i_Xf_W @ jnp.vstack(W_f[i])
168
+ )
169
+ f = f.at[i].set(f_i)
170
+
171
+ return (v, a, i_X_0, f), None
172
+
173
+ (v, a, i_X_0, f), _ = (
174
+ jax.lax.scan(
175
+ f=forward_pass,
176
+ init=forward_pass_carry,
177
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
178
+ )
179
+ if model.number_of_links() > 1
180
+ else [(v, a, i_X_0, f), None]
181
+ )
182
+
183
+ # ======
184
+ # Pass 2
185
+ # ======
186
+
187
+ τ = jnp.zeros_like(s)
188
+
189
+ BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
190
+ backward_pass_carry: BackwardPassCarry = (τ, f)
191
+
192
+ def backward_pass(
193
+ carry: BackwardPassCarry, i: jtp.Int
194
+ ) -> Tuple[BackwardPassCarry, None]:
195
+
196
+ ii = i - 1
197
+ τ, f = carry
198
+
199
+ # Project the 6D force to the DoF of the joint.
200
+ τ_i = S[i].T @ f[i]
201
+ τ = τ.at[ii].set(τ_i.squeeze())
202
+
203
+ # Propagate the force to the parent link.
204
+ def update_f(f: jtp.MatrixJax) -> jtp.MatrixJax:
205
+
206
+ f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
207
+ f = f.at[λ[i]].set(f_λi)
208
+
209
+ return f
210
+
211
+ f = jax.lax.cond(
212
+ pred=jnp.logical_or(λ[i] != 0, model.floating_base()),
213
+ true_fun=update_f,
214
+ false_fun=lambda f: f,
215
+ operand=f,
216
+ )
217
+
218
+ return (τ, f), None
219
+
220
+ (τ, f), _ = (
221
+ jax.lax.scan(
222
+ f=backward_pass,
223
+ init=backward_pass_carry,
224
+ xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
225
+ )
226
+ if model.number_of_links() > 1
227
+ else [(τ, f), None]
228
+ )
229
+
230
+ # ==============
231
+ # Adjust outputs
232
+ # ==============
233
+
234
+ # Express the base 6D force in the world frame.
235
+ W_f0 = B_X_W.T @ f[0]
236
+
237
+ return W_f0.squeeze(), jnp.atleast_1d(τ.squeeze())
@@ -0,0 +1,296 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import jax_dataclasses
8
+
9
+ import jaxsim.api as js
10
+ import jaxsim.typing as jtp
11
+ from jaxsim.math import Skew, StandardGravity
12
+ from jaxsim.terrain import FlatTerrain, Terrain
13
+ from jaxsim.utils import JaxsimDataclass
14
+
15
+
16
+ @jax_dataclasses.pytree_dataclass
17
+ class SoftContactsParams(JaxsimDataclass):
18
+ """Parameters of the soft contacts model."""
19
+
20
+ K: jtp.Float = dataclasses.field(
21
+ default_factory=lambda: jnp.array(1e6, dtype=float)
22
+ )
23
+
24
+ D: jtp.Float = dataclasses.field(
25
+ default_factory=lambda: jnp.array(2000, dtype=float)
26
+ )
27
+
28
+ mu: jtp.Float = dataclasses.field(
29
+ default_factory=lambda: jnp.array(0.5, dtype=float)
30
+ )
31
+
32
+ @staticmethod
33
+ def build(
34
+ K: jtp.FloatLike = 1e6, D: jtp.FloatLike = 2_000, mu: jtp.FloatLike = 0.5
35
+ ) -> SoftContactsParams:
36
+ """
37
+ Create a SoftContactsParams instance with specified parameters.
38
+
39
+ Args:
40
+ K: The stiffness parameter.
41
+ D: The damping parameter of the soft contacts model.
42
+ mu: The static friction coefficient.
43
+
44
+ Returns:
45
+ A SoftContactsParams instance with the specified parameters.
46
+ """
47
+
48
+ return SoftContactsParams(
49
+ K=jnp.array(K, dtype=float),
50
+ D=jnp.array(D, dtype=float),
51
+ mu=jnp.array(mu, dtype=float),
52
+ )
53
+
54
+ @staticmethod
55
+ def build_default_from_jaxsim_model(
56
+ model: js.model.JaxSimModel,
57
+ *,
58
+ standard_gravity: jtp.FloatLike = StandardGravity,
59
+ static_friction_coefficient: jtp.FloatLike = 0.5,
60
+ max_penetration: jtp.FloatLike = 0.001,
61
+ number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
62
+ damping_ratio: jtp.FloatLike = 1.0,
63
+ ) -> SoftContactsParams:
64
+ """
65
+ Create a SoftContactsParams instance with good default parameters.
66
+
67
+ Args:
68
+ model: The target model.
69
+ standard_gravity: The standard gravity constant.
70
+ static_friction_coefficient:
71
+ The static friction coefficient between the model and the terrain.
72
+ max_penetration: The maximum penetration depth.
73
+ number_of_active_collidable_points_steady_state:
74
+ The number of contacts supporting the weight of the model
75
+ in steady state.
76
+ damping_ratio: The ratio controlling the damping behavior.
77
+
78
+ Returns:
79
+ A `SoftContactsParams` instance with the specified parameters.
80
+
81
+ Note:
82
+ The `damping_ratio` parameter allows to operate on the following conditions:
83
+ - ξ > 1.0: over-damped
84
+ - ξ = 1.0: critically damped
85
+ - ξ < 1.0: under-damped
86
+ """
87
+
88
+ # Use symbols for input parameters
89
+ ξ = damping_ratio
90
+ δ_max = max_penetration
91
+ μc = static_friction_coefficient
92
+
93
+ # Compute the total mass of the model
94
+ m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum()
95
+
96
+ # Rename the standard gravity
97
+ g = standard_gravity
98
+
99
+ # Compute the average support force on each collidable point
100
+ f_average = m * g / number_of_active_collidable_points_steady_state
101
+
102
+ # Compute the stiffness to get the desired steady-state penetration
103
+ K = f_average / jnp.power(δ_max, 3 / 2)
104
+
105
+ # Compute the damping using the damping ratio
106
+ critical_damping = 2 * jnp.sqrt(K * m)
107
+ D = ξ * critical_damping
108
+
109
+ return SoftContactsParams.build(K=K, D=D, mu=μc)
110
+
111
+
112
+ @jax_dataclasses.pytree_dataclass
113
+ class SoftContacts:
114
+ """Soft contacts model."""
115
+
116
+ parameters: SoftContactsParams = dataclasses.field(
117
+ default_factory=SoftContactsParams
118
+ )
119
+
120
+ terrain: Terrain = dataclasses.field(default_factory=FlatTerrain)
121
+
122
+ def contact_model(
123
+ self,
124
+ position: jtp.Vector,
125
+ velocity: jtp.Vector,
126
+ tangential_deformation: jtp.Vector,
127
+ ) -> tuple[jtp.Vector, jtp.Vector]:
128
+ """
129
+ Compute the contact forces and material deformation rate.
130
+
131
+ Args:
132
+ position: The position of the collidable point.
133
+ velocity: The linear velocity of the collidable point.
134
+ tangential_deformation: The tangential deformation.
135
+
136
+ Returns:
137
+ A tuple containing the contact force and material deformation rate.
138
+ """
139
+
140
+ # Short name of parameters
141
+ K = self.parameters.K
142
+ D = self.parameters.D
143
+ μ = self.parameters.mu
144
+
145
+ # Material 3D tangential deformation and its derivative
146
+ m = tangential_deformation.squeeze()
147
+ ṁ = jnp.zeros_like(m)
148
+
149
+ # Note: all the small hardcoded tolerances in this method have been introduced
150
+ # to allow jax differentiating through this algorithm. They should not affect
151
+ # the accuracy of the simulation, although they might make it less readable.
152
+
153
+ # ========================
154
+ # Normal force computation
155
+ # ========================
156
+
157
+ # Unpack the position of the collidable point
158
+ px, py, pz = W_p_C = position.squeeze()
159
+ vx, vy, vz = W_ṗ_C = velocity.squeeze()
160
+
161
+ # Compute the terrain normal and the contact depth
162
+ n̂ = self.terrain.normal(x=px, y=py).squeeze()
163
+ h = jnp.array([0, 0, self.terrain.height(x=px, y=py) - pz])
164
+
165
+ # Compute the penetration depth normal to the terrain
166
+ δ = jnp.maximum(0.0, jnp.dot(h, n̂))
167
+
168
+ # Compute the penetration normal velocity
169
+ δ̇ = -jnp.dot(W_ṗ_C, n̂)
170
+
171
+ # Non-linear spring-damper model.
172
+ # This is the force magnitude along the direction normal to the terrain.
173
+ force_normal_mag = jax.lax.select(
174
+ pred=δ >= 1e-9,
175
+ on_true=jnp.sqrt(δ + 1e-12) * (K * δ + D * δ̇),
176
+ on_false=jnp.array(0.0),
177
+ )
178
+
179
+ # Prevent negative normal forces that might occur when δ̇ is largely negative
180
+ force_normal_mag = jnp.maximum(0.0, force_normal_mag)
181
+
182
+ # Compute the 3D linear force in C[W] frame
183
+ force_normal = force_normal_mag * n̂
184
+
185
+ # ====================================
186
+ # No friction and no tangential forces
187
+ # ====================================
188
+
189
+ # Compute the adjoint C[W]->W for transforming 6D forces from mixed to inertial.
190
+ # Note: this is equal to the 6D velocities transform: CW_X_W.transpose().
191
+ W_Xf_CW = jnp.vstack(
192
+ [
193
+ jnp.block([jnp.eye(3), jnp.zeros(shape=(3, 3))]),
194
+ jnp.block([Skew.wedge(W_p_C), jnp.eye(3)]),
195
+ ]
196
+ )
197
+
198
+ def with_no_friction():
199
+ # Compute 6D mixed force in C[W]
200
+ CW_f_lin = force_normal
201
+ CW_f = jnp.hstack([force_normal, jnp.zeros_like(CW_f_lin)])
202
+
203
+ # Compute lin-ang 6D forces (inertial representation)
204
+ W_f = W_Xf_CW @ CW_f
205
+
206
+ return W_f, ṁ
207
+
208
+ # =========================
209
+ # Compute tangential forces
210
+ # =========================
211
+
212
+ def with_friction():
213
+ # Initialize the tangential deformation rate ṁ.
214
+ # For inactive contacts with m≠0, this is the dynamics of the material
215
+ # relaxation converging exponentially to steady state.
216
+ ṁ = (-K / D) * m
217
+
218
+ # Check if the collidable point is below ground.
219
+ # Note: when δ=0, we consider the point still not it contact such that
220
+ # we prevent divisions by 0 in the computations below.
221
+ active_contact = pz < self.terrain.height(x=px, y=py)
222
+
223
+ def above_terrain():
224
+ return jnp.zeros(6), ṁ
225
+
226
+ def below_terrain():
227
+ # Decompose the velocity in normal and tangential components
228
+ v_normal = jnp.dot(W_ṗ_C, n̂) * n̂
229
+ v_tangential = W_ṗ_C - v_normal
230
+
231
+ # Compute the tangential force. If inside the friction cone, the contact
232
+ f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential)
233
+
234
+ def sticking_contact():
235
+ # Sum the normal and tangential forces, and create the 6D force
236
+ CW_f_stick = force_normal + f_tangential
237
+ CW_f = jnp.hstack([CW_f_stick, jnp.zeros(3)])
238
+
239
+ # In this case the 3D material deformation is the tangential velocity
240
+ ṁ = v_tangential
241
+
242
+ # Return the 6D force in the contact frame and
243
+ # the deformation derivative
244
+ return CW_f, ṁ
245
+
246
+ def slipping_contact():
247
+ # Project the force to the friction cone boundary
248
+ f_tangential_projected = (μ * force_normal_mag) * (
249
+ f_tangential / jnp.maximum(jnp.linalg.norm(f_tangential), 1e-9)
250
+ )
251
+
252
+ # Sum the normal and tangential forces, and create the 6D force
253
+ CW_f_slip = force_normal + f_tangential_projected
254
+ CW_f = jnp.hstack([CW_f_slip, jnp.zeros(3)])
255
+
256
+ # Correct the material deformation derivative for slipping contacts.
257
+ # Basically we compute ṁ such that we get `f_tangential` on the cone
258
+ # given the current (m, δ).
259
+ ε = 1e-9
260
+ δε = jnp.maximum(δ, ε)
261
+ α = -K * jnp.sqrt(δε)
262
+ β = -D * jnp.sqrt(δε)
263
+ ṁ = (f_tangential_projected - α * m) / β
264
+
265
+ # Return the 6D force in the contact frame and
266
+ # the deformation derivative
267
+ return CW_f, ṁ
268
+
269
+ CW_f, ṁ = jax.lax.cond(
270
+ pred=f_tangential.dot(f_tangential) > (μ * force_normal_mag) ** 2,
271
+ true_fun=lambda _: slipping_contact(),
272
+ false_fun=lambda _: sticking_contact(),
273
+ operand=None,
274
+ )
275
+
276
+ # Express the 6D force in the world frame
277
+ W_f = W_Xf_CW @ CW_f
278
+
279
+ # Return the 6D force in the world frame and the deformation derivative
280
+ return W_f, ṁ
281
+
282
+ # (W_f, ṁ)
283
+ return jax.lax.cond(
284
+ pred=active_contact,
285
+ true_fun=lambda _: below_terrain(),
286
+ false_fun=lambda _: above_terrain(),
287
+ operand=None,
288
+ )
289
+
290
+ # (W_f, ṁ)
291
+ return jax.lax.cond(
292
+ pred=(μ == 0.0),
293
+ true_fun=lambda _: with_no_friction(),
294
+ false_fun=lambda _: with_friction(),
295
+ operand=None,
296
+ )
jaxsim/rbda/utils.py ADDED
@@ -0,0 +1,152 @@
1
+ import jax.numpy as jnp
2
+
3
+ import jaxsim.api as js
4
+ import jaxsim.typing as jtp
5
+ from jaxsim.math import StandardGravity
6
+
7
+
8
+ def process_inputs(
9
+ model: js.model.JaxSimModel,
10
+ *,
11
+ base_position: jtp.VectorLike | None = None,
12
+ base_quaternion: jtp.VectorLike | None = None,
13
+ joint_positions: jtp.VectorLike | None = None,
14
+ base_linear_velocity: jtp.VectorLike | None = None,
15
+ base_angular_velocity: jtp.VectorLike | None = None,
16
+ joint_velocities: jtp.VectorLike | None = None,
17
+ base_linear_acceleration: jtp.VectorLike | None = None,
18
+ base_angular_acceleration: jtp.VectorLike | None = None,
19
+ joint_accelerations: jtp.VectorLike | None = None,
20
+ joint_forces: jtp.VectorLike | None = None,
21
+ link_forces: jtp.MatrixLike | None = None,
22
+ standard_gravity: jtp.VectorLike | None = None,
23
+ ) -> tuple[
24
+ jtp.Vector,
25
+ jtp.Vector,
26
+ jtp.Vector,
27
+ jtp.Vector,
28
+ jtp.Vector,
29
+ jtp.Vector,
30
+ jtp.Vector,
31
+ jtp.Vector,
32
+ jtp.Matrix,
33
+ jtp.Vector,
34
+ ]:
35
+ """
36
+ Adjust the inputs to rigid-body dynamics algorithms.
37
+
38
+ Args:
39
+ model: The model to consider.
40
+ base_position: The position of the base link.
41
+ base_quaternion: The quaternion of the base link.
42
+ joint_positions: The positions of the joints.
43
+ base_linear_velocity: The linear velocity of the base link.
44
+ base_angular_velocity: The angular velocity of the base link.
45
+ joint_velocities: The velocities of the joints.
46
+ base_linear_acceleration: The linear acceleration of the base link.
47
+ base_angular_acceleration: The angular acceleration of the base link.
48
+ joint_accelerations: The accelerations of the joints.
49
+ joint_forces: The forces applied to the joints.
50
+ link_forces: The forces applied to the links.
51
+ standard_gravity: The standard gravity constant.
52
+
53
+ Returns:
54
+ The adjusted inputs.
55
+ """
56
+
57
+ dofs = model.dofs()
58
+ nl = model.number_of_links()
59
+
60
+ # Floating-base position.
61
+ W_p_B = base_position
62
+ W_Q_B = base_quaternion
63
+ s = joint_positions
64
+
65
+ # Floating-base velocity in inertial-fixed representation.
66
+ W_vl_WB = base_linear_velocity
67
+ W_ω_WB = base_angular_velocity
68
+ ṡ = joint_velocities
69
+
70
+ # Floating-base acceleration in inertial-fixed representation.
71
+ W_v̇l_WB = base_linear_acceleration
72
+ W_ω̇_WB = base_angular_acceleration
73
+ s̈ = joint_accelerations
74
+
75
+ # System dynamics inputs.
76
+ f = link_forces
77
+ τ = joint_forces
78
+
79
+ # Fill missing data and adjust dimensions.
80
+ s = jnp.atleast_1d(s.squeeze()) if s is not None else jnp.zeros(dofs)
81
+ ṡ = jnp.atleast_1d(ṡ.squeeze()) if ṡ is not None else jnp.zeros(dofs)
82
+ s̈ = jnp.atleast_1d(s̈.squeeze()) if s̈ is not None else jnp.zeros(dofs)
83
+ τ = jnp.atleast_1d(τ.squeeze()) if τ is not None else jnp.zeros(dofs)
84
+ W_vl_WB = jnp.atleast_1d(W_vl_WB.squeeze()) if W_vl_WB is not None else jnp.zeros(3)
85
+ W_v̇l_WB = jnp.atleast_1d(W_v̇l_WB.squeeze()) if W_v̇l_WB is not None else jnp.zeros(3)
86
+ W_p_B = jnp.atleast_1d(W_p_B.squeeze()) if W_p_B is not None else jnp.zeros(3)
87
+ W_ω_WB = jnp.atleast_1d(W_ω_WB.squeeze()) if W_ω_WB is not None else jnp.zeros(3)
88
+ W_ω̇_WB = jnp.atleast_1d(W_ω̇_WB.squeeze()) if W_ω̇_WB is not None else jnp.zeros(3)
89
+ f = jnp.atleast_2d(f.squeeze()) if f is not None else jnp.zeros(shape=(nl, 6))
90
+ W_Q_B = (
91
+ jnp.atleast_1d(W_Q_B.squeeze())
92
+ if W_Q_B is not None
93
+ else jnp.array([1.0, 0, 0, 0])
94
+ )
95
+ standard_gravity = (
96
+ jnp.array(standard_gravity).squeeze()
97
+ if standard_gravity is not None
98
+ else StandardGravity
99
+ )
100
+
101
+ if s.shape != (dofs,):
102
+ raise ValueError(s.shape, dofs)
103
+
104
+ if ṡ.shape != (dofs,):
105
+ raise ValueError(ṡ.shape, dofs)
106
+
107
+ if s̈.shape != (dofs,):
108
+ raise ValueError(s̈.shape, dofs)
109
+
110
+ if τ.shape != (dofs,):
111
+ raise ValueError(τ.shape, dofs)
112
+
113
+ if W_p_B.shape != (3,):
114
+ raise ValueError(W_p_B.shape, (3,))
115
+
116
+ if W_vl_WB.shape != (3,):
117
+ raise ValueError(W_vl_WB.shape, (3,))
118
+
119
+ if W_ω_WB.shape != (3,):
120
+ raise ValueError(W_ω_WB.shape, (3,))
121
+
122
+ if W_v̇l_WB.shape != (3,):
123
+ raise ValueError(W_v̇l_WB.shape, (3,))
124
+
125
+ if W_ω̇_WB.shape != (3,):
126
+ raise ValueError(W_ω̇_WB.shape, (3,))
127
+
128
+ if f.shape != (nl, 6):
129
+ raise ValueError(f.shape, (nl, 6))
130
+
131
+ if W_Q_B.shape != (4,):
132
+ raise ValueError(W_Q_B.shape, (4,))
133
+
134
+ # Pack the 6D base velocity and acceleration.
135
+ W_v_WB = jnp.hstack([W_vl_WB, W_ω_WB])
136
+ W_v̇_WB = jnp.hstack([W_v̇l_WB, W_ω̇_WB])
137
+
138
+ # Create the 6D gravity acceleration.
139
+ W_g = jnp.zeros(6).at[2].set(-standard_gravity)
140
+
141
+ return (
142
+ W_p_B.astype(float),
143
+ W_Q_B.astype(float),
144
+ s.astype(float),
145
+ W_v_WB.astype(float),
146
+ ṡ.astype(float),
147
+ W_v̇_WB.astype(float),
148
+ s̈.astype(float),
149
+ τ.astype(float),
150
+ f.astype(float),
151
+ W_g.astype(float),
152
+ )
@@ -0,0 +1,2 @@
1
+ from . import terrain
2
+ from .terrain import FlatTerrain, PlaneTerrain, Terrain