jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/simulation/ode.py
DELETED
@@ -1,290 +0,0 @@
|
|
1
|
-
from typing import Any, Dict, Tuple
|
2
|
-
|
3
|
-
import jax
|
4
|
-
import jax.numpy as jnp
|
5
|
-
import numpy as np
|
6
|
-
|
7
|
-
import jaxsim.typing as jtp
|
8
|
-
from jaxsim.physics import algos
|
9
|
-
from jaxsim.physics.algos.soft_contacts import (
|
10
|
-
SoftContacts,
|
11
|
-
SoftContactsParams,
|
12
|
-
collidable_points_pos_vel,
|
13
|
-
)
|
14
|
-
from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
|
15
|
-
from jaxsim.physics.model.physics_model import PhysicsModel
|
16
|
-
|
17
|
-
from . import ode_data
|
18
|
-
|
19
|
-
|
20
|
-
def compute_contact_forces(
|
21
|
-
physics_model: PhysicsModel,
|
22
|
-
ode_state: ode_data.ODEState,
|
23
|
-
soft_contacts_params: SoftContactsParams = SoftContactsParams(),
|
24
|
-
terrain: Terrain = FlatTerrain(),
|
25
|
-
) -> Tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix]:
|
26
|
-
"""
|
27
|
-
Compute the contact forces acting on the collidable points of the model.
|
28
|
-
|
29
|
-
Args:
|
30
|
-
physics_model: The physics model to consider.
|
31
|
-
ode_state: The state of the ODE corresponding to the physics model.
|
32
|
-
soft_contacts_params: The parameters of the soft contacts model.
|
33
|
-
terrain: The terrain model.
|
34
|
-
|
35
|
-
Returns:
|
36
|
-
A tuple containing:
|
37
|
-
- The contact forces expressed in the world frame acting on the model's links.
|
38
|
-
- The derivative of the tangential deformation of the terrain dynamics.
|
39
|
-
- The contact forces expressed in the world frame acting on the model's collidable points.
|
40
|
-
"""
|
41
|
-
|
42
|
-
# Compute position and linear mixed velocity of all model's collidable points
|
43
|
-
# collidable_points_kinematics
|
44
|
-
pos_cp, vel_cp = collidable_points_pos_vel(
|
45
|
-
model=physics_model,
|
46
|
-
q=ode_state.physics_model.joint_positions,
|
47
|
-
qd=ode_state.physics_model.joint_velocities,
|
48
|
-
xfb=ode_state.physics_model.xfb(),
|
49
|
-
)
|
50
|
-
|
51
|
-
# Compute the forces acting on the collidable points due to contact with
|
52
|
-
# the compliant ground surface. Apply vmap to process all points together.
|
53
|
-
contact_forces_points, tangential_deformation_dot = jax.vmap(
|
54
|
-
SoftContacts(parameters=soft_contacts_params, terrain=terrain).contact_model
|
55
|
-
)(pos_cp.T, vel_cp.T, ode_state.soft_contacts.tangential_deformation.T)
|
56
|
-
|
57
|
-
contact_forces_points = contact_forces_points.T
|
58
|
-
tangential_deformation_dot = tangential_deformation_dot.T
|
59
|
-
|
60
|
-
# Initialize the contact forces, one per body
|
61
|
-
contact_forces_links = jnp.zeros_like(
|
62
|
-
ode_data.ODEInput.zero(physics_model).physics_model.f_ext
|
63
|
-
)
|
64
|
-
|
65
|
-
# Combine the contact forces of all collidable points belonging to the same body
|
66
|
-
for body_idx in set(physics_model.gc.body):
|
67
|
-
body_idx = int(body_idx)
|
68
|
-
contact_forces_links = contact_forces_links.at[body_idx, :].set(
|
69
|
-
jnp.sum(contact_forces_points[:, physics_model.gc.body == body_idx], axis=1)
|
70
|
-
)
|
71
|
-
|
72
|
-
return contact_forces_links, tangential_deformation_dot, contact_forces_points.T
|
73
|
-
|
74
|
-
|
75
|
-
def dx_dt(
|
76
|
-
x: ode_data.ODEState,
|
77
|
-
t: jtp.Float | None,
|
78
|
-
physics_model: PhysicsModel,
|
79
|
-
soft_contacts_params: SoftContactsParams = SoftContactsParams(),
|
80
|
-
ode_input: ode_data.ODEInput | None = None,
|
81
|
-
terrain: Terrain = FlatTerrain(),
|
82
|
-
) -> Tuple[ode_data.ODEState, Dict[str, Any]]:
|
83
|
-
"""
|
84
|
-
Compute the state derivative of the ODE corresponding to the physics model.
|
85
|
-
|
86
|
-
Args:
|
87
|
-
x: The state of the ODE.
|
88
|
-
t: The current time.
|
89
|
-
physics_model: The physics model to consider.
|
90
|
-
soft_contacts_params: The parameters of the soft contacts model.
|
91
|
-
ode_input: The input of the ODE.
|
92
|
-
terrain: The terrain model.
|
93
|
-
|
94
|
-
Returns:
|
95
|
-
A tuple containing:
|
96
|
-
- The state derivative of the ODE.
|
97
|
-
- A dictionary containing auxiliary information.
|
98
|
-
"""
|
99
|
-
|
100
|
-
if t is not None and isinstance(t, np.ndarray) and t.size != 1:
|
101
|
-
raise ValueError(t.size)
|
102
|
-
|
103
|
-
# Initialize arguments
|
104
|
-
ode_state = x
|
105
|
-
ode_input = (
|
106
|
-
ode_input
|
107
|
-
if ode_input is not None
|
108
|
-
else ode_data.ODEInput.zero(physics_model=physics_model)
|
109
|
-
)
|
110
|
-
|
111
|
-
# ======================
|
112
|
-
# Compute contact forces
|
113
|
-
# ======================
|
114
|
-
|
115
|
-
# Initialize the collidable points contact forces
|
116
|
-
contact_forces_points = None
|
117
|
-
|
118
|
-
# Initialize the contact forces, one per body
|
119
|
-
contact_forces_links = jnp.zeros_like(ode_input.physics_model.f_ext)
|
120
|
-
|
121
|
-
# Initialize the derivative of the tangential deformation
|
122
|
-
tangential_deformation_dot = jnp.zeros_like(
|
123
|
-
ode_state.soft_contacts.tangential_deformation
|
124
|
-
)
|
125
|
-
|
126
|
-
if len(physics_model.gc.body) > 0:
|
127
|
-
(
|
128
|
-
contact_forces_links,
|
129
|
-
tangential_deformation_dot,
|
130
|
-
contact_forces_points,
|
131
|
-
) = compute_contact_forces(
|
132
|
-
physics_model=physics_model,
|
133
|
-
soft_contacts_params=soft_contacts_params,
|
134
|
-
ode_state=ode_state,
|
135
|
-
terrain=terrain,
|
136
|
-
)
|
137
|
-
|
138
|
-
# =====================
|
139
|
-
# Joint position limits
|
140
|
-
# =====================
|
141
|
-
|
142
|
-
if physics_model.dofs() > 0:
|
143
|
-
# Get the joint position limits
|
144
|
-
s_min, s_max = jnp.array(
|
145
|
-
[j.position_limit for j in physics_model.description.joints_dict.values()]
|
146
|
-
).T
|
147
|
-
|
148
|
-
# Get the spring/damper parameters of joint limits enforcement
|
149
|
-
k_damper = jnp.array(list(physics_model._joint_limit_damper.values()))
|
150
|
-
|
151
|
-
# Compute the joint torques that enforce joint limits
|
152
|
-
s = ode_state.physics_model.joint_positions
|
153
|
-
tau_min = jnp.where(s <= s_min, k_damper * (s_min - s), 0)
|
154
|
-
tau_max = jnp.where(s >= s_max, k_damper * (s_max - s), 0)
|
155
|
-
tau_limit = tau_max + tau_min
|
156
|
-
|
157
|
-
else:
|
158
|
-
tau_limit = jnp.zeros_like(ode_input.physics_model.tau)
|
159
|
-
|
160
|
-
# ==============
|
161
|
-
# Joint friction
|
162
|
-
# ==============
|
163
|
-
|
164
|
-
if physics_model.dofs() > 0:
|
165
|
-
# Static and viscous joint friction parameters
|
166
|
-
kc = jnp.array(list(physics_model._joint_friction_static.values()))
|
167
|
-
kv = jnp.array(list(physics_model._joint_friction_viscous.values()))
|
168
|
-
|
169
|
-
# Compute the joint friction torque
|
170
|
-
tau_friction = -(
|
171
|
-
jnp.diag(kc) @ jnp.sign(ode_state.physics_model.joint_positions)
|
172
|
-
+ jnp.diag(kv) @ ode_state.physics_model.joint_velocities
|
173
|
-
)
|
174
|
-
|
175
|
-
else:
|
176
|
-
tau_friction = jnp.zeros_like(ode_input.physics_model.tau)
|
177
|
-
|
178
|
-
# ========================
|
179
|
-
# Compute forward dynamics
|
180
|
-
# ========================
|
181
|
-
|
182
|
-
# Compute the total forces applied to the bodies
|
183
|
-
total_forces = ode_input.physics_model.f_ext + contact_forces_links
|
184
|
-
|
185
|
-
# Compute the joint torques to actuate
|
186
|
-
tau = ode_input.physics_model.tau + tau_friction + tau_limit
|
187
|
-
|
188
|
-
# Compute forward dynamics with the ABA algorithm
|
189
|
-
W_a_WB, qdd = algos.aba.aba(
|
190
|
-
model=physics_model,
|
191
|
-
xfb=ode_state.physics_model.xfb(),
|
192
|
-
q=ode_state.physics_model.joint_positions,
|
193
|
-
qd=ode_state.physics_model.joint_velocities,
|
194
|
-
tau=tau,
|
195
|
-
f_ext=total_forces,
|
196
|
-
)
|
197
|
-
|
198
|
-
# =========================================
|
199
|
-
# Compute the state derivative of base link
|
200
|
-
# =========================================
|
201
|
-
|
202
|
-
if not physics_model.is_floating_base:
|
203
|
-
W_Qd_B = jnp.zeros(4)
|
204
|
-
BW_v_WB = jnp.zeros(3)
|
205
|
-
|
206
|
-
else:
|
207
|
-
from jaxsim.math.conv import Convert
|
208
|
-
from jaxsim.math.quaternion import Quaternion
|
209
|
-
|
210
|
-
W_Qd_B = Quaternion.derivative(
|
211
|
-
quaternion=ode_state.physics_model.base_quaternion,
|
212
|
-
omega=ode_state.physics_model.base_angular_velocity,
|
213
|
-
omega_in_body_fixed=False,
|
214
|
-
).squeeze()
|
215
|
-
|
216
|
-
# Compute linear component of mixed velocity
|
217
|
-
BW_v_WB = Convert.velocities_threed(
|
218
|
-
v_6d=jnp.hstack(
|
219
|
-
[
|
220
|
-
ode_state.physics_model.base_linear_velocity,
|
221
|
-
ode_state.physics_model.base_angular_velocity,
|
222
|
-
]
|
223
|
-
),
|
224
|
-
p=ode_state.physics_model.base_position,
|
225
|
-
).squeeze()
|
226
|
-
|
227
|
-
# Derivative of xfb (floating-base state)
|
228
|
-
xd_fb = jnp.hstack([W_Qd_B, BW_v_WB, W_a_WB.squeeze()]).squeeze()
|
229
|
-
|
230
|
-
# =====================================
|
231
|
-
# Build the full derivative of ODEState
|
232
|
-
# =====================================
|
233
|
-
|
234
|
-
def fix_one_dof(vector: jtp.Vector) -> jtp.Vector | None:
|
235
|
-
"""Fix the shape of computed quantities for models with just 1 DoF."""
|
236
|
-
|
237
|
-
if vector is None:
|
238
|
-
return None
|
239
|
-
|
240
|
-
return jnp.array([vector]) if vector.shape == () else vector
|
241
|
-
|
242
|
-
# Fill the PhysicsModelState object included in the input ODEState to store the
|
243
|
-
# returned PhysicsModelState derivative
|
244
|
-
physics_model_state_derivative = ode_state.physics_model.replace(
|
245
|
-
joint_positions=fix_one_dof(ode_state.physics_model.joint_velocities.squeeze()),
|
246
|
-
joint_velocities=fix_one_dof(qdd.squeeze()),
|
247
|
-
base_quaternion=xd_fb.squeeze()[0:4],
|
248
|
-
base_position=xd_fb.squeeze()[4:7],
|
249
|
-
base_angular_velocity=xd_fb.squeeze()[10:13],
|
250
|
-
base_linear_velocity=xd_fb.squeeze()[7:10],
|
251
|
-
)
|
252
|
-
|
253
|
-
# Fill the SoftContactsState object included in the input ODEState to store the
|
254
|
-
# returned SoftContactsState derivative
|
255
|
-
soft_contacts_state_derivative = ode_state.soft_contacts.replace(
|
256
|
-
tangential_deformation=tangential_deformation_dot.squeeze(),
|
257
|
-
)
|
258
|
-
|
259
|
-
# We store the state derivative using the ODEState class so that the pytree
|
260
|
-
# structure remains consistent, allowing to use our generic pytree integrators
|
261
|
-
state_derivative = ode_data.ODEState(
|
262
|
-
physics_model=physics_model_state_derivative,
|
263
|
-
soft_contacts=soft_contacts_state_derivative,
|
264
|
-
)
|
265
|
-
|
266
|
-
# ===============================
|
267
|
-
# Build auxiliary data and return
|
268
|
-
# ===============================
|
269
|
-
|
270
|
-
# Real ODEInput containing the real joint forces that have been actuated and
|
271
|
-
# the total external forces (= original external forces + terrain + limits)
|
272
|
-
ode_input_real = ode_data.ODEInput(
|
273
|
-
physics_model=ode_data.PhysicsModelInput(tau=tau, f_ext=total_forces)
|
274
|
-
)
|
275
|
-
|
276
|
-
# Pack the inertial-fixed floating-base acceleration
|
277
|
-
W_nud_WB = jnp.hstack([W_a_WB.squeeze(), qdd.squeeze()])
|
278
|
-
|
279
|
-
# Build the auxiliary data
|
280
|
-
aux_dict = {
|
281
|
-
"model_acceleration": W_nud_WB,
|
282
|
-
"ode_input": ode_input,
|
283
|
-
"ode_input_real": ode_input_real,
|
284
|
-
"contact_forces_links": contact_forces_links,
|
285
|
-
"contact_forces_points": contact_forces_points,
|
286
|
-
"tangential_deformation_dot": tangential_deformation_dot,
|
287
|
-
}
|
288
|
-
|
289
|
-
# Return the state derivative as a generic PyTree, and the dict with auxiliary info
|
290
|
-
return state_derivative, aux_dict
|
jaxsim/simulation/ode_data.py
DELETED
@@ -1,96 +0,0 @@
|
|
1
|
-
import jax.flatten_util
|
2
|
-
import jax_dataclasses
|
3
|
-
|
4
|
-
import jaxsim.typing as jtp
|
5
|
-
from jaxsim.physics.algos.soft_contacts import SoftContactsState
|
6
|
-
from jaxsim.physics.model.physics_model import PhysicsModel
|
7
|
-
from jaxsim.physics.model.physics_model_state import (
|
8
|
-
PhysicsModelInput,
|
9
|
-
PhysicsModelState,
|
10
|
-
)
|
11
|
-
from jaxsim.utils import JaxsimDataclass
|
12
|
-
|
13
|
-
|
14
|
-
@jax_dataclasses.pytree_dataclass
|
15
|
-
class ODEInput(JaxsimDataclass):
|
16
|
-
""""""
|
17
|
-
|
18
|
-
physics_model: PhysicsModelInput
|
19
|
-
|
20
|
-
@staticmethod
|
21
|
-
def build(
|
22
|
-
physics_model_input: PhysicsModelInput | None = None,
|
23
|
-
physics_model: PhysicsModel | None = None,
|
24
|
-
) -> "ODEInput":
|
25
|
-
""""""
|
26
|
-
|
27
|
-
physics_model_input = (
|
28
|
-
physics_model_input
|
29
|
-
if physics_model_input is not None
|
30
|
-
else PhysicsModelInput.zero(physics_model=physics_model)
|
31
|
-
)
|
32
|
-
|
33
|
-
return ODEInput(physics_model=physics_model_input)
|
34
|
-
|
35
|
-
@staticmethod
|
36
|
-
def zero(physics_model: PhysicsModel) -> "ODEInput":
|
37
|
-
return ODEInput(
|
38
|
-
physics_model=PhysicsModelInput.zero(physics_model=physics_model)
|
39
|
-
)
|
40
|
-
|
41
|
-
def valid(self, physics_model: PhysicsModel) -> bool:
|
42
|
-
return self.physics_model.valid(physics_model=physics_model)
|
43
|
-
|
44
|
-
|
45
|
-
@jax_dataclasses.pytree_dataclass
|
46
|
-
class ODEState(JaxsimDataclass):
|
47
|
-
""""""
|
48
|
-
|
49
|
-
physics_model: PhysicsModelState
|
50
|
-
soft_contacts: SoftContactsState
|
51
|
-
|
52
|
-
@staticmethod
|
53
|
-
def build(
|
54
|
-
physics_model_state: PhysicsModelState | None = None,
|
55
|
-
soft_contacts_state: SoftContactsState | None = None,
|
56
|
-
physics_model: PhysicsModel | None = None,
|
57
|
-
) -> "ODEState":
|
58
|
-
""""""
|
59
|
-
|
60
|
-
physics_model_state = (
|
61
|
-
physics_model_state
|
62
|
-
if physics_model_state is not None
|
63
|
-
else PhysicsModelState.zero(physics_model=physics_model)
|
64
|
-
)
|
65
|
-
|
66
|
-
soft_contacts_state = (
|
67
|
-
soft_contacts_state
|
68
|
-
if soft_contacts_state is not None
|
69
|
-
else SoftContactsState.zero(physics_model=physics_model)
|
70
|
-
)
|
71
|
-
|
72
|
-
return ODEState(
|
73
|
-
physics_model=physics_model_state, soft_contacts=soft_contacts_state
|
74
|
-
)
|
75
|
-
|
76
|
-
@staticmethod
|
77
|
-
def deserialize(data: jtp.VectorJax, physics_model: PhysicsModel) -> "ODEState":
|
78
|
-
dummy_object = ODEState.zero(physics_model=physics_model)
|
79
|
-
_, unflatten_data = jax.flatten_util.ravel_pytree(dummy_object)
|
80
|
-
|
81
|
-
return unflatten_data(data)
|
82
|
-
|
83
|
-
@staticmethod
|
84
|
-
def zero(physics_model: PhysicsModel) -> "ODEState":
|
85
|
-
model_state = ODEState(
|
86
|
-
physics_model=PhysicsModelState.zero(physics_model=physics_model),
|
87
|
-
soft_contacts=SoftContactsState.zero(physics_model=physics_model),
|
88
|
-
)
|
89
|
-
|
90
|
-
assert model_state.valid(physics_model)
|
91
|
-
return model_state
|
92
|
-
|
93
|
-
def valid(self, physics_model: PhysicsModel) -> bool:
|
94
|
-
return self.physics_model.valid(
|
95
|
-
physics_model=physics_model
|
96
|
-
) and self.soft_contacts.valid(physics_model=physics_model)
|
@@ -1,62 +0,0 @@
|
|
1
|
-
import enum
|
2
|
-
import functools
|
3
|
-
from typing import Any, Dict, Tuple, Union
|
4
|
-
|
5
|
-
import jax.flatten_util
|
6
|
-
from jax.experimental.ode import odeint
|
7
|
-
|
8
|
-
import jaxsim.typing as jtp
|
9
|
-
from jaxsim.physics.algos.soft_contacts import SoftContactsParams
|
10
|
-
from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
|
11
|
-
from jaxsim.physics.model.physics_model import PhysicsModel
|
12
|
-
from jaxsim.simulation import integrators, ode
|
13
|
-
from jaxsim.simulation.integrators import IntegratorType
|
14
|
-
|
15
|
-
|
16
|
-
@jax.jit
|
17
|
-
def ode_integration_rk4_adaptive(
|
18
|
-
x0: jtp.Array,
|
19
|
-
t: integrators.TimeHorizon,
|
20
|
-
physics_model: PhysicsModel,
|
21
|
-
*args,
|
22
|
-
**kwargs,
|
23
|
-
) -> jtp.Array:
|
24
|
-
# Close function over its inputs and parameters
|
25
|
-
dx_dt_closure = lambda x, ts: ode.dx_dt(x, ts, physics_model, *args)
|
26
|
-
|
27
|
-
return odeint(dx_dt_closure, x0, t, **kwargs)
|
28
|
-
|
29
|
-
|
30
|
-
@functools.partial(
|
31
|
-
jax.jit, static_argnames=["num_sub_steps", "integrator_type", "return_aux"]
|
32
|
-
)
|
33
|
-
def ode_integration_fixed_step(
|
34
|
-
x0: ode.ode_data.ODEState,
|
35
|
-
t: integrators.TimeHorizon,
|
36
|
-
physics_model: PhysicsModel,
|
37
|
-
integrator_type: IntegratorType,
|
38
|
-
soft_contacts_params: SoftContactsParams = SoftContactsParams(),
|
39
|
-
terrain: Terrain = FlatTerrain(),
|
40
|
-
ode_input: ode.ode_data.ODEInput | None = None,
|
41
|
-
*args,
|
42
|
-
num_sub_steps: int = 1,
|
43
|
-
return_aux: bool = False,
|
44
|
-
) -> Union[ode.ode_data.ODEState, Tuple[ode.ode_data.ODEState, Dict]]:
|
45
|
-
# Close func over additional inputs and parameters
|
46
|
-
dx_dt_closure = lambda x, ts: ode.dx_dt(
|
47
|
-
x, ts, physics_model, soft_contacts_params, ode_input, terrain, *args
|
48
|
-
)
|
49
|
-
|
50
|
-
# Integrate over the horizon
|
51
|
-
out = integrators.odeint(
|
52
|
-
func=dx_dt_closure,
|
53
|
-
y0=x0,
|
54
|
-
t=t,
|
55
|
-
num_sub_steps=num_sub_steps,
|
56
|
-
return_aux=return_aux,
|
57
|
-
integrator_type=integrator_type,
|
58
|
-
)
|
59
|
-
|
60
|
-
# Return output pytree and, optionally, the aux dict
|
61
|
-
state = out if not return_aux else out[0]
|
62
|
-
return (state, out[1]) if return_aux else state
|