openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,30 @@
1
+ """Lowered problem dataclasses.
2
+
3
+ This module contains dataclasses representing the outputs of the lowering phase,
4
+ where symbolic expressions are converted to executable JAX and CVXPy code.
5
+ """
6
+
7
+ from openscvx.lowered.cvxpy_constraints import LoweredCvxpyConstraints
8
+ from openscvx.lowered.cvxpy_variables import CVXPyVariables
9
+ from openscvx.lowered.dynamics import Dynamics
10
+ from openscvx.lowered.jax_constraints import (
11
+ LoweredCrossNodeConstraint,
12
+ LoweredJaxConstraints,
13
+ LoweredNodalConstraint,
14
+ )
15
+ from openscvx.lowered.parameters import ParameterDict
16
+ from openscvx.lowered.problem import LoweredProblem
17
+ from openscvx.lowered.unified import UnifiedControl, UnifiedState
18
+
19
+ __all__ = [
20
+ "LoweredProblem",
21
+ "LoweredJaxConstraints",
22
+ "LoweredCvxpyConstraints",
23
+ "LoweredNodalConstraint",
24
+ "LoweredCrossNodeConstraint",
25
+ "CVXPyVariables",
26
+ "ParameterDict",
27
+ "Dynamics",
28
+ "UnifiedState",
29
+ "UnifiedControl",
30
+ ]
@@ -0,0 +1,23 @@
1
+ """CVXPy-lowered constraint dataclass."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import TYPE_CHECKING
5
+
6
+ if TYPE_CHECKING:
7
+ import cvxpy as cp
8
+
9
+
10
+ @dataclass
11
+ class LoweredCvxpyConstraints:
12
+ """CVXPy-lowered convex constraints.
13
+
14
+ Contains constraints that have been lowered to CVXPy constraint objects.
15
+ These are added directly to the optimal control problem without
16
+ linearization.
17
+
18
+ Attributes:
19
+ constraints: List of CVXPy constraint objects (cp.Constraint).
20
+ Includes both nodal and cross-node convex constraints.
21
+ """
22
+
23
+ constraints: list["cp.Constraint"] = field(default_factory=list)
@@ -0,0 +1,124 @@
1
+ """CVXPy variables and parameters dataclass for the optimal control problem."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import TYPE_CHECKING, List
5
+
6
+ import numpy as np
7
+
8
+ if TYPE_CHECKING:
9
+ import cvxpy as cp
10
+
11
+
12
+ @dataclass
13
+ class CVXPyVariables:
14
+ """CVXPy variables and parameters for the optimal control problem.
15
+
16
+ This dataclass holds all CVXPy Variable and Parameter objects needed to
17
+ construct and solve the optimal control problem. It replaces the previous
18
+ untyped dictionary approach with a typed, self-documenting structure.
19
+
20
+ The variables are organized into logical groups:
21
+ - SCP weights: Parameters controlling trust region and penalty weights
22
+ - State: Variables and parameters for the state trajectory
23
+ - Control: Variables and parameters for the control trajectory
24
+ - Dynamics: Parameters for the discretized dynamics constraints
25
+ - Nodal constraints: Parameters for linearized non-convex nodal constraints
26
+ - Cross-node constraints: Parameters for linearized cross-node constraints
27
+ - Scaling: Affine scaling matrices and offset vectors
28
+ - Scaled expressions: CVXPy expressions for scaled state/control at each node
29
+
30
+ Attributes:
31
+ w_tr: Trust region weight parameter (scalar, nonneg)
32
+ lam_cost: Cost function weight parameter (scalar, nonneg)
33
+ lam_vc: Virtual control penalty weights (N-1 x n_states, nonneg)
34
+ lam_vb: Virtual buffer penalty weight (scalar, nonneg)
35
+
36
+ x: State variable (N x n_states)
37
+ dx: State error variable (N x n_states)
38
+ x_bar: Previous SCP state parameter (N x n_states)
39
+ x_init: Initial state parameter (n_states,)
40
+ x_term: Terminal state parameter (n_states,)
41
+
42
+ u: Control variable (N x n_controls)
43
+ du: Control error variable (N x n_controls)
44
+ u_bar: Previous SCP control parameter (N x n_controls)
45
+
46
+ A_d: Discretized state Jacobian parameter (N-1 x n_states*n_states)
47
+ B_d: Discretized control Jacobian parameter (N-1 x n_states*n_controls)
48
+ C_d: Discretized control Jacobian (next node) parameter
49
+ x_prop: Propagated state parameter (N-1 x n_states)
50
+ nu: Virtual control variable (N-1 x n_states)
51
+
52
+ g: List of constraint value parameters (one per nodal constraint)
53
+ grad_g_x: List of state gradient parameters (one per nodal constraint)
54
+ grad_g_u: List of control gradient parameters (one per nodal constraint)
55
+ nu_vb: List of virtual buffer variables (one per nodal constraint)
56
+
57
+ g_cross: List of cross-node constraint value parameters
58
+ grad_g_X_cross: List of trajectory state gradient parameters
59
+ grad_g_U_cross: List of trajectory control gradient parameters
60
+ nu_vb_cross: List of cross-node virtual buffer variables
61
+
62
+ S_x: State scaling matrix (n_states x n_states)
63
+ inv_S_x: Inverse state scaling matrix
64
+ c_x: State offset vector (n_states,)
65
+ S_u: Control scaling matrix (n_controls x n_controls)
66
+ inv_S_u: Inverse control scaling matrix
67
+ c_u: Control offset vector (n_controls,)
68
+
69
+ x_nonscaled: List of scaled state expressions at each node
70
+ u_nonscaled: List of scaled control expressions at each node
71
+ dx_nonscaled: List of scaled state error expressions at each node
72
+ du_nonscaled: List of scaled control error expressions at each node
73
+ """
74
+
75
+ # SCP weight parameters
76
+ w_tr: "cp.Parameter"
77
+ lam_cost: "cp.Parameter"
78
+ lam_vc: "cp.Parameter"
79
+ lam_vb: "cp.Parameter"
80
+
81
+ # State variables and parameters
82
+ x: "cp.Variable"
83
+ dx: "cp.Variable"
84
+ x_bar: "cp.Parameter"
85
+ x_init: "cp.Parameter"
86
+ x_term: "cp.Parameter"
87
+
88
+ # Control variables and parameters
89
+ u: "cp.Variable"
90
+ du: "cp.Variable"
91
+ u_bar: "cp.Parameter"
92
+
93
+ # Dynamics discretization parameters
94
+ A_d: "cp.Parameter"
95
+ B_d: "cp.Parameter"
96
+ C_d: "cp.Parameter"
97
+ x_prop: "cp.Parameter"
98
+ nu: "cp.Variable"
99
+
100
+ # Nodal constraint linearization (lists, one per constraint)
101
+ g: List["cp.Parameter"] = field(default_factory=list)
102
+ grad_g_x: List["cp.Parameter"] = field(default_factory=list)
103
+ grad_g_u: List["cp.Parameter"] = field(default_factory=list)
104
+ nu_vb: List["cp.Variable"] = field(default_factory=list)
105
+
106
+ # Cross-node constraint linearization (lists, one per constraint)
107
+ g_cross: List["cp.Parameter"] = field(default_factory=list)
108
+ grad_g_X_cross: List["cp.Parameter"] = field(default_factory=list)
109
+ grad_g_U_cross: List["cp.Parameter"] = field(default_factory=list)
110
+ nu_vb_cross: List["cp.Variable"] = field(default_factory=list)
111
+
112
+ # Scaling matrices and offsets (numpy arrays)
113
+ S_x: np.ndarray = field(default_factory=lambda: np.array([]))
114
+ inv_S_x: np.ndarray = field(default_factory=lambda: np.array([]))
115
+ c_x: np.ndarray = field(default_factory=lambda: np.array([]))
116
+ S_u: np.ndarray = field(default_factory=lambda: np.array([]))
117
+ inv_S_u: np.ndarray = field(default_factory=lambda: np.array([]))
118
+ c_u: np.ndarray = field(default_factory=lambda: np.array([]))
119
+
120
+ # Scaled CVXPy expressions at each node (lists of length N)
121
+ x_nonscaled: List = field(default_factory=list)
122
+ u_nonscaled: List = field(default_factory=list)
123
+ dx_nonscaled: List = field(default_factory=list)
124
+ du_nonscaled: List = field(default_factory=list)
@@ -0,0 +1,34 @@
1
+ from dataclasses import dataclass
2
+ from typing import Callable, Optional
3
+
4
+ import jax.numpy as jnp
5
+
6
+
7
+ @dataclass
8
+ class Dynamics:
9
+ """Dataclass to hold a system dynamics function and its Jacobians.
10
+
11
+ This dataclass is used internally by openscvx to store the compiled dynamics
12
+ function and its gradients after symbolic expressions are lowered to JAX.
13
+ Users typically don't instantiate this class directly.
14
+
15
+ Attributes:
16
+ f (Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
17
+ Function defining the continuous time nonlinear system dynamics
18
+ as x_dot = f(x, u, ...params).
19
+ - x: 1D array (state at a single node), shape (n_x,)
20
+ - u: 1D array (control at a single node), shape (n_u,)
21
+ - Additional parameters: passed as keyword arguments with names
22
+ matching the parameter name plus an underscore (e.g., g_ for
23
+ Parameter('g')).
24
+ If you use vectorized integration or batch evaluation, x and u
25
+ may be 2D arrays (N, n_x) and (N, n_u).
26
+ A (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
27
+ Jacobian of ``f`` w.r.t. ``x``.
28
+ B (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
29
+ Jacobian of ``f`` w.r.t. ``u``.
30
+ """
31
+
32
+ f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
33
+ A: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
34
+ B: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
@@ -0,0 +1,133 @@
1
+ """JAX-lowered constraint dataclass."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import TYPE_CHECKING, Callable, List, Optional
5
+
6
+ import jax.numpy as jnp
7
+
8
+ if TYPE_CHECKING:
9
+ from openscvx.symbolic.expr import CTCS
10
+
11
+
12
+ @dataclass
13
+ class LoweredNodalConstraint:
14
+ """
15
+ Dataclass to hold a lowered symbolic constraint function and its jacobians.
16
+
17
+ This is a simplified drop-in replacement for NodalConstraint that holds
18
+ only the essential lowered JAX functions and their jacobians, without
19
+ the complexity of convex/vectorized flags or post-initialization logic.
20
+
21
+ Designed for use with symbolic expressions that have been lowered to JAX
22
+ and will be linearized for sequential convex programming.
23
+
24
+ Args:
25
+ func (Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
26
+ The lowered constraint function g(x, u, ...params) that returns
27
+ constraint residuals. Should follow g(x, u) <= 0 convention.
28
+ - x: 1D array (state at a single node), shape (n_x,)
29
+ - u: 1D array (control at a single node), shape (n_u,)
30
+ - Additional parameters: passed as keyword arguments
31
+
32
+ grad_g_x (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
33
+ Jacobian of g w.r.t. x. If None, should be computed using jax.jacfwd.
34
+
35
+ grad_g_u (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
36
+ Jacobian of g w.r.t. u. If None, should be computed using jax.jacfwd.
37
+
38
+ nodes (Optional[List[int]]): List of node indices where this constraint applies.
39
+ Set after lowering from NodalConstraint.
40
+ """
41
+
42
+ func: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
43
+ grad_g_x: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
44
+ grad_g_u: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
45
+ nodes: Optional[List[int]] = None
46
+
47
+
48
+ @dataclass
49
+ class LoweredCrossNodeConstraint:
50
+ """Lowered cross-node constraint with trajectory-level evaluation.
51
+
52
+ Unlike regular LoweredNodalConstraint which operates on single-node vectors
53
+ and is vmapped across the trajectory, LoweredCrossNodeConstraint operates
54
+ on full trajectory arrays to relate multiple nodes simultaneously.
55
+
56
+ This is necessary for constraints like:
57
+ - Rate limits: x[k] - x[k-1] <= max_rate
58
+ - Multi-step dependencies: x[k] = 2*x[k-1] - x[k-2]
59
+ - Periodic boundaries: x[0] = x[N-1]
60
+
61
+ The function signatures differ from LoweredNodalConstraint:
62
+ - Regular: f(x, u, node, params) -> scalar (vmapped to handle (N, n_x))
63
+ - Cross-node: f(X, U, params) -> scalar (single constraint with fixed node indices)
64
+
65
+ Attributes:
66
+ func: Function (X, U, params) -> scalar residual
67
+ where X: (N, n_x), U: (N, n_u)
68
+ Returns constraint residual following g(X, U) <= 0 convention
69
+ The constraint references fixed trajectory nodes (e.g., X[5] - X[4])
70
+ grad_g_X: Function (X, U, params) -> (N, n_x) Jacobian wrt full state trajectory
71
+ This is typically sparse - most constraints only couple nearby nodes
72
+ grad_g_U: Function (X, U, params) -> (N, n_u) Jacobian wrt full control trajectory
73
+ Often zero or very sparse for cross-node state constraints
74
+
75
+ Example:
76
+ For rate constraint x[5] - x[4] <= r:
77
+
78
+ func(X, U, params) -> scalar residual
79
+ grad_g_X(X, U, params) -> (N, n_x) sparse Jacobian
80
+ where grad_g_X[5, :] = ∂g/∂x[5] (derivative wrt node 5)
81
+ and grad_g_X[4, :] = ∂g/∂x[4] (derivative wrt node 4)
82
+ all other entries are zero
83
+
84
+ Performance Note - Dense Jacobian Storage:
85
+ The Jacobian matrices grad_g_X and grad_g_U are stored as DENSE arrays with
86
+ shape (N, n_x) and (N, n_u), but most cross-node constraints only couple a
87
+ small number of nearby nodes, making these matrices extremely sparse.
88
+
89
+ For example, a rate limit constraint x[k] - x[k-1] <= r only has non-zero
90
+ Jacobian entries at positions [k, :] and [k-1, :]. All other N-2 rows are
91
+ zero but still stored in memory.
92
+
93
+ Memory impact for large problems:
94
+ - A single constraint with N=100 nodes, n_x=10 states requires ~8KB for
95
+ grad_g_X (compared to ~160 bytes if sparse with 2 non-zero rows)
96
+ - Multiple cross-node constraints multiply this overhead
97
+ - May cause issues for N > 1000 with many constraints
98
+
99
+ Performance impact:
100
+ - Slower autodiff (computes many zero gradients)
101
+ - Inefficient constraint linearization in the SCP solver
102
+ - Potential GPU memory limitations for very large problems
103
+
104
+ The current implementation prioritizes simplicity and compatibility with
105
+ JAX's autodiff over memory efficiency. Future versions may support sparse
106
+ Jacobian formats (COO, CSR, or custom sparse representations) if this
107
+ becomes a bottleneck in practice.
108
+ """
109
+
110
+ func: Callable[[jnp.ndarray, jnp.ndarray, dict], jnp.ndarray]
111
+ grad_g_X: Callable[[jnp.ndarray, jnp.ndarray, dict], jnp.ndarray]
112
+ grad_g_U: Callable[[jnp.ndarray, jnp.ndarray, dict], jnp.ndarray]
113
+
114
+
115
+ @dataclass
116
+ class LoweredJaxConstraints:
117
+ """JAX-lowered non-convex constraints with gradient functions.
118
+
119
+ Contains constraints that have been lowered to JAX callable functions
120
+ with automatically computed gradients. These are used for linearization
121
+ in the SCP (Sequential Convex Programming) loop.
122
+
123
+ Attributes:
124
+ nodal: List of LoweredNodalConstraint objects. Each has `func`,
125
+ `grad_g_x`, `grad_g_u` callables and `nodes` list.
126
+ cross_node: List of LoweredCrossNodeConstraint objects. Each has
127
+ `func`, `grad_g_X`, `grad_g_U` for trajectory-level constraints.
128
+ ctcs: CTCS constraints (unchanged from input, not lowered here).
129
+ """
130
+
131
+ nodal: list[LoweredNodalConstraint] = field(default_factory=list)
132
+ cross_node: list[LoweredCrossNodeConstraint] = field(default_factory=list)
133
+ ctcs: list["CTCS"] = field(default_factory=list)
@@ -0,0 +1,54 @@
1
+ """Parameter dictionary that syncs between JAX and CVXPy."""
2
+
3
+ import numpy as np
4
+
5
+
6
+ class ParameterDict(dict):
7
+ """Dictionary that syncs to both internal _parameters dict and CVXPy parameters.
8
+
9
+ This allows users to naturally update parameters like:
10
+ problem.parameters["obs_radius"] = 2.0
11
+
12
+ Changes automatically propagate to:
13
+ 1. Internal _parameters dict (plain dict for JAX)
14
+ 2. CVXPy parameters (for optimization)
15
+ """
16
+
17
+ def __init__(self, problem, internal_dict, *args, **kwargs):
18
+ self._problem = problem
19
+ self._internal_dict = internal_dict # Reference to plain dict for JAX
20
+ super().__init__()
21
+ # Initialize with float enforcement by using __setitem__
22
+ if args:
23
+ other = args[0]
24
+ if hasattr(other, "items"):
25
+ for key, value in other.items():
26
+ self[key] = value
27
+ else:
28
+ for key, value in other:
29
+ self[key] = value
30
+ for key, value in kwargs.items():
31
+ self[key] = value
32
+
33
+ def __setitem__(self, key, value):
34
+ # Enforce float dtype to prevent int/float mismatch bugs
35
+ value = np.asarray(value, dtype=float)
36
+ super().__setitem__(key, value)
37
+ # Sync to internal dict for JAX
38
+ self._internal_dict[key] = value
39
+ # Sync to CVXPy if it exists
40
+ lowered = getattr(self._problem, "_lowered", None)
41
+ if lowered is not None and key in lowered.cvxpy_params:
42
+ lowered.cvxpy_params[key].value = value
43
+
44
+ def update(self, other=None, **kwargs):
45
+ """Update multiple parameters and sync to internal dict and CVXPy."""
46
+ if other is not None:
47
+ if hasattr(other, "items"):
48
+ for key, value in other.items():
49
+ self[key] = value
50
+ else:
51
+ for key, value in other:
52
+ self[key] = value
53
+ for key, value in kwargs.items():
54
+ self[key] = value
@@ -0,0 +1,70 @@
1
+ """LoweredProblem dataclass - container for all lowering outputs."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Dict
5
+
6
+ from openscvx.lowered.cvxpy_constraints import LoweredCvxpyConstraints
7
+ from openscvx.lowered.cvxpy_variables import CVXPyVariables
8
+ from openscvx.lowered.dynamics import Dynamics
9
+ from openscvx.lowered.jax_constraints import LoweredJaxConstraints
10
+ from openscvx.lowered.unified import UnifiedControl, UnifiedState
11
+
12
+ if TYPE_CHECKING:
13
+ import cvxpy as cp
14
+
15
+
16
+ @dataclass
17
+ class LoweredProblem:
18
+ """Container for all outputs from symbolic problem lowering.
19
+
20
+ This dataclass holds all the results of lowering symbolic expressions
21
+ to executable JAX and CVXPy code. It provides a clean, typed interface
22
+ for accessing the various components needed for optimization.
23
+
24
+ Attributes:
25
+ dynamics: Optimization dynamics with fields f, A, B (JAX functions)
26
+ dynamics_prop: Propagation dynamics with fields f, A, B
27
+ jax_constraints: Non-convex constraints lowered to JAX with gradients
28
+ cvxpy_constraints: Convex constraints lowered to CVXPy
29
+ x_unified: Aggregated optimization state interface
30
+ u_unified: Aggregated optimization control interface
31
+ x_prop_unified: Aggregated propagation state interface
32
+ ocp_vars: Typed CVXPy variables and parameters for OCP construction
33
+ cvxpy_params: Dict mapping user parameter names to CVXPy Parameter objects
34
+
35
+ Example:
36
+ After lowering a symbolic problem::
37
+
38
+ lowered = lower_symbolic_problem(
39
+ dynamics_aug=dynamics,
40
+ states_aug=states,
41
+ controls_aug=controls,
42
+ constraints=constraint_set,
43
+ parameters=params,
44
+ N=50,
45
+ )
46
+
47
+ # Access components
48
+ dx_dt = lowered.dynamics.f(x, u, node, params)
49
+ jacobian_A = lowered.dynamics.A(x, u, node, params)
50
+
51
+ # Use CVXPy objects
52
+ ocp = OptimalControlProblem(settings, lowered)
53
+ """
54
+
55
+ # JAX dynamics
56
+ dynamics: Dynamics
57
+ dynamics_prop: Dynamics
58
+
59
+ # Lowered constraints (separate types for JAX vs CVXPy)
60
+ jax_constraints: LoweredJaxConstraints
61
+ cvxpy_constraints: LoweredCvxpyConstraints
62
+
63
+ # Unified interfaces
64
+ x_unified: UnifiedState
65
+ u_unified: UnifiedControl
66
+ x_prop_unified: UnifiedState
67
+
68
+ # CVXPy objects
69
+ ocp_vars: CVXPyVariables
70
+ cvxpy_params: Dict[str, "cp.Parameter"]