openscvx 0.1.3__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

@@ -3,9 +3,13 @@ from typing import List, Union, Optional
3
3
  import queue
4
4
  import threading
5
5
  import time
6
+ from pathlib import Path
7
+ from copy import deepcopy
6
8
 
7
9
  import cvxpy as cp
8
10
  import jax
11
+ from jax import export, ShapeDtypeStruct
12
+ from functools import partial
9
13
  import numpy as np
10
14
 
11
15
  from openscvx.config import (
@@ -23,13 +27,18 @@ from openscvx.augmentation.ctcs import sort_ctcs_constraints
23
27
  from openscvx.constraints.violation import get_g_funcs, CTCSViolation
24
28
  from openscvx.discretization import get_discretization_solver
25
29
  from openscvx.propagation import get_propagation_solver
26
- from openscvx.constraints.boundary import BoundaryConstraint, boundary
27
30
  from openscvx.constraints.ctcs import CTCSConstraint
28
31
  from openscvx.constraints.nodal import NodalConstraint
29
- from openscvx.ptr import PTR_init, PTR_main
32
+ from openscvx.ptr import PTR_init, PTR_subproblem, format_result
30
33
  from openscvx.post_processing import propagate_trajectory_results
31
34
  from openscvx.ocp import OptimalControlProblem
32
35
  from openscvx import io
36
+ from openscvx.utils import stable_function_hash
37
+ from openscvx.backend.state import State, Free
38
+ from openscvx.backend.control import Control
39
+ from openscvx.backend.parameter import Parameter
40
+ from openscvx.results import OptimizationResults
41
+
33
42
 
34
43
 
35
44
  # TODO: (norrisg) Decide whether to have constraints`, `cost`, alongside `dynamics`, ` etc.
@@ -38,19 +47,13 @@ class TrajOptProblem:
38
47
  self,
39
48
  dynamics: Dynamics,
40
49
  constraints: List[Union[CTCSConstraint, NodalConstraint]],
41
- idx_time: int,
50
+ x: State,
51
+ u: Control,
42
52
  N: int,
43
- time_init: float,
44
- x_guess: jnp.ndarray,
45
- u_guess: jnp.ndarray,
46
- initial_state: BoundaryConstraint,
47
- final_state: BoundaryConstraint,
48
- x_max: jnp.ndarray,
49
- x_min: jnp.ndarray,
50
- u_max: jnp.ndarray,
51
- u_min: jnp.ndarray,
53
+ idx_time: int,
54
+ params: dict = {},
52
55
  dynamics_prop: callable = None,
53
- initial_state_prop: BoundaryConstraint = None,
56
+ x_prop: State = None,
54
57
  scp: Optional[ScpConfig] = None,
55
58
  dis: Optional[DiscretizationConfig] = None,
56
59
  prp: Optional[PropagationConfig] = None,
@@ -62,11 +65,44 @@ class TrajOptProblem:
62
65
  time_dilation_factor_min=0.3,
63
66
  time_dilation_factor_max=3.0,
64
67
  ):
68
+ """
69
+ The primary class in charge of compiling and exporting the solvers
70
+
71
+
72
+ Args:
73
+ dynamics (Dynamics): Dynamics function decorated with @dynamics
74
+ constraints (List[Union[CTCSConstraint, NodalConstraint]]): List of constraints decorated with @ctcs or @nodal
75
+ idx_time (int): Index of the time variable in the state vector
76
+ N (int): Number of segments in the trajectory
77
+ time_init (float): Initial time for the trajectory
78
+ x_guess (jnp.ndarray): Initial guess for the state trajectory
79
+ u_guess (jnp.ndarray): Initial guess for the control trajectory
80
+ initial_state (BoundaryConstraint): Initial state constraint
81
+ final_state (BoundaryConstraint): Final state constraint
82
+ x_max (jnp.ndarray): Upper bound on the state variables
83
+ x_min (jnp.ndarray): Lower bound on the state variables
84
+ u_max (jnp.ndarray): Upper bound on the control variables
85
+ u_min (jnp.ndarray): Lower bound on the control variables
86
+ dynamics_prop: Propagation dynamics function decorated with @dynamics
87
+ initial_state_prop: Propagation initial state constraint
88
+ scp: SCP configuration object
89
+ dis: Discretization configuration object
90
+ prp: Propagation configuration object
91
+ sim: Simulation configuration object
92
+ dev: Development configuration object
93
+ cvx: Convex solver configuration object
94
+
95
+ Returns:
96
+ None
97
+ """
98
+
99
+ self.params = params
100
+
65
101
  if dynamics_prop is None:
