openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,215 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, Optional
3
+
4
+ import numpy as np
5
+
6
+
7
+ @dataclass
8
+ class OptimizationResults:
9
+ """
10
+ Structured container for optimization results from the Successive Convexification (SCP) solver.
11
+
12
+ This class provides a type-safe and organized way to store and access optimization results,
13
+ replacing the previous dictionary-based approach. It includes core optimization data,
14
+ iteration history for convergence analysis, post-processing results, and flexible
15
+ storage for plotting and application-specific data.
16
+
17
+ Attributes:
18
+ converged (bool): Whether the optimization successfully converged
19
+ t_final (float): Final time of the optimized trajectory
20
+ x_guess (np.ndarray): Optimized state trajectory at discretization nodes,
21
+ shape (N, n_states)
22
+ u_guess (np.ndarray): Optimized control trajectory at discretization nodes,
23
+ shape (N, n_controls)
24
+
25
+ # Dictionary-based Access
26
+ nodes (dict[str, np.ndarray]): Dictionary mapping state/control names to arrays
27
+ at optimization nodes. Includes both user-defined and augmented variables.
28
+ trajectory (dict[str, np.ndarray]): Dictionary mapping state/control names to arrays
29
+ along the propagated trajectory. Added by post_process().
30
+
31
+ # SCP Iteration History (for convergence analysis)
32
+ x_history (list[np.ndarray]): State trajectories from each SCP iteration
33
+ u_history (list[np.ndarray]): Control trajectories from each SCP iteration
34
+ discretization_history (list[np.ndarray]): Time discretization from each iteration
35
+ J_tr_history (list[np.ndarray]): Trust region cost history
36
+ J_vb_history (list[np.ndarray]): Virtual buffer cost history
37
+ J_vc_history (list[np.ndarray]): Virtual control cost history
38
+
39
+ # Post-processing Results (added by propagate_trajectory_results)
40
+ t_full (Optional[np.ndarray]): Full time grid for interpolated trajectory
41
+ x_full (Optional[np.ndarray]): Interpolated state trajectory on full time grid
42
+ u_full (Optional[np.ndarray]): Interpolated control trajectory on full time grid
43
+ cost (Optional[float]): Total cost of the optimized trajectory
44
+ ctcs_violation (Optional[np.ndarray]): Continuous-time constraint violations
45
+
46
+ # User-defined Data
47
+ plotting_data (dict[str, Any]): Flexible storage for plotting and application data
48
+ """
49
+
50
+ # Core optimization results
51
+ converged: bool
52
+ t_final: float
53
+
54
+ # Dictionary-based access to states and controls
55
+ nodes: dict[str, np.ndarray] = field(default_factory=dict)
56
+ trajectory: dict[str, np.ndarray] = field(default_factory=dict)
57
+
58
+ # Internal metadata for dictionary construction
59
+ _states: list = field(default_factory=list, repr=False)
60
+ _controls: list = field(default_factory=list, repr=False)
61
+
62
+ # History of SCP iterations (single source of truth)
63
+ X: list[np.ndarray] = field(default_factory=list)
64
+ U: list[np.ndarray] = field(default_factory=list)
65
+ discretization_history: list[np.ndarray] = field(default_factory=list)
66
+ J_tr_history: list[np.ndarray] = field(default_factory=list)
67
+ J_vb_history: list[np.ndarray] = field(default_factory=list)
68
+ J_vc_history: list[np.ndarray] = field(default_factory=list)
69
+ TR_history: list[np.ndarray] = field(default_factory=list)
70
+ VC_history: list[np.ndarray] = field(default_factory=list)
71
+
72
+ @property
73
+ def x(self) -> np.ndarray:
74
+ """Optimal state trajectory at discretization nodes.
75
+
76
+ Returns the final converged solution from the SCP iteration history.
77
+
78
+ Returns:
79
+ State trajectory array, shape (N, n_states)
80
+ """
81
+ return self.X[-1]
82
+
83
+ @property
84
+ def u(self) -> np.ndarray:
85
+ """Optimal control trajectory at discretization nodes.
86
+
87
+ Returns the final converged solution from the SCP iteration history.
88
+
89
+ Returns:
90
+ Control trajectory array, shape (N, n_controls)
91
+ """
92
+ return self.U[-1]
93
+
94
+ # Post-processing results (added by propagate_trajectory_results)
95
+ t_full: Optional[np.ndarray] = None
96
+ x_full: Optional[np.ndarray] = None
97
+ u_full: Optional[np.ndarray] = None
98
+ cost: Optional[float] = None
99
+ ctcs_violation: Optional[np.ndarray] = None
100
+
101
+ # Additional plotting/application data (added by user)
102
+ plotting_data: dict[str, Any] = field(default_factory=dict)
103
+
104
+ def __post_init__(self):
105
+ """Initialize the results object."""
106
+ pass
107
+
108
+ def update_plotting_data(self, **kwargs):
109
+ """
110
+ Update the plotting data with additional information.
111
+
112
+ Args:
113
+ **kwargs: Key-value pairs to add to plotting_data
114
+ """
115
+ self.plotting_data.update(kwargs)
116
+
117
+ def get(self, key: str, default: Any = None) -> Any:
118
+ """
119
+ Get a value from the results, similar to dict.get().
120
+
121
+ Args:
122
+ key: The key to look up
123
+ default: Default value if key is not found
124
+
125
+ Returns:
126
+ The value associated with the key, or default if not found
127
+ """
128
+ # Check if it's a direct attribute
129
+ if hasattr(self, key):
130
+ return getattr(self, key)
131
+
132
+ # Check if it's in plotting_data
133
+ if key in self.plotting_data:
134
+ return self.plotting_data[key]
135
+
136
+ return default
137
+
138
+ def __getitem__(self, key: str) -> Any:
139
+ """
140
+ Allow dictionary-style access to results.
141
+
142
+ Args:
143
+ key: The key to look up
144
+
145
+ Returns:
146
+ The value associated with the key
147
+
148
+ Raises:
149
+ KeyError: If key is not found
150
+ """
151
+ # Check if it's a direct attribute
152
+ if hasattr(self, key):
153
+ return getattr(self, key)
154
+
155
+ # Check if it's in plotting_data
156
+ if key in self.plotting_data:
157
+ return self.plotting_data[key]
158
+
159
+ raise KeyError(f"Key '{key}' not found in results")
160
+
161
+ def __setitem__(self, key: str, value: Any):
162
+ """
163
+ Allow dictionary-style assignment to results.
164
+
165
+ Args:
166
+ key: The key to set
167
+ value: The value to assign
168
+ """
169
+ # Check if it's a direct attribute
170
+ if hasattr(self, key):
171
+ setattr(self, key, value)
172
+ else:
173
+ # Store in plotting_data
174
+ self.plotting_data[key] = value
175
+
176
+ def __contains__(self, key: str) -> bool:
177
+ """
178
+ Check if a key exists in the results.
179
+
180
+ Args:
181
+ key: The key to check
182
+
183
+ Returns:
184
+ True if key exists, False otherwise
185
+ """
186
+ return hasattr(self, key) or key in self.plotting_data
187
+
188
+ def update(self, other: dict[str, Any]):
189
+ """
190
+ Update the results with additional data from a dictionary.
191
+
192
+ Args:
193
+ other: Dictionary containing additional data
194
+ """
195
+ for key, value in other.items():
196
+ self[key] = value
197
+
198
+ def to_dict(self) -> dict[str, Any]:
199
+ """
200
+ Convert the results to a dictionary for backward compatibility.
201
+
202
+ Returns:
203
+ Dictionary representation of the results
204
+ """
205
+ result_dict = {}
206
+
207
+ # Add all direct attributes
208
+ for attr_name in self.__dataclass_fields__:
209
+ if attr_name != "plotting_data":
210
+ result_dict[attr_name] = getattr(self, attr_name)
211
+
212
+ # Add plotting data
213
+ result_dict.update(self.plotting_data)
214
+
215
+ return result_dict
@@ -0,0 +1,384 @@
1
+ """Penalized Trust Region (PTR) successive convexification algorithm.
2
+
3
+ This module implements the PTR algorithm for solving non-convex trajectory
4
+ optimization problems through iterative convex approximation.
5
+ """
6
+
7
+ import time
8
+ import warnings
9
+ from typing import TYPE_CHECKING, List
10
+
11
+ import cvxpy as cp
12
+ import numpy as np
13
+ import numpy.linalg as la
14
+
15
+ from openscvx.config import Config
16
+
17
+ from .autotuning import update_scp_weights
18
+ from .base import Algorithm, AlgorithmState
19
+
20
+ if TYPE_CHECKING:
21
+ from openscvx.lowered import LoweredJaxConstraints
22
+
23
+ warnings.filterwarnings("ignore")
24
+
25
+
26
+ class PenalizedTrustRegion(Algorithm):
27
+ """Penalized Trust Region (PTR) successive convexification algorithm.
28
+
29
+ PTR solves non-convex trajectory optimization problems through iterative
30
+ convex approximation. Each subproblem balances competing cost terms:
31
+
32
+ - **Trust region penalty**: Discourages large deviations from the previous
33
+ iterate, keeping the solution within the region where linearization is valid.
34
+ - **Virtual control**: Relaxes dynamics constraints, penalized to drive
35
+ defects toward zero as the algorithm converges.
36
+ - **Virtual buffer**: Relaxes non-convex constraints, similarly penalized
37
+ to enforce feasibility at convergence.
38
+ - **Problem objective and other terms**: The user-defined cost (e.g., minimum
39
+ fuel, minimum time) and any additional penalty terms.
40
+
41
+ The interplay between these terms guides the optimization: the trust region
42
+ anchors the solution near the linearization point while virtual terms allow
43
+ temporary constraint violations that shrink over iterations.
44
+
45
+ Example:
46
+ Using PTR with a Problem::
47
+
48
+ from openscvx.algorithms import PenalizedTrustRegion
49
+
50
+ problem = Problem(dynamics, constraints, states, controls, N, time)
51
+ problem.initialize()
52
+ result = problem.solve()
53
+ """
54
+
55
+ def __init__(self):
56
+ """Initialize PTR with unset infrastructure.
57
+
58
+ Call initialize() before step() to set up compiled components.
59
+ """
60
+ self._ocp: cp.Problem = None
61
+ self._discretization_solver: callable = None
62
+ self._jax_constraints: "LoweredJaxConstraints" = None
63
+ self._solve_ocp: callable = None
64
+ self._emitter: callable = None
65
+
66
+ def initialize(
67
+ self,
68
+ ocp: cp.Problem,
69
+ discretization_solver: callable,
70
+ jax_constraints: "LoweredJaxConstraints",
71
+ solve_ocp: callable,
72
+ emitter: callable,
73
+ params: dict,
74
+ settings: Config,
75
+ ) -> None:
76
+ """Initialize PTR algorithm.
77
+
78
+ Stores compiled infrastructure and performs a warm-start solve to
79
+ initialize DPP and JAX jacobians.
80
+
81
+ Args:
82
+ ocp: CVXPy optimal control problem
83
+ discretization_solver: Compiled discretization solver
84
+ jax_constraints: JIT-compiled constraint functions
85
+ solve_ocp: Callable that solves the OCP
86
+ emitter: Callback for emitting iteration progress
87
+ params: Problem parameters dictionary (for warm-start)
88
+ settings: Configuration object (for warm-start)
89
+ """
90
+ # Store immutable infrastructure
91
+ self._ocp = ocp
92
+ self._discretization_solver = discretization_solver
93
+ self._jax_constraints = jax_constraints
94
+ self._solve_ocp = solve_ocp
95
+ self._emitter = emitter
96
+
97
+ if "x_init" in ocp.param_dict:
98
+ ocp.param_dict["x_init"].value = settings.sim.x.initial
99
+
100
+ if "x_term" in ocp.param_dict:
101
+ ocp.param_dict["x_term"].value = settings.sim.x.final
102
+
103
+ # Create temporary state for initialization solve
104
+ init_state = AlgorithmState.from_settings(settings)
105
+
106
+ # Solve a dumb problem to initialize DPP and JAX jacobians
107
+ _ = self._subproblem(params, init_state, settings)
108
+
109
+ def step(
110
+ self,
111
+ state: AlgorithmState,
112
+ params: dict,
113
+ settings: Config,
114
+ ) -> bool:
115
+ """Execute one PTR iteration.
116
+
117
+ Solves the convex subproblem, updates state in place, and checks
118
+ convergence based on trust region, virtual buffer, and virtual
119
+ control costs.
120
+
121
+ Args:
122
+ state: Mutable solver state (modified in place)
123
+ params: Problem parameters dictionary (may change between steps)
124
+ settings: Configuration object (may change between steps)
125
+
126
+ Returns:
127
+ True if J_tr, J_vb, and J_vc are all below their thresholds.
128
+
129
+ Raises:
130
+ RuntimeError: If initialize() has not been called.
131
+ """
132
+ if self._ocp is None:
133
+ raise RuntimeError(
134
+ "PenalizedTrustRegion.step() called before initialize(). "
135
+ "Call initialize() first to set up compiled infrastructure."
136
+ )
137
+
138
+ # Run the subproblem
139
+ (
140
+ x_sol,
141
+ u_sol,
142
+ cost,
143
+ J_total,
144
+ J_vb_vec,
145
+ J_vc_vec,
146
+ J_tr_vec,
147
+ prob_stat,
148
+ V_multi_shoot,
149
+ subprop_time,
150
+ dis_time,
151
+ vc_mat,
152
+ tr_mat,
153
+ ) = self._subproblem(params, state, settings)
154
+
155
+ # Update state in place by appending to history
156
+ # The x_guess/u_guess properties will automatically return the latest entry
157
+ state.V_history.append(V_multi_shoot)
158
+ state.X.append(x_sol)
159
+ state.U.append(u_sol)
160
+ state.VC_history.append(vc_mat)
161
+ state.TR_history.append(tr_mat)
162
+
163
+ state.J_tr = np.sum(np.array(J_tr_vec))
164
+ state.J_vb = np.sum(np.array(J_vb_vec))
165
+ state.J_vc = np.sum(np.array(J_vc_vec))
166
+
167
+ # Update weights in state
168
+ update_scp_weights(state, settings, state.k)
169
+
170
+ # Emit data
171
+ self._emitter(
172
+ {
173
+ "iter": state.k,
174
+ "dis_time": dis_time * 1000.0,
175
+ "subprop_time": subprop_time * 1000.0,
176
+ "J_total": J_total,
177
+ "J_tr": state.J_tr,
178
+ "J_vb": state.J_vb,
179
+ "J_vc": state.J_vc,
180
+ "cost": cost[-1],
181
+ "prob_stat": prob_stat,
182
+ }
183
+ )
184
+
185
+ # Increment iteration counter
186
+ state.k += 1
187
+
188
+ # Return convergence status
189
+ return (
190
+ (state.J_tr < settings.scp.ep_tr)
191
+ and (state.J_vb < settings.scp.ep_vb)
192
+ and (state.J_vc < settings.scp.ep_vc)
193
+ )
194
+
195
+ def _subproblem(
196
+ self,
197
+ params: dict,
198
+ state: AlgorithmState,
199
+ settings: Config,
200
+ ):
201
+ """Solve a single convex subproblem.
202
+
203
+ Uses stored infrastructure (ocp, discretization_solver, jax_constraints)
204
+ with per-step params and settings.
205
+
206
+ Args:
207
+ params: Problem parameters dictionary
208
+ state: Current solver state
209
+ settings: Configuration object
210
+
211
+ Returns:
212
+ Tuple containing solution data, costs, and timing information.
213
+ """
214
+ self._ocp.param_dict["x_bar"].value = state.x
215
+ self._ocp.param_dict["u_bar"].value = state.u
216
+
217
+ param_dict = params
218
+
219
+ t0 = time.time()
220
+ A_bar, B_bar, C_bar, x_prop, V_multi_shoot = self._discretization_solver.call(
221
+ state.x, state.u.astype(float), param_dict
222
+ )
223
+
224
+ self._ocp.param_dict["A_d"].value = A_bar.__array__()
225
+ self._ocp.param_dict["B_d"].value = B_bar.__array__()
226
+ self._ocp.param_dict["C_d"].value = C_bar.__array__()
227
+ self._ocp.param_dict["x_prop"].value = x_prop.__array__()
228
+ dis_time = time.time() - t0
229
+
230
+ # Update nodal constraint linearization parameters
231
+ # TODO: (norrisg) investigate why we are passing `0` for the node here
232
+ if self._jax_constraints.nodal:
233
+ for g_id, constraint in enumerate(self._jax_constraints.nodal):
234
+ self._ocp.param_dict["g_" + str(g_id)].value = np.asarray(
235
+ constraint.func(state.x, state.u, 0, param_dict)
236
+ )
237
+ self._ocp.param_dict["grad_g_x_" + str(g_id)].value = np.asarray(
238
+ constraint.grad_g_x(state.x, state.u, 0, param_dict)
239
+ )
240
+ self._ocp.param_dict["grad_g_u_" + str(g_id)].value = np.asarray(
241
+ constraint.grad_g_u(state.x, state.u, 0, param_dict)
242
+ )
243
+
244
+ # Update cross-node constraint linearization parameters
245
+ if self._jax_constraints.cross_node:
246
+ for g_id, constraint in enumerate(self._jax_constraints.cross_node):
247
+ # Cross-node constraints take (X, U, params) not (x, u, node, params)
248
+ self._ocp.param_dict["g_cross_" + str(g_id)].value = np.asarray(
249
+ constraint.func(state.x, state.u, param_dict)
250
+ )
251
+ self._ocp.param_dict["grad_g_X_cross_" + str(g_id)].value = np.asarray(
252
+ constraint.grad_g_X(state.x, state.u, param_dict)
253
+ )
254
+ self._ocp.param_dict["grad_g_U_cross_" + str(g_id)].value = np.asarray(
255
+ constraint.grad_g_U(state.x, state.u, param_dict)
256
+ )
257
+
258
+ # Convex constraints are already lowered and handled in the OCP, no action needed here
259
+
260
+ # Initialize lam_vc as matrix if it's still a scalar in state
261
+ if isinstance(state.lam_vc, (int, float)):
262
+ # Convert scalar to matrix: (N-1, n_states)
263
+ state.lam_vc = np.ones((settings.scp.n - 1, settings.sim.n_states)) * state.lam_vc
264
+
265
+ # Update CVXPy parameters from state
266
+ self._ocp.param_dict["w_tr"].value = state.w_tr
267
+ self._ocp.param_dict["lam_cost"].value = state.lam_cost
268
+ self._ocp.param_dict["lam_vc"].value = state.lam_vc
269
+ self._ocp.param_dict["lam_vb"].value = state.lam_vb
270
+
271
+ t0 = time.time()
272
+ self._solve_ocp()
273
+ subprop_time = time.time() - t0
274
+
275
+ x_new_guess = (
276
+ settings.sim.S_x @ self._ocp.var_dict["x"].value.T
277
+ + np.expand_dims(settings.sim.c_x, axis=1)
278
+ ).T
279
+ u_new_guess = (
280
+ settings.sim.S_u @ self._ocp.var_dict["u"].value.T
281
+ + np.expand_dims(settings.sim.c_u, axis=1)
282
+ ).T
283
+
284
+ # Calculate costs from boundary conditions using utility function
285
+ # Note: The original code only considered final_type, but the utility handles both
286
+ # Here we maintain backward compatibility by only using final_type
287
+ costs = [0]
288
+ for i, bc_type in enumerate(settings.sim.x.final_type):
289
+ if bc_type == "Minimize":
290
+ costs += x_new_guess[:, i]
291
+ elif bc_type == "Maximize":
292
+ costs -= x_new_guess[:, i]
293
+
294
+ # Create the block diagonal matrix using jax.numpy.block
295
+ inv_block_diag = np.block(
296
+ [
297
+ [
298
+ settings.sim.inv_S_x,
299
+ np.zeros((settings.sim.inv_S_x.shape[0], settings.sim.inv_S_u.shape[1])),
300
+ ],
301
+ [
302
+ np.zeros((settings.sim.inv_S_u.shape[0], settings.sim.inv_S_x.shape[1])),
303
+ settings.sim.inv_S_u,
304
+ ],
305
+ ]
306
+ )
307
+
308
+ # Calculate J_tr_vec using the JAX-compatible block diagonal matrix
309
+ tr_mat = inv_block_diag @ np.hstack((x_new_guess - state.x, u_new_guess - state.u)).T
310
+ J_tr_vec = la.norm(tr_mat, axis=0) ** 2
311
+ vc_mat = np.abs(self._ocp.var_dict["nu"].value)
312
+ J_vc_vec = np.sum(vc_mat, axis=1)
313
+
314
+ id_ncvx = 0
315
+ J_vb_vec = 0
316
+ if self._jax_constraints.nodal:
317
+ for constraint in self._jax_constraints.nodal:
318
+ J_vb_vec += np.maximum(0, self._ocp.var_dict["nu_vb_" + str(id_ncvx)].value)
319
+ id_ncvx += 1
320
+
321
+ # Add cross-node constraint violations
322
+ id_cross = 0
323
+ if self._jax_constraints.cross_node:
324
+ for constraint in self._jax_constraints.cross_node:
325
+ J_vb_vec += np.maximum(0, self._ocp.var_dict["nu_vb_cross_" + str(id_cross)].value)
326
+ id_cross += 1
327
+
328
+ # Convex constraints are already handled in the OCP, no processing needed here
329
+ return (
330
+ x_new_guess,
331
+ u_new_guess,
332
+ costs,
333
+ self._ocp.value,
334
+ J_vb_vec,
335
+ J_vc_vec,
336
+ J_tr_vec,
337
+ self._ocp.status,
338
+ V_multi_shoot,
339
+ subprop_time,
340
+ dis_time,
341
+ vc_mat,
342
+ abs(tr_mat),
343
+ )
344
+
345
+ def citation(self) -> List[str]:
346
+ """Return BibTeX citations for the PTR algorithm.
347
+
348
+ Returns:
349
+ List containing the BibTeX entry for the PTR paper.
350
+ """
351
+ return [
352
+ r"""@article{drusvyatskiy2018error,
353
+ title={Error bounds, quadratic growth, and linear convergence of proximal methods},
354
+ author={Drusvyatskiy, Dmitriy and Lewis, Adrian S},
355
+ journal={Mathematics of operations research},
356
+ volume={43},
357
+ number={3},
358
+ pages={919--948},
359
+ year={2018},
360
+ publisher={INFORMS}
361
+ }""",
362
+ r"""@article{szmuk2020successive,
363
+ title={Successive convexification for real-time six-degree-of-freedom powered descent guidance
364
+ with state-triggered constraints},
365
+ author={Szmuk, Michael and Reynolds, Taylor P and A{\c{c}}{\i}kme{\c{s}}e, Beh{\c{c}}et},
366
+ journal={Journal of Guidance, Control, and Dynamics},
367
+ volume={43},
368
+ number={8},
369
+ pages={1399--1413},
370
+ year={2020},
371
+ publisher={American Institute of Aeronautics and Astronautics}
372
+ }""",
373
+ r"""@article{reynolds2020dual,
374
+ title={Dual quaternion-based powered descent guidance with state-triggered constraints},
375
+ author={Reynolds, Taylor P and Szmuk, Michael and Malyuta, Danylo and Mesbahi, Mehran and
376
+ A{\c{c}}{\i}kme{\c{s}}e, Beh{\c{c}}et and Carson III, John M},
377
+ journal={Journal of Guidance, Control, and Dynamics},
378
+ volume={43},
379
+ number={9},
380
+ pages={1584--1599},
381
+ year={2020},
382
+ publisher={American Institute of Aeronautics and Astronautics}
383
+ }""",
384
+ ]