openscvx 0.1.2__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

@@ -1,11 +1,15 @@
1
1
  import jax.numpy as jnp
2
- from typing import List, Union
2
+ from typing import List, Union, Optional
3
3
  import queue
4
4
  import threading
5
5
  import time
6
+ from pathlib import Path
7
+ from copy import deepcopy
6
8
 
7
9
  import cvxpy as cp
8
10
  import jax
11
+ from jax import export, ShapeDtypeStruct
12
+ from functools import partial
9
13
  import numpy as np
10
14
 
11
15
  from openscvx.config import (
@@ -23,13 +27,18 @@ from openscvx.augmentation.ctcs import sort_ctcs_constraints
23
27
  from openscvx.constraints.violation import get_g_funcs, CTCSViolation
24
28
  from openscvx.discretization import get_discretization_solver
25
29
  from openscvx.propagation import get_propagation_solver
26
- from openscvx.constraints.boundary import BoundaryConstraint
27
30
  from openscvx.constraints.ctcs import CTCSConstraint
28
31
  from openscvx.constraints.nodal import NodalConstraint
29
- from openscvx.ptr import PTR_init, PTR_main
32
+ from openscvx.ptr import PTR_init, PTR_subproblem, format_result
30
33
  from openscvx.post_processing import propagate_trajectory_results
31
34
  from openscvx.ocp import OptimalControlProblem
32
35
  from openscvx import io
36
+ from openscvx.utils import stable_function_hash
37
+ from openscvx.backend.state import State, Free
38
+ from openscvx.backend.control import Control
39
+ from openscvx.backend.parameter import Parameter
40
+ from openscvx.results import OptimizationResults
41
+
33
42
 
34
43
 
35
44
  # TODO: (norrisg) Decide whether to have constraints`, `cost`, alongside `dynamics`, ` etc.
@@ -38,28 +47,62 @@ class TrajOptProblem:
38
47
  self,
39
48
  dynamics: Dynamics,
40
49
  constraints: List[Union[CTCSConstraint, NodalConstraint]],
41
- idx_time: int,
50
+ x: State,
51
+ u: Control,
42
52
  N: int,
43
- time_init: float,
44
- x_guess: jnp.ndarray,
45
- u_guess: jnp.ndarray,
46
- initial_state: BoundaryConstraint,
47
- final_state: BoundaryConstraint,
48
- x_max: jnp.ndarray,
49
- x_min: jnp.ndarray,
50
- u_max: jnp.ndarray,
51
- u_min: jnp.ndarray,
52
- scp: ScpConfig = None,
53
- dis: DiscretizationConfig = None,
54
- prp: PropagationConfig = None,
55
- sim: SimConfig = None,
56
- dev: DevConfig = None,
57
- cvx: ConvexSolverConfig = None,
53
+ idx_time: int,
54
+ params: dict = {},
55
+ dynamics_prop: callable = None,
56
+ x_prop: State = None,
57
+ scp: Optional[ScpConfig] = None,
58
+ dis: Optional[DiscretizationConfig] = None,
59
+ prp: Optional[PropagationConfig] = None,
60
+ sim: Optional[SimConfig] = None,
61
+ dev: Optional[DevConfig] = None,
62
+ cvx: Optional[ConvexSolverConfig] = None,
58
63
  licq_min=0.0,
59
64
  licq_max=1e-4,
60
65
  time_dilation_factor_min=0.3,
61
66
  time_dilation_factor_max=3.0,
62
67
  ):
