openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,760 @@
|
|
|
1
|
+
"""Symbolic expression lowering to executable code.
|
|
2
|
+
|
|
3
|
+
This module provides the main entry point for converting symbolic expressions
|
|
4
|
+
(AST nodes) into executable code for different backends (JAX, CVXPy, etc.).
|
|
5
|
+
The lowering process translates the symbolic expression tree into functions
|
|
6
|
+
that can be executed during optimization.
|
|
7
|
+
|
|
8
|
+
Architecture:
|
|
9
|
+
The lowering process follows a visitor pattern where each backend implements
|
|
10
|
+
a lowerer class (e.g., JaxLowerer, CVXPyLowerer) with visitor methods for
|
|
11
|
+
each expression type. The `lower()` function dispatches expression nodes
|
|
12
|
+
to the appropriate backend.
|
|
13
|
+
|
|
14
|
+
Lowering Flow:
|
|
15
|
+
|
|
16
|
+
1. Symbolic expressions are built during problem specification
|
|
17
|
+
2. lower_symbolic_expressions() coordinates the full lowering process
|
|
18
|
+
3. Backend-specific lowerers convert each expression node to executable code
|
|
19
|
+
4. Automatic differentiation creates Jacobians for dynamics and constraints
|
|
20
|
+
5. Result is a set of executable functions ready for numerical optimization
|
|
21
|
+
|
|
22
|
+
Backends:
|
|
23
|
+
- JAX: For dynamics and non-convex constraints (with automatic differentiation)
|
|
24
|
+
- CVXPy: For convex constraints (with disciplined convex programming)
|
|
25
|
+
|
|
26
|
+
Example:
|
|
27
|
+
Basic lowering to JAX::
|
|
28
|
+
|
|
29
|
+
import openscvx as ox
|
|
30
|
+
from openscvx.symbolic.lower import lower_to_jax
|
|
31
|
+
|
|
32
|
+
# Define symbolic expression
|
|
33
|
+
x = ox.State("x", shape=(3,))
|
|
34
|
+
u = ox.Control("u", shape=(2,))
|
|
35
|
+
expr = ox.Norm(x)**2 + 0.1 * ox.Norm(u)**2
|
|
36
|
+
|
|
37
|
+
# Lower to JAX function
|
|
38
|
+
f = lower_to_jax(expr)
|
|
39
|
+
# f is now a callable: f(x_val, u_val, node, params) -> scalar
|
|
40
|
+
|
|
41
|
+
Full problem lowering::
|
|
42
|
+
|
|
43
|
+
# After building symbolic problem...
|
|
44
|
+
lowered = lower_symbolic_problem(
|
|
45
|
+
dynamics_aug, states_aug, controls_aug,
|
|
46
|
+
constraints, parameters, N,
|
|
47
|
+
dynamics_prop, states_prop, controls_prop
|
|
48
|
+
)
|
|
49
|
+
# Access via LoweredProblem dataclass
|
|
50
|
+
dynamics = lowered.dynamics
|
|
51
|
+
jax_constraints = lowered.jax_constraints
|
|
52
|
+
# Now have executable JAX functions with Jacobians
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
|
|
56
|
+
|
|
57
|
+
import cvxpy as cp
|
|
58
|
+
import jax
|
|
59
|
+
import numpy as np
|
|
60
|
+
from jax import jacfwd
|
|
61
|
+
|
|
62
|
+
from openscvx.expert import apply_byof
|
|
63
|
+
from openscvx.lowered import (
|
|
64
|
+
CVXPyVariables,
|
|
65
|
+
Dynamics,
|
|
66
|
+
LoweredCrossNodeConstraint,
|
|
67
|
+
LoweredCvxpyConstraints,
|
|
68
|
+
LoweredJaxConstraints,
|
|
69
|
+
LoweredNodalConstraint,
|
|
70
|
+
LoweredProblem,
|
|
71
|
+
)
|
|
72
|
+
from openscvx.symbolic.constraint_set import ConstraintSet
|
|
73
|
+
from openscvx.symbolic.expr import Expr, NodeReference
|
|
74
|
+
|
|
75
|
+
if TYPE_CHECKING:
|
|
76
|
+
from openscvx.lowered.unified import UnifiedState
|
|
77
|
+
from openscvx.symbolic.problem import SymbolicProblem
|
|
78
|
+
|
|
79
|
+
__all__ = [
|
|
80
|
+
"lower",
|
|
81
|
+
"lower_to_jax",
|
|
82
|
+
"lower_cvxpy_constraints",
|
|
83
|
+
"create_cvxpy_variables",
|
|
84
|
+
"lower_symbolic_problem",
|
|
85
|
+
]
|
|
86
|
+
from openscvx.lowered.unified import UnifiedControl, UnifiedState
|
|
87
|
+
from openscvx.symbolic.unified import unify_controls, unify_states
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def lower(expr: Expr, lowerer: Any):
|
|
91
|
+
"""Dispatch an expression node to the appropriate lowerer backend.
|
|
92
|
+
|
|
93
|
+
This is the main entry point for lowering a single symbolic expression to
|
|
94
|
+
executable code. It delegates to the lowerer's `lower()` method, which
|
|
95
|
+
uses the visitor pattern to dispatch based on expression type.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
expr: Symbolic expression to lower (any Expr subclass)
|
|
99
|
+
lowerer: Backend lowerer instance (e.g., JaxLowerer, CVXPyLowerer)
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Backend-specific representation of the expression. For JaxLowerer,
|
|
103
|
+
returns a callable with signature (x, u, node, params) -> result.
|
|
104
|
+
For CVXPyLowerer, returns a CVXPy expression object.
|
|
105
|
+
|
|
106
|
+
Raises:
|
|
107
|
+
NotImplementedError: If the lowerer doesn't support the expression type
|
|
108
|
+
|
|
109
|
+
Example:
|
|
110
|
+
Lower an expression to the appropriate backend (here JAX):
|
|
111
|
+
|
|
112
|
+
from openscvx.symbolic.lowerers.jax import JaxLowerer
|
|
113
|
+
x = ox.State("x", shape=(3,))
|
|
114
|
+
expr = ox.Norm(x)
|
|
115
|
+
lowerer = JaxLowerer()
|
|
116
|
+
f = lower(expr, lowerer)
|
|
117
|
+
|
|
118
|
+
f is now callable: f(x_val, u_val, node, params) -> scalar
|
|
119
|
+
"""
|
|
120
|
+
return lowerer.lower(expr)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
# --- Convenience wrappers for common backends ---
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def lower_to_jax(exprs: Union[Expr, Sequence[Expr]]) -> Union[callable, list[callable]]:
|
|
127
|
+
"""Lower symbolic expression(s) to JAX callable(s).
|
|
128
|
+
|
|
129
|
+
Convenience wrapper that creates a JaxLowerer and lowers one or more
|
|
130
|
+
symbolic expressions to JAX functions. The resulting functions can be
|
|
131
|
+
JIT-compiled and automatically differentiated.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
exprs: Single expression or sequence of expressions to lower
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
- If exprs is a single Expr: Returns a single callable with signature
|
|
138
|
+
(x, u, node, params) -> array
|
|
139
|
+
- If exprs is a sequence: Returns a list of callables with the same signature
|
|
140
|
+
|
|
141
|
+
Example:
|
|
142
|
+
Single expression::
|
|
143
|
+
|
|
144
|
+
x = ox.State("x", shape=(3,))
|
|
145
|
+
expr = ox.Norm(x)**2
|
|
146
|
+
f = lower_to_jax(expr)
|
|
147
|
+
# f(x_val, u_val, node_idx, params_dict) -> scalar
|
|
148
|
+
|
|
149
|
+
Multiple expressions::
|
|
150
|
+
|
|
151
|
+
exprs = [ox.Norm(x), ox.Norm(u), x @ A @ x]
|
|
152
|
+
fns = lower_to_jax(exprs)
|
|
153
|
+
# fns is [f1, f2, f3], each with same signature
|
|
154
|
+
|
|
155
|
+
Note:
|
|
156
|
+
All returned JAX functions have a uniform signature
|
|
157
|
+
(x, u, node, params) regardless of whether they use all arguments.
|
|
158
|
+
This standardization simplifies vectorization and differentiation.
|
|
159
|
+
"""
|
|
160
|
+
from openscvx.symbolic.lowerers.jax import JaxLowerer
|
|
161
|
+
|
|
162
|
+
jl = JaxLowerer()
|
|
163
|
+
if isinstance(exprs, Expr):
|
|
164
|
+
return lower(exprs, jl)
|
|
165
|
+
fns = [lower(e, jl) for e in exprs]
|
|
166
|
+
return fns
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def create_cvxpy_variables(
|
|
170
|
+
N: int,
|
|
171
|
+
n_states: int,
|
|
172
|
+
n_controls: int,
|
|
173
|
+
S_x: np.ndarray,
|
|
174
|
+
c_x: np.ndarray,
|
|
175
|
+
S_u: np.ndarray,
|
|
176
|
+
c_u: np.ndarray,
|
|
177
|
+
n_nodal_constraints: int,
|
|
178
|
+
n_cross_node_constraints: int,
|
|
179
|
+
) -> CVXPyVariables:
|
|
180
|
+
"""Create CVXPy variables and parameters for the optimal control problem.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
N: Number of discretization nodes
|
|
184
|
+
n_states: Number of state variables
|
|
185
|
+
n_controls: Number of control variables
|
|
186
|
+
S_x: State scaling matrix
|
|
187
|
+
c_x: State offset vector
|
|
188
|
+
S_u: Control scaling matrix
|
|
189
|
+
c_u: Control offset vector
|
|
190
|
+
n_nodal_constraints: Number of non-convex nodal constraints (for linearization params)
|
|
191
|
+
n_cross_node_constraints: Number of non-convex cross-node constraints
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
CVXPyVariables dataclass containing all CVXPy variables and parameters for the OCP
|
|
195
|
+
"""
|
|
196
|
+
########################
|
|
197
|
+
# VARIABLES & PARAMETERS
|
|
198
|
+
########################
|
|
199
|
+
|
|
200
|
+
inv_S_x = np.linalg.inv(S_x)
|
|
201
|
+
inv_S_u = np.linalg.inv(S_u)
|
|
202
|
+
|
|
203
|
+
# Parameters
|
|
204
|
+
w_tr = cp.Parameter(nonneg=True, name="w_tr")
|
|
205
|
+
lam_cost = cp.Parameter(nonneg=True, name="lam_cost")
|
|
206
|
+
lam_vc = cp.Parameter((N - 1, n_states), nonneg=True, name="lam_vc")
|
|
207
|
+
lam_vb = cp.Parameter(nonneg=True, name="lam_vb")
|
|
208
|
+
|
|
209
|
+
# State
|
|
210
|
+
x = cp.Variable((N, n_states), name="x") # Current State
|
|
211
|
+
dx = cp.Variable((N, n_states), name="dx") # State Error
|
|
212
|
+
x_bar = cp.Parameter((N, n_states), name="x_bar") # Previous SCP State
|
|
213
|
+
x_init = cp.Parameter(n_states, name="x_init") # Initial State
|
|
214
|
+
x_term = cp.Parameter(n_states, name="x_term") # Final State
|
|
215
|
+
|
|
216
|
+
# Control
|
|
217
|
+
u = cp.Variable((N, n_controls), name="u") # Current Control
|
|
218
|
+
du = cp.Variable((N, n_controls), name="du") # Control Error
|
|
219
|
+
u_bar = cp.Parameter((N, n_controls), name="u_bar") # Previous SCP Control
|
|
220
|
+
|
|
221
|
+
# Discretized Augmented Dynamics Constraints
|
|
222
|
+
A_d = cp.Parameter((N - 1, n_states, n_states), name="A_d")
|
|
223
|
+
B_d = cp.Parameter((N - 1, n_states, n_controls), name="B_d")
|
|
224
|
+
C_d = cp.Parameter((N - 1, n_states, n_controls), name="C_d")
|
|
225
|
+
x_prop = cp.Parameter((N - 1, n_states), name="x_prop")
|
|
226
|
+
nu = cp.Variable((N - 1, n_states), name="nu") # Virtual Control
|
|
227
|
+
|
|
228
|
+
# Linearized Nonconvex Nodal Constraints
|
|
229
|
+
g = []
|
|
230
|
+
grad_g_x = []
|
|
231
|
+
grad_g_u = []
|
|
232
|
+
nu_vb = []
|
|
233
|
+
for idx_ncvx in range(n_nodal_constraints):
|
|
234
|
+
g.append(cp.Parameter(N, name="g_" + str(idx_ncvx)))
|
|
235
|
+
grad_g_x.append(cp.Parameter((N, n_states), name="grad_g_x_" + str(idx_ncvx)))
|
|
236
|
+
grad_g_u.append(cp.Parameter((N, n_controls), name="grad_g_u_" + str(idx_ncvx)))
|
|
237
|
+
nu_vb.append(cp.Variable(N, name="nu_vb_" + str(idx_ncvx))) # Virtual Control for VB
|
|
238
|
+
|
|
239
|
+
# Linearized Cross-Node Constraints
|
|
240
|
+
g_cross = []
|
|
241
|
+
grad_g_X_cross = []
|
|
242
|
+
grad_g_U_cross = []
|
|
243
|
+
nu_vb_cross = []
|
|
244
|
+
for idx_cross in range(n_cross_node_constraints):
|
|
245
|
+
# Cross-node constraints are single constraints with fixed node references
|
|
246
|
+
g_cross.append(cp.Parameter(name="g_cross_" + str(idx_cross)))
|
|
247
|
+
grad_g_X_cross.append(cp.Parameter((N, n_states), name="grad_g_X_cross_" + str(idx_cross)))
|
|
248
|
+
grad_g_U_cross.append(
|
|
249
|
+
cp.Parameter((N, n_controls), name="grad_g_U_cross_" + str(idx_cross))
|
|
250
|
+
)
|
|
251
|
+
nu_vb_cross.append(
|
|
252
|
+
cp.Variable(name="nu_vb_cross_" + str(idx_cross))
|
|
253
|
+
) # Virtual Control for VB
|
|
254
|
+
|
|
255
|
+
# Applying the affine scaling to state and control
|
|
256
|
+
x_nonscaled = []
|
|
257
|
+
u_nonscaled = []
|
|
258
|
+
dx_nonscaled = []
|
|
259
|
+
du_nonscaled = []
|
|
260
|
+
for k in range(N):
|
|
261
|
+
x_nonscaled.append(S_x @ x[k] + c_x)
|
|
262
|
+
u_nonscaled.append(S_u @ u[k] + c_u)
|
|
263
|
+
dx_nonscaled.append(S_x @ dx[k])
|
|
264
|
+
du_nonscaled.append(S_u @ du[k])
|
|
265
|
+
|
|
266
|
+
return CVXPyVariables(
|
|
267
|
+
w_tr=w_tr,
|
|
268
|
+
lam_cost=lam_cost,
|
|
269
|
+
lam_vc=lam_vc,
|
|
270
|
+
lam_vb=lam_vb,
|
|
271
|
+
x=x,
|
|
272
|
+
dx=dx,
|
|
273
|
+
x_bar=x_bar,
|
|
274
|
+
x_init=x_init,
|
|
275
|
+
x_term=x_term,
|
|
276
|
+
u=u,
|
|
277
|
+
du=du,
|
|
278
|
+
u_bar=u_bar,
|
|
279
|
+
A_d=A_d,
|
|
280
|
+
B_d=B_d,
|
|
281
|
+
C_d=C_d,
|
|
282
|
+
x_prop=x_prop,
|
|
283
|
+
nu=nu,
|
|
284
|
+
g=g,
|
|
285
|
+
grad_g_x=grad_g_x,
|
|
286
|
+
grad_g_u=grad_g_u,
|
|
287
|
+
nu_vb=nu_vb,
|
|
288
|
+
g_cross=g_cross,
|
|
289
|
+
grad_g_X_cross=grad_g_X_cross,
|
|
290
|
+
grad_g_U_cross=grad_g_U_cross,
|
|
291
|
+
nu_vb_cross=nu_vb_cross,
|
|
292
|
+
S_x=S_x,
|
|
293
|
+
inv_S_x=inv_S_x,
|
|
294
|
+
c_x=c_x,
|
|
295
|
+
S_u=S_u,
|
|
296
|
+
inv_S_u=inv_S_u,
|
|
297
|
+
c_u=c_u,
|
|
298
|
+
x_nonscaled=x_nonscaled,
|
|
299
|
+
u_nonscaled=u_nonscaled,
|
|
300
|
+
dx_nonscaled=dx_nonscaled,
|
|
301
|
+
du_nonscaled=du_nonscaled,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def lower_cvxpy_constraints(
|
|
306
|
+
constraints: ConstraintSet,
|
|
307
|
+
x_cvxpy: List,
|
|
308
|
+
u_cvxpy: List,
|
|
309
|
+
parameters: dict = None,
|
|
310
|
+
) -> Tuple[List, dict]:
|
|
311
|
+
"""Lower symbolic convex constraints to CVXPy constraints.
|
|
312
|
+
|
|
313
|
+
Converts symbolic convex constraint expressions to CVXPy constraint objects
|
|
314
|
+
that can be used in the optimal control problem. This function handles both
|
|
315
|
+
nodal constraints (applied at specific trajectory nodes) and cross-node
|
|
316
|
+
constraints (relating multiple nodes).
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
constraints: ConstraintSet containing nodal_convex and cross_node_convex
|
|
320
|
+
x_cvxpy: List of CVXPy expressions for state at each node (length N).
|
|
321
|
+
Typically the x_nonscaled list from create_cvxpy_variables().
|
|
322
|
+
u_cvxpy: List of CVXPy expressions for control at each node (length N).
|
|
323
|
+
Typically the u_nonscaled list from create_cvxpy_variables().
|
|
324
|
+
parameters: Optional dict of parameter values to use for any Parameter
|
|
325
|
+
expressions in the constraints. If None, uses Parameter default values.
|
|
326
|
+
|
|
327
|
+
Returns:
|
|
328
|
+
Tuple of:
|
|
329
|
+
- List of CVXPy constraint objects ready for the OCP
|
|
330
|
+
- Dict mapping parameter names to their CVXPy Parameter objects
|
|
331
|
+
|
|
332
|
+
Example:
|
|
333
|
+
After creating CVXPy variables::
|
|
334
|
+
|
|
335
|
+
ocp_vars = create_cvxpy_variables(settings)
|
|
336
|
+
cvxpy_constraints, cvxpy_params = lower_cvxpy_constraints(
|
|
337
|
+
constraint_set,
|
|
338
|
+
ocp_vars.x_nonscaled,
|
|
339
|
+
ocp_vars.u_nonscaled,
|
|
340
|
+
parameters,
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
Note:
|
|
344
|
+
This function only processes convex constraints (nodal_convex and
|
|
345
|
+
cross_node_convex). Non-convex constraints are lowered to JAX in
|
|
346
|
+
lower_symbolic_expressions() and handled via linearization in the SCP.
|
|
347
|
+
"""
|
|
348
|
+
import cvxpy as cp
|
|
349
|
+
|
|
350
|
+
from openscvx.symbolic.expr import Parameter, traverse
|
|
351
|
+
from openscvx.symbolic.expr.control import Control
|
|
352
|
+
from openscvx.symbolic.expr.state import State
|
|
353
|
+
from openscvx.symbolic.lowerers.cvxpy import lower_to_cvxpy
|
|
354
|
+
|
|
355
|
+
all_constraints = list(constraints.nodal_convex) + list(constraints.cross_node_convex)
|
|
356
|
+
|
|
357
|
+
if not all_constraints:
|
|
358
|
+
return [], {}
|
|
359
|
+
|
|
360
|
+
# Collect all unique Parameters across all constraints and create cp.Parameter objects
|
|
361
|
+
all_params = {}
|
|
362
|
+
|
|
363
|
+
def collect_params(expr):
|
|
364
|
+
if isinstance(expr, Parameter):
|
|
365
|
+
if expr.name not in all_params:
|
|
366
|
+
# Use value from params dict if provided, otherwise use Parameter's initial value
|
|
367
|
+
if parameters and expr.name in parameters:
|
|
368
|
+
param_value = parameters[expr.name]
|
|
369
|
+
else:
|
|
370
|
+
param_value = expr.value
|
|
371
|
+
|
|
372
|
+
cvx_param = cp.Parameter(expr.shape, value=param_value, name=expr.name)
|
|
373
|
+
all_params[expr.name] = cvx_param
|
|
374
|
+
|
|
375
|
+
# Collect all parameters from all constraints
|
|
376
|
+
for constraint in all_constraints:
|
|
377
|
+
traverse(constraint.constraint, collect_params)
|
|
378
|
+
|
|
379
|
+
cvxpy_constraints = []
|
|
380
|
+
|
|
381
|
+
# Process nodal constraints
|
|
382
|
+
for constraint in constraints.nodal_convex:
|
|
383
|
+
# nodes should already be validated and normalized in preprocessing
|
|
384
|
+
nodes = constraint.nodes
|
|
385
|
+
|
|
386
|
+
# Collect all State and Control variables referenced in the constraint
|
|
387
|
+
state_vars = {}
|
|
388
|
+
control_vars = {}
|
|
389
|
+
|
|
390
|
+
def collect_vars(expr):
|
|
391
|
+
if isinstance(expr, State):
|
|
392
|
+
state_vars[expr.name] = expr
|
|
393
|
+
elif isinstance(expr, Control):
|
|
394
|
+
control_vars[expr.name] = expr
|
|
395
|
+
|
|
396
|
+
traverse(constraint.constraint, collect_vars)
|
|
397
|
+
|
|
398
|
+
# Regular nodal constraint: apply at each specified node
|
|
399
|
+
for node in nodes:
|
|
400
|
+
# Create variable map for this specific node
|
|
401
|
+
variable_map = {}
|
|
402
|
+
|
|
403
|
+
if state_vars:
|
|
404
|
+
variable_map["x"] = x_cvxpy[node]
|
|
405
|
+
|
|
406
|
+
if control_vars:
|
|
407
|
+
variable_map["u"] = u_cvxpy[node]
|
|
408
|
+
|
|
409
|
+
# Add all CVXPy Parameter objects to the variable map
|
|
410
|
+
variable_map.update(all_params)
|
|
411
|
+
|
|
412
|
+
# Verify all variables have slices (should be guaranteed by preprocessing)
|
|
413
|
+
for state_name, state_var in state_vars.items():
|
|
414
|
+
if state_var._slice is None:
|
|
415
|
+
raise ValueError(
|
|
416
|
+
f"State variable '{state_name}' has no slice assigned. "
|
|
417
|
+
f"This indicates a bug in the preprocessing pipeline."
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
for control_name, control_var in control_vars.items():
|
|
421
|
+
if control_var._slice is None:
|
|
422
|
+
raise ValueError(
|
|
423
|
+
f"Control variable '{control_name}' has no slice assigned. "
|
|
424
|
+
f"This indicates a bug in the preprocessing pipeline."
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
# Lower the constraint to CVXPy
|
|
428
|
+
cvxpy_constraint = lower_to_cvxpy(constraint.constraint, variable_map)
|
|
429
|
+
cvxpy_constraints.append(cvxpy_constraint)
|
|
430
|
+
|
|
431
|
+
# Process cross-node constraints
|
|
432
|
+
for constraint in constraints.cross_node_convex:
|
|
433
|
+
# Collect all State and Control variables referenced in the constraint
|
|
434
|
+
state_vars = {}
|
|
435
|
+
control_vars = {}
|
|
436
|
+
|
|
437
|
+
def collect_vars(expr):
|
|
438
|
+
if isinstance(expr, State):
|
|
439
|
+
state_vars[expr.name] = expr
|
|
440
|
+
elif isinstance(expr, Control):
|
|
441
|
+
control_vars[expr.name] = expr
|
|
442
|
+
|
|
443
|
+
traverse(constraint.constraint, collect_vars)
|
|
444
|
+
|
|
445
|
+
# Cross-node constraint: provide full trajectory
|
|
446
|
+
variable_map = {}
|
|
447
|
+
|
|
448
|
+
# Stack all nodes into (N, n_x) and (N, n_u) matrices
|
|
449
|
+
if state_vars:
|
|
450
|
+
variable_map["x"] = cp.vstack(x_cvxpy)
|
|
451
|
+
|
|
452
|
+
if control_vars:
|
|
453
|
+
variable_map["u"] = cp.vstack(u_cvxpy)
|
|
454
|
+
|
|
455
|
+
# Add all CVXPy Parameter objects to the variable map
|
|
456
|
+
variable_map.update(all_params)
|
|
457
|
+
|
|
458
|
+
# Verify all variables have slices
|
|
459
|
+
for state_name, state_var in state_vars.items():
|
|
460
|
+
if state_var._slice is None:
|
|
461
|
+
raise ValueError(
|
|
462
|
+
f"State variable '{state_name}' has no slice assigned. "
|
|
463
|
+
f"This indicates a bug in the preprocessing pipeline."
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
for control_name, control_var in control_vars.items():
|
|
467
|
+
if control_var._slice is None:
|
|
468
|
+
raise ValueError(
|
|
469
|
+
f"Control variable '{control_name}' has no slice assigned. "
|
|
470
|
+
f"This indicates a bug in the preprocessing pipeline."
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
# Lower the constraint once - NodeReference handles node indexing internally
|
|
474
|
+
cvxpy_constraint = lower_to_cvxpy(constraint.constraint, variable_map)
|
|
475
|
+
cvxpy_constraints.append(cvxpy_constraint)
|
|
476
|
+
|
|
477
|
+
return cvxpy_constraints, all_params
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def _lower_dynamics(dynamics_expr) -> Dynamics:
|
|
481
|
+
"""Lower symbolic dynamics to JAX function with Jacobians.
|
|
482
|
+
|
|
483
|
+
Converts a symbolic dynamics expression to a JAX function and computes
|
|
484
|
+
Jacobians via automatic differentiation.
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
dynamics_expr: Symbolic dynamics expression (dx/dt = f(x, u))
|
|
488
|
+
|
|
489
|
+
Returns:
|
|
490
|
+
Dynamics object with f, A (df/dx), B (df/du)
|
|
491
|
+
"""
|
|
492
|
+
dyn_fn = lower_to_jax(dynamics_expr)
|
|
493
|
+
return Dynamics(
|
|
494
|
+
f=dyn_fn,
|
|
495
|
+
A=jacfwd(dyn_fn, argnums=0), # df/dx
|
|
496
|
+
B=jacfwd(dyn_fn, argnums=1), # df/du
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def _lower_jax_constraints(
|
|
501
|
+
constraints: ConstraintSet,
|
|
502
|
+
) -> LoweredJaxConstraints:
|
|
503
|
+
"""Lower non-convex constraints to JAX functions with gradients.
|
|
504
|
+
|
|
505
|
+
Converts symbolic non-convex constraints to JAX callable functions with
|
|
506
|
+
automatically computed gradients for use in SCP linearization.
|
|
507
|
+
|
|
508
|
+
Args:
|
|
509
|
+
constraints: ConstraintSet containing nodal and cross_node constraints
|
|
510
|
+
|
|
511
|
+
Returns:
|
|
512
|
+
LoweredJaxConstraints with nodal, cross_node, and ctcs lists
|
|
513
|
+
"""
|
|
514
|
+
lowered_nodal: List[LoweredNodalConstraint] = []
|
|
515
|
+
lowered_cross_node: List[LoweredCrossNodeConstraint] = []
|
|
516
|
+
|
|
517
|
+
# Lower regular nodal constraints
|
|
518
|
+
if len(constraints.nodal) > 0:
|
|
519
|
+
# Convert symbolic constraint expressions to JAX functions
|
|
520
|
+
constraints_nodal_fns = lower_to_jax(constraints.nodal)
|
|
521
|
+
|
|
522
|
+
# Create LoweredConstraint objects with Jacobians
|
|
523
|
+
for i, fn in enumerate(constraints_nodal_fns):
|
|
524
|
+
# Apply vectorization to handle (N, n_x) and (N, n_u) inputs
|
|
525
|
+
constraint = LoweredNodalConstraint(
|
|
526
|
+
func=jax.vmap(fn, in_axes=(0, 0, None, None)),
|
|
527
|
+
grad_g_x=jax.vmap(jacfwd(fn, argnums=0), in_axes=(0, 0, None, None)),
|
|
528
|
+
grad_g_u=jax.vmap(jacfwd(fn, argnums=1), in_axes=(0, 0, None, None)),
|
|
529
|
+
nodes=constraints.nodal[i].nodes,
|
|
530
|
+
)
|
|
531
|
+
lowered_nodal.append(constraint)
|
|
532
|
+
|
|
533
|
+
# Lower cross-node constraints (trajectory-level)
|
|
534
|
+
for cross_node_constraint in constraints.cross_node:
|
|
535
|
+
# Lower the CrossNodeConstraint - visitor handles wrapping
|
|
536
|
+
constraint_fn = lower_to_jax(cross_node_constraint)
|
|
537
|
+
|
|
538
|
+
# Compute Jacobians for trajectory-level function
|
|
539
|
+
grad_g_X = jacfwd(constraint_fn, argnums=0) # dg/dX - shape (N, n_x)
|
|
540
|
+
grad_g_U = jacfwd(constraint_fn, argnums=1) # dg/dU - shape (N, n_u)
|
|
541
|
+
|
|
542
|
+
cross_node_lowered = LoweredCrossNodeConstraint(
|
|
543
|
+
func=constraint_fn,
|
|
544
|
+
grad_g_X=grad_g_X,
|
|
545
|
+
grad_g_U=grad_g_U,
|
|
546
|
+
)
|
|
547
|
+
lowered_cross_node.append(cross_node_lowered)
|
|
548
|
+
|
|
549
|
+
return LoweredJaxConstraints(
|
|
550
|
+
nodal=lowered_nodal,
|
|
551
|
+
cross_node=lowered_cross_node,
|
|
552
|
+
ctcs=list(constraints.ctcs), # Copy the list
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
def _lower_cvxpy(
|
|
557
|
+
constraints: ConstraintSet,
|
|
558
|
+
parameters: dict,
|
|
559
|
+
N: int,
|
|
560
|
+
x_unified: UnifiedState,
|
|
561
|
+
u_unified: UnifiedControl,
|
|
562
|
+
jax_constraints: LoweredJaxConstraints,
|
|
563
|
+
) -> Tuple[CVXPyVariables, LoweredCvxpyConstraints, dict]:
|
|
564
|
+
"""Create CVXPy variables and lower convex constraints.
|
|
565
|
+
|
|
566
|
+
Creates all CVXPy variables/parameters needed for the OCP and lowers
|
|
567
|
+
convex constraints to CVXPy constraint objects.
|
|
568
|
+
|
|
569
|
+
Args:
|
|
570
|
+
constraints: ConstraintSet containing convex constraints
|
|
571
|
+
parameters: Dict of parameter values for constraint lowering
|
|
572
|
+
N: Number of discretization nodes
|
|
573
|
+
jax_constraints: Lowered JAX constraints (for sizing CVXPy variables)
|
|
574
|
+
x_unified: Unified state interface (for dimensions and scaling)
|
|
575
|
+
u_unified: Unified control interface (for dimensions and scaling)
|
|
576
|
+
|
|
577
|
+
Returns:
|
|
578
|
+
Tuple of:
|
|
579
|
+
- CVXPyVariables dataclass with all OCP variables
|
|
580
|
+
- LoweredCvxpyConstraints with CVXPy constraint objects
|
|
581
|
+
- Dict mapping parameter names to CVXPy Parameter objects
|
|
582
|
+
"""
|
|
583
|
+
from openscvx.config import get_affine_scaling_matrices
|
|
584
|
+
|
|
585
|
+
n_states = len(x_unified.max)
|
|
586
|
+
n_controls = len(u_unified.max)
|
|
587
|
+
|
|
588
|
+
# Compute scaling matrices from unified object bounds
|
|
589
|
+
if x_unified.scaling_min is not None:
|
|
590
|
+
lower_x = np.array(x_unified.scaling_min, dtype=float)
|
|
591
|
+
else:
|
|
592
|
+
lower_x = np.array(x_unified.min, dtype=float)
|
|
593
|
+
|
|
594
|
+
if x_unified.scaling_max is not None:
|
|
595
|
+
upper_x = np.array(x_unified.scaling_max, dtype=float)
|
|
596
|
+
else:
|
|
597
|
+
upper_x = np.array(x_unified.max, dtype=float)
|
|
598
|
+
|
|
599
|
+
S_x, c_x = get_affine_scaling_matrices(n_states, lower_x, upper_x)
|
|
600
|
+
|
|
601
|
+
if u_unified.scaling_min is not None:
|
|
602
|
+
lower_u = np.array(u_unified.scaling_min, dtype=float)
|
|
603
|
+
else:
|
|
604
|
+
lower_u = np.array(u_unified.min, dtype=float)
|
|
605
|
+
|
|
606
|
+
if u_unified.scaling_max is not None:
|
|
607
|
+
upper_u = np.array(u_unified.scaling_max, dtype=float)
|
|
608
|
+
else:
|
|
609
|
+
upper_u = np.array(u_unified.max, dtype=float)
|
|
610
|
+
|
|
611
|
+
S_u, c_u = get_affine_scaling_matrices(n_controls, lower_u, upper_u)
|
|
612
|
+
|
|
613
|
+
# Create all CVXPy variables for the OCP
|
|
614
|
+
ocp_vars = create_cvxpy_variables(
|
|
615
|
+
N=N,
|
|
616
|
+
n_states=n_states,
|
|
617
|
+
n_controls=n_controls,
|
|
618
|
+
S_x=S_x,
|
|
619
|
+
c_x=c_x,
|
|
620
|
+
S_u=S_u,
|
|
621
|
+
c_u=c_u,
|
|
622
|
+
n_nodal_constraints=len(jax_constraints.nodal),
|
|
623
|
+
n_cross_node_constraints=len(jax_constraints.cross_node),
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
# Lower convex constraints to CVXPy
|
|
627
|
+
lowered_cvxpy_constraint_list, cvxpy_params = lower_cvxpy_constraints(
|
|
628
|
+
constraints,
|
|
629
|
+
ocp_vars.x_nonscaled,
|
|
630
|
+
ocp_vars.u_nonscaled,
|
|
631
|
+
parameters,
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
cvxpy_constraints = LoweredCvxpyConstraints(
|
|
635
|
+
constraints=lowered_cvxpy_constraint_list,
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
return ocp_vars, cvxpy_constraints, cvxpy_params
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
def _contains_node_reference(expr: Expr) -> bool:
|
|
642
|
+
"""Check if an expression contains any NodeReference nodes.
|
|
643
|
+
|
|
644
|
+
Internal helper for routing constraints during lowering.
|
|
645
|
+
|
|
646
|
+
Recursively traverses the expression tree to detect the presence of
|
|
647
|
+
NodeReference nodes, which indicate cross-node constraints.
|
|
648
|
+
|
|
649
|
+
Args:
|
|
650
|
+
expr: Expression to check for NodeReference nodes
|
|
651
|
+
|
|
652
|
+
Returns:
|
|
653
|
+
True if the expression contains at least one NodeReference, False otherwise
|
|
654
|
+
|
|
655
|
+
Example:
|
|
656
|
+
position = State("pos", shape=(3,))
|
|
657
|
+
|
|
658
|
+
# Regular expression - no NodeReference
|
|
659
|
+
_contains_node_reference(position) # False
|
|
660
|
+
|
|
661
|
+
# Cross-node expression - has NodeReference
|
|
662
|
+
_contains_node_reference(position.at(10) - position.at(9)) # True
|
|
663
|
+
"""
|
|
664
|
+
if isinstance(expr, NodeReference):
|
|
665
|
+
return True
|
|
666
|
+
|
|
667
|
+
# Recursively check all children
|
|
668
|
+
for child in expr.children():
|
|
669
|
+
if _contains_node_reference(child):
|
|
670
|
+
return True
|
|
671
|
+
|
|
672
|
+
return False
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
def lower_symbolic_problem(
|
|
676
|
+
problem: "SymbolicProblem", byof: Optional[dict] = None
|
|
677
|
+
) -> LoweredProblem:
|
|
678
|
+
"""Lower symbolic problem specification to executable JAX and CVXPy code.
|
|
679
|
+
|
|
680
|
+
This is the main orchestrator for converting a preprocessed SymbolicProblem
|
|
681
|
+
into executable numerical code. It coordinates the lowering of dynamics,
|
|
682
|
+
constraints, and state/control interfaces from symbolic AST representations
|
|
683
|
+
to JAX functions (with automatic differentiation) and CVXPy constraints.
|
|
684
|
+
|
|
685
|
+
This is pure translation - no validation, shape checking, or augmentation occurs
|
|
686
|
+
here. The input problem must be preprocessed (problem.is_preprocessed == True).
|
|
687
|
+
|
|
688
|
+
Args:
|
|
689
|
+
problem: Preprocessed SymbolicProblem from preprocess_symbolic_problem().
|
|
690
|
+
Must have is_preprocessed == True.
|
|
691
|
+
byof: Optional dict of raw JAX functions for expert users. Supported keys:
|
|
692
|
+
- "nodal_constraints": List of f(x, u, node, params) -> residual
|
|
693
|
+
- "cross_nodal_constraints": List of f(X, U, params) -> residual
|
|
694
|
+
- "ctcs_constraints": List of dicts with "constraint_fn", "penalty", "bounds"
|
|
695
|
+
|
|
696
|
+
Returns:
|
|
697
|
+
LoweredProblem dataclass containing lowered problem
|
|
698
|
+
|
|
699
|
+
Example:
|
|
700
|
+
After preprocessing::
|
|
701
|
+
|
|
702
|
+
problem = preprocess_symbolic_problem(...)
|
|
703
|
+
lowered = lower_symbolic_problem(problem)
|
|
704
|
+
|
|
705
|
+
# Access dynamics
|
|
706
|
+
dx = lowered.dynamics.f(x_val, u_val, node=0, params={...})
|
|
707
|
+
|
|
708
|
+
# Use CVXPy objects for OCP
|
|
709
|
+
ocp = OptimalControlProblem(settings, lowered)
|
|
710
|
+
|
|
711
|
+
Raises:
|
|
712
|
+
AssertionError: If problem.is_preprocessed is False
|
|
713
|
+
"""
|
|
714
|
+
assert problem.is_preprocessed, "Problem must be preprocessed before lowering"
|
|
715
|
+
|
|
716
|
+
# Create unified state/control interfaces
|
|
717
|
+
x_unified = unify_states(problem.states, name="x")
|
|
718
|
+
u_unified = unify_controls(problem.controls)
|
|
719
|
+
x_prop_unified = unify_states(problem.states_prop, name="x_prop")
|
|
720
|
+
|
|
721
|
+
# Lower dynamics to JAX
|
|
722
|
+
dynamics = _lower_dynamics(problem.dynamics)
|
|
723
|
+
dynamics_prop = _lower_dynamics(problem.dynamics_prop)
|
|
724
|
+
|
|
725
|
+
# Lower non-convex constraints to JAX
|
|
726
|
+
jax_constraints = _lower_jax_constraints(problem.constraints)
|
|
727
|
+
|
|
728
|
+
# Handle byof (bring-your-own-functions) for expert users
|
|
729
|
+
# This must happen BEFORE CVXPy variable creation since CTCS constraints
|
|
730
|
+
# augment the state dimension
|
|
731
|
+
if byof is not None:
|
|
732
|
+
dynamics, dynamics_prop, jax_constraints, x_unified, x_prop_unified = apply_byof(
|
|
733
|
+
byof,
|
|
734
|
+
dynamics,
|
|
735
|
+
dynamics_prop,
|
|
736
|
+
jax_constraints,
|
|
737
|
+
x_unified,
|
|
738
|
+
x_prop_unified,
|
|
739
|
+
u_unified,
|
|
740
|
+
problem.states,
|
|
741
|
+
problem.states_prop,
|
|
742
|
+
problem.N,
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
# Create CVXPy variables and lower convex constraints
|
|
746
|
+
ocp_vars, cvxpy_constraints, cvxpy_params = _lower_cvxpy(
|
|
747
|
+
problem.constraints, problem.parameters, problem.N, x_unified, u_unified, jax_constraints
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
return LoweredProblem(
|
|
751
|
+
dynamics=dynamics,
|
|
752
|
+
dynamics_prop=dynamics_prop,
|
|
753
|
+
jax_constraints=jax_constraints,
|
|
754
|
+
cvxpy_constraints=cvxpy_constraints,
|
|
755
|
+
x_unified=x_unified,
|
|
756
|
+
u_unified=u_unified,
|
|
757
|
+
x_prop_unified=x_prop_unified,
|
|
758
|
+
ocp_vars=ocp_vars,
|
|
759
|
+
cvxpy_params=cvxpy_params,
|
|
760
|
+
)
|