openscvx 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

openscvx/__init__.py ADDED
File without changes
openscvx/_version.py ADDED
@@ -0,0 +1,21 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
5
+
6
+ TYPE_CHECKING = False
7
+ if TYPE_CHECKING:
8
+ from typing import Tuple
9
+ from typing import Union
10
+
11
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
12
+ else:
13
+ VERSION_TUPLE = object
14
+
15
+ version: str
16
+ __version__: str
17
+ __version_tuple__: VERSION_TUPLE
18
+ version_tuple: VERSION_TUPLE
19
+
20
+ __version__ = version = '0.1.1'
21
+ __version_tuple__ = version_tuple = (0, 1, 1)
@@ -0,0 +1,44 @@
1
+ from typing import List
2
+
3
+ from openscvx.constraints.ctcs import CTCSConstraint
4
+
5
+ def sort_ctcs_constraints(constraints_ctcs: List[CTCSConstraint], N: int):
6
+ idx_to_nodes: dict[int, tuple] = {}
7
+ next_idx = 0
8
+ for c in constraints_ctcs:
9
+ # normalize None to full horizon
10
+ c.nodes = c.nodes or (0, N)
11
+ key = c.nodes
12
+
13
+ if c.idx is not None:
14
+ # user supplied an identifier: ensure it always points to the same interval
15
+ if c.idx in idx_to_nodes:
16
+ if idx_to_nodes[c.idx] != key:
17
+ raise ValueError(
18
+ f"idx={c.idx} was first used with interval={idx_to_nodes[c.idx]}, "
19
+ f"but now you gave it interval={key}"
20
+ )
21
+ else:
22
+ idx_to_nodes[c.idx] = key
23
+
24
+ else:
25
+ # no identifier: see if this interval already has one
26
+ for existing_id, nodes in idx_to_nodes.items():
27
+ if nodes == key:
28
+ c.idx = existing_id
29
+ break
30
+ else:
31
+ # brand-new interval: pick the next free auto-id
32
+ while next_idx in idx_to_nodes:
33
+ next_idx += 1
34
+ c.idx = next_idx
35
+ idx_to_nodes[next_idx] = key
36
+ next_idx += 1
37
+
38
+ # Extract your intervals in ascending‐idx order
39
+ ordered_ids = sorted(idx_to_nodes.keys())
40
+ node_intervals = [ idx_to_nodes[i] for i in ordered_ids ]
41
+ id_to_position = { ident: pos for pos, ident in enumerate(ordered_ids) }
42
+ num_augmented_states = len(ordered_ids)
43
+
44
+ return constraints_ctcs, node_intervals, num_augmented_states,
openscvx/config.py ADDED
@@ -0,0 +1,247 @@
1
+ import numpy as np
2
+ from dataclasses import dataclass, field
3
+ from typing import Dict, List
4
+
5
+
6
+ def get_affine_scaling_matrices(n, minimum, maximum):
7
+ S = np.diag(np.maximum(np.ones(n), abs(minimum - maximum) / 2))
8
+ c = (maximum + minimum) / 2
9
+ return S, c
10
+
11
+
12
+ @dataclass
13
+ class DiscretizationConfig:
14
+ dis_type: str = "FOH"
15
+ custom_integrator: bool = True
16
+ solver: str = "Tsit5"
17
+ args: Dict = field(default_factory=dict)
18
+ atol: float = 1e-3
19
+ rtol: float = 1e-6
20
+
21
+ """
22
+ Configuration class for discretization settings.
23
+
24
+ This class defines the parameters required for discretizing system dynamics.
25
+
26
+ Main arguments:
27
+ These are the arguments most commonly used day-to-day.
28
+
29
+ Args:
30
+ dis_type (str): The type of discretization to use (e.g., "FOH" for First-Order Hold). Defaults to "FOH".
31
+ custom_integrator (bool): This enables our custom fixed-step RK45 algorthim. This tends to be faster then Diffrax but unless your going for speed, its reccomended to stick with Diffrax for robustness and other solver options. Defaults to False.
32
+ solver (str): Not used if custom_integrator is enabled. Any choice of solver in Diffrax is valid, please refer here, [How to Choose a Solver](https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/). Defaults to "Tsit5".
33
+
34
+ Other arguments:
35
+ These arguments are less frequently used, and for most purposes you shouldn't need to understand these.
36
+
37
+ Args:
38
+ args (Dict): Additional arguments to pass to the solver which can be found [here](https://docs.kidger.site/diffrax/api/diffeqsolve/). Defaults to an empty dictionary.
39
+ atol (float): Absolute tolerance for the solver. Defaults to 1e-3.
40
+ rtol (float): Relative tolerance for the solver. Defaults to 1e-6.
41
+ """
42
+
43
+
44
+ @dataclass
45
+ class DevConfig:
46
+ profiling: bool = False
47
+ debug: bool = False
48
+ printing: bool = True
49
+
50
+ """
51
+ Configuration class for development settings.
52
+
53
+ This class defines the parameters used for development and debugging purposes.
54
+
55
+ Main arguments:
56
+ These are the arguments most commonly used day-to-day.
57
+
58
+ Args:
59
+ profiling (bool): Whether to enable profiling for performance analysis. Defaults to False.
60
+ debug (bool): Disables all precompilation so you can place breakpoints and inspect values. Defaults to False.
61
+ """
62
+
63
+
64
+ @dataclass
65
+ class ConvexSolverConfig:
66
+ solver: str = "QOCO"
67
+ solver_args: dict = field(default_factory=lambda: {"abstol": 1e-6, "reltol": 1e-9})
68
+ cvxpygen: bool = False
69
+
70
+ """
71
+ Configuration class for convex solver settings.
72
+
73
+ This class defines the parameters required for configuring a convex solver.
74
+
75
+ These are the arguments most commonly used day-to-day. Generally I have found [QOCO](https://qoco-org.github.io/qoco/index.html) to be the most performant of the CVXPY solvers for these types of problems (I do have a bias as the author is from my group) and can handle up to SOCP's.
76
+ [CLARABEL](https://clarabel.org/stable/) is also a great option with feasibility checking and can handle a few more problem types.
77
+ [CVXPYGen](https://github.com/cvxgrp/cvxpygen) is also great if your problem isn't too large and allows. I have found qocogen to be the most performant of the CVXPYGen solvers.
78
+
79
+ Args:
80
+ solver (str): The name of the CVXPY solver to use. A list of options can be found [here](https://www.cvxpy.org/tutorial/solvers/index.html). Defaults to "QOCO".
81
+ solver_args (dict): Ensure you are using the correct arguments for your solver as they are not all common. Additional arguments to configure the solver, such as tolerances.
82
+ Defaults to {"abstol": 1e-6, "reltol": 1e-9}.
83
+ cvxpygen (bool): Whether to enable CVXPY code generation for the solver. Defaults to False.
84
+ """
85
+
86
+
87
+ @dataclass
88
+ class PropagationConfig:
89
+ inter_sample: int = 30
90
+ dt: float = 0.1
91
+ solver: str = "Dopri8"
92
+ args: Dict = field(default_factory=dict)
93
+ atol: float = 1e-3
94
+ rtol: float = 1e-6
95
+
96
+ """
97
+ Configuration class for propagation settings.
98
+
99
+ This class defines the parameters required for propagating the nonlinear system dynamics using the optimal control sequence.
100
+
101
+ Main arguments:
102
+ These are the arguments most commonly used day-to-day.
103
+
104
+ Args:
105
+ dt (float): The time step for propagation. Defaults to 0.1.
106
+ inter_sample (int): How dense the propagation within multishot discretization should be.
107
+
108
+ Other arguments:
109
+ The solver should likley not to be changed as it is a high accuracy 8th order runga kutta method.
110
+
111
+ Args:
112
+ solver (str): The numerical solver to use for propagation (e.g., "Dopri8"). Defaults to "Dopri8".
113
+ args (Dict): Additional arguments to pass to the solver. Defaults to an empty dictionary.
114
+ atol (float): Absolute tolerance for the solver. Defaults to 1e-3.
115
+ rtol (float): Relative tolerance for the solver. Defaults to 1e-6.
116
+ """
117
+
118
+
119
+ @dataclass
120
+ class SimConfig:
121
+ x_bar: np.ndarray
122
+ u_bar: np.ndarray
123
+ initial_state: np.ndarray
124
+ final_state: np.ndarray
125
+ max_state: np.ndarray
126
+ min_state: np.ndarray
127
+ max_control: np.ndarray
128
+ min_control: np.ndarray
129
+ total_time: float
130
+ idx_x_true: slice
131
+ idx_u_true: slice
132
+ idx_t: slice
133
+ idx_y: slice
134
+ idx_s: slice
135
+ ctcs_node_intervals: list = None
136
+ constraints_ctcs: List[callable] = field(
137
+ default_factory=list
138
+ ) # TODO (norrisg): clean this up, consider moving to dedicated `constraints` dataclass
139
+ constraints_nodal: List[callable] = field(default_factory=list)
140
+ n_states: int = None
141
+ n_controls: int = None
142
+ S_x: np.ndarray = None
143
+ inv_S_x: np.ndarray = None
144
+ c_x: np.ndarray = None
145
+ S_u: np.ndarray = None
146
+ inv_S_u: np.ndarray = None
147
+ c_u: np.ndarray = None
148
+
149
+ def __post_init__(self):
150
+ self.n_states = len(self.max_state)
151
+ self.n_controls = len(self.max_control)
152
+
153
+ assert (
154
+ len(self.initial_state.value) == self.n_states - (self.idx_y.stop - self.idx_y.start)
155
+ ), f"Initial state must have {self.n_states - (self.idx_y.stop - self.idx_y.start)} elements"
156
+ assert (
157
+ len(self.final_state.value) == self.n_states - (self.idx_y.stop - self.idx_y.start)
158
+ ), f"Final state must have {self.n_states - (self.idx_y.stop - self.idx_y.start)} elements"
159
+ assert (
160
+ self.max_state.shape[0] == self.n_states
161
+ ), f"Max state must have {self.n_states} elements"
162
+ assert (
163
+ self.min_state.shape[0] == self.n_states
164
+ ), f"Min state must have {self.n_states} elements"
165
+ assert (
166
+ self.max_control.shape[0] == self.n_controls
167
+ ), f"Max control must have {self.n_controls} elements"
168
+ assert (
169
+ self.min_control.shape[0] == self.n_controls
170
+ ), f"Min control must have {self.n_controls} elements"
171
+
172
+ if self.S_x is None or self.c_x is None:
173
+ self.S_x, self.c_x = get_affine_scaling_matrices(
174
+ self.n_states, self.min_state, self.max_state
175
+ )
176
+ # Use the fact that S_x is diagonal to compute the inverse
177
+ self.inv_S_x = np.diag(1 / np.diag(self.S_x))
178
+ if self.S_u is None or self.c_u is None:
179
+ self.S_u, self.c_u = get_affine_scaling_matrices(
180
+ self.n_controls, self.min_control, self.max_control
181
+ )
182
+ self.inv_S_u = np.diag(1 / np.diag(self.S_u))
183
+
184
+
185
+ @dataclass
186
+ class ScpConfig:
187
+ n: int = None
188
+ k_max: int = 200
189
+ w_tr: float = 1e0
190
+ lam_vc: float = 1e0
191
+ ep_tr: float = 1e-4
192
+ ep_vb: float = 1e-4
193
+ ep_vc: float = 1e-8
194
+ lam_cost: float = 0.0
195
+ lam_vb: float = 0.0
196
+ uniform_time_grid: bool = False
197
+ cost_drop: int = -1
198
+ cost_relax: float = 1.0
199
+ w_tr_adapt: float = 1.0
200
+ w_tr_max: float = None
201
+ w_tr_max_scaling_factor: float = None
202
+
203
+ """
204
+ Configuration class for Sequential Convex Programming (SCP).
205
+
206
+ This class defines the parameters used to configure the SCP solver. You will very likely need to modify
207
+ the weights for your problem. Please refer to my guide [here](https://haynec.github.io/openscvx/hyperparameter_tuning) for more information.
208
+
209
+ Attributes:
210
+ n (int): The number of discretization nodes. Defaults to `None`.
211
+ k_max (int): The maximum number of SCP iterations. Defaults to 200.
212
+ w_tr (float): The trust region weight. Defaults to 1.0.
213
+ lam_vc (float): The penalty weight for virtual control. Defaults to 1.0.
214
+ ep_tr (float): The trust region convergence tolerance. Defaults to 1e-4.
215
+ ep_vb (float): The boundary constraint convergence tolerance. Defaults to 1e-4.
216
+ ep_vc (float): The virtual constraint convergence tolerance. Defaults to 1e-8.
217
+ lam_cost (float): The weight for original cost. Defaults to 0.0.
218
+ lam_vb (float): The weight for virtual buffer. This is only used if there are nonconvex nodal constraints present. Defaults to 0.0.
219
+ uniform_time_grid (bool): Whether to use a uniform time grid. TODO haynec add a link to the time dilation page. Defaults to `False`.
220
+ cost_drop (int): The number of iterations to allow for cost stagnation before termination. Defaults to -1 (disabled).
221
+ cost_relax (float): The relaxation factor for cost reduction. Defaults to 1.0.
222
+ w_tr_adapt (float): The adaptation factor for the trust region weight. Defaults to 1.0.
223
+ w_tr_max (float): The maximum allowable trust region weight. Defaults to `None`.
224
+ w_tr_max_scaling_factor (float): The scaling factor for the maximum trust region weight. Defaults to `None`.
225
+ """
226
+
227
+ def __post_init__(self):
228
+ keys_to_scale = ["w_tr", "lam_vc", "lam_cost", "lam_vb"]
229
+ scale = max(getattr(self, key) for key in keys_to_scale)
230
+ for key in keys_to_scale:
231
+ setattr(self, key, getattr(self, key) / scale)
232
+
233
+ if self.w_tr_max_scaling_factor is not None and self.w_tr_max is None:
234
+ self.w_tr_max = self.w_tr_max_scaling_factor * self.w_tr
235
+
236
+
237
+ @dataclass
238
+ class Config:
239
+ sim: SimConfig
240
+ scp: ScpConfig
241
+ cvx: ConvexSolverConfig
242
+ dis: DiscretizationConfig
243
+ prp: PropagationConfig
244
+ dev: DevConfig
245
+
246
+ def __post_init__(self):
247
+ pass
@@ -0,0 +1,169 @@
1
+ import jax.numpy as jnp
2
+ import numpy as np
3
+
4
+ from openscvx.integrators import solve_ivp_rk45, solve_ivp_diffrax
5
+
6
+
7
+ def dVdt(
8
+ tau: float,
9
+ V: jnp.ndarray,
10
+ u_cur: np.ndarray,
11
+ u_next: np.ndarray,
12
+ state_dot: callable,
13
+ A: callable,
14
+ B: callable,
15
+ n_x: int,
16
+ n_u: int,
17
+ N: int,
18
+ dis_type: str,
19
+ ) -> jnp.ndarray:
20
+ # Define the nodes
21
+ nodes = jnp.arange(0, N-1)
22
+
23
+ # Define indices for slicing the augmented state vector
24
+ i0 = 0
25
+ i1 = n_x
26
+ i2 = i1 + n_x * n_x
27
+ i3 = i2 + n_x * n_u
28
+ i4 = i3 + n_x * n_u
29
+ i5 = i4 + n_x
30
+
31
+ # Unflatten V
32
+ V = V.reshape(-1, i5)
33
+
34
+ # Compute the interpolation factor based on the discretization type
35
+ if dis_type == "ZOH":
36
+ beta = 0.0
37
+ elif dis_type == "FOH":
38
+ beta = (tau) * N
39
+ alpha = 1 - beta
40
+
41
+ # Interpolate the control input
42
+ u = u_cur + beta * (u_next - u_cur)
43
+ s = u[:, -1]
44
+
45
+ # Initialize the augmented Jacobians
46
+ dfdx = jnp.zeros((V.shape[0], n_x, n_x))
47
+ dfdu = jnp.zeros((V.shape[0], n_x, n_u))
48
+
49
+ # Ensure x_seq and u have the same batch size
50
+ x = V[:, :n_x]
51
+ u = u[: x.shape[0]]
52
+
53
+ # Compute the nonlinear propagation term
54
+ f = state_dot(x, u[:, :-1], nodes)
55
+ F = s[:, None] * f
56
+
57
+ # Evaluate the State Jacobian
58
+ dfdx = A(x, u[:, :-1], nodes)
59
+ sdfdx = s[:, None, None] * dfdx
60
+
61
+ # Evaluate the Control Jacobian
62
+ dfdu_veh = B(x, u[:, :-1], nodes)
63
+ dfdu = dfdu.at[:, :, :-1].set(s[:, None, None] * dfdu_veh)
64
+ dfdu = dfdu.at[:, :, -1].set(f)
65
+
66
+ # Compute the defect
67
+ z = F - jnp.einsum("ijk,ik->ij", sdfdx, x) - jnp.einsum("ijk,ik->ij", dfdu, u)
68
+
69
+ # Stack up the results into the augmented state vector
70
+ # fmt: off
71
+ dVdt = jnp.zeros_like(V)
72
+ dVdt = dVdt.at[:, i0:i1].set(F)
73
+ dVdt = dVdt.at[:, i1:i2].set(jnp.matmul(sdfdx, V[:, i1:i2].reshape(-1, n_x, n_x)).reshape(-1, n_x * n_x))
74
+ dVdt = dVdt.at[:, i2:i3].set((jnp.matmul(sdfdx, V[:, i2:i3].reshape(-1, n_x, n_u)) + dfdu * alpha).reshape(-1, n_x * n_u))
75
+ dVdt = dVdt.at[:, i3:i4].set((jnp.matmul(sdfdx, V[:, i3:i4].reshape(-1, n_x, n_u)) + dfdu * beta).reshape(-1, n_x * n_u))
76
+ dVdt = dVdt.at[:, i4:i5].set((jnp.matmul(sdfdx, V[:, i4:i5].reshape(-1, n_x)[..., None]).squeeze(-1) + z).reshape(-1, n_x))
77
+ # fmt: on
78
+ return dVdt.flatten()
79
+
80
+
81
+ def calculate_discretization(
82
+ x,
83
+ u,
84
+ state_dot: callable,
85
+ A: callable,
86
+ B: callable,
87
+ n_x: int,
88
+ n_u: int,
89
+ N: int,
90
+ custom_integrator: bool,
91
+ debug: bool,
92
+ solver: str,
93
+ rtol,
94
+ atol,
95
+ dis_type: str,
96
+ ):
97
+
98
+ # Define indices for slicing the augmented state vector
99
+ i0 = 0
100
+ i1 = n_x
101
+ i2 = i1 + n_x * n_x
102
+ i3 = i2 + n_x * n_u
103
+ i4 = i3 + n_x * n_u
104
+ i5 = i4 + n_x
105
+
106
+ # initial augmented state
107
+ V0 = jnp.zeros((N - 1, i5))
108
+ V0 = V0.at[:, :n_x].set(x[:-1].astype(float))
109
+ V0 = V0.at[:, n_x : n_x + n_x * n_x].set(
110
+ jnp.eye(n_x).reshape(1, -1).repeat(N - 1, axis=0)
111
+ )
112
+
113
+ # choose integrator
114
+ if custom_integrator:
115
+ # fmt: off
116
+ sol = solve_ivp_rk45(
117
+ lambda t,y,*a: dVdt(t, y, *a),
118
+ 1.0/(N-1),
119
+ V0.reshape(-1),
120
+ args=(u[:-1].astype(float), u[1:].astype(float),
121
+ state_dot, A, B, n_x, n_u, N, dis_type),
122
+ is_not_compiled=debug,
123
+ )
124
+ # fmt: on
125
+ else:
126
+ # fmt: off
127
+ sol = solve_ivp_diffrax(
128
+ lambda t,y,*a: dVdt(t, y, *a),
129
+ 1.0/(N-1),
130
+ V0.reshape(-1),
131
+ args=(u[:-1].astype(float), u[1:].astype(float),
132
+ state_dot, A, B, n_x, n_u, N, dis_type),
133
+ solver_name=solver,
134
+ rtol=rtol,
135
+ atol=atol,
136
+ extra_kwargs=None,
137
+ )
138
+ # fmt: on
139
+
140
+ Vend = sol[-1].T.reshape(-1, i5)
141
+ Vmulti = sol.T
142
+
143
+ # fmt: off
144
+ A_bar = Vend[:, i1:i2].reshape(N-1, n_x, n_x).transpose(1,2,0).reshape(n_x*n_x, -1, order='F').T
145
+ B_bar = Vend[:, i2:i3].reshape(N-1, n_x, n_u).transpose(1,2,0).reshape(n_x*n_u, -1, order='F').T
146
+ C_bar = Vend[:, i3:i4].reshape(N-1, n_x, n_u).transpose(1,2,0).reshape(n_x*n_u, -1, order='F').T
147
+ z_bar = Vend[:, i4:i5]
148
+ # fmt: on
149
+
150
+ return A_bar, B_bar, C_bar, z_bar, Vmulti
151
+
152
+
153
+ def get_discretization_solver(state_dot, A, B, params):
154
+ return lambda x, u: calculate_discretization(
155
+ x=x,
156
+ u=u,
157
+ state_dot=state_dot,
158
+ A=A,
159
+ B=B,
160
+ n_x=params.sim.n_states,
161
+ n_u=params.sim.n_controls,
162
+ N=params.scp.n,
163
+ custom_integrator=params.dis.custom_integrator,
164
+ debug=params.dev.debug,
165
+ solver=params.dis.solver,
166
+ rtol=params.dis.rtol,
167
+ atol=params.dis.atol,
168
+ dis_type=params.dis.dis_type,
169
+ )
openscvx/dynamics.py ADDED
@@ -0,0 +1,24 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+
5
+ def get_augmented_dynamics(
6
+ dynamics: callable, g_funcs: list[callable], idx_x_true: slice, idx_u_true: slice
7
+ ) -> callable:
8
+ def dynamics_augmented(x: jnp.array, u: jnp.array, node: int) -> jnp.array:
9
+ x_dot = dynamics(x[idx_x_true], u[idx_u_true])
10
+
11
+ # Iterate through the g_func dictionary and stack the output each function
12
+ # to x_dot
13
+ for g in g_funcs:
14
+ x_dot = jnp.hstack([x_dot, g(x[idx_x_true], u[idx_u_true], node)])
15
+
16
+ return x_dot
17
+
18
+ return dynamics_augmented
19
+
20
+
21
+ def get_jacobians(dyn: callable):
22
+ A = jax.jacfwd(dyn, argnums=0)
23
+ B = jax.jacfwd(dyn, argnums=1)
24
+ return A, B
@@ -0,0 +1,139 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import diffrax as dfx
4
+
5
+ SOLVER_MAP = {
6
+ "Tsit5": dfx.Tsit5,
7
+ "Euler": dfx.Euler,
8
+ "Heun": dfx.Heun,
9
+ "Midpoint": dfx.Midpoint,
10
+ "Ralston": dfx.Ralston,
11
+ "Dopri5": dfx.Dopri5,
12
+ "Dopri8": dfx.Dopri8,
13
+ "Bosh3": dfx.Bosh3,
14
+ "ReversibleHeun": dfx.ReversibleHeun,
15
+ "ImplicitEuler": dfx.ImplicitEuler,
16
+ "KenCarp3": dfx.KenCarp3,
17
+ "KenCarp4": dfx.KenCarp4,
18
+ "KenCarp5": dfx.KenCarp5,
19
+ }
20
+
21
+ # fmt: off
22
+ def rk45_step(f, t, y, h, *args):
23
+ k1 = f(t, y, *args)
24
+ k2 = f(t + h/4, y + h*k1/4, *args)
25
+ k3 = f(t + 3*h/8, y + 3*h*k1/32 + 9*h*k2/32, *args)
26
+ k4 = f(t + 12*h/13, y + 1932*h*k1/2197 - 7200*h*k2/2197 + 7296*h*k3/2197, *args)
27
+ k5 = f(t + h, y + 439*h*k1/216 - 8*h*k2 + 3680*h*k3/513 - 845*h*k4/4104, *args)
28
+ y_next = y + h * (25*k1/216 + 1408*k3/2565 + 2197*k4/4104 - k5/5)
29
+ return y_next
30
+ # fmt: on
31
+
32
+
33
+ def solve_ivp_rk45(
34
+ f,
35
+ tau_final: float,
36
+ y_0,
37
+ args,
38
+ tau_0: float = 0.0,
39
+ num_substeps: int = 50,
40
+ is_not_compiled: bool = False,
41
+ ):
42
+ substeps = jnp.linspace(tau_0, tau_final, num_substeps)
43
+
44
+ h = (tau_final - tau_0) / (len(substeps) - 1)
45
+ solution = jnp.zeros((len(substeps), len(y_0)))
46
+ solution = solution.at[0].set(y_0)
47
+
48
+ if is_not_compiled:
49
+ for i in range(1, len(substeps)):
50
+ t = tau_0 + i * h
51
+ solution = solution.at[i].set(rk45_step(f, t, solution[i - 1], h, *args))
52
+ else:
53
+
54
+ def body_fun(i, val):
55
+ t, y, V_result = val
56
+ y_next = rk45_step(f, t, y, h, *args)
57
+ V_result = V_result.at[i].set(y_next)
58
+ return (t + h, y_next, V_result)
59
+
60
+ _, _, solution = jax.lax.fori_loop(
61
+ 1, len(substeps), body_fun, (tau_0, y_0, solution)
62
+ )
63
+
64
+ return solution
65
+
66
+
67
+ def solve_ivp_diffrax(
68
+ f,
69
+ tau_final,
70
+ y_0,
71
+ args,
72
+ tau_0: float = 0.0,
73
+ num_substeps: int = 50,
74
+ solver_name="Dopri8",
75
+ rtol: float = 1e-3,
76
+ atol: float = 1e-6,
77
+ extra_kwargs=None,
78
+ ):
79
+ substeps = jnp.linspace(tau_0, tau_final, num_substeps)
80
+
81
+ solver_class = SOLVER_MAP.get(solver_name)
82
+ if solver_class is None:
83
+ raise ValueError(f"Unknown solver: {solver_name}")
84
+ solver = solver_class()
85
+
86
+ term = dfx.ODETerm(lambda t, y, args: f(t, y, *args))
87
+ stepsize_controller = dfx.PIDController(rtol=rtol, atol=atol)
88
+ solution = dfx.diffeqsolve(
89
+ term,
90
+ solver=solver,
91
+ t0=tau_0,
92
+ t1=tau_final,
93
+ dt0=(tau_final - tau_0) / (len(substeps) - 1),
94
+ y0=y_0,
95
+ args=args,
96
+ stepsize_controller=stepsize_controller,
97
+ saveat=dfx.SaveAt(ts=substeps),
98
+ **(extra_kwargs or {}),
99
+ )
100
+
101
+ return solution.ys
102
+
103
+
104
+ # TODO: (norrisg) this function is basically identical to `solve_ivp_diffrax`, could combine, but requires returning solution and getting `.ys` wherever the `solve_ivp_diffrax` is called
105
+ def solve_ivp_diffrax_prop(
106
+ f,
107
+ tau_final,
108
+ y_0,
109
+ args,
110
+ tau_0: float = 0.0,
111
+ num_substeps: int = 50,
112
+ solver_name="Dopri8",
113
+ rtol: float = 1e-3,
114
+ atol: float = 1e-6,
115
+ extra_kwargs=None,
116
+ ):
117
+ substeps = jnp.linspace(tau_0, tau_final, num_substeps)
118
+
119
+ solver_class = SOLVER_MAP.get(solver_name)
120
+ if solver_class is None:
121
+ raise ValueError(f"Unknown solver: {solver_name}")
122
+ solver = solver_class()
123
+
124
+ term = dfx.ODETerm(lambda t, y, args: f(t, y, *args))
125
+ stepsize_controller = dfx.PIDController(rtol=rtol, atol=atol)
126
+ solution = dfx.diffeqsolve(
127
+ term,
128
+ solver=solver,
129
+ t0=tau_0,
130
+ t1=tau_final,
131
+ dt0=(tau_final - tau_0) / (len(substeps) - 1),
132
+ y0=y_0,
133
+ args=args,
134
+ stepsize_controller=stepsize_controller,
135
+ saveat=dfx.SaveAt(dense=True, ts=substeps),
136
+ **(extra_kwargs or {}),
137
+ )
138
+
139
+ return solution