openscvx 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +0 -0
- openscvx/_version.py +21 -0
- openscvx/augmentation.py +44 -0
- openscvx/config.py +247 -0
- openscvx/discretization.py +169 -0
- openscvx/dynamics.py +24 -0
- openscvx/integrators.py +139 -0
- openscvx/io.py +81 -0
- openscvx/ocp.py +160 -0
- openscvx/plotting.py +632 -0
- openscvx/post_processing.py +36 -0
- openscvx/propagation.py +135 -0
- openscvx/ptr.py +149 -0
- openscvx/trajoptproblem.py +336 -0
- openscvx/utils.py +80 -0
- {openscvx-0.1.0.dist-info → openscvx-0.1.1.dist-info}/METADATA +2 -2
- openscvx-0.1.1.dist-info/RECORD +25 -0
- openscvx-0.1.1.dist-info/top_level.txt +1 -0
- openscvx-0.1.0.dist-info/RECORD +0 -10
- openscvx-0.1.0.dist-info/top_level.txt +0 -1
- {constraints → openscvx/constraints}/__init__.py +0 -0
- {constraints → openscvx/constraints}/boundary.py +0 -0
- {constraints → openscvx/constraints}/ctcs.py +0 -0
- {constraints → openscvx/constraints}/nodal.py +0 -0
- {constraints → openscvx/constraints}/violation.py +0 -0
- {openscvx-0.1.0.dist-info → openscvx-0.1.1.dist-info}/WHEEL +0 -0
- {openscvx-0.1.0.dist-info → openscvx-0.1.1.dist-info}/licenses/LICENSE +0 -0
openscvx/__init__.py
ADDED
|
File without changes
|
openscvx/_version.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
|
|
5
|
+
|
|
6
|
+
TYPE_CHECKING = False
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from typing import Tuple
|
|
9
|
+
from typing import Union
|
|
10
|
+
|
|
11
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
12
|
+
else:
|
|
13
|
+
VERSION_TUPLE = object
|
|
14
|
+
|
|
15
|
+
version: str
|
|
16
|
+
__version__: str
|
|
17
|
+
__version_tuple__: VERSION_TUPLE
|
|
18
|
+
version_tuple: VERSION_TUPLE
|
|
19
|
+
|
|
20
|
+
__version__ = version = '0.1.1'
|
|
21
|
+
__version_tuple__ = version_tuple = (0, 1, 1)
|
openscvx/augmentation.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from openscvx.constraints.ctcs import CTCSConstraint
|
|
4
|
+
|
|
5
|
+
def sort_ctcs_constraints(constraints_ctcs: List[CTCSConstraint], N: int):
|
|
6
|
+
idx_to_nodes: dict[int, tuple] = {}
|
|
7
|
+
next_idx = 0
|
|
8
|
+
for c in constraints_ctcs:
|
|
9
|
+
# normalize None to full horizon
|
|
10
|
+
c.nodes = c.nodes or (0, N)
|
|
11
|
+
key = c.nodes
|
|
12
|
+
|
|
13
|
+
if c.idx is not None:
|
|
14
|
+
# user supplied an identifier: ensure it always points to the same interval
|
|
15
|
+
if c.idx in idx_to_nodes:
|
|
16
|
+
if idx_to_nodes[c.idx] != key:
|
|
17
|
+
raise ValueError(
|
|
18
|
+
f"idx={c.idx} was first used with interval={idx_to_nodes[c.idx]}, "
|
|
19
|
+
f"but now you gave it interval={key}"
|
|
20
|
+
)
|
|
21
|
+
else:
|
|
22
|
+
idx_to_nodes[c.idx] = key
|
|
23
|
+
|
|
24
|
+
else:
|
|
25
|
+
# no identifier: see if this interval already has one
|
|
26
|
+
for existing_id, nodes in idx_to_nodes.items():
|
|
27
|
+
if nodes == key:
|
|
28
|
+
c.idx = existing_id
|
|
29
|
+
break
|
|
30
|
+
else:
|
|
31
|
+
# brand-new interval: pick the next free auto-id
|
|
32
|
+
while next_idx in idx_to_nodes:
|
|
33
|
+
next_idx += 1
|
|
34
|
+
c.idx = next_idx
|
|
35
|
+
idx_to_nodes[next_idx] = key
|
|
36
|
+
next_idx += 1
|
|
37
|
+
|
|
38
|
+
# Extract your intervals in ascending‐idx order
|
|
39
|
+
ordered_ids = sorted(idx_to_nodes.keys())
|
|
40
|
+
node_intervals = [ idx_to_nodes[i] for i in ordered_ids ]
|
|
41
|
+
id_to_position = { ident: pos for pos, ident in enumerate(ordered_ids) }
|
|
42
|
+
num_augmented_states = len(ordered_ids)
|
|
43
|
+
|
|
44
|
+
return constraints_ctcs, node_intervals, num_augmented_states,
|
openscvx/config.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Dict, List
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_affine_scaling_matrices(n, minimum, maximum):
|
|
7
|
+
S = np.diag(np.maximum(np.ones(n), abs(minimum - maximum) / 2))
|
|
8
|
+
c = (maximum + minimum) / 2
|
|
9
|
+
return S, c
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class DiscretizationConfig:
|
|
14
|
+
dis_type: str = "FOH"
|
|
15
|
+
custom_integrator: bool = True
|
|
16
|
+
solver: str = "Tsit5"
|
|
17
|
+
args: Dict = field(default_factory=dict)
|
|
18
|
+
atol: float = 1e-3
|
|
19
|
+
rtol: float = 1e-6
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
Configuration class for discretization settings.
|
|
23
|
+
|
|
24
|
+
This class defines the parameters required for discretizing system dynamics.
|
|
25
|
+
|
|
26
|
+
Main arguments:
|
|
27
|
+
These are the arguments most commonly used day-to-day.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
dis_type (str): The type of discretization to use (e.g., "FOH" for First-Order Hold). Defaults to "FOH".
|
|
31
|
+
custom_integrator (bool): This enables our custom fixed-step RK45 algorthim. This tends to be faster then Diffrax but unless your going for speed, its reccomended to stick with Diffrax for robustness and other solver options. Defaults to False.
|
|
32
|
+
solver (str): Not used if custom_integrator is enabled. Any choice of solver in Diffrax is valid, please refer here, [How to Choose a Solver](https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/). Defaults to "Tsit5".
|
|
33
|
+
|
|
34
|
+
Other arguments:
|
|
35
|
+
These arguments are less frequently used, and for most purposes you shouldn't need to understand these.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
args (Dict): Additional arguments to pass to the solver which can be found [here](https://docs.kidger.site/diffrax/api/diffeqsolve/). Defaults to an empty dictionary.
|
|
39
|
+
atol (float): Absolute tolerance for the solver. Defaults to 1e-3.
|
|
40
|
+
rtol (float): Relative tolerance for the solver. Defaults to 1e-6.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class DevConfig:
|
|
46
|
+
profiling: bool = False
|
|
47
|
+
debug: bool = False
|
|
48
|
+
printing: bool = True
|
|
49
|
+
|
|
50
|
+
"""
|
|
51
|
+
Configuration class for development settings.
|
|
52
|
+
|
|
53
|
+
This class defines the parameters used for development and debugging purposes.
|
|
54
|
+
|
|
55
|
+
Main arguments:
|
|
56
|
+
These are the arguments most commonly used day-to-day.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
profiling (bool): Whether to enable profiling for performance analysis. Defaults to False.
|
|
60
|
+
debug (bool): Disables all precompilation so you can place breakpoints and inspect values. Defaults to False.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class ConvexSolverConfig:
|
|
66
|
+
solver: str = "QOCO"
|
|
67
|
+
solver_args: dict = field(default_factory=lambda: {"abstol": 1e-6, "reltol": 1e-9})
|
|
68
|
+
cvxpygen: bool = False
|
|
69
|
+
|
|
70
|
+
"""
|
|
71
|
+
Configuration class for convex solver settings.
|
|
72
|
+
|
|
73
|
+
This class defines the parameters required for configuring a convex solver.
|
|
74
|
+
|
|
75
|
+
These are the arguments most commonly used day-to-day. Generally I have found [QOCO](https://qoco-org.github.io/qoco/index.html) to be the most performant of the CVXPY solvers for these types of problems (I do have a bias as the author is from my group) and can handle up to SOCP's.
|
|
76
|
+
[CLARABEL](https://clarabel.org/stable/) is also a great option with feasibility checking and can handle a few more problem types.
|
|
77
|
+
[CVXPYGen](https://github.com/cvxgrp/cvxpygen) is also great if your problem isn't too large and allows. I have found qocogen to be the most performant of the CVXPYGen solvers.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
solver (str): The name of the CVXPY solver to use. A list of options can be found [here](https://www.cvxpy.org/tutorial/solvers/index.html). Defaults to "QOCO".
|
|
81
|
+
solver_args (dict): Ensure you are using the correct arguments for your solver as they are not all common. Additional arguments to configure the solver, such as tolerances.
|
|
82
|
+
Defaults to {"abstol": 1e-6, "reltol": 1e-9}.
|
|
83
|
+
cvxpygen (bool): Whether to enable CVXPY code generation for the solver. Defaults to False.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class PropagationConfig:
|
|
89
|
+
inter_sample: int = 30
|
|
90
|
+
dt: float = 0.1
|
|
91
|
+
solver: str = "Dopri8"
|
|
92
|
+
args: Dict = field(default_factory=dict)
|
|
93
|
+
atol: float = 1e-3
|
|
94
|
+
rtol: float = 1e-6
|
|
95
|
+
|
|
96
|
+
"""
|
|
97
|
+
Configuration class for propagation settings.
|
|
98
|
+
|
|
99
|
+
This class defines the parameters required for propagating the nonlinear system dynamics using the optimal control sequence.
|
|
100
|
+
|
|
101
|
+
Main arguments:
|
|
102
|
+
These are the arguments most commonly used day-to-day.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
dt (float): The time step for propagation. Defaults to 0.1.
|
|
106
|
+
inter_sample (int): How dense the propagation within multishot discretization should be.
|
|
107
|
+
|
|
108
|
+
Other arguments:
|
|
109
|
+
The solver should likley not to be changed as it is a high accuracy 8th order runga kutta method.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
solver (str): The numerical solver to use for propagation (e.g., "Dopri8"). Defaults to "Dopri8".
|
|
113
|
+
args (Dict): Additional arguments to pass to the solver. Defaults to an empty dictionary.
|
|
114
|
+
atol (float): Absolute tolerance for the solver. Defaults to 1e-3.
|
|
115
|
+
rtol (float): Relative tolerance for the solver. Defaults to 1e-6.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@dataclass
|
|
120
|
+
class SimConfig:
|
|
121
|
+
x_bar: np.ndarray
|
|
122
|
+
u_bar: np.ndarray
|
|
123
|
+
initial_state: np.ndarray
|
|
124
|
+
final_state: np.ndarray
|
|
125
|
+
max_state: np.ndarray
|
|
126
|
+
min_state: np.ndarray
|
|
127
|
+
max_control: np.ndarray
|
|
128
|
+
min_control: np.ndarray
|
|
129
|
+
total_time: float
|
|
130
|
+
idx_x_true: slice
|
|
131
|
+
idx_u_true: slice
|
|
132
|
+
idx_t: slice
|
|
133
|
+
idx_y: slice
|
|
134
|
+
idx_s: slice
|
|
135
|
+
ctcs_node_intervals: list = None
|
|
136
|
+
constraints_ctcs: List[callable] = field(
|
|
137
|
+
default_factory=list
|
|
138
|
+
) # TODO (norrisg): clean this up, consider moving to dedicated `constraints` dataclass
|
|
139
|
+
constraints_nodal: List[callable] = field(default_factory=list)
|
|
140
|
+
n_states: int = None
|
|
141
|
+
n_controls: int = None
|
|
142
|
+
S_x: np.ndarray = None
|
|
143
|
+
inv_S_x: np.ndarray = None
|
|
144
|
+
c_x: np.ndarray = None
|
|
145
|
+
S_u: np.ndarray = None
|
|
146
|
+
inv_S_u: np.ndarray = None
|
|
147
|
+
c_u: np.ndarray = None
|
|
148
|
+
|
|
149
|
+
def __post_init__(self):
|
|
150
|
+
self.n_states = len(self.max_state)
|
|
151
|
+
self.n_controls = len(self.max_control)
|
|
152
|
+
|
|
153
|
+
assert (
|
|
154
|
+
len(self.initial_state.value) == self.n_states - (self.idx_y.stop - self.idx_y.start)
|
|
155
|
+
), f"Initial state must have {self.n_states - (self.idx_y.stop - self.idx_y.start)} elements"
|
|
156
|
+
assert (
|
|
157
|
+
len(self.final_state.value) == self.n_states - (self.idx_y.stop - self.idx_y.start)
|
|
158
|
+
), f"Final state must have {self.n_states - (self.idx_y.stop - self.idx_y.start)} elements"
|
|
159
|
+
assert (
|
|
160
|
+
self.max_state.shape[0] == self.n_states
|
|
161
|
+
), f"Max state must have {self.n_states} elements"
|
|
162
|
+
assert (
|
|
163
|
+
self.min_state.shape[0] == self.n_states
|
|
164
|
+
), f"Min state must have {self.n_states} elements"
|
|
165
|
+
assert (
|
|
166
|
+
self.max_control.shape[0] == self.n_controls
|
|
167
|
+
), f"Max control must have {self.n_controls} elements"
|
|
168
|
+
assert (
|
|
169
|
+
self.min_control.shape[0] == self.n_controls
|
|
170
|
+
), f"Min control must have {self.n_controls} elements"
|
|
171
|
+
|
|
172
|
+
if self.S_x is None or self.c_x is None:
|
|
173
|
+
self.S_x, self.c_x = get_affine_scaling_matrices(
|
|
174
|
+
self.n_states, self.min_state, self.max_state
|
|
175
|
+
)
|
|
176
|
+
# Use the fact that S_x is diagonal to compute the inverse
|
|
177
|
+
self.inv_S_x = np.diag(1 / np.diag(self.S_x))
|
|
178
|
+
if self.S_u is None or self.c_u is None:
|
|
179
|
+
self.S_u, self.c_u = get_affine_scaling_matrices(
|
|
180
|
+
self.n_controls, self.min_control, self.max_control
|
|
181
|
+
)
|
|
182
|
+
self.inv_S_u = np.diag(1 / np.diag(self.S_u))
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@dataclass
|
|
186
|
+
class ScpConfig:
|
|
187
|
+
n: int = None
|
|
188
|
+
k_max: int = 200
|
|
189
|
+
w_tr: float = 1e0
|
|
190
|
+
lam_vc: float = 1e0
|
|
191
|
+
ep_tr: float = 1e-4
|
|
192
|
+
ep_vb: float = 1e-4
|
|
193
|
+
ep_vc: float = 1e-8
|
|
194
|
+
lam_cost: float = 0.0
|
|
195
|
+
lam_vb: float = 0.0
|
|
196
|
+
uniform_time_grid: bool = False
|
|
197
|
+
cost_drop: int = -1
|
|
198
|
+
cost_relax: float = 1.0
|
|
199
|
+
w_tr_adapt: float = 1.0
|
|
200
|
+
w_tr_max: float = None
|
|
201
|
+
w_tr_max_scaling_factor: float = None
|
|
202
|
+
|
|
203
|
+
"""
|
|
204
|
+
Configuration class for Sequential Convex Programming (SCP).
|
|
205
|
+
|
|
206
|
+
This class defines the parameters used to configure the SCP solver. You will very likely need to modify
|
|
207
|
+
the weights for your problem. Please refer to my guide [here](https://haynec.github.io/openscvx/hyperparameter_tuning) for more information.
|
|
208
|
+
|
|
209
|
+
Attributes:
|
|
210
|
+
n (int): The number of discretization nodes. Defaults to `None`.
|
|
211
|
+
k_max (int): The maximum number of SCP iterations. Defaults to 200.
|
|
212
|
+
w_tr (float): The trust region weight. Defaults to 1.0.
|
|
213
|
+
lam_vc (float): The penalty weight for virtual control. Defaults to 1.0.
|
|
214
|
+
ep_tr (float): The trust region convergence tolerance. Defaults to 1e-4.
|
|
215
|
+
ep_vb (float): The boundary constraint convergence tolerance. Defaults to 1e-4.
|
|
216
|
+
ep_vc (float): The virtual constraint convergence tolerance. Defaults to 1e-8.
|
|
217
|
+
lam_cost (float): The weight for original cost. Defaults to 0.0.
|
|
218
|
+
lam_vb (float): The weight for virtual buffer. This is only used if there are nonconvex nodal constraints present. Defaults to 0.0.
|
|
219
|
+
uniform_time_grid (bool): Whether to use a uniform time grid. TODO haynec add a link to the time dilation page. Defaults to `False`.
|
|
220
|
+
cost_drop (int): The number of iterations to allow for cost stagnation before termination. Defaults to -1 (disabled).
|
|
221
|
+
cost_relax (float): The relaxation factor for cost reduction. Defaults to 1.0.
|
|
222
|
+
w_tr_adapt (float): The adaptation factor for the trust region weight. Defaults to 1.0.
|
|
223
|
+
w_tr_max (float): The maximum allowable trust region weight. Defaults to `None`.
|
|
224
|
+
w_tr_max_scaling_factor (float): The scaling factor for the maximum trust region weight. Defaults to `None`.
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
def __post_init__(self):
|
|
228
|
+
keys_to_scale = ["w_tr", "lam_vc", "lam_cost", "lam_vb"]
|
|
229
|
+
scale = max(getattr(self, key) for key in keys_to_scale)
|
|
230
|
+
for key in keys_to_scale:
|
|
231
|
+
setattr(self, key, getattr(self, key) / scale)
|
|
232
|
+
|
|
233
|
+
if self.w_tr_max_scaling_factor is not None and self.w_tr_max is None:
|
|
234
|
+
self.w_tr_max = self.w_tr_max_scaling_factor * self.w_tr
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
@dataclass
|
|
238
|
+
class Config:
|
|
239
|
+
sim: SimConfig
|
|
240
|
+
scp: ScpConfig
|
|
241
|
+
cvx: ConvexSolverConfig
|
|
242
|
+
dis: DiscretizationConfig
|
|
243
|
+
prp: PropagationConfig
|
|
244
|
+
dev: DevConfig
|
|
245
|
+
|
|
246
|
+
def __post_init__(self):
|
|
247
|
+
pass
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
from openscvx.integrators import solve_ivp_rk45, solve_ivp_diffrax
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def dVdt(
|
|
8
|
+
tau: float,
|
|
9
|
+
V: jnp.ndarray,
|
|
10
|
+
u_cur: np.ndarray,
|
|
11
|
+
u_next: np.ndarray,
|
|
12
|
+
state_dot: callable,
|
|
13
|
+
A: callable,
|
|
14
|
+
B: callable,
|
|
15
|
+
n_x: int,
|
|
16
|
+
n_u: int,
|
|
17
|
+
N: int,
|
|
18
|
+
dis_type: str,
|
|
19
|
+
) -> jnp.ndarray:
|
|
20
|
+
# Define the nodes
|
|
21
|
+
nodes = jnp.arange(0, N-1)
|
|
22
|
+
|
|
23
|
+
# Define indices for slicing the augmented state vector
|
|
24
|
+
i0 = 0
|
|
25
|
+
i1 = n_x
|
|
26
|
+
i2 = i1 + n_x * n_x
|
|
27
|
+
i3 = i2 + n_x * n_u
|
|
28
|
+
i4 = i3 + n_x * n_u
|
|
29
|
+
i5 = i4 + n_x
|
|
30
|
+
|
|
31
|
+
# Unflatten V
|
|
32
|
+
V = V.reshape(-1, i5)
|
|
33
|
+
|
|
34
|
+
# Compute the interpolation factor based on the discretization type
|
|
35
|
+
if dis_type == "ZOH":
|
|
36
|
+
beta = 0.0
|
|
37
|
+
elif dis_type == "FOH":
|
|
38
|
+
beta = (tau) * N
|
|
39
|
+
alpha = 1 - beta
|
|
40
|
+
|
|
41
|
+
# Interpolate the control input
|
|
42
|
+
u = u_cur + beta * (u_next - u_cur)
|
|
43
|
+
s = u[:, -1]
|
|
44
|
+
|
|
45
|
+
# Initialize the augmented Jacobians
|
|
46
|
+
dfdx = jnp.zeros((V.shape[0], n_x, n_x))
|
|
47
|
+
dfdu = jnp.zeros((V.shape[0], n_x, n_u))
|
|
48
|
+
|
|
49
|
+
# Ensure x_seq and u have the same batch size
|
|
50
|
+
x = V[:, :n_x]
|
|
51
|
+
u = u[: x.shape[0]]
|
|
52
|
+
|
|
53
|
+
# Compute the nonlinear propagation term
|
|
54
|
+
f = state_dot(x, u[:, :-1], nodes)
|
|
55
|
+
F = s[:, None] * f
|
|
56
|
+
|
|
57
|
+
# Evaluate the State Jacobian
|
|
58
|
+
dfdx = A(x, u[:, :-1], nodes)
|
|
59
|
+
sdfdx = s[:, None, None] * dfdx
|
|
60
|
+
|
|
61
|
+
# Evaluate the Control Jacobian
|
|
62
|
+
dfdu_veh = B(x, u[:, :-1], nodes)
|
|
63
|
+
dfdu = dfdu.at[:, :, :-1].set(s[:, None, None] * dfdu_veh)
|
|
64
|
+
dfdu = dfdu.at[:, :, -1].set(f)
|
|
65
|
+
|
|
66
|
+
# Compute the defect
|
|
67
|
+
z = F - jnp.einsum("ijk,ik->ij", sdfdx, x) - jnp.einsum("ijk,ik->ij", dfdu, u)
|
|
68
|
+
|
|
69
|
+
# Stack up the results into the augmented state vector
|
|
70
|
+
# fmt: off
|
|
71
|
+
dVdt = jnp.zeros_like(V)
|
|
72
|
+
dVdt = dVdt.at[:, i0:i1].set(F)
|
|
73
|
+
dVdt = dVdt.at[:, i1:i2].set(jnp.matmul(sdfdx, V[:, i1:i2].reshape(-1, n_x, n_x)).reshape(-1, n_x * n_x))
|
|
74
|
+
dVdt = dVdt.at[:, i2:i3].set((jnp.matmul(sdfdx, V[:, i2:i3].reshape(-1, n_x, n_u)) + dfdu * alpha).reshape(-1, n_x * n_u))
|
|
75
|
+
dVdt = dVdt.at[:, i3:i4].set((jnp.matmul(sdfdx, V[:, i3:i4].reshape(-1, n_x, n_u)) + dfdu * beta).reshape(-1, n_x * n_u))
|
|
76
|
+
dVdt = dVdt.at[:, i4:i5].set((jnp.matmul(sdfdx, V[:, i4:i5].reshape(-1, n_x)[..., None]).squeeze(-1) + z).reshape(-1, n_x))
|
|
77
|
+
# fmt: on
|
|
78
|
+
return dVdt.flatten()
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def calculate_discretization(
|
|
82
|
+
x,
|
|
83
|
+
u,
|
|
84
|
+
state_dot: callable,
|
|
85
|
+
A: callable,
|
|
86
|
+
B: callable,
|
|
87
|
+
n_x: int,
|
|
88
|
+
n_u: int,
|
|
89
|
+
N: int,
|
|
90
|
+
custom_integrator: bool,
|
|
91
|
+
debug: bool,
|
|
92
|
+
solver: str,
|
|
93
|
+
rtol,
|
|
94
|
+
atol,
|
|
95
|
+
dis_type: str,
|
|
96
|
+
):
|
|
97
|
+
|
|
98
|
+
# Define indices for slicing the augmented state vector
|
|
99
|
+
i0 = 0
|
|
100
|
+
i1 = n_x
|
|
101
|
+
i2 = i1 + n_x * n_x
|
|
102
|
+
i3 = i2 + n_x * n_u
|
|
103
|
+
i4 = i3 + n_x * n_u
|
|
104
|
+
i5 = i4 + n_x
|
|
105
|
+
|
|
106
|
+
# initial augmented state
|
|
107
|
+
V0 = jnp.zeros((N - 1, i5))
|
|
108
|
+
V0 = V0.at[:, :n_x].set(x[:-1].astype(float))
|
|
109
|
+
V0 = V0.at[:, n_x : n_x + n_x * n_x].set(
|
|
110
|
+
jnp.eye(n_x).reshape(1, -1).repeat(N - 1, axis=0)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# choose integrator
|
|
114
|
+
if custom_integrator:
|
|
115
|
+
# fmt: off
|
|
116
|
+
sol = solve_ivp_rk45(
|
|
117
|
+
lambda t,y,*a: dVdt(t, y, *a),
|
|
118
|
+
1.0/(N-1),
|
|
119
|
+
V0.reshape(-1),
|
|
120
|
+
args=(u[:-1].astype(float), u[1:].astype(float),
|
|
121
|
+
state_dot, A, B, n_x, n_u, N, dis_type),
|
|
122
|
+
is_not_compiled=debug,
|
|
123
|
+
)
|
|
124
|
+
# fmt: on
|
|
125
|
+
else:
|
|
126
|
+
# fmt: off
|
|
127
|
+
sol = solve_ivp_diffrax(
|
|
128
|
+
lambda t,y,*a: dVdt(t, y, *a),
|
|
129
|
+
1.0/(N-1),
|
|
130
|
+
V0.reshape(-1),
|
|
131
|
+
args=(u[:-1].astype(float), u[1:].astype(float),
|
|
132
|
+
state_dot, A, B, n_x, n_u, N, dis_type),
|
|
133
|
+
solver_name=solver,
|
|
134
|
+
rtol=rtol,
|
|
135
|
+
atol=atol,
|
|
136
|
+
extra_kwargs=None,
|
|
137
|
+
)
|
|
138
|
+
# fmt: on
|
|
139
|
+
|
|
140
|
+
Vend = sol[-1].T.reshape(-1, i5)
|
|
141
|
+
Vmulti = sol.T
|
|
142
|
+
|
|
143
|
+
# fmt: off
|
|
144
|
+
A_bar = Vend[:, i1:i2].reshape(N-1, n_x, n_x).transpose(1,2,0).reshape(n_x*n_x, -1, order='F').T
|
|
145
|
+
B_bar = Vend[:, i2:i3].reshape(N-1, n_x, n_u).transpose(1,2,0).reshape(n_x*n_u, -1, order='F').T
|
|
146
|
+
C_bar = Vend[:, i3:i4].reshape(N-1, n_x, n_u).transpose(1,2,0).reshape(n_x*n_u, -1, order='F').T
|
|
147
|
+
z_bar = Vend[:, i4:i5]
|
|
148
|
+
# fmt: on
|
|
149
|
+
|
|
150
|
+
return A_bar, B_bar, C_bar, z_bar, Vmulti
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def get_discretization_solver(state_dot, A, B, params):
|
|
154
|
+
return lambda x, u: calculate_discretization(
|
|
155
|
+
x=x,
|
|
156
|
+
u=u,
|
|
157
|
+
state_dot=state_dot,
|
|
158
|
+
A=A,
|
|
159
|
+
B=B,
|
|
160
|
+
n_x=params.sim.n_states,
|
|
161
|
+
n_u=params.sim.n_controls,
|
|
162
|
+
N=params.scp.n,
|
|
163
|
+
custom_integrator=params.dis.custom_integrator,
|
|
164
|
+
debug=params.dev.debug,
|
|
165
|
+
solver=params.dis.solver,
|
|
166
|
+
rtol=params.dis.rtol,
|
|
167
|
+
atol=params.dis.atol,
|
|
168
|
+
dis_type=params.dis.dis_type,
|
|
169
|
+
)
|
openscvx/dynamics.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def get_augmented_dynamics(
|
|
6
|
+
dynamics: callable, g_funcs: list[callable], idx_x_true: slice, idx_u_true: slice
|
|
7
|
+
) -> callable:
|
|
8
|
+
def dynamics_augmented(x: jnp.array, u: jnp.array, node: int) -> jnp.array:
|
|
9
|
+
x_dot = dynamics(x[idx_x_true], u[idx_u_true])
|
|
10
|
+
|
|
11
|
+
# Iterate through the g_func dictionary and stack the output each function
|
|
12
|
+
# to x_dot
|
|
13
|
+
for g in g_funcs:
|
|
14
|
+
x_dot = jnp.hstack([x_dot, g(x[idx_x_true], u[idx_u_true], node)])
|
|
15
|
+
|
|
16
|
+
return x_dot
|
|
17
|
+
|
|
18
|
+
return dynamics_augmented
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_jacobians(dyn: callable):
|
|
22
|
+
A = jax.jacfwd(dyn, argnums=0)
|
|
23
|
+
B = jax.jacfwd(dyn, argnums=1)
|
|
24
|
+
return A, B
|
openscvx/integrators.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import diffrax as dfx
|
|
4
|
+
|
|
5
|
+
SOLVER_MAP = {
|
|
6
|
+
"Tsit5": dfx.Tsit5,
|
|
7
|
+
"Euler": dfx.Euler,
|
|
8
|
+
"Heun": dfx.Heun,
|
|
9
|
+
"Midpoint": dfx.Midpoint,
|
|
10
|
+
"Ralston": dfx.Ralston,
|
|
11
|
+
"Dopri5": dfx.Dopri5,
|
|
12
|
+
"Dopri8": dfx.Dopri8,
|
|
13
|
+
"Bosh3": dfx.Bosh3,
|
|
14
|
+
"ReversibleHeun": dfx.ReversibleHeun,
|
|
15
|
+
"ImplicitEuler": dfx.ImplicitEuler,
|
|
16
|
+
"KenCarp3": dfx.KenCarp3,
|
|
17
|
+
"KenCarp4": dfx.KenCarp4,
|
|
18
|
+
"KenCarp5": dfx.KenCarp5,
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
# fmt: off
|
|
22
|
+
def rk45_step(f, t, y, h, *args):
|
|
23
|
+
k1 = f(t, y, *args)
|
|
24
|
+
k2 = f(t + h/4, y + h*k1/4, *args)
|
|
25
|
+
k3 = f(t + 3*h/8, y + 3*h*k1/32 + 9*h*k2/32, *args)
|
|
26
|
+
k4 = f(t + 12*h/13, y + 1932*h*k1/2197 - 7200*h*k2/2197 + 7296*h*k3/2197, *args)
|
|
27
|
+
k5 = f(t + h, y + 439*h*k1/216 - 8*h*k2 + 3680*h*k3/513 - 845*h*k4/4104, *args)
|
|
28
|
+
y_next = y + h * (25*k1/216 + 1408*k3/2565 + 2197*k4/4104 - k5/5)
|
|
29
|
+
return y_next
|
|
30
|
+
# fmt: on
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def solve_ivp_rk45(
|
|
34
|
+
f,
|
|
35
|
+
tau_final: float,
|
|
36
|
+
y_0,
|
|
37
|
+
args,
|
|
38
|
+
tau_0: float = 0.0,
|
|
39
|
+
num_substeps: int = 50,
|
|
40
|
+
is_not_compiled: bool = False,
|
|
41
|
+
):
|
|
42
|
+
substeps = jnp.linspace(tau_0, tau_final, num_substeps)
|
|
43
|
+
|
|
44
|
+
h = (tau_final - tau_0) / (len(substeps) - 1)
|
|
45
|
+
solution = jnp.zeros((len(substeps), len(y_0)))
|
|
46
|
+
solution = solution.at[0].set(y_0)
|
|
47
|
+
|
|
48
|
+
if is_not_compiled:
|
|
49
|
+
for i in range(1, len(substeps)):
|
|
50
|
+
t = tau_0 + i * h
|
|
51
|
+
solution = solution.at[i].set(rk45_step(f, t, solution[i - 1], h, *args))
|
|
52
|
+
else:
|
|
53
|
+
|
|
54
|
+
def body_fun(i, val):
|
|
55
|
+
t, y, V_result = val
|
|
56
|
+
y_next = rk45_step(f, t, y, h, *args)
|
|
57
|
+
V_result = V_result.at[i].set(y_next)
|
|
58
|
+
return (t + h, y_next, V_result)
|
|
59
|
+
|
|
60
|
+
_, _, solution = jax.lax.fori_loop(
|
|
61
|
+
1, len(substeps), body_fun, (tau_0, y_0, solution)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return solution
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def solve_ivp_diffrax(
|
|
68
|
+
f,
|
|
69
|
+
tau_final,
|
|
70
|
+
y_0,
|
|
71
|
+
args,
|
|
72
|
+
tau_0: float = 0.0,
|
|
73
|
+
num_substeps: int = 50,
|
|
74
|
+
solver_name="Dopri8",
|
|
75
|
+
rtol: float = 1e-3,
|
|
76
|
+
atol: float = 1e-6,
|
|
77
|
+
extra_kwargs=None,
|
|
78
|
+
):
|
|
79
|
+
substeps = jnp.linspace(tau_0, tau_final, num_substeps)
|
|
80
|
+
|
|
81
|
+
solver_class = SOLVER_MAP.get(solver_name)
|
|
82
|
+
if solver_class is None:
|
|
83
|
+
raise ValueError(f"Unknown solver: {solver_name}")
|
|
84
|
+
solver = solver_class()
|
|
85
|
+
|
|
86
|
+
term = dfx.ODETerm(lambda t, y, args: f(t, y, *args))
|
|
87
|
+
stepsize_controller = dfx.PIDController(rtol=rtol, atol=atol)
|
|
88
|
+
solution = dfx.diffeqsolve(
|
|
89
|
+
term,
|
|
90
|
+
solver=solver,
|
|
91
|
+
t0=tau_0,
|
|
92
|
+
t1=tau_final,
|
|
93
|
+
dt0=(tau_final - tau_0) / (len(substeps) - 1),
|
|
94
|
+
y0=y_0,
|
|
95
|
+
args=args,
|
|
96
|
+
stepsize_controller=stepsize_controller,
|
|
97
|
+
saveat=dfx.SaveAt(ts=substeps),
|
|
98
|
+
**(extra_kwargs or {}),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return solution.ys
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# TODO: (norrisg) this function is basically identical to `solve_ivp_diffrax`, could combine, but requires returning solution and getting `.ys` wherever the `solve_ivp_diffrax` is called
|
|
105
|
+
def solve_ivp_diffrax_prop(
|
|
106
|
+
f,
|
|
107
|
+
tau_final,
|
|
108
|
+
y_0,
|
|
109
|
+
args,
|
|
110
|
+
tau_0: float = 0.0,
|
|
111
|
+
num_substeps: int = 50,
|
|
112
|
+
solver_name="Dopri8",
|
|
113
|
+
rtol: float = 1e-3,
|
|
114
|
+
atol: float = 1e-6,
|
|
115
|
+
extra_kwargs=None,
|
|
116
|
+
):
|
|
117
|
+
substeps = jnp.linspace(tau_0, tau_final, num_substeps)
|
|
118
|
+
|
|
119
|
+
solver_class = SOLVER_MAP.get(solver_name)
|
|
120
|
+
if solver_class is None:
|
|
121
|
+
raise ValueError(f"Unknown solver: {solver_name}")
|
|
122
|
+
solver = solver_class()
|
|
123
|
+
|
|
124
|
+
term = dfx.ODETerm(lambda t, y, args: f(t, y, *args))
|
|
125
|
+
stepsize_controller = dfx.PIDController(rtol=rtol, atol=atol)
|
|
126
|
+
solution = dfx.diffeqsolve(
|
|
127
|
+
term,
|
|
128
|
+
solver=solver,
|
|
129
|
+
t0=tau_0,
|
|
130
|
+
t1=tau_final,
|
|
131
|
+
dt0=(tau_final - tau_0) / (len(substeps) - 1),
|
|
132
|
+
y0=y_0,
|
|
133
|
+
args=args,
|
|
134
|
+
stepsize_controller=stepsize_controller,
|
|
135
|
+
saveat=dfx.SaveAt(dense=True, ts=substeps),
|
|
136
|
+
**(extra_kwargs or {}),
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
return solution
|