openscvx 0.1.2__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/_version.py +2 -2
- openscvx/augmentation/dynamics_augmentation.py +22 -7
- openscvx/config.py +310 -192
- openscvx/constraints/__init__.py +0 -3
- openscvx/constraints/ctcs.py +188 -33
- openscvx/constraints/nodal.py +150 -11
- openscvx/constraints/violation.py +12 -2
- openscvx/discretization.py +115 -37
- openscvx/dynamics.py +150 -11
- openscvx/integrators.py +135 -16
- openscvx/io.py +129 -17
- openscvx/ocp.py +86 -67
- openscvx/plotting.py +72 -215
- openscvx/post_processing.py +57 -16
- openscvx/propagation.py +155 -55
- openscvx/ptr.py +96 -57
- openscvx/results.py +153 -0
- openscvx/trajoptproblem.py +359 -114
- openscvx/utils.py +50 -0
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/METADATA +129 -41
- openscvx-0.2.1.dev0.dist-info/RECORD +27 -0
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/WHEEL +1 -1
- openscvx/constraints/boundary.py +0 -49
- openscvx-0.1.2.dist-info/RECORD +0 -27
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/top_level.txt +0 -0
openscvx/discretization.py
CHANGED
|
@@ -17,7 +17,30 @@ def dVdt(
|
|
|
17
17
|
n_u: int,
|
|
18
18
|
N: int,
|
|
19
19
|
dis_type: str,
|
|
20
|
+
**params
|
|
20
21
|
) -> jnp.ndarray:
|
|
22
|
+
"""Compute the time derivative of the augmented state vector.
|
|
23
|
+
|
|
24
|
+
This function computes the time derivative of the augmented state vector V,
|
|
25
|
+
which includes the state, state transition matrix, and control influence matrix.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
tau (float): Current normalized time in [0,1].
|
|
29
|
+
V (jnp.ndarray): Augmented state vector.
|
|
30
|
+
u_cur (np.ndarray): Control input at current node.
|
|
31
|
+
u_next (np.ndarray): Control input at next node.
|
|
32
|
+
state_dot (callable): Function computing state derivatives.
|
|
33
|
+
A (callable): Function computing state Jacobian.
|
|
34
|
+
B (callable): Function computing control Jacobian.
|
|
35
|
+
n_x (int): Number of states.
|
|
36
|
+
n_u (int): Number of controls.
|
|
37
|
+
N (int): Number of nodes in trajectory.
|
|
38
|
+
dis_type (str): Discretization type ("ZOH" or "FOH").
|
|
39
|
+
**params: Additional parameters passed to state_dot, A, and B.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
jnp.ndarray: Time derivative of augmented state vector.
|
|
43
|
+
"""
|
|
21
44
|
# Define the nodes
|
|
22
45
|
nodes = jnp.arange(0, N-1)
|
|
23
46
|
|
|
@@ -52,15 +75,15 @@ def dVdt(
|
|
|
52
75
|
u = u[: x.shape[0]]
|
|
53
76
|
|
|
54
77
|
# Compute the nonlinear propagation term
|
|
55
|
-
f = state_dot(x, u[:, :-1], nodes)
|
|
78
|
+
f = state_dot(x, u[:, :-1], nodes, *params.items())
|
|
56
79
|
F = s[:, None] * f
|
|
57
80
|
|
|
58
81
|
# Evaluate the State Jacobian
|
|
59
|
-
dfdx = A(x, u[:, :-1], nodes)
|
|
82
|
+
dfdx = A(x, u[:, :-1], nodes, *params.items())
|
|
60
83
|
sdfdx = s[:, None, None] * dfdx
|
|
61
84
|
|
|
62
85
|
# Evaluate the Control Jacobian
|
|
63
|
-
dfdu_veh = B(x, u[:, :-1], nodes)
|
|
86
|
+
dfdu_veh = B(x, u[:, :-1], nodes, *params.items())
|
|
64
87
|
dfdu = dfdu.at[:, :, :-1].set(s[:, None, None] * dfdu_veh)
|
|
65
88
|
dfdu = dfdu.at[:, :, -1].set(f)
|
|
66
89
|
|
|
@@ -76,7 +99,8 @@ def dVdt(
|
|
|
76
99
|
dVdt = dVdt.at[:, i3:i4].set((jnp.matmul(sdfdx, V[:, i3:i4].reshape(-1, n_x, n_u)) + dfdu * beta).reshape(-1, n_x * n_u))
|
|
77
100
|
dVdt = dVdt.at[:, i4:i5].set((jnp.matmul(sdfdx, V[:, i4:i5].reshape(-1, n_x)[..., None]).squeeze(-1) + z).reshape(-1, n_x))
|
|
78
101
|
# fmt: on
|
|
79
|
-
|
|
102
|
+
|
|
103
|
+
return dVdt.reshape(-1)
|
|
80
104
|
|
|
81
105
|
|
|
82
106
|
def calculate_discretization(
|
|
@@ -94,8 +118,38 @@ def calculate_discretization(
|
|
|
94
118
|
rtol,
|
|
95
119
|
atol,
|
|
96
120
|
dis_type: str,
|
|
121
|
+
**kwargs
|
|
97
122
|
):
|
|
98
|
-
|
|
123
|
+
"""Calculate the discretized system matrices.
|
|
124
|
+
|
|
125
|
+
This function computes the discretized system matrices (A_bar, B_bar, C_bar)
|
|
126
|
+
and defect vector (z_bar) using numerical integration.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
x: State trajectory.
|
|
130
|
+
u: Control trajectory.
|
|
131
|
+
state_dot (callable): Function computing state derivatives.
|
|
132
|
+
A (callable): Function computing state Jacobian.
|
|
133
|
+
B (callable): Function computing control Jacobian.
|
|
134
|
+
n_x (int): Number of states.
|
|
135
|
+
n_u (int): Number of controls.
|
|
136
|
+
N (int): Number of nodes in trajectory.
|
|
137
|
+
custom_integrator (bool): Whether to use custom RK45 integrator.
|
|
138
|
+
debug (bool): Whether to use debug mode.
|
|
139
|
+
solver (str): Name of the solver to use.
|
|
140
|
+
rtol (float): Relative tolerance for integration.
|
|
141
|
+
atol (float): Absolute tolerance for integration.
|
|
142
|
+
dis_type (str): Discretization type ("ZOH" or "FOH").
|
|
143
|
+
**kwargs: Additional parameters passed to state_dot, A, and B.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
tuple: (A_bar, B_bar, C_bar, z_bar, Vmulti) where:
|
|
147
|
+
- A_bar: Discretized state transition matrix
|
|
148
|
+
- B_bar: Discretized control influence matrix
|
|
149
|
+
- C_bar: Discretized control influence matrix for next node
|
|
150
|
+
- z_bar: Defect vector
|
|
151
|
+
- Vmulti: Full augmented state trajectory
|
|
152
|
+
"""
|
|
99
153
|
# Define indices for slicing the augmented state vector
|
|
100
154
|
i0 = 0
|
|
101
155
|
i1 = n_x
|
|
@@ -104,67 +158,91 @@ def calculate_discretization(
|
|
|
104
158
|
i4 = i3 + n_x * n_u
|
|
105
159
|
i5 = i4 + n_x
|
|
106
160
|
|
|
107
|
-
#
|
|
161
|
+
# Initial augmented state
|
|
108
162
|
V0 = jnp.zeros((N - 1, i5))
|
|
109
163
|
V0 = V0.at[:, :n_x].set(x[:-1].astype(float))
|
|
110
|
-
V0 = V0.at[:, n_x
|
|
164
|
+
V0 = V0.at[:, n_x:n_x + n_x * n_x].set(
|
|
111
165
|
jnp.eye(n_x).reshape(1, -1).repeat(N - 1, axis=0)
|
|
112
166
|
)
|
|
113
167
|
|
|
114
|
-
#
|
|
168
|
+
# Choose integrator
|
|
169
|
+
integrator_args = dict(
|
|
170
|
+
u_cur=u[:-1].astype(float),
|
|
171
|
+
u_next=u[1:].astype(float),
|
|
172
|
+
state_dot=state_dot,
|
|
173
|
+
A=A,
|
|
174
|
+
B=B,
|
|
175
|
+
n_x=n_x,
|
|
176
|
+
n_u=n_u,
|
|
177
|
+
N=N,
|
|
178
|
+
dis_type=dis_type,
|
|
179
|
+
**kwargs # <-- adds parameter values with names
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Define dVdt wrapper using named arguments
|
|
183
|
+
def dVdt_wrapped(t, y):
|
|
184
|
+
return dVdt(t, y, **integrator_args)
|
|
185
|
+
|
|
186
|
+
# Choose integrator
|
|
115
187
|
if custom_integrator:
|
|
116
|
-
# fmt: off
|
|
117
188
|
sol = solve_ivp_rk45(
|
|
118
|
-
|
|
119
|
-
1.0/(N-1),
|
|
189
|
+
dVdt_wrapped,
|
|
190
|
+
1.0 / (N - 1),
|
|
120
191
|
V0.reshape(-1),
|
|
121
|
-
args=(
|
|
122
|
-
state_dot, A, B, n_x, n_u, N, dis_type),
|
|
192
|
+
args=(),
|
|
123
193
|
is_not_compiled=debug,
|
|
124
194
|
)
|
|
125
|
-
# fmt: on
|
|
126
195
|
else:
|
|
127
|
-
# fmt: off
|
|
128
196
|
sol = solve_ivp_diffrax(
|
|
129
|
-
|
|
130
|
-
1.0/(N-1),
|
|
197
|
+
dVdt_wrapped,
|
|
198
|
+
1.0 / (N - 1),
|
|
131
199
|
V0.reshape(-1),
|
|
132
|
-
args=(u[:-1].astype(float), u[1:].astype(float),
|
|
133
|
-
state_dot, A, B, n_x, n_u, N, dis_type),
|
|
134
200
|
solver_name=solver,
|
|
135
201
|
rtol=rtol,
|
|
136
202
|
atol=atol,
|
|
203
|
+
args=(),
|
|
137
204
|
extra_kwargs=None,
|
|
138
205
|
)
|
|
139
|
-
# fmt: on
|
|
140
206
|
|
|
141
207
|
Vend = sol[-1].T.reshape(-1, i5)
|
|
142
208
|
Vmulti = sol.T
|
|
143
209
|
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
C_bar = Vend[:, i3:i4].reshape(N-1, n_x, n_u).transpose(1,2,0).reshape(n_x*n_u, -1, order='F').T
|
|
210
|
+
A_bar = Vend[:, i1:i2].reshape(N - 1, n_x, n_x).transpose(1, 2, 0).reshape(n_x * n_x, -1, order='F').T
|
|
211
|
+
B_bar = Vend[:, i2:i3].reshape(N - 1, n_x, n_u).transpose(1, 2, 0).reshape(n_x * n_u, -1, order='F').T
|
|
212
|
+
C_bar = Vend[:, i3:i4].reshape(N - 1, n_x, n_u).transpose(1, 2, 0).reshape(n_x * n_u, -1, order='F').T
|
|
148
213
|
z_bar = Vend[:, i4:i5]
|
|
149
|
-
# fmt: on
|
|
150
214
|
|
|
151
215
|
return A_bar, B_bar, C_bar, z_bar, Vmulti
|
|
152
216
|
|
|
153
217
|
|
|
154
|
-
def get_discretization_solver(dyn: Dynamics,
|
|
155
|
-
|
|
218
|
+
def get_discretization_solver(dyn: Dynamics, settings, param_map):
|
|
219
|
+
"""Create a discretization solver function.
|
|
220
|
+
|
|
221
|
+
This function creates a solver that computes the discretized system matrices
|
|
222
|
+
using the specified dynamics and settings.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
dyn (Dynamics): System dynamics object.
|
|
226
|
+
settings: Configuration settings for discretization.
|
|
227
|
+
param_map (dict): Mapping of parameter names to values.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
callable: A function that computes the discretized system matrices.
|
|
231
|
+
"""
|
|
232
|
+
return lambda x, u, *params: calculate_discretization(
|
|
156
233
|
x=x,
|
|
157
234
|
u=u,
|
|
158
235
|
state_dot=dyn.f,
|
|
159
236
|
A=dyn.A,
|
|
160
237
|
B=dyn.B,
|
|
161
|
-
n_x=
|
|
162
|
-
n_u=
|
|
163
|
-
N=
|
|
164
|
-
custom_integrator=
|
|
165
|
-
debug=
|
|
166
|
-
solver=
|
|
167
|
-
rtol=
|
|
168
|
-
atol=
|
|
169
|
-
dis_type=
|
|
170
|
-
|
|
238
|
+
n_x=settings.sim.n_states,
|
|
239
|
+
n_u=settings.sim.n_controls,
|
|
240
|
+
N=settings.scp.n,
|
|
241
|
+
custom_integrator=settings.dis.custom_integrator,
|
|
242
|
+
debug=settings.dev.debug,
|
|
243
|
+
solver=settings.dis.solver,
|
|
244
|
+
rtol=settings.dis.rtol,
|
|
245
|
+
atol=settings.dis.atol,
|
|
246
|
+
dis_type=settings.dis.dis_type,
|
|
247
|
+
**dict(zip(param_map.keys(), params)) # <--- Named keyword args
|
|
248
|
+
)
|
openscvx/dynamics.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Callable, Optional
|
|
2
|
+
from typing import Callable, Optional, Union
|
|
3
3
|
import functools
|
|
4
4
|
|
|
5
5
|
import jax.numpy as jnp
|
|
@@ -7,25 +7,165 @@ import jax.numpy as jnp
|
|
|
7
7
|
|
|
8
8
|
@dataclass
|
|
9
9
|
class Dynamics:
|
|
10
|
+
"""
|
|
11
|
+
Dataclass to hold a system dynamics function and (optionally) its gradients.
|
|
12
|
+
This class is intended to be instantiated using the `dynamics` decorator wrapped around a function defining the system dynamics.
|
|
13
|
+
Both the dynamics and optional gradients should be composed of `jax` primitives to enable efficient computation.
|
|
14
|
+
|
|
15
|
+
Usage examples:
|
|
16
|
+
|
|
17
|
+
```python
|
|
18
|
+
@dynamics
|
|
19
|
+
def f(x_, u_):
|
|
20
|
+
return x_ + u_
|
|
21
|
+
# f is now a Dynamics object
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
```python
|
|
25
|
+
@dynamics(A=grad_f_x, B=grad_f_u)
|
|
26
|
+
def f(x_, u_):
|
|
27
|
+
return x_ + u_
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
Or, if a more lambda-function-style is desired, the function can be directly wrapped:
|
|
31
|
+
|
|
32
|
+
```python
|
|
33
|
+
dyn = dynamics(lambda x_, u_: x_ + u_)
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
---
|
|
37
|
+
**Using Parameters in Dynamics**
|
|
38
|
+
|
|
39
|
+
You can use symbolic `Parameter` objects in your dynamics function to represent tunable or environment-dependent values. **The argument names for parameters must match the parameter name with an underscore suffix** (e.g., `I_sp_` for a parameter named `I_sp`). This is required for the parameter mapping to work correctly.
|
|
40
|
+
|
|
41
|
+
Example (3DoF rocket landing):
|
|
42
|
+
|
|
43
|
+
```python
|
|
44
|
+
from openscvx.backend.parameter import Parameter
|
|
45
|
+
import jax.numpy as jnp
|
|
46
|
+
|
|
47
|
+
I_sp = Parameter("I_sp")
|
|
48
|
+
g = Parameter("g")
|
|
49
|
+
theta = Parameter("theta")
|
|
50
|
+
|
|
51
|
+
@dynamics
|
|
52
|
+
def rocket_dynamics(x_, u_, I_sp_, g_, theta_):
|
|
53
|
+
m = x_[6]
|
|
54
|
+
T = u_
|
|
55
|
+
r_dot = x_[3:6]
|
|
56
|
+
g_vec = jnp.array([0, 0, g_])
|
|
57
|
+
v_dot = T/m - g_vec
|
|
58
|
+
m_dot = -jnp.linalg.norm(T) / (I_sp_ * 9.807 * jnp.cos(theta_))
|
|
59
|
+
t_dot = 1
|
|
60
|
+
return jnp.hstack([r_dot, v_dot, m_dot, t_dot])
|
|
61
|
+
|
|
62
|
+
# Set parameter values before solving
|
|
63
|
+
I_sp.value = 225
|
|
64
|
+
g.value = 3.7114
|
|
65
|
+
theta.value = 27 * jnp.pi / 180
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
---
|
|
69
|
+
**Using Parameters in Nodal Constraints**
|
|
70
|
+
|
|
71
|
+
You can also use symbolic `Parameter` objects in nodal constraints. As with dynamics, the argument names for parameters in the constraint function must match the parameter name with an underscore suffix (e.g., `g_` for a parameter named `g`).
|
|
72
|
+
|
|
73
|
+
Example:
|
|
74
|
+
|
|
75
|
+
```python
|
|
76
|
+
from openscvx.backend.parameter import Parameter
|
|
77
|
+
from openscvx.constraints import nodal
|
|
78
|
+
import jax.numpy as jnp
|
|
79
|
+
|
|
80
|
+
g = Parameter("g")
|
|
81
|
+
g.value = 3.7114
|
|
82
|
+
|
|
83
|
+
@nodal
|
|
84
|
+
def terminal_velocity_constraint(x_, u_, g_):
|
|
85
|
+
# Enforce a terminal velocity constraint using the gravity parameter
|
|
86
|
+
return x_[5] + g_ * x_[7] # e.g., vz + g * t <= 0 at final node
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
When building your problem, collect all parameters with `Parameter.get_all()` and pass them to your problem setup.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
f (Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
|
|
93
|
+
Function defining the continuous time nonlinear system dynamics as x_dot = f(x, u, ...params).
|
|
94
|
+
- x: 1D array (state at a single node), shape (n_x,)
|
|
95
|
+
- u: 1D array (control at a single node), shape (n_u,)
|
|
96
|
+
- Additional parameters: passed as keyword arguments with names matching the parameter name plus an underscore (e.g., g_ for Parameter('g')).
|
|
97
|
+
If you want to use parameters, include them as extra arguments with the underscore naming convention.
|
|
98
|
+
If you use vectorized integration or batch evaluation, x and u may be 2D arrays (N, n_x) and (N, n_u).
|
|
99
|
+
A (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
|
|
100
|
+
Jacobian of `f` w.r.t. `x`. If not specified, will be calculated using `jax.jacfwd`.
|
|
101
|
+
B (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
|
|
102
|
+
Jacobian of `f` w.r.t. `u`. If not specified, will be calculated using `jax.jacfwd`.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Dynamics: A dataclass bundling the system dynamics function and Jacobians.
|
|
106
|
+
"""
|
|
107
|
+
|
|
10
108
|
f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
|
|
11
109
|
A: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
|
|
12
110
|
B: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
|
|
13
111
|
|
|
112
|
+
|
|
14
113
|
def dynamics(
|
|
15
|
-
_func=None,
|
|
114
|
+
_func: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
|
|
16
115
|
*,
|
|
17
|
-
A: Optional[Callable] = None,
|
|
18
|
-
B: Optional[Callable] = None,
|
|
19
|
-
|
|
116
|
+
A: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
|
|
117
|
+
B: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
|
|
118
|
+
) -> Union[Callable, Dynamics]:
|
|
119
|
+
"""
|
|
120
|
+
Decorator that wraps a function defining the system dynamics as a `Dynamics` object.
|
|
121
|
+
You may optionally specify the system gradients w.r.t. `x`, `u` if desired, if not specified they will be calculated using `jax.jacfwd`.
|
|
122
|
+
Note: the dynamics as well as the optional gradients should be composed of `jax` primitives to enable efficient computation.
|
|
123
|
+
|
|
124
|
+
This decorator may be used with or without arguments:
|
|
20
125
|
|
|
21
|
-
|
|
22
|
-
@dynamics
|
|
23
|
-
def
|
|
126
|
+
```
|
|
127
|
+
@dynamics
|
|
128
|
+
def f(x, u): ...
|
|
129
|
+
```
|
|
130
|
+
|
|
131
|
+
or
|
|
132
|
+
|
|
133
|
+
```
|
|
134
|
+
@dynamics(A=grad_f_x, B=grad_f_u)
|
|
135
|
+
def f(x, u): ...
|
|
136
|
+
```
|
|
137
|
+
|
|
138
|
+
or, if a more lambda-function-style is desired, the function can be direclty wrapped
|
|
139
|
+
|
|
140
|
+
```
|
|
141
|
+
dyn = dynamics(f(x,u))
|
|
142
|
+
dyn_lambda = dynamics(lambda x, u: ...)
|
|
143
|
+
```
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
_func (callable, optional): The function to wrap. Populated
|
|
147
|
+
when using @dynamics with no extra args.
|
|
148
|
+
A (callable, optional): Jacobian of f wrt state x. Computed
|
|
149
|
+
via jax.jacfwd if not provided.
|
|
150
|
+
B (callable, optional): Jacobian of f wrt input u. Computed
|
|
151
|
+
via jax.jacfwd if not provided.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
Union[Callable, Dynamics]
|
|
155
|
+
A decorator if called without a function, or a `Dynamics` dataclass bundling system dynamics function
|
|
156
|
+
and Jacobians when applied to a function.
|
|
157
|
+
|
|
158
|
+
Examples:
|
|
159
|
+
>>> @dynamics
|
|
160
|
+
... def f(x, u):
|
|
161
|
+
... return x + u
|
|
162
|
+
>>> isinstance(f, Dynamics)
|
|
163
|
+
True
|
|
24
164
|
"""
|
|
25
165
|
|
|
26
166
|
def decorator(f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
|
|
27
|
-
#
|
|
28
|
-
wrapped =
|
|
167
|
+
# Had to unwrap to ensure arguments are visible downstream. Originally wrapped so name, doc, signature stay on f
|
|
168
|
+
wrapped = f
|
|
29
169
|
return Dynamics(
|
|
30
170
|
f=wrapped,
|
|
31
171
|
A=A,
|
|
@@ -38,4 +178,3 @@ def dynamics(
|
|
|
38
178
|
# if called as dynamics(func), we immediately decorate
|
|
39
179
|
else:
|
|
40
180
|
return decorator(_func)
|
|
41
|
-
|
openscvx/integrators.py
CHANGED
|
@@ -1,7 +1,52 @@
|
|
|
1
|
+
import os
|
|
2
|
+
os.environ["EQX_ON_ERROR"] = "nan"
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
from typing import Callable, Any, Optional
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
import jax
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
import diffrax as dfx
|
|
11
|
+
from typing import Callable, Any
|
|
12
|
+
|
|
1
13
|
import jax
|
|
2
14
|
import jax.numpy as jnp
|
|
3
15
|
import diffrax as dfx
|
|
4
16
|
|
|
17
|
+
from diffrax._global_interpolation import DenseInterpolation
|
|
18
|
+
from jax import tree_util
|
|
19
|
+
|
|
20
|
+
# Safely check if DenseInterpolation is already registered
|
|
21
|
+
try:
|
|
22
|
+
# Attempt to flatten a dummy DenseInterpolation instance
|
|
23
|
+
# Provide dummy arguments to create a valid instance
|
|
24
|
+
dummy_instance = DenseInterpolation(
|
|
25
|
+
ts=jnp.array([]),
|
|
26
|
+
ts_size=0,
|
|
27
|
+
infos=None,
|
|
28
|
+
interpolation_cls=None,
|
|
29
|
+
direction=None,
|
|
30
|
+
t0_if_trivial=0.0,
|
|
31
|
+
y0_if_trivial=jnp.array([]),
|
|
32
|
+
)
|
|
33
|
+
tree_util.tree_flatten(dummy_instance)
|
|
34
|
+
except ValueError:
|
|
35
|
+
# Register DenseInterpolation as a PyTree node if not already registered
|
|
36
|
+
def dense_interpolation_flatten(obj):
|
|
37
|
+
# Flatten the internal data of DenseInterpolation
|
|
38
|
+
return (obj._data,), None
|
|
39
|
+
|
|
40
|
+
def dense_interpolation_unflatten(aux_data, children):
|
|
41
|
+
# Reconstruct DenseInterpolation from its flattened data
|
|
42
|
+
return DenseInterpolation(*children)
|
|
43
|
+
|
|
44
|
+
tree_util.register_pytree_node(
|
|
45
|
+
DenseInterpolation,
|
|
46
|
+
dense_interpolation_flatten,
|
|
47
|
+
dense_interpolation_unflatten,
|
|
48
|
+
)
|
|
49
|
+
|
|
5
50
|
SOLVER_MAP = {
|
|
6
51
|
"Tsit5": dfx.Tsit5,
|
|
7
52
|
"Euler": dfx.Euler,
|
|
@@ -19,7 +64,30 @@ SOLVER_MAP = {
|
|
|
19
64
|
}
|
|
20
65
|
|
|
21
66
|
# fmt: off
|
|
22
|
-
def rk45_step(
|
|
67
|
+
def rk45_step(
|
|
68
|
+
f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
|
|
69
|
+
t: jnp.ndarray,
|
|
70
|
+
y: jnp.ndarray,
|
|
71
|
+
h: float,
|
|
72
|
+
*args
|
|
73
|
+
) -> jnp.ndarray:
|
|
74
|
+
"""
|
|
75
|
+
Perform a single RK45 (Runge-Kutta-Fehlberg) integration step.
|
|
76
|
+
|
|
77
|
+
This implements the classic Dorman-Prince coefficients for an
|
|
78
|
+
explicit 4(5) method, returning the fourth-order estimate.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
f (Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray]):
|
|
82
|
+
ODE right-hand side; signature f(t, y, *args) -> dy/dt.
|
|
83
|
+
t (jnp.ndarray): Current time.
|
|
84
|
+
y (jnp.ndarray): Current state vector.
|
|
85
|
+
h (float): Step size.
|
|
86
|
+
*args: Additional arguments passed to `f`.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
jnp.ndarray: Next state estimate at t + h.
|
|
90
|
+
"""
|
|
23
91
|
k1 = f(t, y, *args)
|
|
24
92
|
k2 = f(t + h/4, y + h*k1/4, *args)
|
|
25
93
|
k3 = f(t + 3*h/8, y + 3*h*k1/32 + 9*h*k2/32, *args)
|
|
@@ -31,14 +99,31 @@ def rk45_step(f, t, y, h, *args):
|
|
|
31
99
|
|
|
32
100
|
|
|
33
101
|
def solve_ivp_rk45(
|
|
34
|
-
f,
|
|
102
|
+
f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
|
|
35
103
|
tau_final: float,
|
|
36
|
-
y_0,
|
|
104
|
+
y_0: jnp.ndarray,
|
|
37
105
|
args,
|
|
38
106
|
tau_0: float = 0.0,
|
|
39
107
|
num_substeps: int = 50,
|
|
40
108
|
is_not_compiled: bool = False,
|
|
41
109
|
):
|
|
110
|
+
"""
|
|
111
|
+
Solve an initial-value ODE problem using fixed-step RK45 integration.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
f (Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray]):
|
|
115
|
+
ODE right-hand side; signature f(t, y, *args) -> dy/dt.
|
|
116
|
+
tau_final (float): Final integration time.
|
|
117
|
+
y_0 (jnp.ndarray): Initial state at tau_0.
|
|
118
|
+
args (tuple): Extra arguments to pass to `f`.
|
|
119
|
+
tau_0 (float, optional): Initial time. Defaults to 0.0.
|
|
120
|
+
num_substeps (int, optional): Number of output time points. Defaults to 50.
|
|
121
|
+
is_not_compiled (bool, optional): If True, use Python loop instead of
|
|
122
|
+
JAX `lax.fori_loop`. Defaults to False.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
jnp.ndarray: Array of shape (num_substeps, state_dim) with solution at each time.
|
|
126
|
+
"""
|
|
42
127
|
substeps = jnp.linspace(tau_0, tau_final, num_substeps)
|
|
43
128
|
|
|
44
129
|
h = (tau_final - tau_0) / (len(substeps) - 1)
|
|
@@ -65,17 +150,40 @@ def solve_ivp_rk45(
|
|
|
65
150
|
|
|
66
151
|
|
|
67
152
|
def solve_ivp_diffrax(
|
|
68
|
-
f,
|
|
69
|
-
tau_final,
|
|
70
|
-
y_0,
|
|
153
|
+
f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
|
|
154
|
+
tau_final: float,
|
|
155
|
+
y_0: jnp.ndarray,
|
|
71
156
|
args,
|
|
72
157
|
tau_0: float = 0.0,
|
|
73
158
|
num_substeps: int = 50,
|
|
74
|
-
solver_name="Dopri8",
|
|
159
|
+
solver_name: str = "Dopri8",
|
|
75
160
|
rtol: float = 1e-3,
|
|
76
161
|
atol: float = 1e-6,
|
|
77
162
|
extra_kwargs=None,
|
|
78
163
|
):
|
|
164
|
+
"""
|
|
165
|
+
Solve an initial-value ODE problem using a Diffrax adaptive solver.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
f (Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray]): ODE right-hand side; signature f(t, y, *args) -> dy/dt.
|
|
169
|
+
tau_final (float): Final integration time.
|
|
170
|
+
y_0 (jnp.ndarray): Initial state at tau_0.
|
|
171
|
+
args (tuple): Extra arguments to pass to `f` in the solver term.
|
|
172
|
+
tau_0 (float, optional): Initial time. Defaults to 0.0.
|
|
173
|
+
num_substeps (int, optional): Number of save points between tau_0 and tau_final.
|
|
174
|
+
Defaults to 50.
|
|
175
|
+
solver_name (str, optional): Key into SOLVER_MAP for the Diffrax solver class.
|
|
176
|
+
Defaults to "Dopri8".
|
|
177
|
+
rtol (float, optional): Relative tolerance for adaptive stepping. Defaults to 1e-3.
|
|
178
|
+
atol (float, optional): Absolute tolerance for adaptive stepping. Defaults to 1e-6.
|
|
179
|
+
extra_kwargs (dict, optional): Additional keyword arguments forwarded to `diffeqsolve`.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
jnp.ndarray: Solution states at the requested save points, shape (num_substeps, state_dim).
|
|
183
|
+
|
|
184
|
+
Raises:
|
|
185
|
+
ValueError: If `solver_name` is not in SOLVER_MAP.
|
|
186
|
+
"""
|
|
79
187
|
substeps = jnp.linspace(tau_0, tau_final, num_substeps)
|
|
80
188
|
|
|
81
189
|
solver_class = SOLVER_MAP.get(solver_name)
|
|
@@ -95,26 +203,31 @@ def solve_ivp_diffrax(
|
|
|
95
203
|
args=args,
|
|
96
204
|
stepsize_controller=stepsize_controller,
|
|
97
205
|
saveat=dfx.SaveAt(ts=substeps),
|
|
206
|
+
progress_meter=dfx.NoProgressMeter(),
|
|
98
207
|
**(extra_kwargs or {}),
|
|
99
208
|
)
|
|
100
209
|
|
|
101
210
|
return solution.ys
|
|
102
211
|
|
|
103
212
|
|
|
104
|
-
# TODO: (norrisg) this function is basically identical to `solve_ivp_diffrax`, could combine, but requires returning solution and getting `.ys` wherever the `solve_ivp_diffrax` is called
|
|
105
213
|
def solve_ivp_diffrax_prop(
|
|
106
|
-
f,
|
|
107
|
-
tau_final,
|
|
108
|
-
y_0,
|
|
214
|
+
f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
|
|
215
|
+
tau_final: float,
|
|
216
|
+
y_0: jnp.ndarray,
|
|
109
217
|
args,
|
|
110
218
|
tau_0: float = 0.0,
|
|
111
219
|
num_substeps: int = 50,
|
|
112
|
-
solver_name="Dopri8",
|
|
220
|
+
solver_name: str = "Dopri8",
|
|
113
221
|
rtol: float = 1e-3,
|
|
114
222
|
atol: float = 1e-6,
|
|
115
223
|
extra_kwargs=None,
|
|
224
|
+
save_time: jnp.ndarray = None,
|
|
225
|
+
mask: jnp.ndarray = None,
|
|
116
226
|
):
|
|
117
|
-
|
|
227
|
+
if save_time is None:
|
|
228
|
+
raise ValueError("save_time must be provided for export compatibility.")
|
|
229
|
+
if mask is None:
|
|
230
|
+
mask = jnp.ones_like(save_time, dtype=bool)
|
|
118
231
|
|
|
119
232
|
solver_class = SOLVER_MAP.get(solver_name)
|
|
120
233
|
if solver_class is None:
|
|
@@ -123,17 +236,23 @@ def solve_ivp_diffrax_prop(
|
|
|
123
236
|
|
|
124
237
|
term = dfx.ODETerm(lambda t, y, args: f(t, y, *args))
|
|
125
238
|
stepsize_controller = dfx.PIDController(rtol=rtol, atol=atol)
|
|
239
|
+
|
|
126
240
|
solution = dfx.diffeqsolve(
|
|
127
241
|
term,
|
|
128
242
|
solver=solver,
|
|
129
243
|
t0=tau_0,
|
|
130
244
|
t1=tau_final,
|
|
131
|
-
dt0=(tau_final - tau_0) /
|
|
245
|
+
dt0=(tau_final - tau_0) / 1,
|
|
132
246
|
y0=y_0,
|
|
133
247
|
args=args,
|
|
134
248
|
stepsize_controller=stepsize_controller,
|
|
135
|
-
saveat=dfx.SaveAt(dense=True
|
|
249
|
+
saveat=dfx.SaveAt(dense=True),
|
|
136
250
|
**(extra_kwargs or {}),
|
|
137
251
|
)
|
|
138
252
|
|
|
139
|
-
|
|
253
|
+
# Evaluate all save_time points (static size), then mask them
|
|
254
|
+
all_evals = jax.vmap(solution.evaluate)(save_time) # shape: (MAX_TAU_LEN, n_states)
|
|
255
|
+
masked_array = jnp.where(mask[:, None], all_evals, jnp.zeros_like(all_evals))
|
|
256
|
+
# shape: (variable_len, n_states)
|
|
257
|
+
|
|
258
|
+
return masked_array
|