openscvx 0.1.2__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/_version.py +2 -2
- openscvx/augmentation/dynamics_augmentation.py +22 -7
- openscvx/config.py +310 -192
- openscvx/constraints/__init__.py +0 -3
- openscvx/constraints/ctcs.py +188 -33
- openscvx/constraints/nodal.py +150 -11
- openscvx/constraints/violation.py +12 -2
- openscvx/discretization.py +115 -37
- openscvx/dynamics.py +150 -11
- openscvx/integrators.py +135 -16
- openscvx/io.py +129 -17
- openscvx/ocp.py +86 -67
- openscvx/plotting.py +72 -215
- openscvx/post_processing.py +57 -16
- openscvx/propagation.py +155 -55
- openscvx/ptr.py +96 -57
- openscvx/results.py +153 -0
- openscvx/trajoptproblem.py +359 -114
- openscvx/utils.py +50 -0
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/METADATA +129 -41
- openscvx-0.2.1.dev0.dist-info/RECORD +27 -0
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/WHEEL +1 -1
- openscvx/constraints/boundary.py +0 -49
- openscvx-0.1.2.dist-info/RECORD +0 -27
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/top_level.txt +0 -0
openscvx/constraints/ctcs.py
CHANGED
|
@@ -1,22 +1,98 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Callable,
|
|
3
|
-
import functools
|
|
4
|
-
import types
|
|
2
|
+
from typing import Callable, Optional, Tuple, Union
|
|
5
3
|
|
|
6
|
-
from jax.lax import cond
|
|
7
4
|
import jax.numpy as jnp
|
|
5
|
+
from jax.lax import cond
|
|
6
|
+
import functools
|
|
7
|
+
import inspect
|
|
8
8
|
|
|
9
|
+
from openscvx.backend.state import State, Variable
|
|
10
|
+
from openscvx.backend.control import Control
|
|
11
|
+
from openscvx.backend.parameter import Parameter
|
|
12
|
+
|
|
13
|
+
# TODO: (norrisg) Unclear if should specify behavior for `idx`, `jacfwd` behavior for Jacobians, etc. since that logic is handled elsewhere and could change
|
|
9
14
|
|
|
10
15
|
@dataclass
|
|
11
16
|
class CTCSConstraint:
|
|
12
|
-
|
|
17
|
+
"""
|
|
18
|
+
Dataclass for continuous-time constraint satisfaction (CTCS) constraints over a trajectory interval.
|
|
19
|
+
|
|
20
|
+
A `CTCSConstraint` wraps a residual function `func(x, u)`, applies a
|
|
21
|
+
pointwise `penalty` to its outputs, and accumulates the penalized sum
|
|
22
|
+
only within a specified node interval [nodes[0], nodes[1]).
|
|
23
|
+
|
|
24
|
+
CTCS constraints are used for continuous-time constraints that need to be satisfied
|
|
25
|
+
over trajectory intervals rather than at specific nodes. The constraint function
|
|
26
|
+
should return residuals where positive values indicate constraint violations.
|
|
27
|
+
|
|
28
|
+
Usage examples:
|
|
29
|
+
|
|
30
|
+
```python
|
|
31
|
+
@ctcs
|
|
32
|
+
def g(x_, u_):
|
|
33
|
+
return 2.0 - jnp.linalg.norm(x_[:3]) # ||x[:3]|| <= 2 constraint
|
|
34
|
+
```
|
|
35
|
+
```python
|
|
36
|
+
@ctcs(penalty="huber", nodes=(0, 10), idx=2)
|
|
37
|
+
def g(x_, u_):
|
|
38
|
+
return jnp.sin(x_) + u_ # sin(x) + u <= 0 constraint
|
|
39
|
+
```
|
|
40
|
+
```python
|
|
41
|
+
@ctcs(penalty="smooth_relu", scaling=0.5)
|
|
42
|
+
def g(x_, u_):
|
|
43
|
+
return x_[0]**2 + x_[1]**2 - 1.0 # ||x||^2 <= 1 constraint
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
Or can directly wrap a function if a more lambda-function interface is desired:
|
|
47
|
+
|
|
48
|
+
```python
|
|
49
|
+
constraint = ctcs(lambda x_, u_: jnp.maximum(0, x[0] - 1.0))
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
func (Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
|
|
54
|
+
Function computing constraint residuals g(x, u).
|
|
55
|
+
- x: 1D array (state at a single node), shape (n_x,)
|
|
56
|
+
- u: 1D array (control at a single node), shape (n_u,)
|
|
57
|
+
- Additional parameters: passed as keyword arguments with names matching the parameter name plus an underscore (e.g., g_ for Parameter('g')).
|
|
58
|
+
Should return positive values for constraint violations (g(x,u) > 0 indicates violation).
|
|
59
|
+
If you want to use parameters, include them as extra arguments with the underscore naming convention.
|
|
60
|
+
penalty (Callable[[jnp.ndarray], jnp.ndarray]):
|
|
61
|
+
Penalty function applied elementwise to g's output. Used to calculate and penalize
|
|
62
|
+
constraint violation during state augmentation. Common penalties include:
|
|
63
|
+
- "squared_relu": max(0, x)² (default)
|
|
64
|
+
- "huber": smooth approximation of absolute value
|
|
65
|
+
- "smooth_relu": differentiable version of ReLU
|
|
66
|
+
nodes (Optional[Tuple[int, int]]):
|
|
67
|
+
Half-open interval (start, end) of node indices where this constraint is active.
|
|
68
|
+
If None, the penalty applies at every node.
|
|
69
|
+
idx (Optional[int]):
|
|
70
|
+
Optional index used to group CTCS constraints. Used during automatic state augmentation.
|
|
71
|
+
All CTCS constraints with the same index must be active over the same `nodes` interval
|
|
72
|
+
grad_f_x (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
|
|
73
|
+
User-supplied gradient of `func` w.r.t. state `x`, signature (x, u) -> jacobian.
|
|
74
|
+
If None, computed via `jax.jacfwd(func, argnums=0)` during state augmentation.
|
|
75
|
+
grad_f_u (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
|
|
76
|
+
User-supplied gradient of `func` w.r.t. input `u`, signature (x, u) -> jacobian.
|
|
77
|
+
If None, computed via `jax.jacfwd(func, argnums=1)` during state augmentation.
|
|
78
|
+
scaling (float):
|
|
79
|
+
Scaling factor to apply to the penalized sum.
|
|
80
|
+
"""
|
|
81
|
+
func: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] # takes (x_expr, u_expr, *param_exprs)
|
|
13
82
|
penalty: Callable[[jnp.ndarray], jnp.ndarray]
|
|
14
83
|
nodes: Optional[Tuple[int, int]] = None
|
|
15
84
|
idx: Optional[int] = None
|
|
16
|
-
grad_f_x: Optional[Callable] = None
|
|
17
|
-
grad_f_u: Optional[Callable] = None
|
|
85
|
+
grad_f_x: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
|
|
86
|
+
grad_f_u: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
|
|
87
|
+
scaling: float = 1.0
|
|
18
88
|
|
|
19
89
|
def __post_init__(self):
|
|
90
|
+
"""
|
|
91
|
+
Adapt user-provided gradients to the three-argument signature (x, u, node).
|
|
92
|
+
|
|
93
|
+
If `grad_f_x` or `grad_f_u` are given as functions of (x, u), wrap them
|
|
94
|
+
so they accept the extra `node` argument to match `__call__`.
|
|
95
|
+
"""
|
|
20
96
|
if self.grad_f_x is not None:
|
|
21
97
|
_grad_f_x = self.grad_f_x
|
|
22
98
|
self.grad_f_x = lambda x, u, nodes: _grad_f_x(x, u)
|
|
@@ -24,62 +100,141 @@ class CTCSConstraint:
|
|
|
24
100
|
_grad_f_u = self.grad_f_u
|
|
25
101
|
self.grad_f_u = lambda x, u, nodes: _grad_f_u(x, u)
|
|
26
102
|
|
|
27
|
-
def __call__(self, x, u, node):
|
|
103
|
+
def __call__(self, x, u, node: int, *params):
|
|
104
|
+
"""
|
|
105
|
+
Evaluate the penalized constraint at a given node index.
|
|
106
|
+
The penalty is summed only if `node` lies within the active interval.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
x (jnp.ndarray): State vector at this node.
|
|
110
|
+
u (jnp.ndarray): Input vector at this node.
|
|
111
|
+
node (int): Trajectory time-step index.
|
|
112
|
+
*params (tuple): Sequence of (name, value) pairs for parameters.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
jnp.ndarray or float:
|
|
116
|
+
The total penalty (sum over selected residuals) if inside interval,
|
|
117
|
+
otherwise zero.
|
|
118
|
+
"""
|
|
119
|
+
x_expr = x.expr if isinstance(x, (State, Variable)) else x
|
|
120
|
+
u_expr = u.expr if isinstance(u, (Control, Variable)) else u
|
|
121
|
+
|
|
122
|
+
# Inspect function signature for expected parameter names
|
|
123
|
+
func_signature = inspect.signature(self.func)
|
|
124
|
+
expected_args = set(func_signature.parameters.keys())
|
|
125
|
+
|
|
126
|
+
# Only include params whose name (with underscore) is in the function's signature
|
|
127
|
+
filtered_params = {
|
|
128
|
+
f"{name}_": value for name, value in params if f"{name}_" in expected_args
|
|
129
|
+
}
|
|
130
|
+
|
|
28
131
|
return cond(
|
|
29
|
-
jnp.all((self.nodes[0] <= node) & (node < self.nodes[1]))
|
|
30
|
-
|
|
132
|
+
jnp.all((self.nodes[0] <= node) & (node < self.nodes[1]))
|
|
133
|
+
if self.nodes is not None else True,
|
|
134
|
+
lambda _: self.scaling * jnp.sum(self.penalty(self.func(x_expr, u_expr, **filtered_params))),
|
|
31
135
|
lambda _: 0.0,
|
|
32
136
|
operand=None,
|
|
33
137
|
)
|
|
34
138
|
|
|
35
139
|
|
|
36
140
|
def ctcs(
|
|
37
|
-
_func=None,
|
|
141
|
+
_func: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
|
|
38
142
|
*,
|
|
39
|
-
penalty: str = "squared_relu",
|
|
40
|
-
nodes: Optional[
|
|
143
|
+
penalty: Union[str, Callable[[jnp.ndarray], jnp.ndarray]] = "squared_relu",
|
|
144
|
+
nodes: Optional[Tuple[int, int]] = None,
|
|
41
145
|
idx: Optional[int] = None,
|
|
42
|
-
grad_f_x: Optional[Callable] = None,
|
|
43
|
-
grad_f_u: Optional[Callable] = None,
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
146
|
+
grad_f_x: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
|
|
147
|
+
grad_f_u: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
|
|
148
|
+
scaling: float = 1.0,
|
|
149
|
+
) -> Union[Callable, CTCSConstraint]:
|
|
150
|
+
"""
|
|
151
|
+
Decorator to build a CTCSConstraint from a raw constraint function.
|
|
152
|
+
|
|
153
|
+
Supports built-in penalties by name or a custom penalty function.
|
|
154
|
+
|
|
155
|
+
Usage examples:
|
|
156
|
+
|
|
157
|
+
```python
|
|
158
|
+
@ctcs
|
|
159
|
+
def g(x, u):
|
|
160
|
+
return jnp.maximum(0, x - 1)
|
|
161
|
+
```
|
|
162
|
+
```python
|
|
163
|
+
@ctcs("huber", nodes=[(0, 10)], idx=2)
|
|
164
|
+
def g2(x, u):
|
|
165
|
+
return jnp.sin(x) + u
|
|
166
|
+
```
|
|
167
|
+
|
|
168
|
+
Or can directly wrap a function if a more lambda-function interface is desired:
|
|
169
|
+
|
|
170
|
+
```python
|
|
171
|
+
constraint = [ctcs(lambda x, u: ...)]
|
|
172
|
+
```
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
_func (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
|
|
176
|
+
The function to wrap; provided automatically when using bare @ctcs.
|
|
177
|
+
penalty (Union[str, Callable[[jnp.ndarray], jnp.ndarray]]):
|
|
178
|
+
Name of a built-in penalty ('squared_relu', 'huber', 'smooth_relu')
|
|
179
|
+
or a custom elementwise penalty function.
|
|
180
|
+
nodes (Optional[Tuple[int, int]]):
|
|
181
|
+
Half-open interval (start, end) of node indices where this constraint is active.
|
|
182
|
+
If None, the penalty applies at every node.
|
|
183
|
+
idx (Optional[int]):
|
|
184
|
+
Optional index used to group CTCS constraints. Used during automatic state augmentation.
|
|
185
|
+
All CTCS constraints with the same index must be active over the same `nodes` interval
|
|
186
|
+
grad_f_x (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
|
|
187
|
+
User-supplied gradient of `func` w.r.t state `x`.
|
|
188
|
+
If None, computed via `jax.jacfwd(func, argnums=0)` during state augmentation.
|
|
189
|
+
grad_f_u (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
|
|
190
|
+
User-supplied gradient of `func` w.r.t input `u`.
|
|
191
|
+
If None, computed via `jax.jacfwd(func, argnums=1)` during state augmentation.
|
|
192
|
+
scaling (float):
|
|
193
|
+
Scaling factor to apply to the penalized sum.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
Union[Callable, CTCSConstraint]
|
|
197
|
+
A decorator if called without a function, or a CTCSConstraint instance
|
|
198
|
+
when applied to a function.
|
|
199
|
+
|
|
200
|
+
Raises:
|
|
201
|
+
ValueError: If `penalty` string is not one of the supported names.
|
|
50
202
|
"""
|
|
51
203
|
# prepare penalty function once
|
|
52
204
|
if penalty == "squared_relu":
|
|
53
205
|
pen = lambda x: jnp.maximum(0, x) ** 2
|
|
54
206
|
elif penalty == "huber":
|
|
55
207
|
delta = 0.25
|
|
56
|
-
|
|
57
|
-
def pen(x):
|
|
58
|
-
r = jnp.maximum(0, x)
|
|
59
|
-
return jnp.where(r < delta, 0.5 * r**2, r - 0.5 * delta)
|
|
208
|
+
def pen(x): return jnp.where(x < delta, 0.5 * x**2, x - 0.5 * delta)
|
|
60
209
|
elif penalty == "smooth_relu":
|
|
61
210
|
c = 1e-8
|
|
62
211
|
pen = lambda x: (jnp.maximum(0, x) ** 2 + c**2) ** 0.5 - c
|
|
63
|
-
elif
|
|
212
|
+
elif callable(penalty):
|
|
64
213
|
pen = penalty
|
|
65
214
|
else:
|
|
66
215
|
raise ValueError(f"Unknown penalty {penalty}")
|
|
67
216
|
|
|
68
|
-
def decorator(f: Callable
|
|
69
|
-
#
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
217
|
+
def decorator(f: Callable):
|
|
218
|
+
@functools.wraps(f) # preserves name, docstring, and signature on the wrapper
|
|
219
|
+
def wrapper(*args, **kwargs):
|
|
220
|
+
return f(*args, **kwargs)
|
|
221
|
+
|
|
222
|
+
# Now attach original function as attribute if needed
|
|
223
|
+
wrapper._original_func = f
|
|
224
|
+
|
|
225
|
+
# Return your CTCSConstraint with the original function, but keep the wrapper around it
|
|
226
|
+
constraint = CTCSConstraint(
|
|
227
|
+
func=wrapper,
|
|
73
228
|
penalty=pen,
|
|
74
229
|
nodes=nodes,
|
|
75
230
|
idx=idx,
|
|
76
231
|
grad_f_x=grad_f_x,
|
|
77
232
|
grad_f_u=grad_f_u,
|
|
233
|
+
scaling=scaling,
|
|
78
234
|
)
|
|
235
|
+
return constraint
|
|
79
236
|
|
|
80
|
-
# if called as @ctcs or @ctcs(...), _func will be None and we return decorator
|
|
81
237
|
if _func is None:
|
|
82
238
|
return decorator
|
|
83
|
-
# if called as ctcs(func), we immediately decorate
|
|
84
239
|
else:
|
|
85
240
|
return decorator(_func)
|
openscvx/constraints/nodal.py
CHANGED
|
@@ -1,20 +1,109 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Callable, Optional, List
|
|
2
|
+
from typing import Callable, Optional, List, Union
|
|
3
3
|
|
|
4
4
|
import jax.numpy as jnp
|
|
5
|
-
from jax import
|
|
5
|
+
from jax import vmap, jacfwd
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
@dataclass
|
|
9
9
|
class NodalConstraint:
|
|
10
|
+
"""
|
|
11
|
+
Encapsulates a constraint function applied at specific trajectory nodes.
|
|
12
|
+
|
|
13
|
+
A `NodalConstraint` wraps a function `g(x, u)` that computes constraint residuals
|
|
14
|
+
for given state `x` and input `u`. It can optionally apply only at
|
|
15
|
+
a subset of trajectory nodes, support vectorized evaluation across nodes,
|
|
16
|
+
and integrate with convex solvers when `convex=True`.
|
|
17
|
+
|
|
18
|
+
**Expected input types:**
|
|
19
|
+
|
|
20
|
+
| Case | x, u type/shape |
|
|
21
|
+
|-----------------------------|-------------------------------------------------|
|
|
22
|
+
| convex=False, vectorized=False | 1D arrays, shape (n_x,), (n_u,) (single node) |
|
|
23
|
+
| convex=False, vectorized=True | 2D arrays, shape (N, n_x), (N, n_u) (all nodes) |
|
|
24
|
+
| convex=True, vectorized=False | list of cvxpy variables, one per node |
|
|
25
|
+
| convex=True, vectorized=True | list of cvxpy variables, one per node |
|
|
26
|
+
|
|
27
|
+
**Expected output:**
|
|
28
|
+
|
|
29
|
+
| Case | Output type |
|
|
30
|
+
|-----------------------------|--------------------------------------------------|
|
|
31
|
+
| convex=False, vectorized=False | float (single node) |
|
|
32
|
+
| convex=False, vectorized=True | float array (per node) |
|
|
33
|
+
| convex=True, vectorized=False | cvxpy expression (single node) |
|
|
34
|
+
| convex=True, vectorized=True | list of cvxpy expressions (one per node) |
|
|
35
|
+
|
|
36
|
+
Nonconvex examples:
|
|
37
|
+
|
|
38
|
+
```python
|
|
39
|
+
@nodal
|
|
40
|
+
def g(x_, u_):
|
|
41
|
+
return 1 - x_[0] <= 0
|
|
42
|
+
```
|
|
43
|
+
```python
|
|
44
|
+
@nodal(nodes=[0, 3])
|
|
45
|
+
def g(x_, u_):
|
|
46
|
+
return jnp.linalg.norm(x_) - 1.0
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
Or can directly wrap a function if a more lambda-function interface is desired:
|
|
50
|
+
|
|
51
|
+
```python
|
|
52
|
+
constraint = nodal(lambda x_, u_: 1 - x_[0])
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
Convex Examples:
|
|
56
|
+
|
|
57
|
+
```python
|
|
58
|
+
@nodal(convex=True)
|
|
59
|
+
def g(x_, u_):
|
|
60
|
+
return cp.norm(x_) <= 1.0 # cvxpy expression following DPP rules
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
func (Callable):
|
|
65
|
+
The user-supplied constraint function. The expected input and output types depend on the values of `convex` and `vectorized`:
|
|
66
|
+
|
|
67
|
+
| Case | x, u type/shape | Output type |
|
|
68
|
+
|-------------------------------|---------------------------------------------------|--------------------------------------------------|
|
|
69
|
+
| convex=False, vectorized=False | 1D arrays, shape (n_x,), (n_u,) (single node) | float (single node) |
|
|
70
|
+
| convex=False, vectorized=True | 2D arrays, shape (N, n_x), (N, n_u) (all nodes) | float array (per node) |
|
|
71
|
+
| convex=True, vectorized=False | list of cvxpy variables, one per node | cvxpy expression (single node) |
|
|
72
|
+
| convex=True, vectorized=True | list of cvxpy variables, one per node | list of cvxpy expressions (one per node) |
|
|
73
|
+
|
|
74
|
+
Additional parameters: always passed as keyword arguments with names matching the parameter name plus an underscore (e.g., `g_` for `Parameter('g')`).
|
|
75
|
+
For nonconvex constraints, the function should return constraint residuals (g(x,u) <= 0). For convex constraints, the function should return a cvxpy expression.
|
|
76
|
+
nodes (Optional[List[int]]):
|
|
77
|
+
Specific node indices where this constraint applies. If None, applies at all nodes.
|
|
78
|
+
convex (bool):
|
|
79
|
+
If True, the provided cvxpy.expression is directly passed to the cvxpy.problem.
|
|
80
|
+
vectorized (bool):
|
|
81
|
+
If False, automatically vectorizes `func` and its jacobians over
|
|
82
|
+
the node dimension using `jax.vmap`. If True, assumes `func` already
|
|
83
|
+
handles vectorization.
|
|
84
|
+
grad_g_x (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
|
|
85
|
+
User-supplied gradient of `func` wrt `x`. If None, computed via
|
|
86
|
+
`jax.jacfwd(func, argnums=0)`.
|
|
87
|
+
grad_g_u (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
|
|
88
|
+
User-supplied gradient of `func` wrt `u`. If None, computed via
|
|
89
|
+
`jax.jacfwd(func, argnums=1)`.
|
|
90
|
+
"""
|
|
91
|
+
|
|
10
92
|
func: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
|
|
11
93
|
nodes: Optional[List[int]] = None
|
|
12
94
|
convex: bool = False
|
|
13
95
|
vectorized: bool = False
|
|
14
|
-
grad_g_x: Optional[Callable] = None
|
|
15
|
-
grad_g_u: Optional[Callable] = None
|
|
96
|
+
grad_g_x: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
|
|
97
|
+
grad_g_u: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
|
|
16
98
|
|
|
17
99
|
def __post_init__(self):
|
|
100
|
+
"""Initialize gradients and vectorization after instance creation.
|
|
101
|
+
|
|
102
|
+
If the constraint is not convex, this method:
|
|
103
|
+
1. Sets up the constraint function
|
|
104
|
+
2. Computes gradients if not provided
|
|
105
|
+
3. Vectorizes the functions if needed
|
|
106
|
+
"""
|
|
18
107
|
if not self.convex:
|
|
19
108
|
# single-node but still using JAX
|
|
20
109
|
self.g = self.func
|
|
@@ -29,21 +118,66 @@ class NodalConstraint:
|
|
|
29
118
|
# if convex=True assume an external solver (e.g. CVX) will handle it
|
|
30
119
|
|
|
31
120
|
def __call__(self, x: jnp.ndarray, u: jnp.ndarray):
|
|
121
|
+
"""Evaluate the constraint function at the given state and control.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
x (jnp.ndarray): The state vector.
|
|
125
|
+
u (jnp.ndarray): The control vector.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
jnp.ndarray: The constraint violation values.
|
|
129
|
+
"""
|
|
32
130
|
return self.func(x, u)
|
|
33
131
|
|
|
34
132
|
|
|
35
133
|
def nodal(
|
|
36
|
-
_func=None,
|
|
134
|
+
_func: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
|
|
37
135
|
*,
|
|
38
136
|
nodes: Optional[List[int]] = None,
|
|
39
137
|
convex: bool = False,
|
|
40
138
|
vectorized: bool = False,
|
|
41
|
-
grad_g_x: Optional[Callable] = None,
|
|
42
|
-
grad_g_u: Optional[Callable] = None,
|
|
43
|
-
):
|
|
44
|
-
|
|
139
|
+
grad_g_x: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
|
|
140
|
+
grad_g_u: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
|
|
141
|
+
) -> Union[Callable, NodalConstraint]:
|
|
142
|
+
"""
|
|
143
|
+
Decorator to build a `NodalConstraint` from a constraint function.
|
|
144
|
+
|
|
145
|
+
Usage examples:
|
|
146
|
+
|
|
147
|
+
```python
|
|
148
|
+
@nodal
|
|
149
|
+
def g(x, u):
|
|
150
|
+
...
|
|
151
|
+
```
|
|
152
|
+
```python
|
|
153
|
+
@nodal(nodes=[0, -1], convex=True, vectorized=False)
|
|
154
|
+
def g(x, u):
|
|
155
|
+
...
|
|
156
|
+
```
|
|
157
|
+
|
|
158
|
+
Or can directly wrap a function if a more lambda-function interface is desired:
|
|
159
|
+
|
|
160
|
+
```python
|
|
161
|
+
constraint = nodal(lambda x, u: ...)
|
|
162
|
+
```
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
_func (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
|
|
166
|
+
The function to wrap; populated automatically when using bare @nodal.
|
|
167
|
+
When `convex=False`, this is a standard function g(x, u) that should
|
|
168
|
+
return constraint residuals (g(x,u) <= 0). When `convex=True`, this
|
|
169
|
+
must be a cvxpy expression following the DPP ruleset.
|
|
170
|
+
nodes (Optional[List[int]]):
|
|
171
|
+
Node indices where the constraint applies; default None applies to all.
|
|
172
|
+
convex (bool):
|
|
173
|
+
If True, the provided cvxpy.expression is directly passed to the cvxpy.problem.
|
|
174
|
+
vectorized (bool):
|
|
175
|
+
If False, auto-vectorize over nodes using `jax.vmap`. If True, assumes
|
|
176
|
+
the function already handles vectorization.
|
|
177
|
+
"""
|
|
178
|
+
def decorator(f: Callable):
|
|
45
179
|
return NodalConstraint(
|
|
46
|
-
func=f,
|
|
180
|
+
func=f,
|
|
47
181
|
nodes=nodes,
|
|
48
182
|
convex=convex,
|
|
49
183
|
vectorized=vectorized,
|
|
@@ -51,4 +185,9 @@ def nodal(
|
|
|
51
185
|
grad_g_u=grad_g_u,
|
|
52
186
|
)
|
|
53
187
|
|
|
54
|
-
|
|
188
|
+
if _func is None:
|
|
189
|
+
# Called with arguments, e.g., @nodal(nodes=[0, 1])
|
|
190
|
+
return decorator
|
|
191
|
+
else:
|
|
192
|
+
# Called as a bare decorator, e.g., @nodal
|
|
193
|
+
return decorator(_func)
|
|
@@ -9,6 +9,16 @@ from openscvx.constraints.ctcs import CTCSConstraint
|
|
|
9
9
|
|
|
10
10
|
@dataclass
|
|
11
11
|
class CTCSViolation:
|
|
12
|
+
"""Class representing a continuous-time constraint satisfaction (CTCS) violation.
|
|
13
|
+
|
|
14
|
+
This class holds the constraint function and its gradients for computing
|
|
15
|
+
constraint violations in continuous-time optimization problems.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
g (Callable): The constraint function that computes violations.
|
|
19
|
+
g_grad_x (Optional[Callable]): Gradient of g with respect to state x.
|
|
20
|
+
g_grad_u (Optional[Callable]): Gradient of g with respect to control u.
|
|
21
|
+
"""
|
|
12
22
|
g: Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]
|
|
13
23
|
g_grad_x: Optional[Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]] = None
|
|
14
24
|
g_grad_u: Optional[Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]] = None
|
|
@@ -35,8 +45,8 @@ def get_g_grad_u(constraints_ctcs: List[CTCSConstraint]) -> Callable[[jnp.ndarra
|
|
|
35
45
|
|
|
36
46
|
|
|
37
47
|
def get_g_func(constraints_ctcs: List[CTCSConstraint]) -> Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]:
|
|
38
|
-
def g_func(x: jnp.array, u: jnp.array, node: int) -> jnp.array:
|
|
39
|
-
return sum(c(x, u, node) for c in constraints_ctcs)
|
|
48
|
+
def g_func(x: jnp.array, u: jnp.array, node: int, *params) -> jnp.array:
|
|
49
|
+
return sum(c(x, u, node, *params) for c in constraints_ctcs)
|
|
40
50
|
|
|
41
51
|
return g_func
|
|
42
52
|
|