openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Lowered problem dataclasses.
|
|
2
|
+
|
|
3
|
+
This module contains dataclasses representing the outputs of the lowering phase,
|
|
4
|
+
where symbolic expressions are converted to executable JAX and CVXPy code.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from openscvx.lowered.cvxpy_constraints import LoweredCvxpyConstraints
|
|
8
|
+
from openscvx.lowered.cvxpy_variables import CVXPyVariables
|
|
9
|
+
from openscvx.lowered.dynamics import Dynamics
|
|
10
|
+
from openscvx.lowered.jax_constraints import (
|
|
11
|
+
LoweredCrossNodeConstraint,
|
|
12
|
+
LoweredJaxConstraints,
|
|
13
|
+
LoweredNodalConstraint,
|
|
14
|
+
)
|
|
15
|
+
from openscvx.lowered.parameters import ParameterDict
|
|
16
|
+
from openscvx.lowered.problem import LoweredProblem
|
|
17
|
+
from openscvx.lowered.unified import UnifiedControl, UnifiedState
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"LoweredProblem",
|
|
21
|
+
"LoweredJaxConstraints",
|
|
22
|
+
"LoweredCvxpyConstraints",
|
|
23
|
+
"LoweredNodalConstraint",
|
|
24
|
+
"LoweredCrossNodeConstraint",
|
|
25
|
+
"CVXPyVariables",
|
|
26
|
+
"ParameterDict",
|
|
27
|
+
"Dynamics",
|
|
28
|
+
"UnifiedState",
|
|
29
|
+
"UnifiedControl",
|
|
30
|
+
]
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""CVXPy-lowered constraint dataclass."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
import cvxpy as cp
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class LoweredCvxpyConstraints:
|
|
12
|
+
"""CVXPy-lowered convex constraints.
|
|
13
|
+
|
|
14
|
+
Contains constraints that have been lowered to CVXPy constraint objects.
|
|
15
|
+
These are added directly to the optimal control problem without
|
|
16
|
+
linearization.
|
|
17
|
+
|
|
18
|
+
Attributes:
|
|
19
|
+
constraints: List of CVXPy constraint objects (cp.Constraint).
|
|
20
|
+
Includes both nodal and cross-node convex constraints.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
constraints: list["cp.Constraint"] = field(default_factory=list)
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""CVXPy variables and parameters dataclass for the optimal control problem."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import TYPE_CHECKING, List
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
import cvxpy as cp
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class CVXPyVariables:
|
|
14
|
+
"""CVXPy variables and parameters for the optimal control problem.
|
|
15
|
+
|
|
16
|
+
This dataclass holds all CVXPy Variable and Parameter objects needed to
|
|
17
|
+
construct and solve the optimal control problem. It replaces the previous
|
|
18
|
+
untyped dictionary approach with a typed, self-documenting structure.
|
|
19
|
+
|
|
20
|
+
The variables are organized into logical groups:
|
|
21
|
+
- SCP weights: Parameters controlling trust region and penalty weights
|
|
22
|
+
- State: Variables and parameters for the state trajectory
|
|
23
|
+
- Control: Variables and parameters for the control trajectory
|
|
24
|
+
- Dynamics: Parameters for the discretized dynamics constraints
|
|
25
|
+
- Nodal constraints: Parameters for linearized non-convex nodal constraints
|
|
26
|
+
- Cross-node constraints: Parameters for linearized cross-node constraints
|
|
27
|
+
- Scaling: Affine scaling matrices and offset vectors
|
|
28
|
+
- Scaled expressions: CVXPy expressions for scaled state/control at each node
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
w_tr: Trust region weight parameter (scalar, nonneg)
|
|
32
|
+
lam_cost: Cost function weight parameter (scalar, nonneg)
|
|
33
|
+
lam_vc: Virtual control penalty weights (N-1 x n_states, nonneg)
|
|
34
|
+
lam_vb: Virtual buffer penalty weight (scalar, nonneg)
|
|
35
|
+
|
|
36
|
+
x: State variable (N x n_states)
|
|
37
|
+
dx: State error variable (N x n_states)
|
|
38
|
+
x_bar: Previous SCP state parameter (N x n_states)
|
|
39
|
+
x_init: Initial state parameter (n_states,)
|
|
40
|
+
x_term: Terminal state parameter (n_states,)
|
|
41
|
+
|
|
42
|
+
u: Control variable (N x n_controls)
|
|
43
|
+
du: Control error variable (N x n_controls)
|
|
44
|
+
u_bar: Previous SCP control parameter (N x n_controls)
|
|
45
|
+
|
|
46
|
+
A_d: Discretized state Jacobian parameter (N-1 x n_states*n_states)
|
|
47
|
+
B_d: Discretized control Jacobian parameter (N-1 x n_states*n_controls)
|
|
48
|
+
C_d: Discretized control Jacobian (next node) parameter
|
|
49
|
+
x_prop: Propagated state parameter (N-1 x n_states)
|
|
50
|
+
nu: Virtual control variable (N-1 x n_states)
|
|
51
|
+
|
|
52
|
+
g: List of constraint value parameters (one per nodal constraint)
|
|
53
|
+
grad_g_x: List of state gradient parameters (one per nodal constraint)
|
|
54
|
+
grad_g_u: List of control gradient parameters (one per nodal constraint)
|
|
55
|
+
nu_vb: List of virtual buffer variables (one per nodal constraint)
|
|
56
|
+
|
|
57
|
+
g_cross: List of cross-node constraint value parameters
|
|
58
|
+
grad_g_X_cross: List of trajectory state gradient parameters
|
|
59
|
+
grad_g_U_cross: List of trajectory control gradient parameters
|
|
60
|
+
nu_vb_cross: List of cross-node virtual buffer variables
|
|
61
|
+
|
|
62
|
+
S_x: State scaling matrix (n_states x n_states)
|
|
63
|
+
inv_S_x: Inverse state scaling matrix
|
|
64
|
+
c_x: State offset vector (n_states,)
|
|
65
|
+
S_u: Control scaling matrix (n_controls x n_controls)
|
|
66
|
+
inv_S_u: Inverse control scaling matrix
|
|
67
|
+
c_u: Control offset vector (n_controls,)
|
|
68
|
+
|
|
69
|
+
x_nonscaled: List of scaled state expressions at each node
|
|
70
|
+
u_nonscaled: List of scaled control expressions at each node
|
|
71
|
+
dx_nonscaled: List of scaled state error expressions at each node
|
|
72
|
+
du_nonscaled: List of scaled control error expressions at each node
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
# SCP weight parameters
|
|
76
|
+
w_tr: "cp.Parameter"
|
|
77
|
+
lam_cost: "cp.Parameter"
|
|
78
|
+
lam_vc: "cp.Parameter"
|
|
79
|
+
lam_vb: "cp.Parameter"
|
|
80
|
+
|
|
81
|
+
# State variables and parameters
|
|
82
|
+
x: "cp.Variable"
|
|
83
|
+
dx: "cp.Variable"
|
|
84
|
+
x_bar: "cp.Parameter"
|
|
85
|
+
x_init: "cp.Parameter"
|
|
86
|
+
x_term: "cp.Parameter"
|
|
87
|
+
|
|
88
|
+
# Control variables and parameters
|
|
89
|
+
u: "cp.Variable"
|
|
90
|
+
du: "cp.Variable"
|
|
91
|
+
u_bar: "cp.Parameter"
|
|
92
|
+
|
|
93
|
+
# Dynamics discretization parameters
|
|
94
|
+
A_d: "cp.Parameter"
|
|
95
|
+
B_d: "cp.Parameter"
|
|
96
|
+
C_d: "cp.Parameter"
|
|
97
|
+
x_prop: "cp.Parameter"
|
|
98
|
+
nu: "cp.Variable"
|
|
99
|
+
|
|
100
|
+
# Nodal constraint linearization (lists, one per constraint)
|
|
101
|
+
g: List["cp.Parameter"] = field(default_factory=list)
|
|
102
|
+
grad_g_x: List["cp.Parameter"] = field(default_factory=list)
|
|
103
|
+
grad_g_u: List["cp.Parameter"] = field(default_factory=list)
|
|
104
|
+
nu_vb: List["cp.Variable"] = field(default_factory=list)
|
|
105
|
+
|
|
106
|
+
# Cross-node constraint linearization (lists, one per constraint)
|
|
107
|
+
g_cross: List["cp.Parameter"] = field(default_factory=list)
|
|
108
|
+
grad_g_X_cross: List["cp.Parameter"] = field(default_factory=list)
|
|
109
|
+
grad_g_U_cross: List["cp.Parameter"] = field(default_factory=list)
|
|
110
|
+
nu_vb_cross: List["cp.Variable"] = field(default_factory=list)
|
|
111
|
+
|
|
112
|
+
# Scaling matrices and offsets (numpy arrays)
|
|
113
|
+
S_x: np.ndarray = field(default_factory=lambda: np.array([]))
|
|
114
|
+
inv_S_x: np.ndarray = field(default_factory=lambda: np.array([]))
|
|
115
|
+
c_x: np.ndarray = field(default_factory=lambda: np.array([]))
|
|
116
|
+
S_u: np.ndarray = field(default_factory=lambda: np.array([]))
|
|
117
|
+
inv_S_u: np.ndarray = field(default_factory=lambda: np.array([]))
|
|
118
|
+
c_u: np.ndarray = field(default_factory=lambda: np.array([]))
|
|
119
|
+
|
|
120
|
+
# Scaled CVXPy expressions at each node (lists of length N)
|
|
121
|
+
x_nonscaled: List = field(default_factory=list)
|
|
122
|
+
u_nonscaled: List = field(default_factory=list)
|
|
123
|
+
dx_nonscaled: List = field(default_factory=list)
|
|
124
|
+
du_nonscaled: List = field(default_factory=list)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Callable, Optional
|
|
3
|
+
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class Dynamics:
|
|
9
|
+
"""Dataclass to hold a system dynamics function and its Jacobians.
|
|
10
|
+
|
|
11
|
+
This dataclass is used internally by openscvx to store the compiled dynamics
|
|
12
|
+
function and its gradients after symbolic expressions are lowered to JAX.
|
|
13
|
+
Users typically don't instantiate this class directly.
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
f (Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
|
|
17
|
+
Function defining the continuous time nonlinear system dynamics
|
|
18
|
+
as x_dot = f(x, u, ...params).
|
|
19
|
+
- x: 1D array (state at a single node), shape (n_x,)
|
|
20
|
+
- u: 1D array (control at a single node), shape (n_u,)
|
|
21
|
+
- Additional parameters: passed as keyword arguments with names
|
|
22
|
+
matching the parameter name plus an underscore (e.g., g_ for
|
|
23
|
+
Parameter('g')).
|
|
24
|
+
If you use vectorized integration or batch evaluation, x and u
|
|
25
|
+
may be 2D arrays (N, n_x) and (N, n_u).
|
|
26
|
+
A (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
|
|
27
|
+
Jacobian of ``f`` w.r.t. ``x``.
|
|
28
|
+
B (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
|
|
29
|
+
Jacobian of ``f`` w.r.t. ``u``.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
|
|
33
|
+
A: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
|
|
34
|
+
B: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""JAX-lowered constraint dataclass."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import TYPE_CHECKING, Callable, List, Optional
|
|
5
|
+
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from openscvx.symbolic.expr import CTCS
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class LoweredNodalConstraint:
|
|
14
|
+
"""
|
|
15
|
+
Dataclass to hold a lowered symbolic constraint function and its jacobians.
|
|
16
|
+
|
|
17
|
+
This is a simplified drop-in replacement for NodalConstraint that holds
|
|
18
|
+
only the essential lowered JAX functions and their jacobians, without
|
|
19
|
+
the complexity of convex/vectorized flags or post-initialization logic.
|
|
20
|
+
|
|
21
|
+
Designed for use with symbolic expressions that have been lowered to JAX
|
|
22
|
+
and will be linearized for sequential convex programming.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
func (Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
|
|
26
|
+
The lowered constraint function g(x, u, ...params) that returns
|
|
27
|
+
constraint residuals. Should follow g(x, u) <= 0 convention.
|
|
28
|
+
- x: 1D array (state at a single node), shape (n_x,)
|
|
29
|
+
- u: 1D array (control at a single node), shape (n_u,)
|
|
30
|
+
- Additional parameters: passed as keyword arguments
|
|
31
|
+
|
|
32
|
+
grad_g_x (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
|
|
33
|
+
Jacobian of g w.r.t. x. If None, should be computed using jax.jacfwd.
|
|
34
|
+
|
|
35
|
+
grad_g_u (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
|
|
36
|
+
Jacobian of g w.r.t. u. If None, should be computed using jax.jacfwd.
|
|
37
|
+
|
|
38
|
+
nodes (Optional[List[int]]): List of node indices where this constraint applies.
|
|
39
|
+
Set after lowering from NodalConstraint.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
func: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
|
|
43
|
+
grad_g_x: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
|
|
44
|
+
grad_g_u: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
|
|
45
|
+
nodes: Optional[List[int]] = None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class LoweredCrossNodeConstraint:
|
|
50
|
+
"""Lowered cross-node constraint with trajectory-level evaluation.
|
|
51
|
+
|
|
52
|
+
Unlike regular LoweredNodalConstraint which operates on single-node vectors
|
|
53
|
+
and is vmapped across the trajectory, LoweredCrossNodeConstraint operates
|
|
54
|
+
on full trajectory arrays to relate multiple nodes simultaneously.
|
|
55
|
+
|
|
56
|
+
This is necessary for constraints like:
|
|
57
|
+
- Rate limits: x[k] - x[k-1] <= max_rate
|
|
58
|
+
- Multi-step dependencies: x[k] = 2*x[k-1] - x[k-2]
|
|
59
|
+
- Periodic boundaries: x[0] = x[N-1]
|
|
60
|
+
|
|
61
|
+
The function signatures differ from LoweredNodalConstraint:
|
|
62
|
+
- Regular: f(x, u, node, params) -> scalar (vmapped to handle (N, n_x))
|
|
63
|
+
- Cross-node: f(X, U, params) -> scalar (single constraint with fixed node indices)
|
|
64
|
+
|
|
65
|
+
Attributes:
|
|
66
|
+
func: Function (X, U, params) -> scalar residual
|
|
67
|
+
where X: (N, n_x), U: (N, n_u)
|
|
68
|
+
Returns constraint residual following g(X, U) <= 0 convention
|
|
69
|
+
The constraint references fixed trajectory nodes (e.g., X[5] - X[4])
|
|
70
|
+
grad_g_X: Function (X, U, params) -> (N, n_x) Jacobian wrt full state trajectory
|
|
71
|
+
This is typically sparse - most constraints only couple nearby nodes
|
|
72
|
+
grad_g_U: Function (X, U, params) -> (N, n_u) Jacobian wrt full control trajectory
|
|
73
|
+
Often zero or very sparse for cross-node state constraints
|
|
74
|
+
|
|
75
|
+
Example:
|
|
76
|
+
For rate constraint x[5] - x[4] <= r:
|
|
77
|
+
|
|
78
|
+
func(X, U, params) -> scalar residual
|
|
79
|
+
grad_g_X(X, U, params) -> (N, n_x) sparse Jacobian
|
|
80
|
+
where grad_g_X[5, :] = ∂g/∂x[5] (derivative wrt node 5)
|
|
81
|
+
and grad_g_X[4, :] = ∂g/∂x[4] (derivative wrt node 4)
|
|
82
|
+
all other entries are zero
|
|
83
|
+
|
|
84
|
+
Performance Note - Dense Jacobian Storage:
|
|
85
|
+
The Jacobian matrices grad_g_X and grad_g_U are stored as DENSE arrays with
|
|
86
|
+
shape (N, n_x) and (N, n_u), but most cross-node constraints only couple a
|
|
87
|
+
small number of nearby nodes, making these matrices extremely sparse.
|
|
88
|
+
|
|
89
|
+
For example, a rate limit constraint x[k] - x[k-1] <= r only has non-zero
|
|
90
|
+
Jacobian entries at positions [k, :] and [k-1, :]. All other N-2 rows are
|
|
91
|
+
zero but still stored in memory.
|
|
92
|
+
|
|
93
|
+
Memory impact for large problems:
|
|
94
|
+
- A single constraint with N=100 nodes, n_x=10 states requires ~8KB for
|
|
95
|
+
grad_g_X (compared to ~160 bytes if sparse with 2 non-zero rows)
|
|
96
|
+
- Multiple cross-node constraints multiply this overhead
|
|
97
|
+
- May cause issues for N > 1000 with many constraints
|
|
98
|
+
|
|
99
|
+
Performance impact:
|
|
100
|
+
- Slower autodiff (computes many zero gradients)
|
|
101
|
+
- Inefficient constraint linearization in the SCP solver
|
|
102
|
+
- Potential GPU memory limitations for very large problems
|
|
103
|
+
|
|
104
|
+
The current implementation prioritizes simplicity and compatibility with
|
|
105
|
+
JAX's autodiff over memory efficiency. Future versions may support sparse
|
|
106
|
+
Jacobian formats (COO, CSR, or custom sparse representations) if this
|
|
107
|
+
becomes a bottleneck in practice.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
func: Callable[[jnp.ndarray, jnp.ndarray, dict], jnp.ndarray]
|
|
111
|
+
grad_g_X: Callable[[jnp.ndarray, jnp.ndarray, dict], jnp.ndarray]
|
|
112
|
+
grad_g_U: Callable[[jnp.ndarray, jnp.ndarray, dict], jnp.ndarray]
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass
|
|
116
|
+
class LoweredJaxConstraints:
|
|
117
|
+
"""JAX-lowered non-convex constraints with gradient functions.
|
|
118
|
+
|
|
119
|
+
Contains constraints that have been lowered to JAX callable functions
|
|
120
|
+
with automatically computed gradients. These are used for linearization
|
|
121
|
+
in the SCP (Sequential Convex Programming) loop.
|
|
122
|
+
|
|
123
|
+
Attributes:
|
|
124
|
+
nodal: List of LoweredNodalConstraint objects. Each has `func`,
|
|
125
|
+
`grad_g_x`, `grad_g_u` callables and `nodes` list.
|
|
126
|
+
cross_node: List of LoweredCrossNodeConstraint objects. Each has
|
|
127
|
+
`func`, `grad_g_X`, `grad_g_U` for trajectory-level constraints.
|
|
128
|
+
ctcs: CTCS constraints (unchanged from input, not lowered here).
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
nodal: list[LoweredNodalConstraint] = field(default_factory=list)
|
|
132
|
+
cross_node: list[LoweredCrossNodeConstraint] = field(default_factory=list)
|
|
133
|
+
ctcs: list["CTCS"] = field(default_factory=list)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Parameter dictionary that syncs between JAX and CVXPy."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ParameterDict(dict):
|
|
7
|
+
"""Dictionary that syncs to both internal _parameters dict and CVXPy parameters.
|
|
8
|
+
|
|
9
|
+
This allows users to naturally update parameters like:
|
|
10
|
+
problem.parameters["obs_radius"] = 2.0
|
|
11
|
+
|
|
12
|
+
Changes automatically propagate to:
|
|
13
|
+
1. Internal _parameters dict (plain dict for JAX)
|
|
14
|
+
2. CVXPy parameters (for optimization)
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, problem, internal_dict, *args, **kwargs):
|
|
18
|
+
self._problem = problem
|
|
19
|
+
self._internal_dict = internal_dict # Reference to plain dict for JAX
|
|
20
|
+
super().__init__()
|
|
21
|
+
# Initialize with float enforcement by using __setitem__
|
|
22
|
+
if args:
|
|
23
|
+
other = args[0]
|
|
24
|
+
if hasattr(other, "items"):
|
|
25
|
+
for key, value in other.items():
|
|
26
|
+
self[key] = value
|
|
27
|
+
else:
|
|
28
|
+
for key, value in other:
|
|
29
|
+
self[key] = value
|
|
30
|
+
for key, value in kwargs.items():
|
|
31
|
+
self[key] = value
|
|
32
|
+
|
|
33
|
+
def __setitem__(self, key, value):
|
|
34
|
+
# Enforce float dtype to prevent int/float mismatch bugs
|
|
35
|
+
value = np.asarray(value, dtype=float)
|
|
36
|
+
super().__setitem__(key, value)
|
|
37
|
+
# Sync to internal dict for JAX
|
|
38
|
+
self._internal_dict[key] = value
|
|
39
|
+
# Sync to CVXPy if it exists
|
|
40
|
+
lowered = getattr(self._problem, "_lowered", None)
|
|
41
|
+
if lowered is not None and key in lowered.cvxpy_params:
|
|
42
|
+
lowered.cvxpy_params[key].value = value
|
|
43
|
+
|
|
44
|
+
def update(self, other=None, **kwargs):
|
|
45
|
+
"""Update multiple parameters and sync to internal dict and CVXPy."""
|
|
46
|
+
if other is not None:
|
|
47
|
+
if hasattr(other, "items"):
|
|
48
|
+
for key, value in other.items():
|
|
49
|
+
self[key] = value
|
|
50
|
+
else:
|
|
51
|
+
for key, value in other:
|
|
52
|
+
self[key] = value
|
|
53
|
+
for key, value in kwargs.items():
|
|
54
|
+
self[key] = value
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""LoweredProblem dataclass - container for all lowering outputs."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING, Dict
|
|
5
|
+
|
|
6
|
+
from openscvx.lowered.cvxpy_constraints import LoweredCvxpyConstraints
|
|
7
|
+
from openscvx.lowered.cvxpy_variables import CVXPyVariables
|
|
8
|
+
from openscvx.lowered.dynamics import Dynamics
|
|
9
|
+
from openscvx.lowered.jax_constraints import LoweredJaxConstraints
|
|
10
|
+
from openscvx.lowered.unified import UnifiedControl, UnifiedState
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
import cvxpy as cp
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class LoweredProblem:
|
|
18
|
+
"""Container for all outputs from symbolic problem lowering.
|
|
19
|
+
|
|
20
|
+
This dataclass holds all the results of lowering symbolic expressions
|
|
21
|
+
to executable JAX and CVXPy code. It provides a clean, typed interface
|
|
22
|
+
for accessing the various components needed for optimization.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
dynamics: Optimization dynamics with fields f, A, B (JAX functions)
|
|
26
|
+
dynamics_prop: Propagation dynamics with fields f, A, B
|
|
27
|
+
jax_constraints: Non-convex constraints lowered to JAX with gradients
|
|
28
|
+
cvxpy_constraints: Convex constraints lowered to CVXPy
|
|
29
|
+
x_unified: Aggregated optimization state interface
|
|
30
|
+
u_unified: Aggregated optimization control interface
|
|
31
|
+
x_prop_unified: Aggregated propagation state interface
|
|
32
|
+
ocp_vars: Typed CVXPy variables and parameters for OCP construction
|
|
33
|
+
cvxpy_params: Dict mapping user parameter names to CVXPy Parameter objects
|
|
34
|
+
|
|
35
|
+
Example:
|
|
36
|
+
After lowering a symbolic problem::
|
|
37
|
+
|
|
38
|
+
lowered = lower_symbolic_problem(
|
|
39
|
+
dynamics_aug=dynamics,
|
|
40
|
+
states_aug=states,
|
|
41
|
+
controls_aug=controls,
|
|
42
|
+
constraints=constraint_set,
|
|
43
|
+
parameters=params,
|
|
44
|
+
N=50,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Access components
|
|
48
|
+
dx_dt = lowered.dynamics.f(x, u, node, params)
|
|
49
|
+
jacobian_A = lowered.dynamics.A(x, u, node, params)
|
|
50
|
+
|
|
51
|
+
# Use CVXPy objects
|
|
52
|
+
ocp = OptimalControlProblem(settings, lowered)
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
# JAX dynamics
|
|
56
|
+
dynamics: Dynamics
|
|
57
|
+
dynamics_prop: Dynamics
|
|
58
|
+
|
|
59
|
+
# Lowered constraints (separate types for JAX vs CVXPy)
|
|
60
|
+
jax_constraints: LoweredJaxConstraints
|
|
61
|
+
cvxpy_constraints: LoweredCvxpyConstraints
|
|
62
|
+
|
|
63
|
+
# Unified interfaces
|
|
64
|
+
x_unified: UnifiedState
|
|
65
|
+
u_unified: UnifiedControl
|
|
66
|
+
x_prop_unified: UnifiedState
|
|
67
|
+
|
|
68
|
+
# CVXPy objects
|
|
69
|
+
ocp_vars: CVXPyVariables
|
|
70
|
+
cvxpy_params: Dict[str, "cp.Parameter"]
|