68
+ """
69
+ The primary class in charge of compiling and exporting the solvers
70
+
71
+
72
+ Args:
73
+ dynamics (Dynamics): Dynamics function decorated with @dynamics
74
+ constraints (List[Union[CTCSConstraint, NodalConstraint]]): List of constraints decorated with @ctcs or @nodal
75
+ idx_time (int): Index of the time variable in the state vector
76
+ N (int): Number of segments in the trajectory
77
+ time_init (float): Initial time for the trajectory
78
+ x_guess (jnp.ndarray): Initial guess for the state trajectory
79
+ u_guess (jnp.ndarray): Initial guess for the control trajectory
80
+ initial_state (BoundaryConstraint): Initial state constraint
81
+ final_state (BoundaryConstraint): Final state constraint
82
+ x_max (jnp.ndarray): Upper bound on the state variables
83
+ x_min (jnp.ndarray): Lower bound on the state variables
84
+ u_max (jnp.ndarray): Upper bound on the control variables
85
+ u_min (jnp.ndarray): Lower bound on the control variables
86
+ dynamics_prop: Propagation dynamics function decorated with @dynamics
87
+ initial_state_prop: Propagation initial state constraint
88
+ scp: SCP configuration object
89
+ dis: Discretization configuration object
90
+ prp: Propagation configuration object
91
+ sim: Simulation configuration object
92
+ dev: Development configuration object
93
+ cvx: Convex solver configuration object
94
+
95
+ Returns:
96
+ None
97
+ """
98
+
99
+ self.params = params
100
+
101
+ if dynamics_prop is None:
102
+ dynamics_prop = dynamics
103
+
104
+ if x_prop is None:
105
+ x_prop = deepcopy(x)
63
106
 
64
107
  # TODO (norrisg) move this into some augmentation function, if we want to make this be executed after the init (i.e. within problem.initialize) need to rethink how problem is defined
65
108
  constraints_ctcs = []
@@ -81,50 +124,61 @@ class TrajOptProblem:
81
124
  constraints_ctcs, node_intervals, num_augmented_states = sort_ctcs_constraints(constraints_ctcs, N)
82
125
 
83
126
  # Index tracking
84
- idx_x_true = slice(0, len(x_max))
85
- idx_u_true = slice(0, len(u_max))
127
+ idx_x_true = slice(0, x.shape[0])
128
+ idx_x_true_prop = slice(0, x_prop.shape[0])
129
+ idx_u_true = slice(0, u.shape[0])
86
130
  idx_constraint_violation = slice(
87
131
  idx_x_true.stop, idx_x_true.stop + num_augmented_states
88
132
  )
133
+ idx_constraint_violation_prop = slice(
134
+ idx_x_true_prop.stop, idx_x_true_prop.stop + num_augmented_states
135
+ )
89
136
 
90
137
  idx_time_dilation = slice(idx_u_true.stop, idx_u_true.stop + 1)
91
138
 
92
139
  # check that idx_time is in the correct range
93
140
  assert idx_time >= 0 and idx_time < len(
94
- x_max
141
+ x.max
95
142
  ), "idx_time must be in the range of the state vector and non-negative"
96
143
  idx_time = slice(idx_time, idx_time + 1)
97
144
 
98
- x_min_augmented = np.hstack([x_min, np.repeat(licq_min, num_augmented_states)])
99
- x_max_augmented = np.hstack([x_max, np.repeat(licq_max, num_augmented_states)])
100
-
101
- u_min_augmented = np.hstack([u_min, time_dilation_factor_min * time_init])
102
- u_max_augmented = np.hstack([u_max, time_dilation_factor_max * time_init])
103
-
104
- x_bar_augmented = np.hstack([x_guess, np.full((x_guess.shape[0], num_augmented_states), 0)])
105
- u_bar_augmented = np.hstack(
106
- [u_guess, np.full((u_guess.shape[0], 1), time_init)]
107
- )
145
+ # Create a new state object for the augmented states
146
+ if num_augmented_states != 0:
147
+ y = State(name="y", shape=(num_augmented_states,))
148
+ y.initial = np.zeros((num_augmented_states,))
149
+ y.final = np.array([Free(0)] * num_augmented_states)
150
+ y.guess = np.zeros((N, num_augmented_states,))
151
+ y.min = np.zeros((num_augmented_states,))
152
+ y.max = licq_max * np.ones((num_augmented_states,))
153
+
154
+ x.append(y, augmented=True)
155
+ x_prop.append(y, augmented=True)
156
+
157
+ s = Control(name="s", shape=(1,))
158
+ s.min = np.array([time_dilation_factor_min * x.final[idx_time][0]])
159
+ s.max = np.array([time_dilation_factor_max * x.final[idx_time][0]])
160
+ s.guess = np.ones((N, 1)) * x.final[idx_time][0]
161
+
162
+
163
+ u.append(s, augmented=True)
108
164
 
