openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
openscvx/problem.py ADDED
@@ -0,0 +1,734 @@
1
+ """Core optimization problem interface for trajectory optimization.
2
+
3
+ This module provides the Problem class, the main entry point for defining
4
+ and solving trajectory optimization problems using Sequential Convex Programming (SCP).
5
+
6
+ Example:
7
+ The prototypical flow is to define a problem, then initialize, solve, and post-process the
8
+ results
9
+
10
+ problem = Problem(dynamics, constraints, states, controls, N, time)
11
+ problem.initialize()
12
+ result = problem.solve()
13
+ result = problem.post_process()
14
+ """
15
+
16
+ import copy
17
+ import os
18
+ import pickle
19
+ import queue
20
+ import threading
21
+ import time
22
+ from typing import TYPE_CHECKING, List, Optional, Union
23
+
24
+ import jax
25
+
26
+ os.environ["EQX_ON_ERROR"] = "nan"
27
+
28
+ from openscvx.algorithms import (
29
+ AlgorithmState,
30
+ OptimizationResults,
31
+ PenalizedTrustRegion,
32
+ )
33
+ from openscvx.config import (
34
+ Config,
35
+ ConvexSolverConfig,
36
+ DevConfig,
37
+ DiscretizationConfig,
38
+ PropagationConfig,
39
+ ScpConfig,
40
+ SimConfig,
41
+ )
42
+ from openscvx.discretization import get_discretization_solver
43
+ from openscvx.expert import ByofSpec
44
+ from openscvx.lowered import LoweredProblem, ParameterDict
45
+ from openscvx.lowered.dynamics import Dynamics
46
+ from openscvx.lowered.jax_constraints import (
47
+ LoweredCrossNodeConstraint,
48
+ LoweredJaxConstraints,
49
+ LoweredNodalConstraint,
50
+ )
51
+ from openscvx.propagation import get_propagation_solver, propagate_trajectory_results
52
+ from openscvx.solvers import optimal_control_problem
53
+ from openscvx.symbolic.builder import preprocess_symbolic_problem
54
+ from openscvx.symbolic.constraint_set import ConstraintSet
55
+ from openscvx.symbolic.expr import CTCS, Constraint
56
+ from openscvx.symbolic.expr.control import Control
57
+ from openscvx.symbolic.expr.state import State
58
+ from openscvx.symbolic.lower import lower_symbolic_problem
59
+ from openscvx.symbolic.problem import SymbolicProblem
60
+ from openscvx.symbolic.time import Time
61
+ from openscvx.utils import printing, profiling
62
+ from openscvx.utils.caching import (
63
+ get_solver_cache_paths,
64
+ load_or_compile_discretization_solver,
65
+ load_or_compile_propagation_solver,
66
+ prime_propagation_solver,
67
+ )
68
+
69
+ if TYPE_CHECKING:
70
+ import cvxpy as cp
71
+
72
+
73
+ class Problem:
74
+ def __init__(
75
+ self,
76
+ dynamics: dict,
77
+ constraints: List[Union[Constraint, CTCS]],
78
+ states: List[State],
79
+ controls: List[Control],
80
+ N: int,
81
+ time: Time,
82
+ *,
83
+ dynamics_prop: Optional[dict] = None,
84
+ states_prop: Optional[List[State]] = None,
85
+ licq_min=0.0,
86
+ licq_max=1e-4,
87
+ time_dilation_factor_min=0.3,
88
+ time_dilation_factor_max=3.0,
89
+ byof: Optional[ByofSpec] = None,
90
+ ):
91
+ """The primary class in charge of compiling and exporting the solvers.
92
+
93
+ Args:
94
+ dynamics (dict): Dictionary mapping state names to their dynamics expressions.
95
+ Each key should be a state name, and each value should be an Expr
96
+ representing the derivative of that state.
97
+ constraints (List[Union[CTCSConstraint, NodalConstraint]]):
98
+ List of constraints decorated with @ctcs or @nodal
99
+ states (List[State]): List of State objects representing the state variables.
100
+ May optionally include a State named "time" (see time parameter below).
101
+ controls (List[Control]): List of Control objects representing the control variables
102
+ N (int): Number of segments in the trajectory
103
+ time (Time): Time configuration object with initial, final, min, max.
104
+ Required. If including a "time" state in states, the Time object will be ignored
105
+ and time properties should be set on the time State object instead.
106
+ dynamics_prop (dict, optional): Dictionary mapping EXTRA state names to their
107
+ dynamics expressions for propagation. Only specify additional states beyond
108
+ optimization states (e.g., {"distance": speed}). Do NOT duplicate optimization
109
+ state dynamics here.
110
+ states_prop (List[State], optional): List of EXTRA State objects for propagation only.
111
+ Only specify additional states beyond optimization states. Used with dynamics_prop.
112
+ licq_min: Minimum LICQ constraint value
113
+ licq_max: Maximum LICQ constraint value
114
+ time_dilation_factor_min: Minimum time dilation factor
115
+ time_dilation_factor_max: Maximum time dilation factor
116
+ byof: Expert mode only. Raw JAX functions to bypass symbolic layer.
117
+ See :class:`openscvx.expert.ByofSpec` for detailed documentation.
118
+
119
+ Returns:
120
+ None
121
+
122
+ Note:
123
+ There are two approaches for handling time:
124
+ 1. Auto-create (simple): Don't include "time" in states, provide Time object
125
+ 2. User-provided (for time-dependent constraints): Include "time" State in states and
126
+ in dynamics dict, don't provide Time object
127
+ """
128
+
129
+ # Symbolic Preprocessing & Augmentation
130
+ self.symbolic: SymbolicProblem = preprocess_symbolic_problem(
131
+ dynamics=dynamics,
132
+ constraints=ConstraintSet(unsorted=list(constraints)),
133
+ states=states,
134
+ controls=controls,
135
+ N=N,
136
+ time=time,
137
+ licq_min=licq_min,
138
+ licq_max=licq_max,
139
+ time_dilation_factor_min=time_dilation_factor_min,
140
+ time_dilation_factor_max=time_dilation_factor_max,
141
+ dynamics_prop_extra=dynamics_prop,
142
+ states_prop_extra=states_prop,
143
+ byof=byof,
144
+ )
145
+
146
+ # Validate byof early (after preprocessing, before lowering) to fail fast
147
+ if byof is not None:
148
+ from openscvx.expert.validation import validate_byof
149
+
150
+ # Calculate unified state and control dimensions from preprocessed states/controls
151
+ # These dimensions include symbolic augmentation (time, CTCS) but not byof CTCS
152
+ # augmentation, which is exactly what user byof functions will see
153
+ n_x = sum(
154
+ state.shape[0] if len(state.shape) > 0 else 1 for state in self.symbolic.states
155
+ )
156
+ n_u = sum(
157
+ control.shape[0] if len(control.shape) > 0 else 1
158
+ for control in self.symbolic.controls
159
+ )
160
+
161
+ validate_byof(byof, self.symbolic.states, n_x, n_u, N)
162
+
163
+ # Lower to JAX and CVXPy (byof handling happens inside lower_symbolic_problem)
164
+ self._lowered: LoweredProblem = lower_symbolic_problem(self.symbolic, byof=byof)
165
+
166
+ # Store parameters in two forms:
167
+ self._parameters = self.symbolic.parameters # Plain dict for JAX functions
168
+ # Wrapper dict for user access that auto-syncs
169
+ self._parameter_wrapper = ParameterDict(self, self._parameters, self.symbolic.parameters)
170
+
171
+ # Setup SCP Configuration
172
+ self.settings = Config(
173
+ sim=SimConfig(
174
+ x=self._lowered.x_unified,
175
+ x_prop=self._lowered.x_prop_unified,
176
+ u=self._lowered.u_unified,
177
+ total_time=self._lowered.x_unified.initial[self._lowered.x_unified.time_slice][0],
178
+ n_states=self._lowered.x_unified.initial.shape[0],
179
+ n_states_prop=self._lowered.x_prop_unified.initial.shape[0],
180
+ ctcs_node_intervals=self.symbolic.node_intervals,
181
+ ),
182
+ scp=ScpConfig(
183
+ n=N,
184
+ w_tr_max_scaling_factor=1e2, # Maximum Trust Region Weight
185
+ ),
186
+ dis=DiscretizationConfig(),
187
+ dev=DevConfig(),
188
+ cvx=ConvexSolverConfig(),
189
+ prp=PropagationConfig(),
190
+ )
191
+
192
+ # OCP construction happens in initialize() so users can modify
193
+ # settings (like uniform_time_grid) between __init__ and initialize()
194
+ self._optimal_control_problem: cp.Problem = None
195
+ self._discretization_solver: callable = None
196
+ self._solve_ocp: callable = None # Solver callable (built during initialize)
197
+
198
+ # Set up emitter & thread only if printing is enabled
199
+ if self.settings.dev.printing:
200
+ self.print_queue = queue.Queue()
201
+ self.emitter_function = lambda data: self.print_queue.put(data)
202
+ self.print_thread = threading.Thread(
203
+ target=printing.intermediate,
204
+ args=(self.print_queue, self.settings),
205
+ daemon=True,
206
+ )
207
+ self.print_thread.start()
208
+ else:
209
+ # no-op emitter; nothing ever gets queued or printed
210
+ self.emitter_function = lambda data: None
211
+
212
+ self.timing_init = None
213
+ self.timing_solve = None
214
+ self.timing_post = None
215
+
216
+ # Compiled dynamics (vmapped versions, set in initialize())
217
+ self._compiled_dynamics: Optional[Dynamics] = None
218
+ self._compiled_dynamics_prop: Optional[Dynamics] = None
219
+
220
+ # Compiled constraints (JIT-compiled versions, set in initialize())
221
+ self._compiled_constraints: Optional[LoweredJaxConstraints] = None
222
+
223
+ # Solver state (created fresh for each solve)
224
+ self._state: Optional[AlgorithmState] = None
225
+
226
+ # Final solution state (saved after successful solve)
227
+ self._solution: Optional[AlgorithmState] = None
228
+
229
+ # SCP algorithm (currently hardcoded to PTR)
230
+ self._algorithm = PenalizedTrustRegion()
231
+
232
+ @property
233
+ def parameters(self):
234
+ """Get the parameters dictionary.
235
+
236
+ The returned dictionary automatically syncs to CVXPy when modified:
237
+ problem.parameters["obs_radius"] = 2.0 # Auto-syncs to CVXPy
238
+ problem.parameters.update({"gate_0_center": center}) # Also syncs
239
+
240
+ Returns:
241
+ ParameterDict: Special dict that syncs to CVXPy on assignment
242
+ """
243
+ return self._parameter_wrapper
244
+
245
+ @parameters.setter
246
+ def parameters(self, new_params: dict):
247
+ """Replace the entire parameters dictionary and sync to CVXPy.
248
+
249
+ Args:
250
+ new_params: New parameters dictionary
251
+ """
252
+ self._parameters = dict(new_params) # Create new plain dict
253
+ self._parameter_wrapper = ParameterDict(self, self._parameters, new_params)
254
+ self._sync_parameters()
255
+
256
+ def _sync_parameters(self):
257
+ """Sync all parameter values to CVXPy parameters."""
258
+ if self._lowered.cvxpy_params is not None:
259
+ for name, value in self._parameter_wrapper.items():
260
+ if name in self._lowered.cvxpy_params:
261
+ self._lowered.cvxpy_params[name].value = value
262
+
263
+ @property
264
+ def state(self) -> Optional[AlgorithmState]:
265
+ """Access the current solver state.
266
+
267
+ The solver state contains all mutable state from the SCP iterations,
268
+ including current guesses, costs, weights, and history.
269
+
270
+ Returns:
271
+ AlgorithmState if initialized, None otherwise
272
+
273
+ Example:
274
+ When using `Problem.step()` can use the state to check convergence _etc._
275
+
276
+ problem.initialize()
277
+ problem.step()
278
+ print(f"Iteration {problem.state.k}, J_tr={problem.state.J_tr}")
279
+ """
280
+ return self._state
281
+
282
+ @property
283
+ def lowered(self) -> LoweredProblem:
284
+ """Access the lowered problem containing JAX/CVXPy objects.
285
+
286
+ Returns:
287
+ LoweredProblem with dynamics, constraints, unified interfaces, and CVXPy vars
288
+ """
289
+ return self._lowered
290
+
291
+ @property
292
+ def x_unified(self):
293
+ """Unified state interface (delegates to lowered.x_unified)."""
294
+ return self._lowered.x_unified
295
+
296
+ @property
297
+ def u_unified(self):
298
+ """Unified control interface (delegates to lowered.u_unified)."""
299
+ return self._lowered.u_unified
300
+
301
+ @property
302
+ def slices(self) -> dict[str, slice]:
303
+ """Get mapping of state and control names to their slices in unified vectors.
304
+
305
+ This property returns a dictionary mapping each state and control variable name
306
+ to its slice in the respective unified vector. This is particularly useful for
307
+ expert users working with byof (bring-your-own functions) who need to manually
308
+ index into the unified x and u vectors.
309
+
310
+ Returns:
311
+ dict[str, slice]: Dictionary mapping variable names to slice objects.
312
+ State variables map to slices in the x vector.
313
+ Control variables map to slices in the u vector.
314
+
315
+ Example:
316
+ problem = ox.Problem(dynamics, states, controls, ...)
317
+ print(problem.slices)
318
+ # {'position': slice(0, 3), 'velocity': slice(3, 6), 'theta': slice(0, 1)}
319
+
320
+ # Use in byof functions
321
+ byof = {
322
+ "nodal_constraints": [
323
+ lambda x, u, node, params: x[problem.slices["velocity"][0]] - 10.0,
324
+ lambda x, u, node, params: u[problem.slices["theta"][0]] - 1.57,
325
+ ]
326
+ }
327
+ """
328
+ slices = {}
329
+ slices.update({state.name: state.slice for state in self.symbolic.states})
330
+ slices.update({control.name: control.slice for control in self.symbolic.controls})
331
+ return slices
332
+
333
+ def _format_result(self, state: AlgorithmState, converged: bool) -> OptimizationResults:
334
+ """Format solver state as an OptimizationResults object.
335
+
336
+ Converts the internal solver state into a user-facing results object,
337
+ mapping state/control arrays to named fields based on symbolic metadata.
338
+
339
+ Args:
340
+ state: The AlgorithmState to extract results from.
341
+ converged: Whether the optimization converged.
342
+
343
+ Returns:
344
+ OptimizationResults containing the solution data.
345
+ """
346
+ # Build nodes dictionary with all states and controls
347
+ nodes_dict = {}
348
+
349
+ # Add all states (user-defined and augmented)
350
+ for sym_state in self.symbolic.states:
351
+ nodes_dict[sym_state.name] = state.x[:, sym_state._slice]
352
+
353
+ # Add all controls (user-defined and augmented)
354
+ for control in self.symbolic.controls:
355
+ nodes_dict[control.name] = state.u[:, control._slice]
356
+
357
+ return OptimizationResults(
358
+ converged=converged,
359
+ t_final=state.x[:, self.settings.sim.time_slice][-1],
360
+ nodes=nodes_dict,
361
+ trajectory={}, # Populated by post_process
362
+ _states=self.symbolic.states_prop, # Use propagation states for trajectory dict
363
+ _controls=self.symbolic.controls,
364
+ X=state.X, # Single source of truth - x and u are properties
365
+ U=state.U,
366
+ discretization_history=state.V_history,
367
+ J_tr_history=state.J_tr,
368
+ J_vb_history=state.J_vb,
369
+ J_vc_history=state.J_vc,
370
+ TR_history=state.TR_history,
371
+ VC_history=state.VC_history,
372
+ )
373
+
374
+ def initialize(self):
375
+ """Compile dynamics, constraints, and solvers; prepare for optimization.
376
+
377
+ This method vmaps dynamics, JIT-compiles constraints, builds the convex
378
+ subproblem, and initializes the solver state. Must be called before solve().
379
+
380
+ Example:
381
+ Prior to calling the `.solve()` method it is necessary to initialize the problem
382
+
383
+ problem = Problem(dynamics, constraints, states, controls, N, time)
384
+ problem.initialize() # Compile and prepare
385
+ problem.solve() # Run optimization
386
+ """
387
+ printing.intro()
388
+
389
+ # Print problem summary
390
+ printing.print_problem_summary(self.settings, self._lowered)
391
+
392
+ # Enable the profiler
393
+ pr = profiling.profiling_start(self.settings.dev.profiling)
394
+
395
+ t_0_while = time.time()
396
+ # Ensure parameter sizes and normalization are correct
397
+ self.settings.scp.__post_init__()
398
+ self.settings.sim.__post_init__()
399
+
400
+ # Create compiled (vmapped) dynamics as new instances
401
+ # This preserves the original un-vmapped versions in _lowered
402
+ self._compiled_dynamics = Dynamics(
403
+ f=jax.vmap(self._lowered.dynamics.f, in_axes=(0, 0, 0, None)),
404
+ A=jax.vmap(self._lowered.dynamics.A, in_axes=(0, 0, 0, None)),
405
+ B=jax.vmap(self._lowered.dynamics.B, in_axes=(0, 0, 0, None)),
406
+ )
407
+
408
+ self._compiled_dynamics_prop = Dynamics(
409
+ f=jax.vmap(self._lowered.dynamics_prop.f, in_axes=(0, 0, 0, None)),
410
+ )
411
+
412
+ # Create compiled (JIT-compiled) constraints as new instances
413
+ # This preserves the original un-JIT'd versions in _lowered
414
+ # TODO: (haynec) switch to AOT instead of JIT
415
+ compiled_nodal = [
416
+ LoweredNodalConstraint(
417
+ func=jax.jit(c.func),
418
+ grad_g_x=jax.jit(c.grad_g_x),
419
+ grad_g_u=jax.jit(c.grad_g_u),
420
+ nodes=c.nodes,
421
+ )
422
+ for c in self._lowered.jax_constraints.nodal
423
+ ]
424
+
425
+ compiled_cross_node = [
426
+ LoweredCrossNodeConstraint(
427
+ func=jax.jit(c.func),
428
+ grad_g_X=jax.jit(c.grad_g_X),
429
+ grad_g_U=jax.jit(c.grad_g_U),
430
+ )
431
+ for c in self._lowered.jax_constraints.cross_node
432
+ ]
433
+
434
+ self._compiled_constraints = LoweredJaxConstraints(
435
+ nodal=compiled_nodal,
436
+ cross_node=compiled_cross_node,
437
+ ctcs=self._lowered.jax_constraints.ctcs, # CTCS aren't JIT-compiled here
438
+ )
439
+
440
+ # Generate solvers using compiled (vmapped) dynamics
441
+ self._discretization_solver = get_discretization_solver(
442
+ self._compiled_dynamics, self.settings
443
+ )
444
+ self._propagation_solver = get_propagation_solver(
445
+ self._compiled_dynamics_prop.f, self.settings
446
+ )
447
+
448
+ # Build optimal control problem using LoweredProblem
449
+ self._optimal_control_problem = optimal_control_problem(self.settings, self._lowered)
450
+
451
+ # Get cache file paths using symbolic AST hashing
452
+ # This is more stable than hashing lowered JAX code
453
+ dis_solver_file, prop_solver_file = get_solver_cache_paths(
454
+ self.symbolic,
455
+ dt=self.settings.prp.dt,
456
+ total_time=self.settings.sim.total_time,
457
+ )
458
+
459
+ # Compile the discretization solver
460
+ self._discretization_solver = load_or_compile_discretization_solver(
461
+ self._discretization_solver,
462
+ dis_solver_file,
463
+ self._parameters, # Plain dict for JAX
464
+ self.settings.scp.n,
465
+ self.settings.sim.n_states,
466
+ self.settings.sim.n_controls,
467
+ save_compiled=self.settings.sim.save_compiled,
468
+ debug=self.settings.dev.debug,
469
+ )
470
+
471
+ # Setup propagation solver parameters
472
+ dtau = 1.0 / (self.settings.scp.n - 1)
473
+ dt_max = self.settings.sim.u.max[self.settings.sim.time_dilation_slice][0] * dtau
474
+ self.settings.prp.max_tau_len = int(dt_max / self.settings.prp.dt) + 2
475
+
476
+ # Compile the propagation solver
477
+ self._propagation_solver = load_or_compile_propagation_solver(
478
+ self._propagation_solver,
479
+ prop_solver_file,
480
+ self._parameters, # Plain dict for JAX
481
+ self.settings.sim.n_states_prop,
482
+ self.settings.sim.n_controls,
483
+ self.settings.prp.max_tau_len,
484
+ save_compiled=self.settings.sim.save_compiled,
485
+ )
486
+
487
+ # Build solver callable (handle CVXPyGen if enabled)
488
+ if self.settings.cvx.cvxpygen:
489
+ try:
490
+ from solver.cpg_solver import cpg_solve
491
+
492
+ with open("solver/problem.pickle", "rb") as f:
493
+ pickle.load(f)
494
+ self._optimal_control_problem.register_solve("CPG", cpg_solve)
495
+ solver_args = self.settings.cvx.solver_args
496
+ self._solve_ocp = lambda: self._optimal_control_problem.solve(
497
+ method="CPG", **solver_args
498
+ )
499
+ except ImportError:
500
+ raise ImportError(
501
+ "cvxpygen solver not found. Make sure cvxpygen is installed and code "
502
+ "generation has been run. Install with: pip install openscvx[cvxpygen]"
503
+ )
504
+ else:
505
+ solver = self.settings.cvx.solver
506
+ solver_args = self.settings.cvx.solver_args
507
+ self._solve_ocp = lambda: self._optimal_control_problem.solve(
508
+ solver=solver, **solver_args
509
+ )
510
+
511
+ # Initialize the SCP algorithm
512
+ print("Initializing the SCvx Subproblem Solver...")
513
+ self._algorithm.initialize(
514
+ self._optimal_control_problem,
515
+ self._discretization_solver,
516
+ self._compiled_constraints,
517
+ self._solve_ocp,
518
+ self.emitter_function,
519
+ self._parameters, # For warm-start only
520
+ self.settings, # For warm-start only
521
+ )
522
+ print("✓ SCvx Subproblem Solver initialized")
523
+
524
+ # Create fresh solver state
525
+ self._state = AlgorithmState.from_settings(self.settings)
526
+
527
+ t_f_while = time.time()
528
+ self.timing_init = t_f_while - t_0_while
529
+ print("Total Initialization Time: ", self.timing_init)
530
+
531
+ # Prime the propagation solver
532
+ prime_propagation_solver(self._propagation_solver, self._parameters, self.settings)
533
+
534
+ profiling.profiling_end(pr, "initialize")
535
+
536
+ def reset(self):
537
+ """Reset solver state to re-run optimization from initial conditions.
538
+
539
+ Creates fresh AlgorithmState while preserving compiled dynamics and solvers.
540
+ Use this to run multiple optimizations without re-initializing.
541
+
542
+ Raises:
543
+ ValueError: If initialize() has not been called yet.
544
+
545
+ Example:
546
+ After calling `.step()` it may be necessary to reset the problem back to the initial
547
+ conditions
548
+
549
+ problem.initialize()
550
+ result1 = problem.step()
551
+ problem.reset()
552
+ result2 = problem.solve() # Fresh run with same setup
553
+ """
554
+ if self._compiled_dynamics is None:
555
+ raise ValueError("Problem has not been initialized. Call initialize() first")
556
+
557
+ # Create fresh solver state from settings
558
+ self._state = AlgorithmState.from_settings(self.settings)
559
+
560
+ # Reset solution
561
+ self._solution = None
562
+
563
+ # Reset timing
564
+ self.timing_solve = None
565
+ self.timing_post = None
566
+
567
+ def step(self) -> dict:
568
+ """Perform a single SCP iteration.
569
+
570
+ Designed for real-time plotting and interactive optimization. Performs one
571
+ iteration including subproblem solve, state update, and progress emission.
572
+
573
+ Note:
574
+ This method is NOT idempotent - it mutates internal state and advances
575
+ the iteration counter. Use reset() to return to initial conditions.
576
+
577
+ Returns:
578
+ dict: Contains "converged" (bool) and current iteration state
579
+
580
+ Example:
581
+ Call `.step()` manually in a loop to control the algorithm directly
582
+
583
+ problem.initialize()
584
+ while not problem.step()["converged"]:
585
+ plot_trajectory(problem.state.trajs[-1])
586
+ """
587
+ if self._state is None:
588
+ raise ValueError("Problem has not been initialized. Call initialize() first")
589
+
590
+ converged = self._algorithm.step(
591
+ self._state,
592
+ self._parameters, # May change between steps
593
+ self.settings, # May change between steps
594
+ )
595
+
596
+ # Return dict matching original API
597
+ return {
598
+ "converged": converged,
599
+ "scp_k": self._state.k,
600
+ "scp_J_tr": self._state.J_tr,
601
+ "scp_J_vb": self._state.J_vb,
602
+ "scp_J_vc": self._state.J_vc,
603
+ }
604
+
605
+ def solve(
606
+ self, max_iters: Optional[int] = None, continuous: bool = False
607
+ ) -> OptimizationResults:
608
+ """Run the SCP algorithm until convergence or iteration limit.
609
+
610
+ Args:
611
+ max_iters: Maximum iterations (default: settings.scp.k_max)
612
+ continuous: If True, run all iterations regardless of convergence
613
+
614
+ Returns:
615
+ OptimizationResults with trajectory and convergence info
616
+ (call post_process() for full propagation)
617
+ """
618
+ # Sync parameters before solving
619
+ self._sync_parameters()
620
+
621
+ required = [
622
+ self._compiled_dynamics,
623
+ self._compiled_constraints,
624
+ self._optimal_control_problem,
625
+ self._discretization_solver,
626
+ self._state,
627
+ ]
628
+ if any(r is None for r in required):
629
+ raise ValueError("Problem has not been initialized. Call initialize() before solve()")
630
+
631
+ # Enable the profiler
632
+ pr = profiling.profiling_start(self.settings.dev.profiling)
633
+
634
+ t_0_while = time.time()
635
+ # Print top header for solver results
636
+ printing.header()
637
+
638
+ k_max = max_iters if max_iters is not None else self.settings.scp.k_max
639
+
640
+ while self._state.k <= k_max:
641
+ result = self.step()
642
+ if result["converged"] and not continuous:
643
+ break
644
+
645
+ t_f_while = time.time()
646
+ self.timing_solve = t_f_while - t_0_while
647
+
648
+ while self.print_queue.qsize() > 0:
649
+ time.sleep(0.1)
650
+
651
+ # Print bottom footer for solver results as well as total computation time
652
+ printing.footer()
653
+
654
+ profiling.profiling_end(pr, "solve")
655
+
656
+ # Store solution state
657
+ self._solution = copy.deepcopy(self._state)
658
+
659
+ return self._format_result(self._state, self._state.k <= k_max)
660
+
661
+ def post_process(self) -> OptimizationResults:
662
+ """Propagate solution through full nonlinear dynamics for high-fidelity trajectory.
663
+
664
+ Integrates the converged SCP solution through the nonlinear dynamics to
665
+ produce x_full, u_full, and t_full. Call after solve() for final results.
666
+
667
+ Returns:
668
+ OptimizationResults with propagated trajectory fields
669
+
670
+ Raises:
671
+ ValueError: If solve() has not been called yet.
672
+ """
673
+ if self._solution is None:
674
+ raise ValueError("No solution available. Call solve() first.")
675
+
676
+ # Enable the profiler
677
+ pr = profiling.profiling_start(self.settings.dev.profiling)
678
+
679
+ # Create result from stored solution state
680
+ result = self._format_result(self._solution, self._solution.k <= self.settings.scp.k_max)
681
+
682
+ t_0_post = time.time()
683
+ result = propagate_trajectory_results(
684
+ self._parameters, self.settings, result, self._propagation_solver
685
+ )
686
+ t_f_post = time.time()
687
+
688
+ self.timing_post = t_f_post - t_0_post
689
+
690
+ # Store the propagated result back into _solution for plotting
691
+ # Store as a cached attribute on the _solution object
692
+ self._solution._propagated_result = result
693
+
694
+ # Print results summary
695
+ printing.print_results_summary(
696
+ result, self.timing_post, self.timing_init, self.timing_solve
697
+ )
698
+
699
+ profiling.profiling_end(pr, "postprocess")
700
+ return result
701
+
702
+ def citation(self) -> str:
703
+ """Return BibTeX citations for all components used in this problem.
704
+
705
+ Aggregates citations from the algorithm and other components (discretization,
706
+ convex solver, etc.) Each section is prefixed with a comment indicating which component the
707
+ citation is for.
708
+
709
+ Returns:
710
+ Formatted string containing all BibTeX citations with comments.
711
+
712
+ Example:
713
+ Print all citations for a problem::
714
+
715
+ problem = Problem(dynamics, constraints, states, controls, N, time)
716
+ print(problem.citation())
717
+ """
718
+ sections = []
719
+
720
+ sections.append(r"% --- AUTO-GENERATED CITATIONS FOR OPENSCVX CONFIGURATION ---")
721
+
722
+ # Algorithm citations
723
+ algo_citations = self._algorithm.citation()
724
+ if algo_citations:
725
+ algo_name = type(self._algorithm).__name__
726
+ header = f"% Algorithm: {algo_name}"
727
+ citations = "\n".join(algo_citations)
728
+ sections.append(f"{header}\n\n{citations}")
729
+
730
+ # Future: add citations from discretization, constraint formulations, etc.
731
+
732
+ sections.append(r"% --- END AUTO-GENERATED CITATIONS")
733
+
734
+ return "\n\n".join(sections)