openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class OptimizationResults:
|
|
9
|
+
"""
|
|
10
|
+
Structured container for optimization results from the Successive Convexification (SCP) solver.
|
|
11
|
+
|
|
12
|
+
This class provides a type-safe and organized way to store and access optimization results,
|
|
13
|
+
replacing the previous dictionary-based approach. It includes core optimization data,
|
|
14
|
+
iteration history for convergence analysis, post-processing results, and flexible
|
|
15
|
+
storage for plotting and application-specific data.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
converged (bool): Whether the optimization successfully converged
|
|
19
|
+
t_final (float): Final time of the optimized trajectory
|
|
20
|
+
x_guess (np.ndarray): Optimized state trajectory at discretization nodes,
|
|
21
|
+
shape (N, n_states)
|
|
22
|
+
u_guess (np.ndarray): Optimized control trajectory at discretization nodes,
|
|
23
|
+
shape (N, n_controls)
|
|
24
|
+
|
|
25
|
+
# Dictionary-based Access
|
|
26
|
+
nodes (dict[str, np.ndarray]): Dictionary mapping state/control names to arrays
|
|
27
|
+
at optimization nodes. Includes both user-defined and augmented variables.
|
|
28
|
+
trajectory (dict[str, np.ndarray]): Dictionary mapping state/control names to arrays
|
|
29
|
+
along the propagated trajectory. Added by post_process().
|
|
30
|
+
|
|
31
|
+
# SCP Iteration History (for convergence analysis)
|
|
32
|
+
x_history (list[np.ndarray]): State trajectories from each SCP iteration
|
|
33
|
+
u_history (list[np.ndarray]): Control trajectories from each SCP iteration
|
|
34
|
+
discretization_history (list[np.ndarray]): Time discretization from each iteration
|
|
35
|
+
J_tr_history (list[np.ndarray]): Trust region cost history
|
|
36
|
+
J_vb_history (list[np.ndarray]): Virtual buffer cost history
|
|
37
|
+
J_vc_history (list[np.ndarray]): Virtual control cost history
|
|
38
|
+
|
|
39
|
+
# Post-processing Results (added by propagate_trajectory_results)
|
|
40
|
+
t_full (Optional[np.ndarray]): Full time grid for interpolated trajectory
|
|
41
|
+
x_full (Optional[np.ndarray]): Interpolated state trajectory on full time grid
|
|
42
|
+
u_full (Optional[np.ndarray]): Interpolated control trajectory on full time grid
|
|
43
|
+
cost (Optional[float]): Total cost of the optimized trajectory
|
|
44
|
+
ctcs_violation (Optional[np.ndarray]): Continuous-time constraint violations
|
|
45
|
+
|
|
46
|
+
# User-defined Data
|
|
47
|
+
plotting_data (dict[str, Any]): Flexible storage for plotting and application data
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
# Core optimization results
|
|
51
|
+
converged: bool
|
|
52
|
+
t_final: float
|
|
53
|
+
|
|
54
|
+
# Dictionary-based access to states and controls
|
|
55
|
+
nodes: dict[str, np.ndarray] = field(default_factory=dict)
|
|
56
|
+
trajectory: dict[str, np.ndarray] = field(default_factory=dict)
|
|
57
|
+
|
|
58
|
+
# Internal metadata for dictionary construction
|
|
59
|
+
_states: list = field(default_factory=list, repr=False)
|
|
60
|
+
_controls: list = field(default_factory=list, repr=False)
|
|
61
|
+
|
|
62
|
+
# History of SCP iterations (single source of truth)
|
|
63
|
+
X: list[np.ndarray] = field(default_factory=list)
|
|
64
|
+
U: list[np.ndarray] = field(default_factory=list)
|
|
65
|
+
discretization_history: list[np.ndarray] = field(default_factory=list)
|
|
66
|
+
J_tr_history: list[np.ndarray] = field(default_factory=list)
|
|
67
|
+
J_vb_history: list[np.ndarray] = field(default_factory=list)
|
|
68
|
+
J_vc_history: list[np.ndarray] = field(default_factory=list)
|
|
69
|
+
TR_history: list[np.ndarray] = field(default_factory=list)
|
|
70
|
+
VC_history: list[np.ndarray] = field(default_factory=list)
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def x(self) -> np.ndarray:
|
|
74
|
+
"""Optimal state trajectory at discretization nodes.
|
|
75
|
+
|
|
76
|
+
Returns the final converged solution from the SCP iteration history.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
State trajectory array, shape (N, n_states)
|
|
80
|
+
"""
|
|
81
|
+
return self.X[-1]
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def u(self) -> np.ndarray:
|
|
85
|
+
"""Optimal control trajectory at discretization nodes.
|
|
86
|
+
|
|
87
|
+
Returns the final converged solution from the SCP iteration history.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Control trajectory array, shape (N, n_controls)
|
|
91
|
+
"""
|
|
92
|
+
return self.U[-1]
|
|
93
|
+
|
|
94
|
+
# Post-processing results (added by propagate_trajectory_results)
|
|
95
|
+
t_full: Optional[np.ndarray] = None
|
|
96
|
+
x_full: Optional[np.ndarray] = None
|
|
97
|
+
u_full: Optional[np.ndarray] = None
|
|
98
|
+
cost: Optional[float] = None
|
|
99
|
+
ctcs_violation: Optional[np.ndarray] = None
|
|
100
|
+
|
|
101
|
+
# Additional plotting/application data (added by user)
|
|
102
|
+
plotting_data: dict[str, Any] = field(default_factory=dict)
|
|
103
|
+
|
|
104
|
+
def __post_init__(self):
|
|
105
|
+
"""Initialize the results object."""
|
|
106
|
+
pass
|
|
107
|
+
|
|
108
|
+
def update_plotting_data(self, **kwargs):
|
|
109
|
+
"""
|
|
110
|
+
Update the plotting data with additional information.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
**kwargs: Key-value pairs to add to plotting_data
|
|
114
|
+
"""
|
|
115
|
+
self.plotting_data.update(kwargs)
|
|
116
|
+
|
|
117
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
118
|
+
"""
|
|
119
|
+
Get a value from the results, similar to dict.get().
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
key: The key to look up
|
|
123
|
+
default: Default value if key is not found
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
The value associated with the key, or default if not found
|
|
127
|
+
"""
|
|
128
|
+
# Check if it's a direct attribute
|
|
129
|
+
if hasattr(self, key):
|
|
130
|
+
return getattr(self, key)
|
|
131
|
+
|
|
132
|
+
# Check if it's in plotting_data
|
|
133
|
+
if key in self.plotting_data:
|
|
134
|
+
return self.plotting_data[key]
|
|
135
|
+
|
|
136
|
+
return default
|
|
137
|
+
|
|
138
|
+
def __getitem__(self, key: str) -> Any:
|
|
139
|
+
"""
|
|
140
|
+
Allow dictionary-style access to results.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
key: The key to look up
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
The value associated with the key
|
|
147
|
+
|
|
148
|
+
Raises:
|
|
149
|
+
KeyError: If key is not found
|
|
150
|
+
"""
|
|
151
|
+
# Check if it's a direct attribute
|
|
152
|
+
if hasattr(self, key):
|
|
153
|
+
return getattr(self, key)
|
|
154
|
+
|
|
155
|
+
# Check if it's in plotting_data
|
|
156
|
+
if key in self.plotting_data:
|
|
157
|
+
return self.plotting_data[key]
|
|
158
|
+
|
|
159
|
+
raise KeyError(f"Key '{key}' not found in results")
|
|
160
|
+
|
|
161
|
+
def __setitem__(self, key: str, value: Any):
|
|
162
|
+
"""
|
|
163
|
+
Allow dictionary-style assignment to results.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
key: The key to set
|
|
167
|
+
value: The value to assign
|
|
168
|
+
"""
|
|
169
|
+
# Check if it's a direct attribute
|
|
170
|
+
if hasattr(self, key):
|
|
171
|
+
setattr(self, key, value)
|
|
172
|
+
else:
|
|
173
|
+
# Store in plotting_data
|
|
174
|
+
self.plotting_data[key] = value
|
|
175
|
+
|
|
176
|
+
def __contains__(self, key: str) -> bool:
|
|
177
|
+
"""
|
|
178
|
+
Check if a key exists in the results.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
key: The key to check
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
True if key exists, False otherwise
|
|
185
|
+
"""
|
|
186
|
+
return hasattr(self, key) or key in self.plotting_data
|
|
187
|
+
|
|
188
|
+
def update(self, other: dict[str, Any]):
|
|
189
|
+
"""
|
|
190
|
+
Update the results with additional data from a dictionary.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
other: Dictionary containing additional data
|
|
194
|
+
"""
|
|
195
|
+
for key, value in other.items():
|
|
196
|
+
self[key] = value
|
|
197
|
+
|
|
198
|
+
def to_dict(self) -> dict[str, Any]:
|
|
199
|
+
"""
|
|
200
|
+
Convert the results to a dictionary for backward compatibility.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
Dictionary representation of the results
|
|
204
|
+
"""
|
|
205
|
+
result_dict = {}
|
|
206
|
+
|
|
207
|
+
# Add all direct attributes
|
|
208
|
+
for attr_name in self.__dataclass_fields__:
|
|
209
|
+
if attr_name != "plotting_data":
|
|
210
|
+
result_dict[attr_name] = getattr(self, attr_name)
|
|
211
|
+
|
|
212
|
+
# Add plotting data
|
|
213
|
+
result_dict.update(self.plotting_data)
|
|
214
|
+
|
|
215
|
+
return result_dict
|
|
@@ -0,0 +1,384 @@
|
|
|
1
|
+
"""Penalized Trust Region (PTR) successive convexification algorithm.
|
|
2
|
+
|
|
3
|
+
This module implements the PTR algorithm for solving non-convex trajectory
|
|
4
|
+
optimization problems through iterative convex approximation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import time
|
|
8
|
+
import warnings
|
|
9
|
+
from typing import TYPE_CHECKING, List
|
|
10
|
+
|
|
11
|
+
import cvxpy as cp
|
|
12
|
+
import numpy as np
|
|
13
|
+
import numpy.linalg as la
|
|
14
|
+
|
|
15
|
+
from openscvx.config import Config
|
|
16
|
+
|
|
17
|
+
from .autotuning import update_scp_weights
|
|
18
|
+
from .base import Algorithm, AlgorithmState
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from openscvx.lowered import LoweredJaxConstraints
|
|
22
|
+
|
|
23
|
+
warnings.filterwarnings("ignore")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PenalizedTrustRegion(Algorithm):
|
|
27
|
+
"""Penalized Trust Region (PTR) successive convexification algorithm.
|
|
28
|
+
|
|
29
|
+
PTR solves non-convex trajectory optimization problems through iterative
|
|
30
|
+
convex approximation. Each subproblem balances competing cost terms:
|
|
31
|
+
|
|
32
|
+
- **Trust region penalty**: Discourages large deviations from the previous
|
|
33
|
+
iterate, keeping the solution within the region where linearization is valid.
|
|
34
|
+
- **Virtual control**: Relaxes dynamics constraints, penalized to drive
|
|
35
|
+
defects toward zero as the algorithm converges.
|
|
36
|
+
- **Virtual buffer**: Relaxes non-convex constraints, similarly penalized
|
|
37
|
+
to enforce feasibility at convergence.
|
|
38
|
+
- **Problem objective and other terms**: The user-defined cost (e.g., minimum
|
|
39
|
+
fuel, minimum time) and any additional penalty terms.
|
|
40
|
+
|
|
41
|
+
The interplay between these terms guides the optimization: the trust region
|
|
42
|
+
anchors the solution near the linearization point while virtual terms allow
|
|
43
|
+
temporary constraint violations that shrink over iterations.
|
|
44
|
+
|
|
45
|
+
Example:
|
|
46
|
+
Using PTR with a Problem::
|
|
47
|
+
|
|
48
|
+
from openscvx.algorithms import PenalizedTrustRegion
|
|
49
|
+
|
|
50
|
+
problem = Problem(dynamics, constraints, states, controls, N, time)
|
|
51
|
+
problem.initialize()
|
|
52
|
+
result = problem.solve()
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self):
|
|
56
|
+
"""Initialize PTR with unset infrastructure.
|
|
57
|
+
|
|
58
|
+
Call initialize() before step() to set up compiled components.
|
|
59
|
+
"""
|
|
60
|
+
self._ocp: cp.Problem = None
|
|
61
|
+
self._discretization_solver: callable = None
|
|
62
|
+
self._jax_constraints: "LoweredJaxConstraints" = None
|
|
63
|
+
self._solve_ocp: callable = None
|
|
64
|
+
self._emitter: callable = None
|
|
65
|
+
|
|
66
|
+
def initialize(
|
|
67
|
+
self,
|
|
68
|
+
ocp: cp.Problem,
|
|
69
|
+
discretization_solver: callable,
|
|
70
|
+
jax_constraints: "LoweredJaxConstraints",
|
|
71
|
+
solve_ocp: callable,
|
|
72
|
+
emitter: callable,
|
|
73
|
+
params: dict,
|
|
74
|
+
settings: Config,
|
|
75
|
+
) -> None:
|
|
76
|
+
"""Initialize PTR algorithm.
|
|
77
|
+
|
|
78
|
+
Stores compiled infrastructure and performs a warm-start solve to
|
|
79
|
+
initialize DPP and JAX jacobians.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
ocp: CVXPy optimal control problem
|
|
83
|
+
discretization_solver: Compiled discretization solver
|
|
84
|
+
jax_constraints: JIT-compiled constraint functions
|
|
85
|
+
solve_ocp: Callable that solves the OCP
|
|
86
|
+
emitter: Callback for emitting iteration progress
|
|
87
|
+
params: Problem parameters dictionary (for warm-start)
|
|
88
|
+
settings: Configuration object (for warm-start)
|
|
89
|
+
"""
|
|
90
|
+
# Store immutable infrastructure
|
|
91
|
+
self._ocp = ocp
|
|
92
|
+
self._discretization_solver = discretization_solver
|
|
93
|
+
self._jax_constraints = jax_constraints
|
|
94
|
+
self._solve_ocp = solve_ocp
|
|
95
|
+
self._emitter = emitter
|
|
96
|
+
|
|
97
|
+
if "x_init" in ocp.param_dict:
|
|
98
|
+
ocp.param_dict["x_init"].value = settings.sim.x.initial
|
|
99
|
+
|
|
100
|
+
if "x_term" in ocp.param_dict:
|
|
101
|
+
ocp.param_dict["x_term"].value = settings.sim.x.final
|
|
102
|
+
|
|
103
|
+
# Create temporary state for initialization solve
|
|
104
|
+
init_state = AlgorithmState.from_settings(settings)
|
|
105
|
+
|
|
106
|
+
# Solve a dumb problem to initialize DPP and JAX jacobians
|
|
107
|
+
_ = self._subproblem(params, init_state, settings)
|
|
108
|
+
|
|
109
|
+
def step(
|
|
110
|
+
self,
|
|
111
|
+
state: AlgorithmState,
|
|
112
|
+
params: dict,
|
|
113
|
+
settings: Config,
|
|
114
|
+
) -> bool:
|
|
115
|
+
"""Execute one PTR iteration.
|
|
116
|
+
|
|
117
|
+
Solves the convex subproblem, updates state in place, and checks
|
|
118
|
+
convergence based on trust region, virtual buffer, and virtual
|
|
119
|
+
control costs.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
state: Mutable solver state (modified in place)
|
|
123
|
+
params: Problem parameters dictionary (may change between steps)
|
|
124
|
+
settings: Configuration object (may change between steps)
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
True if J_tr, J_vb, and J_vc are all below their thresholds.
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
RuntimeError: If initialize() has not been called.
|
|
131
|
+
"""
|
|
132
|
+
if self._ocp is None:
|
|
133
|
+
raise RuntimeError(
|
|
134
|
+
"PenalizedTrustRegion.step() called before initialize(). "
|
|
135
|
+
"Call initialize() first to set up compiled infrastructure."
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Run the subproblem
|
|
139
|
+
(
|
|
140
|
+
x_sol,
|
|
141
|
+
u_sol,
|
|
142
|
+
cost,
|
|
143
|
+
J_total,
|
|
144
|
+
J_vb_vec,
|
|
145
|
+
J_vc_vec,
|
|
146
|
+
J_tr_vec,
|
|
147
|
+
prob_stat,
|
|
148
|
+
V_multi_shoot,
|
|
149
|
+
subprop_time,
|
|
150
|
+
dis_time,
|
|
151
|
+
vc_mat,
|
|
152
|
+
tr_mat,
|
|
153
|
+
) = self._subproblem(params, state, settings)
|
|
154
|
+
|
|
155
|
+
# Update state in place by appending to history
|
|
156
|
+
# The x_guess/u_guess properties will automatically return the latest entry
|
|
157
|
+
state.V_history.append(V_multi_shoot)
|
|
158
|
+
state.X.append(x_sol)
|
|
159
|
+
state.U.append(u_sol)
|
|
160
|
+
state.VC_history.append(vc_mat)
|
|
161
|
+
state.TR_history.append(tr_mat)
|
|
162
|
+
|
|
163
|
+
state.J_tr = np.sum(np.array(J_tr_vec))
|
|
164
|
+
state.J_vb = np.sum(np.array(J_vb_vec))
|
|
165
|
+
state.J_vc = np.sum(np.array(J_vc_vec))
|
|
166
|
+
|
|
167
|
+
# Update weights in state
|
|
168
|
+
update_scp_weights(state, settings, state.k)
|
|
169
|
+
|
|
170
|
+
# Emit data
|
|
171
|
+
self._emitter(
|
|
172
|
+
{
|
|
173
|
+
"iter": state.k,
|
|
174
|
+
"dis_time": dis_time * 1000.0,
|
|
175
|
+
"subprop_time": subprop_time * 1000.0,
|
|
176
|
+
"J_total": J_total,
|
|
177
|
+
"J_tr": state.J_tr,
|
|
178
|
+
"J_vb": state.J_vb,
|
|
179
|
+
"J_vc": state.J_vc,
|
|
180
|
+
"cost": cost[-1],
|
|
181
|
+
"prob_stat": prob_stat,
|
|
182
|
+
}
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Increment iteration counter
|
|
186
|
+
state.k += 1
|
|
187
|
+
|
|
188
|
+
# Return convergence status
|
|
189
|
+
return (
|
|
190
|
+
(state.J_tr < settings.scp.ep_tr)
|
|
191
|
+
and (state.J_vb < settings.scp.ep_vb)
|
|
192
|
+
and (state.J_vc < settings.scp.ep_vc)
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def _subproblem(
|
|
196
|
+
self,
|
|
197
|
+
params: dict,
|
|
198
|
+
state: AlgorithmState,
|
|
199
|
+
settings: Config,
|
|
200
|
+
):
|
|
201
|
+
"""Solve a single convex subproblem.
|
|
202
|
+
|
|
203
|
+
Uses stored infrastructure (ocp, discretization_solver, jax_constraints)
|
|
204
|
+
with per-step params and settings.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
params: Problem parameters dictionary
|
|
208
|
+
state: Current solver state
|
|
209
|
+
settings: Configuration object
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
Tuple containing solution data, costs, and timing information.
|
|
213
|
+
"""
|
|
214
|
+
self._ocp.param_dict["x_bar"].value = state.x
|
|
215
|
+
self._ocp.param_dict["u_bar"].value = state.u
|
|
216
|
+
|
|
217
|
+
param_dict = params
|
|
218
|
+
|
|
219
|
+
t0 = time.time()
|
|
220
|
+
A_bar, B_bar, C_bar, x_prop, V_multi_shoot = self._discretization_solver.call(
|
|
221
|
+
state.x, state.u.astype(float), param_dict
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
self._ocp.param_dict["A_d"].value = A_bar.__array__()
|
|
225
|
+
self._ocp.param_dict["B_d"].value = B_bar.__array__()
|
|
226
|
+
self._ocp.param_dict["C_d"].value = C_bar.__array__()
|
|
227
|
+
self._ocp.param_dict["x_prop"].value = x_prop.__array__()
|
|
228
|
+
dis_time = time.time() - t0
|
|
229
|
+
|
|
230
|
+
# Update nodal constraint linearization parameters
|
|
231
|
+
# TODO: (norrisg) investigate why we are passing `0` for the node here
|
|
232
|
+
if self._jax_constraints.nodal:
|
|
233
|
+
for g_id, constraint in enumerate(self._jax_constraints.nodal):
|
|
234
|
+
self._ocp.param_dict["g_" + str(g_id)].value = np.asarray(
|
|
235
|
+
constraint.func(state.x, state.u, 0, param_dict)
|
|
236
|
+
)
|
|
237
|
+
self._ocp.param_dict["grad_g_x_" + str(g_id)].value = np.asarray(
|
|
238
|
+
constraint.grad_g_x(state.x, state.u, 0, param_dict)
|
|
239
|
+
)
|
|
240
|
+
self._ocp.param_dict["grad_g_u_" + str(g_id)].value = np.asarray(
|
|
241
|
+
constraint.grad_g_u(state.x, state.u, 0, param_dict)
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
# Update cross-node constraint linearization parameters
|
|
245
|
+
if self._jax_constraints.cross_node:
|
|
246
|
+
for g_id, constraint in enumerate(self._jax_constraints.cross_node):
|
|
247
|
+
# Cross-node constraints take (X, U, params) not (x, u, node, params)
|
|
248
|
+
self._ocp.param_dict["g_cross_" + str(g_id)].value = np.asarray(
|
|
249
|
+
constraint.func(state.x, state.u, param_dict)
|
|
250
|
+
)
|
|
251
|
+
self._ocp.param_dict["grad_g_X_cross_" + str(g_id)].value = np.asarray(
|
|
252
|
+
constraint.grad_g_X(state.x, state.u, param_dict)
|
|
253
|
+
)
|
|
254
|
+
self._ocp.param_dict["grad_g_U_cross_" + str(g_id)].value = np.asarray(
|
|
255
|
+
constraint.grad_g_U(state.x, state.u, param_dict)
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Convex constraints are already lowered and handled in the OCP, no action needed here
|
|
259
|
+
|
|
260
|
+
# Initialize lam_vc as matrix if it's still a scalar in state
|
|
261
|
+
if isinstance(state.lam_vc, (int, float)):
|
|
262
|
+
# Convert scalar to matrix: (N-1, n_states)
|
|
263
|
+
state.lam_vc = np.ones((settings.scp.n - 1, settings.sim.n_states)) * state.lam_vc
|
|
264
|
+
|
|
265
|
+
# Update CVXPy parameters from state
|
|
266
|
+
self._ocp.param_dict["w_tr"].value = state.w_tr
|
|
267
|
+
self._ocp.param_dict["lam_cost"].value = state.lam_cost
|
|
268
|
+
self._ocp.param_dict["lam_vc"].value = state.lam_vc
|
|
269
|
+
self._ocp.param_dict["lam_vb"].value = state.lam_vb
|
|
270
|
+
|
|
271
|
+
t0 = time.time()
|
|
272
|
+
self._solve_ocp()
|
|
273
|
+
subprop_time = time.time() - t0
|
|
274
|
+
|
|
275
|
+
x_new_guess = (
|
|
276
|
+
settings.sim.S_x @ self._ocp.var_dict["x"].value.T
|
|
277
|
+
+ np.expand_dims(settings.sim.c_x, axis=1)
|
|
278
|
+
).T
|
|
279
|
+
u_new_guess = (
|
|
280
|
+
settings.sim.S_u @ self._ocp.var_dict["u"].value.T
|
|
281
|
+
+ np.expand_dims(settings.sim.c_u, axis=1)
|
|
282
|
+
).T
|
|
283
|
+
|
|
284
|
+
# Calculate costs from boundary conditions using utility function
|
|
285
|
+
# Note: The original code only considered final_type, but the utility handles both
|
|
286
|
+
# Here we maintain backward compatibility by only using final_type
|
|
287
|
+
costs = [0]
|
|
288
|
+
for i, bc_type in enumerate(settings.sim.x.final_type):
|
|
289
|
+
if bc_type == "Minimize":
|
|
290
|
+
costs += x_new_guess[:, i]
|
|
291
|
+
elif bc_type == "Maximize":
|
|
292
|
+
costs -= x_new_guess[:, i]
|
|
293
|
+
|
|
294
|
+
# Create the block diagonal matrix using jax.numpy.block
|
|
295
|
+
inv_block_diag = np.block(
|
|
296
|
+
[
|
|
297
|
+
[
|
|
298
|
+
settings.sim.inv_S_x,
|
|
299
|
+
np.zeros((settings.sim.inv_S_x.shape[0], settings.sim.inv_S_u.shape[1])),
|
|
300
|
+
],
|
|
301
|
+
[
|
|
302
|
+
np.zeros((settings.sim.inv_S_u.shape[0], settings.sim.inv_S_x.shape[1])),
|
|
303
|
+
settings.sim.inv_S_u,
|
|
304
|
+
],
|
|
305
|
+
]
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Calculate J_tr_vec using the JAX-compatible block diagonal matrix
|
|
309
|
+
tr_mat = inv_block_diag @ np.hstack((x_new_guess - state.x, u_new_guess - state.u)).T
|
|
310
|
+
J_tr_vec = la.norm(tr_mat, axis=0) ** 2
|
|
311
|
+
vc_mat = np.abs(self._ocp.var_dict["nu"].value)
|
|
312
|
+
J_vc_vec = np.sum(vc_mat, axis=1)
|
|
313
|
+
|
|
314
|
+
id_ncvx = 0
|
|
315
|
+
J_vb_vec = 0
|
|
316
|
+
if self._jax_constraints.nodal:
|
|
317
|
+
for constraint in self._jax_constraints.nodal:
|
|
318
|
+
J_vb_vec += np.maximum(0, self._ocp.var_dict["nu_vb_" + str(id_ncvx)].value)
|
|
319
|
+
id_ncvx += 1
|
|
320
|
+
|
|
321
|
+
# Add cross-node constraint violations
|
|
322
|
+
id_cross = 0
|
|
323
|
+
if self._jax_constraints.cross_node:
|
|
324
|
+
for constraint in self._jax_constraints.cross_node:
|
|
325
|
+
J_vb_vec += np.maximum(0, self._ocp.var_dict["nu_vb_cross_" + str(id_cross)].value)
|
|
326
|
+
id_cross += 1
|
|
327
|
+
|
|
328
|
+
# Convex constraints are already handled in the OCP, no processing needed here
|
|
329
|
+
return (
|
|
330
|
+
x_new_guess,
|
|
331
|
+
u_new_guess,
|
|
332
|
+
costs,
|
|
333
|
+
self._ocp.value,
|
|
334
|
+
J_vb_vec,
|
|
335
|
+
J_vc_vec,
|
|
336
|
+
J_tr_vec,
|
|
337
|
+
self._ocp.status,
|
|
338
|
+
V_multi_shoot,
|
|
339
|
+
subprop_time,
|
|
340
|
+
dis_time,
|
|
341
|
+
vc_mat,
|
|
342
|
+
abs(tr_mat),
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
def citation(self) -> List[str]:
|
|
346
|
+
"""Return BibTeX citations for the PTR algorithm.
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
List containing the BibTeX entry for the PTR paper.
|
|
350
|
+
"""
|
|
351
|
+
return [
|
|
352
|
+
r"""@article{drusvyatskiy2018error,
|
|
353
|
+
title={Error bounds, quadratic growth, and linear convergence of proximal methods},
|
|
354
|
+
author={Drusvyatskiy, Dmitriy and Lewis, Adrian S},
|
|
355
|
+
journal={Mathematics of operations research},
|
|
356
|
+
volume={43},
|
|
357
|
+
number={3},
|
|
358
|
+
pages={919--948},
|
|
359
|
+
year={2018},
|
|
360
|
+
publisher={INFORMS}
|
|
361
|
+
}""",
|
|
362
|
+
r"""@article{szmuk2020successive,
|
|
363
|
+
title={Successive convexification for real-time six-degree-of-freedom powered descent guidance
|
|
364
|
+
with state-triggered constraints},
|
|
365
|
+
author={Szmuk, Michael and Reynolds, Taylor P and A{\c{c}}{\i}kme{\c{s}}e, Beh{\c{c}}et},
|
|
366
|
+
journal={Journal of Guidance, Control, and Dynamics},
|
|
367
|
+
volume={43},
|
|
368
|
+
number={8},
|
|
369
|
+
pages={1399--1413},
|
|
370
|
+
year={2020},
|
|
371
|
+
publisher={American Institute of Aeronautics and Astronautics}
|
|
372
|
+
}""",
|
|
373
|
+
r"""@article{reynolds2020dual,
|
|
374
|
+
title={Dual quaternion-based powered descent guidance with state-triggered constraints},
|
|
375
|
+
author={Reynolds, Taylor P and Szmuk, Michael and Malyuta, Danylo and Mesbahi, Mehran and
|
|
376
|
+
A{\c{c}}{\i}kme{\c{s}}e, Beh{\c{c}}et and Carson III, John M},
|
|
377
|
+
journal={Journal of Guidance, Control, and Dynamics},
|
|
378
|
+
volume={43},
|
|
379
|
+
number={9},
|
|
380
|
+
pages={1584--1599},
|
|
381
|
+
year={2020},
|
|
382
|
+
publisher={American Institute of Aeronautics and Astronautics}
|
|
383
|
+
}""",
|
|
384
|
+
]
|