109
165
  if dis is None:
110
166
  dis = DiscretizationConfig()
111
167
 
112
168
  if sim is None:
113
169
  sim = SimConfig(
114
- x_bar=x_bar_augmented,
115
- u_bar=u_bar_augmented,
116
- initial_state=initial_state,
117
- final_state=final_state,
118
- max_state=x_max_augmented,
119
- min_state=x_min_augmented,
120
- max_control=u_max_augmented,
121
- min_control=u_min_augmented,
122
- total_time=time_init,
123
- n_states=len(x_max),
170
+ x=x,
171
+ x_prop=x_prop,
172
+ u=u,
173
+ total_time=x.initial[idx_time][0],
174
+ n_states=x.initial.shape[0],
175
+ n_states_prop=x_prop.initial.shape[0],
124
176
  idx_x_true=idx_x_true,
177
+ idx_x_true_prop=idx_x_true_prop,
125
178
  idx_u_true=idx_u_true,
126
179
  idx_t=idx_time,
127
180
  idx_y=idx_constraint_violation,
181
+ idx_y_prop=idx_constraint_violation_prop,
128
182
  idx_s=idx_time_dilation,
129
183
  ctcs_node_intervals=node_intervals,
130
184
  )
@@ -132,22 +186,11 @@ class TrajOptProblem:
132
186
  if scp is None:
133
187
  scp = ScpConfig(
134
188
  n=N,
135
- k_max=200,
136
- w_tr=1e1, # Weight on the Trust Reigon
137
- lam_cost=1e1, # Weight on the Nonlinear Cost
138
- lam_vc=1e2, # Weight on the Virtual Control Objective
139
- lam_vb=0e0, # Weight on the Virtual Buffer Objective (only for penalized nodal constraints)
140
- ep_tr=1e-4, # Trust Region Tolerance
141
- ep_vb=1e-4, # Virtual Control Tolerance
142
- ep_vc=1e-8, # Virtual Control Tolerance for CTCS
143
- cost_drop=4, # SCP iteration to relax minimal final time objective
144
- cost_relax=0.5, # Minimal Time Relaxation Factor
145
- w_tr_adapt=1.2, # Trust Region Adaptation Factor
146
189
  w_tr_max_scaling_factor=1e2, # Maximum Trust Region Weight
147
190
  )
148
191
  else:
149
192
  assert (
150
- self.scp.n == N
193
+ self.settings.scp.n == N
151
194
  ), "Number of segments must be the same as in the config"
152
195
 
153
196
  if dev is None:
@@ -162,8 +205,9 @@ class TrajOptProblem:
162
205
 
163
206
  ctcs_violation_funcs = get_g_funcs(constraints_ctcs)
164
207
  self.dynamics_augmented = build_augmented_dynamics(dynamics, ctcs_violation_funcs, idx_x_true, idx_u_true)
208
+ self.dynamics_augmented_prop = build_augmented_dynamics(dynamics_prop, ctcs_violation_funcs, idx_x_true_prop, idx_u_true)
165
209
 
