openscvx 0.1.2__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

@@ -1,22 +1,98 @@
1
1
  from dataclasses import dataclass
2
- from typing import Callable, Sequence, Tuple, Optional
3
- import functools
4
- import types
2
+ from typing import Callable, Optional, Tuple, Union
5
3
 
6
- from jax.lax import cond
7
4
  import jax.numpy as jnp
5
+ from jax.lax import cond
6
+ import functools
7
+ import inspect
8
8
 
9
+ from openscvx.backend.state import State, Variable
10
+ from openscvx.backend.control import Control
11
+ from openscvx.backend.parameter import Parameter
12
+
13
+ # TODO: (norrisg) Unclear if should specify behavior for `idx`, `jacfwd` behavior for Jacobians, etc. since that logic is handled elsewhere and could change
9
14
 
10
15
  @dataclass
11
16
  class CTCSConstraint:
12
- func: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
17
+ """
18
+ Dataclass for continuous-time constraint satisfaction (CTCS) constraints over a trajectory interval.
19
+
20
+ A `CTCSConstraint` wraps a residual function `func(x, u)`, applies a
21
+ pointwise `penalty` to its outputs, and accumulates the penalized sum
22
+ only within a specified node interval [nodes[0], nodes[1]).
23
+
24
+ CTCS constraints are used for continuous-time constraints that need to be satisfied
25
+ over trajectory intervals rather than at specific nodes. The constraint function
26
+ should return residuals where positive values indicate constraint violations.
27
+
28
+ Usage examples:
29
+
30
+ ```python
31
+ @ctcs
32
+ def g(x_, u_):
33
+ return 2.0 - jnp.linalg.norm(x_[:3]) # ||x[:3]|| <= 2 constraint
34
+ ```
35
+ ```python
36
+ @ctcs(penalty="huber", nodes=(0, 10), idx=2)
37
+ def g(x_, u_):
38
+ return jnp.sin(x_) + u_ # sin(x) + u <= 0 constraint
39
+ ```
40
+ ```python
41
+ @ctcs(penalty="smooth_relu", scaling=0.5)
42
+ def g(x_, u_):
43
+ return x_[0]**2 + x_[1]**2 - 1.0 # ||x||^2 <= 1 constraint
44
+ ```
45
+
46
+ Or can directly wrap a function if a more lambda-function interface is desired:
47
+
48
+ ```python
49
+ constraint = ctcs(lambda x_, u_: jnp.maximum(0, x[0] - 1.0))
50
+ ```
51
+
52
+ Args:
53
+ func (Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
54
+ Function computing constraint residuals g(x, u).
55
+ - x: 1D array (state at a single node), shape (n_x,)
56
+ - u: 1D array (control at a single node), shape (n_u,)
57
+ - Additional parameters: passed as keyword arguments with names matching the parameter name plus an underscore (e.g., g_ for Parameter('g')).
58
+ Should return positive values for constraint violations (g(x,u) > 0 indicates violation).
59
+ If you want to use parameters, include them as extra arguments with the underscore naming convention.
60
+ penalty (Callable[[jnp.ndarray], jnp.ndarray]):
61
+ Penalty function applied elementwise to g's output. Used to calculate and penalize
62
+ constraint violation during state augmentation. Common penalties include:
63
+ - "squared_relu": max(0, x)² (default)
64
+ - "huber": smooth approximation of absolute value
65
+ - "smooth_relu": differentiable version of ReLU
66
+ nodes (Optional[Tuple[int, int]]):
67
+ Half-open interval (start, end) of node indices where this constraint is active.
68
+ If None, the penalty applies at every node.
69
+ idx (Optional[int]):
70
+ Optional index used to group CTCS constraints. Used during automatic state augmentation.
71
+ All CTCS constraints with the same index must be active over the same `nodes` interval
72
+ grad_f_x (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
73
+ User-supplied gradient of `func` w.r.t. state `x`, signature (x, u) -> jacobian.
74
+ If None, computed via `jax.jacfwd(func, argnums=0)` during state augmentation.
75
+ grad_f_u (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
76
+ User-supplied gradient of `func` w.r.t. input `u`, signature (x, u) -> jacobian.
77
+ If None, computed via `jax.jacfwd(func, argnums=1)` during state augmentation.
78
+ scaling (float):
79
+ Scaling factor to apply to the penalized sum.
80
+ """
81
+ func: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] # takes (x_expr, u_expr, *param_exprs)
13
82
  penalty: Callable[[jnp.ndarray], jnp.ndarray]
