openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
openscvx/__init__.py ADDED
@@ -0,0 +1,123 @@
1
+ import os
2
+
3
+ # Set Equinox error handling to return NaN instead of crashing
4
+ os.environ["EQX_ON_ERROR"] = "nan"
5
+
6
+ # Cache management
7
+ # Core symbolic expressions - flat namespace for most common functions
8
+ import openscvx.symbolic.expr.lie as lie
9
+ import openscvx.symbolic.expr.linalg as linalg
10
+ import openscvx.symbolic.expr.spatial as spatial
11
+ import openscvx.symbolic.expr.stl as stl
12
+ from openscvx.expert import ByofSpec
13
+ from openscvx.problem import Problem
14
+ from openscvx.symbolic.expr import (
15
+ CTCS,
16
+ Abs,
17
+ Add,
18
+ Block,
19
+ Concat,
20
+ Constant,
21
+ Constraint,
22
+ Control,
23
+ Cos,
24
+ Diag,
25
+ Div,
26
+ Equality,
27
+ Exp,
28
+ Expr,
29
+ Fixed,
30
+ Free,
31
+ Hstack,
32
+ Index,
33
+ Inequality,
34
+ Leaf,
35
+ Log,
36
+ LogSumExp,
37
+ MatMul,
38
+ Max,
39
+ Maximize,
40
+ Minimize,
41
+ Mul,
42
+ Neg,
43
+ NodalConstraint,
44
+ Parameter,
45
+ Power,
46
+ Sin,
47
+ Sqrt,
48
+ Stack,
49
+ State,
50
+ Sub,
51
+ Sum,
52
+ Tan,
53
+ Variable,
54
+ Vstack,
55
+ ctcs,
56
+ )
57
+ from openscvx.symbolic.time import Time
58
+ from openscvx.utils.cache import clear_cache, get_cache_dir, get_cache_size
59
+
60
+ __all__ = [
61
+ # Main Trajectory Optimization Entrypoint
62
+ "Problem",
63
+ # Cache management
64
+ "get_cache_dir",
65
+ "clear_cache",
66
+ "get_cache_size",
67
+ # Time configuration
68
+ "Time",
69
+ # Core base classes
70
+ "Expr",
71
+ "Leaf",
72
+ "Parameter",
73
+ "Variable",
74
+ "State",
75
+ "Control",
76
+ # Boundary condition helpers
77
+ "Free",
78
+ "Fixed",
79
+ "Minimize",
80
+ "Maximize",
81
+ # Basic arithmetic operations
82
+ "Add",
83
+ "Sub",
84
+ "Mul",
85
+ "Div",
86
+ "MatMul",
87
+ "Neg",
88
+ "Power",
89
+ "Sum",
90
+ # Array operations
91
+ "Index",
92
+ "Concat",
93
+ "Stack",
94
+ "Hstack",
95
+ "Vstack",
96
+ "Block",
97
+ "Diag",
98
+ "Constant",
99
+ # Mathematical functions
100
+ "Sin",
101
+ "Cos",
102
+ "Tan",
103
+ "Sqrt",
104
+ "Abs",
105
+ "Exp",
106
+ "Log",
107
+ "LogSumExp",
108
+ "Max",
109
+ # Constraints
110
+ "Constraint",
111
+ "Equality",
112
+ "Inequality",
113
+ "NodalConstraint",
114
+ "CTCS",
115
+ "ctcs",
116
+ # Submodules
117
+ "stl",
118
+ "spatial",
119
+ "linalg",
120
+ "lie",
121
+ # Expert mode types
122
+ "ByofSpec",
123
+ ]
openscvx/_version.py ADDED
@@ -0,0 +1,34 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple
16
+ from typing import Union
17
+
18
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
20
+ else:
21
+ VERSION_TUPLE = object
22
+ COMMIT_ID = object
23
+
24
+ version: str
25
+ __version__: str
26
+ __version_tuple__: VERSION_TUPLE
27
+ version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '0.3.2.dev170'
32
+ __version_tuple__ = version_tuple = (0, 3, 2, 'dev170')
33
+
34
+ __commit_id__ = commit_id = None
@@ -0,0 +1,92 @@
1
+ """Successive convexification algorithms for trajectory optimization.
2
+
3
+ This module provides implementations of SCvx (Successive Convexification) algorithms
4
+ for solving non-convex trajectory optimization problems through iterative convex
5
+ approximation.
6
+
7
+ All algorithms inherit from :class:`Algorithm`, enabling pluggable algorithm
8
+ implementations and custom SCvx variants:
9
+
10
+ ```python
11
+ class Algorithm(ABC):
12
+ @abstractmethod
13
+ def initialize(self, ocp, discretization_solver, jax_constraints,
14
+ solve_ocp, emitter, params, settings) -> None:
15
+ '''Store compiled infrastructure and warm-start solvers.'''
16
+ ...
17
+
18
+ @abstractmethod
19
+ def step(self, state, params, settings) -> bool:
20
+ '''Execute one iteration using stored infrastructure.'''
21
+ ...
22
+ ```
23
+
24
+ Immutable components (ocp, discretization_solver, jax_constraints, etc.) are stored
25
+ during ``initialize()``. Mutable configuration (params, settings) is passed per-step
26
+ to support runtime parameter updates and tolerance tuning.
27
+
28
+ :class:`AlgorithmState` holds mutable state during SCP iterations. Algorithms
29
+ that require additional state can subclass it:
30
+
31
+ ```python
32
+ @dataclass
33
+ class MyAlgorithmState(AlgorithmState):
34
+ my_custom_field: float = 0.0
35
+ ```
36
+
37
+ Note:
38
+ ``AlgorithmState`` currently combines iteration metrics (costs, weights),
39
+ trajectory history, and discretization data. A future refactor may separate
40
+ these concerns into distinct classes for clearer data flow:
41
+
42
+ ```python
43
+ @dataclass
44
+ class AlgorithmState:
45
+ # Mutable iteration state
46
+ k: int
47
+ J_tr: float
48
+ J_vb: float
49
+ J_vc: float
50
+ w_tr: float
51
+ lam_cost: float
52
+ lam_vc: ...
53
+ lam_vb: float
54
+
55
+ @dataclass
56
+ class TrajectoryHistory:
57
+ # Accumulated trajectory solutions
58
+ X: List[np.ndarray]
59
+ U: List[np.ndarray]
60
+
61
+ @property
62
+ def x(self): return self.X[-1]
63
+
64
+ @property
65
+ def u(self): return self.U[-1]
66
+
67
+ @dataclass
68
+ class DebugHistory:
69
+ # Optional diagnostic data (discretization matrices, etc.)
70
+ V_history: List[np.ndarray]
71
+ VC_history: List[np.ndarray]
72
+ TR_history: List[np.ndarray]
73
+ ```
74
+
75
+ Current Implementations:
76
+
77
+ - :class:`PenalizedTrustRegion`: Penalized Trust Region (PTR) algorithm
78
+ """
79
+
80
+ from .base import Algorithm, AlgorithmState
81
+ from .optimization_results import OptimizationResults
82
+ from .penalized_trust_region import PenalizedTrustRegion
83
+
84
+ __all__ = [
85
+ # Base class
86
+ "Algorithm",
87
+ "AlgorithmState",
88
+ # Core results
89
+ "OptimizationResults",
90
+ # PTR algorithm
91
+ "PenalizedTrustRegion",
92
+ ]
@@ -0,0 +1,24 @@
1
+ """Autotuning functions for SCP (Successive Convex Programming) parameters."""
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from openscvx.config import Config
6
+
7
+ if TYPE_CHECKING:
8
+ from .base import AlgorithmState
9
+
10
+
11
+ def update_scp_weights(state: "AlgorithmState", settings: Config, scp_k: int):
12
+ """Update SCP weights and cost parameters based on iteration number.
13
+
14
+ Args:
15
+ state: Solver state containing current weight values (mutated in place)
16
+ settings: Configuration object containing adaptation parameters
17
+ scp_k: Current SCP iteration number
18
+ """
19
+ # Update trust region weight in state
20
+ state.w_tr = min(state.w_tr * settings.scp.w_tr_adapt, settings.scp.w_tr_max)
21
+
22
+ # Update cost relaxation parameter after cost_drop iterations
23
+ if scp_k > settings.scp.cost_drop:
24
+ state.lam_cost = state.lam_cost * settings.scp.cost_relax
@@ -0,0 +1,351 @@
1
+ """Base class for successive convexification algorithms.
2
+
3
+ This module defines the abstract interface that all SCP algorithm implementations
4
+ must follow, along with the AlgorithmState dataclass that holds mutable state
5
+ during SCP iterations.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from dataclasses import dataclass, field
10
+ from typing import TYPE_CHECKING, List, Union
11
+
12
+ import numpy as np
13
+
14
+ if TYPE_CHECKING:
15
+ import cvxpy as cp
16
+
17
+ from openscvx.config import Config
18
+ from openscvx.lowered.jax_constraints import LoweredJaxConstraints
19
+
20
+
21
+ @dataclass
22
+ class AlgorithmState:
23
+ """Mutable state for SCP iterations.
24
+
25
+ This dataclass holds all state that changes during the solve process.
26
+ It stores only the evolving trajectory arrays, not the full State/Control
27
+ objects which contain immutable configuration metadata.
28
+
29
+ Trajectory arrays are stored in history lists, with the current guess
30
+ accessed via properties that return the latest entry.
31
+
32
+ A fresh instance is created for each solve, enabling easy reset functionality.
33
+
34
+ Attributes:
35
+ k: Current iteration number (starts at 1)
36
+ J_tr: Current trust region cost
37
+ J_vb: Current virtual buffer cost
38
+ J_vc: Current virtual control cost
39
+ w_tr: Current trust region weight (may adapt during solve)
40
+ lam_cost: Current cost weight (may relax during solve)
41
+ lam_vc: Current virtual control penalty weight
42
+ lam_vb: Current virtual buffer penalty weight
43
+ n_x: Number of states (for unpacking V vectors)
44
+ n_u: Number of controls (for unpacking V vectors)
45
+ N: Number of trajectory nodes (for unpacking V vectors)
46
+ X: List of state trajectory iterates
47
+ U: List of control trajectory iterates
48
+ V_history: List of discretization history
49
+ """
50
+
51
+ k: int
52
+ J_tr: float
53
+ J_vb: float
54
+ J_vc: float
55
+ w_tr: float
56
+ lam_cost: float
57
+ lam_vc: Union[float, np.ndarray]
58
+ lam_vb: float
59
+ n_x: int
60
+ n_u: int
61
+ N: int
62
+ X: List[np.ndarray] = field(default_factory=list)
63
+ U: List[np.ndarray] = field(default_factory=list)
64
+ V_history: List[np.ndarray] = field(default_factory=list)
65
+ VC_history: List[np.ndarray] = field(default_factory=list)
66
+ TR_history: List[np.ndarray] = field(default_factory=list)
67
+
68
+ @property
69
+ def x(self) -> np.ndarray:
70
+ """Get current state trajectory array.
71
+
72
+ Returns:
73
+ Current state trajectory guess (latest entry in history), shape (N, n_states)
74
+ """
75
+ return self.X[-1]
76
+
77
+ @property
78
+ def u(self) -> np.ndarray:
79
+ """Get current control trajectory array.
80
+
81
+ Returns:
82
+ Current control trajectory guess (latest entry in history), shape (N, n_controls)
83
+ """
84
+ return self.U[-1]
85
+
86
+ @property
87
+ def x_prop(self) -> np.ndarray:
88
+ """Extract propagated state trajectory from latest V.
89
+
90
+ Returns:
91
+ Propagated state trajectory x_prop with shape (N-1, n_x), or None if no V_history
92
+
93
+ Example:
94
+ After running an iteration, access the propagated states::
95
+
96
+ problem.step()
97
+ x_prop = problem.state.x_prop # Shape (N-1, n_x)
98
+ """
99
+ if not self.V_history:
100
+ return None
101
+
102
+ # V_history contains Vmulti from discretization
103
+ # Shape: (flattened_size, n_timesteps) where flattened_size = (N-1) * i4
104
+ V = self.V_history[-1]
105
+
106
+ # Take final timestep and reshape to (N-1, i4)
107
+ i4 = self.n_x + self.n_x * self.n_x + 2 * self.n_x * self.n_u
108
+ V_final = V[:, -1].reshape(-1, i4)
109
+
110
+ # Extract propagated state (first n_x elements of each row)
111
+ return V_final[:, : self.n_x]
112
+
113
+ @property
114
+ def A_d(self) -> np.ndarray:
115
+ """Extract discretized state transition matrix from latest V.
116
+
117
+ Returns:
118
+ Discretized state Jacobian A_d with shape (N-1, n_x, n_x), or None if no V_history
119
+
120
+ Example:
121
+ After running an iteration, access linearization matrices::
122
+
123
+ problem.step()
124
+ A_d = problem.state.A_d # Shape (N-1, n_x, n_x)
125
+ """
126
+ if not self.V_history:
127
+ return None
128
+
129
+ # Extract indices for unpacking V vector
130
+ i1 = self.n_x
131
+ i2 = i1 + self.n_x * self.n_x
132
+
133
+ # V_history contains Vmulti from discretization
134
+ # Shape: (flattened_size, n_timesteps) where flattened_size = (N-1) * i4
135
+ V = self.V_history[-1]
136
+
137
+ # Take final timestep and reshape to (N-1, i4)
138
+ i4 = self.n_x + self.n_x * self.n_x + 2 * self.n_x * self.n_u
139
+ V_final = V[:, -1].reshape(-1, i4)
140
+
141
+ # Extract and reshape A_d matrix
142
+ return V_final[:, i1:i2].reshape(self.N - 1, self.n_x, self.n_x)
143
+
144
+ @property
145
+ def B_d(self) -> np.ndarray:
146
+ """Extract discretized control influence matrix (current node) from latest V.
147
+
148
+ Returns:
149
+ Discretized control Jacobian B_d with shape (N-1, n_x, n_u), or None if no V_history
150
+
151
+ Example:
152
+ After running an iteration, access linearization matrices::
153
+
154
+ problem.step()
155
+ B_d = problem.state.B_d # Shape (N-1, n_x, n_u)
156
+ """
157
+ if not self.V_history:
158
+ return None
159
+
160
+ # Extract indices for unpacking V vector
161
+ i1 = self.n_x
162
+ i2 = i1 + self.n_x * self.n_x
163
+ i3 = i2 + self.n_x * self.n_u
164
+
165
+ # V_history contains Vmulti from discretization
166
+ V = self.V_history[-1]
167
+
168
+ # Take final timestep and reshape to (N-1, i4)
169
+ i4 = self.n_x + self.n_x * self.n_x + 2 * self.n_x * self.n_u
170
+ V_final = V[:, -1].reshape(-1, i4)
171
+
172
+ # Extract and reshape B_d matrix
173
+ return V_final[:, i2:i3].reshape(self.N - 1, self.n_x, self.n_u)
174
+
175
+ @property
176
+ def C_d(self) -> np.ndarray:
177
+ """Extract discretized control influence matrix (next node) from latest V.
178
+
179
+ Returns:
180
+ Discretized control Jacobian C_d with shape (N-1, n_x, n_u), or None if no V_history
181
+
182
+ Example:
183
+ After running an iteration, access linearization matrices::
184
+
185
+ problem.step()
186
+ C_d = problem.state.C_d # Shape (N-1, n_x, n_u)
187
+ """
188
+ if not self.V_history:
189
+ return None
190
+
191
+ # Extract indices for unpacking V vector
192
+ i2 = self.n_x + self.n_x * self.n_x
193
+ i3 = i2 + self.n_x * self.n_u
194
+ i4 = i3 + self.n_x * self.n_u
195
+
196
+ # V_history contains Vmulti from discretization
197
+ V = self.V_history[-1]
198
+
199
+ # Take final timestep and reshape to (N-1, i4)
200
+ V_final = V[:, -1].reshape(-1, i4)
201
+
202
+ # Extract and reshape C_d matrix
203
+ return V_final[:, i3:i4].reshape(self.N - 1, self.n_x, self.n_u)
204
+
205
+ @classmethod
206
+ def from_settings(cls, settings: "Config") -> "AlgorithmState":
207
+ """Create initial algorithm state from configuration.
208
+
209
+ Copies only the trajectory arrays from settings, leaving all metadata
210
+ (bounds, boundary conditions, etc.) in the original settings object.
211
+
212
+ Args:
213
+ settings: Configuration object containing initial guesses and SCP parameters
214
+
215
+ Returns:
216
+ Fresh AlgorithmState initialized from settings with copied arrays
217
+ """
218
+ return cls(
219
+ k=1,
220
+ J_tr=1e2,
221
+ J_vb=1e2,
222
+ J_vc=1e2,
223
+ w_tr=settings.scp.w_tr,
224
+ lam_cost=settings.scp.lam_cost,
225
+ lam_vc=settings.scp.lam_vc,
226
+ lam_vb=settings.scp.lam_vb,
227
+ n_x=settings.sim.n_states,
228
+ n_u=settings.sim.n_controls,
229
+ N=settings.scp.n,
230
+ X=[settings.sim.x.guess.copy()],
231
+ U=[settings.sim.u.guess.copy()],
232
+ V_history=[],
233
+ VC_history=[],
234
+ TR_history=[],
235
+ )
236
+
237
+
238
+ class Algorithm(ABC):
239
+ """Abstract base class for successive convexification algorithms.
240
+
241
+ This class defines the interface for SCP algorithms used in trajectory
242
+ optimization. Implementations should remain minimal and functional,
243
+ delegating state management to the AlgorithmState dataclass.
244
+
245
+ The two core methods mirror the SCP workflow:
246
+
247
+ - initialize: Store compiled infrastructure and warm-start solvers
248
+ - step: Execute one convex subproblem iteration
249
+
250
+ Immutable components (ocp, discretization_solver, jax_constraints, etc.) are
251
+ stored during initialize(). Mutable configuration (params, settings) is passed
252
+ per-step to support runtime parameter updates and tolerance tuning.
253
+
254
+ !!! tip "Statefullness"
255
+ Avoid storing mutable iteration state (costs, weights, trajectories) on
256
+ ``self``. All iteration state should live in :class:`AlgorithmState` or
257
+ a subclass thereof, passed explicitly to ``step()``. This keeps algorithm
258
+ classes stateless w.r.t. iteration, making data flow explicit and staying
259
+ close to functional programming principles where possible.
260
+
261
+ Example:
262
+ Implementing a custom algorithm::
263
+
264
+ class MyAlgorithm(Algorithm):
265
+ def initialize(self, ocp, discretization_solver,
266
+ jax_constraints, solve_ocp, emitter,
267
+ params, settings):
268
+ # Store compiled infrastructure
269
+ self._ocp = ocp
270
+ self._discretization_solver = discretization_solver
271
+ self._jax_constraints = jax_constraints
272
+ self._solve_ocp = solve_ocp
273
+ self._emitter = emitter
274
+ # Warm-start with initial params/settings...
275
+
276
+ def step(self, state, params, settings):
277
+ # Run one iteration using self._* and per-step params/settings
278
+ return converged
279
+ """
280
+
281
+ @abstractmethod
282
+ def initialize(
283
+ self,
284
+ ocp: "cp.Problem",
285
+ discretization_solver: callable,
286
+ jax_constraints: "LoweredJaxConstraints",
287
+ solve_ocp: callable,
288
+ emitter: callable,
289
+ params: dict,
290
+ settings: "Config",
291
+ ) -> None:
292
+ """Initialize the algorithm and store compiled infrastructure.
293
+
294
+ This method stores immutable components and performs any setup required
295
+ before the SCP loop begins (e.g., warm-starting solvers). The params and
296
+ settings are passed for warm-start but may change between steps.
297
+
298
+ Args:
299
+ ocp: The CVXPy optimal control problem
300
+ discretization_solver: Compiled discretization solver function
301
+ jax_constraints: JIT-compiled JAX constraint functions
302
+ solve_ocp: Callable that solves the OCP (captures solver config)
303
+ emitter: Callback for emitting iteration progress data
304
+ params: Problem parameters dictionary (for warm-start only)
305
+ settings: Configuration object (for warm-start only)
306
+ """
307
+ ...
308
+
309
+ @abstractmethod
310
+ def step(
311
+ self,
312
+ state: AlgorithmState,
313
+ params: dict,
314
+ settings: "Config",
315
+ ) -> bool:
316
+ """Execute one iteration of the SCP algorithm.
317
+
318
+ This method solves a single convex subproblem, updates the algorithm
319
+ state in place, and returns whether convergence criteria are met.
320
+
321
+ Uses stored infrastructure (ocp, discretization_solver, etc.) with
322
+ per-step params and settings to support runtime modifications.
323
+
324
+ Args:
325
+ state: Mutable algorithm state (modified in place)
326
+ params: Problem parameters dictionary (may change between steps)
327
+ settings: Configuration object (may change between steps)
328
+
329
+ Returns:
330
+ True if convergence criteria are satisfied, False otherwise.
331
+ """
332
+ ...
333
+
334
+ @abstractmethod
335
+ def citation(self) -> List[str]:
336
+ """Return BibTeX citations for this algorithm.
337
+
338
+ Implementations should return a list of BibTeX entry strings for the
339
+ papers that should be cited when using this algorithm.
340
+
341
+ Returns:
342
+ List of BibTeX citation strings.
343
+
344
+ Example:
345
+ Getting citations for an algorithm::
346
+
347
+ algorithm = PenalizedTrustRegion()
348
+ for bibtex in algorithm.citation():
349
+ print(bibtex)
350
+ """
351
+ ...