openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,760 @@
1
+ """Symbolic expression lowering to executable code.
2
+
3
+ This module provides the main entry point for converting symbolic expressions
4
+ (AST nodes) into executable code for different backends (JAX, CVXPy, etc.).
5
+ The lowering process translates the symbolic expression tree into functions
6
+ that can be executed during optimization.
7
+
8
+ Architecture:
9
+ The lowering process follows a visitor pattern where each backend implements
10
+ a lowerer class (e.g., JaxLowerer, CVXPyLowerer) with visitor methods for
11
+ each expression type. The `lower()` function dispatches expression nodes
12
+ to the appropriate backend.
13
+
14
+ Lowering Flow:
15
+
16
+ 1. Symbolic expressions are built during problem specification
17
+ 2. lower_symbolic_expressions() coordinates the full lowering process
18
+ 3. Backend-specific lowerers convert each expression node to executable code
19
+ 4. Automatic differentiation creates Jacobians for dynamics and constraints
20
+ 5. Result is a set of executable functions ready for numerical optimization
21
+
22
+ Backends:
23
+ - JAX: For dynamics and non-convex constraints (with automatic differentiation)
24
+ - CVXPy: For convex constraints (with disciplined convex programming)
25
+
26
+ Example:
27
+ Basic lowering to JAX::
28
+
29
+ import openscvx as ox
30
+ from openscvx.symbolic.lower import lower_to_jax
31
+
32
+ # Define symbolic expression
33
+ x = ox.State("x", shape=(3,))
34
+ u = ox.Control("u", shape=(2,))
35
+ expr = ox.Norm(x)**2 + 0.1 * ox.Norm(u)**2
36
+
37
+ # Lower to JAX function
38
+ f = lower_to_jax(expr)
39
+ # f is now a callable: f(x_val, u_val, node, params) -> scalar
40
+
41
+ Full problem lowering::
42
+
43
+ # After building symbolic problem...
44
+ lowered = lower_symbolic_problem(
45
+ dynamics_aug, states_aug, controls_aug,
46
+ constraints, parameters, N,
47
+ dynamics_prop, states_prop, controls_prop
48
+ )
49
+ # Access via LoweredProblem dataclass
50
+ dynamics = lowered.dynamics
51
+ jax_constraints = lowered.jax_constraints
52
+ # Now have executable JAX functions with Jacobians
53
+ """
54
+
55
+ from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
56
+
57
+ import cvxpy as cp
58
+ import jax
59
+ import numpy as np
60
+ from jax import jacfwd
61
+
62
+ from openscvx.expert import apply_byof
63
+ from openscvx.lowered import (
64
+ CVXPyVariables,
65
+ Dynamics,
66
+ LoweredCrossNodeConstraint,
67
+ LoweredCvxpyConstraints,
68
+ LoweredJaxConstraints,
69
+ LoweredNodalConstraint,
70
+ LoweredProblem,
71
+ )
72
+ from openscvx.symbolic.constraint_set import ConstraintSet
73
+ from openscvx.symbolic.expr import Expr, NodeReference
74
+
75
+ if TYPE_CHECKING:
76
+ from openscvx.lowered.unified import UnifiedState
77
+ from openscvx.symbolic.problem import SymbolicProblem
78
+
79
+ __all__ = [
80
+ "lower",
81
+ "lower_to_jax",
82
+ "lower_cvxpy_constraints",
83
+ "create_cvxpy_variables",
84
+ "lower_symbolic_problem",
85
+ ]
86
+ from openscvx.lowered.unified import UnifiedControl, UnifiedState
87
+ from openscvx.symbolic.unified import unify_controls, unify_states
88
+
89
+
90
+ def lower(expr: Expr, lowerer: Any):
91
+ """Dispatch an expression node to the appropriate lowerer backend.
92
+
93
+ This is the main entry point for lowering a single symbolic expression to
94
+ executable code. It delegates to the lowerer's `lower()` method, which
95
+ uses the visitor pattern to dispatch based on expression type.
96
+
97
+ Args:
98
+ expr: Symbolic expression to lower (any Expr subclass)
99
+ lowerer: Backend lowerer instance (e.g., JaxLowerer, CVXPyLowerer)
100
+
101
+ Returns:
102
+ Backend-specific representation of the expression. For JaxLowerer,
103
+ returns a callable with signature (x, u, node, params) -> result.
104
+ For CVXPyLowerer, returns a CVXPy expression object.
105
+
106
+ Raises:
107
+ NotImplementedError: If the lowerer doesn't support the expression type
108
+
109
+ Example:
110
+ Lower an expression to the appropriate backend (here JAX):
111
+
112
+ from openscvx.symbolic.lowerers.jax import JaxLowerer
113
+ x = ox.State("x", shape=(3,))
114
+ expr = ox.Norm(x)
115
+ lowerer = JaxLowerer()
116
+ f = lower(expr, lowerer)
117
+
118
+ f is now callable: f(x_val, u_val, node, params) -> scalar
119
+ """
120
+ return lowerer.lower(expr)
121
+
122
+
123
+ # --- Convenience wrappers for common backends ---
124
+
125
+
126
+ def lower_to_jax(exprs: Union[Expr, Sequence[Expr]]) -> Union[callable, list[callable]]:
127
+ """Lower symbolic expression(s) to JAX callable(s).
128
+
129
+ Convenience wrapper that creates a JaxLowerer and lowers one or more
130
+ symbolic expressions to JAX functions. The resulting functions can be
131
+ JIT-compiled and automatically differentiated.
132
+
133
+ Args:
134
+ exprs: Single expression or sequence of expressions to lower
135
+
136
+ Returns:
137
+ - If exprs is a single Expr: Returns a single callable with signature
138
+ (x, u, node, params) -> array
139
+ - If exprs is a sequence: Returns a list of callables with the same signature
140
+
141
+ Example:
142
+ Single expression::
143
+
144
+ x = ox.State("x", shape=(3,))
145
+ expr = ox.Norm(x)**2
146
+ f = lower_to_jax(expr)
147
+ # f(x_val, u_val, node_idx, params_dict) -> scalar
148
+
149
+ Multiple expressions::
150
+
151
+ exprs = [ox.Norm(x), ox.Norm(u), x @ A @ x]
152
+ fns = lower_to_jax(exprs)
153
+ # fns is [f1, f2, f3], each with same signature
154
+
155
+ Note:
156
+ All returned JAX functions have a uniform signature
157
+ (x, u, node, params) regardless of whether they use all arguments.
158
+ This standardization simplifies vectorization and differentiation.
159
+ """
160
+ from openscvx.symbolic.lowerers.jax import JaxLowerer
161
+
162
+ jl = JaxLowerer()
163
+ if isinstance(exprs, Expr):
164
+ return lower(exprs, jl)
165
+ fns = [lower(e, jl) for e in exprs]
166
+ return fns
167
+
168
+
169
+ def create_cvxpy_variables(
170
+ N: int,
171
+ n_states: int,
172
+ n_controls: int,
173
+ S_x: np.ndarray,
174
+ c_x: np.ndarray,
175
+ S_u: np.ndarray,
176
+ c_u: np.ndarray,
177
+ n_nodal_constraints: int,
178
+ n_cross_node_constraints: int,
179
+ ) -> CVXPyVariables:
180
+ """Create CVXPy variables and parameters for the optimal control problem.
181
+
182
+ Args:
183
+ N: Number of discretization nodes
184
+ n_states: Number of state variables
185
+ n_controls: Number of control variables
186
+ S_x: State scaling matrix
187
+ c_x: State offset vector
188
+ S_u: Control scaling matrix
189
+ c_u: Control offset vector
190
+ n_nodal_constraints: Number of non-convex nodal constraints (for linearization params)
191
+ n_cross_node_constraints: Number of non-convex cross-node constraints
192
+
193
+ Returns:
194
+ CVXPyVariables dataclass containing all CVXPy variables and parameters for the OCP
195
+ """
196
+ ########################
197
+ # VARIABLES & PARAMETERS
198
+ ########################
199
+
200
+ inv_S_x = np.linalg.inv(S_x)
201
+ inv_S_u = np.linalg.inv(S_u)
202
+
203
+ # Parameters
204
+ w_tr = cp.Parameter(nonneg=True, name="w_tr")
205
+ lam_cost = cp.Parameter(nonneg=True, name="lam_cost")
206
+ lam_vc = cp.Parameter((N - 1, n_states), nonneg=True, name="lam_vc")
207
+ lam_vb = cp.Parameter(nonneg=True, name="lam_vb")
208
+
209
+ # State
210
+ x = cp.Variable((N, n_states), name="x") # Current State
211
+ dx = cp.Variable((N, n_states), name="dx") # State Error
212
+ x_bar = cp.Parameter((N, n_states), name="x_bar") # Previous SCP State
213
+ x_init = cp.Parameter(n_states, name="x_init") # Initial State
214
+ x_term = cp.Parameter(n_states, name="x_term") # Final State
215
+
216
+ # Control
217
+ u = cp.Variable((N, n_controls), name="u") # Current Control
218
+ du = cp.Variable((N, n_controls), name="du") # Control Error
219
+ u_bar = cp.Parameter((N, n_controls), name="u_bar") # Previous SCP Control
220
+
221
+ # Discretized Augmented Dynamics Constraints
222
+ A_d = cp.Parameter((N - 1, n_states, n_states), name="A_d")
223
+ B_d = cp.Parameter((N - 1, n_states, n_controls), name="B_d")
224
+ C_d = cp.Parameter((N - 1, n_states, n_controls), name="C_d")
225
+ x_prop = cp.Parameter((N - 1, n_states), name="x_prop")
226
+ nu = cp.Variable((N - 1, n_states), name="nu") # Virtual Control
227
+
228
+ # Linearized Nonconvex Nodal Constraints
229
+ g = []
230
+ grad_g_x = []
231
+ grad_g_u = []
232
+ nu_vb = []
233
+ for idx_ncvx in range(n_nodal_constraints):
234
+ g.append(cp.Parameter(N, name="g_" + str(idx_ncvx)))
235
+ grad_g_x.append(cp.Parameter((N, n_states), name="grad_g_x_" + str(idx_ncvx)))
236
+ grad_g_u.append(cp.Parameter((N, n_controls), name="grad_g_u_" + str(idx_ncvx)))
237
+ nu_vb.append(cp.Variable(N, name="nu_vb_" + str(idx_ncvx))) # Virtual Control for VB
238
+
239
+ # Linearized Cross-Node Constraints
240
+ g_cross = []
241
+ grad_g_X_cross = []
242
+ grad_g_U_cross = []
243
+ nu_vb_cross = []
244
+ for idx_cross in range(n_cross_node_constraints):
245
+ # Cross-node constraints are single constraints with fixed node references
246
+ g_cross.append(cp.Parameter(name="g_cross_" + str(idx_cross)))
247
+ grad_g_X_cross.append(cp.Parameter((N, n_states), name="grad_g_X_cross_" + str(idx_cross)))
248
+ grad_g_U_cross.append(
249
+ cp.Parameter((N, n_controls), name="grad_g_U_cross_" + str(idx_cross))
250
+ )
251
+ nu_vb_cross.append(
252
+ cp.Variable(name="nu_vb_cross_" + str(idx_cross))
253
+ ) # Virtual Control for VB
254
+
255
+ # Applying the affine scaling to state and control
256
+ x_nonscaled = []
257
+ u_nonscaled = []
258
+ dx_nonscaled = []
259
+ du_nonscaled = []
260
+ for k in range(N):
261
+ x_nonscaled.append(S_x @ x[k] + c_x)
262
+ u_nonscaled.append(S_u @ u[k] + c_u)
263
+ dx_nonscaled.append(S_x @ dx[k])
264
+ du_nonscaled.append(S_u @ du[k])
265
+
266
+ return CVXPyVariables(
267
+ w_tr=w_tr,
268
+ lam_cost=lam_cost,
269
+ lam_vc=lam_vc,
270
+ lam_vb=lam_vb,
271
+ x=x,
272
+ dx=dx,
273
+ x_bar=x_bar,
274
+ x_init=x_init,
275
+ x_term=x_term,
276
+ u=u,
277
+ du=du,
278
+ u_bar=u_bar,
279
+ A_d=A_d,
280
+ B_d=B_d,
281
+ C_d=C_d,
282
+ x_prop=x_prop,
283
+ nu=nu,
284
+ g=g,
285
+ grad_g_x=grad_g_x,
286
+ grad_g_u=grad_g_u,
287
+ nu_vb=nu_vb,
288
+ g_cross=g_cross,
289
+ grad_g_X_cross=grad_g_X_cross,
290
+ grad_g_U_cross=grad_g_U_cross,
291
+ nu_vb_cross=nu_vb_cross,
292
+ S_x=S_x,
293
+ inv_S_x=inv_S_x,
294
+ c_x=c_x,
295
+ S_u=S_u,
296
+ inv_S_u=inv_S_u,
297
+ c_u=c_u,
298
+ x_nonscaled=x_nonscaled,
299
+ u_nonscaled=u_nonscaled,
300
+ dx_nonscaled=dx_nonscaled,
301
+ du_nonscaled=du_nonscaled,
302
+ )
303
+
304
+
305
+ def lower_cvxpy_constraints(
306
+ constraints: ConstraintSet,
307
+ x_cvxpy: List,
308
+ u_cvxpy: List,
309
+ parameters: dict = None,
310
+ ) -> Tuple[List, dict]:
311
+ """Lower symbolic convex constraints to CVXPy constraints.
312
+
313
+ Converts symbolic convex constraint expressions to CVXPy constraint objects
314
+ that can be used in the optimal control problem. This function handles both
315
+ nodal constraints (applied at specific trajectory nodes) and cross-node
316
+ constraints (relating multiple nodes).
317
+
318
+ Args:
319
+ constraints: ConstraintSet containing nodal_convex and cross_node_convex
320
+ x_cvxpy: List of CVXPy expressions for state at each node (length N).
321
+ Typically the x_nonscaled list from create_cvxpy_variables().
322
+ u_cvxpy: List of CVXPy expressions for control at each node (length N).
323
+ Typically the u_nonscaled list from create_cvxpy_variables().
324
+ parameters: Optional dict of parameter values to use for any Parameter
325
+ expressions in the constraints. If None, uses Parameter default values.
326
+
327
+ Returns:
328
+ Tuple of:
329
+ - List of CVXPy constraint objects ready for the OCP
330
+ - Dict mapping parameter names to their CVXPy Parameter objects
331
+
332
+ Example:
333
+ After creating CVXPy variables::
334
+
335
+ ocp_vars = create_cvxpy_variables(settings)
336
+ cvxpy_constraints, cvxpy_params = lower_cvxpy_constraints(
337
+ constraint_set,
338
+ ocp_vars.x_nonscaled,
339
+ ocp_vars.u_nonscaled,
340
+ parameters,
341
+ )
342
+
343
+ Note:
344
+ This function only processes convex constraints (nodal_convex and
345
+ cross_node_convex). Non-convex constraints are lowered to JAX in
346
+ lower_symbolic_expressions() and handled via linearization in the SCP.
347
+ """
348
+ import cvxpy as cp
349
+
350
+ from openscvx.symbolic.expr import Parameter, traverse
351
+ from openscvx.symbolic.expr.control import Control
352
+ from openscvx.symbolic.expr.state import State
353
+ from openscvx.symbolic.lowerers.cvxpy import lower_to_cvxpy
354
+
355
+ all_constraints = list(constraints.nodal_convex) + list(constraints.cross_node_convex)
356
+
357
+ if not all_constraints:
358
+ return [], {}
359
+
360
+ # Collect all unique Parameters across all constraints and create cp.Parameter objects
361
+ all_params = {}
362
+
363
+ def collect_params(expr):
364
+ if isinstance(expr, Parameter):
365
+ if expr.name not in all_params:
366
+ # Use value from params dict if provided, otherwise use Parameter's initial value
367
+ if parameters and expr.name in parameters:
368
+ param_value = parameters[expr.name]
369
+ else:
370
+ param_value = expr.value
371
+
372
+ cvx_param = cp.Parameter(expr.shape, value=param_value, name=expr.name)
373
+ all_params[expr.name] = cvx_param
374
+
375
+ # Collect all parameters from all constraints
376
+ for constraint in all_constraints:
377
+ traverse(constraint.constraint, collect_params)
378
+
379
+ cvxpy_constraints = []
380
+
381
+ # Process nodal constraints
382
+ for constraint in constraints.nodal_convex:
383
+ # nodes should already be validated and normalized in preprocessing
384
+ nodes = constraint.nodes
385
+
386
+ # Collect all State and Control variables referenced in the constraint
387
+ state_vars = {}
388
+ control_vars = {}
389
+
390
+ def collect_vars(expr):
391
+ if isinstance(expr, State):
392
+ state_vars[expr.name] = expr
393
+ elif isinstance(expr, Control):
394
+ control_vars[expr.name] = expr
395
+
396
+ traverse(constraint.constraint, collect_vars)
397
+
398
+ # Regular nodal constraint: apply at each specified node
399
+ for node in nodes:
400
+ # Create variable map for this specific node
401
+ variable_map = {}
402
+
403
+ if state_vars:
404
+ variable_map["x"] = x_cvxpy[node]
405
+
406
+ if control_vars:
407
+ variable_map["u"] = u_cvxpy[node]
408
+
409
+ # Add all CVXPy Parameter objects to the variable map
410
+ variable_map.update(all_params)
411
+
412
+ # Verify all variables have slices (should be guaranteed by preprocessing)
413
+ for state_name, state_var in state_vars.items():
414
+ if state_var._slice is None:
415
+ raise ValueError(
416
+ f"State variable '{state_name}' has no slice assigned. "
417
+ f"This indicates a bug in the preprocessing pipeline."
418
+ )
419
+
420
+ for control_name, control_var in control_vars.items():
421
+ if control_var._slice is None:
422
+ raise ValueError(
423
+ f"Control variable '{control_name}' has no slice assigned. "
424
+ f"This indicates a bug in the preprocessing pipeline."
425
+ )
426
+
427
+ # Lower the constraint to CVXPy
428
+ cvxpy_constraint = lower_to_cvxpy(constraint.constraint, variable_map)
429
+ cvxpy_constraints.append(cvxpy_constraint)
430
+
431
+ # Process cross-node constraints
432
+ for constraint in constraints.cross_node_convex:
433
+ # Collect all State and Control variables referenced in the constraint
434
+ state_vars = {}
435
+ control_vars = {}
436
+
437
+ def collect_vars(expr):
438
+ if isinstance(expr, State):
439
+ state_vars[expr.name] = expr
440
+ elif isinstance(expr, Control):
441
+ control_vars[expr.name] = expr
442
+
443
+ traverse(constraint.constraint, collect_vars)
444
+
445
+ # Cross-node constraint: provide full trajectory
446
+ variable_map = {}
447
+
448
+ # Stack all nodes into (N, n_x) and (N, n_u) matrices
449
+ if state_vars:
450
+ variable_map["x"] = cp.vstack(x_cvxpy)
451
+
452
+ if control_vars:
453
+ variable_map["u"] = cp.vstack(u_cvxpy)
454
+
455
+ # Add all CVXPy Parameter objects to the variable map
456
+ variable_map.update(all_params)
457
+
458
+ # Verify all variables have slices
459
+ for state_name, state_var in state_vars.items():
460
+ if state_var._slice is None:
461
+ raise ValueError(
462
+ f"State variable '{state_name}' has no slice assigned. "
463
+ f"This indicates a bug in the preprocessing pipeline."
464
+ )
465
+
466
+ for control_name, control_var in control_vars.items():
467
+ if control_var._slice is None:
468
+ raise ValueError(
469
+ f"Control variable '{control_name}' has no slice assigned. "
470
+ f"This indicates a bug in the preprocessing pipeline."
471
+ )
472
+
473
+ # Lower the constraint once - NodeReference handles node indexing internally
474
+ cvxpy_constraint = lower_to_cvxpy(constraint.constraint, variable_map)
475
+ cvxpy_constraints.append(cvxpy_constraint)
476
+
477
+ return cvxpy_constraints, all_params
478
+
479
+
480
+ def _lower_dynamics(dynamics_expr) -> Dynamics:
481
+ """Lower symbolic dynamics to JAX function with Jacobians.
482
+
483
+ Converts a symbolic dynamics expression to a JAX function and computes
484
+ Jacobians via automatic differentiation.
485
+
486
+ Args:
487
+ dynamics_expr: Symbolic dynamics expression (dx/dt = f(x, u))
488
+
489
+ Returns:
490
+ Dynamics object with f, A (df/dx), B (df/du)
491
+ """
492
+ dyn_fn = lower_to_jax(dynamics_expr)
493
+ return Dynamics(
494
+ f=dyn_fn,
495
+ A=jacfwd(dyn_fn, argnums=0), # df/dx
496
+ B=jacfwd(dyn_fn, argnums=1), # df/du
497
+ )
498
+
499
+
500
+ def _lower_jax_constraints(
501
+ constraints: ConstraintSet,
502
+ ) -> LoweredJaxConstraints:
503
+ """Lower non-convex constraints to JAX functions with gradients.
504
+
505
+ Converts symbolic non-convex constraints to JAX callable functions with
506
+ automatically computed gradients for use in SCP linearization.
507
+
508
+ Args:
509
+ constraints: ConstraintSet containing nodal and cross_node constraints
510
+
511
+ Returns:
512
+ LoweredJaxConstraints with nodal, cross_node, and ctcs lists
513
+ """
514
+ lowered_nodal: List[LoweredNodalConstraint] = []
515
+ lowered_cross_node: List[LoweredCrossNodeConstraint] = []
516
+
517
+ # Lower regular nodal constraints
518
+ if len(constraints.nodal) > 0:
519
+ # Convert symbolic constraint expressions to JAX functions
520
+ constraints_nodal_fns = lower_to_jax(constraints.nodal)
521
+
522
+ # Create LoweredConstraint objects with Jacobians
523
+ for i, fn in enumerate(constraints_nodal_fns):
524
+ # Apply vectorization to handle (N, n_x) and (N, n_u) inputs
525
+ constraint = LoweredNodalConstraint(
526
+ func=jax.vmap(fn, in_axes=(0, 0, None, None)),
527
+ grad_g_x=jax.vmap(jacfwd(fn, argnums=0), in_axes=(0, 0, None, None)),
528
+ grad_g_u=jax.vmap(jacfwd(fn, argnums=1), in_axes=(0, 0, None, None)),
529
+ nodes=constraints.nodal[i].nodes,
530
+ )
531
+ lowered_nodal.append(constraint)
532
+
533
+ # Lower cross-node constraints (trajectory-level)
534
+ for cross_node_constraint in constraints.cross_node:
535
+ # Lower the CrossNodeConstraint - visitor handles wrapping
536
+ constraint_fn = lower_to_jax(cross_node_constraint)
537
+
538
+ # Compute Jacobians for trajectory-level function
539
+ grad_g_X = jacfwd(constraint_fn, argnums=0) # dg/dX - shape (N, n_x)
540
+ grad_g_U = jacfwd(constraint_fn, argnums=1) # dg/dU - shape (N, n_u)
541
+
542
+ cross_node_lowered = LoweredCrossNodeConstraint(
543
+ func=constraint_fn,
544
+ grad_g_X=grad_g_X,
545
+ grad_g_U=grad_g_U,
546
+ )
547
+ lowered_cross_node.append(cross_node_lowered)
548
+
549
+ return LoweredJaxConstraints(
550
+ nodal=lowered_nodal,
551
+ cross_node=lowered_cross_node,
552
+ ctcs=list(constraints.ctcs), # Copy the list
553
+ )
554
+
555
+
556
+ def _lower_cvxpy(
557
+ constraints: ConstraintSet,
558
+ parameters: dict,
559
+ N: int,
560
+ x_unified: UnifiedState,
561
+ u_unified: UnifiedControl,
562
+ jax_constraints: LoweredJaxConstraints,
563
+ ) -> Tuple[CVXPyVariables, LoweredCvxpyConstraints, dict]:
564
+ """Create CVXPy variables and lower convex constraints.
565
+
566
+ Creates all CVXPy variables/parameters needed for the OCP and lowers
567
+ convex constraints to CVXPy constraint objects.
568
+
569
+ Args:
570
+ constraints: ConstraintSet containing convex constraints
571
+ parameters: Dict of parameter values for constraint lowering
572
+ N: Number of discretization nodes
573
+ jax_constraints: Lowered JAX constraints (for sizing CVXPy variables)
574
+ x_unified: Unified state interface (for dimensions and scaling)
575
+ u_unified: Unified control interface (for dimensions and scaling)
576
+
577
+ Returns:
578
+ Tuple of:
579
+ - CVXPyVariables dataclass with all OCP variables
580
+ - LoweredCvxpyConstraints with CVXPy constraint objects
581
+ - Dict mapping parameter names to CVXPy Parameter objects
582
+ """
583
+ from openscvx.config import get_affine_scaling_matrices
584
+
585
+ n_states = len(x_unified.max)
586
+ n_controls = len(u_unified.max)
587
+
588
+ # Compute scaling matrices from unified object bounds
589
+ if x_unified.scaling_min is not None:
590
+ lower_x = np.array(x_unified.scaling_min, dtype=float)
591
+ else:
592
+ lower_x = np.array(x_unified.min, dtype=float)
593
+
594
+ if x_unified.scaling_max is not None:
595
+ upper_x = np.array(x_unified.scaling_max, dtype=float)
596
+ else:
597
+ upper_x = np.array(x_unified.max, dtype=float)
598
+
599
+ S_x, c_x = get_affine_scaling_matrices(n_states, lower_x, upper_x)
600
+
601
+ if u_unified.scaling_min is not None:
602
+ lower_u = np.array(u_unified.scaling_min, dtype=float)
603
+ else:
604
+ lower_u = np.array(u_unified.min, dtype=float)
605
+
606
+ if u_unified.scaling_max is not None:
607
+ upper_u = np.array(u_unified.scaling_max, dtype=float)
608
+ else:
609
+ upper_u = np.array(u_unified.max, dtype=float)
610
+
611
+ S_u, c_u = get_affine_scaling_matrices(n_controls, lower_u, upper_u)
612
+
613
+ # Create all CVXPy variables for the OCP
614
+ ocp_vars = create_cvxpy_variables(
615
+ N=N,
616
+ n_states=n_states,
617
+ n_controls=n_controls,
618
+ S_x=S_x,
619
+ c_x=c_x,
620
+ S_u=S_u,
621
+ c_u=c_u,
622
+ n_nodal_constraints=len(jax_constraints.nodal),
623
+ n_cross_node_constraints=len(jax_constraints.cross_node),
624
+ )
625
+
626
+ # Lower convex constraints to CVXPy
627
+ lowered_cvxpy_constraint_list, cvxpy_params = lower_cvxpy_constraints(
628
+ constraints,
629
+ ocp_vars.x_nonscaled,
630
+ ocp_vars.u_nonscaled,
631
+ parameters,
632
+ )
633
+
634
+ cvxpy_constraints = LoweredCvxpyConstraints(
635
+ constraints=lowered_cvxpy_constraint_list,
636
+ )
637
+
638
+ return ocp_vars, cvxpy_constraints, cvxpy_params
639
+
640
+
641
+ def _contains_node_reference(expr: Expr) -> bool:
642
+ """Check if an expression contains any NodeReference nodes.
643
+
644
+ Internal helper for routing constraints during lowering.
645
+
646
+ Recursively traverses the expression tree to detect the presence of
647
+ NodeReference nodes, which indicate cross-node constraints.
648
+
649
+ Args:
650
+ expr: Expression to check for NodeReference nodes
651
+
652
+ Returns:
653
+ True if the expression contains at least one NodeReference, False otherwise
654
+
655
+ Example:
656
+ position = State("pos", shape=(3,))
657
+
658
+ # Regular expression - no NodeReference
659
+ _contains_node_reference(position) # False
660
+
661
+ # Cross-node expression - has NodeReference
662
+ _contains_node_reference(position.at(10) - position.at(9)) # True
663
+ """
664
+ if isinstance(expr, NodeReference):
665
+ return True
666
+
667
+ # Recursively check all children
668
+ for child in expr.children():
669
+ if _contains_node_reference(child):
670
+ return True
671
+
672
+ return False
673
+
674
+
675
+ def lower_symbolic_problem(
676
+ problem: "SymbolicProblem", byof: Optional[dict] = None
677
+ ) -> LoweredProblem:
678
+ """Lower symbolic problem specification to executable JAX and CVXPy code.
679
+
680
+ This is the main orchestrator for converting a preprocessed SymbolicProblem
681
+ into executable numerical code. It coordinates the lowering of dynamics,
682
+ constraints, and state/control interfaces from symbolic AST representations
683
+ to JAX functions (with automatic differentiation) and CVXPy constraints.
684
+
685
+ This is pure translation - no validation, shape checking, or augmentation occurs
686
+ here. The input problem must be preprocessed (problem.is_preprocessed == True).
687
+
688
+ Args:
689
+ problem: Preprocessed SymbolicProblem from preprocess_symbolic_problem().
690
+ Must have is_preprocessed == True.
691
+ byof: Optional dict of raw JAX functions for expert users. Supported keys:
692
+ - "nodal_constraints": List of f(x, u, node, params) -> residual
693
+ - "cross_nodal_constraints": List of f(X, U, params) -> residual
694
+ - "ctcs_constraints": List of dicts with "constraint_fn", "penalty", "bounds"
695
+
696
+ Returns:
697
+ LoweredProblem dataclass containing lowered problem
698
+
699
+ Example:
700
+ After preprocessing::
701
+
702
+ problem = preprocess_symbolic_problem(...)
703
+ lowered = lower_symbolic_problem(problem)
704
+
705
+ # Access dynamics
706
+ dx = lowered.dynamics.f(x_val, u_val, node=0, params={...})
707
+
708
+ # Use CVXPy objects for OCP
709
+ ocp = OptimalControlProblem(settings, lowered)
710
+
711
+ Raises:
712
+ AssertionError: If problem.is_preprocessed is False
713
+ """
714
+ assert problem.is_preprocessed, "Problem must be preprocessed before lowering"
715
+
716
+ # Create unified state/control interfaces
717
+ x_unified = unify_states(problem.states, name="x")
718
+ u_unified = unify_controls(problem.controls)
719
+ x_prop_unified = unify_states(problem.states_prop, name="x_prop")
720
+
721
+ # Lower dynamics to JAX
722
+ dynamics = _lower_dynamics(problem.dynamics)
723
+ dynamics_prop = _lower_dynamics(problem.dynamics_prop)
724
+
725
+ # Lower non-convex constraints to JAX
726
+ jax_constraints = _lower_jax_constraints(problem.constraints)
727
+
728
+ # Handle byof (bring-your-own-functions) for expert users
729
+ # This must happen BEFORE CVXPy variable creation since CTCS constraints
730
+ # augment the state dimension
731
+ if byof is not None:
732
+ dynamics, dynamics_prop, jax_constraints, x_unified, x_prop_unified = apply_byof(
733
+ byof,
734
+ dynamics,
735
+ dynamics_prop,
736
+ jax_constraints,
737
+ x_unified,
738
+ x_prop_unified,
739
+ u_unified,
740
+ problem.states,
741
+ problem.states_prop,
742
+ problem.N,
743
+ )
744
+
745
+ # Create CVXPy variables and lower convex constraints
746
+ ocp_vars, cvxpy_constraints, cvxpy_params = _lower_cvxpy(
747
+ problem.constraints, problem.parameters, problem.N, x_unified, u_unified, jax_constraints
748
+ )
749
+
750
+ return LoweredProblem(
751
+ dynamics=dynamics,
752
+ dynamics_prop=dynamics_prop,
753
+ jax_constraints=jax_constraints,
754
+ cvxpy_constraints=cvxpy_constraints,
755
+ x_unified=x_unified,
756
+ u_unified=u_unified,
757
+ x_prop_unified=x_prop_unified,
758
+ ocp_vars=ocp_vars,
759
+ cvxpy_params=cvxpy_params,
760
+ )