openscvx 0.1.2__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

openscvx/propagation.py CHANGED
@@ -2,6 +2,7 @@ import numpy as np
2
2
 
3
3
  from openscvx.config import Config
4
4
  from openscvx.integrators import solve_ivp_diffrax_prop
5
+ from openscvx.backend.parameter import Parameter
5
6
 
6
7
 
7
8
  def prop_aug_dy(
@@ -15,7 +16,29 @@ def prop_aug_dy(
15
16
  state_dot: callable,
16
17
  dis_type: str,
17
18
  N: int,
19
+ *params
18
20
  ) -> np.ndarray:
21
+ """Compute the augmented dynamics for propagation.
22
+
23
+ This function computes the time-scaled dynamics for propagating the system state,
24
+ taking into account the discretization type (ZOH or FOH) and time dilation.
25
+
26
+ Args:
27
+ tau (float): Current normalized time in [0,1].
28
+ x (np.ndarray): Current state vector.
29
+ u_current (np.ndarray): Control input at current node.
30
+ u_next (np.ndarray): Control input at next node.
31
+ tau_init (float): Initial normalized time.
32
+ node (int): Current node index.
33
+ idx_s (int): Index of time dilation variable in control vector.
34
+ state_dot (callable): Function computing state derivatives.
35
+ dis_type (str): Discretization type ("ZOH" or "FOH").
36
+ N (int): Number of nodes in trajectory.
37
+ *params: Additional parameters passed to state_dot.
38
+
39
+ Returns:
40
+ np.ndarray: Time-scaled state derivatives.
41
+ """
19
42
  x = x[None, :]
20
43
 
21
44
  if dis_type == "ZOH":
@@ -24,37 +47,69 @@ def prop_aug_dy(
24
47
  beta = (tau - tau_init) * N
25
48
  u = u_current + beta * (u_next - u_current)
26
49
 
27
- return u[:, idx_s] * state_dot(x, u[:, :-1], node).squeeze()
28
-
29
-
30
- def get_propagation_solver(state_dot, params):
31
- def propagation_solver(V0, tau_grid, u_cur, u_next, tau_init, node, idx_s):
50
+ return u[:, idx_s] * state_dot(x, u[:, :-1], node, *params).squeeze()
51
+
52
+ def get_propagation_solver(state_dot, settings, param_map):
53
+ """Create a propagation solver function.
54
+
55
+ This function creates a solver that propagates the system state using the
56
+ specified dynamics and settings.
57
+
58
+ Args:
59
+ state_dot (callable): Function computing state derivatives.
60
+ settings: Configuration settings for propagation.
61
+ param_map (dict): Mapping of parameter names to values.
62
+
63
+ Returns:
64
+ callable: A function that solves the propagation problem.
65
+ """
66
+ def propagation_solver(V0, tau_grid, u_cur, u_next, tau_init, node, idx_s, save_time, mask, *params):
67
+ param_map_update = dict(zip(param_map.keys(), params))
32
68
  return solve_ivp_diffrax_prop(
33
69
  f=prop_aug_dy,
34
- tau_final=tau_grid[1],
35
- y_0=V0,
70
+ tau_final=tau_grid[1], # scalar
71
+ y_0=V0, # shape (n_states,)
36
72
  args=(
37
- u_cur,
38
- u_next,
39
- tau_init,
40
- node,
41
- idx_s,
42
- state_dot,
43
- params.dis.dis_type,
44
- params.scp.n,
73
+ u_cur, # shape (1, n_controls)
74
+ u_next, # shape (1, n_controls)
75
+ tau_init, # shape (1, 1)
76
+ node, # shape (1, 1)
77
+ idx_s, # int
78
+ state_dot, # function or array
79
+ settings.dis.dis_type,
80
+ settings.scp.n,
81
+ *param_map_update.items(),
82
+ # additional named parameters as **kwargs
45
83
  ),
46
- tau_0=tau_grid[0],
84
+ tau_0=tau_grid[0], # scalar
85
+ save_time=save_time, # shape (MAX_TAU_LEN,)
86
+ mask=mask # shape (MAX_TAU_LEN,), dtype=bool
47
87
  )
48
88
 
49
89
  return propagation_solver
50
90
 
51
91
 
52
- def s_to_t(u, params: Config):
53
- t = [0]
92
+
93
+
94
+ def s_to_t(x, u, params: Config):
95
+ """Convert normalized time s to real time t.
96
+
97
+ This function converts the normalized time variable s to real time t
98
+ based on the discretization type and time dilation factors.
99
+
100
+ Args:
101
+ x: State trajectory.
102
+ u: Control trajectory.
103
+ params (Config): Configuration settings.
104
+
105
+ Returns:
106
+ list: List of real time points.
107
+ """
108
+ t = [x.guess[:,params.sim.idx_t][0]]
54
109
  tau = np.linspace(0, 1, params.scp.n)
55
110
  for k in range(1, params.scp.n):
56
- s_kp = u[k - 1, -1]
57
- s_k = u[k, -1]
111
+ s_kp = u.guess[k - 1, -1]
112
+ s_k = u.guess[k, -1]
58
113
  if params.dis.dis_type == "ZOH":
59
114
  t.append(t[k - 1] + (tau[k] - tau[k - 1]) * (s_kp))
60
115
  else:
@@ -62,17 +117,30 @@ def s_to_t(u, params: Config):
62
117
  return t
63
118
 
64
119
 
65
- def t_to_tau(u, t, u_nodal, t_nodal, params: Config):
66
- u_lam = lambda new_t: np.array(
67
- [np.interp(new_t, t_nodal, u[:, i]) for i in range(u.shape[1])]
68
- ).T
120
+ def t_to_tau(u, t, t_nodal, params: Config):
121
+ """Convert real time t to normalized time tau.
122
+
123
+ This function converts real time t to normalized time tau and interpolates
124
+ the control inputs accordingly.
125
+
126
+ Args:
127
+ u: Control trajectory.
128
+ t (np.ndarray): Real time points.
129
+ t_nodal (np.ndarray): Nodal time points.
130
+ params (Config): Configuration settings.
131
+
132
+ Returns:
133
+ tuple: (tau, u_interp) where tau is normalized time and u_interp is interpolated controls.
134
+ """
135
+ u_guess = u.guess
136
+ u_lam = lambda new_t: np.array([np.interp(new_t, t_nodal, u_guess[:,i]) for i in range(u_guess.shape[1])]).T
69
137
  u = np.array([u_lam(t_i) for t_i in t])
70
138
 
71
139
  tau = np.zeros(len(t))
72
140
  tau_nodal = np.linspace(0, 1, params.scp.n)
73
141
  for k in range(1, len(t)):
74
142
  k_nodal = np.where(t_nodal < t[k])[0][-1]
75
- s_kp = u_nodal[k_nodal, -1]
143
+ s_kp = u_guess[k_nodal, -1]
76
144
  tp = t_nodal[k_nodal]
77
145
  tau_p = tau_nodal[k_nodal]
78
146
 
@@ -84,52 +152,84 @@ def t_to_tau(u, t, u_nodal, t_nodal, params: Config):
84
152
  return tau, u
85
153
 
86
154
 
87
- def simulate_nonlinear_time(x_0, u, tau_vals, t, params, propagation_solver):
88
- states = np.empty(
89
- (x_0.shape[0], 0)
90
- ) # Initialize states as a 2D array with shape (n, 0)
91
-
92
- tau = np.linspace(0, 1, params.scp.n)
93
-
94
- u_lam = lambda new_t: np.array(
95
- [np.interp(new_t, t, u[:, i]) for i in range(u.shape[1])]
96
- ).T
97
-
98
- # Bin the tau_vals into with respect to the uniform tau grid, tau
155
+ def simulate_nonlinear_time(params, x, u, tau_vals, t, settings, propagation_solver):
156
+ """Simulate the nonlinear system dynamics over time.
157
+
158
+ This function simulates the system dynamics using the optimal control sequence
159
+ and returns the resulting state trajectory.
160
+
161
+ Args:
162
+ params: System parameters.
163
+ x: State trajectory.
164
+ u: Control trajectory.
165
+ tau_vals (np.ndarray): Normalized time points for simulation.
166
+ t (np.ndarray): Real time points.
167
+ settings: Configuration settings.
168
+ propagation_solver (callable): Function for propagating the system state.
169
+
170
+ Returns:
171
+ np.ndarray: Simulated state trajectory.
172
+ """
173
+ x_0 = settings.sim.x_prop.initial
174
+
175
+ n_segments = settings.scp.n - 1
176
+ n_states = x_0.shape[0]
177
+ n_tau = len(tau_vals)
178
+
179
+ params = params.items()
180
+ param_values = tuple([param.value for _, param in params])
181
+
182
+ states = np.empty((n_states, n_tau))
183
+ tau = np.linspace(0, 1, settings.scp.n)
184
+
185
+ # Precompute control interpolation
186
+ u_interp = np.stack([
187
+ np.interp(t, t, u.guess[:, i]) for i in range(u.guess.shape[1])
188
+ ], axis=-1)
189
+
190
+ # Bin tau_vals into segments of tau
99
191
  tau_inds = np.digitize(tau_vals, tau) - 1
100
- # Force the last indice to be in the same bin as the previous ones
101
- tau_inds = np.where(tau_inds == params.scp.n - 1, params.scp.n - 2, tau_inds)
192
+ tau_inds = np.where(tau_inds == settings.scp.n - 1, settings.scp.n - 2, tau_inds)
102
193
 
103
194
  prev_count = 0
195
+ out_idx = 0
104
196
 
105
- for k in range(params.scp.n - 1):
106
- controls_current = np.squeeze(u_lam(t[k]))[None, :]
107
- controls_next = np.squeeze(u_lam(t[k + 1]))[None, :]
197
+ for k in range(n_segments):
198
+ controls_current = u_interp[k][None, :]
199
+ controls_next = u_interp[k + 1][None, :]
108
200
 
109
- # Create a mask
201
+ # Mask for tau_vals in current segment
110
202
  mask = (tau_inds >= k) & (tau_inds < k + 1)
111
-
112
203
  count = np.sum(mask)
113
204
 
114
- # Use count to grab the first count number of elements
115
- tau_cur = tau_vals[prev_count : prev_count + count]
205
+ tau_cur = tau_vals[prev_count:prev_count + count]
206
+ tau_cur = np.concatenate([tau_cur, np.array([tau[k + 1]])]) # Always include final point
207
+ count += 1
208
+
209
+ # Pad to fixed length
210
+ pad_len = settings.prp.max_tau_len - count
211
+ tau_cur_padded = np.pad(tau_cur, (0, pad_len), constant_values=tau[k + 1])
212
+ mask_padded = np.concatenate([np.ones(count), np.zeros(pad_len)]).astype(bool)
116
213
 
117
- sol = propagation_solver(
214
+ # Call the solver with padded tau_cur and mask
215
+ sol = propagation_solver.call(
118
216
  x_0,
119
217
  (tau[k], tau[k + 1]),
120
218
  controls_current,
121
219
  controls_next,
122
220
  np.array([[tau[k]]]),
123
221
  np.array([[k]]),
124
- params.sim.idx_s.stop,
222
+ settings.sim.idx_s.stop,
223
+ tau_cur_padded,
224
+ mask_padded,
225
+ *param_values
125
226
  )
126
227
 
127
- x = sol.ys
128
- for tau_i in tau_cur:
129
- new_state = sol.evaluate(tau_i).reshape(-1, 1) # Ensure new_state is 2D
130
- states = np.concatenate([states, new_state], axis=1)
228
+ # Only store the valid portion (excluding the final point which becomes next x_0)
229
+ states[:, out_idx:out_idx + count - 1] = sol[:count - 1].T
230
+ out_idx += count - 1
231
+ x_0 = sol[count - 1] # Last value used as next x_0
131
232
 
132
- x_0 = x[-1]
133
- prev_count += count
233
+ prev_count += (count - 1)
134
234
 
135
- return states.T
235
+ return states.T
openscvx/ptr.py CHANGED
@@ -4,55 +4,91 @@ import cvxpy as cp
4
4
  import pickle
5
5
  import time
6
6
 
7
+ from openscvx.backend.parameter import Parameter
7
8
  from openscvx.config import Config
9
+ from openscvx.results import OptimizationResults
8
10
 
9
11
  import warnings
10
12
  warnings.filterwarnings("ignore")
11
13
 
12
- def PTR_init(ocp: cp.Problem, discretization_solver: callable, params: Config):
13
- if params.cvx.cvxpygen:
14
- from solver.cpg_solver import cpg_solve
15
- with open('solver/problem.pickle', 'rb') as f:
16
- prob = pickle.load(f)
14
+ def PTR_init(params, ocp: cp.Problem, discretization_solver: callable, settings: Config):
15
+ if settings.cvx.cvxpygen:
16
+ try:
17
+ from solver.cpg_solver import cpg_solve
18
+ with open('solver/problem.pickle', 'rb') as f:
19
+ prob = pickle.load(f)
20
+ except ImportError:
21
+ raise ImportError(
22
+ "cvxpygen solver not found. Make sure cvxpygen is installed and code generation has been run. "
23
+ "Install with: pip install openscvx[cvxpygen]"
24
+ )
17
25
  else:
18
26
  cpg_solve = None
19
27
 
28
+ if 'x_init' in ocp.param_dict:
29
+ ocp.param_dict['x_init'].value = settings.sim.x.initial
30
+
31
+ if 'x_term' in ocp.param_dict:
32
+ ocp.param_dict['x_term'].value = settings.sim.x.final
33
+
20
34
  # Solve a dumb problem to intilize DPP and JAX jacobians
21
- _ = PTR_subproblem(cpg_solve, params.sim.x_bar, params.sim.u_bar, discretization_solver, ocp, params)
35
+ _ = PTR_subproblem(params.items(), cpg_solve, settings.sim.x, settings.sim.u, discretization_solver, ocp, settings)
22
36
 
23
37
  return cpg_solve
24
38
 
25
- def PTR_main(params: Config, prob: cp.Problem, aug_dy: callable, cpg_solve, emitter_function) -> dict:
39
+ def format_result(problem, converged: bool) -> OptimizationResults:
40
+ """Formats the final result as an OptimizationResults object from the problem's state."""
41
+ return OptimizationResults(
42
+ converged=converged,
43
+ t_final=problem.settings.sim.x.guess[:, problem.settings.sim.idx_t][-1],
44
+ u=problem.settings.sim.u,
45
+ x=problem.settings.sim.x,
46
+ x_history=problem.scp_trajs,
47
+ u_history=problem.scp_controls,
48
+ discretization_history=problem.scp_V_multi_shoot_traj,
49
+ J_tr_history=problem.scp_J_tr,
50
+ J_vb_history=problem.scp_J_vb,
51
+ J_vc_history=problem.scp_J_vc,
52
+ )
53
+
54
+ def PTR_main(params, settings: Config, prob: cp.Problem, aug_dy: callable, cpg_solve, emitter_function) -> OptimizationResults:
26
55
  J_vb = 1E2
27
56
  J_vc = 1E2
28
57
  J_tr = 1E2
29
58
 
30
- x_bar = params.sim.x_bar
31
- u_bar = params.sim.u_bar
59
+ x = settings.sim.x
60
+ u = settings.sim.u
61
+
62
+ if 'x_init' in prob.param_dict:
63
+ prob.param_dict['x_init'].value = settings.sim.x.initial
64
+
65
+ if 'x_term' in prob.param_dict:
66
+ prob.param_dict['x_term'].value = settings.sim.x.final
67
+
32
68
 
33
- scp_trajs = [x_bar]
34
- scp_controls = [u_bar]
69
+ scp_trajs = [x.guess]
70
+ scp_controls = [u.guess]
35
71
  V_multi_shoot_traj = []
36
72
 
37
73
  k = 1
38
74
 
39
- while k <= params.scp.k_max and ((J_tr >= params.scp.ep_tr) or (J_vb >= params.scp.ep_vb) or (J_vc >= params.scp.ep_vc)):
40
- x, u, t, J_total, J_vb_vec, J_vc_vec, J_tr_vec, prob_stat, V_multi_shoot, subprop_time, dis_time = PTR_subproblem(cpg_solve, x_bar, u_bar, aug_dy, prob, params)
75
+ while k <= settings.scp.k_max and ((J_tr >= settings.scp.ep_tr) or (J_vb >= settings.scp.ep_vb) or (J_vc >= settings.scp.ep_vc)):
76
+ x_sol, u_sol, cost, J_total, J_vb_vec, J_vc_vec, J_tr_vec, prob_stat, V_multi_shoot, subprop_time, dis_time = PTR_subproblem(params.items(), cpg_solve, x, u, aug_dy, prob, settings)
41
77
 
42
78
  V_multi_shoot_traj.append(V_multi_shoot)
43
79
 
44
- x_bar = x
45
- u_bar = u
80
+ x.guess = x_sol
81
+ u.guess = u_sol
46
82
 
47
83
  J_tr = np.sum(np.array(J_tr_vec))
48
84
  J_vb = np.sum(np.array(J_vb_vec))
49
85
  J_vc = np.sum(np.array(J_vc_vec))
50
- scp_trajs.append(x)
51
- scp_controls.append(u)
86
+ scp_trajs.append(x.guess)
87
+ scp_controls.append(u.guess)
52
88
 
53
- params.scp.w_tr = min(params.scp.w_tr * params.scp.w_tr_adapt, params.scp.w_tr_max)
54
- if k > params.scp.cost_drop:
55
- params.scp.lam_cost = params.scp.lam_cost * params.scp.cost_relax
89
+ settings.scp.w_tr = min(settings.scp.w_tr * settings.scp.w_tr_adapt, settings.scp.w_tr_max)
90
+ if k > settings.scp.cost_drop:
91
+ settings.scp.lam_cost = settings.scp.lam_cost * settings.scp.cost_relax
56
92
 
57
93
  emitter_function(
58
94
  {
@@ -63,34 +99,37 @@ def PTR_main(params: Config, prob: cp.Problem, aug_dy: callable, cpg_solve, emit
63
99
  "J_tr": J_tr,
64
100
  "J_vb": J_vb,
65
101
  "J_vc": J_vc,
66
- "cost": t[-1],
102
+ "cost": cost[-1],
67
103
  "prob_stat": prob_stat,
68
104
  }
69
105
  )
70
106
 
71
107
  k += 1
72
108
 
73
- result = dict(
74
- converged = k <= params.scp.k_max,
75
- t_final = x[:,params.sim.idx_t][-1],
76
- u = u,
77
- x = x,
78
- x_history = scp_trajs,
79
- u_history = scp_controls,
80
- discretization_history = V_multi_shoot_traj,
81
- J_tr_history = J_tr_vec,
82
- J_vb_history = J_vb_vec,
83
- J_vc_history = J_vc_vec,
109
+ result = OptimizationResults(
110
+ converged=k <= settings.scp.k_max,
111
+ t_final=x.guess[:,settings.sim.idx_t][-1],
112
+ u=u,
113
+ x=x,
114
+ x_history=scp_trajs,
115
+ u_history=scp_controls,
116
+ discretization_history=V_multi_shoot_traj,
117
+ J_tr_history=J_tr_vec,
118
+ J_vb_history=J_vb_vec,
119
+ J_vc_history=J_vc_vec,
84
120
  )
121
+
85
122
  return result
86
123
 
87
-
88
- def PTR_subproblem(cpg_solve, x_bar, u_bar, aug_dy, prob, params: Config):
89
- prob.param_dict['x_bar'].value = x_bar
90
- prob.param_dict['u_bar'].value = u_bar
124
+ def PTR_subproblem(params, cpg_solve, x, u, aug_dy, prob, settings: Config):
125
+ prob.param_dict['x_bar'].value = x.guess
126
+ prob.param_dict['u_bar'].value = u.guess
91
127
 
128
+ # Make a tuple from list of parameter values
129
+ param_values = tuple([param.value for _, param in params])
130
+
92
131
  t0 = time.time()
93
- A_bar, B_bar, C_bar, z_bar, V_multi_shoot = aug_dy(x_bar, u_bar.astype(float))
132
+ A_bar, B_bar, C_bar, z_bar, V_multi_shoot = aug_dy.call(x.guess, u.guess.astype(float), *param_values)
94
133
 
95
134
  prob.param_dict['A_d'].value = A_bar.__array__()
96
135
  prob.param_dict['B_d'].value = B_bar.__array__()
@@ -98,52 +137,52 @@ def PTR_subproblem(cpg_solve, x_bar, u_bar, aug_dy, prob, params: Config):
98
137
  prob.param_dict['z_d'].value = z_bar.__array__()
99
138
  dis_time = time.time() - t0
100
139
 
101
- if params.sim.constraints_nodal:
102
- for g_id, constraint in enumerate(params.sim.constraints_nodal):
140
+ if settings.sim.constraints_nodal:
141
+ for g_id, constraint in enumerate(settings.sim.constraints_nodal):
103
142
  if not constraint.convex:
104
- prob.param_dict['g_' + str(g_id)].value = np.asarray(constraint.g(x_bar, u_bar))
105
- prob.param_dict['grad_g_x_' + str(g_id)].value = np.asarray(constraint.grad_g_x(x_bar, u_bar))
106
- prob.param_dict['grad_g_u_' + str(g_id)].value = np.asarray(constraint.grad_g_u(x_bar, u_bar))
143
+ prob.param_dict['g_' + str(g_id)].value = np.asarray(constraint.g(x.guess, u.guess))
144
+ prob.param_dict['grad_g_x_' + str(g_id)].value = np.asarray(constraint.grad_g_x(x.guess, u.guess))
145
+ prob.param_dict['grad_g_u_' + str(g_id)].value = np.asarray(constraint.grad_g_u(x.guess, u.guess))
107
146
 
108
- prob.param_dict['w_tr'].value = params.scp.w_tr
109
- prob.param_dict['lam_cost'].value = params.scp.lam_cost
147
+ prob.param_dict['w_tr'].value = settings.scp.w_tr
148
+ prob.param_dict['lam_cost'].value = settings.scp.lam_cost
110
149
 
111
- if params.cvx.cvxpygen:
150
+ if settings.cvx.cvxpygen:
112
151
  t0 = time.time()
113
152
  prob.register_solve('CPG', cpg_solve)
114
- prob.solve(method = 'CPG', **params.cvx.solver_args)
153
+ prob.solve(method = 'CPG', **settings.cvx.solver_args)
115
154
  subprop_time = time.time() - t0
116
155
  else:
117
156
  t0 = time.time()
118
- prob.solve(solver = params.cvx.solver, enforce_dpp = True, **params.cvx.solver_args)
157
+ prob.solve(solver = settings.cvx.solver, **settings.cvx.solver_args)
119
158
  subprop_time = time.time() - t0
120
159
 
121
- x = (params.sim.S_x @ prob.var_dict['x'].value.T + np.expand_dims(params.sim.c_x, axis = 1)).T
122
- u = (params.sim.S_u @ prob.var_dict['u'].value.T + np.expand_dims(params.sim.c_u, axis = 1)).T
160
+ x_new_guess = (settings.sim.S_x @ prob.var_dict['x'].value.T + np.expand_dims(settings.sim.c_x, axis = 1)).T
161
+ u_new_guess = (settings.sim.S_u @ prob.var_dict['u'].value.T + np.expand_dims(settings.sim.c_u, axis = 1)).T
123
162
 
124
163
  i = 0
125
164
  costs = [0]
126
- for type in params.sim.final_state.type:
165
+ for type in x.final_type:
127
166
  if type == 'Minimize':
128
- costs += x[:,i]
167
+ costs += x_new_guess[:,i]
129
168
  if type == 'Maximize':
130
- costs -= x[:,i]
169
+ costs -= x_new_guess[:,i]
131
170
  i += 1
132
171
 
133
172
  # Create the block diagonal matrix using jax.numpy.block
134
173
  inv_block_diag = np.block([
135
- [params.sim.inv_S_x, np.zeros((params.sim.inv_S_x.shape[0], params.sim.inv_S_u.shape[1]))],
136
- [np.zeros((params.sim.inv_S_u.shape[0], params.sim.inv_S_x.shape[1])), params.sim.inv_S_u]
174
+ [settings.sim.inv_S_x, np.zeros((settings.sim.inv_S_x.shape[0], settings.sim.inv_S_u.shape[1]))],
175
+ [np.zeros((settings.sim.inv_S_u.shape[0], settings.sim.inv_S_x.shape[1])), settings.sim.inv_S_u]
137
176
  ])
138
177
 
139
178
  # Calculate J_tr_vec using the JAX-compatible block diagonal matrix
140
- J_tr_vec = la.norm(inv_block_diag @ np.hstack((x - x_bar, u - u_bar)).T, axis=0)**2
179
+ J_tr_vec = la.norm(inv_block_diag @ np.hstack((x_new_guess - x.guess, u_new_guess - u.guess)).T, axis=0)**2
141
180
  J_vc_vec = np.sum(np.abs(prob.var_dict['nu'].value), axis = 1)
142
181
 
143
182
  id_ncvx = 0
144
183
  J_vb_vec = 0
145
- for constraint in params.sim.constraints_nodal:
184
+ for constraint in settings.sim.constraints_nodal:
146
185
  if constraint.convex == False:
147
186
  J_vb_vec += np.maximum(0, prob.var_dict['nu_vb_' + str(id_ncvx)].value)
148
187
  id_ncvx += 1
149
- return x, u, costs, prob.value, J_vb_vec, J_vc_vec, J_tr_vec, prob.status, V_multi_shoot, subprop_time, dis_time
188
+ return x_new_guess, u_new_guess, costs, prob.value, J_vb_vec, J_vc_vec, J_tr_vec, prob.status, V_multi_shoot, subprop_time, dis_time
openscvx/results.py ADDED
@@ -0,0 +1,153 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Optional, Any, Dict
3
+ import numpy as np
4
+
5
+ from openscvx.backend.state import State
6
+ from openscvx.backend.control import Control
7
+
8
+
9
+ @dataclass
10
+ class OptimizationResults:
11
+ """
12
+ Class to hold optimization results from the SCP solver.
13
+
14
+ This class replaces the dictionary-based results structure with a more
15
+ structured and type-safe approach.
16
+ """
17
+
18
+ # Core optimization results
19
+ converged: bool
20
+ t_final: float
21
+ u: Control
22
+ x: State
23
+
24
+ # History of SCP iterations
25
+ x_history: List[np.ndarray] = field(default_factory=list)
26
+ u_history: List[np.ndarray] = field(default_factory=list)
27
+ discretization_history: List[np.ndarray] = field(default_factory=list)
28
+ J_tr_history: List[np.ndarray] = field(default_factory=list)
29
+ J_vb_history: List[np.ndarray] = field(default_factory=list)
30
+ J_vc_history: List[np.ndarray] = field(default_factory=list)
31
+
32
+ # Post-processing results (added by propagate_trajectory_results)
33
+ t_full: Optional[np.ndarray] = None
34
+ x_full: Optional[np.ndarray] = None
35
+ u_full: Optional[np.ndarray] = None
36
+ cost: Optional[float] = None
37
+ ctcs_violation: Optional[np.ndarray] = None
38
+
39
+ # Additional plotting/application data (added by user)
40
+ plotting_data: Dict[str, Any] = field(default_factory=dict)
41
+
42
+ def __post_init__(self):
43
+ """Initialize the results object."""
44
+ pass
45
+
46
+ def update_plotting_data(self, **kwargs):
47
+ """
48
+ Update the plotting data with additional information.
49
+
50
+ Args:
51
+ **kwargs: Key-value pairs to add to plotting_data
52
+ """
53
+ self.plotting_data.update(kwargs)
54
+
55
+ def get(self, key: str, default: Any = None) -> Any:
56
+ """
57
+ Get a value from the results, similar to dict.get().
58
+
59
+ Args:
60
+ key: The key to look up
61
+ default: Default value if key is not found
62
+
63
+ Returns:
64
+ The value associated with the key, or default if not found
65
+ """
66
+ # Check if it's a direct attribute
67
+ if hasattr(self, key):
68
+ return getattr(self, key)
69
+
70
+ # Check if it's in plotting_data
71
+ if key in self.plotting_data:
72
+ return self.plotting_data[key]
73
+
74
+ return default
75
+
76
+ def __getitem__(self, key: str) -> Any:
77
+ """
78
+ Allow dictionary-style access to results.
79
+
80
+ Args:
81
+ key: The key to look up
82
+
83
+ Returns:
84
+ The value associated with the key
85
+
86
+ Raises:
87
+ KeyError: If key is not found
88
+ """
89
+ # Check if it's a direct attribute
90
+ if hasattr(self, key):
91
+ return getattr(self, key)
92
+
93
+ # Check if it's in plotting_data
94
+ if key in self.plotting_data:
95
+ return self.plotting_data[key]
96
+
97
+ raise KeyError(f"Key '{key}' not found in results")
98
+
99
+ def __setitem__(self, key: str, value: Any):
100
+ """
101
+ Allow dictionary-style assignment to results.
102
+
103
+ Args:
104
+ key: The key to set
105
+ value: The value to assign
106
+ """
107
+ # Check if it's a direct attribute
108
+ if hasattr(self, key):
109
+ setattr(self, key, value)
110
+ else:
111
+ # Store in plotting_data
112
+ self.plotting_data[key] = value
113
+
114
+ def __contains__(self, key: str) -> bool:
115
+ """
116
+ Check if a key exists in the results.
117
+
118
+ Args:
119
+ key: The key to check
120
+
121
+ Returns:
122
+ True if key exists, False otherwise
123
+ """
124
+ return hasattr(self, key) or key in self.plotting_data
125
+
126
+ def update(self, other: Dict[str, Any]):
127
+ """
128
+ Update the results with additional data from a dictionary.
129
+
130
+ Args:
131
+ other: Dictionary containing additional data
132
+ """
133
+ for key, value in other.items():
134
+ self[key] = value
135
+
136
+ def to_dict(self) -> Dict[str, Any]:
137
+ """
138
+ Convert the results to a dictionary for backward compatibility.
139
+
140
+ Returns:
141
+ Dictionary representation of the results
142
+ """
143
+ result_dict = {}
144
+
145
+ # Add all direct attributes
146
+ for attr_name in self.__dataclass_fields__:
147
+ if attr_name != 'plotting_data':
148
+ result_dict[attr_name] = getattr(self, attr_name)
149
+
150
+ # Add plotting data
151
+ result_dict.update(self.plotting_data)
152
+
153
+ return result_dict