66
102
  dynamics_prop = dynamics
67
103
 
68
- if initial_state_prop is None:
69
- initial_state_prop = initial_state
104
+ if x_prop is None:
105
+ x_prop = deepcopy(x)
70
106
 
71
107
  # TODO (norrisg) move this into some augmentation function, if we want to make this be executed after the init (i.e. within problem.initialize) need to rethink how problem is defined
72
108
  constraints_ctcs = []
@@ -88,9 +124,9 @@ class TrajOptProblem:
88
124
  constraints_ctcs, node_intervals, num_augmented_states = sort_ctcs_constraints(constraints_ctcs, N)
89
125
 
90
126
  # Index tracking
91
- idx_x_true = slice(0, len(initial_state.value))
92
- idx_x_true_prop = slice(0, len(initial_state_prop.value))
93
- idx_u_true = slice(0, len(u_max))
127
+ idx_x_true = slice(0, x.shape[0])
128
+ idx_x_true_prop = slice(0, x_prop.shape[0])
129
+ idx_u_true = slice(0, u.shape[0])
94
130
  idx_constraint_violation = slice(
95
131
  idx_x_true.stop, idx_x_true.stop + num_augmented_states
96
132
  )
@@ -102,43 +138,41 @@ class TrajOptProblem:
102
138
 
103
139
  # check that idx_time is in the correct range
104
140
  assert idx_time >= 0 and idx_time < len(
105
- x_max
141
+ x.max
106
142
  ), "idx_time must be in the range of the state vector and non-negative"
107
143
  idx_time = slice(idx_time, idx_time + 1)
108
144
 
109
- x_min_augmented = np.hstack([x_min, np.repeat(licq_min, num_augmented_states)])
110
- x_max_augmented = np.hstack([x_max, np.repeat(licq_max, num_augmented_states)])
145
+ # Create a new state object for the augmented states
146
+ if num_augmented_states != 0:
147
+ y = State(name="y", shape=(num_augmented_states,))
148
+ y.initial = np.zeros((num_augmented_states,))
149
+ y.final = np.array([Free(0)] * num_augmented_states)
150
+ y.guess = np.zeros((N, num_augmented_states,))
151
+ y.min = np.zeros((num_augmented_states,))
152
+ y.max = licq_max * np.ones((num_augmented_states,))
153
+
154
+ x.append(y, augmented=True)
155
+ x_prop.append(y, augmented=True)
156
+
157
+ s = Control(name="s", shape=(1,))
158
+ s.min = np.array([time_dilation_factor_min * x.final[idx_time][0]])
159
+ s.max = np.array([time_dilation_factor_max * x.final[idx_time][0]])
160
+ s.guess = np.ones((N, 1)) * x.final[idx_time][0]
111
161
 
112
- u_min_augmented = np.hstack([u_min, time_dilation_factor_min * time_init])
113
- u_max_augmented = np.hstack([u_max, time_dilation_factor_max * time_init])
114
-
115
- x_bar_augmented = np.hstack([x_guess, np.full((x_guess.shape[0], num_augmented_states), 0)])
116
- u_bar_augmented = np.hstack(
117
- [u_guess, np.full((u_guess.shape[0], 1), time_init)]
118
- )
119
-
120
- initial_state_prop_values = np.hstack([initial_state_prop.value, np.repeat(licq_min, num_augmented_states)])
121
- initial_state_prop_types = np.hstack([initial_state_prop.type, ["Fix"] * num_augmented_states])
122
- initial_state_prop = boundary(initial_state_prop_values)
123
- initial_state_prop.types = initial_state_prop_types
162
+
163
+ u.append(s, augmented=True)
124
164
 
