jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +5 -6
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -0
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +216 -0
- jaxsim/api/contact.py +271 -0
- jaxsim/api/data.py +821 -0
- jaxsim/api/joint.py +189 -0
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +361 -0
- jaxsim/api/model.py +1633 -0
- jaxsim/api/ode.py +295 -0
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +421 -0
- jaxsim/integrators/__init__.py +2 -0
- jaxsim/integrators/common.py +594 -0
- jaxsim/integrators/fixed_step.py +102 -0
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +92 -0
- jaxsim/mujoco/__init__.py +3 -0
- jaxsim/mujoco/__main__.py +192 -0
- jaxsim/mujoco/loaders.py +615 -0
- jaxsim/mujoco/model.py +414 -0
- jaxsim/mujoco/visualizer.py +176 -0
- jaxsim/parsers/descriptions/collision.py +14 -0
- jaxsim/parsers/descriptions/link.py +13 -2
- jaxsim/parsers/kinematic_graph.py +8 -3
- jaxsim/parsers/rod/parser.py +54 -38
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/{physics/algos → terrain}/terrain.py +4 -6
- jaxsim/typing.py +30 -30
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -31
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
- jaxsim-0.2.0.dist-info/METADATA +237 -0
- jaxsim-0.2.0.dist-info/RECORD +64 -0
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1695
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -101
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -256
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -454
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -358
- jaxsim/physics/model/physics_model_state.py +0 -174
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -452
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -53
- jaxsim/simulation/ode_integration.py +0 -125
- jaxsim/simulation/simulator.py +0 -544
- jaxsim/simulation/simulator_callbacks.py +0 -53
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -532
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.1.dev401.dist-info/METADATA +0 -167
- jaxsim-0.1.dev401.dist-info/RECORD +0 -64
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/simulation/integrators.py
DELETED
@@ -1,452 +0,0 @@
|
|
1
|
-
from typing import Any, Callable, Dict, Tuple, Union
|
2
|
-
|
3
|
-
import jax
|
4
|
-
import jax.numpy as jnp
|
5
|
-
from jax.tree_util import tree_map
|
6
|
-
|
7
|
-
import jaxsim.typing as jtp
|
8
|
-
from jaxsim.physics.algos.soft_contacts import SoftContactsState
|
9
|
-
from jaxsim.physics.model.physics_model_state import PhysicsModelState
|
10
|
-
from jaxsim.simulation.ode_data import ODEState
|
11
|
-
|
12
|
-
Time = float
|
13
|
-
TimeHorizon = jtp.Vector
|
14
|
-
|
15
|
-
State = jtp.PyTree
|
16
|
-
StateDerivative = jtp.PyTree
|
17
|
-
|
18
|
-
StateDerivativeCallable = Callable[
|
19
|
-
[State, Time], Tuple[StateDerivative, Dict[str, Any]]
|
20
|
-
]
|
21
|
-
|
22
|
-
|
23
|
-
# =======================
|
24
|
-
# Single-step integration
|
25
|
-
# =======================
|
26
|
-
|
27
|
-
|
28
|
-
def odeint_euler_one_step(
|
29
|
-
dx_dt: StateDerivativeCallable,
|
30
|
-
x0: State,
|
31
|
-
t0: Time,
|
32
|
-
tf: Time,
|
33
|
-
num_sub_steps: int = 1,
|
34
|
-
) -> Tuple[State, Dict[str, Any]]:
|
35
|
-
"""
|
36
|
-
Forward Euler integrator.
|
37
|
-
|
38
|
-
Args:
|
39
|
-
dx_dt: Callable that computes the state derivative.
|
40
|
-
x0: Initial state.
|
41
|
-
t0: Initial time.
|
42
|
-
tf: Final time.
|
43
|
-
num_sub_steps: Number of sub-steps to break the integration into.
|
44
|
-
|
45
|
-
Returns:
|
46
|
-
The final state and a dictionary including auxiliary data at t0.
|
47
|
-
"""
|
48
|
-
|
49
|
-
# Compute the sub-step size.
|
50
|
-
# We break dt in configurable sub-steps.
|
51
|
-
dt = tf - t0
|
52
|
-
sub_step_dt = dt / num_sub_steps
|
53
|
-
|
54
|
-
# Initialize the carry
|
55
|
-
Carry = Tuple[State, Time]
|
56
|
-
carry_init: Carry = (x0, t0)
|
57
|
-
|
58
|
-
def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
|
59
|
-
# Unpack the carry
|
60
|
-
x_t0, t0 = carry
|
61
|
-
|
62
|
-
# Compute the state derivative
|
63
|
-
dxdt_t0, _ = dx_dt(x_t0, t0)
|
64
|
-
|
65
|
-
# Integrate the dynamics
|
66
|
-
x_tf = jax.tree_util.tree_map(
|
67
|
-
lambda x, dxdt: x + sub_step_dt * dxdt, x_t0, dxdt_t0
|
68
|
-
)
|
69
|
-
|
70
|
-
# Update the time
|
71
|
-
tf = t0 + sub_step_dt
|
72
|
-
|
73
|
-
# Pack the carry
|
74
|
-
carry = (x_tf, tf)
|
75
|
-
|
76
|
-
return carry, None
|
77
|
-
|
78
|
-
# Integrate over the given horizon
|
79
|
-
(x_tf, _), _ = jax.lax.scan(
|
80
|
-
f=body_fun, init=carry_init, xs=None, length=num_sub_steps
|
81
|
-
)
|
82
|
-
|
83
|
-
# Compute the aux dictionary at t0
|
84
|
-
_, aux_t0 = dx_dt(x0, t0)
|
85
|
-
|
86
|
-
return x_tf, aux_t0
|
87
|
-
|
88
|
-
|
89
|
-
def odeint_rk4_one_step(
|
90
|
-
dx_dt: StateDerivativeCallable,
|
91
|
-
x0: State,
|
92
|
-
t0: Time,
|
93
|
-
tf: Time,
|
94
|
-
num_sub_steps: int = 1,
|
95
|
-
) -> Tuple[State, Dict[str, Any]]:
|
96
|
-
"""
|
97
|
-
Runge-Kutta 4 integrator.
|
98
|
-
|
99
|
-
Args:
|
100
|
-
dx_dt: Callable that computes the state derivative.
|
101
|
-
x0: Initial state.
|
102
|
-
t0: Initial time.
|
103
|
-
tf: Final time.
|
104
|
-
num_sub_steps: Number of sub-steps to break the integration into.
|
105
|
-
|
106
|
-
Returns:
|
107
|
-
The final state and a dictionary including auxiliary data at t0.
|
108
|
-
"""
|
109
|
-
|
110
|
-
# Compute the sub-step size.
|
111
|
-
# We break dt in configurable sub-steps.
|
112
|
-
dt = tf - t0
|
113
|
-
sub_step_dt = dt / num_sub_steps
|
114
|
-
|
115
|
-
# Initialize the carry
|
116
|
-
Carry = Tuple[State, Time]
|
117
|
-
carry_init: Carry = (x0, t0)
|
118
|
-
|
119
|
-
def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
|
120
|
-
# Unpack the carry
|
121
|
-
x_t0, t0 = carry
|
122
|
-
|
123
|
-
# Helper to forward the state to compute k2 and k3 at midpoint and k4 at final
|
124
|
-
euler_mid = lambda x, dxdt: x + (0.5 * sub_step_dt) * dxdt
|
125
|
-
euler_fin = lambda x, dxdt: x + sub_step_dt * dxdt
|
126
|
-
|
127
|
-
# Compute the RK4 slopes
|
128
|
-
k1, _ = dx_dt(x_t0, t0)
|
129
|
-
k2, _ = dx_dt(tree_map(euler_mid, x_t0, k1), t0 + 0.5 * sub_step_dt)
|
130
|
-
k3, _ = dx_dt(tree_map(euler_mid, x_t0, k2), t0 + 0.5 * sub_step_dt)
|
131
|
-
k4, _ = dx_dt(tree_map(euler_fin, x_t0, k3), t0 + sub_step_dt)
|
132
|
-
|
133
|
-
# Average the slopes and compute the RK4 state derivative
|
134
|
-
average = lambda k1, k2, k3, k4: (k1 + 2 * k2 + 2 * k3 + k4) / 6
|
135
|
-
dxdt = jax.tree_util.tree_map(average, k1, k2, k3, k4)
|
136
|
-
|
137
|
-
# Integrate the dynamics
|
138
|
-
x_tf = jax.tree_util.tree_map(euler_fin, x_t0, dxdt)
|
139
|
-
|
140
|
-
# Update the time
|
141
|
-
tf = t0 + sub_step_dt
|
142
|
-
|
143
|
-
# Pack the carry
|
144
|
-
carry = (x_tf, tf)
|
145
|
-
|
146
|
-
return carry, None
|
147
|
-
|
148
|
-
# Integrate over the given horizon
|
149
|
-
(x_tf, _), _ = jax.lax.scan(
|
150
|
-
f=body_fun, init=carry_init, xs=None, length=num_sub_steps
|
151
|
-
)
|
152
|
-
|
153
|
-
# Compute the aux dictionary at t0
|
154
|
-
_, aux_t0 = dx_dt(x0, t0)
|
155
|
-
|
156
|
-
return x_tf, aux_t0
|
157
|
-
|
158
|
-
|
159
|
-
def odeint_euler_semi_implicit_one_step(
|
160
|
-
dx_dt: StateDerivativeCallable,
|
161
|
-
x0: ODEState,
|
162
|
-
t0: Time,
|
163
|
-
tf: Time,
|
164
|
-
num_sub_steps: int = 1,
|
165
|
-
) -> Tuple[ODEState, Dict[str, Any]]:
|
166
|
-
"""
|
167
|
-
Semi-implicit Euler integrator.
|
168
|
-
|
169
|
-
Args:
|
170
|
-
dx_dt: Callable that computes the state derivative.
|
171
|
-
x0: Initial state.
|
172
|
-
t0: Initial time.
|
173
|
-
tf: Final time.
|
174
|
-
num_sub_steps: Number of sub-steps to break the integration into.
|
175
|
-
|
176
|
-
Returns:
|
177
|
-
The final state and a dictionary including auxiliary data at t0.
|
178
|
-
"""
|
179
|
-
|
180
|
-
# Compute the sub-step size.
|
181
|
-
# We break dt in configurable sub-steps.
|
182
|
-
dt = tf - t0
|
183
|
-
sub_step_dt = dt / num_sub_steps
|
184
|
-
|
185
|
-
# Initialize the carry
|
186
|
-
Carry = Tuple[ODEState, Time]
|
187
|
-
carry_init: Carry = (x0, t0)
|
188
|
-
|
189
|
-
def quaternion_derivative(W_Q_B: jtp.Vector, W_omega_WB: jtp.Vector) -> jtp.Vector:
|
190
|
-
from jaxsim.math.quaternion import Quaternion
|
191
|
-
|
192
|
-
return Quaternion.derivative(
|
193
|
-
quaternion=W_Q_B, omega=W_omega_WB, omega_in_body_fixed=False
|
194
|
-
).squeeze()
|
195
|
-
|
196
|
-
def inertial_to_3d_mixed(
|
197
|
-
W_v_lin_WB: jtp.Vector, W_v_ang_WB: jtp.Vector, W_pos_B: jtp.Vector
|
198
|
-
) -> jtp.Vector:
|
199
|
-
from jaxsim.math.conv import Convert
|
200
|
-
|
201
|
-
# Compute linear component of mixed velocity BW_v_WB
|
202
|
-
return Convert.velocities_threed(
|
203
|
-
v_6d=jnp.hstack([W_v_lin_WB, W_v_ang_WB]), p=W_pos_B.squeeze()
|
204
|
-
).squeeze()
|
205
|
-
|
206
|
-
def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
|
207
|
-
# Unpack the carry
|
208
|
-
x_t0, t0 = carry
|
209
|
-
|
210
|
-
# Extract the initial position and velocity
|
211
|
-
pos_t0 = x_t0.physics_model.position()
|
212
|
-
vel_t0 = x_t0.physics_model.velocity()
|
213
|
-
|
214
|
-
# Compute the state derivative
|
215
|
-
StateDerivative = ODEState
|
216
|
-
dxdt_t0: StateDerivative = dx_dt(x_t0, t0)[0]
|
217
|
-
|
218
|
-
# Extract the velocity derivative
|
219
|
-
d_vel_dt = dxdt_t0.physics_model.velocity()
|
220
|
-
|
221
|
-
# Perform semi-implicit Euler integration [1-4].
|
222
|
-
|
223
|
-
# 1. Integrate the velocities
|
224
|
-
vel_tf = vel_t0 + sub_step_dt * d_vel_dt
|
225
|
-
|
226
|
-
# 2. Compute the quaternion derivative and the base position derivative
|
227
|
-
W_Qd_B = quaternion_derivative(
|
228
|
-
W_Q_B=x_t0.physics_model.base_quaternion, W_omega_WB=vel_tf[3:6]
|
229
|
-
)
|
230
|
-
BW_v_WB = inertial_to_3d_mixed(
|
231
|
-
W_pos_B=x_t0.physics_model.base_position,
|
232
|
-
W_v_lin_WB=x_t0.physics_model.base_linear_velocity,
|
233
|
-
W_v_ang_WB=x_t0.physics_model.base_angular_velocity,
|
234
|
-
)
|
235
|
-
|
236
|
-
# 3. Compute the derivative of the position
|
237
|
-
posd_tf = jnp.hstack([BW_v_WB, W_Qd_B, vel_tf[6:]])
|
238
|
-
|
239
|
-
# 4. Integrate the positions
|
240
|
-
pos_tf = pos_t0 + sub_step_dt * posd_tf
|
241
|
-
|
242
|
-
# Integrate the remaining state
|
243
|
-
u = x_t0.soft_contacts.tangential_deformation
|
244
|
-
ud = dxdt_t0.soft_contacts.tangential_deformation
|
245
|
-
tangential_deformation_tf = u + sub_step_dt * ud
|
246
|
-
|
247
|
-
x_tf = ODEState(
|
248
|
-
physics_model=PhysicsModelState(
|
249
|
-
base_position=pos_tf[0:3],
|
250
|
-
base_quaternion=pos_tf[3:7],
|
251
|
-
joint_positions=pos_tf[7:],
|
252
|
-
base_linear_velocity=vel_tf[0:3],
|
253
|
-
base_angular_velocity=vel_tf[3:6],
|
254
|
-
joint_velocities=vel_tf[6:],
|
255
|
-
),
|
256
|
-
soft_contacts=SoftContactsState(
|
257
|
-
tangential_deformation=tangential_deformation_tf
|
258
|
-
),
|
259
|
-
)
|
260
|
-
|
261
|
-
# Update the time
|
262
|
-
tf = t0 + sub_step_dt
|
263
|
-
|
264
|
-
# Pack the carry
|
265
|
-
carry = (x_tf, tf)
|
266
|
-
|
267
|
-
return carry, None
|
268
|
-
|
269
|
-
# Integrate over the given horizon
|
270
|
-
(x_tf, _), _ = jax.lax.scan(
|
271
|
-
f=body_fun, init=carry_init, xs=None, length=num_sub_steps
|
272
|
-
)
|
273
|
-
|
274
|
-
# Compute the aux dictionary at t0
|
275
|
-
_, aux_t0 = dx_dt(x0, t0)
|
276
|
-
|
277
|
-
return x_tf, aux_t0
|
278
|
-
|
279
|
-
|
280
|
-
# ===============================
|
281
|
-
# Adapter: single step -> horizon
|
282
|
-
# ===============================
|
283
|
-
|
284
|
-
|
285
|
-
def integrate_single_step_over_horizon(
|
286
|
-
integrator_single_step: Callable[[Time, Time, State], Tuple[State, Dict[str, Any]]],
|
287
|
-
t: TimeHorizon,
|
288
|
-
x0: State,
|
289
|
-
) -> Tuple[State, Dict[str, Any]]:
|
290
|
-
"""
|
291
|
-
Integrate a single-step integrator over a given horizon.
|
292
|
-
|
293
|
-
Args:
|
294
|
-
integrator_single_step: A single-step integrator.
|
295
|
-
t: The vector of time instants of the integration horizon.
|
296
|
-
x0: The initial state of the integration horizon.
|
297
|
-
|
298
|
-
Returns:
|
299
|
-
The final state and auxiliary data produced by the integrator.
|
300
|
-
"""
|
301
|
-
|
302
|
-
# Initialize the carry
|
303
|
-
carry_init = (x0, t)
|
304
|
-
|
305
|
-
def body_fun(carry: Tuple, idx: int) -> Tuple[Tuple, jtp.PyTree]:
|
306
|
-
# Unpack the carry
|
307
|
-
x_t0, horizon = carry
|
308
|
-
|
309
|
-
# Get the integration interval
|
310
|
-
t0 = horizon[idx]
|
311
|
-
tf = horizon[idx + 1]
|
312
|
-
|
313
|
-
# Perform a single-step integration of the ODE
|
314
|
-
x_tf, aux_t0 = integrator_single_step(t0, tf, x_t0)
|
315
|
-
|
316
|
-
# Prepare returned data
|
317
|
-
out = (x_t0, aux_t0)
|
318
|
-
carry = (x_tf, horizon)
|
319
|
-
|
320
|
-
return carry, out
|
321
|
-
|
322
|
-
# Integrate over the given horizon
|
323
|
-
_, (x_horizon, aux_horizon) = jax.lax.scan(
|
324
|
-
f=body_fun, init=carry_init, xs=jnp.arange(start=0, stop=len(t))
|
325
|
-
)
|
326
|
-
|
327
|
-
return x_horizon, aux_horizon
|
328
|
-
|
329
|
-
|
330
|
-
# ===================================================================
|
331
|
-
# Integration over horizon (same APIs of jax.experimental.ode.odeint)
|
332
|
-
# ===================================================================
|
333
|
-
|
334
|
-
|
335
|
-
def odeint_euler(
|
336
|
-
func,
|
337
|
-
y0: State,
|
338
|
-
t: TimeHorizon,
|
339
|
-
*args,
|
340
|
-
num_sub_steps: int = 1,
|
341
|
-
return_aux: bool = False
|
342
|
-
) -> Union[State, Tuple[State, Dict[str, Any]]]:
|
343
|
-
"""
|
344
|
-
Integrate a system of ODEs using the Euler method.
|
345
|
-
|
346
|
-
Args:
|
347
|
-
func: A function that computes the time-derivative of the state.
|
348
|
-
y0: The initial state.
|
349
|
-
t: The vector of time instants of the integration horizon.
|
350
|
-
*args: Additional arguments to be passed to the function func.
|
351
|
-
num_sub_steps: The number of sub-steps to be performed within each integration step.
|
352
|
-
return_aux: Whether to return the auxiliary data produced by the integrator.
|
353
|
-
|
354
|
-
Returns:
|
355
|
-
The state of the system at the end of the integration horizon, and optionally
|
356
|
-
the auxiliary data produced by the integrator.
|
357
|
-
"""
|
358
|
-
|
359
|
-
# Close func over additional inputs and parameters
|
360
|
-
dx_dt_closure_aux = lambda x, ts: func(x, ts, *args)
|
361
|
-
|
362
|
-
# Close one-step integration over its arguments
|
363
|
-
integrator_single_step = lambda t0, tf, x0: odeint_euler_one_step(
|
364
|
-
dx_dt=dx_dt_closure_aux, x0=x0, t0=t0, tf=tf, num_sub_steps=num_sub_steps
|
365
|
-
)
|
366
|
-
|
367
|
-
# Integrate the state and compute optional auxiliary data over the horizon
|
368
|
-
out, aux = integrate_single_step_over_horizon(
|
369
|
-
integrator_single_step=integrator_single_step, t=t, x0=y0
|
370
|
-
)
|
371
|
-
|
372
|
-
return (out, aux) if return_aux else out
|
373
|
-
|
374
|
-
|
375
|
-
def odeint_euler_semi_implicit(
|
376
|
-
func,
|
377
|
-
y0: State,
|
378
|
-
t: TimeHorizon,
|
379
|
-
*args,
|
380
|
-
num_sub_steps: int = 1,
|
381
|
-
return_aux: bool = False
|
382
|
-
) -> Union[State, Tuple[State, Dict[str, Any]]]:
|
383
|
-
"""
|
384
|
-
Integrate a system of ODEs using the Semi-Implicit Euler method.
|
385
|
-
|
386
|
-
Args:
|
387
|
-
func: A function that computes the time-derivative of the state.
|
388
|
-
y0: The initial state.
|
389
|
-
t: The vector of time instants of the integration horizon.
|
390
|
-
*args: Additional arguments to be passed to the function func.
|
391
|
-
num_sub_steps: The number of sub-steps to be performed within each integration step.
|
392
|
-
return_aux: Whether to return the auxiliary data produced by the integrator.
|
393
|
-
|
394
|
-
Returns:
|
395
|
-
The state of the system at the end of the integration horizon, and optionally
|
396
|
-
the auxiliary data produced by the integrator.
|
397
|
-
"""
|
398
|
-
|
399
|
-
# Close func over additional inputs and parameters
|
400
|
-
dx_dt_closure_aux = lambda x, ts: func(x, ts, *args)
|
401
|
-
|
402
|
-
# Close one-step integration over its arguments
|
403
|
-
integrator_single_step = lambda t0, tf, x0: odeint_euler_semi_implicit_one_step(
|
404
|
-
dx_dt=dx_dt_closure_aux, x0=x0, t0=t0, tf=tf, num_sub_steps=num_sub_steps
|
405
|
-
)
|
406
|
-
|
407
|
-
# Integrate the state and compute optional auxiliary data over the horizon
|
408
|
-
out, aux = integrate_single_step_over_horizon(
|
409
|
-
integrator_single_step=integrator_single_step, t=t, x0=y0
|
410
|
-
)
|
411
|
-
|
412
|
-
return (out, aux) if return_aux else out
|
413
|
-
|
414
|
-
|
415
|
-
def odeint_rk4(
|
416
|
-
func,
|
417
|
-
y0: State,
|
418
|
-
t: TimeHorizon,
|
419
|
-
*args,
|
420
|
-
num_sub_steps: int = 1,
|
421
|
-
return_aux: bool = False
|
422
|
-
) -> Union[State, Tuple[State, Dict[str, Any]]]:
|
423
|
-
"""
|
424
|
-
Integrate a system of ODEs using the Runge-Kutta 4 method.
|
425
|
-
|
426
|
-
Args:
|
427
|
-
func: A function that computes the time-derivative of the state.
|
428
|
-
y0: The initial state.
|
429
|
-
t: The vector of time instants of the integration horizon.
|
430
|
-
*args: Additional arguments to be passed to the function func.
|
431
|
-
num_sub_steps: The number of sub-steps to be performed within each integration step.
|
432
|
-
return_aux: Whether to return the auxiliary data produced by the integrator.
|
433
|
-
|
434
|
-
Returns:
|
435
|
-
The state of the system at the end of the integration horizon, and optionally
|
436
|
-
the auxiliary data produced by the integrator.
|
437
|
-
"""
|
438
|
-
|
439
|
-
# Close func over additional inputs and parameters
|
440
|
-
dx_dt_closure = lambda x, ts: func(x, ts, *args)
|
441
|
-
|
442
|
-
# Close one-step integration over its arguments
|
443
|
-
integrator_single_step = lambda t0, tf, x0: odeint_rk4_one_step(
|
444
|
-
dx_dt=dx_dt_closure, x0=x0, t0=t0, tf=tf, num_sub_steps=num_sub_steps
|
445
|
-
)
|
446
|
-
|
447
|
-
# Integrate the state and compute optional auxiliary data over the horizon
|
448
|
-
out, aux = integrate_single_step_over_horizon(
|
449
|
-
integrator_single_step=integrator_single_step, t=t, x0=y0
|
450
|
-
)
|
451
|
-
|
452
|
-
return (out, aux) if return_aux else out
|