jaxsim 0.2.dev108__py3-none-any.whl → 0.2.dev166__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/api/ode.py ADDED
@@ -0,0 +1,280 @@
1
+ from typing import Any, Protocol
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import jaxlie
6
+
7
+ import jaxsim.physics.algos.soft_contacts
8
+ import jaxsim.typing as jtp
9
+ from jaxsim import VelRepr, integrators
10
+ from jaxsim.integrators.common import Time
11
+ from jaxsim.math.quaternion import Quaternion
12
+ from jaxsim.physics.algos.soft_contacts import SoftContactsState
13
+ from jaxsim.physics.model.physics_model_state import PhysicsModelState
14
+ from jaxsim.simulation.ode_data import ODEState
15
+
16
+ from . import contact as Contact
17
+ from . import data as Data
18
+ from . import model as Model
19
+
20
+
21
+ class SystemDynamicsFromModelAndData(Protocol):
22
+ def __call__(
23
+ self,
24
+ model: Model.JaxSimModel,
25
+ data: Data.JaxSimModelData,
26
+ **kwargs: dict[str, Any],
27
+ ) -> tuple[ODEState, dict[str, Any]]: ...
28
+
29
+
30
+ def wrap_system_dynamics_for_integration(
31
+ model: Model.JaxSimModel,
32
+ data: Data.JaxSimModelData,
33
+ *,
34
+ system_dynamics: SystemDynamicsFromModelAndData,
35
+ **kwargs,
36
+ ) -> jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]:
37
+ """
38
+ Wrap generic system dynamics operating on `JaxSimModel` and `JaxSimModelData`
39
+ for integration with `jaxsim.integrators`.
40
+
41
+ Args:
42
+ model: The model to consider.
43
+ data: The data of the considered model.
44
+ system_dynamics: The system dynamics to wrap.
45
+ **kwargs: Additional kwargs to close over the system dynamics.
46
+
47
+ Returns:
48
+ The system dynamics closed over the model, the data, and the additional kwargs.
49
+ """
50
+
51
+ # We allow to close `system_dynamics` over additional kwargs.
52
+ kwargs_closed = kwargs
53
+
54
+ def f(x: ODEState, t: Time, **kwargs) -> tuple[ODEState, dict[str, Any]]:
55
+
56
+ # Close f over the `data` parameter.
57
+ with data.editable(validate=True) as data_rw:
58
+ data_rw.state = x
59
+ data_rw.time_ns = jnp.array(t * 1e9).astype(jnp.uint64)
60
+
61
+ # Close f over the `model` parameter.
62
+ return system_dynamics(model=model, data=data_rw, **kwargs_closed | kwargs)
63
+
64
+ f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]
65
+ return f
66
+
67
+
68
+ # ==================================
69
+ # Functions defining system dynamics
70
+ # ==================================
71
+
72
+
73
+ @jax.jit
74
+ def system_velocity_dynamics(
75
+ model: Model.JaxSimModel,
76
+ data: Data.JaxSimModelData,
77
+ *,
78
+ joint_forces: jtp.Vector | None = None,
79
+ external_forces: jtp.Vector | None = None,
80
+ ) -> tuple[jtp.Vector, jtp.Vector, jtp.Matrix, dict[str, Any]]:
81
+ """
82
+ Compute the dynamics of the system velocity.
83
+
84
+ Args:
85
+ model: The model to consider.
86
+ data: The data of the considered model.
87
+ joint_forces: The joint forces to apply.
88
+ external_forces: The external forces to apply to the links.
89
+
90
+ Returns:
91
+ A tuple containing the derivative of the base 6D velocity in inertial-fixed
92
+ representation, the derivative of the joint velocities, the derivative of
93
+ the material deformation, and the dictionary of auxiliary data returned by
94
+ the system dynamics evalutation.
95
+ """
96
+
97
+ # Build joint torques if not provided
98
+ τ = (
99
+ jnp.atleast_1d(joint_forces.squeeze())
100
+ if joint_forces is not None
101
+ else jnp.zeros_like(data.joint_positions())
102
+ ).astype(float)
103
+
104
+ # Build external forces if not provided
105
+ f_ext = (
106
+ jnp.atleast_2d(external_forces.squeeze())
107
+ if external_forces is not None
108
+ else jnp.zeros((model.number_of_links(), 6))
109
+ ).astype(float)
110
+
111
+ # ======================
112
+ # Compute contact forces
113
+ # ======================
114
+
115
+ # Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
116
+ # with the terrain.
117
+ W_f_Li_terrain = jnp.zeros_like(f_ext).astype(float)
118
+
119
+ # Initialize the 6D contact forces W_f ∈ ℝ^{n_c × 3} applied to collidable points,
120
+ # expressed in the world frame.
121
+ W_f_Ci = None
122
+
123
+ # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
124
+ ṁ = jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float)
125
+
126
+ if model.physics_model.gc.body.size > 0:
127
+ # Compute the position and linear velocities (mixed representation) of
128
+ # all collidable points belonging to the robot.
129
+ W_p_Ci, W_ṗ_Ci = Contact.collidable_point_kinematics(model=model, data=data)
130
+
131
+ # Compute the 3D forces applied to each collidable point.
132
+ W_f_Ci, ṁ = jax.vmap(
133
+ lambda p, ṗ, m: jaxsim.physics.algos.soft_contacts.SoftContacts(
134
+ parameters=data.soft_contacts_params, terrain=model.terrain
135
+ ).contact_model(position=p, velocity=ṗ, tangential_deformation=m)
136
+ )(W_p_Ci, W_ṗ_Ci, data.state.soft_contacts.tangential_deformation.T)
137
+
138
+ # Sum the forces of all collidable points rigidly attached to a body.
139
+ # Since the contact forces W_f_Ci are expressed in the world frame,
140
+ # we don't need any coordinate transformation.
141
+ W_f_Li_terrain = jax.vmap(
142
+ lambda nc: (
143
+ jnp.vstack(jnp.equal(model.physics_model.gc.body, nc).astype(int))
144
+ * W_f_Ci
145
+ ).sum(axis=0)
146
+ )(jnp.arange(model.number_of_links()))
147
+
148
+ # ====================
149
+ # Enforce joint limits
150
+ # ====================
151
+
152
+ # TODO: enforce joint limits
153
+ τ_position_limit = jnp.zeros_like(τ).astype(float)
154
+
155
+ # ====================
156
+ # Joint friction model
157
+ # ====================
158
+
159
+ τ_friction = jnp.zeros_like(τ).astype(float)
160
+
161
+ if model.dofs() > 0:
162
+ # Static and viscous joint friction parameters
163
+ kc = jnp.array(list(model.physics_model._joint_friction_static.values()))
164
+ kv = jnp.array(list(model.physics_model._joint_friction_viscous.values()))
165
+
166
+ # Compute the joint friction torque
167
+ τ_friction = -(
168
+ jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_positions)
169
+ + jnp.diag(kv) @ data.state.physics_model.joint_velocities
170
+ )
171
+
172
+ # ========================
173
+ # Compute forward dynamics
174
+ # ========================
175
+
176
+ # Compute the total joint forces
177
+ τ_total = τ + τ_friction + τ_position_limit
178
+
179
+ # Compute the total external 6D forces applied to the links
180
+ W_f_L_total = f_ext + W_f_Li_terrain
181
+
182
+ # - Joint accelerations: s̈ ∈ ℝⁿ
183
+ # - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
184
+ with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
185
+ W_v̇_WB, s̈ = Model.forward_dynamics_aba(
186
+ model=model,
187
+ data=data,
188
+ joint_forces=τ_total,
189
+ external_forces=W_f_L_total,
190
+ )
191
+
192
+ return W_v̇_WB, s̈, ṁ.T, dict()
193
+
194
+
195
+ @jax.jit
196
+ def system_position_dynamics(
197
+ model: Model.JaxSimModel, data: Data.JaxSimModelData
198
+ ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
199
+ """
200
+ Compute the dynamics of the system position.
201
+
202
+ Args:
203
+ model: The model to consider.
204
+ data: The data of the considered model.
205
+
206
+ Returns:
207
+ A tuple containing the derivative of the base position, the derivative of the
208
+ base quaternion, and the derivative of the joint positions.
209
+ """
210
+
211
+ ṡ = data.state.physics_model.joint_velocities
212
+ W_Q_B = data.state.physics_model.base_quaternion
213
+
214
+ with data.switch_velocity_representation(velocity_representation=VelRepr.Mixed):
215
+ W_ṗ_B = data.base_velocity()[0:3]
216
+
217
+ with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
218
+ W_ω_WB = data.base_velocity()[3:6]
219
+
220
+ W_Q̇_B = Quaternion.derivative(
221
+ quaternion=W_Q_B,
222
+ omega=W_ω_WB,
223
+ omega_in_body_fixed=False,
224
+ ).squeeze()
225
+
226
+ return W_ṗ_B, W_Q̇_B, ṡ
227
+
228
+
229
+ @jax.jit
230
+ def system_dynamics(
231
+ model: Model.JaxSimModel,
232
+ data: Data.JaxSimModelData,
233
+ *,
234
+ joint_forces: jtp.Vector | None = None,
235
+ external_forces: jtp.Vector | None = None,
236
+ ) -> tuple[ODEState, dict[str, Any]]:
237
+ """
238
+ Compute the dynamics of the system.
239
+
240
+ Args:
241
+ model: The model to consider.
242
+ data: The data of the considered model.
243
+ joint_forces: The joint forces to apply.
244
+ external_forces: The external forces to apply to the links.
245
+
246
+ Returns:
247
+ A tuple with an `ODEState` object storing in each of its attributes the
248
+ corresponding derivative, and the dictionary of auxiliary data returned
249
+ by the system dynamics evaluation.
250
+ """
251
+
252
+ # Compute the accelerations and the material deformation rate.
253
+ W_v̇_WB, s̈, ṁ, aux_dict = system_velocity_dynamics(
254
+ model=model,
255
+ data=data,
256
+ joint_forces=joint_forces,
257
+ external_forces=external_forces,
258
+ )
259
+
260
+ # Extract the velocities.
261
+ W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(model=model, data=data)
262
+
263
+ # Create an ODEState object populated with the derivative of each leaf.
264
+ # Our integrators, operating on generic pytrees, will be able to handle it
265
+ # automatically as state derivative.
266
+ ode_state_derivative = ODEState.build(
267
+ physics_model_state=PhysicsModelState.build(
268
+ base_position=W_ṗ_B,
269
+ base_quaternion=W_Q̇_B,
270
+ joint_positions=ṡ,
271
+ base_linear_velocity=W_v̇_WB[0:3],
272
+ base_angular_velocity=W_v̇_WB[3:6],
273
+ joint_velocities=s̈,
274
+ ),
275
+ soft_contacts_state=SoftContactsState.build(
276
+ tangential_deformation=ṁ,
277
+ ),
278
+ )
279
+
280
+ return ode_state_derivative, aux_dict
@@ -0,0 +1,2 @@
1
+ from . import fixed_step
2
+ from .common import Integrator, Time, TimeStep