125
165
  if dis is None:
126
166
  dis = DiscretizationConfig()
127
167
 
128
168
  if sim is None:
129
169
  sim = SimConfig(
130
- x_bar=x_bar_augmented,
131
- u_bar=u_bar_augmented,
132
- initial_state=initial_state,
133
- initial_state_prop=initial_state_prop,
134
- final_state=final_state,
135
- max_state=x_max_augmented,
136
- min_state=x_min_augmented,
137
- max_control=u_max_augmented,
138
- min_control=u_min_augmented,
139
- total_time=time_init,
140
- n_states=len(initial_state.value),
141
- n_states_prop=len(initial_state_prop.value),
170
+ x=x,
171
+ x_prop=x_prop,
172
+ u=u,
173
+ total_time=x.initial[idx_time][0],
174
+ n_states=x.initial.shape[0],
175
+ n_states_prop=x_prop.initial.shape[0],
142
176
  idx_x_true=idx_x_true,
143
177
  idx_x_true_prop=idx_x_true_prop,
144
178
  idx_u_true=idx_u_true,
@@ -152,22 +186,11 @@ class TrajOptProblem:
152
186
  if scp is None:
153
187
  scp = ScpConfig(
154
188
  n=N,
155
- k_max=200,
156
- w_tr=1e1, # Weight on the Trust Reigon
157
- lam_cost=1e1, # Weight on the Nonlinear Cost
158
- lam_vc=1e2, # Weight on the Virtual Control Objective
159
- lam_vb=0e0, # Weight on the Virtual Buffer Objective (only for penalized nodal constraints)
160
- ep_tr=1e-4, # Trust Region Tolerance
161
- ep_vb=1e-4, # Virtual Control Tolerance
162
- ep_vc=1e-8, # Virtual Control Tolerance for CTCS
163
- cost_drop=4, # SCP iteration to relax minimal final time objective
164
- cost_relax=0.5, # Minimal Time Relaxation Factor
165
- w_tr_adapt=1.2, # Trust Region Adaptation Factor
166
189
  w_tr_max_scaling_factor=1e2, # Maximum Trust Region Weight
167
190
  )
168
191
  else:
169
192
  assert (
170
- self.scp.n == N
193
+ self.settings.scp.n == N
171
194
  ), "Number of segments must be the same as in the config"
172
195
 
173
196
  if dev is None:
@@ -184,7 +207,7 @@ class TrajOptProblem:
184
207
  self.dynamics_augmented = build_augmented_dynamics(dynamics, ctcs_violation_funcs, idx_x_true, idx_u_true)
185
208
  self.dynamics_augmented_prop = build_augmented_dynamics(dynamics_prop, ctcs_violation_funcs, idx_x_true_prop, idx_u_true)
186
209
 
187
- self.params = Config(
210
+ self.settings = Config(
188
211
  sim=sim,
189
212
  scp=scp,
190
213
  dis=dis,
@@ -192,18 +215,18 @@ class TrajOptProblem:
192
215
  cvx=cvx,
193
216
  prp=prp,
194
217
  )
195
-
218
+
196
219
  self.optimal_control_problem: cp.Problem = None
197
220
  self.discretization_solver: callable = None
198
221
  self.cpg_solve = None
199
222
 
200
223
  # set up emitter & thread only if printing is enabled
201
- if self.params.dev.printing:
224
+ if self.settings.dev.printing:
202
225
  self.print_queue = queue.Queue()
203
226
  self.emitter_function = lambda data: self.print_queue.put(data)
204
227
  self.print_thread = threading.Thread(
205
228
  target=io.intermediate,
206
- args=(self.print_queue, self.params),
229
+ args=(self.print_queue, self.settings),
207
230
  daemon=True,
208
231
  )
209
232
  self.print_thread.start()
@@ -216,11 +239,23 @@ class TrajOptProblem:
216
239
  self.timing_solve = None
217
240
  self.timing_post = None
218
241
 
242
+ # SCP state variables
243
+ self.scp_k = 0
244
+ self.scp_J_tr = 1e2
245
+ self.scp_J_vb = 1e2
246
+ self.scp_J_vc = 1e2
247
+ self.scp_trajs = []
248
+ self.scp_controls = []
249
+ self.scp_V_multi_shoot_traj = []
250
+
219
251
  def initialize(self):
220
252
  io.intro()
221
253
 
254
+ # Print problem summary
255
+ io.print_problem_summary(self.settings)
256
+
222
257
  # Enable the profiler
223
- if self.params.dev.profiling:
258
+ if self.settings.dev.profiling:
224
259
  import cProfile
225
260
 
226
261
  pr = cProfile.Profile()
@@ -228,18 +263,17 @@ class TrajOptProblem:
228
263
 
229
264
  t_0_while = time.time()
230
265
  # Ensure parameter sizes and normalization are correct
231
- self.params.scp.__post_init__()
232
- self.params.sim.__post_init__()
266
+ self.settings.scp.__post_init__()
267
+ self.settings.sim.__post_init__()
233
268
 
234
269
  # Compile dynamics and jacobians
235
- self.dynamics_augmented.f = jax.vmap(self.dynamics_augmented.f)
236
- self.dynamics_augmented.A = jax.jit(jax.vmap(self.dynamics_augmented.A, in_axes=(0, 0, 0)))
237
- self.dynamics_augmented.B = jax.jit(jax.vmap(self.dynamics_augmented.B, in_axes=(0, 0, 0)))
270
+ self.dynamics_augmented.f = jax.vmap(self.dynamics_augmented.f, in_axes=(0, 0, 0, *(None,) * len(self.params)))
271
+ self.dynamics_augmented.A = jax.vmap(self.dynamics_augmented.A, in_axes=(0, 0, 0, *(None,) * len(self.params)))
272
+ self.dynamics_augmented.B = jax.vmap(self.dynamics_augmented.B, in_axes=(0, 0, 0, *(None,) * len(self.params)))
273
+
274
+ self.dynamics_augmented_prop.f = jax.vmap(self.dynamics_augmented_prop.f, in_axes=(0, 0, 0, *(None,) * len(self.params)))
238
275
 
239
-
240
- self.dynamics_augmented_prop.f = jax.vmap(self.dynamics_augmented_prop.f)
241
-
242
- for constraint in self.params.sim.constraints_nodal:
276
+ for constraint in self.settings.sim.constraints_nodal:
243
277
  if not constraint.convex:
244
278
  # TODO: (haynec) switch to AOT instead of JIT
245
279
  constraint.g = jax.jit(constraint.g)
@@ -247,55 +281,241 @@ class TrajOptProblem:
247
281
  constraint.grad_g_u = jax.jit(constraint.grad_g_u)
248
282
 
249
283
  # Generate solvers and optimal control problem
250
- self.discretization_solver = get_discretization_solver(self.dynamics_augmented, self.params)
251
- self.propagation_solver = get_propagation_solver(self.dynamics_augmented_prop.f, self.params)
252
- self.optimal_control_problem = OptimalControlProblem(self.params)
253
-
254
- # Initialize the PTR loop
255
- self.cpg_solve = PTR_init(
256
- self.optimal_control_problem,
257
- self.discretization_solver,
258
- self.params,
284
+ self.discretization_solver = get_discretization_solver(self.dynamics_augmented, self.settings, self.params)
285
+ self.propagation_solver = get_propagation_solver(self.dynamics_augmented_prop.f, self.settings, self.params)
286
+ self.optimal_control_problem = OptimalControlProblem(self.settings)
287
+
288
+ # Collect all relevant functions
289
+ functions_to_hash = [self.dynamics_augmented.f, self.dynamics_augmented_prop.f]
290
+ for constraint in self.settings.sim.constraints_nodal:
291
+ functions_to_hash.append(constraint.func)
292
+ for constraint in self.settings.sim.constraints_ctcs:
293
+ functions_to_hash.append(constraint.func)
294
+
295
+ # Get unique source-based hash
296
+ function_hash = stable_function_hash(
297
+ functions_to_hash,
298
+ n_discretization_nodes=self.settings.scp.n,
299
+ dt=self.settings.prp.dt,
300
+ total_time=self.settings.sim.total_time,
301
+ state_max=self.settings.sim.x.max,
302
+ state_min=self.settings.sim.x.min,
303
+ control_max=self.settings.sim.u.max,
304
+ control_min=self.settings.sim.u.min
259
305
  )
260
306
 
307
+ solver_dir = Path(".tmp")
308
+ solver_dir.mkdir(parents=True, exist_ok=True)
309
+ dis_solver_file = solver_dir / f"compiled_discretization_solver_{function_hash}.jax"
310
+ prop_solver_file = solver_dir / f"compiled_propagation_solver_{function_hash}.jax"
311
+
312
+
261
313
  # Compile the solvers
262
- if not self.params.dev.debug:
263
- self.discretization_solver = (
264
- jax.jit(self.discretization_solver)
265
- .lower(
266
- np.ones((self.params.scp.n, self.params.sim.n_states)),
267
- np.ones((self.params.scp.n, self.params.sim.n_controls)),
314
+ if not self.settings.dev.debug:
315
+ if self.settings.sim.save_compiled:
316
+ # Check if the compiled file already exists
317
+ try:
318
+ with open(dis_solver_file, "rb") as f:
319
+ serial_dis = f.read()
320
+ # Load the compiled code
321
+ self.discretization_solver = export.deserialize(serial_dis)
322
+ print("✓ Loaded existing discretization solver")
323
+ except FileNotFoundError:
324
+ print("Compiling discretization solver...")
325
+ # Extract parameter values and names in order
326
+ param_values = [param.value for _, param in self.params.items()]
327
+
328
+ self.discretization_solver = export.export(jax.jit(self.discretization_solver))(
329
+ np.ones((self.settings.scp.n, self.settings.sim.n_states)),
330
+ np.ones((self.settings.scp.n, self.settings.sim.n_controls)),
331
+ *param_values
332
+ )
333
+ # Serialize and Save the compiled code in a temp directory
334
+ with open(dis_solver_file, "wb") as f:
335
+ f.write(self.discretization_solver.serialize())
336
+ print("✓ Discretization solver compiled and saved")
337
+ else:
338
+ print("Compiling discretization solver (not saving/loading from disk)...")
339
+ param_values = [param.value for _, param in self.params.items()]
340
+ self.discretization_solver = export.export(jax.jit(self.discretization_solver))(
341
+ np.ones((self.settings.scp.n, self.settings.sim.n_states)),
342
+ np.ones((self.settings.scp.n, self.settings.sim.n_controls)),
343
+ *param_values
344
+ )
345
+
346
+ # Compile the discretization solver and save it
347
+ dtau = 1.0 / (self.settings.scp.n - 1)
348
+ dt_max = self.settings.sim.u.max[self.settings.sim.idx_s][0] * dtau
349
+
350
+ self.settings.prp.max_tau_len = int(dt_max / self.settings.prp.dt) + 2
351
+
352
+ # Check if the compiled file already exists
353
+ if self.settings.sim.save_compiled:
354
+ try:
355
+ with open(prop_solver_file, "rb") as f:
356
+ serial_prop = f.read()
357
+ # Load the compiled code
358
+ self.propagation_solver = export.deserialize(serial_prop)
359
+ print("✓ Loaded existing propagation solver")
360
+ except FileNotFoundError:
361
+ print("Compiling propagation solver...")
362
+ # Extract parameter values and names in order
363
+ param_values = [param.value for _, param in self.params.items()]
364
+
365
+ propagation_solver = export.export(jax.jit(self.propagation_solver))(
366
+ np.ones((self.settings.sim.n_states_prop)), # x_0
367
+ (0.0, 0.0), # time span
368
+ np.ones((1, self.settings.sim.n_controls)), # controls_current
369
+ np.ones((1, self.settings.sim.n_controls)), # controls_next
370
+ np.ones((1, 1)), # tau_0
371
+ np.ones((1, 1)).astype("int"), # segment index
372
+ 0, # idx_s_stop
373
+ np.ones((self.settings.prp.max_tau_len,)), # save_time (tau_cur_padded)
374
+ np.ones((self.settings.prp.max_tau_len,), dtype=bool), # mask_padded (boolean mask)
375
+ *param_values, # additional parameters
268
376
  )
269
- .compile()
270
- )
271
377
 
272
- self.propagation_solver = (
273
- jax.jit(self.propagation_solver)
274
- .lower(
275
- np.ones((self.params.sim.n_states_prop)),
276
- (0.0, 0.0),
277
- np.ones((1, self.params.sim.n_controls)),
278
- np.ones((1, self.params.sim.n_controls)),
279
- np.ones((1, 1)),
280
- np.ones((1, 1)).astype("int"),
281
- 0,
378
+ # Serialize and Save the compiled code in a temp directory
379
+ self.propagation_solver = propagation_solver
380
+
381
+ with open(prop_solver_file, "wb") as f:
382
+ f.write(self.propagation_solver.serialize())
383
+ print("✓ Propagation solver compiled and saved")
384
+ else:
385
+ print("Compiling propagation solver (not saving/loading from disk)...")
386
+ param_values = [param.value for _, param in self.params.items()]
387
+ propagation_solver = export.export(jax.jit(self.propagation_solver))(
388
+ np.ones((self.settings.sim.n_states_prop)), # x_0
389
+ (0.0, 0.0), # time span
390
+ np.ones((1, self.settings.sim.n_controls)), # controls_current
391
+ np.ones((1, self.settings.sim.n_controls)), # controls_next
392
+ np.ones((1, 1)), # tau_0
393
+ np.ones((1, 1)).astype("int"), # segment index
394
+ 0, # idx_s_stop
395
+ np.ones((self.settings.prp.max_tau_len,)), # save_time (tau_cur_padded)
396
+ np.ones((self.settings.prp.max_tau_len,), dtype=bool), # mask_padded (boolean mask)
397
+ *param_values, # additional parameters
282
398
  )
283
- .compile()
399
+ self.propagation_solver = propagation_solver
400
+
401
+ # Initialize the PTR loop
402
+ print("Initializing the SCvx Subproblem Solver...")
403
+ self.cpg_solve = PTR_init(
404
+ self.params,
405
+ self.optimal_control_problem,
406
+ self.discretization_solver,
407
+ self.settings,
284
408
  )
409
+ print("✓ SCvx Subproblem Solver initialized")
410
+
411
+ # Reset SCP state
412
+ self.scp_k = 1
413
+ self.scp_J_tr = 1e2
414
+ self.scp_J_vb = 1e2
415
+ self.scp_J_vc = 1e2
416
+ self.scp_trajs = [self.settings.sim.x.guess]
417
+ self.scp_controls = [self.settings.sim.u.guess]
418
+ self.scp_V_multi_shoot_traj = []
285
419
 
286
420
  t_f_while = time.time()
287
421
  self.timing_init = t_f_while - t_0_while
288
422
  print("Total Initialization Time: ", self.timing_init)
289
423
 
290
- if self.params.dev.profiling:
424
+ # Robust priming call for propagation_solver.call (no debug prints)
425
+ try:
426
+ x_0 = np.ones(self.settings.sim.x_prop.initial.shape, dtype=self.settings.sim.x_prop.initial.dtype)
427
+ tau_grid = (0.0, 1.0)
428
+ controls_current = np.ones((1, self.settings.sim.u.shape[0]), dtype=self.settings.sim.u.guess.dtype)
429
+ controls_next = np.ones((1, self.settings.sim.u.shape[0]), dtype=self.settings.sim.u.guess.dtype)
430
+ tau_init = np.array([[0.0]], dtype=np.float64)
431
+ node = np.array([[0]], dtype=np.int64)
432
+ idx_s_stop = self.settings.sim.idx_s.stop
433
+ save_time = np.ones((self.settings.prp.max_tau_len,), dtype=np.float64)
434
+ mask_padded = np.ones((self.settings.prp.max_tau_len,), dtype=bool)
435
+ param_values = [np.ones_like(param.value) if hasattr(param.value, 'shape') else float(param.value) for _, param in self.params.items()]
436
+ self.propagation_solver.call(
437
+ x_0, tau_grid, controls_current, controls_next, tau_init, node, idx_s_stop, save_time, mask_padded, *param_values
438
+ )
439
+ except Exception as e:
440
+ print(f"[Initialization] Priming propagation_solver.call failed: {e}")
441
+
442
+ if self.settings.dev.profiling:
291
443
  pr.disable()
292
444
  # Save results so it can be viusualized with snakeviz
293
445
  pr.dump_stats("profiling_initialize.prof")
294
446
 
295
- def solve(self):
447
+ def step(self):
448
+ """Performs a single SCP iteration.
449
+
450
+ This method is designed for real-time plotting and interactive optimization.
451
+ It performs one complete SCP iteration including subproblem solving,
452
+ state updates, and progress emission for real-time visualization.
453
+
454
+ Returns:
455
+ dict: Dictionary containing convergence status and current state
456
+ """
457
+ x = self.settings.sim.x
458
+ u = self.settings.sim.u
459
+
460
+ # Run the subproblem
461
+ x_sol, u_sol, cost, J_total, J_vb_vec, J_vc_vec, J_tr_vec, prob_stat, V_multi_shoot, subprop_time, dis_time = PTR_subproblem(
462
+ self.params.items(),
463
+ self.cpg_solve,
464
+ x,
465
+ u,
466
+ self.discretization_solver,
467
+ self.optimal_control_problem,
468
+ self.settings,
469
+ )
470
+
471
+ # Update state
472
+ self.scp_V_multi_shoot_traj.append(V_multi_shoot)
473
+ x.guess = x_sol
474
+ u.guess = u_sol
475
+ self.scp_trajs.append(x.guess)
476
+ self.scp_controls.append(u.guess)
477
+
478
+ self.scp_J_tr = np.sum(np.array(J_tr_vec))
479
+ self.scp_J_vb = np.sum(np.array(J_vb_vec))
480
+ self.scp_J_vc = np.sum(np.array(J_vc_vec))
481
+
482
+ # Update weights
483
+ self.settings.scp.w_tr = min(self.settings.scp.w_tr * self.settings.scp.w_tr_adapt, self.settings.scp.w_tr_max)
484
+ if self.scp_k > self.settings.scp.cost_drop:
485
+ self.settings.scp.lam_cost = self.settings.scp.lam_cost * self.settings.scp.cost_relax
486
+
487
+ # Emit data
488
+ self.emitter_function(
489
+ {
490
+ "iter": self.scp_k,
491
+ "dis_time": dis_time * 1000.0,
492
+ "subprop_time": subprop_time * 1000.0,
493
+ "J_total": J_total,
494
+ "J_tr": self.scp_J_tr,
495
+ "J_vb": self.scp_J_vb,
496
+ "J_vc": self.scp_J_vc,
497
+ "cost": cost[-1],
498
+ "prob_stat": prob_stat,
499
+ }
500
+ )
501
+
502
+ # Increment counter
503
+ self.scp_k += 1
504
+
505
+ # Create a result dictionary for this step
506
+ return {
507
+ "converged": (self.scp_J_tr < self.settings.scp.ep_tr) and \
508
+ (self.scp_J_vb < self.settings.scp.ep_vb) and \
509
+ (self.scp_J_vc < self.settings.scp.ep_vc),
510
+ "u": u,
511
+ "x": x,
512
+ "V_multi_shoot": V_multi_shoot
513
+ }
514
+
515
+ def solve(self, max_iters: Optional[int] = None, continuous: bool = False) -> OptimizationResults:
296
516
  # Ensure parameter sizes and normalization are correct
297
- self.params.scp.__post_init__()
298
- self.params.sim.__post_init__()
517
+ self.settings.scp.__post_init__()
518
+ self.settings.sim.__post_init__()
299
519
 
300
520
  if self.optimal_control_problem is None or self.discretization_solver is None:
301
521
  raise ValueError(
@@ -303,7 +523,7 @@ class TrajOptProblem:
303
523
  )
304
524
 
305
525
  # Enable the profiler
306
- if self.params.dev.profiling:
526
+ if self.settings.dev.profiling:
307
527
  import cProfile
308
528
 
309
529
  pr = cProfile.Profile()
@@ -312,14 +532,13 @@ class TrajOptProblem:
312
532
  t_0_while = time.time()
313
533
  # Print top header for solver results
314
534
  io.header()
315
-
316
- result = PTR_main(
317
- self.params,
318
- self.optimal_control_problem,
319
- self.discretization_solver,
320
- self.cpg_solve,
321
- self.emitter_function,
322
- )
535
+
536
+ k_max = max_iters if max_iters is not None else self.settings.scp.k_max
537
+
538
+ while self.scp_k <= k_max:
539
+ result = self.step()
540
+ if result["converged"] and not continuous:
541
+ break
323
542
 
324
543
  t_f_while = time.time()
325
544
  self.timing_solve = t_f_while - t_0_while
@@ -328,33 +547,35 @@ class TrajOptProblem:
328
547
  time.sleep(0.1)
329
548
 
330
549
  # Print bottom footer for solver results as well as total computation time
331
- io.footer(self.timing_solve)
550
+ io.footer()
332
551
 
333
552
  # Disable the profiler
334
- if self.params.dev.profiling:
553
+ if self.settings.dev.profiling:
335
554
  pr.disable()
336
555
  # Save results so it can be viusualized with snakeviz
337
556
  pr.dump_stats("profiling_solve.prof")
338
557
 
339
- return result
558
+ return format_result(self, self.scp_k <= k_max)
340
559
 
341
- def post_process(self, result):
560
+ def post_process(self, result: OptimizationResults) -> OptimizationResults:
342
561
  # Enable the profiler
343
- if self.params.dev.profiling:
562
+ if self.settings.dev.profiling:
344
563
  import cProfile
345
564
 
346
565
  pr = cProfile.Profile()
347
566
  pr.enable()
348
567
 
349
568
  t_0_post = time.time()
350
- result = propagate_trajectory_results(self.params, result, self.propagation_solver)
569
+ result = propagate_trajectory_results(self.params, self.settings, result, self.propagation_solver)
351
570
  t_f_post = time.time()
352
571
 
353
572
  self.timing_post = t_f_post - t_0_post
354
- print("Total Post Processing Time: ", self.timing_post)
573
+
574
+ # Print results summary
575
+ io.print_results_summary(result, self.timing_post, self.timing_init, self.timing_solve)
355
576
 
356
577
  # Disable the profiler
357
- if self.params.dev.profiling:
578
+ if self.settings.dev.profiling:
358
579
  pr.disable()
359
580
  # Save results so it can be viusualized with snakeviz
360
581
  pr.dump_stats("profiling_postprocess.prof")