166
- self.params = Config(
210
+ self.settings = Config(
167
211
  sim=sim,
168
212
  scp=scp,
169
213
  dis=dis,
@@ -171,18 +215,18 @@ class TrajOptProblem:
171
215
  cvx=cvx,
172
216
  prp=prp,
173
217
  )
174
-
218
+
175
219
  self.optimal_control_problem: cp.Problem = None
176
220
  self.discretization_solver: callable = None
177
221
  self.cpg_solve = None
178
222
 
179
223
  # set up emitter & thread only if printing is enabled
180
- if self.params.dev.printing:
224
+ if self.settings.dev.printing:
181
225
  self.print_queue = queue.Queue()
182
226
  self.emitter_function = lambda data: self.print_queue.put(data)
183
227
  self.print_thread = threading.Thread(
184
228
  target=io.intermediate,
185
- args=(self.print_queue, self.params),
229
+ args=(self.print_queue, self.settings),
186
230
  daemon=True,
187
231
  )
188
232
  self.print_thread.start()
@@ -195,11 +239,23 @@ class TrajOptProblem:
195
239
  self.timing_solve = None
196
240
  self.timing_post = None
197
241
 
242
+ # SCP state variables
243
+ self.scp_k = 0
244
+ self.scp_J_tr = 1e2
245
+ self.scp_J_vb = 1e2
246
+ self.scp_J_vc = 1e2
247
+ self.scp_trajs = []
248
+ self.scp_controls = []
249
+ self.scp_V_multi_shoot_traj = []
250
+
198
251
  def initialize(self):
199
252
  io.intro()
200
253
 
254
+ # Print problem summary
255
+ io.print_problem_summary(self.settings)
256
+
201
257
  # Enable the profiler
202
- if self.params.dev.profiling:
258
+ if self.settings.dev.profiling:
203
259
  import cProfile
204
260
 
205
261
  pr = cProfile.Profile()
@@ -207,15 +263,17 @@ class TrajOptProblem:
207
263
 
208
264
  t_0_while = time.time()
209
265
  # Ensure parameter sizes and normalization are correct
210
- self.params.scp.__post_init__()
211
- self.params.sim.__post_init__()
266
+ self.settings.scp.__post_init__()
267
+ self.settings.sim.__post_init__()
212
268
 
213
269
  # Compile dynamics and jacobians
214
- self.dynamics_augmented.f = jax.vmap(self.dynamics_augmented.f)
215
- self.dynamics_augmented.A = jax.jit(jax.vmap(self.dynamics_augmented.A, in_axes=(0, 0, 0)))
216
- self.dynamics_augmented.B = jax.jit(jax.vmap(self.dynamics_augmented.B, in_axes=(0, 0, 0)))
270
+ self.dynamics_augmented.f = jax.vmap(self.dynamics_augmented.f, in_axes=(0, 0, 0, *(None,) * len(self.params)))
271
+ self.dynamics_augmented.A = jax.vmap(self.dynamics_augmented.A, in_axes=(0, 0, 0, *(None,) * len(self.params)))
272
+ self.dynamics_augmented.B = jax.vmap(self.dynamics_augmented.B, in_axes=(0, 0, 0, *(None,) * len(self.params)))
273
+
274
+ self.dynamics_augmented_prop.f = jax.vmap(self.dynamics_augmented_prop.f, in_axes=(0, 0, 0, *(None,) * len(self.params)))
217
275
 
218
- for constraint in self.params.sim.constraints_nodal:
276
+ for constraint in self.settings.sim.constraints_nodal:
219
277
  if not constraint.convex:
220
278
  # TODO: (haynec) switch to AOT instead of JIT
221
279
  constraint.g = jax.jit(constraint.g)
@@ -223,55 +281,241 @@ class TrajOptProblem:
223
281
  constraint.grad_g_u = jax.jit(constraint.grad_g_u)
224
282
 
225
283
  # Generate solvers and optimal control problem
226
- self.discretization_solver = get_discretization_solver(self.dynamics_augmented, self.params)
227
- self.propagation_solver = get_propagation_solver(self.dynamics_augmented.f, self.params)
228
- self.optimal_control_problem = OptimalControlProblem(self.params)
229
-
230
- # Initialize the PTR loop
231
- self.cpg_solve = PTR_init(
232
- self.optimal_control_problem,
233
- self.discretization_solver,
234
- self.params,
284
+ self.discretization_solver = get_discretization_solver(self.dynamics_augmented, self.settings, self.params)
285
+ self.propagation_solver = get_propagation_solver(self.dynamics_augmented_prop.f, self.settings, self.params)
286
+ self.optimal_control_problem = OptimalControlProblem(self.settings)
287
+
288
+ # Collect all relevant functions
289
+ functions_to_hash = [self.dynamics_augmented.f, self.dynamics_augmented_prop.f]
290
+ for constraint in self.settings.sim.constraints_nodal:
291
+ functions_to_hash.append(constraint.func)
292
+ for constraint in self.settings.sim.constraints_ctcs:
293
+ functions_to_hash.append(constraint.func)
294
+
295
+ # Get unique source-based hash
296
+ function_hash = stable_function_hash(
297
+ functions_to_hash,
298
+ n_discretization_nodes=self.settings.scp.n,
299
+ dt=self.settings.prp.dt,
300
+ total_time=self.settings.sim.total_time,
301
+ state_max=self.settings.sim.x.max,
302
+ state_min=self.settings.sim.x.min,
303
+ control_max=self.settings.sim.u.max,
304
+ control_min=self.settings.sim.u.min
235
305
  )
236
306
 
307
+ solver_dir = Path(".tmp")
308
+ solver_dir.mkdir(parents=True, exist_ok=True)
309
+ dis_solver_file = solver_dir / f"compiled_discretization_solver_{function_hash}.jax"
310
+ prop_solver_file = solver_dir / f"compiled_propagation_solver_{function_hash}.jax"
311
+
312
+
237
313
  # Compile the solvers
238
- if not self.params.dev.debug:
239
- self.discretization_solver = (
240
- jax.jit(self.discretization_solver)
241
- .lower(
242
- np.ones((self.params.scp.n, self.params.sim.n_states)),
243
- np.ones((self.params.scp.n, self.params.sim.n_controls)),
314
+ if not self.settings.dev.debug:
315
+ if self.settings.sim.save_compiled:
316
+ # Check if the compiled file already exists
317
+ try:
318
+ with open(dis_solver_file, "rb") as f:
319
+ serial_dis = f.read()
320
+ # Load the compiled code
321
+ self.discretization_solver = export.deserialize(serial_dis)
322
+ print("✓ Loaded existing discretization solver")
323
+ except FileNotFoundError:
324
+ print("Compiling discretization solver...")
325
+ # Extract parameter values and names in order
326
+ param_values = [param.value for _, param in self.params.items()]
327
+
328
+ self.discretization_solver = export.export(jax.jit(self.discretization_solver))(
329
+ np.ones((self.settings.scp.n, self.settings.sim.n_states)),
330
+ np.ones((self.settings.scp.n, self.settings.sim.n_controls)),
331
+ *param_values
332
+ )
333
+ # Serialize and Save the compiled code in a temp directory
334
+ with open(dis_solver_file, "wb") as f:
335
+ f.write(self.discretization_solver.serialize())
336
+ print("✓ Discretization solver compiled and saved")
337
+ else:
338
+ print("Compiling discretization solver (not saving/loading from disk)...")
339
+ param_values = [param.value for _, param in self.params.items()]
340
+ self.discretization_solver = export.export(jax.jit(self.discretization_solver))(
341
+ np.ones((self.settings.scp.n, self.settings.sim.n_states)),
342
+ np.ones((self.settings.scp.n, self.settings.sim.n_controls)),
343
+ *param_values
344
+ )
345
+
346
+ # Compile the discretization solver and save it
347
+ dtau = 1.0 / (self.settings.scp.n - 1)
348
+ dt_max = self.settings.sim.u.max[self.settings.sim.idx_s][0] * dtau
349
+
350
+ self.settings.prp.max_tau_len = int(dt_max / self.settings.prp.dt) + 2
351
+
352
+ # Check if the compiled file already exists
353
+ if self.settings.sim.save_compiled:
354
+ try:
355
+ with open(prop_solver_file, "rb") as f:
356
+ serial_prop = f.read()
357
+ # Load the compiled code
358
+ self.propagation_solver = export.deserialize(serial_prop)
359
+ print("✓ Loaded existing propagation solver")
360
+ except FileNotFoundError:
361
+ print("Compiling propagation solver...")
362
+ # Extract parameter values and names in order
363
+ param_values = [param.value for _, param in self.params.items()]
364
+
365
+ propagation_solver = export.export(jax.jit(self.propagation_solver))(
366
+ np.ones((self.settings.sim.n_states_prop)), # x_0
367
+ (0.0, 0.0), # time span
368
+ np.ones((1, self.settings.sim.n_controls)), # controls_current
369
+ np.ones((1, self.settings.sim.n_controls)), # controls_next
370
+ np.ones((1, 1)), # tau_0
371
+ np.ones((1, 1)).astype("int"), # segment index
372
+ 0, # idx_s_stop
373
+ np.ones((self.settings.prp.max_tau_len,)), # save_time (tau_cur_padded)
374
+ np.ones((self.settings.prp.max_tau_len,), dtype=bool), # mask_padded (boolean mask)
375
+ *param_values, # additional parameters
244
376
  )
245
- .compile()
246
- )
247
377
 
248
- self.propagation_solver = (
249
- jax.jit(self.propagation_solver)
250
- .lower(
251
- np.ones((self.params.sim.n_states)),
252
- (0.0, 0.0),
253
- np.ones((1, self.params.sim.n_controls)),
254
- np.ones((1, self.params.sim.n_controls)),
255
- np.ones((1, 1)),
256
- np.ones((1, 1)).astype("int"),
257
- 0,
378
+ # Serialize and Save the compiled code in a temp directory
379
+ self.propagation_solver = propagation_solver
380
+
381
+ with open(prop_solver_file, "wb") as f:
382
+ f.write(self.propagation_solver.serialize())
383
+ print("✓ Propagation solver compiled and saved")
384
+ else:
385
+ print("Compiling propagation solver (not saving/loading from disk)...")
386
+ param_values = [param.value for _, param in self.params.items()]
387
+ propagation_solver = export.export(jax.jit(self.propagation_solver))(
388
+ np.ones((self.settings.sim.n_states_prop)), # x_0
389
+ (0.0, 0.0), # time span
390
+ np.ones((1, self.settings.sim.n_controls)), # controls_current
391
+ np.ones((1, self.settings.sim.n_controls)), # controls_next
392
+ np.ones((1, 1)), # tau_0
393
+ np.ones((1, 1)).astype("int"), # segment index
394
+ 0, # idx_s_stop
395
+ np.ones((self.settings.prp.max_tau_len,)), # save_time (tau_cur_padded)
396
+ np.ones((self.settings.prp.max_tau_len,), dtype=bool), # mask_padded (boolean mask)
397
+ *param_values, # additional parameters
258
398
  )
259
- .compile()
399
+ self.propagation_solver = propagation_solver
400
+
401
+ # Initialize the PTR loop
402
+ print("Initializing the SCvx Subproblem Solver...")
403
+ self.cpg_solve = PTR_init(
404
+ self.params,
405
+ self.optimal_control_problem,
406
+ self.discretization_solver,
407
+ self.settings,
260
408
  )
409
+ print("✓ SCvx Subproblem Solver initialized")
410
+
411
+ # Reset SCP state
412
+ self.scp_k = 1
413
+ self.scp_J_tr = 1e2
414
+ self.scp_J_vb = 1e2
415
+ self.scp_J_vc = 1e2
416
+ self.scp_trajs = [self.settings.sim.x.guess]
417
+ self.scp_controls = [self.settings.sim.u.guess]
418
+ self.scp_V_multi_shoot_traj = []
261
419
 
262
420
  t_f_while = time.time()
263
421
  self.timing_init = t_f_while - t_0_while
264
422
  print("Total Initialization Time: ", self.timing_init)
265
423
 
266
- if self.params.dev.profiling:
424
+ # Robust priming call for propagation_solver.call (no debug prints)
425
+ try:
426
+ x_0 = np.ones(self.settings.sim.x_prop.initial.shape, dtype=self.settings.sim.x_prop.initial.dtype)
427
+ tau_grid = (0.0, 1.0)
428
+ controls_current = np.ones((1, self.settings.sim.u.shape[0]), dtype=self.settings.sim.u.guess.dtype)
429
+ controls_next = np.ones((1, self.settings.sim.u.shape[0]), dtype=self.settings.sim.u.guess.dtype)
430
+ tau_init = np.array([[0.0]], dtype=np.float64)
431
+ node = np.array([[0]], dtype=np.int64)
432
+ idx_s_stop = self.settings.sim.idx_s.stop
433
+ save_time = np.ones((self.settings.prp.max_tau_len,), dtype=np.float64)
434
+ mask_padded = np.ones((self.settings.prp.max_tau_len,), dtype=bool)
435
+ param_values = [np.ones_like(param.value) if hasattr(param.value, 'shape') else float(param.value) for _, param in self.params.items()]
436
+ self.propagation_solver.call(
437
+ x_0, tau_grid, controls_current, controls_next, tau_init, node, idx_s_stop, save_time, mask_padded, *param_values
438
+ )
439
+ except Exception as e:
440
+ print(f"[Initialization] Priming propagation_solver.call failed: {e}")
441
+
442
+ if self.settings.dev.profiling:
267
443
  pr.disable()
268
444
  # Save results so it can be viusualized with snakeviz
269
445
  pr.dump_stats("profiling_initialize.prof")
270
446
 
271
- def solve(self):
447
+ def step(self):
448
+ """Performs a single SCP iteration.
449
+
450
+ This method is designed for real-time plotting and interactive optimization.
451
+ It performs one complete SCP iteration including subproblem solving,
452
+ state updates, and progress emission for real-time visualization.
453
+
454
+ Returns:
455
+ dict: Dictionary containing convergence status and current state
456
+ """
457
+ x = self.settings.sim.x
458
+ u = self.settings.sim.u
459
+
460
+ # Run the subproblem
461
+ x_sol, u_sol, cost, J_total, J_vb_vec, J_vc_vec, J_tr_vec, prob_stat, V_multi_shoot, subprop_time, dis_time = PTR_subproblem(
462
+ self.params.items(),
463
+ self.cpg_solve,
464
+ x,
465
+ u,
466
+ self.discretization_solver,
467
+ self.optimal_control_problem,
468
+ self.settings,
469
+ )
470
+
471
+ # Update state
472
+ self.scp_V_multi_shoot_traj.append(V_multi_shoot)
473
+ x.guess = x_sol
474
+ u.guess = u_sol
475
+ self.scp_trajs.append(x.guess)
476
+ self.scp_controls.append(u.guess)
477
+
478
+ self.scp_J_tr = np.sum(np.array(J_tr_vec))
479
+ self.scp_J_vb = np.sum(np.array(J_vb_vec))
480
+ self.scp_J_vc = np.sum(np.array(J_vc_vec))
481
+
482
+ # Update weights
483
+ self.settings.scp.w_tr = min(self.settings.scp.w_tr * self.settings.scp.w_tr_adapt, self.settings.scp.w_tr_max)
484
+ if self.scp_k > self.settings.scp.cost_drop:
485
+ self.settings.scp.lam_cost = self.settings.scp.lam_cost * self.settings.scp.cost_relax
486
+
487
+ # Emit data
488
+ self.emitter_function(
489
+ {
490
+ "iter": self.scp_k,
491
+ "dis_time": dis_time * 1000.0,
492
+ "subprop_time": subprop_time * 1000.0,
493
+ "J_total": J_total,
494
+ "J_tr": self.scp_J_tr,
495
+ "J_vb": self.scp_J_vb,
496
+ "J_vc": self.scp_J_vc,
497
+ "cost": cost[-1],
498
+ "prob_stat": prob_stat,
499
+ }
500
+ )
501
+
502
+ # Increment counter
503
+ self.scp_k += 1
504
+
505
+ # Create a result dictionary for this step
506
+ return {
507
+ "converged": (self.scp_J_tr < self.settings.scp.ep_tr) and \
508
+ (self.scp_J_vb < self.settings.scp.ep_vb) and \
509
+ (self.scp_J_vc < self.settings.scp.ep_vc),
510
+ "u": u,
511
+ "x": x,
512
+ "V_multi_shoot": V_multi_shoot
513
+ }
514
+
515
+ def solve(self, max_iters: Optional[int] = None, continuous: bool = False) -> OptimizationResults:
272
516
  # Ensure parameter sizes and normalization are correct
273
- self.params.scp.__post_init__()
274
- self.params.sim.__post_init__()
517
+ self.settings.scp.__post_init__()
518
+ self.settings.sim.__post_init__()
275
519
 
276
520
  if self.optimal_control_problem is None or self.discretization_solver is None:
277
521
  raise ValueError(
@@ -279,7 +523,7 @@ class TrajOptProblem:
279
523
  )
280
524
 
281
525
  # Enable the profiler
282
- if self.params.dev.profiling:
526
+ if self.settings.dev.profiling:
283
527
  import cProfile
284
528
 
285
529
  pr = cProfile.Profile()
@@ -288,14 +532,13 @@ class TrajOptProblem:
288
532
  t_0_while = time.time()
289
533
  # Print top header for solver results
290
534
  io.header()
291
-
292
- result = PTR_main(
293
- self.params,
294
- self.optimal_control_problem,
295
- self.discretization_solver,
296
- self.cpg_solve,
297
- self.emitter_function,
298
- )
535
+
536
+ k_max = max_iters if max_iters is not None else self.settings.scp.k_max
537
+
538
+ while self.scp_k <= k_max:
539
+ result = self.step()
540
+ if result["converged"] and not continuous:
541
+ break
299
542
 
300
543
  t_f_while = time.time()
301
544
  self.timing_solve = t_f_while - t_0_while
@@ -304,33 +547,35 @@ class TrajOptProblem:
304
547
  time.sleep(0.1)
305
548
 
306
549
  # Print bottom footer for solver results as well as total computation time
307
- io.footer(self.timing_solve)
550
+ io.footer()
308
551
 
309
552
  # Disable the profiler
310
- if self.params.dev.profiling:
553
+ if self.settings.dev.profiling:
311
554
  pr.disable()
312
555
  # Save results so it can be viusualized with snakeviz
313
556
  pr.dump_stats("profiling_solve.prof")
314
557
 
315
- return result
558
+ return format_result(self, self.scp_k <= k_max)
316
559
 
317
- def post_process(self, result):
560
+ def post_process(self, result: OptimizationResults) -> OptimizationResults:
318
561
  # Enable the profiler
319
- if self.params.dev.profiling:
562
+ if self.settings.dev.profiling:
320
563
  import cProfile
321
564
 
322
565
  pr = cProfile.Profile()
323
566
  pr.enable()
324
567
 
325
568
  t_0_post = time.time()
326
- result = propagate_trajectory_results(self.params, result, self.propagation_solver)
569
+ result = propagate_trajectory_results(self.params, self.settings, result, self.propagation_solver)
327
570
  t_f_post = time.time()
328
571
 
329
572
  self.timing_post = t_f_post - t_0_post
330
- print("Total Post Processing Time: ", self.timing_post)
573
+
574
+ # Print results summary
575
+ io.print_results_summary(result, self.timing_post, self.timing_init, self.timing_solve)
331
576
 
332
577
  # Disable the profiler
333
- if self.params.dev.profiling:
578
+ if self.settings.dev.profiling:
334
579
  pr.disable()
335
580
  # Save results so it can be viusualized with snakeviz
336
581
  pr.dump_stats("profiling_postprocess.prof")