openscvx 0.1.2__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/_version.py +2 -2
- openscvx/augmentation/dynamics_augmentation.py +22 -7
- openscvx/config.py +310 -192
- openscvx/constraints/__init__.py +0 -3
- openscvx/constraints/ctcs.py +188 -33
- openscvx/constraints/nodal.py +150 -11
- openscvx/constraints/violation.py +12 -2
- openscvx/discretization.py +115 -37
- openscvx/dynamics.py +150 -11
- openscvx/integrators.py +135 -16
- openscvx/io.py +129 -17
- openscvx/ocp.py +86 -67
- openscvx/plotting.py +72 -215
- openscvx/post_processing.py +57 -16
- openscvx/propagation.py +155 -55
- openscvx/ptr.py +96 -57
- openscvx/results.py +153 -0
- openscvx/trajoptproblem.py +359 -114
- openscvx/utils.py +50 -0
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/METADATA +129 -41
- openscvx-0.2.1.dev0.dist-info/RECORD +27 -0
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/WHEEL +1 -1
- openscvx/constraints/boundary.py +0 -49
- openscvx-0.1.2.dist-info/RECORD +0 -27
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/top_level.txt +0 -0
openscvx/propagation.py
CHANGED
|
@@ -2,6 +2,7 @@ import numpy as np
|
|
|
2
2
|
|
|
3
3
|
from openscvx.config import Config
|
|
4
4
|
from openscvx.integrators import solve_ivp_diffrax_prop
|
|
5
|
+
from openscvx.backend.parameter import Parameter
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
def prop_aug_dy(
|
|
@@ -15,7 +16,29 @@ def prop_aug_dy(
|
|
|
15
16
|
state_dot: callable,
|
|
16
17
|
dis_type: str,
|
|
17
18
|
N: int,
|
|
19
|
+
*params
|
|
18
20
|
) -> np.ndarray:
|
|
21
|
+
"""Compute the augmented dynamics for propagation.
|
|
22
|
+
|
|
23
|
+
This function computes the time-scaled dynamics for propagating the system state,
|
|
24
|
+
taking into account the discretization type (ZOH or FOH) and time dilation.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
tau (float): Current normalized time in [0,1].
|
|
28
|
+
x (np.ndarray): Current state vector.
|
|
29
|
+
u_current (np.ndarray): Control input at current node.
|
|
30
|
+
u_next (np.ndarray): Control input at next node.
|
|
31
|
+
tau_init (float): Initial normalized time.
|
|
32
|
+
node (int): Current node index.
|
|
33
|
+
idx_s (int): Index of time dilation variable in control vector.
|
|
34
|
+
state_dot (callable): Function computing state derivatives.
|
|
35
|
+
dis_type (str): Discretization type ("ZOH" or "FOH").
|
|
36
|
+
N (int): Number of nodes in trajectory.
|
|
37
|
+
*params: Additional parameters passed to state_dot.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
np.ndarray: Time-scaled state derivatives.
|
|
41
|
+
"""
|
|
19
42
|
x = x[None, :]
|
|
20
43
|
|
|
21
44
|
if dis_type == "ZOH":
|
|
@@ -24,37 +47,69 @@ def prop_aug_dy(
|
|
|
24
47
|
beta = (tau - tau_init) * N
|
|
25
48
|
u = u_current + beta * (u_next - u_current)
|
|
26
49
|
|
|
27
|
-
return u[:, idx_s] * state_dot(x, u[:, :-1], node).squeeze()
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
50
|
+
return u[:, idx_s] * state_dot(x, u[:, :-1], node, *params).squeeze()
|
|
51
|
+
|
|
52
|
+
def get_propagation_solver(state_dot, settings, param_map):
|
|
53
|
+
"""Create a propagation solver function.
|
|
54
|
+
|
|
55
|
+
This function creates a solver that propagates the system state using the
|
|
56
|
+
specified dynamics and settings.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
state_dot (callable): Function computing state derivatives.
|
|
60
|
+
settings: Configuration settings for propagation.
|
|
61
|
+
param_map (dict): Mapping of parameter names to values.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
callable: A function that solves the propagation problem.
|
|
65
|
+
"""
|
|
66
|
+
def propagation_solver(V0, tau_grid, u_cur, u_next, tau_init, node, idx_s, save_time, mask, *params):
|
|
67
|
+
param_map_update = dict(zip(param_map.keys(), params))
|
|
32
68
|
return solve_ivp_diffrax_prop(
|
|
33
69
|
f=prop_aug_dy,
|
|
34
|
-
tau_final=tau_grid[1],
|
|
35
|
-
y_0=V0,
|
|
70
|
+
tau_final=tau_grid[1], # scalar
|
|
71
|
+
y_0=V0, # shape (n_states,)
|
|
36
72
|
args=(
|
|
37
|
-
u_cur,
|
|
38
|
-
u_next,
|
|
39
|
-
tau_init,
|
|
40
|
-
node,
|
|
41
|
-
idx_s,
|
|
42
|
-
state_dot,
|
|
43
|
-
|
|
44
|
-
|
|
73
|
+
u_cur, # shape (1, n_controls)
|
|
74
|
+
u_next, # shape (1, n_controls)
|
|
75
|
+
tau_init, # shape (1, 1)
|
|
76
|
+
node, # shape (1, 1)
|
|
77
|
+
idx_s, # int
|
|
78
|
+
state_dot, # function or array
|
|
79
|
+
settings.dis.dis_type,
|
|
80
|
+
settings.scp.n,
|
|
81
|
+
*param_map_update.items(),
|
|
82
|
+
# additional named parameters as **kwargs
|
|
45
83
|
),
|
|
46
|
-
tau_0=tau_grid[0],
|
|
84
|
+
tau_0=tau_grid[0], # scalar
|
|
85
|
+
save_time=save_time, # shape (MAX_TAU_LEN,)
|
|
86
|
+
mask=mask # shape (MAX_TAU_LEN,), dtype=bool
|
|
47
87
|
)
|
|
48
88
|
|
|
49
89
|
return propagation_solver
|
|
50
90
|
|
|
51
91
|
|
|
52
|
-
|
|
53
|
-
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def s_to_t(x, u, params: Config):
|
|
95
|
+
"""Convert normalized time s to real time t.
|
|
96
|
+
|
|
97
|
+
This function converts the normalized time variable s to real time t
|
|
98
|
+
based on the discretization type and time dilation factors.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
x: State trajectory.
|
|
102
|
+
u: Control trajectory.
|
|
103
|
+
params (Config): Configuration settings.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
list: List of real time points.
|
|
107
|
+
"""
|
|
108
|
+
t = [x.guess[:,params.sim.idx_t][0]]
|
|
54
109
|
tau = np.linspace(0, 1, params.scp.n)
|
|
55
110
|
for k in range(1, params.scp.n):
|
|
56
|
-
s_kp = u[k - 1, -1]
|
|
57
|
-
s_k = u[k, -1]
|
|
111
|
+
s_kp = u.guess[k - 1, -1]
|
|
112
|
+
s_k = u.guess[k, -1]
|
|
58
113
|
if params.dis.dis_type == "ZOH":
|
|
59
114
|
t.append(t[k - 1] + (tau[k] - tau[k - 1]) * (s_kp))
|
|
60
115
|
else:
|
|
@@ -62,17 +117,30 @@ def s_to_t(u, params: Config):
|
|
|
62
117
|
return t
|
|
63
118
|
|
|
64
119
|
|
|
65
|
-
def t_to_tau(u, t,
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
120
|
+
def t_to_tau(u, t, t_nodal, params: Config):
|
|
121
|
+
"""Convert real time t to normalized time tau.
|
|
122
|
+
|
|
123
|
+
This function converts real time t to normalized time tau and interpolates
|
|
124
|
+
the control inputs accordingly.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
u: Control trajectory.
|
|
128
|
+
t (np.ndarray): Real time points.
|
|
129
|
+
t_nodal (np.ndarray): Nodal time points.
|
|
130
|
+
params (Config): Configuration settings.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
tuple: (tau, u_interp) where tau is normalized time and u_interp is interpolated controls.
|
|
134
|
+
"""
|
|
135
|
+
u_guess = u.guess
|
|
136
|
+
u_lam = lambda new_t: np.array([np.interp(new_t, t_nodal, u_guess[:,i]) for i in range(u_guess.shape[1])]).T
|
|
69
137
|
u = np.array([u_lam(t_i) for t_i in t])
|
|
70
138
|
|
|
71
139
|
tau = np.zeros(len(t))
|
|
72
140
|
tau_nodal = np.linspace(0, 1, params.scp.n)
|
|
73
141
|
for k in range(1, len(t)):
|
|
74
142
|
k_nodal = np.where(t_nodal < t[k])[0][-1]
|
|
75
|
-
s_kp =
|
|
143
|
+
s_kp = u_guess[k_nodal, -1]
|
|
76
144
|
tp = t_nodal[k_nodal]
|
|
77
145
|
tau_p = tau_nodal[k_nodal]
|
|
78
146
|
|
|
@@ -84,52 +152,84 @@ def t_to_tau(u, t, u_nodal, t_nodal, params: Config):
|
|
|
84
152
|
return tau, u
|
|
85
153
|
|
|
86
154
|
|
|
87
|
-
def simulate_nonlinear_time(
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
155
|
+
def simulate_nonlinear_time(params, x, u, tau_vals, t, settings, propagation_solver):
|
|
156
|
+
"""Simulate the nonlinear system dynamics over time.
|
|
157
|
+
|
|
158
|
+
This function simulates the system dynamics using the optimal control sequence
|
|
159
|
+
and returns the resulting state trajectory.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
params: System parameters.
|
|
163
|
+
x: State trajectory.
|
|
164
|
+
u: Control trajectory.
|
|
165
|
+
tau_vals (np.ndarray): Normalized time points for simulation.
|
|
166
|
+
t (np.ndarray): Real time points.
|
|
167
|
+
settings: Configuration settings.
|
|
168
|
+
propagation_solver (callable): Function for propagating the system state.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
np.ndarray: Simulated state trajectory.
|
|
172
|
+
"""
|
|
173
|
+
x_0 = settings.sim.x_prop.initial
|
|
174
|
+
|
|
175
|
+
n_segments = settings.scp.n - 1
|
|
176
|
+
n_states = x_0.shape[0]
|
|
177
|
+
n_tau = len(tau_vals)
|
|
178
|
+
|
|
179
|
+
params = params.items()
|
|
180
|
+
param_values = tuple([param.value for _, param in params])
|
|
181
|
+
|
|
182
|
+
states = np.empty((n_states, n_tau))
|
|
183
|
+
tau = np.linspace(0, 1, settings.scp.n)
|
|
184
|
+
|
|
185
|
+
# Precompute control interpolation
|
|
186
|
+
u_interp = np.stack([
|
|
187
|
+
np.interp(t, t, u.guess[:, i]) for i in range(u.guess.shape[1])
|
|
188
|
+
], axis=-1)
|
|
189
|
+
|
|
190
|
+
# Bin tau_vals into segments of tau
|
|
99
191
|
tau_inds = np.digitize(tau_vals, tau) - 1
|
|
100
|
-
|
|
101
|
-
tau_inds = np.where(tau_inds == params.scp.n - 1, params.scp.n - 2, tau_inds)
|
|
192
|
+
tau_inds = np.where(tau_inds == settings.scp.n - 1, settings.scp.n - 2, tau_inds)
|
|
102
193
|
|
|
103
194
|
prev_count = 0
|
|
195
|
+
out_idx = 0
|
|
104
196
|
|
|
105
|
-
for k in range(
|
|
106
|
-
controls_current =
|
|
107
|
-
controls_next =
|
|
197
|
+
for k in range(n_segments):
|
|
198
|
+
controls_current = u_interp[k][None, :]
|
|
199
|
+
controls_next = u_interp[k + 1][None, :]
|
|
108
200
|
|
|
109
|
-
#
|
|
201
|
+
# Mask for tau_vals in current segment
|
|
110
202
|
mask = (tau_inds >= k) & (tau_inds < k + 1)
|
|
111
|
-
|
|
112
203
|
count = np.sum(mask)
|
|
113
204
|
|
|
114
|
-
|
|
115
|
-
tau_cur =
|
|
205
|
+
tau_cur = tau_vals[prev_count:prev_count + count]
|
|
206
|
+
tau_cur = np.concatenate([tau_cur, np.array([tau[k + 1]])]) # Always include final point
|
|
207
|
+
count += 1
|
|
208
|
+
|
|
209
|
+
# Pad to fixed length
|
|
210
|
+
pad_len = settings.prp.max_tau_len - count
|
|
211
|
+
tau_cur_padded = np.pad(tau_cur, (0, pad_len), constant_values=tau[k + 1])
|
|
212
|
+
mask_padded = np.concatenate([np.ones(count), np.zeros(pad_len)]).astype(bool)
|
|
116
213
|
|
|
117
|
-
|
|
214
|
+
# Call the solver with padded tau_cur and mask
|
|
215
|
+
sol = propagation_solver.call(
|
|
118
216
|
x_0,
|
|
119
217
|
(tau[k], tau[k + 1]),
|
|
120
218
|
controls_current,
|
|
121
219
|
controls_next,
|
|
122
220
|
np.array([[tau[k]]]),
|
|
123
221
|
np.array([[k]]),
|
|
124
|
-
|
|
222
|
+
settings.sim.idx_s.stop,
|
|
223
|
+
tau_cur_padded,
|
|
224
|
+
mask_padded,
|
|
225
|
+
*param_values
|
|
125
226
|
)
|
|
126
227
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
228
|
+
# Only store the valid portion (excluding the final point which becomes next x_0)
|
|
229
|
+
states[:, out_idx:out_idx + count - 1] = sol[:count - 1].T
|
|
230
|
+
out_idx += count - 1
|
|
231
|
+
x_0 = sol[count - 1] # Last value used as next x_0
|
|
131
232
|
|
|
132
|
-
|
|
133
|
-
prev_count += count
|
|
233
|
+
prev_count += (count - 1)
|
|
134
234
|
|
|
135
|
-
return states.T
|
|
235
|
+
return states.T
|
openscvx/ptr.py
CHANGED
|
@@ -4,55 +4,91 @@ import cvxpy as cp
|
|
|
4
4
|
import pickle
|
|
5
5
|
import time
|
|
6
6
|
|
|
7
|
+
from openscvx.backend.parameter import Parameter
|
|
7
8
|
from openscvx.config import Config
|
|
9
|
+
from openscvx.results import OptimizationResults
|
|
8
10
|
|
|
9
11
|
import warnings
|
|
10
12
|
warnings.filterwarnings("ignore")
|
|
11
13
|
|
|
12
|
-
def PTR_init(ocp: cp.Problem, discretization_solver: callable,
|
|
13
|
-
if
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
14
|
+
def PTR_init(params, ocp: cp.Problem, discretization_solver: callable, settings: Config):
|
|
15
|
+
if settings.cvx.cvxpygen:
|
|
16
|
+
try:
|
|
17
|
+
from solver.cpg_solver import cpg_solve
|
|
18
|
+
with open('solver/problem.pickle', 'rb') as f:
|
|
19
|
+
prob = pickle.load(f)
|
|
20
|
+
except ImportError:
|
|
21
|
+
raise ImportError(
|
|
22
|
+
"cvxpygen solver not found. Make sure cvxpygen is installed and code generation has been run. "
|
|
23
|
+
"Install with: pip install openscvx[cvxpygen]"
|
|
24
|
+
)
|
|
17
25
|
else:
|
|
18
26
|
cpg_solve = None
|
|
19
27
|
|
|
28
|
+
if 'x_init' in ocp.param_dict:
|
|
29
|
+
ocp.param_dict['x_init'].value = settings.sim.x.initial
|
|
30
|
+
|
|
31
|
+
if 'x_term' in ocp.param_dict:
|
|
32
|
+
ocp.param_dict['x_term'].value = settings.sim.x.final
|
|
33
|
+
|
|
20
34
|
# Solve a dumb problem to intilize DPP and JAX jacobians
|
|
21
|
-
_ = PTR_subproblem(cpg_solve,
|
|
35
|
+
_ = PTR_subproblem(params.items(), cpg_solve, settings.sim.x, settings.sim.u, discretization_solver, ocp, settings)
|
|
22
36
|
|
|
23
37
|
return cpg_solve
|
|
24
38
|
|
|
25
|
-
def
|
|
39
|
+
def format_result(problem, converged: bool) -> OptimizationResults:
|
|
40
|
+
"""Formats the final result as an OptimizationResults object from the problem's state."""
|
|
41
|
+
return OptimizationResults(
|
|
42
|
+
converged=converged,
|
|
43
|
+
t_final=problem.settings.sim.x.guess[:, problem.settings.sim.idx_t][-1],
|
|
44
|
+
u=problem.settings.sim.u,
|
|
45
|
+
x=problem.settings.sim.x,
|
|
46
|
+
x_history=problem.scp_trajs,
|
|
47
|
+
u_history=problem.scp_controls,
|
|
48
|
+
discretization_history=problem.scp_V_multi_shoot_traj,
|
|
49
|
+
J_tr_history=problem.scp_J_tr,
|
|
50
|
+
J_vb_history=problem.scp_J_vb,
|
|
51
|
+
J_vc_history=problem.scp_J_vc,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def PTR_main(params, settings: Config, prob: cp.Problem, aug_dy: callable, cpg_solve, emitter_function) -> OptimizationResults:
|
|
26
55
|
J_vb = 1E2
|
|
27
56
|
J_vc = 1E2
|
|
28
57
|
J_tr = 1E2
|
|
29
58
|
|
|
30
|
-
|
|
31
|
-
|
|
59
|
+
x = settings.sim.x
|
|
60
|
+
u = settings.sim.u
|
|
61
|
+
|
|
62
|
+
if 'x_init' in prob.param_dict:
|
|
63
|
+
prob.param_dict['x_init'].value = settings.sim.x.initial
|
|
64
|
+
|
|
65
|
+
if 'x_term' in prob.param_dict:
|
|
66
|
+
prob.param_dict['x_term'].value = settings.sim.x.final
|
|
67
|
+
|
|
32
68
|
|
|
33
|
-
scp_trajs = [
|
|
34
|
-
scp_controls = [
|
|
69
|
+
scp_trajs = [x.guess]
|
|
70
|
+
scp_controls = [u.guess]
|
|
35
71
|
V_multi_shoot_traj = []
|
|
36
72
|
|
|
37
73
|
k = 1
|
|
38
74
|
|
|
39
|
-
while k <=
|
|
40
|
-
|
|
75
|
+
while k <= settings.scp.k_max and ((J_tr >= settings.scp.ep_tr) or (J_vb >= settings.scp.ep_vb) or (J_vc >= settings.scp.ep_vc)):
|
|
76
|
+
x_sol, u_sol, cost, J_total, J_vb_vec, J_vc_vec, J_tr_vec, prob_stat, V_multi_shoot, subprop_time, dis_time = PTR_subproblem(params.items(), cpg_solve, x, u, aug_dy, prob, settings)
|
|
41
77
|
|
|
42
78
|
V_multi_shoot_traj.append(V_multi_shoot)
|
|
43
79
|
|
|
44
|
-
|
|
45
|
-
|
|
80
|
+
x.guess = x_sol
|
|
81
|
+
u.guess = u_sol
|
|
46
82
|
|
|
47
83
|
J_tr = np.sum(np.array(J_tr_vec))
|
|
48
84
|
J_vb = np.sum(np.array(J_vb_vec))
|
|
49
85
|
J_vc = np.sum(np.array(J_vc_vec))
|
|
50
|
-
scp_trajs.append(x)
|
|
51
|
-
scp_controls.append(u)
|
|
86
|
+
scp_trajs.append(x.guess)
|
|
87
|
+
scp_controls.append(u.guess)
|
|
52
88
|
|
|
53
|
-
|
|
54
|
-
if k >
|
|
55
|
-
|
|
89
|
+
settings.scp.w_tr = min(settings.scp.w_tr * settings.scp.w_tr_adapt, settings.scp.w_tr_max)
|
|
90
|
+
if k > settings.scp.cost_drop:
|
|
91
|
+
settings.scp.lam_cost = settings.scp.lam_cost * settings.scp.cost_relax
|
|
56
92
|
|
|
57
93
|
emitter_function(
|
|
58
94
|
{
|
|
@@ -63,34 +99,37 @@ def PTR_main(params: Config, prob: cp.Problem, aug_dy: callable, cpg_solve, emit
|
|
|
63
99
|
"J_tr": J_tr,
|
|
64
100
|
"J_vb": J_vb,
|
|
65
101
|
"J_vc": J_vc,
|
|
66
|
-
"cost":
|
|
102
|
+
"cost": cost[-1],
|
|
67
103
|
"prob_stat": prob_stat,
|
|
68
104
|
}
|
|
69
105
|
)
|
|
70
106
|
|
|
71
107
|
k += 1
|
|
72
108
|
|
|
73
|
-
result =
|
|
74
|
-
converged
|
|
75
|
-
t_final
|
|
76
|
-
u
|
|
77
|
-
x
|
|
78
|
-
x_history
|
|
79
|
-
u_history
|
|
80
|
-
discretization_history
|
|
81
|
-
J_tr_history
|
|
82
|
-
J_vb_history
|
|
83
|
-
J_vc_history
|
|
109
|
+
result = OptimizationResults(
|
|
110
|
+
converged=k <= settings.scp.k_max,
|
|
111
|
+
t_final=x.guess[:,settings.sim.idx_t][-1],
|
|
112
|
+
u=u,
|
|
113
|
+
x=x,
|
|
114
|
+
x_history=scp_trajs,
|
|
115
|
+
u_history=scp_controls,
|
|
116
|
+
discretization_history=V_multi_shoot_traj,
|
|
117
|
+
J_tr_history=J_tr_vec,
|
|
118
|
+
J_vb_history=J_vb_vec,
|
|
119
|
+
J_vc_history=J_vc_vec,
|
|
84
120
|
)
|
|
121
|
+
|
|
85
122
|
return result
|
|
86
123
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
prob.param_dict['
|
|
90
|
-
prob.param_dict['u_bar'].value = u_bar
|
|
124
|
+
def PTR_subproblem(params, cpg_solve, x, u, aug_dy, prob, settings: Config):
|
|
125
|
+
prob.param_dict['x_bar'].value = x.guess
|
|
126
|
+
prob.param_dict['u_bar'].value = u.guess
|
|
91
127
|
|
|
128
|
+
# Make a tuple from list of parameter values
|
|
129
|
+
param_values = tuple([param.value for _, param in params])
|
|
130
|
+
|
|
92
131
|
t0 = time.time()
|
|
93
|
-
A_bar, B_bar, C_bar, z_bar, V_multi_shoot = aug_dy(
|
|
132
|
+
A_bar, B_bar, C_bar, z_bar, V_multi_shoot = aug_dy.call(x.guess, u.guess.astype(float), *param_values)
|
|
94
133
|
|
|
95
134
|
prob.param_dict['A_d'].value = A_bar.__array__()
|
|
96
135
|
prob.param_dict['B_d'].value = B_bar.__array__()
|
|
@@ -98,52 +137,52 @@ def PTR_subproblem(cpg_solve, x_bar, u_bar, aug_dy, prob, params: Config):
|
|
|
98
137
|
prob.param_dict['z_d'].value = z_bar.__array__()
|
|
99
138
|
dis_time = time.time() - t0
|
|
100
139
|
|
|
101
|
-
if
|
|
102
|
-
for g_id, constraint in enumerate(
|
|
140
|
+
if settings.sim.constraints_nodal:
|
|
141
|
+
for g_id, constraint in enumerate(settings.sim.constraints_nodal):
|
|
103
142
|
if not constraint.convex:
|
|
104
|
-
prob.param_dict['g_' + str(g_id)].value = np.asarray(constraint.g(
|
|
105
|
-
prob.param_dict['grad_g_x_' + str(g_id)].value = np.asarray(constraint.grad_g_x(
|
|
106
|
-
prob.param_dict['grad_g_u_' + str(g_id)].value = np.asarray(constraint.grad_g_u(
|
|
143
|
+
prob.param_dict['g_' + str(g_id)].value = np.asarray(constraint.g(x.guess, u.guess))
|
|
144
|
+
prob.param_dict['grad_g_x_' + str(g_id)].value = np.asarray(constraint.grad_g_x(x.guess, u.guess))
|
|
145
|
+
prob.param_dict['grad_g_u_' + str(g_id)].value = np.asarray(constraint.grad_g_u(x.guess, u.guess))
|
|
107
146
|
|
|
108
|
-
prob.param_dict['w_tr'].value =
|
|
109
|
-
prob.param_dict['lam_cost'].value =
|
|
147
|
+
prob.param_dict['w_tr'].value = settings.scp.w_tr
|
|
148
|
+
prob.param_dict['lam_cost'].value = settings.scp.lam_cost
|
|
110
149
|
|
|
111
|
-
if
|
|
150
|
+
if settings.cvx.cvxpygen:
|
|
112
151
|
t0 = time.time()
|
|
113
152
|
prob.register_solve('CPG', cpg_solve)
|
|
114
|
-
prob.solve(method = 'CPG', **
|
|
153
|
+
prob.solve(method = 'CPG', **settings.cvx.solver_args)
|
|
115
154
|
subprop_time = time.time() - t0
|
|
116
155
|
else:
|
|
117
156
|
t0 = time.time()
|
|
118
|
-
prob.solve(solver =
|
|
157
|
+
prob.solve(solver = settings.cvx.solver, **settings.cvx.solver_args)
|
|
119
158
|
subprop_time = time.time() - t0
|
|
120
159
|
|
|
121
|
-
|
|
122
|
-
|
|
160
|
+
x_new_guess = (settings.sim.S_x @ prob.var_dict['x'].value.T + np.expand_dims(settings.sim.c_x, axis = 1)).T
|
|
161
|
+
u_new_guess = (settings.sim.S_u @ prob.var_dict['u'].value.T + np.expand_dims(settings.sim.c_u, axis = 1)).T
|
|
123
162
|
|
|
124
163
|
i = 0
|
|
125
164
|
costs = [0]
|
|
126
|
-
for type in
|
|
165
|
+
for type in x.final_type:
|
|
127
166
|
if type == 'Minimize':
|
|
128
|
-
costs +=
|
|
167
|
+
costs += x_new_guess[:,i]
|
|
129
168
|
if type == 'Maximize':
|
|
130
|
-
costs -=
|
|
169
|
+
costs -= x_new_guess[:,i]
|
|
131
170
|
i += 1
|
|
132
171
|
|
|
133
172
|
# Create the block diagonal matrix using jax.numpy.block
|
|
134
173
|
inv_block_diag = np.block([
|
|
135
|
-
[
|
|
136
|
-
[np.zeros((
|
|
174
|
+
[settings.sim.inv_S_x, np.zeros((settings.sim.inv_S_x.shape[0], settings.sim.inv_S_u.shape[1]))],
|
|
175
|
+
[np.zeros((settings.sim.inv_S_u.shape[0], settings.sim.inv_S_x.shape[1])), settings.sim.inv_S_u]
|
|
137
176
|
])
|
|
138
177
|
|
|
139
178
|
# Calculate J_tr_vec using the JAX-compatible block diagonal matrix
|
|
140
|
-
J_tr_vec = la.norm(inv_block_diag @ np.hstack((
|
|
179
|
+
J_tr_vec = la.norm(inv_block_diag @ np.hstack((x_new_guess - x.guess, u_new_guess - u.guess)).T, axis=0)**2
|
|
141
180
|
J_vc_vec = np.sum(np.abs(prob.var_dict['nu'].value), axis = 1)
|
|
142
181
|
|
|
143
182
|
id_ncvx = 0
|
|
144
183
|
J_vb_vec = 0
|
|
145
|
-
for constraint in
|
|
184
|
+
for constraint in settings.sim.constraints_nodal:
|
|
146
185
|
if constraint.convex == False:
|
|
147
186
|
J_vb_vec += np.maximum(0, prob.var_dict['nu_vb_' + str(id_ncvx)].value)
|
|
148
187
|
id_ncvx += 1
|
|
149
|
-
return
|
|
188
|
+
return x_new_guess, u_new_guess, costs, prob.value, J_vb_vec, J_vc_vec, J_tr_vec, prob.status, V_multi_shoot, subprop_time, dis_time
|
openscvx/results.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import List, Optional, Any, Dict
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from openscvx.backend.state import State
|
|
6
|
+
from openscvx.backend.control import Control
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class OptimizationResults:
|
|
11
|
+
"""
|
|
12
|
+
Class to hold optimization results from the SCP solver.
|
|
13
|
+
|
|
14
|
+
This class replaces the dictionary-based results structure with a more
|
|
15
|
+
structured and type-safe approach.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
# Core optimization results
|
|
19
|
+
converged: bool
|
|
20
|
+
t_final: float
|
|
21
|
+
u: Control
|
|
22
|
+
x: State
|
|
23
|
+
|
|
24
|
+
# History of SCP iterations
|
|
25
|
+
x_history: List[np.ndarray] = field(default_factory=list)
|
|
26
|
+
u_history: List[np.ndarray] = field(default_factory=list)
|
|
27
|
+
discretization_history: List[np.ndarray] = field(default_factory=list)
|
|
28
|
+
J_tr_history: List[np.ndarray] = field(default_factory=list)
|
|
29
|
+
J_vb_history: List[np.ndarray] = field(default_factory=list)
|
|
30
|
+
J_vc_history: List[np.ndarray] = field(default_factory=list)
|
|
31
|
+
|
|
32
|
+
# Post-processing results (added by propagate_trajectory_results)
|
|
33
|
+
t_full: Optional[np.ndarray] = None
|
|
34
|
+
x_full: Optional[np.ndarray] = None
|
|
35
|
+
u_full: Optional[np.ndarray] = None
|
|
36
|
+
cost: Optional[float] = None
|
|
37
|
+
ctcs_violation: Optional[np.ndarray] = None
|
|
38
|
+
|
|
39
|
+
# Additional plotting/application data (added by user)
|
|
40
|
+
plotting_data: Dict[str, Any] = field(default_factory=dict)
|
|
41
|
+
|
|
42
|
+
def __post_init__(self):
|
|
43
|
+
"""Initialize the results object."""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
def update_plotting_data(self, **kwargs):
|
|
47
|
+
"""
|
|
48
|
+
Update the plotting data with additional information.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
**kwargs: Key-value pairs to add to plotting_data
|
|
52
|
+
"""
|
|
53
|
+
self.plotting_data.update(kwargs)
|
|
54
|
+
|
|
55
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
56
|
+
"""
|
|
57
|
+
Get a value from the results, similar to dict.get().
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
key: The key to look up
|
|
61
|
+
default: Default value if key is not found
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
The value associated with the key, or default if not found
|
|
65
|
+
"""
|
|
66
|
+
# Check if it's a direct attribute
|
|
67
|
+
if hasattr(self, key):
|
|
68
|
+
return getattr(self, key)
|
|
69
|
+
|
|
70
|
+
# Check if it's in plotting_data
|
|
71
|
+
if key in self.plotting_data:
|
|
72
|
+
return self.plotting_data[key]
|
|
73
|
+
|
|
74
|
+
return default
|
|
75
|
+
|
|
76
|
+
def __getitem__(self, key: str) -> Any:
|
|
77
|
+
"""
|
|
78
|
+
Allow dictionary-style access to results.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
key: The key to look up
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
The value associated with the key
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
KeyError: If key is not found
|
|
88
|
+
"""
|
|
89
|
+
# Check if it's a direct attribute
|
|
90
|
+
if hasattr(self, key):
|
|
91
|
+
return getattr(self, key)
|
|
92
|
+
|
|
93
|
+
# Check if it's in plotting_data
|
|
94
|
+
if key in self.plotting_data:
|
|
95
|
+
return self.plotting_data[key]
|
|
96
|
+
|
|
97
|
+
raise KeyError(f"Key '{key}' not found in results")
|
|
98
|
+
|
|
99
|
+
def __setitem__(self, key: str, value: Any):
|
|
100
|
+
"""
|
|
101
|
+
Allow dictionary-style assignment to results.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
key: The key to set
|
|
105
|
+
value: The value to assign
|
|
106
|
+
"""
|
|
107
|
+
# Check if it's a direct attribute
|
|
108
|
+
if hasattr(self, key):
|
|
109
|
+
setattr(self, key, value)
|
|
110
|
+
else:
|
|
111
|
+
# Store in plotting_data
|
|
112
|
+
self.plotting_data[key] = value
|
|
113
|
+
|
|
114
|
+
def __contains__(self, key: str) -> bool:
|
|
115
|
+
"""
|
|
116
|
+
Check if a key exists in the results.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
key: The key to check
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
True if key exists, False otherwise
|
|
123
|
+
"""
|
|
124
|
+
return hasattr(self, key) or key in self.plotting_data
|
|
125
|
+
|
|
126
|
+
def update(self, other: Dict[str, Any]):
|
|
127
|
+
"""
|
|
128
|
+
Update the results with additional data from a dictionary.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
other: Dictionary containing additional data
|
|
132
|
+
"""
|
|
133
|
+
for key, value in other.items():
|
|
134
|
+
self[key] = value
|
|
135
|
+
|
|
136
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
137
|
+
"""
|
|
138
|
+
Convert the results to a dictionary for backward compatibility.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Dictionary representation of the results
|
|
142
|
+
"""
|
|
143
|
+
result_dict = {}
|
|
144
|
+
|
|
145
|
+
# Add all direct attributes
|
|
146
|
+
for attr_name in self.__dataclass_fields__:
|
|
147
|
+
if attr_name != 'plotting_data':
|
|
148
|
+
result_dict[attr_name] = getattr(self, attr_name)
|
|
149
|
+
|
|
150
|
+
# Add plotting data
|
|
151
|
+
result_dict.update(self.plotting_data)
|
|
152
|
+
|
|
153
|
+
return result_dict
|