openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
openscvx/problem.py
ADDED
|
@@ -0,0 +1,734 @@
|
|
|
1
|
+
"""Core optimization problem interface for trajectory optimization.
|
|
2
|
+
|
|
3
|
+
This module provides the Problem class, the main entry point for defining
|
|
4
|
+
and solving trajectory optimization problems using Sequential Convex Programming (SCP).
|
|
5
|
+
|
|
6
|
+
Example:
|
|
7
|
+
The prototypical flow is to define a problem, then initialize, solve, and post-process the
|
|
8
|
+
results
|
|
9
|
+
|
|
10
|
+
problem = Problem(dynamics, constraints, states, controls, N, time)
|
|
11
|
+
problem.initialize()
|
|
12
|
+
result = problem.solve()
|
|
13
|
+
result = problem.post_process()
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import copy
|
|
17
|
+
import os
|
|
18
|
+
import pickle
|
|
19
|
+
import queue
|
|
20
|
+
import threading
|
|
21
|
+
import time
|
|
22
|
+
from typing import TYPE_CHECKING, List, Optional, Union
|
|
23
|
+
|
|
24
|
+
import jax
|
|
25
|
+
|
|
26
|
+
os.environ["EQX_ON_ERROR"] = "nan"
|
|
27
|
+
|
|
28
|
+
from openscvx.algorithms import (
|
|
29
|
+
AlgorithmState,
|
|
30
|
+
OptimizationResults,
|
|
31
|
+
PenalizedTrustRegion,
|
|
32
|
+
)
|
|
33
|
+
from openscvx.config import (
|
|
34
|
+
Config,
|
|
35
|
+
ConvexSolverConfig,
|
|
36
|
+
DevConfig,
|
|
37
|
+
DiscretizationConfig,
|
|
38
|
+
PropagationConfig,
|
|
39
|
+
ScpConfig,
|
|
40
|
+
SimConfig,
|
|
41
|
+
)
|
|
42
|
+
from openscvx.discretization import get_discretization_solver
|
|
43
|
+
from openscvx.expert import ByofSpec
|
|
44
|
+
from openscvx.lowered import LoweredProblem, ParameterDict
|
|
45
|
+
from openscvx.lowered.dynamics import Dynamics
|
|
46
|
+
from openscvx.lowered.jax_constraints import (
|
|
47
|
+
LoweredCrossNodeConstraint,
|
|
48
|
+
LoweredJaxConstraints,
|
|
49
|
+
LoweredNodalConstraint,
|
|
50
|
+
)
|
|
51
|
+
from openscvx.propagation import get_propagation_solver, propagate_trajectory_results
|
|
52
|
+
from openscvx.solvers import optimal_control_problem
|
|
53
|
+
from openscvx.symbolic.builder import preprocess_symbolic_problem
|
|
54
|
+
from openscvx.symbolic.constraint_set import ConstraintSet
|
|
55
|
+
from openscvx.symbolic.expr import CTCS, Constraint
|
|
56
|
+
from openscvx.symbolic.expr.control import Control
|
|
57
|
+
from openscvx.symbolic.expr.state import State
|
|
58
|
+
from openscvx.symbolic.lower import lower_symbolic_problem
|
|
59
|
+
from openscvx.symbolic.problem import SymbolicProblem
|
|
60
|
+
from openscvx.symbolic.time import Time
|
|
61
|
+
from openscvx.utils import printing, profiling
|
|
62
|
+
from openscvx.utils.caching import (
|
|
63
|
+
get_solver_cache_paths,
|
|
64
|
+
load_or_compile_discretization_solver,
|
|
65
|
+
load_or_compile_propagation_solver,
|
|
66
|
+
prime_propagation_solver,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
if TYPE_CHECKING:
|
|
70
|
+
import cvxpy as cp
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class Problem:
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
dynamics: dict,
|
|
77
|
+
constraints: List[Union[Constraint, CTCS]],
|
|
78
|
+
states: List[State],
|
|
79
|
+
controls: List[Control],
|
|
80
|
+
N: int,
|
|
81
|
+
time: Time,
|
|
82
|
+
*,
|
|
83
|
+
dynamics_prop: Optional[dict] = None,
|
|
84
|
+
states_prop: Optional[List[State]] = None,
|
|
85
|
+
licq_min=0.0,
|
|
86
|
+
licq_max=1e-4,
|
|
87
|
+
time_dilation_factor_min=0.3,
|
|
88
|
+
time_dilation_factor_max=3.0,
|
|
89
|
+
byof: Optional[ByofSpec] = None,
|
|
90
|
+
):
|
|
91
|
+
"""The primary class in charge of compiling and exporting the solvers.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
dynamics (dict): Dictionary mapping state names to their dynamics expressions.
|
|
95
|
+
Each key should be a state name, and each value should be an Expr
|
|
96
|
+
representing the derivative of that state.
|
|
97
|
+
constraints (List[Union[CTCSConstraint, NodalConstraint]]):
|
|
98
|
+
List of constraints decorated with @ctcs or @nodal
|
|
99
|
+
states (List[State]): List of State objects representing the state variables.
|
|
100
|
+
May optionally include a State named "time" (see time parameter below).
|
|
101
|
+
controls (List[Control]): List of Control objects representing the control variables
|
|
102
|
+
N (int): Number of segments in the trajectory
|
|
103
|
+
time (Time): Time configuration object with initial, final, min, max.
|
|
104
|
+
Required. If including a "time" state in states, the Time object will be ignored
|
|
105
|
+
and time properties should be set on the time State object instead.
|
|
106
|
+
dynamics_prop (dict, optional): Dictionary mapping EXTRA state names to their
|
|
107
|
+
dynamics expressions for propagation. Only specify additional states beyond
|
|
108
|
+
optimization states (e.g., {"distance": speed}). Do NOT duplicate optimization
|
|
109
|
+
state dynamics here.
|
|
110
|
+
states_prop (List[State], optional): List of EXTRA State objects for propagation only.
|
|
111
|
+
Only specify additional states beyond optimization states. Used with dynamics_prop.
|
|
112
|
+
licq_min: Minimum LICQ constraint value
|
|
113
|
+
licq_max: Maximum LICQ constraint value
|
|
114
|
+
time_dilation_factor_min: Minimum time dilation factor
|
|
115
|
+
time_dilation_factor_max: Maximum time dilation factor
|
|
116
|
+
byof: Expert mode only. Raw JAX functions to bypass symbolic layer.
|
|
117
|
+
See :class:`openscvx.expert.ByofSpec` for detailed documentation.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
None
|
|
121
|
+
|
|
122
|
+
Note:
|
|
123
|
+
There are two approaches for handling time:
|
|
124
|
+
1. Auto-create (simple): Don't include "time" in states, provide Time object
|
|
125
|
+
2. User-provided (for time-dependent constraints): Include "time" State in states and
|
|
126
|
+
in dynamics dict, don't provide Time object
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
# Symbolic Preprocessing & Augmentation
|
|
130
|
+
self.symbolic: SymbolicProblem = preprocess_symbolic_problem(
|
|
131
|
+
dynamics=dynamics,
|
|
132
|
+
constraints=ConstraintSet(unsorted=list(constraints)),
|
|
133
|
+
states=states,
|
|
134
|
+
controls=controls,
|
|
135
|
+
N=N,
|
|
136
|
+
time=time,
|
|
137
|
+
licq_min=licq_min,
|
|
138
|
+
licq_max=licq_max,
|
|
139
|
+
time_dilation_factor_min=time_dilation_factor_min,
|
|
140
|
+
time_dilation_factor_max=time_dilation_factor_max,
|
|
141
|
+
dynamics_prop_extra=dynamics_prop,
|
|
142
|
+
states_prop_extra=states_prop,
|
|
143
|
+
byof=byof,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Validate byof early (after preprocessing, before lowering) to fail fast
|
|
147
|
+
if byof is not None:
|
|
148
|
+
from openscvx.expert.validation import validate_byof
|
|
149
|
+
|
|
150
|
+
# Calculate unified state and control dimensions from preprocessed states/controls
|
|
151
|
+
# These dimensions include symbolic augmentation (time, CTCS) but not byof CTCS
|
|
152
|
+
# augmentation, which is exactly what user byof functions will see
|
|
153
|
+
n_x = sum(
|
|
154
|
+
state.shape[0] if len(state.shape) > 0 else 1 for state in self.symbolic.states
|
|
155
|
+
)
|
|
156
|
+
n_u = sum(
|
|
157
|
+
control.shape[0] if len(control.shape) > 0 else 1
|
|
158
|
+
for control in self.symbolic.controls
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
validate_byof(byof, self.symbolic.states, n_x, n_u, N)
|
|
162
|
+
|
|
163
|
+
# Lower to JAX and CVXPy (byof handling happens inside lower_symbolic_problem)
|
|
164
|
+
self._lowered: LoweredProblem = lower_symbolic_problem(self.symbolic, byof=byof)
|
|
165
|
+
|
|
166
|
+
# Store parameters in two forms:
|
|
167
|
+
self._parameters = self.symbolic.parameters # Plain dict for JAX functions
|
|
168
|
+
# Wrapper dict for user access that auto-syncs
|
|
169
|
+
self._parameter_wrapper = ParameterDict(self, self._parameters, self.symbolic.parameters)
|
|
170
|
+
|
|
171
|
+
# Setup SCP Configuration
|
|
172
|
+
self.settings = Config(
|
|
173
|
+
sim=SimConfig(
|
|
174
|
+
x=self._lowered.x_unified,
|
|
175
|
+
x_prop=self._lowered.x_prop_unified,
|
|
176
|
+
u=self._lowered.u_unified,
|
|
177
|
+
total_time=self._lowered.x_unified.initial[self._lowered.x_unified.time_slice][0],
|
|
178
|
+
n_states=self._lowered.x_unified.initial.shape[0],
|
|
179
|
+
n_states_prop=self._lowered.x_prop_unified.initial.shape[0],
|
|
180
|
+
ctcs_node_intervals=self.symbolic.node_intervals,
|
|
181
|
+
),
|
|
182
|
+
scp=ScpConfig(
|
|
183
|
+
n=N,
|
|
184
|
+
w_tr_max_scaling_factor=1e2, # Maximum Trust Region Weight
|
|
185
|
+
),
|
|
186
|
+
dis=DiscretizationConfig(),
|
|
187
|
+
dev=DevConfig(),
|
|
188
|
+
cvx=ConvexSolverConfig(),
|
|
189
|
+
prp=PropagationConfig(),
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# OCP construction happens in initialize() so users can modify
|
|
193
|
+
# settings (like uniform_time_grid) between __init__ and initialize()
|
|
194
|
+
self._optimal_control_problem: cp.Problem = None
|
|
195
|
+
self._discretization_solver: callable = None
|
|
196
|
+
self._solve_ocp: callable = None # Solver callable (built during initialize)
|
|
197
|
+
|
|
198
|
+
# Set up emitter & thread only if printing is enabled
|
|
199
|
+
if self.settings.dev.printing:
|
|
200
|
+
self.print_queue = queue.Queue()
|
|
201
|
+
self.emitter_function = lambda data: self.print_queue.put(data)
|
|
202
|
+
self.print_thread = threading.Thread(
|
|
203
|
+
target=printing.intermediate,
|
|
204
|
+
args=(self.print_queue, self.settings),
|
|
205
|
+
daemon=True,
|
|
206
|
+
)
|
|
207
|
+
self.print_thread.start()
|
|
208
|
+
else:
|
|
209
|
+
# no-op emitter; nothing ever gets queued or printed
|
|
210
|
+
self.emitter_function = lambda data: None
|
|
211
|
+
|
|
212
|
+
self.timing_init = None
|
|
213
|
+
self.timing_solve = None
|
|
214
|
+
self.timing_post = None
|
|
215
|
+
|
|
216
|
+
# Compiled dynamics (vmapped versions, set in initialize())
|
|
217
|
+
self._compiled_dynamics: Optional[Dynamics] = None
|
|
218
|
+
self._compiled_dynamics_prop: Optional[Dynamics] = None
|
|
219
|
+
|
|
220
|
+
# Compiled constraints (JIT-compiled versions, set in initialize())
|
|
221
|
+
self._compiled_constraints: Optional[LoweredJaxConstraints] = None
|
|
222
|
+
|
|
223
|
+
# Solver state (created fresh for each solve)
|
|
224
|
+
self._state: Optional[AlgorithmState] = None
|
|
225
|
+
|
|
226
|
+
# Final solution state (saved after successful solve)
|
|
227
|
+
self._solution: Optional[AlgorithmState] = None
|
|
228
|
+
|
|
229
|
+
# SCP algorithm (currently hardcoded to PTR)
|
|
230
|
+
self._algorithm = PenalizedTrustRegion()
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def parameters(self):
|
|
234
|
+
"""Get the parameters dictionary.
|
|
235
|
+
|
|
236
|
+
The returned dictionary automatically syncs to CVXPy when modified:
|
|
237
|
+
problem.parameters["obs_radius"] = 2.0 # Auto-syncs to CVXPy
|
|
238
|
+
problem.parameters.update({"gate_0_center": center}) # Also syncs
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
ParameterDict: Special dict that syncs to CVXPy on assignment
|
|
242
|
+
"""
|
|
243
|
+
return self._parameter_wrapper
|
|
244
|
+
|
|
245
|
+
@parameters.setter
|
|
246
|
+
def parameters(self, new_params: dict):
|
|
247
|
+
"""Replace the entire parameters dictionary and sync to CVXPy.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
new_params: New parameters dictionary
|
|
251
|
+
"""
|
|
252
|
+
self._parameters = dict(new_params) # Create new plain dict
|
|
253
|
+
self._parameter_wrapper = ParameterDict(self, self._parameters, new_params)
|
|
254
|
+
self._sync_parameters()
|
|
255
|
+
|
|
256
|
+
def _sync_parameters(self):
|
|
257
|
+
"""Sync all parameter values to CVXPy parameters."""
|
|
258
|
+
if self._lowered.cvxpy_params is not None:
|
|
259
|
+
for name, value in self._parameter_wrapper.items():
|
|
260
|
+
if name in self._lowered.cvxpy_params:
|
|
261
|
+
self._lowered.cvxpy_params[name].value = value
|
|
262
|
+
|
|
263
|
+
@property
|
|
264
|
+
def state(self) -> Optional[AlgorithmState]:
|
|
265
|
+
"""Access the current solver state.
|
|
266
|
+
|
|
267
|
+
The solver state contains all mutable state from the SCP iterations,
|
|
268
|
+
including current guesses, costs, weights, and history.
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
AlgorithmState if initialized, None otherwise
|
|
272
|
+
|
|
273
|
+
Example:
|
|
274
|
+
When using `Problem.step()` can use the state to check convergence _etc._
|
|
275
|
+
|
|
276
|
+
problem.initialize()
|
|
277
|
+
problem.step()
|
|
278
|
+
print(f"Iteration {problem.state.k}, J_tr={problem.state.J_tr}")
|
|
279
|
+
"""
|
|
280
|
+
return self._state
|
|
281
|
+
|
|
282
|
+
@property
|
|
283
|
+
def lowered(self) -> LoweredProblem:
|
|
284
|
+
"""Access the lowered problem containing JAX/CVXPy objects.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
LoweredProblem with dynamics, constraints, unified interfaces, and CVXPy vars
|
|
288
|
+
"""
|
|
289
|
+
return self._lowered
|
|
290
|
+
|
|
291
|
+
@property
|
|
292
|
+
def x_unified(self):
|
|
293
|
+
"""Unified state interface (delegates to lowered.x_unified)."""
|
|
294
|
+
return self._lowered.x_unified
|
|
295
|
+
|
|
296
|
+
@property
|
|
297
|
+
def u_unified(self):
|
|
298
|
+
"""Unified control interface (delegates to lowered.u_unified)."""
|
|
299
|
+
return self._lowered.u_unified
|
|
300
|
+
|
|
301
|
+
@property
|
|
302
|
+
def slices(self) -> dict[str, slice]:
|
|
303
|
+
"""Get mapping of state and control names to their slices in unified vectors.
|
|
304
|
+
|
|
305
|
+
This property returns a dictionary mapping each state and control variable name
|
|
306
|
+
to its slice in the respective unified vector. This is particularly useful for
|
|
307
|
+
expert users working with byof (bring-your-own functions) who need to manually
|
|
308
|
+
index into the unified x and u vectors.
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
dict[str, slice]: Dictionary mapping variable names to slice objects.
|
|
312
|
+
State variables map to slices in the x vector.
|
|
313
|
+
Control variables map to slices in the u vector.
|
|
314
|
+
|
|
315
|
+
Example:
|
|
316
|
+
problem = ox.Problem(dynamics, states, controls, ...)
|
|
317
|
+
print(problem.slices)
|
|
318
|
+
# {'position': slice(0, 3), 'velocity': slice(3, 6), 'theta': slice(0, 1)}
|
|
319
|
+
|
|
320
|
+
# Use in byof functions
|
|
321
|
+
byof = {
|
|
322
|
+
"nodal_constraints": [
|
|
323
|
+
lambda x, u, node, params: x[problem.slices["velocity"][0]] - 10.0,
|
|
324
|
+
lambda x, u, node, params: u[problem.slices["theta"][0]] - 1.57,
|
|
325
|
+
]
|
|
326
|
+
}
|
|
327
|
+
"""
|
|
328
|
+
slices = {}
|
|
329
|
+
slices.update({state.name: state.slice for state in self.symbolic.states})
|
|
330
|
+
slices.update({control.name: control.slice for control in self.symbolic.controls})
|
|
331
|
+
return slices
|
|
332
|
+
|
|
333
|
+
def _format_result(self, state: AlgorithmState, converged: bool) -> OptimizationResults:
|
|
334
|
+
"""Format solver state as an OptimizationResults object.
|
|
335
|
+
|
|
336
|
+
Converts the internal solver state into a user-facing results object,
|
|
337
|
+
mapping state/control arrays to named fields based on symbolic metadata.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
state: The AlgorithmState to extract results from.
|
|
341
|
+
converged: Whether the optimization converged.
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
OptimizationResults containing the solution data.
|
|
345
|
+
"""
|
|
346
|
+
# Build nodes dictionary with all states and controls
|
|
347
|
+
nodes_dict = {}
|
|
348
|
+
|
|
349
|
+
# Add all states (user-defined and augmented)
|
|
350
|
+
for sym_state in self.symbolic.states:
|
|
351
|
+
nodes_dict[sym_state.name] = state.x[:, sym_state._slice]
|
|
352
|
+
|
|
353
|
+
# Add all controls (user-defined and augmented)
|
|
354
|
+
for control in self.symbolic.controls:
|
|
355
|
+
nodes_dict[control.name] = state.u[:, control._slice]
|
|
356
|
+
|
|
357
|
+
return OptimizationResults(
|
|
358
|
+
converged=converged,
|
|
359
|
+
t_final=state.x[:, self.settings.sim.time_slice][-1],
|
|
360
|
+
nodes=nodes_dict,
|
|
361
|
+
trajectory={}, # Populated by post_process
|
|
362
|
+
_states=self.symbolic.states_prop, # Use propagation states for trajectory dict
|
|
363
|
+
_controls=self.symbolic.controls,
|
|
364
|
+
X=state.X, # Single source of truth - x and u are properties
|
|
365
|
+
U=state.U,
|
|
366
|
+
discretization_history=state.V_history,
|
|
367
|
+
J_tr_history=state.J_tr,
|
|
368
|
+
J_vb_history=state.J_vb,
|
|
369
|
+
J_vc_history=state.J_vc,
|
|
370
|
+
TR_history=state.TR_history,
|
|
371
|
+
VC_history=state.VC_history,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
def initialize(self):
|
|
375
|
+
"""Compile dynamics, constraints, and solvers; prepare for optimization.
|
|
376
|
+
|
|
377
|
+
This method vmaps dynamics, JIT-compiles constraints, builds the convex
|
|
378
|
+
subproblem, and initializes the solver state. Must be called before solve().
|
|
379
|
+
|
|
380
|
+
Example:
|
|
381
|
+
Prior to calling the `.solve()` method it is necessary to initialize the problem
|
|
382
|
+
|
|
383
|
+
problem = Problem(dynamics, constraints, states, controls, N, time)
|
|
384
|
+
problem.initialize() # Compile and prepare
|
|
385
|
+
problem.solve() # Run optimization
|
|
386
|
+
"""
|
|
387
|
+
printing.intro()
|
|
388
|
+
|
|
389
|
+
# Print problem summary
|
|
390
|
+
printing.print_problem_summary(self.settings, self._lowered)
|
|
391
|
+
|
|
392
|
+
# Enable the profiler
|
|
393
|
+
pr = profiling.profiling_start(self.settings.dev.profiling)
|
|
394
|
+
|
|
395
|
+
t_0_while = time.time()
|
|
396
|
+
# Ensure parameter sizes and normalization are correct
|
|
397
|
+
self.settings.scp.__post_init__()
|
|
398
|
+
self.settings.sim.__post_init__()
|
|
399
|
+
|
|
400
|
+
# Create compiled (vmapped) dynamics as new instances
|
|
401
|
+
# This preserves the original un-vmapped versions in _lowered
|
|
402
|
+
self._compiled_dynamics = Dynamics(
|
|
403
|
+
f=jax.vmap(self._lowered.dynamics.f, in_axes=(0, 0, 0, None)),
|
|
404
|
+
A=jax.vmap(self._lowered.dynamics.A, in_axes=(0, 0, 0, None)),
|
|
405
|
+
B=jax.vmap(self._lowered.dynamics.B, in_axes=(0, 0, 0, None)),
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
self._compiled_dynamics_prop = Dynamics(
|
|
409
|
+
f=jax.vmap(self._lowered.dynamics_prop.f, in_axes=(0, 0, 0, None)),
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
# Create compiled (JIT-compiled) constraints as new instances
|
|
413
|
+
# This preserves the original un-JIT'd versions in _lowered
|
|
414
|
+
# TODO: (haynec) switch to AOT instead of JIT
|
|
415
|
+
compiled_nodal = [
|
|
416
|
+
LoweredNodalConstraint(
|
|
417
|
+
func=jax.jit(c.func),
|
|
418
|
+
grad_g_x=jax.jit(c.grad_g_x),
|
|
419
|
+
grad_g_u=jax.jit(c.grad_g_u),
|
|
420
|
+
nodes=c.nodes,
|
|
421
|
+
)
|
|
422
|
+
for c in self._lowered.jax_constraints.nodal
|
|
423
|
+
]
|
|
424
|
+
|
|
425
|
+
compiled_cross_node = [
|
|
426
|
+
LoweredCrossNodeConstraint(
|
|
427
|
+
func=jax.jit(c.func),
|
|
428
|
+
grad_g_X=jax.jit(c.grad_g_X),
|
|
429
|
+
grad_g_U=jax.jit(c.grad_g_U),
|
|
430
|
+
)
|
|
431
|
+
for c in self._lowered.jax_constraints.cross_node
|
|
432
|
+
]
|
|
433
|
+
|
|
434
|
+
self._compiled_constraints = LoweredJaxConstraints(
|
|
435
|
+
nodal=compiled_nodal,
|
|
436
|
+
cross_node=compiled_cross_node,
|
|
437
|
+
ctcs=self._lowered.jax_constraints.ctcs, # CTCS aren't JIT-compiled here
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
# Generate solvers using compiled (vmapped) dynamics
|
|
441
|
+
self._discretization_solver = get_discretization_solver(
|
|
442
|
+
self._compiled_dynamics, self.settings
|
|
443
|
+
)
|
|
444
|
+
self._propagation_solver = get_propagation_solver(
|
|
445
|
+
self._compiled_dynamics_prop.f, self.settings
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
# Build optimal control problem using LoweredProblem
|
|
449
|
+
self._optimal_control_problem = optimal_control_problem(self.settings, self._lowered)
|
|
450
|
+
|
|
451
|
+
# Get cache file paths using symbolic AST hashing
|
|
452
|
+
# This is more stable than hashing lowered JAX code
|
|
453
|
+
dis_solver_file, prop_solver_file = get_solver_cache_paths(
|
|
454
|
+
self.symbolic,
|
|
455
|
+
dt=self.settings.prp.dt,
|
|
456
|
+
total_time=self.settings.sim.total_time,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Compile the discretization solver
|
|
460
|
+
self._discretization_solver = load_or_compile_discretization_solver(
|
|
461
|
+
self._discretization_solver,
|
|
462
|
+
dis_solver_file,
|
|
463
|
+
self._parameters, # Plain dict for JAX
|
|
464
|
+
self.settings.scp.n,
|
|
465
|
+
self.settings.sim.n_states,
|
|
466
|
+
self.settings.sim.n_controls,
|
|
467
|
+
save_compiled=self.settings.sim.save_compiled,
|
|
468
|
+
debug=self.settings.dev.debug,
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
# Setup propagation solver parameters
|
|
472
|
+
dtau = 1.0 / (self.settings.scp.n - 1)
|
|
473
|
+
dt_max = self.settings.sim.u.max[self.settings.sim.time_dilation_slice][0] * dtau
|
|
474
|
+
self.settings.prp.max_tau_len = int(dt_max / self.settings.prp.dt) + 2
|
|
475
|
+
|
|
476
|
+
# Compile the propagation solver
|
|
477
|
+
self._propagation_solver = load_or_compile_propagation_solver(
|
|
478
|
+
self._propagation_solver,
|
|
479
|
+
prop_solver_file,
|
|
480
|
+
self._parameters, # Plain dict for JAX
|
|
481
|
+
self.settings.sim.n_states_prop,
|
|
482
|
+
self.settings.sim.n_controls,
|
|
483
|
+
self.settings.prp.max_tau_len,
|
|
484
|
+
save_compiled=self.settings.sim.save_compiled,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
# Build solver callable (handle CVXPyGen if enabled)
|
|
488
|
+
if self.settings.cvx.cvxpygen:
|
|
489
|
+
try:
|
|
490
|
+
from solver.cpg_solver import cpg_solve
|
|
491
|
+
|
|
492
|
+
with open("solver/problem.pickle", "rb") as f:
|
|
493
|
+
pickle.load(f)
|
|
494
|
+
self._optimal_control_problem.register_solve("CPG", cpg_solve)
|
|
495
|
+
solver_args = self.settings.cvx.solver_args
|
|
496
|
+
self._solve_ocp = lambda: self._optimal_control_problem.solve(
|
|
497
|
+
method="CPG", **solver_args
|
|
498
|
+
)
|
|
499
|
+
except ImportError:
|
|
500
|
+
raise ImportError(
|
|
501
|
+
"cvxpygen solver not found. Make sure cvxpygen is installed and code "
|
|
502
|
+
"generation has been run. Install with: pip install openscvx[cvxpygen]"
|
|
503
|
+
)
|
|
504
|
+
else:
|
|
505
|
+
solver = self.settings.cvx.solver
|
|
506
|
+
solver_args = self.settings.cvx.solver_args
|
|
507
|
+
self._solve_ocp = lambda: self._optimal_control_problem.solve(
|
|
508
|
+
solver=solver, **solver_args
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
# Initialize the SCP algorithm
|
|
512
|
+
print("Initializing the SCvx Subproblem Solver...")
|
|
513
|
+
self._algorithm.initialize(
|
|
514
|
+
self._optimal_control_problem,
|
|
515
|
+
self._discretization_solver,
|
|
516
|
+
self._compiled_constraints,
|
|
517
|
+
self._solve_ocp,
|
|
518
|
+
self.emitter_function,
|
|
519
|
+
self._parameters, # For warm-start only
|
|
520
|
+
self.settings, # For warm-start only
|
|
521
|
+
)
|
|
522
|
+
print("✓ SCvx Subproblem Solver initialized")
|
|
523
|
+
|
|
524
|
+
# Create fresh solver state
|
|
525
|
+
self._state = AlgorithmState.from_settings(self.settings)
|
|
526
|
+
|
|
527
|
+
t_f_while = time.time()
|
|
528
|
+
self.timing_init = t_f_while - t_0_while
|
|
529
|
+
print("Total Initialization Time: ", self.timing_init)
|
|
530
|
+
|
|
531
|
+
# Prime the propagation solver
|
|
532
|
+
prime_propagation_solver(self._propagation_solver, self._parameters, self.settings)
|
|
533
|
+
|
|
534
|
+
profiling.profiling_end(pr, "initialize")
|
|
535
|
+
|
|
536
|
+
def reset(self):
|
|
537
|
+
"""Reset solver state to re-run optimization from initial conditions.
|
|
538
|
+
|
|
539
|
+
Creates fresh AlgorithmState while preserving compiled dynamics and solvers.
|
|
540
|
+
Use this to run multiple optimizations without re-initializing.
|
|
541
|
+
|
|
542
|
+
Raises:
|
|
543
|
+
ValueError: If initialize() has not been called yet.
|
|
544
|
+
|
|
545
|
+
Example:
|
|
546
|
+
After calling `.step()` it may be necessary to reset the problem back to the initial
|
|
547
|
+
conditions
|
|
548
|
+
|
|
549
|
+
problem.initialize()
|
|
550
|
+
result1 = problem.step()
|
|
551
|
+
problem.reset()
|
|
552
|
+
result2 = problem.solve() # Fresh run with same setup
|
|
553
|
+
"""
|
|
554
|
+
if self._compiled_dynamics is None:
|
|
555
|
+
raise ValueError("Problem has not been initialized. Call initialize() first")
|
|
556
|
+
|
|
557
|
+
# Create fresh solver state from settings
|
|
558
|
+
self._state = AlgorithmState.from_settings(self.settings)
|
|
559
|
+
|
|
560
|
+
# Reset solution
|
|
561
|
+
self._solution = None
|
|
562
|
+
|
|
563
|
+
# Reset timing
|
|
564
|
+
self.timing_solve = None
|
|
565
|
+
self.timing_post = None
|
|
566
|
+
|
|
567
|
+
def step(self) -> dict:
|
|
568
|
+
"""Perform a single SCP iteration.
|
|
569
|
+
|
|
570
|
+
Designed for real-time plotting and interactive optimization. Performs one
|
|
571
|
+
iteration including subproblem solve, state update, and progress emission.
|
|
572
|
+
|
|
573
|
+
Note:
|
|
574
|
+
This method is NOT idempotent - it mutates internal state and advances
|
|
575
|
+
the iteration counter. Use reset() to return to initial conditions.
|
|
576
|
+
|
|
577
|
+
Returns:
|
|
578
|
+
dict: Contains "converged" (bool) and current iteration state
|
|
579
|
+
|
|
580
|
+
Example:
|
|
581
|
+
Call `.step()` manually in a loop to control the algorithm directly
|
|
582
|
+
|
|
583
|
+
problem.initialize()
|
|
584
|
+
while not problem.step()["converged"]:
|
|
585
|
+
plot_trajectory(problem.state.trajs[-1])
|
|
586
|
+
"""
|
|
587
|
+
if self._state is None:
|
|
588
|
+
raise ValueError("Problem has not been initialized. Call initialize() first")
|
|
589
|
+
|
|
590
|
+
converged = self._algorithm.step(
|
|
591
|
+
self._state,
|
|
592
|
+
self._parameters, # May change between steps
|
|
593
|
+
self.settings, # May change between steps
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
# Return dict matching original API
|
|
597
|
+
return {
|
|
598
|
+
"converged": converged,
|
|
599
|
+
"scp_k": self._state.k,
|
|
600
|
+
"scp_J_tr": self._state.J_tr,
|
|
601
|
+
"scp_J_vb": self._state.J_vb,
|
|
602
|
+
"scp_J_vc": self._state.J_vc,
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
def solve(
|
|
606
|
+
self, max_iters: Optional[int] = None, continuous: bool = False
|
|
607
|
+
) -> OptimizationResults:
|
|
608
|
+
"""Run the SCP algorithm until convergence or iteration limit.
|
|
609
|
+
|
|
610
|
+
Args:
|
|
611
|
+
max_iters: Maximum iterations (default: settings.scp.k_max)
|
|
612
|
+
continuous: If True, run all iterations regardless of convergence
|
|
613
|
+
|
|
614
|
+
Returns:
|
|
615
|
+
OptimizationResults with trajectory and convergence info
|
|
616
|
+
(call post_process() for full propagation)
|
|
617
|
+
"""
|
|
618
|
+
# Sync parameters before solving
|
|
619
|
+
self._sync_parameters()
|
|
620
|
+
|
|
621
|
+
required = [
|
|
622
|
+
self._compiled_dynamics,
|
|
623
|
+
self._compiled_constraints,
|
|
624
|
+
self._optimal_control_problem,
|
|
625
|
+
self._discretization_solver,
|
|
626
|
+
self._state,
|
|
627
|
+
]
|
|
628
|
+
if any(r is None for r in required):
|
|
629
|
+
raise ValueError("Problem has not been initialized. Call initialize() before solve()")
|
|
630
|
+
|
|
631
|
+
# Enable the profiler
|
|
632
|
+
pr = profiling.profiling_start(self.settings.dev.profiling)
|
|
633
|
+
|
|
634
|
+
t_0_while = time.time()
|
|
635
|
+
# Print top header for solver results
|
|
636
|
+
printing.header()
|
|
637
|
+
|
|
638
|
+
k_max = max_iters if max_iters is not None else self.settings.scp.k_max
|
|
639
|
+
|
|
640
|
+
while self._state.k <= k_max:
|
|
641
|
+
result = self.step()
|
|
642
|
+
if result["converged"] and not continuous:
|
|
643
|
+
break
|
|
644
|
+
|
|
645
|
+
t_f_while = time.time()
|
|
646
|
+
self.timing_solve = t_f_while - t_0_while
|
|
647
|
+
|
|
648
|
+
while self.print_queue.qsize() > 0:
|
|
649
|
+
time.sleep(0.1)
|
|
650
|
+
|
|
651
|
+
# Print bottom footer for solver results as well as total computation time
|
|
652
|
+
printing.footer()
|
|
653
|
+
|
|
654
|
+
profiling.profiling_end(pr, "solve")
|
|
655
|
+
|
|
656
|
+
# Store solution state
|
|
657
|
+
self._solution = copy.deepcopy(self._state)
|
|
658
|
+
|
|
659
|
+
return self._format_result(self._state, self._state.k <= k_max)
|
|
660
|
+
|
|
661
|
+
def post_process(self) -> OptimizationResults:
|
|
662
|
+
"""Propagate solution through full nonlinear dynamics for high-fidelity trajectory.
|
|
663
|
+
|
|
664
|
+
Integrates the converged SCP solution through the nonlinear dynamics to
|
|
665
|
+
produce x_full, u_full, and t_full. Call after solve() for final results.
|
|
666
|
+
|
|
667
|
+
Returns:
|
|
668
|
+
OptimizationResults with propagated trajectory fields
|
|
669
|
+
|
|
670
|
+
Raises:
|
|
671
|
+
ValueError: If solve() has not been called yet.
|
|
672
|
+
"""
|
|
673
|
+
if self._solution is None:
|
|
674
|
+
raise ValueError("No solution available. Call solve() first.")
|
|
675
|
+
|
|
676
|
+
# Enable the profiler
|
|
677
|
+
pr = profiling.profiling_start(self.settings.dev.profiling)
|
|
678
|
+
|
|
679
|
+
# Create result from stored solution state
|
|
680
|
+
result = self._format_result(self._solution, self._solution.k <= self.settings.scp.k_max)
|
|
681
|
+
|
|
682
|
+
t_0_post = time.time()
|
|
683
|
+
result = propagate_trajectory_results(
|
|
684
|
+
self._parameters, self.settings, result, self._propagation_solver
|
|
685
|
+
)
|
|
686
|
+
t_f_post = time.time()
|
|
687
|
+
|
|
688
|
+
self.timing_post = t_f_post - t_0_post
|
|
689
|
+
|
|
690
|
+
# Store the propagated result back into _solution for plotting
|
|
691
|
+
# Store as a cached attribute on the _solution object
|
|
692
|
+
self._solution._propagated_result = result
|
|
693
|
+
|
|
694
|
+
# Print results summary
|
|
695
|
+
printing.print_results_summary(
|
|
696
|
+
result, self.timing_post, self.timing_init, self.timing_solve
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
profiling.profiling_end(pr, "postprocess")
|
|
700
|
+
return result
|
|
701
|
+
|
|
702
|
+
def citation(self) -> str:
|
|
703
|
+
"""Return BibTeX citations for all components used in this problem.
|
|
704
|
+
|
|
705
|
+
Aggregates citations from the algorithm and other components (discretization,
|
|
706
|
+
convex solver, etc.) Each section is prefixed with a comment indicating which component the
|
|
707
|
+
citation is for.
|
|
708
|
+
|
|
709
|
+
Returns:
|
|
710
|
+
Formatted string containing all BibTeX citations with comments.
|
|
711
|
+
|
|
712
|
+
Example:
|
|
713
|
+
Print all citations for a problem::
|
|
714
|
+
|
|
715
|
+
problem = Problem(dynamics, constraints, states, controls, N, time)
|
|
716
|
+
print(problem.citation())
|
|
717
|
+
"""
|
|
718
|
+
sections = []
|
|
719
|
+
|
|
720
|
+
sections.append(r"% --- AUTO-GENERATED CITATIONS FOR OPENSCVX CONFIGURATION ---")
|
|
721
|
+
|
|
722
|
+
# Algorithm citations
|
|
723
|
+
algo_citations = self._algorithm.citation()
|
|
724
|
+
if algo_citations:
|
|
725
|
+
algo_name = type(self._algorithm).__name__
|
|
726
|
+
header = f"% Algorithm: {algo_name}"
|
|
727
|
+
citations = "\n".join(algo_citations)
|
|
728
|
+
sections.append(f"{header}\n\n{citations}")
|
|
729
|
+
|
|
730
|
+
# Future: add citations from discretization, constraint formulations, etc.
|
|
731
|
+
|
|
732
|
+
sections.append(r"% --- END AUTO-GENERATED CITATIONS")
|
|
733
|
+
|
|
734
|
+
return "\n\n".join(sections)
|