14
83
  nodes: Optional[Tuple[int, int]] = None
15
84
  idx: Optional[int] = None
16
- grad_f_x: Optional[Callable] = None
17
- grad_f_u: Optional[Callable] = None
85
+ grad_f_x: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
86
+ grad_f_u: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
87
+ scaling: float = 1.0
18
88
 
19
89
  def __post_init__(self):
90
+ """
91
+ Adapt user-provided gradients to the three-argument signature (x, u, node).
92
+
93
+ If `grad_f_x` or `grad_f_u` are given as functions of (x, u), wrap them
94
+ so they accept the extra `node` argument to match `__call__`.
95
+ """
20
96
  if self.grad_f_x is not None:
21
97
  _grad_f_x = self.grad_f_x
22
98
  self.grad_f_x = lambda x, u, nodes: _grad_f_x(x, u)
@@ -24,62 +100,141 @@ class CTCSConstraint:
24
100
  _grad_f_u = self.grad_f_u
25
101
  self.grad_f_u = lambda x, u, nodes: _grad_f_u(x, u)
26
102
 
27
- def __call__(self, x, u, node):
103
+ def __call__(self, x, u, node: int, *params):
104
+ """
105
+ Evaluate the penalized constraint at a given node index.
106
+ The penalty is summed only if `node` lies within the active interval.
107
+
108
+ Args:
109
+ x (jnp.ndarray): State vector at this node.
110
+ u (jnp.ndarray): Input vector at this node.
111
+ node (int): Trajectory time-step index.
112
+ *params (tuple): Sequence of (name, value) pairs for parameters.
113
+
114
+ Returns:
115
+ jnp.ndarray or float:
116
+ The total penalty (sum over selected residuals) if inside interval,
117
+ otherwise zero.
118
+ """
119
+ x_expr = x.expr if isinstance(x, (State, Variable)) else x
120
+ u_expr = u.expr if isinstance(u, (Control, Variable)) else u
121
+
122
+ # Inspect function signature for expected parameter names
123
+ func_signature = inspect.signature(self.func)
124
+ expected_args = set(func_signature.parameters.keys())
125
+
126
+ # Only include params whose name (with underscore) is in the function's signature
127
+ filtered_params = {
128
+ f"{name}_": value for name, value in params if f"{name}_" in expected_args
129
+ }
130
+
28
131
  return cond(
29
- jnp.all((self.nodes[0] <= node) & (node < self.nodes[1])),
30
- lambda _: jnp.sum(self.penalty(self.func(x, u))),
132
+ jnp.all((self.nodes[0] <= node) & (node < self.nodes[1]))
133
+ if self.nodes is not None else True,
134
+ lambda _: self.scaling * jnp.sum(self.penalty(self.func(x_expr, u_expr, **filtered_params))),
31
135
  lambda _: 0.0,
32
136
  operand=None,
33
137
  )
34
138
 
35
139
 
