openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1302 @@
1
+ """CVXPy backend for lowering symbolic expressions to CVXPy format.
2
+
3
+ This module implements the CVXPy lowering backend that converts symbolic expression
4
+ AST nodes into CVXPy expressions for convex optimization. The lowering uses a visitor
5
+ pattern where each expression type has a corresponding visitor method.
6
+
7
+ Architecture:
8
+ The CVXPy lowerer follows a visitor pattern with centralized registration:
9
+
10
+ 1. **Visitor Registration**: The @visitor decorator registers handler functions
11
+ for each expression type in the _CVXPY_VISITORS dictionary
12
+ 2. **Dispatch**: The dispatch() function looks up and calls the appropriate
13
+ visitor based on the expression's type
14
+ 3. **Recursive Lowering**: Each visitor recursively lowers child expressions
15
+ and composes CVXPy operations
16
+ 4. **Translation Only**: This module only translates expressions; CVXPy itself
17
+ validates DCP (Disciplined Convex Programming) rules when the problem is
18
+ constructed/solved
19
+
20
+ Key Features:
21
+ - **Expression Translation**: Converts symbolic AST to CVXPy expression format
22
+ - **Variable Management**: Maps symbolic States/Controls to CVXPy variables
23
+ through a variable_map dictionary
24
+ - **Parameter Support**: Handles both constant parameters and CVXPy Parameters
25
+ for efficient parameter sweeps
26
+ - **Constraint Generation**: Produces CVXPy constraint objects from symbolic
27
+ equality and inequality expressions
28
+
29
+ Backend Usage:
30
+ CVXPy lowering is used for convex constraints in the SCP subproblem. Unlike
31
+ JAX lowering (which happens early during problem construction), CVXPy lowering
32
+ occurs later during Problem.initialize() when CVXPy variables are
33
+ available. See lower_symbolic_expressions() in symbolic/lower.py for details.
34
+
35
+ CVXPy Variable Mapping:
36
+ The lowerer requires a variable_map dictionary that maps symbolic variable names
37
+ to CVXPy expressions. For trajectory optimization::
38
+
39
+ variable_map = {
40
+ "x": cvxpy.Variable((n_x,)), # State vector
41
+ "u": cvxpy.Variable((n_u,)), # Control vector
42
+ "param_name": cvxpy.Parameter((3,)), # Runtime parameters
43
+ }
44
+
45
+ States and Controls use their slices (assigned during unification) to extract
46
+ the correct portion of the unified x and u vectors.
47
+
48
+ Example:
49
+ Basic usage::
50
+
51
+ import cvxpy as cp
52
+ from openscvx.symbolic.lowerers.cvxpy import CvxpyLowerer
53
+ import openscvx as ox
54
+
55
+ # Create symbolic expression
56
+ x = ox.State("x", shape=(3,))
57
+ u = ox.Control("u", shape=(2,))
58
+ expr = ox.Norm(x)**2 + 0.1 * ox.Norm(u)**2
59
+
60
+ # Create CVXPy variables
61
+ cvx_x = cp.Variable(3)
62
+ cvx_u = cp.Variable(2)
63
+
64
+ # Lower to CVXPy
65
+ lowerer = CvxpyLowerer(variable_map={"x": cvx_x, "u": cvx_u})
66
+ cvx_expr = lowerer.lower(expr)
67
+
68
+ # Use in optimization problem
69
+ prob = cp.Problem(cp.Minimize(cvx_expr), constraints=[...])
70
+ prob.solve()
71
+
72
+ Constraint lowering::
73
+
74
+ # Symbolic constraint
75
+ constraint = ox.Norm(x) <= 1.0
76
+
77
+ # Lower to CVXPy constraint
78
+ cvx_constraint = lowerer.lower(constraint)
79
+
80
+ # Add to problem
81
+ prob = cp.Problem(cp.Minimize(cost), constraints=[cvx_constraint])
82
+
83
+ For Contributors:
84
+ **Adding Support for New Expression Types**
85
+
86
+ To add support for a new symbolic expression type to CVXPy lowering:
87
+
88
+ 1. **Define the visitor method** in CvxpyLowerer with the @visitor decorator::
89
+
90
+ @visitor(MyNewExpr)
91
+ def _visit_my_new_expr(self, node: MyNewExpr) -> cp.Expression:
92
+ # Lower child expressions recursively
93
+ operand = self.lower(node.operand)
94
+
95
+ # Return CVXPy expression
96
+ return cp.my_operation(operand)
97
+
98
+ 2. **Key requirements**:
99
+ - Use the @visitor(ExprType) decorator to register the handler
100
+ - Method name should be _visit_<expr_name> (private, lowercase, snake_case)
101
+ - Recursively lower all child expressions using self.lower()
102
+ - Return a cp.Expression or cp.Constraint object
103
+ - Use cp.* operations for CVXPy atoms
104
+
105
+ 3. **DCP considerations**:
106
+ - This module only translates; CVXPy validates DCP rules
107
+ - Document the mathematical properties in the docstring (convex, concave, affine)
108
+ - For non-DCP operations, raise NotImplementedError with helpful message
109
+ - See _visit_sin, _visit_cos, _visit_ctcs for examples
110
+
111
+ 4. **Example patterns**:
112
+ - Unary operation: ``return cp.my_func(self.lower(node.operand))``
113
+ - Binary operation: ``return self.lower(node.left) + self.lower(node.right)``
114
+ - Constraints: ``return self.lower(node.lhs) <= self.lower(node.rhs)``
115
+ - Not supported: Raise NotImplementedError with guidance
116
+
117
+ 5. **Testing**: Ensure your visitor works with:
118
+ - Simple expressions: Direct lowering to cp.Expression
119
+ - Constraint validation: CVXPy accepts the result
120
+ - DCP checking: CVXPy's problem.solve() validates correctly
121
+
122
+ See Also:
123
+ - lower_to_cvxpy(): Convenience wrapper for single expression lowering
124
+ - JaxLowerer: Alternative backend for non-convex constraints and dynamics
125
+ - lower_symbolic_expressions(): Main orchestrator in symbolic/lower.py
126
+ - CVXPy documentation: https://www.cvxpy.org/
127
+ """
128
+
129
+ from typing import Any, Callable, Dict, Type
130
+
131
+ import cvxpy as cp
132
+
133
+ from openscvx.symbolic.expr import (
134
+ CTCS,
135
+ Abs,
136
+ Add,
137
+ Block,
138
+ Concat,
139
+ Constant,
140
+ Cos,
141
+ CrossNodeConstraint,
142
+ Div,
143
+ Equality,
144
+ Exp,
145
+ Expr,
146
+ Hstack,
147
+ Huber,
148
+ Index,
149
+ Inequality,
150
+ Log,
151
+ LogSumExp,
152
+ MatMul,
153
+ Max,
154
+ Mul,
155
+ Neg,
156
+ NodeReference,
157
+ Norm,
158
+ Parameter,
159
+ PositivePart,
160
+ Power,
161
+ Sin,
162
+ SmoothReLU,
163
+ Sqrt,
164
+ Square,
165
+ Stack,
166
+ Sub,
167
+ Sum,
168
+ Tan,
169
+ Transpose,
170
+ Vstack,
171
+ )
172
+ from openscvx.symbolic.expr.control import Control
173
+ from openscvx.symbolic.expr.state import State
174
+
175
+ _CVXPY_VISITORS: Dict[Type[Expr], Callable] = {}
176
+ """Registry mapping expression types to their visitor functions."""
177
+
178
+
179
+ def visitor(expr_cls: Type[Expr]):
180
+ """Decorator to register a visitor function for an expression type.
181
+
182
+ This decorator registers a visitor method to handle a specific expression
183
+ type during CVXPy lowering. The decorated function is stored in _CVXPY_VISITORS
184
+ and will be called by dispatch() when lowering that expression type.
185
+
186
+ Args:
187
+ expr_cls: The Expr subclass this visitor handles (e.g., Add, Mul, Norm)
188
+
189
+ Returns:
190
+ Decorator function that registers the visitor and returns it unchanged
191
+
192
+ Example:
193
+ Register a function as the visitor for the Add expression:
194
+
195
+ @visitor(Add)
196
+ def _visit_add(self, node: Add):
197
+ # Lower addition to CVXPy
198
+ ...
199
+
200
+ Note:
201
+ Multiple expression types can share a visitor by stacking decorators::
202
+
203
+ @visitor(Equality)
204
+ @visitor(Inequality)
205
+ def _visit_constraint(self, node: Constraint):
206
+ # Handle both equality and inequality
207
+ ...
208
+ """
209
+
210
+ def register(fn: Callable[[Any, Expr], cp.Expression]):
211
+ _CVXPY_VISITORS[expr_cls] = fn
212
+ return fn
213
+
214
+ return register
215
+
216
+
217
+ def dispatch(lowerer: Any, expr: Expr):
218
+ """Dispatch an expression to its registered visitor function.
219
+
220
+ Looks up the visitor function for the expression's type and calls it.
221
+ This is the core of the visitor pattern implementation.
222
+
223
+ Args:
224
+ lowerer: The CvxpyLowerer instance (provides context for visitor methods)
225
+ expr: The expression node to lower
226
+
227
+ Returns:
228
+ The result of calling the visitor function (CVXPy expression or constraint)
229
+
230
+ Raises:
231
+ NotImplementedError: If no visitor is registered for the expression type
232
+
233
+ Example:
234
+ Dispatch an expression to lower it:
235
+
236
+ lowerer = CvxpyLowerer(variable_map={...})
237
+ expr = Add(x, y)
238
+ cvx_expr = dispatch(lowerer, expr) # Calls visit_add
239
+ """
240
+ fn = _CVXPY_VISITORS.get(type(expr))
241
+ if fn is None:
242
+ raise NotImplementedError(
243
+ f"{lowerer.__class__.__name__!r} has no visitor for {type(expr).__name__}"
244
+ )
245
+ return fn(lowerer, expr)
246
+
247
+
248
+ class CvxpyLowerer:
249
+ """CVXPy backend for lowering symbolic expressions to disciplined convex programs.
250
+
251
+ This class implements the visitor pattern for converting symbolic expression
252
+ AST nodes to CVXPy expressions and constraints. Each expression type has a
253
+ corresponding visitor method decorated with @visitor that handles the lowering
254
+ logic.
255
+
256
+ The lowering process is recursive: each visitor lowers its child expressions
257
+ first, then composes them into a CVXPy operation. CVXPy will validate DCP
258
+ (Disciplined Convex Programming) compliance when the problem is constructed.
259
+
260
+ Attributes:
261
+ variable_map (dict): Dictionary mapping variable names to CVXPy expressions.
262
+ Must include "x" for states and "u" for controls. May include parameter
263
+ names mapped to CVXPy Parameter objects or constants.
264
+
265
+ Example:
266
+ Lower an expression to CVXPy:
267
+
268
+ import cvxpy as cp
269
+ lowerer = CvxpyLowerer(variable_map={
270
+ "x": cp.Variable(3),
271
+ "u": cp.Variable(2),
272
+ })
273
+ expr = ox.Norm(x)**2 + 0.1 * ox.Norm(u)**2
274
+ cvx_expr = lowerer.lower(expr)
275
+
276
+ Note:
277
+ The lowerer is stateful (stores variable_map) unlike JaxLowerer which
278
+ is stateless. Variables must be registered before lowering expressions
279
+ that reference them.
280
+ """
281
+
282
+ def __init__(self, variable_map: Dict[str, cp.Expression] = None):
283
+ """Initialize the CVXPy lowerer.
284
+
285
+ Args:
286
+ variable_map: Dictionary mapping variable names to CVXPy expressions.
287
+ For State/Control objects, keys should be "x" and "u" respectively.
288
+ For Parameter objects, keys should match their names. If None, an
289
+ empty dictionary is created.
290
+
291
+ Example:
292
+ Initialize the CVXPy lowerer with the variable map:
293
+
294
+ cvx_x = cp.Variable(3, name="x")
295
+ cvx_u = cp.Variable(2, name="u")
296
+ lowerer = CvxpyLowerer({"x": cvx_x, "u": cvx_u})
297
+ """
298
+ self.variable_map = variable_map or {}
299
+
300
+ def lower(self, expr: Expr) -> cp.Expression:
301
+ """Lower a symbolic expression to a CVXPy expression.
302
+
303
+ Main entry point for lowering. Delegates to dispatch() which looks up
304
+ the appropriate visitor method based on the expression type.
305
+
306
+ Args:
307
+ expr: Symbolic expression to lower (any Expr subclass)
308
+
309
+ Returns:
310
+ CVXPy expression or constraint object. For arithmetic expressions,
311
+ returns cp.Expression. For Equality/Inequality, returns cp.Constraint.
312
+
313
+ Raises:
314
+ NotImplementedError: If no visitor exists for the expression type
315
+ ValueError: If required variables are not in variable_map
316
+
317
+ Example:
318
+ Lower an expression to a CVXPy expression:
319
+
320
+ lowerer = CvxpyLowerer(variable_map={"x": cvx_x, "u": cvx_u})
321
+ x = ox.State("x", shape=(3,))
322
+ expr = ox.Norm(x)
323
+ cvx_expr = lowerer.lower(expr)
324
+ """
325
+ return dispatch(self, expr)
326
+
327
+ def register_variable(self, name: str, cvx_expr: cp.Expression):
328
+ """Register a CVXPy variable/expression for use in lowering.
329
+
330
+ Adds or updates a variable in the variable_map. Useful for dynamically
331
+ adding variables after the lowerer has been created.
332
+
333
+ Args:
334
+ name: Variable name (e.g., "x", "u", or parameter name)
335
+ cvx_expr: CVXPy expression to associate with the name
336
+
337
+ Example:
338
+ Register a variable:
339
+
340
+ lowerer = CvxpyLowerer()
341
+ lowerer.register_variable("x", cp.Variable(3))
342
+ lowerer.register_variable("obs_center", cp.Parameter(3))
343
+ """
344
+ self.variable_map[name] = cvx_expr
345
+
346
+ @visitor(Constant)
347
+ def _visit_constant(self, node: Constant) -> cp.Expression:
348
+ """Lower a constant value to a CVXPy constant.
349
+
350
+ Wraps the constant's numpy array value in a CVXPy Constant expression.
351
+
352
+ Args:
353
+ node: Constant expression node
354
+
355
+ Returns:
356
+ CVXPy constant expression wrapping the value
357
+ """
358
+ return cp.Constant(node.value)
359
+
360
+ @visitor(State)
361
+ def _visit_state(self, node: State) -> cp.Expression:
362
+ """Lower a state variable to a CVXPy expression.
363
+
364
+ Extracts the appropriate slice from the unified state vector "x" using
365
+ the slice assigned during unification. The "x" variable must exist in
366
+ the variable_map.
367
+
368
+ Args:
369
+ node: State expression node
370
+
371
+ Returns:
372
+ CVXPy expression representing the state slice: x[slice]
373
+
374
+ Raises:
375
+ ValueError: If "x" is not found in variable_map
376
+ """
377
+ if "x" not in self.variable_map:
378
+ raise ValueError("State vector 'x' not found in variable_map.")
379
+
380
+ cvx_var = self.variable_map["x"]
381
+
382
+ # If the state has a slice assigned, apply it
383
+ if node._slice is not None:
384
+ return cvx_var[node._slice]
385
+ return cvx_var
386
+
387
+ @visitor(Control)
388
+ def _visit_control(self, node: Control) -> cp.Expression:
389
+ """Lower a control variable to a CVXPy expression.
390
+
391
+ Extracts the appropriate slice from the unified control vector "u" using
392
+ the slice assigned during unification. The "u" variable must exist in
393
+ the variable_map.
394
+
395
+ Args:
396
+ node: Control expression node
397
+
398
+ Returns:
399
+ CVXPy expression representing the control slice: u[slice]
400
+
401
+ Raises:
402
+ ValueError: If "u" is not found in variable_map
403
+ """
404
+ if "u" not in self.variable_map:
405
+ raise ValueError("Control vector 'u' not found in variable_map.")
406
+
407
+ cvx_var = self.variable_map["u"]
408
+
409
+ # If the control has a slice assigned, apply it
410
+ if node._slice is not None:
411
+ return cvx_var[node._slice]
412
+ return cvx_var
413
+
414
+ @visitor(NodeReference)
415
+ def _visit_node_reference(self, node: "NodeReference") -> cp.Expression:
416
+ """Lower NodeReference - extract value at a specific trajectory node.
417
+
418
+ NodeReference enables cross-node constraints by referencing state/control
419
+ values at specific discrete time points. This requires the variable_map to
420
+ contain full trajectory arrays (N, n_x) or (N, n_u) rather than single-node
421
+ vectors.
422
+
423
+ Args:
424
+ node: NodeReference expression with base and node_idx
425
+
426
+ Returns:
427
+ CVXPy expression representing the variable at the specified node:
428
+ x[node_idx, slice] or u[node_idx, slice]
429
+
430
+ Raises:
431
+ ValueError: If the required trajectory variable is not in variable_map
432
+ ValueError: If the base variable has no slice assigned
433
+ NotImplementedError: If the base is a compound expression
434
+
435
+ Example:
436
+ For cross-node constraint: position.at(5) - position.at(4) <= 0.1
437
+
438
+ variable_map = {
439
+ "x": cp.vstack([x_nonscaled[k] for k in range(N)]), # (N, n_x)
440
+ }
441
+ # position.at(5) lowers to x[5, position._slice]
442
+
443
+ Note:
444
+ The node_idx is already resolved to an absolute integer index during
445
+ expression construction, so negative indices are already handled.
446
+ """
447
+ from openscvx.symbolic.expr.control import Control
448
+ from openscvx.symbolic.expr.state import State
449
+
450
+ idx = node.node_idx
451
+
452
+ if isinstance(node.base, State):
453
+ if "x" not in self.variable_map:
454
+ raise ValueError(
455
+ "State vector 'x' not found in variable_map. "
456
+ "For cross-node constraints, 'x' must be the full trajectory (N, n_x)."
457
+ )
458
+
459
+ cvx_var = self.variable_map["x"] # Should be (N, n_x) for cross-node constraints
460
+
461
+ # Apply slice if state has one assigned
462
+ if node.base._slice is not None:
463
+ return cvx_var[idx, node.base._slice]
464
+ else:
465
+ # No slice means this is the entire unified state vector
466
+ return cvx_var[idx, :]
467
+
468
+ elif isinstance(node.base, Control):
469
+ if "u" not in self.variable_map:
470
+ raise ValueError(
471
+ "Control vector 'u' not found in variable_map. "
472
+ "For cross-node constraints, 'u' must be the full trajectory (N, n_u)."
473
+ )
474
+
475
+ cvx_var = self.variable_map["u"] # Should be (N, n_u) for cross-node constraints
476
+
477
+ # Apply slice if control has one assigned
478
+ if node.base._slice is not None:
479
+ return cvx_var[idx, node.base._slice]
480
+ else:
481
+ # No slice means this is the entire unified control vector
482
+ return cvx_var[idx, :]
483
+
484
+ else:
485
+ # Compound expression (e.g., position[0].at(5))
486
+ # This is more complex - would need to lower base in single-node context
487
+ raise NotImplementedError(
488
+ "Compound expressions in NodeReference are not yet supported for CVXPy lowering. "
489
+ f"Base expression type: {type(node.base).__name__}. "
490
+ "Only State and Control NodeReferences are currently supported."
491
+ )
492
+
493
+ @visitor(CrossNodeConstraint)
494
+ def _visit_cross_node_constraint(self, node: CrossNodeConstraint) -> cp.Constraint:
495
+ """Lower CrossNodeConstraint to CVXPy constraint.
496
+
497
+ CrossNodeConstraint wraps constraints that reference multiple trajectory
498
+ nodes via NodeReference (e.g., rate limits like x.at(k) - x.at(k-1) <= r).
499
+
500
+ For CVXPy lowering, this simply lowers the inner constraint. The NodeReference
501
+ nodes within the constraint will handle extracting values from the full
502
+ trajectory arrays (which must be provided in variable_map as "x" and "u").
503
+
504
+ Args:
505
+ node: CrossNodeConstraint expression wrapping the inner constraint
506
+
507
+ Returns:
508
+ CVXPy constraint object
509
+
510
+ Note:
511
+ The variable_map must contain full trajectory arrays:
512
+ - "x": (N, n_x) CVXPy expression (e.g., cp.vstack(x_nonscaled))
513
+ - "u": (N, n_u) CVXPy expression (e.g., cp.vstack(u_nonscaled))
514
+
515
+ NodeReference visitors will index into these arrays using the fixed
516
+ node indices baked into the expression.
517
+
518
+ Example:
519
+ For constraint: position.at(5) - position.at(4) <= max_step
520
+
521
+ With variable_map = {"x": cp.vstack([x[k] for k in range(N)])}
522
+
523
+ The lowered constraint evaluates:
524
+ x[5, pos_slice] - x[4, pos_slice] <= max_step
525
+ """
526
+ # Simply lower the inner constraint - NodeReference handles indexing
527
+ return self.lower(node.constraint)
528
+
529
+ @visitor(Parameter)
530
+ def _visit_parameter(self, node: Parameter) -> cp.Expression:
531
+ """Lower a parameter to a CVXPy expression.
532
+
533
+ Parameters are looked up by name in the variable_map. They can be mapped
534
+ to CVXPy Parameter objects (for efficient parameter sweeps) or constants.
535
+
536
+ Args:
537
+ node: Parameter expression node
538
+
539
+ Returns:
540
+ CVXPy expression from variable_map (Parameter or constant)
541
+
542
+ Raises:
543
+ ValueError: If parameter name is not found in variable_map
544
+
545
+ Note:
546
+ For parameter sweeps without recompilation, map to cp.Parameter.
547
+ For fixed values, map to cp.Constant or numpy arrays.
548
+ """
549
+ param_name = node.name
550
+ if param_name in self.variable_map:
551
+ return self.variable_map[param_name]
552
+ else:
553
+ raise ValueError(
554
+ f"Parameter '{param_name}' not found in variable_map. "
555
+ f"Add it during CVXPy lowering or use cp.Parameter for parameter sweeps."
556
+ )
557
+
558
+ @visitor(Add)
559
+ def _visit_add(self, node: Add) -> cp.Expression:
560
+ """Lower addition to CVXPy expression.
561
+
562
+ Recursively lowers all terms and composes them with element-wise addition.
563
+ Addition is affine and always DCP-compliant.
564
+
565
+ Args:
566
+ node: Add expression node with multiple terms
567
+
568
+ Returns:
569
+ CVXPy expression representing the sum of all terms
570
+ """
571
+ terms = [self.lower(term) for term in node.terms]
572
+ result = terms[0]
573
+ for term in terms[1:]:
574
+ result = result + term
575
+ return result
576
+
577
+ @visitor(Sub)
578
+ def _visit_sub(self, node: Sub) -> cp.Expression:
579
+ """Lower subtraction to CVXPy expression (element-wise left - right).
580
+
581
+ Subtraction is affine and always DCP-compliant.
582
+
583
+ Args:
584
+ node: Sub expression node
585
+
586
+ Returns:
587
+ CVXPy expression representing left - right
588
+ """
589
+ left = self.lower(node.left)
590
+ right = self.lower(node.right)
591
+ return left - right
592
+
593
+ @visitor(Mul)
594
+ def _visit_mul(self, node: Mul) -> cp.Expression:
595
+ """Lower element-wise multiplication to CVXPy expression.
596
+
597
+ Element-wise multiplication is DCP-compliant when at least one operand
598
+ is constant. For quadratic forms, use MatMul instead.
599
+
600
+ Args:
601
+ node: Mul expression node with multiple factors
602
+
603
+ Returns:
604
+ CVXPy expression representing element-wise product
605
+
606
+ Note:
607
+ For convex optimization, typically one factor should be constant.
608
+ CVXPy will raise a DCP error if the composition violates DCP rules.
609
+ """
610
+ factors = [self.lower(factor) for factor in node.factors]
611
+ result = factors[0]
612
+ for factor in factors[1:]:
613
+ result = result * factor
614
+ return result
615
+
616
+ @visitor(Div)
617
+ def _visit_div(self, node: Div) -> cp.Expression:
618
+ """Lower element-wise division to CVXPy expression.
619
+
620
+ Division is DCP-compliant when the denominator is constant or when
621
+ the numerator is constant and the denominator is concave.
622
+
623
+ Args:
624
+ node: Div expression node
625
+
626
+ Returns:
627
+ CVXPy expression representing left / right
628
+
629
+ Note:
630
+ CVXPy will raise a DCP error if the division violates DCP rules.
631
+ """
632
+ left = self.lower(node.left)
633
+ right = self.lower(node.right)
634
+ return left / right
635
+
636
+ @visitor(MatMul)
637
+ def _visit_matmul(self, node: MatMul) -> cp.Expression:
638
+ """Lower matrix multiplication to CVXPy expression using @ operator.
639
+
640
+ Matrix multiplication is DCP-compliant when at least one operand is
641
+ constant. Used for quadratic forms like x.T @ Q @ x.
642
+
643
+ Args:
644
+ node: MatMul expression node
645
+
646
+ Returns:
647
+ CVXPy expression representing left @ right
648
+ """
649
+ left = self.lower(node.left)
650
+ right = self.lower(node.right)
651
+ return left @ right
652
+
653
+ @visitor(Neg)
654
+ def _visit_neg(self, node: Neg) -> cp.Expression:
655
+ """Lower negation (unary minus) to CVXPy expression.
656
+
657
+ Negation preserves DCP properties (negating convex gives concave).
658
+
659
+ Args:
660
+ node: Neg expression node
661
+
662
+ Returns:
663
+ CVXPy expression representing -operand
664
+ """
665
+ operand = self.lower(node.operand)
666
+ return -operand
667
+
668
+ @visitor(Sum)
669
+ def _visit_sum(self, node: Sum) -> cp.Expression:
670
+ """Lower sum reduction to CVXPy expression (sums all elements).
671
+
672
+ Sum preserves DCP properties (sum of convex is convex).
673
+
674
+ Args:
675
+ node: Sum expression node
676
+
677
+ Returns:
678
+ CVXPy scalar expression representing the sum of all elements
679
+ """
680
+ operand = self.lower(node.operand)
681
+ return cp.sum(operand)
682
+
683
+ @visitor(Norm)
684
+ def _visit_norm(self, node: Norm) -> cp.Expression:
685
+ """Lower norm operation to CVXPy expression.
686
+
687
+ Norms are convex functions and commonly used in convex optimization.
688
+ Supports all CVXPy norm types (1, 2, inf, "fro", etc.).
689
+
690
+ Args:
691
+ node: Norm expression node with ord attribute
692
+
693
+ Returns:
694
+ CVXPy expression representing the norm of the operand
695
+
696
+ Note:
697
+ Common norms: ord=2 (Euclidean), ord=1 (Manhattan), ord="inf"
698
+ """
699
+ operand = self.lower(node.operand)
700
+ return cp.norm(operand, node.ord)
701
+
702
+ @visitor(Index)
703
+ def _visit_index(self, node: Index) -> cp.Expression:
704
+ """Lower indexing/slicing operation to CVXPy expression.
705
+
706
+ Indexing preserves DCP properties (indexing into convex is convex).
707
+
708
+ Args:
709
+ node: Index expression node
710
+
711
+ Returns:
712
+ CVXPy expression representing base[index]
713
+ """
714
+ base = self.lower(node.base)
715
+ return base[node.index]
716
+
717
+ @visitor(Concat)
718
+ def _visit_concat(self, node: Concat) -> cp.Expression:
719
+ """Lower concatenation to CVXPy expression.
720
+
721
+ Concatenates expressions horizontally along axis 0. Scalars are
722
+ promoted to 1D arrays before concatenation. Preserves DCP properties.
723
+
724
+ Args:
725
+ node: Concat expression node
726
+
727
+ Returns:
728
+ CVXPy expression representing horizontal concatenation
729
+
730
+ Note:
731
+ Uses cp.hstack for concatenation. Scalars are reshaped to (1,).
732
+ """
733
+ exprs = [self.lower(child) for child in node.exprs]
734
+ # Ensure all expressions are at least 1D for concatenation
735
+ exprs_1d = []
736
+ for expr in exprs:
737
+ if expr.ndim == 0: # scalar
738
+ exprs_1d.append(cp.reshape(expr, (1,), order="C"))
739
+ else:
740
+ exprs_1d.append(expr)
741
+ return cp.hstack(exprs_1d)
742
+
743
+ @visitor(Sin)
744
+ def _visit_sin(self, node: Sin) -> cp.Expression:
745
+ """Raise NotImplementedError for sine function.
746
+
747
+ Sine is not DCP-compliant in CVXPy as it is neither convex nor concave.
748
+
749
+ Args:
750
+ node: Sin expression node
751
+
752
+ Raises:
753
+ NotImplementedError: Always raised since sine is not DCP-compliant
754
+
755
+ Note:
756
+ For constraints involving trigonometric functions:
757
+ - Use piecewise-linear approximations, or
758
+ - Handle in the JAX dynamics/constraint layer instead of CVXPy
759
+ """
760
+ raise NotImplementedError(
761
+ "Trigonometric functions like Sin are not DCP-compliant in CVXPy. "
762
+ "Consider using piecewise-linear approximations or handle these constraints "
763
+ "in the dynamics (JAX) layer instead."
764
+ )
765
+
766
+ @visitor(Cos)
767
+ def _visit_cos(self, node: Cos) -> cp.Expression:
768
+ """Raise NotImplementedError for cosine function.
769
+
770
+ Cosine is not DCP-compliant in CVXPy as it is neither convex nor concave.
771
+
772
+ Args:
773
+ node: Cos expression node
774
+
775
+ Raises:
776
+ NotImplementedError: Always raised since cosine is not DCP-compliant
777
+
778
+ Note:
779
+ For constraints involving trigonometric functions:
780
+ - Use piecewise-linear approximations, or
781
+ - Handle in the JAX dynamics/constraint layer instead of CVXPy
782
+ """
783
+ raise NotImplementedError(
784
+ "Trigonometric functions like Cos are not DCP-compliant in CVXPy. "
785
+ "Consider using piecewise-linear approximations or handle these constraints "
786
+ "in the dynamics (JAX) layer instead."
787
+ )
788
+
789
+ @visitor(Tan)
790
+ def _visit_tan(self, node: Tan) -> cp.Expression:
791
+ """Raise NotImplementedError for tangent function.
792
+
793
+ Tangent is not DCP-compliant in CVXPy as it is neither convex nor concave.
794
+
795
+ Args:
796
+ node: Tan expression node
797
+
798
+ Raises:
799
+ NotImplementedError: Always raised since tangent is not DCP-compliant
800
+
801
+ Note:
802
+ For constraints involving trigonometric functions:
803
+ - Use piecewise-linear approximations, or
804
+ - Handle in the JAX dynamics/constraint layer instead of CVXPy
805
+ """
806
+ raise NotImplementedError(
807
+ "Trigonometric functions like Tan are not DCP-compliant in CVXPy. "
808
+ "Consider using piecewise-linear approximations or handle these constraints "
809
+ "in the dynamics (JAX) layer instead."
810
+ )
811
+
812
+ @visitor(Exp)
813
+ def _visit_exp(self, node: Exp) -> cp.Expression:
814
+ """Lower exponential function to CVXPy expression.
815
+
816
+ Exponential is a convex function and DCP-compliant when used in
817
+ appropriate contexts (e.g., minimizing exp(x) or constraints like
818
+ exp(x) <= c).
819
+
820
+ Args:
821
+ node: Exp expression node
822
+
823
+ Returns:
824
+ CVXPy expression representing exp(operand)
825
+
826
+ Note:
827
+ Exponential is convex increasing, so it's valid in:
828
+ - Objective: minimize exp(x)
829
+ - Constraints: exp(x) <= c (convex constraint)
830
+ """
831
+ operand = self.lower(node.operand)
832
+ return cp.exp(operand)
833
+
834
+ @visitor(Log)
835
+ def _visit_log(self, node: Log) -> cp.Expression:
836
+ """Lower natural logarithm to CVXPy expression.
837
+
838
+ Logarithm is a concave function and DCP-compliant when used in
839
+ appropriate contexts (e.g., maximizing log(x) or constraints like
840
+ log(x) >= c).
841
+
842
+ Args:
843
+ node: Log expression node
844
+
845
+ Returns:
846
+ CVXPy expression representing log(operand)
847
+
848
+ Note:
849
+ Logarithm is concave increasing, so it's valid in:
850
+ - Objective: maximize log(x)
851
+ - Constraints: log(x) >= c (concave constraint, or equivalently c <= log(x))
852
+ """
853
+ operand = self.lower(node.operand)
854
+ return cp.log(operand)
855
+
856
+ @visitor(Abs)
857
+ def _visit_abs(self, node: Abs) -> cp.Expression:
858
+ """Lower absolute value to CVXPy expression.
859
+
860
+ Absolute value is a convex function and DCP-compliant when used in
861
+ appropriate contexts (e.g., minimizing |x| or constraints like |x| <= c).
862
+
863
+ Args:
864
+ node: Abs expression node
865
+
866
+ Returns:
867
+ CVXPy expression representing |operand|
868
+
869
+ Note:
870
+ Absolute value is convex, so it's valid in:
871
+ - Objective: minimize abs(x)
872
+ - Constraints: abs(x) <= c (convex constraint)
873
+ """
874
+ operand = self.lower(node.operand)
875
+ return cp.abs(operand)
876
+
877
+ @visitor(Equality)
878
+ def _visit_equality(self, node: Equality) -> cp.Constraint:
879
+ """Lower equality constraint to CVXPy constraint (lhs == rhs).
880
+
881
+ Equality constraints require affine expressions on both sides for
882
+ DCP compliance.
883
+
884
+ Args:
885
+ node: Equality constraint node
886
+
887
+ Returns:
888
+ CVXPy equality constraint object
889
+
890
+ Note:
891
+ For DCP compliance, both lhs and rhs must be affine. CVXPy will
892
+ raise a DCP error if either side is non-affine.
893
+ """
894
+ left = self.lower(node.lhs)
895
+ right = self.lower(node.rhs)
896
+ return left == right
897
+
898
+ @visitor(Inequality)
899
+ def _visit_inequality(self, node: Inequality) -> cp.Constraint:
900
+ """Lower inequality constraint to CVXPy constraint (lhs <= rhs).
901
+
902
+ Inequality constraints must satisfy DCP rules: convex <= concave.
903
+
904
+ Args:
905
+ node: Inequality constraint node
906
+
907
+ Returns:
908
+ CVXPy inequality constraint object
909
+
910
+ Note:
911
+ For DCP compliance: lhs must be convex and rhs must be concave.
912
+ Common form: convex_expr(x) <= constant
913
+ """
914
+ left = self.lower(node.lhs)
915
+ right = self.lower(node.rhs)
916
+ return left <= right
917
+
918
+ @visitor(CTCS)
919
+ def _visit_ctcs(self, node: CTCS) -> cp.Expression:
920
+ """Raise NotImplementedError for CTCS constraints.
921
+
922
+ CTCS (Continuous-Time Constraint Satisfaction) constraints are handled
923
+ through dynamics augmentation using JAX, not CVXPy. They represent
924
+ non-convex continuous-time constraints.
925
+
926
+ Args:
927
+ node: CTCS constraint node
928
+
929
+ Raises:
930
+ NotImplementedError: Always raised since CTCS uses JAX, not CVXPy
931
+
932
+ Note:
933
+ CTCS constraints are lowered to JAX during dynamics augmentation.
934
+ They add virtual states and controls to enforce constraints over
935
+ continuous time intervals. See JaxLowerer.visit_ctcs() instead.
936
+ """
937
+ raise NotImplementedError(
938
+ "CTCS constraints are for continuous-time constraint satisfaction and "
939
+ "should be handled through dynamics augmentation with JAX lowering, "
940
+ "not CVXPy lowering. CTCS constraints represent non-convex dynamics "
941
+ "augmentation."
942
+ )
943
+
944
+ @visitor(PositivePart)
945
+ def _visit_pos(self, node: PositivePart) -> cp.Expression:
946
+ """Lower positive part function to CVXPy.
947
+
948
+ Computes max(x, 0), which is convex. Used in penalty methods for
949
+ inequality constraints.
950
+
951
+ Args:
952
+ node: PositivePart expression node
953
+
954
+ Returns:
955
+ CVXPy expression representing max(operand, 0)
956
+
957
+ Note:
958
+ Positive part is convex and commonly used in hinge loss and
959
+ penalty methods for inequality constraints.
960
+ """
961
+ operand = self.lower(node.x)
962
+ return cp.maximum(operand, 0.0)
963
+
964
+ @visitor(Square)
965
+ def _visit_square(self, node: Square) -> cp.Expression:
966
+ """Lower square function to CVXPy.
967
+
968
+ Computes x^2, which is convex. Used in quadratic penalty methods
969
+ and least-squares objectives.
970
+
971
+ Args:
972
+ node: Square expression node
973
+
974
+ Returns:
975
+ CVXPy expression representing operand^2
976
+
977
+ Note:
978
+ Square is convex increasing for x >= 0 and convex decreasing for
979
+ x <= 0. It's always convex overall.
980
+ """
981
+ operand = self.lower(node.x)
982
+ return cp.square(operand)
983
+
984
+ @visitor(Huber)
985
+ def _visit_huber(self, node: Huber) -> cp.Expression:
986
+ """Lower Huber penalty function to CVXPy.
987
+
988
+ Huber penalty is quadratic for small values and linear for large values,
989
+ providing robustness to outliers. It is convex and DCP-compliant.
990
+
991
+ The Huber function is defined as:
992
+ - |x| <= delta: 0.5 * x^2
993
+ - |x| > delta: delta * (|x| - 0.5 * delta)
994
+
995
+ Args:
996
+ node: Huber expression node with delta parameter
997
+
998
+ Returns:
999
+ CVXPy expression representing Huber penalty
1000
+
1001
+ Note:
1002
+ Huber loss is convex and combines the benefits of squared error
1003
+ (smooth, differentiable) and absolute error (robust to outliers).
1004
+ """
1005
+ operand = self.lower(node.x)
1006
+ return cp.huber(operand, M=node.delta)
1007
+
1008
+ @visitor(SmoothReLU)
1009
+ def _visit_srelu(self, node: SmoothReLU) -> cp.Expression:
1010
+ """Lower smooth ReLU penalty function to CVXPy.
1011
+
1012
+ Smooth approximation to ReLU: sqrt(max(x, 0)^2 + c^2) - c
1013
+ Differentiable everywhere, approaches ReLU as c -> 0. Convex.
1014
+
1015
+ Args:
1016
+ node: SmoothReLU expression node with smoothing parameter c
1017
+
1018
+ Returns:
1019
+ CVXPy expression representing smooth ReLU penalty
1020
+
1021
+ Note:
1022
+ This provides a smooth, convex approximation to the ReLU function
1023
+ max(x, 0). The parameter c controls the smoothness: smaller c gives
1024
+ a better approximation but less smoothness.
1025
+ """
1026
+ operand = self.lower(node.x)
1027
+ c = node.c
1028
+ # smooth_relu(x) = sqrt(max(x, 0)^2 + c^2) - c
1029
+ pos_part = cp.maximum(operand, 0.0)
1030
+ # For SmoothReLU, we use the 2-norm formulation
1031
+ return cp.sqrt(cp.sum_squares(pos_part) + c**2) - c
1032
+
1033
+ @visitor(Sqrt)
1034
+ def _visit_sqrt(self, node: Sqrt) -> cp.Expression:
1035
+ """Lower square root to CVXPy expression.
1036
+
1037
+ Square root is concave and DCP-compliant when used appropriately
1038
+ (e.g., maximizing sqrt(x) or constraints like sqrt(x) >= c).
1039
+
1040
+ Args:
1041
+ node: Sqrt expression node
1042
+
1043
+ Returns:
1044
+ CVXPy expression representing sqrt(operand)
1045
+
1046
+ Note:
1047
+ Square root is concave increasing for x > 0. Valid in:
1048
+ - Objective: maximize sqrt(x)
1049
+ - Constraints: sqrt(x) >= c (concave constraint)
1050
+ """
1051
+ operand = self.lower(node.operand)
1052
+ return cp.sqrt(operand)
1053
+
1054
+ @visitor(Max)
1055
+ def _visit_max(self, node: Max) -> cp.Expression:
1056
+ """Lower element-wise maximum to CVXPy expression.
1057
+
1058
+ Maximum is convex (pointwise max of convex functions is convex).
1059
+
1060
+ Args:
1061
+ node: Max expression node with multiple operands
1062
+
1063
+ Returns:
1064
+ CVXPy expression representing element-wise maximum
1065
+
1066
+ Note:
1067
+ For multiple operands, chains binary maximum operations.
1068
+ Maximum preserves convexity.
1069
+ """
1070
+ operands = [self.lower(op) for op in node.operands]
1071
+ # CVXPy's maximum can take multiple arguments
1072
+ if len(operands) == 2:
1073
+ return cp.maximum(operands[0], operands[1])
1074
+ else:
1075
+ # For more than 2 operands, chain maximum calls
1076
+ result = cp.maximum(operands[0], operands[1])
1077
+ for op in operands[2:]:
1078
+ result = cp.maximum(result, op)
1079
+ return result
1080
+
1081
+ @visitor(LogSumExp)
1082
+ def _visit_logsumexp(self, node: LogSumExp) -> cp.Expression:
1083
+ """Lower log-sum-exp to CVXPy expression.
1084
+
1085
+ Log-sum-exp is convex and is a smooth approximation to the maximum function.
1086
+ CVXPy's log_sum_exp atom computes log(sum(exp(x_i))) for stacked operands.
1087
+
1088
+ Args:
1089
+ node: LogSumExp expression node with multiple operands
1090
+
1091
+ Returns:
1092
+ CVXPy expression representing log-sum-exp
1093
+
1094
+ Note:
1095
+ Log-sum-exp is convex and DCP-compliant. It satisfies:
1096
+ max(x₁, ..., xₙ) ≤ logsumexp(x₁, ..., xₙ) ≤ max(x₁, ..., xₙ) + log(n)
1097
+ """
1098
+ operands = [self.lower(op) for op in node.operands]
1099
+
1100
+ # CVXPy's log_sum_exp expects a stacked expression with an axis parameter
1101
+ # For element-wise log-sum-exp, we stack along a new axis and reduce along it
1102
+ if len(operands) == 1:
1103
+ return operands[0]
1104
+
1105
+ # Stack operands along a new axis (axis 0) and compute log_sum_exp along that axis
1106
+ stacked = cp.vstack(operands)
1107
+ return cp.log_sum_exp(stacked, axis=0)
1108
+
1109
+ @visitor(Transpose)
1110
+ def _visit_transpose(self, node: Transpose) -> cp.Expression:
1111
+ """Lower matrix transpose to CVXPy expression.
1112
+
1113
+ Transpose preserves DCP properties (transpose of convex is convex).
1114
+
1115
+ Args:
1116
+ node: Transpose expression node
1117
+
1118
+ Returns:
1119
+ CVXPy expression representing operand.T
1120
+ """
1121
+ operand = self.lower(node.operand)
1122
+ return operand.T
1123
+
1124
+ @visitor(Power)
1125
+ def _visit_power(self, node: Power) -> cp.Expression:
1126
+ """Lower element-wise power (base**exponent) to CVXPy expression.
1127
+
1128
+ Power is DCP-compliant for specific exponent values:
1129
+ - exponent >= 1: convex (when base >= 0)
1130
+ - 0 <= exponent <= 1: concave (when base >= 0)
1131
+
1132
+ Args:
1133
+ node: Power expression node
1134
+
1135
+ Returns:
1136
+ CVXPy expression representing base**exponent
1137
+
1138
+ Note:
1139
+ CVXPy will verify DCP compliance at problem construction time.
1140
+ Common convex cases: x^2, x^3, x^4 (even powers)
1141
+ """
1142
+ base = self.lower(node.base)
1143
+ exponent = self.lower(node.exponent)
1144
+ return cp.power(base, exponent)
1145
+
1146
+ @visitor(Stack)
1147
+ def _visit_stack(self, node: Stack) -> cp.Expression:
1148
+ """Lower vertical stacking to CVXPy expression.
1149
+
1150
+ Stacks expressions vertically using cp.vstack. Preserves DCP properties.
1151
+
1152
+ Args:
1153
+ node: Stack expression node with multiple rows
1154
+
1155
+ Returns:
1156
+ CVXPy expression representing vertical stack of rows
1157
+
1158
+ Note:
1159
+ Each row is stacked along axis 0 to create a 2D array.
1160
+ """
1161
+ rows = [self.lower(row) for row in node.rows]
1162
+ # Stack rows vertically
1163
+ return cp.vstack(rows)
1164
+
1165
+ @visitor(Hstack)
1166
+ def _visit_hstack(self, node: Hstack) -> cp.Expression:
1167
+ """Lower horizontal stacking to CVXPy expression.
1168
+
1169
+ For 1D arrays, uses cp.hstack (concatenation). For 2D+ arrays, uses
1170
+ cp.bmat with a single row to achieve proper horizontal stacking along
1171
+ axis 1, matching numpy.hstack semantics.
1172
+
1173
+ Args:
1174
+ node: Hstack expression node with multiple arrays
1175
+
1176
+ Returns:
1177
+ CVXPy expression representing horizontal stack of arrays
1178
+ """
1179
+ arrays = [self.lower(arr) for arr in node.arrays]
1180
+
1181
+ # Check dimensionality from the symbolic node's shape
1182
+ shape = node.check_shape()
1183
+ if len(shape) == 1:
1184
+ # 1D: simple concatenation
1185
+ return cp.hstack(arrays)
1186
+ else:
1187
+ # 2D+: use bmat with single row for proper horizontal stacking
1188
+ return cp.bmat([arrays])
1189
+
1190
+ @visitor(Vstack)
1191
+ def _visit_vstack(self, node: Vstack) -> cp.Expression:
1192
+ """Lower vertical stacking to CVXPy expression.
1193
+
1194
+ Stacks expressions vertically using cp.vstack. Preserves DCP properties.
1195
+
1196
+ Args:
1197
+ node: Vstack expression node with multiple arrays
1198
+
1199
+ Returns:
1200
+ CVXPy expression representing vertical stack of arrays
1201
+ """
1202
+ arrays = [self.lower(arr) for arr in node.arrays]
1203
+ return cp.vstack(arrays)
1204
+
1205
+ @visitor(Block)
1206
+ def _visit_block(self, node: Block) -> cp.Expression:
1207
+ """Lower block matrix construction to CVXPy expression.
1208
+
1209
+ Assembles a block matrix from nested lists of expressions using cp.bmat.
1210
+ This is the CVXPy equivalent of numpy.block() for block matrix construction.
1211
+
1212
+ Args:
1213
+ node: Block expression node with 2D nested structure of expressions
1214
+
1215
+ Returns:
1216
+ CVXPy expression representing the assembled block matrix
1217
+
1218
+ Raises:
1219
+ NotImplementedError: If any block has more than 2 dimensions
1220
+
1221
+ Note:
1222
+ cp.bmat preserves DCP properties when all blocks are DCP-compliant.
1223
+ Block matrices are commonly used for constraint aggregation.
1224
+ For 3D+ tensors, use JAX lowering instead.
1225
+ """
1226
+ # Check for 3D+ blocks - CVXPy's bmat only supports 2D
1227
+ for i, row in enumerate(node.blocks):
1228
+ for j, block in enumerate(row):
1229
+ block_shape = block.check_shape()
1230
+ if len(block_shape) > 2:
1231
+ raise NotImplementedError(
1232
+ f"CVXPy does not support Block with tensors of dimension > 2. "
1233
+ f"Block[{i}][{j}] has shape {block_shape} ({len(block_shape)}D). "
1234
+ f"For N-D tensor block assembly, use JAX lowering instead."
1235
+ )
1236
+
1237
+ # Lower each block expression
1238
+ block_exprs = [[self.lower(block) for block in row] for row in node.blocks]
1239
+ return cp.bmat(block_exprs)
1240
+
1241
+
1242
+ def lower_to_cvxpy(expr: Expr, variable_map: Dict[str, cp.Expression] = None) -> cp.Expression:
1243
+ """Lower symbolic expression to CVXPy expression or constraint.
1244
+
1245
+ Convenience wrapper that creates a CvxpyLowerer and lowers a single
1246
+ symbolic expression to a CVXPy expression. The result can be used in
1247
+ CVXPy optimization problems.
1248
+
1249
+ Args:
1250
+ expr: Symbolic expression to lower (any Expr subclass)
1251
+ variable_map: Dictionary mapping variable names to CVXPy expressions.
1252
+ Must include "x" for states and "u" for controls. May include
1253
+ parameter names mapped to CVXPy Parameters or constants.
1254
+
1255
+ Returns:
1256
+ CVXPy expression for arithmetic expressions (Add, Mul, Norm, etc.)
1257
+ or CVXPy constraint for constraint expressions (Equality, Inequality)
1258
+
1259
+ Raises:
1260
+ NotImplementedError: If the expression type is not supported (e.g., Sin, Cos, CTCS)
1261
+ ValueError: If required variables are missing from variable_map
1262
+
1263
+ Example:
1264
+ Basic expression lowering::
1265
+
1266
+ import cvxpy as cp
1267
+ import openscvx as ox
1268
+
1269
+ # Create CVXPy variables
1270
+ cvx_x = cp.Variable(3, name="x")
1271
+ cvx_u = cp.Variable(2, name="u")
1272
+
1273
+ # Create symbolic expression
1274
+ x = ox.State("x", shape=(3,))
1275
+ u = ox.Control("u", shape=(2,))
1276
+ expr = ox.Norm(x)**2 + 0.1 * ox.Norm(u)**2
1277
+
1278
+ # Lower to CVXPy
1279
+ cvx_expr = lower_to_cvxpy(expr, {"x": cvx_x, "u": cvx_u})
1280
+
1281
+ # Use in optimization problem
1282
+ prob = cp.Problem(cp.Minimize(cvx_expr))
1283
+ prob.solve()
1284
+
1285
+ Constraint lowering::
1286
+
1287
+ # Symbolic constraint
1288
+ constraint = ox.Norm(x) <= 1.0
1289
+
1290
+ # Lower to CVXPy constraint
1291
+ cvx_constraint = lower_to_cvxpy(constraint, {"x": cvx_x, "u": cvx_u})
1292
+
1293
+ # Use in problem
1294
+ prob = cp.Problem(cp.Minimize(cost), constraints=[cvx_constraint])
1295
+
1296
+ See Also:
1297
+ - CvxpyLowerer: The underlying lowerer class
1298
+ - lower_to_jax(): Convenience wrapper for JAX lowering
1299
+ - lower_symbolic_expressions(): Main orchestrator in symbolic/lower.py
1300
+ """
1301
+ lowerer = CvxpyLowerer(variable_map)
1302
+ return lowerer.lower(expr)