openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,60 @@
1
+ """Trajectory propagation for trajectory optimization.
2
+
3
+ This module provides implementations of trajectory propagation methods that
4
+ simulate the nonlinear system dynamics forward in time. Propagation is used
5
+ to evaluate solution quality, verify constraint satisfaction, and generate
6
+ high-fidelity trajectories from optimized control sequences.
7
+
8
+ Current Implementations:
9
+ Forward Simulation: The default propagation method that integrates the
10
+ nonlinear dynamics forward in time using adaptive or fixed-step
11
+ numerical integration (via Diffrax). Supports both ZOH and FOH
12
+ control interpolation schemes.
13
+
14
+ Planned Architecture (ABC-based):
15
+
16
+ A base class will be introduced to enable pluggable propagation methods.
17
+ This will enable users to implement custom propagation methods.
18
+ Future propagators will implement the Propagator interface:
19
+
20
+ ```python
21
+ # propagation/base.py (planned):
22
+ class Propagator(ABC):
23
+ def __init__(self, integrator: Integrator):
24
+ '''Initialize with a numerical integrator.'''
25
+ self.integrator = integrator
26
+
27
+ @abstractmethod
28
+ def propagate(self, dynamics, x0, u_traj, time_grid) -> Array:
29
+ '''Propagate trajectory forward in time.
30
+
31
+ Args:
32
+ dynamics: Continuous-time dynamics object
33
+ x0: Initial state
34
+ u_traj: Control trajectory
35
+ time_grid: Time points for dense output
36
+
37
+ Returns:
38
+ State trajectory evaluated at time_grid points
39
+ '''
40
+ ...
41
+ ```
42
+ """
43
+
44
+ from .post_processing import propagate_trajectory_results
45
+ from .propagation import (
46
+ get_propagation_solver,
47
+ prop_aug_dy,
48
+ s_to_t,
49
+ simulate_nonlinear_time,
50
+ t_to_tau,
51
+ )
52
+
53
+ __all__ = [
54
+ "get_propagation_solver",
55
+ "simulate_nonlinear_time",
56
+ "prop_aug_dy",
57
+ "s_to_t",
58
+ "t_to_tau",
59
+ "propagate_trajectory_results",
60
+ ]
@@ -0,0 +1,104 @@
1
+ import copy
2
+
3
+ import jax.numpy as jnp
4
+ import numpy as np
5
+
6
+ from openscvx.algorithms import OptimizationResults
7
+ from openscvx.config import Config
8
+ from openscvx.utils import calculate_cost_from_boundaries
9
+
10
+ from .propagation import s_to_t, simulate_nonlinear_time, t_to_tau
11
+
12
+
13
+ def propagate_trajectory_results(
14
+ params: dict, settings: Config, result: OptimizationResults, propagation_solver: callable
15
+ ) -> OptimizationResults:
16
+ """Propagate the optimal trajectory and compute additional results.
17
+
18
+ This function takes the optimal control solution and propagates it through the
19
+ nonlinear dynamics to compute the actual state trajectory and other metrics.
20
+
21
+ Args:
22
+ params (dict): System parameters.
23
+ settings (Config): Configuration settings.
24
+ result (OptimizationResults): Optimization results object.
25
+ propagation_solver (callable): Function for propagating the system state.
26
+
27
+ Returns:
28
+ OptimizationResults: Updated results object containing:
29
+ - t_full: Full time vector
30
+ - x_full: Full state trajectory
31
+ - u_full: Full control trajectory
32
+ - cost: Computed cost
33
+ - ctcs_violation: CTCS constraint violation
34
+ """
35
+ # Get arrays from result
36
+ x = result.x
37
+ u = result.u
38
+
39
+ t = np.array(s_to_t(x, u, settings)).squeeze()
40
+
41
+ t_full = np.arange(t[0], t[-1], settings.prp.dt)
42
+
43
+ tau_vals, u_full = t_to_tau(u, t_full, t, settings)
44
+
45
+ # Create a copy of x_prop for propagation to avoid mutating settings
46
+ # Match free values from initial state to the initial value from the result
47
+ x_prop_for_propagation = copy.copy(settings.sim.x_prop)
48
+
49
+ # Only copy for states that exist in optimization (propagation may have extra states at the end)
50
+ n_opt_states = x.shape[1]
51
+ n_prop_states = settings.sim.x_prop.initial.shape[0]
52
+
53
+ if n_opt_states == n_prop_states:
54
+ # Same size - copy all
55
+ # Use metadata from settings (immutable configuration)
56
+ mask = jnp.array([t == "Free" for t in settings.sim.x.initial_type], dtype=bool)
57
+ x_prop_for_propagation.initial = jnp.where(mask, x[0, :], settings.sim.x_prop.initial)
58
+ else:
59
+ # Propagation has extra states - only copy the overlapping portion
60
+ # Use metadata from settings (immutable configuration)
61
+ mask = jnp.array([t == "Free" for t in settings.sim.x.initial_type], dtype=bool)
62
+ x_prop_initial_updated = settings.sim.x_prop.initial.copy()
63
+ x_prop_initial_updated[:n_opt_states] = jnp.where(
64
+ mask, x[0, :], settings.sim.x_prop.initial[:n_opt_states]
65
+ )
66
+ x_prop_for_propagation.initial = x_prop_initial_updated
67
+
68
+ # Temporarily replace x_prop with our modified copy for propagation
69
+ # Save original to restore after propagation
70
+ original_x_prop = settings.sim.x_prop
71
+ settings.sim.x_prop = x_prop_for_propagation
72
+
73
+ try:
74
+ x_full = simulate_nonlinear_time(params, x, u, tau_vals, t, settings, propagation_solver)
75
+ finally:
76
+ # Always restore original x_prop, even if propagation fails
77
+ settings.sim.x_prop = original_x_prop
78
+
79
+ # Calculate cost using utility function and metadata from settings
80
+ cost = calculate_cost_from_boundaries(x, settings.sim.x.initial_type, settings.sim.x.final_type)
81
+
82
+ # Calculate CTCS constraint violation
83
+ ctcs_violation = x_full[-1, settings.sim.ctcs_slice_prop]
84
+
85
+ # Build trajectory dictionary with all states and controls
86
+ trajectory_dict = {}
87
+
88
+ # Add all states (user-defined and augmented)
89
+ for state in result._states:
90
+ trajectory_dict[state.name] = x_full[:, state._slice]
91
+
92
+ # Add all controls (user-defined and augmented)
93
+ for control in result._controls:
94
+ trajectory_dict[control.name] = u_full[:, control._slice]
95
+
96
+ # Update the results object with post-processing data
97
+ result.t_full = t_full
98
+ result.x_full = x_full
99
+ result.u_full = u_full
100
+ result.cost = cost
101
+ result.ctcs_violation = ctcs_violation
102
+ result.trajectory = trajectory_dict
103
+
104
+ return result
@@ -0,0 +1,248 @@
1
+ import numpy as np
2
+
3
+ from openscvx.config import Config
4
+ from openscvx.integrators import solve_ivp_diffrax_prop
5
+ from openscvx.lowered import Dynamics
6
+
7
+
8
+ def prop_aug_dy(
9
+ tau: float,
10
+ x: np.ndarray,
11
+ u_current: np.ndarray,
12
+ u_next: np.ndarray,
13
+ tau_init: float,
14
+ node: int,
15
+ idx_s: int,
16
+ state_dot: callable,
17
+ dis_type: str,
18
+ N: int,
19
+ params,
20
+ ) -> np.ndarray:
21
+ """Compute the augmented dynamics for propagation.
22
+
23
+ This function computes the time-scaled dynamics for propagating the system state,
24
+ taking into account the discretization type (ZOH or FOH) and time dilation.
25
+
26
+ Args:
27
+ tau (float): Current normalized time in [0,1].
28
+ x (np.ndarray): Current state vector.
29
+ u_current (np.ndarray): Control input at current node.
30
+ u_next (np.ndarray): Control input at next node.
31
+ tau_init (float): Initial normalized time.
32
+ node (int): Current node index.
33
+ idx_s (int): Index of time dilation variable in control vector.
34
+ state_dot (callable): Function computing state derivatives.
35
+ dis_type (str): Discretization type ("ZOH" or "FOH").
36
+ N (int): Number of nodes in trajectory.
37
+ params: Dictionary of additional parameters passed to state_dot.
38
+
39
+ Returns:
40
+ np.ndarray: Time-scaled state derivatives.
41
+ """
42
+ x = x[None, :]
43
+
44
+ if dis_type == "ZOH":
45
+ beta = 0.0
46
+ elif dis_type == "FOH":
47
+ beta = (tau - tau_init) * N
48
+ u = u_current + beta * (u_next - u_current)
49
+
50
+ return u[:, idx_s] * state_dot(x, u[:, :-1], node, params).squeeze()
51
+
52
+
53
+ def get_propagation_solver(state_dot: Dynamics, settings: Config):
54
+ """Create a propagation solver function.
55
+
56
+ This function creates a solver that propagates the system state using the
57
+ specified dynamics and settings.
58
+
59
+ Args:
60
+ state_dot (callable): Function computing state derivatives.
61
+ settings: Configuration settings for propagation.
62
+ param_map (dict): Mapping of parameter names to values.
63
+
64
+ Returns:
65
+ callable: A function that solves the propagation problem.
66
+ """
67
+
68
+ def propagation_solver(
69
+ V0, tau_grid, u_cur, u_next, tau_init, node, idx_s, save_time, mask, params
70
+ ):
71
+ param_map_update = params
72
+ return solve_ivp_diffrax_prop(
73
+ f=prop_aug_dy,
74
+ tau_final=tau_grid[1], # scalar
75
+ y_0=V0, # shape (n_states,)
76
+ args=(
77
+ u_cur, # shape (1, n_controls)
78
+ u_next, # shape (1, n_controls)
79
+ tau_init, # shape (1, 1)
80
+ node, # shape (1, 1)
81
+ idx_s, # int
82
+ state_dot, # function or array
83
+ settings.dis.dis_type,
84
+ settings.scp.n,
85
+ param_map_update,
86
+ # additional named parameters as **kwargs
87
+ ),
88
+ tau_0=tau_grid[0], # scalar
89
+ solver_name=settings.prp.solver,
90
+ rtol=settings.prp.rtol,
91
+ atol=settings.prp.atol,
92
+ extra_kwargs=settings.prp.args,
93
+ save_time=save_time, # shape (MAX_TAU_LEN,)
94
+ mask=mask, # shape (MAX_TAU_LEN,), dtype=bool
95
+ )
96
+
97
+ return propagation_solver
98
+
99
+
100
+ def s_to_t(x: np.ndarray, u: np.ndarray, settings: Config):
101
+ """Convert normalized time s to real time t.
102
+
103
+ This function converts the normalized time variable s to real time t
104
+ based on the discretization type and time dilation factors.
105
+
106
+ Args:
107
+ x: State trajectory array, shape (N, n_states).
108
+ u: Control trajectory array, shape (N, n_controls).
109
+ settings (Config): Configuration settings.
110
+
111
+ Returns:
112
+ list: List of real time points.
113
+ """
114
+ t = [x[:, settings.sim.time_slice][0]]
115
+ tau = np.linspace(0, 1, settings.scp.n)
116
+ for k in range(1, settings.scp.n):
117
+ s_kp = u[k - 1, -1]
118
+ s_k = u[k, -1]
119
+ if settings.dis.dis_type == "ZOH":
120
+ t.append(t[k - 1] + (tau[k] - tau[k - 1]) * (s_kp))
121
+ else:
122
+ t.append(t[k - 1] + 0.5 * (s_k + s_kp) * (tau[k] - tau[k - 1]))
123
+ return t
124
+
125
+
126
+ def t_to_tau(u: np.ndarray, t, t_nodal, settings: Config):
127
+ """Convert real time t to normalized time tau.
128
+
129
+ This function converts real time t to normalized time tau and interpolates
130
+ the control inputs accordingly.
131
+
132
+ Args:
133
+ u (np.ndarray): Control trajectory array, shape (N, n_controls).
134
+ t (np.ndarray): Real time points.
135
+ t_nodal (np.ndarray): Nodal time points.
136
+ settings (Config): Configuration settings.
137
+
138
+ Returns:
139
+ tuple: (tau, u_interp) where tau is normalized time and u_interp is interpolated controls.
140
+ """
141
+ if settings.dis.dis_type == "ZOH":
142
+ # Zero-Order Hold: step interpolation (hold previous value)
143
+ def u_lam(new_t):
144
+ # Find the index of the last nodal time <= new_t
145
+ idx = np.searchsorted(t_nodal, new_t, side="right") - 1
146
+ idx = np.clip(idx, 0, len(t_nodal) - 1)
147
+ return u[idx, :]
148
+ elif settings.dis.dis_type == "FOH":
149
+ # First-Order Hold: linear interpolation
150
+ def u_lam(new_t):
151
+ return np.array([np.interp(new_t, t_nodal, u[:, i]) for i in range(u.shape[1])]).T
152
+ else:
153
+ raise ValueError("Currently unsupported discretization type")
154
+
155
+ u_interp = np.array([u_lam(t_i) for t_i in t])
156
+
157
+ tau = np.zeros(len(t))
158
+ tau_nodal = np.linspace(0, 1, settings.scp.n)
159
+ for k in range(1, len(t)):
160
+ k_nodal = np.where(t_nodal < t[k])[0][-1]
161
+ s_kp = u[k_nodal, -1]
162
+ tp = t_nodal[k_nodal]
163
+ tau_p = tau_nodal[k_nodal]
164
+
165
+ s_k = u[k_nodal + 1, -1]
166
+ if settings.dis.dis_type == "ZOH":
167
+ tau[k] = tau_p + (t[k] - tp) / s_kp
168
+ else:
169
+ tau[k] = tau_p + 2 * (t[k] - tp) / (s_k + s_kp)
170
+ return tau, u_interp
171
+
172
+
173
+ def simulate_nonlinear_time(params, x, u, tau_vals, t, settings, propagation_solver):
174
+ """Simulate the nonlinear system dynamics over time.
175
+
176
+ This function simulates the system dynamics using the optimal control sequence
177
+ and returns the resulting state trajectory.
178
+
179
+ Args:
180
+ params: System parameters.
181
+ x: State trajectory array, shape (N, n_states).
182
+ u: Control trajectory array, shape (N, n_controls).
183
+ tau_vals (np.ndarray): Normalized time points for simulation.
184
+ t (np.ndarray): Real time points.
185
+ settings: Configuration settings.
186
+ propagation_solver (callable): Function for propagating the system state.
187
+
188
+ Returns:
189
+ np.ndarray: Simulated state trajectory.
190
+ """
191
+ x_0 = settings.sim.x_prop.initial
192
+
193
+ n_segments = settings.scp.n - 1
194
+ n_states = x_0.shape[0]
195
+ n_tau = len(tau_vals)
196
+
197
+ states = np.empty((n_states, n_tau))
198
+ tau = np.linspace(0, 1, settings.scp.n)
199
+
200
+ # Precompute control interpolation
201
+ u_interp = np.stack([np.interp(t, t, u[:, i]) for i in range(u.shape[1])], axis=-1)
202
+
203
+ # Bin tau_vals into segments of tau
204
+ tau_inds = np.digitize(tau_vals, tau) - 1
205
+ tau_inds = np.where(tau_inds == settings.scp.n - 1, settings.scp.n - 2, tau_inds)
206
+
207
+ prev_count = 0
208
+ out_idx = 0
209
+
210
+ for k in range(n_segments):
211
+ controls_current = u_interp[k][None, :]
212
+ controls_next = u_interp[k + 1][None, :]
213
+
214
+ # Mask for tau_vals in current segment
215
+ mask = (tau_inds >= k) & (tau_inds < k + 1)
216
+ count = np.sum(mask)
217
+
218
+ tau_cur = tau_vals[prev_count : prev_count + count]
219
+ tau_cur = np.concatenate([tau_cur, np.array([tau[k + 1]])]) # Always include final point
220
+ count += 1
221
+
222
+ # Pad to fixed length
223
+ pad_len = settings.prp.max_tau_len - count
224
+ tau_cur_padded = np.pad(tau_cur, (0, pad_len), constant_values=tau[k + 1])
225
+ mask_padded = np.concatenate([np.ones(count), np.zeros(pad_len)]).astype(bool)
226
+
227
+ # Call the solver with padded tau_cur and mask
228
+ sol = propagation_solver.call(
229
+ x_0,
230
+ (tau[k], tau[k + 1]),
231
+ controls_current,
232
+ controls_next,
233
+ np.array([[tau[k]]]),
234
+ np.array([[k]]),
235
+ settings.sim.time_dilation_slice.stop,
236
+ tau_cur_padded,
237
+ mask_padded,
238
+ params,
239
+ )
240
+
241
+ # Only store the valid portion (excluding the final point which becomes next x_0)
242
+ states[:, out_idx : out_idx + count - 1] = sol[: count - 1].T
243
+ out_idx += count - 1
244
+ x_0 = sol[count - 1] # Last value used as next x_0
245
+
246
+ prev_count += count - 1
247
+
248
+ return states.T
@@ -0,0 +1,51 @@
1
+ """Convex subproblem solvers for trajectory optimization.
2
+
3
+ This module provides implementations of convex subproblem solvers used within
4
+ SCvx algorithms. At each iteration of a successive convexification algorithm,
5
+ the non-convex problem is approximated by a convex subproblem, which is then
6
+ solved using one of these solver backends.
7
+
8
+ Current Implementations:
9
+ CVXPy Solver: The default solver backend using CVXPy's modeling language
10
+ with support for multiple backend solvers (CLARABEL, etc.).
11
+ Includes optional code generation via cvxpygen for improved performance.
12
+
13
+ Note:
14
+ CVXPyGen setup logic is currently in :class:`Problem`. When the
15
+ ``ConvexSolver`` base class is implemented, this setup will move here.
16
+
17
+ Planned Architecture (ABC-based):
18
+
19
+ A base class will be introduced to enable pluggable solver implementations.
20
+ This will enable users to implement custom solver backends such as:
21
+
22
+ - Direct Clarabel solver (Rust-based, GPU-capable)
23
+ - QPAX (JAX-based QP solver for end-to-end differentiability)
24
+ - OSQP direct interface (specialized for QP structure)
25
+ - Custom embedded solvers for real-time applications
26
+ - Research solvers with specialized structure exploitation
27
+
28
+ This should also make the solver choice independent of the algorithm choice
29
+
30
+ Future solvers will implement the ConvexSolver interface:
31
+
32
+ ```python
33
+ # solvers/base.py (planned):
34
+ class ConvexSolver(ABC):
35
+ @abstractmethod
36
+ def build_subproblem(self, state: AlgorithmState, lowered: LoweredProblem):
37
+ '''Build the convex subproblem from current state.'''
38
+ ...
39
+
40
+ @abstractmethod
41
+ def solve(self) -> OptimizationResults:
42
+ '''Solve the convex subproblem and return results.'''
43
+ ...
44
+ ```
45
+ """
46
+
47
+ from .cvxpy import optimal_control_problem
48
+
49
+ __all__ = [
50
+ "optimal_control_problem",
51
+ ]
@@ -0,0 +1,226 @@
1
+ import os
2
+ from typing import TYPE_CHECKING
3
+
4
+ import cvxpy as cp
5
+ import numpy as np
6
+
7
+ from openscvx.config import Config
8
+
9
+ if TYPE_CHECKING:
10
+ from openscvx.lowered import LoweredProblem
11
+
12
+ # Optional cvxpygen import
13
+ try:
14
+ from cvxpygen import cpg
15
+
16
+ CVXPYGEN_AVAILABLE = True
17
+ except ImportError:
18
+ CVXPYGEN_AVAILABLE = False
19
+ cpg = None
20
+
21
+
22
+ def optimal_control_problem(settings: Config, lowered: "LoweredProblem"):
23
+ """Build the complete optimal control problem with all constraints.
24
+
25
+ Args:
26
+ settings: Configuration settings for the optimization problem
27
+ lowered: LoweredProblem containing ocp_vars and lowered constraints
28
+ """
29
+ # Extract typed CVXPy variables from LoweredProblem
30
+ ocp_vars = lowered.ocp_vars
31
+
32
+ # Extract variables from the dataclass for easier access
33
+ w_tr = ocp_vars.w_tr
34
+ lam_cost = ocp_vars.lam_cost
35
+ lam_vc = ocp_vars.lam_vc
36
+ lam_vb = ocp_vars.lam_vb
37
+ x = ocp_vars.x
38
+ dx = ocp_vars.dx
39
+ x_bar = ocp_vars.x_bar
40
+ x_init = ocp_vars.x_init
41
+ x_term = ocp_vars.x_term
42
+ u = ocp_vars.u
43
+ du = ocp_vars.du
44
+ u_bar = ocp_vars.u_bar
45
+ A_d = ocp_vars.A_d
46
+ B_d = ocp_vars.B_d
47
+ C_d = ocp_vars.C_d
48
+ x_prop = ocp_vars.x_prop
49
+ nu = ocp_vars.nu
50
+ g = ocp_vars.g
51
+ grad_g_x = ocp_vars.grad_g_x
52
+ grad_g_u = ocp_vars.grad_g_u
53
+ nu_vb = ocp_vars.nu_vb
54
+ g_cross = ocp_vars.g_cross
55
+ grad_g_X_cross = ocp_vars.grad_g_X_cross
56
+ grad_g_U_cross = ocp_vars.grad_g_U_cross
57
+ nu_vb_cross = ocp_vars.nu_vb_cross
58
+ S_x = ocp_vars.S_x
59
+ c_x = ocp_vars.c_x
60
+ S_u = ocp_vars.S_u
61
+ c_u = ocp_vars.c_u
62
+ x_nonscaled = ocp_vars.x_nonscaled
63
+ u_nonscaled = ocp_vars.u_nonscaled
64
+ dx_nonscaled = ocp_vars.dx_nonscaled
65
+ du_nonscaled = ocp_vars.du_nonscaled
66
+
67
+ # Extract lowered constraints
68
+ jax_constraints = lowered.jax_constraints
69
+ cvxpy_constraints = lowered.cvxpy_constraints
70
+
71
+ constr = []
72
+ cost = lam_cost * 0
73
+ cost += lam_vb * 0
74
+
75
+ #############
76
+ # CONSTRAINTS
77
+ #############
78
+
79
+ # Linearized nodal constraints (from JAX-lowered non-convex)
80
+ idx_ncvx = 0
81
+ if jax_constraints.nodal:
82
+ for constraint in jax_constraints.nodal:
83
+ # nodes should already be validated and normalized in preprocessing
84
+ nodes = constraint.nodes
85
+ constr += [
86
+ (
87
+ g[idx_ncvx][node]
88
+ + grad_g_x[idx_ncvx][node] @ dx[node]
89
+ + grad_g_u[idx_ncvx][node] @ du[node]
90
+ )
91
+ == nu_vb[idx_ncvx][node]
92
+ for node in nodes
93
+ ]
94
+ idx_ncvx += 1
95
+
96
+ # Linearized cross-node constraints (from JAX-lowered non-convex)
97
+ idx_cross = 0
98
+ if jax_constraints.cross_node:
99
+ for constraint in jax_constraints.cross_node:
100
+ # Linearization: g(X_bar, U_bar) + ∇g_X @ dX + ∇g_U @ dU == nu_vb
101
+ # Sum over all trajectory nodes to couple multiple nodes
102
+ residual = g_cross[idx_cross]
103
+ for k in range(settings.scp.n):
104
+ # Contribution from state at node k
105
+ residual += grad_g_X_cross[idx_cross][k, :] @ dx[k]
106
+ # Contribution from control at node k
107
+ residual += grad_g_U_cross[idx_cross][k, :] @ du[k]
108
+ # Add constraint: residual == slack variable
109
+ constr += [residual == nu_vb_cross[idx_cross]]
110
+ idx_cross += 1
111
+
112
+ # Convex constraints (already lowered to CVXPy)
113
+ if cvxpy_constraints.constraints:
114
+ constr += cvxpy_constraints.constraints
115
+
116
+ for i in range(settings.sim.true_state_slice.start, settings.sim.true_state_slice.stop):
117
+ if settings.sim.x.initial_type[i] == "Fix":
118
+ constr += [x_nonscaled[0][i] == x_init[i]] # Initial Boundary Conditions
119
+ if settings.sim.x.final_type[i] == "Fix":
120
+ constr += [x_nonscaled[-1][i] == x_term[i]] # Final Boundary Conditions
121
+ if settings.sim.x.initial_type[i] == "Minimize":
122
+ cost += lam_cost * x_nonscaled[0][i]
123
+ if settings.sim.x.final_type[i] == "Minimize":
124
+ cost += lam_cost * x_nonscaled[-1][i]
125
+ if settings.sim.x.initial_type[i] == "Maximize":
126
+ cost -= lam_cost * x_nonscaled[0][i]
127
+ if settings.sim.x.final_type[i] == "Maximize":
128
+ cost -= lam_cost * x_nonscaled[-1][i]
129
+
130
+ if settings.scp.uniform_time_grid:
131
+ constr += [
132
+ u_nonscaled[i][settings.sim.time_dilation_slice]
133
+ == u_nonscaled[i - 1][settings.sim.time_dilation_slice]
134
+ for i in range(1, settings.scp.n)
135
+ ]
136
+
137
+ constr += [
138
+ (x[i] - np.linalg.inv(S_x) @ (x_bar[i] - c_x) - dx[i]) == 0 for i in range(settings.scp.n)
139
+ ] # State Error
140
+ constr += [
141
+ (u[i] - np.linalg.inv(S_u) @ (u_bar[i] - c_u) - du[i]) == 0 for i in range(settings.scp.n)
142
+ ] # Control Error
143
+
144
+ constr += [
145
+ x_nonscaled[i]
146
+ == A_d[i - 1] @ dx_nonscaled[i - 1]
147
+ + B_d[i - 1] @ du_nonscaled[i - 1]
148
+ + C_d[i - 1] @ du_nonscaled[i]
149
+ + x_prop[i - 1]
150
+ + nu[i - 1]
151
+ for i in range(1, settings.scp.n)
152
+ ] # Dynamics Constraint
153
+
154
+ constr += [u_nonscaled[i] <= settings.sim.u.max for i in range(settings.scp.n)]
155
+ constr += [
156
+ u_nonscaled[i] >= settings.sim.u.min for i in range(settings.scp.n)
157
+ ] # Control Constraints
158
+
159
+ # TODO: (norrisg) formalize this
160
+ constr += [x_nonscaled[i][:] <= settings.sim.x.max for i in range(settings.scp.n)]
161
+ constr += [
162
+ x_nonscaled[i][:] >= settings.sim.x.min for i in range(settings.scp.n)
163
+ ] # State Constraints (Also implemented in CTCS but included for numerical stability)
164
+
165
+ ########
166
+ # COSTS
167
+ ########
168
+
169
+ cost += sum(
170
+ w_tr * cp.sum_squares(cp.hstack((dx[i], du[i]))) for i in range(settings.scp.n)
171
+ ) # Trust Region Cost
172
+ cost += sum(
173
+ cp.sum(lam_vc[i - 1] * cp.abs(nu[i - 1])) for i in range(1, settings.scp.n)
174
+ ) # Virtual Control Slack
175
+
176
+ idx_ncvx = 0
177
+ if jax_constraints.nodal:
178
+ for constraint in jax_constraints.nodal:
179
+ cost += lam_vb * cp.sum(cp.pos(nu_vb[idx_ncvx]))
180
+ idx_ncvx += 1
181
+
182
+ # Virtual slack penalty for cross-node constraints
183
+ idx_cross = 0
184
+ if jax_constraints.cross_node:
185
+ for constraint in jax_constraints.cross_node:
186
+ cost += lam_vb * cp.pos(nu_vb_cross[idx_cross])
187
+ idx_cross += 1
188
+
189
+ for idx, nodes in zip(
190
+ np.arange(settings.sim.ctcs_slice.start, settings.sim.ctcs_slice.stop),
191
+ settings.sim.ctcs_node_intervals,
192
+ ):
193
+ start_idx = 1 if nodes[0] == 0 else nodes[0]
194
+ constr += [
195
+ cp.abs(x_nonscaled[i][idx] - x_nonscaled[i - 1][idx]) <= settings.sim.x.max[idx]
196
+ for i in range(start_idx, nodes[1])
197
+ ]
198
+ constr += [x_nonscaled[0][idx] == 0]
199
+
200
+ #########
201
+ # PROBLEM
202
+ #########
203
+ prob = cp.Problem(cp.Minimize(cost), constr)
204
+ if settings.cvx.cvxpygen:
205
+ if not CVXPYGEN_AVAILABLE:
206
+ raise ImportError(
207
+ "cvxpygen is required for code generation but not installed. "
208
+ "Install it with: pip install openscvx[cvxpygen] or pip install cvxpygen"
209
+ )
210
+ # Check to see if solver directory exists
211
+ if not os.path.exists("solver"):
212
+ cpg.generate_code(prob, solver=settings.cvx.solver, code_dir="solver", wrapper=True)
213
+ else:
214
+ # Prompt the use to indicate if they wish to overwrite the solver
215
+ # directory or use the existing compiled solver
216
+ if settings.cvx.cvxpygen_override:
217
+ cpg.generate_code(prob, solver=settings.cvx.solver, code_dir="solver", wrapper=True)
218
+ else:
219
+ overwrite = input("Solver directory already exists. Overwrite? (y/n): ")
220
+ if overwrite.lower() == "y":
221
+ cpg.generate_code(
222
+ prob, solver=settings.cvx.solver, code_dir="solver", wrapper=True
223
+ )
224
+ else:
225
+ pass
226
+ return prob
@@ -0,0 +1,9 @@
1
+ """Symbolic expression system for trajectory optimization.
2
+
3
+ See openscvx.symbolic.expr for detailed documentation and examples.
4
+ """
5
+
6
+ from openscvx.symbolic.constraint_set import ConstraintSet
7
+ from openscvx.symbolic.problem import SymbolicProblem
8
+
9
+ __all__ = ["ConstraintSet", "SymbolicProblem"]