36
140
  def ctcs(
37
- _func=None,
141
+ _func: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
38
142
  *,
39
- penalty: str = "squared_relu",
40
- nodes: Optional[Sequence[Tuple[int, int]]] = None,
143
+ penalty: Union[str, Callable[[jnp.ndarray], jnp.ndarray]] = "squared_relu",
144
+ nodes: Optional[Tuple[int, int]] = None,
41
145
  idx: Optional[int] = None,
42
- grad_f_x: Optional[Callable] = None,
43
- grad_f_u: Optional[Callable] = None,
44
- ):
45
- """Decorator to mark a function as a 'ctcs' constraint.
46
-
47
- Use as:
48
- @ctcs(nodes=[(0,10)], idx=2, penalty='huber')
49
- def my_constraint(x,u): ...
146
+ grad_f_x: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
147
+ grad_f_u: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
148
+ scaling: float = 1.0,
149
+ ) -> Union[Callable, CTCSConstraint]:
150
+ """
151
+ Decorator to build a CTCSConstraint from a raw constraint function.
152
+
153
+ Supports built-in penalties by name or a custom penalty function.
154
+
155
+ Usage examples:
156
+
157
+ ```python
158
+ @ctcs
159
+ def g(x, u):
160
+ return jnp.maximum(0, x - 1)
161
+ ```
162
+ ```python
163
+ @ctcs("huber", nodes=[(0, 10)], idx=2)
164
+ def g2(x, u):
165
+ return jnp.sin(x) + u
166
+ ```
167
+
168
+ Or can directly wrap a function if a more lambda-function interface is desired:
169
+
170
+ ```python
171
+ constraint = [ctcs(lambda x, u: ...)]
172
+ ```
173
+
174
+ Args:
175
+ _func (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
176
+ The function to wrap; provided automatically when using bare @ctcs.
177
+ penalty (Union[str, Callable[[jnp.ndarray], jnp.ndarray]]):
178
+ Name of a built-in penalty ('squared_relu', 'huber', 'smooth_relu')
179
+ or a custom elementwise penalty function.
180
+ nodes (Optional[Tuple[int, int]]):
181
+ Half-open interval (start, end) of node indices where this constraint is active.
182
+ If None, the penalty applies at every node.
183
+ idx (Optional[int]):
184
+ Optional index used to group CTCS constraints. Used during automatic state augmentation.
185
+ All CTCS constraints with the same index must be active over the same `nodes` interval
186
+ grad_f_x (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
187
+ User-supplied gradient of `func` w.r.t state `x`.
188
+ If None, computed via `jax.jacfwd(func, argnums=0)` during state augmentation.
189
+ grad_f_u (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
190
+ User-supplied gradient of `func` w.r.t input `u`.
191
+ If None, computed via `jax.jacfwd(func, argnums=1)` during state augmentation.
192
+ scaling (float):
193
+ Scaling factor to apply to the penalized sum.
194
+
195
+ Returns:
196
+ Union[Callable, CTCSConstraint]
197
+ A decorator if called without a function, or a CTCSConstraint instance
198
+ when applied to a function.
199
+
200
+ Raises:
201
+ ValueError: If `penalty` string is not one of the supported names.
50
202
  """
51
203
  # prepare penalty function once
52
204
  if penalty == "squared_relu":
53
205
  pen = lambda x: jnp.maximum(0, x) ** 2
54
206
  elif penalty == "huber":
55
207
  delta = 0.25
56
-
57
- def pen(x):
58
- r = jnp.maximum(0, x)
59
- return jnp.where(r < delta, 0.5 * r**2, r - 0.5 * delta)
208
+ def pen(x): return jnp.where(x < delta, 0.5 * x**2, x - 0.5 * delta)
60
209
  elif penalty == "smooth_relu":
61
210
  c = 1e-8
62
211
  pen = lambda x: (jnp.maximum(0, x) ** 2 + c**2) ** 0.5 - c
63
- elif isinstance(penalty, types.LambdaType):
212
+ elif callable(penalty):
64
213
  pen = penalty
65
214
  else:
66
215
  raise ValueError(f"Unknown penalty {penalty}")
67
216
 
68
- def decorator(f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
69
- # wrap so name, doc, signature stay on f
70
- wrapped = functools.wraps(f)(f)
71
- return CTCSConstraint(
72
- func=wrapped,
217
+ def decorator(f: Callable):
218
+ @functools.wraps(f) # preserves name, docstring, and signature on the wrapper
219
+ def wrapper(*args, **kwargs):
220
+ return f(*args, **kwargs)
221
+
222
+ # Now attach original function as attribute if needed
223
+ wrapper._original_func = f
224
+
225
+ # Return your CTCSConstraint with the original function, but keep the wrapper around it
226
+ constraint = CTCSConstraint(
227
+ func=wrapper,
73
228
  penalty=pen,
74
229
  nodes=nodes,
75
230
  idx=idx,
76
231
  grad_f_x=grad_f_x,
77
232
  grad_f_u=grad_f_u,
233
+ scaling=scaling,
78
234
  )
235
+ return constraint
79
236
 
80
- # if called as @ctcs or @ctcs(...), _func will be None and we return decorator
81
237
  if _func is None:
82
238
  return decorator
83
- # if called as ctcs(func), we immediately decorate
84
239
  else:
85
240
  return decorator(_func)
@@ -1,20 +1,109 @@
1
1
  from dataclasses import dataclass
2
- from typing import Callable, Optional, List
2
+ from typing import Callable, Optional, List, Union
3
3
 
4
4
  import jax.numpy as jnp
5
- from jax import jit, vmap, jacfwd
5
+ from jax import vmap, jacfwd
6
6
 
7
7
 
8
8
  @dataclass
9
9
  class NodalConstraint:
10
+ """
11
+ Encapsulates a constraint function applied at specific trajectory nodes.
12
+
13
+ A `NodalConstraint` wraps a function `g(x, u)` that computes constraint residuals
14
+ for given state `x` and input `u`. It can optionally apply only at
15
+ a subset of trajectory nodes, support vectorized evaluation across nodes,
16
+ and integrate with convex solvers when `convex=True`.
17
+
18
+ **Expected input types:**
19
+
20
+ | Case | x, u type/shape |
21
+ |-----------------------------|-------------------------------------------------|
22
+ | convex=False, vectorized=False | 1D arrays, shape (n_x,), (n_u,) (single node) |
23
+ | convex=False, vectorized=True | 2D arrays, shape (N, n_x), (N, n_u) (all nodes) |
24
+ | convex=True, vectorized=False | list of cvxpy variables, one per node |
25
+ | convex=True, vectorized=True | list of cvxpy variables, one per node |
26
+
27
+ **Expected output:**
28
+
29
+ | Case | Output type |
30
+ |-----------------------------|--------------------------------------------------|
31
+ | convex=False, vectorized=False | float (single node) |
32
+ | convex=False, vectorized=True | float array (per node) |
33
+ | convex=True, vectorized=False | cvxpy expression (single node) |
34
+ | convex=True, vectorized=True | list of cvxpy expressions (one per node) |
35
+
36
+ Nonconvex examples:
37
+
38
+ ```python
39
+ @nodal
40
+ def g(x_, u_):
41
+ return 1 - x_[0] <= 0
42
+ ```
43
+ ```python
44
+ @nodal(nodes=[0, 3])
45
+ def g(x_, u_):
46
+ return jnp.linalg.norm(x_) - 1.0
47
+ ```
48
+
49
+ Or can directly wrap a function if a more lambda-function interface is desired:
50
+
51
+ ```python
52
+ constraint = nodal(lambda x_, u_: 1 - x_[0])
53
+ ```
54
+
55
+ Convex Examples:
56
+
57
+ ```python
58
+ @nodal(convex=True)
59
+ def g(x_, u_):
60
+ return cp.norm(x_) <= 1.0 # cvxpy expression following DPP rules
61
+ ```
62
+
63
+ Args:
64
+ func (Callable):
65
+ The user-supplied constraint function. The expected input and output types depend on the values of `convex` and `vectorized`:
66
+
67
+ | Case | x, u type/shape | Output type |
68
+ |-------------------------------|---------------------------------------------------|--------------------------------------------------|
69
+ | convex=False, vectorized=False | 1D arrays, shape (n_x,), (n_u,) (single node) | float (single node) |
70
+ | convex=False, vectorized=True | 2D arrays, shape (N, n_x), (N, n_u) (all nodes) | float array (per node) |
71
+ | convex=True, vectorized=False | list of cvxpy variables, one per node | cvxpy expression (single node) |
72
+ | convex=True, vectorized=True | list of cvxpy variables, one per node | list of cvxpy expressions (one per node) |
73
+
74
+ Additional parameters: always passed as keyword arguments with names matching the parameter name plus an underscore (e.g., `g_` for `Parameter('g')`).
75
+ For nonconvex constraints, the function should return constraint residuals (g(x,u) <= 0). For convex constraints, the function should return a cvxpy expression.
76
+ nodes (Optional[List[int]]):
77
+ Specific node indices where this constraint applies. If None, applies at all nodes.
78
+ convex (bool):
79
+ If True, the provided cvxpy.expression is directly passed to the cvxpy.problem.
80
+ vectorized (bool):
81
+ If False, automatically vectorizes `func` and its jacobians over
82
+ the node dimension using `jax.vmap`. If True, assumes `func` already
83
+ handles vectorization.
84
+ grad_g_x (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
85
+ User-supplied gradient of `func` wrt `x`. If None, computed via
86
+ `jax.jacfwd(func, argnums=0)`.
87
+ grad_g_u (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
88
+ User-supplied gradient of `func` wrt `u`. If None, computed via
89
+ `jax.jacfwd(func, argnums=1)`.
90
+ """
91
+
10
92
  func: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
11
93
  nodes: Optional[List[int]] = None
12
94
  convex: bool = False
13
95
  vectorized: bool = False
14
- grad_g_x: Optional[Callable] = None
15
- grad_g_u: Optional[Callable] = None
96
+ grad_g_x: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
97
+ grad_g_u: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
16
98
 
17
99
  def __post_init__(self):
100
+ """Initialize gradients and vectorization after instance creation.
101
+
102
+ If the constraint is not convex, this method:
103
+ 1. Sets up the constraint function
104
+ 2. Computes gradients if not provided
105
+ 3. Vectorizes the functions if needed
106
+ """
18
107
  if not self.convex:
19
108
  # single-node but still using JAX
20
109
  self.g = self.func
@@ -29,21 +118,66 @@ class NodalConstraint:
29
118
  # if convex=True assume an external solver (e.g. CVX) will handle it
30
119
 
31
120
  def __call__(self, x: jnp.ndarray, u: jnp.ndarray):
121
+ """Evaluate the constraint function at the given state and control.
122
+
123
+ Args:
124
+ x (jnp.ndarray): The state vector.
125
+ u (jnp.ndarray): The control vector.
126
+
127
+ Returns:
128
+ jnp.ndarray: The constraint violation values.
129
+ """
32
130
  return self.func(x, u)
33
131
 
34
132
 
35
133
  def nodal(
36
- _func=None,
134
+ _func: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
37
135
  *,
38
136
  nodes: Optional[List[int]] = None,
39
137
  convex: bool = False,
40
138
  vectorized: bool = False,
41
- grad_g_x: Optional[Callable] = None,
42
- grad_g_u: Optional[Callable] = None,
43
- ):
44
- def decorator(f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
139
+ grad_g_x: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
140
+ grad_g_u: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
141
+ ) -> Union[Callable, NodalConstraint]:
142
+ """
143
+ Decorator to build a `NodalConstraint` from a constraint function.
144
+
145
+ Usage examples:
146
+
147
+ ```python
148
+ @nodal
149
+ def g(x, u):
150
+ ...
151
+ ```
152
+ ```python
153
+ @nodal(nodes=[0, -1], convex=True, vectorized=False)
154
+ def g(x, u):
155
+ ...
156
+ ```
157
+
158
+ Or can directly wrap a function if a more lambda-function interface is desired:
159
+
160
+ ```python
161
+ constraint = nodal(lambda x, u: ...)
162
+ ```
163
+
164
+ Args:
165
+ _func (Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]]):
166
+ The function to wrap; populated automatically when using bare @nodal.
167
+ When `convex=False`, this is a standard function g(x, u) that should
168
+ return constraint residuals (g(x,u) <= 0). When `convex=True`, this
169
+ must be a cvxpy expression following the DPP ruleset.
170
+ nodes (Optional[List[int]]):
171
+ Node indices where the constraint applies; default None applies to all.
172
+ convex (bool):
173
+ If True, the provided cvxpy.expression is directly passed to the cvxpy.problem.
174
+ vectorized (bool):
175
+ If False, auto-vectorize over nodes using `jax.vmap`. If True, assumes
176
+ the function already handles vectorization.
177
+ """
178
+ def decorator(f: Callable):
45
179
  return NodalConstraint(
46
- func=f, # no wraps, just keep the original
180
+ func=f,
47
181
  nodes=nodes,
48
182
  convex=convex,
49
183
  vectorized=vectorized,
@@ -51,4 +185,9 @@ def nodal(
51
185
  grad_g_u=grad_g_u,
52
186
  )
53
187
 
54
- return decorator if _func is None else decorator(_func)
188
+ if _func is None:
189
+ # Called with arguments, e.g., @nodal(nodes=[0, 1])
190
+ return decorator
191
+ else:
192
+ # Called as a bare decorator, e.g., @nodal
193
+ return decorator(_func)
@@ -9,6 +9,16 @@ from openscvx.constraints.ctcs import CTCSConstraint
9
9
 
10
10
  @dataclass
11
11
  class CTCSViolation:
12
+ """Class representing a continuous-time constraint satisfaction (CTCS) violation.
13
+
14
+ This class holds the constraint function and its gradients for computing
15
+ constraint violations in continuous-time optimization problems.
16
+
17
+ Attributes:
18
+ g (Callable): The constraint function that computes violations.
19
+ g_grad_x (Optional[Callable]): Gradient of g with respect to state x.
20
+ g_grad_u (Optional[Callable]): Gradient of g with respect to control u.
21
+ """
12
22
  g: Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]
13
23
  g_grad_x: Optional[Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]] = None
14
24
  g_grad_u: Optional[Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]] = None
@@ -35,8 +45,8 @@ def get_g_grad_u(constraints_ctcs: List[CTCSConstraint]) -> Callable[[jnp.ndarra
35
45
 
36
46
 
37
47
  def get_g_func(constraints_ctcs: List[CTCSConstraint]) -> Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]:
38
- def g_func(x: jnp.array, u: jnp.array, node: int) -> jnp.array:
39
- return sum(c(x, u, node) for c in constraints_ctcs)
48
+ def g_func(x: jnp.array, u: jnp.array, node: int, *params) -> jnp.array:
49
+ return sum(c(x, u, node, *params) for c in constraints_ctcs)
40
50
 
41
51
  return g_func
42
52