openscvx 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

openscvx/__init__.py ADDED
File without changes
openscvx/_version.py ADDED
@@ -0,0 +1,21 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
5
+
6
+ TYPE_CHECKING = False
7
+ if TYPE_CHECKING:
8
+ from typing import Tuple
9
+ from typing import Union
10
+
11
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
12
+ else:
13
+ VERSION_TUPLE = object
14
+
15
+ version: str
16
+ __version__: str
17
+ __version_tuple__: VERSION_TUPLE
18
+ version_tuple: VERSION_TUPLE
19
+
20
+ __version__ = version = '0.1.2'
21
+ __version_tuple__ = version_tuple = (0, 1, 2)
File without changes
@@ -0,0 +1,44 @@
1
+ from typing import List
2
+
3
+ from openscvx.constraints.ctcs import CTCSConstraint
4
+
5
+ def sort_ctcs_constraints(constraints_ctcs: List[CTCSConstraint], N: int):
6
+ idx_to_nodes: dict[int, tuple] = {}
7
+ next_idx = 0
8
+ for c in constraints_ctcs:
9
+ # normalize None to full horizon
10
+ c.nodes = c.nodes or (0, N)
11
+ key = c.nodes
12
+
13
+ if c.idx is not None:
14
+ # user supplied an identifier: ensure it always points to the same interval
15
+ if c.idx in idx_to_nodes:
16
+ if idx_to_nodes[c.idx] != key:
17
+ raise ValueError(
18
+ f"idx={c.idx} was first used with interval={idx_to_nodes[c.idx]}, "
19
+ f"but now you gave it interval={key}"
20
+ )
21
+ else:
22
+ idx_to_nodes[c.idx] = key
23
+
24
+ else:
25
+ # no identifier: see if this interval already has one
26
+ for existing_id, nodes in idx_to_nodes.items():
27
+ if nodes == key:
28
+ c.idx = existing_id
29
+ break
30
+ else:
31
+ # brand-new interval: pick the next free auto-id
32
+ while next_idx in idx_to_nodes:
33
+ next_idx += 1
34
+ c.idx = next_idx
35
+ idx_to_nodes[next_idx] = key
36
+ next_idx += 1
37
+
38
+ # Extract your intervals in ascending‐idx order
39
+ ordered_ids = sorted(idx_to_nodes.keys())
40
+ node_intervals = [ idx_to_nodes[i] for i in ordered_ids ]
41
+ id_to_position = { ident: pos for pos, ident in enumerate(ordered_ids) }
42
+ num_augmented_states = len(ordered_ids)
43
+
44
+ return constraints_ctcs, node_intervals, num_augmented_states,
@@ -0,0 +1,122 @@
1
+ from typing import Callable, List, Tuple
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+
6
+ from openscvx.constraints.violation import CTCSViolation
7
+ from openscvx.dynamics import Dynamics
8
+
9
+ def build_augmented_dynamics(
10
+ dynamics_non_augmented: Dynamics,
11
+ violations: List[CTCSViolation],
12
+ idx_x_true: slice,
13
+ idx_u_true: slice,
14
+ ) -> Dynamics:
15
+ dynamics_augmented = Dynamics(
16
+ f=get_augmented_dynamics(
17
+ dynamics_non_augmented.f, violations, idx_x_true, idx_u_true
18
+ ),
19
+ )
20
+ A, B = get_jacobians(
21
+ dynamics_augmented.f, dynamics_non_augmented, violations, idx_x_true, idx_u_true
22
+ )
23
+ dynamics_augmented.A = A
24
+ dynamics_augmented.B = B
25
+ return dynamics_augmented
26
+
27
+
28
+ def get_augmented_dynamics(
29
+ dynamics: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
30
+ violations: List[CTCSViolation],
31
+ idx_x_true: slice,
32
+ idx_u_true: slice,
33
+ ) -> Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]:
34
+ def dynamics_augmented(x: jnp.array, u: jnp.array, node: int) -> jnp.array:
35
+ x_dot = dynamics(x[idx_x_true], u[idx_u_true])
36
+
37
+ # Iterate through the g_func dictionary and stack the output each function
38
+ # to x_dot
39
+ for v in violations:
40
+ x_dot = jnp.hstack([x_dot, v.g(x[idx_x_true], u[idx_u_true], node)])
41
+
42
+ return x_dot
43
+
44
+ return dynamics_augmented
45
+
46
+
47
+ def get_jacobians(
48
+ dyn_augmented: Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray],
49
+ dynamics_non_augmented: Dynamics,
50
+ violations: List[CTCSViolation],
51
+ idx_x_true: slice,
52
+ idx_u_true: slice,
53
+ ) -> Tuple[
54
+ Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray],
55
+ Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray],
56
+ ]:
57
+ # 1) Early return if absolutely no custom grads anywhere
58
+ no_dyn_grads = dynamics_non_augmented.A is None and dynamics_non_augmented.B is None
59
+ no_vio_grads = all(v.g_grad_x is None and v.g_grad_u is None for v in violations)
60
+
61
+ if no_dyn_grads and no_vio_grads:
62
+ return (
63
+ jax.jacfwd(dyn_augmented, argnums=0),
64
+ jax.jacfwd(dyn_augmented, argnums=1),
65
+ )
66
+
67
+ # 2) Build the *true‐state* Jacobians of f(x_true,u_true)
68
+ f_fn = dynamics_non_augmented.f
69
+ if dynamics_non_augmented.A is None:
70
+ A_f = lambda x_true, u_true: jax.jacfwd(f_fn, argnums=0)(x_true, u_true)
71
+ else:
72
+ A_f = dynamics_non_augmented.A
73
+
74
+ if dynamics_non_augmented.B is None:
75
+ B_f = lambda x_true, u_true: jax.jacfwd(f_fn, argnums=1)(x_true, u_true)
76
+ else:
77
+ B_f = dynamics_non_augmented.B
78
+
79
+ # 3) Build per-violation gradients
80
+ def make_violation_grad_x(i: int) -> Callable:
81
+ viol = violations[i]
82
+ # use user‐provided if present, otherwise autodiff viol.g in argnum=0
83
+ return viol.g_grad_x or jax.jacfwd(viol.g, argnums=0)
84
+
85
+ def make_violation_grad_u(i: int) -> Callable:
86
+ viol = violations[i]
87
+ # use user‐provided if present, otherwise autodiff viol.g in argnum=0
88
+ return viol.g_grad_u or jax.jacfwd(viol.g, argnums=1)
89
+
90
+ # 4) Assemble A_aug, B_aug
91
+ def A(x_aug, u_aug, node):
92
+ # dynamics block + zero‐pad
93
+ Af = A_f(x_aug[idx_x_true], u_aug[idx_u_true]) # (n_f, n_x_true)
94
+ nv = sum(
95
+ v.g(x_aug[idx_x_true], u_aug[idx_u_true], node).shape[0] for v in violations
96
+ ) # total # rows of violations
97
+ zero_pad = jnp.zeros((Af.shape[0], nv)) # (n_f, n_v)
98
+ top = jnp.hstack([Af, zero_pad]) # (n_f, n_x_true + n_v)
99
+
100
+ # violation blocks
101
+ rows = [top]
102
+ for i in range(len(violations)):
103
+ dx_i = make_violation_grad_x(i)(
104
+ x_aug[idx_x_true], u_aug[idx_u_true], node
105
+ ) # (n_gi, n_x_true)
106
+ pad_i = jnp.zeros((dx_i.shape[0], nv)) # (n_gi, n_v)
107
+ rows.append(jnp.hstack([dx_i, pad_i])) # (n_gi, n_x_true + n_v)
108
+
109
+ return jnp.vstack(rows)
110
+
111
+ def B(x_aug, u_aug, node):
112
+ Bf = B_f(x_aug[idx_x_true], u_aug[idx_u_true]) # (n_f, n_u_true)
113
+ rows = [Bf]
114
+ for i in range(len(violations)):
115
+ du_i = make_violation_grad_u(i)(
116
+ x_aug[idx_x_true], u_aug[idx_u_true], node
117
+ ) # (n_gi, n_u_true)
118
+ rows.append(du_i)
119
+
120
+ return jnp.vstack(rows)
121
+
122
+ return A, B
openscvx/config.py ADDED
@@ -0,0 +1,247 @@
1
+ import numpy as np
2
+ from dataclasses import dataclass, field
3
+ from typing import Dict, List
4
+
5
+
6
+ def get_affine_scaling_matrices(n, minimum, maximum):
7
+ S = np.diag(np.maximum(np.ones(n), abs(minimum - maximum) / 2))
8
+ c = (maximum + minimum) / 2
9
+ return S, c
10
+
11
+
12
+ @dataclass
13
+ class DiscretizationConfig:
14
+ dis_type: str = "FOH"
15
+ custom_integrator: bool = True
16
+ solver: str = "Tsit5"
17
+ args: Dict = field(default_factory=dict)
18
+ atol: float = 1e-3
19
+ rtol: float = 1e-6
20
+
21
+ """
22
+ Configuration class for discretization settings.
23
+
24
+ This class defines the parameters required for discretizing system dynamics.
25
+
26
+ Main arguments:
27
+ These are the arguments most commonly used day-to-day.
28
+
29
+ Args:
30
+ dis_type (str): The type of discretization to use (e.g., "FOH" for First-Order Hold). Defaults to "FOH".
31
+ custom_integrator (bool): This enables our custom fixed-step RK45 algorthim. This tends to be faster then Diffrax but unless your going for speed, its reccomended to stick with Diffrax for robustness and other solver options. Defaults to False.
32
+ solver (str): Not used if custom_integrator is enabled. Any choice of solver in Diffrax is valid, please refer here, [How to Choose a Solver](https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/). Defaults to "Tsit5".
33
+
34
+ Other arguments:
35
+ These arguments are less frequently used, and for most purposes you shouldn't need to understand these.
36
+
37
+ Args:
38
+ args (Dict): Additional arguments to pass to the solver which can be found [here](https://docs.kidger.site/diffrax/api/diffeqsolve/). Defaults to an empty dictionary.
39
+ atol (float): Absolute tolerance for the solver. Defaults to 1e-3.
40
+ rtol (float): Relative tolerance for the solver. Defaults to 1e-6.
41
+ """
42
+
43
+
44
+ @dataclass
45
+ class DevConfig:
46
+ profiling: bool = False
47
+ debug: bool = False
48
+ printing: bool = True
49
+
50
+ """
51
+ Configuration class for development settings.
52
+
53
+ This class defines the parameters used for development and debugging purposes.
54
+
55
+ Main arguments:
56
+ These are the arguments most commonly used day-to-day.
57
+
58
+ Args:
59
+ profiling (bool): Whether to enable profiling for performance analysis. Defaults to False.
60
+ debug (bool): Disables all precompilation so you can place breakpoints and inspect values. Defaults to False.
61
+ """
62
+
63
+
64
+ @dataclass
65
+ class ConvexSolverConfig:
66
+ solver: str = "QOCO"
67
+ solver_args: dict = field(default_factory=lambda: {"abstol": 1e-6, "reltol": 1e-9})
68
+ cvxpygen: bool = False
69
+
70
+ """
71
+ Configuration class for convex solver settings.
72
+
73
+ This class defines the parameters required for configuring a convex solver.
74
+
75
+ These are the arguments most commonly used day-to-day. Generally I have found [QOCO](https://qoco-org.github.io/qoco/index.html) to be the most performant of the CVXPY solvers for these types of problems (I do have a bias as the author is from my group) and can handle up to SOCP's.
76
+ [CLARABEL](https://clarabel.org/stable/) is also a great option with feasibility checking and can handle a few more problem types.
77
+ [CVXPYGen](https://github.com/cvxgrp/cvxpygen) is also great if your problem isn't too large and allows. I have found qocogen to be the most performant of the CVXPYGen solvers.
78
+
79
+ Args:
80
+ solver (str): The name of the CVXPY solver to use. A list of options can be found [here](https://www.cvxpy.org/tutorial/solvers/index.html). Defaults to "QOCO".
81
+ solver_args (dict): Ensure you are using the correct arguments for your solver as they are not all common. Additional arguments to configure the solver, such as tolerances.
82
+ Defaults to {"abstol": 1e-6, "reltol": 1e-9}.
83
+ cvxpygen (bool): Whether to enable CVXPY code generation for the solver. Defaults to False.
84
+ """
85
+
86
+
87
+ @dataclass
88
+ class PropagationConfig:
89
+ inter_sample: int = 30
90
+ dt: float = 0.1
91
+ solver: str = "Dopri8"
92
+ args: Dict = field(default_factory=dict)
93
+ atol: float = 1e-3
94
+ rtol: float = 1e-6
95
+
96
+ """
97
+ Configuration class for propagation settings.
98
+
99
+ This class defines the parameters required for propagating the nonlinear system dynamics using the optimal control sequence.
100
+
101
+ Main arguments:
102
+ These are the arguments most commonly used day-to-day.
103
+
104
+ Args:
105
+ dt (float): The time step for propagation. Defaults to 0.1.
106
+ inter_sample (int): How dense the propagation within multishot discretization should be.
107
+
108
+ Other arguments:
109
+ The solver should likley not to be changed as it is a high accuracy 8th order runga kutta method.
110
+
111
+ Args:
112
+ solver (str): The numerical solver to use for propagation (e.g., "Dopri8"). Defaults to "Dopri8".
113
+ args (Dict): Additional arguments to pass to the solver. Defaults to an empty dictionary.
114
+ atol (float): Absolute tolerance for the solver. Defaults to 1e-3.
115
+ rtol (float): Relative tolerance for the solver. Defaults to 1e-6.
116
+ """
117
+
118
+
119
+ @dataclass
120
+ class SimConfig:
121
+ x_bar: np.ndarray
122
+ u_bar: np.ndarray
123
+ initial_state: np.ndarray
124
+ final_state: np.ndarray
125
+ max_state: np.ndarray
126
+ min_state: np.ndarray
127
+ max_control: np.ndarray
128
+ min_control: np.ndarray
129
+ total_time: float
130
+ idx_x_true: slice
131
+ idx_u_true: slice
132
+ idx_t: slice
133
+ idx_y: slice
134
+ idx_s: slice
135
+ ctcs_node_intervals: list = None
136
+ constraints_ctcs: List[callable] = field(
137
+ default_factory=list
138
+ ) # TODO (norrisg): clean this up, consider moving to dedicated `constraints` dataclass
139
+ constraints_nodal: List[callable] = field(default_factory=list)
140
+ n_states: int = None
141
+ n_controls: int = None
142
+ S_x: np.ndarray = None
143
+ inv_S_x: np.ndarray = None
144
+ c_x: np.ndarray = None
145
+ S_u: np.ndarray = None
146
+ inv_S_u: np.ndarray = None
147
+ c_u: np.ndarray = None
148
+
149
+ def __post_init__(self):
150
+ self.n_states = len(self.max_state)
151
+ self.n_controls = len(self.max_control)
152
+
153
+ assert (
154
+ len(self.initial_state.value) == self.n_states - (self.idx_y.stop - self.idx_y.start)
155
+ ), f"Initial state must have {self.n_states - (self.idx_y.stop - self.idx_y.start)} elements"
156
+ assert (
157
+ len(self.final_state.value) == self.n_states - (self.idx_y.stop - self.idx_y.start)
158
+ ), f"Final state must have {self.n_states - (self.idx_y.stop - self.idx_y.start)} elements"
159
+ assert (
160
+ self.max_state.shape[0] == self.n_states
161
+ ), f"Max state must have {self.n_states} elements"
162
+ assert (
163
+ self.min_state.shape[0] == self.n_states
164
+ ), f"Min state must have {self.n_states} elements"
165
+ assert (
166
+ self.max_control.shape[0] == self.n_controls
167
+ ), f"Max control must have {self.n_controls} elements"
168
+ assert (
169
+ self.min_control.shape[0] == self.n_controls
170
+ ), f"Min control must have {self.n_controls} elements"
171
+
172
+ if self.S_x is None or self.c_x is None:
173
+ self.S_x, self.c_x = get_affine_scaling_matrices(
174
+ self.n_states, self.min_state, self.max_state
175
+ )
176
+ # Use the fact that S_x is diagonal to compute the inverse
177
+ self.inv_S_x = np.diag(1 / np.diag(self.S_x))
178
+ if self.S_u is None or self.c_u is None:
179
+ self.S_u, self.c_u = get_affine_scaling_matrices(
180
+ self.n_controls, self.min_control, self.max_control
181
+ )
182
+ self.inv_S_u = np.diag(1 / np.diag(self.S_u))
183
+
184
+
185
+ @dataclass
186
+ class ScpConfig:
187
+ n: int = None
188
+ k_max: int = 200
189
+ w_tr: float = 1e0
190
+ lam_vc: float = 1e0
191
+ ep_tr: float = 1e-4
192
+ ep_vb: float = 1e-4
193
+ ep_vc: float = 1e-8
194
+ lam_cost: float = 0.0
195
+ lam_vb: float = 0.0
196
+ uniform_time_grid: bool = False
197
+ cost_drop: int = -1
198
+ cost_relax: float = 1.0
199
+ w_tr_adapt: float = 1.0
200
+ w_tr_max: float = None
201
+ w_tr_max_scaling_factor: float = None
202
+
203
+ """
204
+ Configuration class for Sequential Convex Programming (SCP).
205
+
206
+ This class defines the parameters used to configure the SCP solver. You will very likely need to modify
207
+ the weights for your problem. Please refer to my guide [here](https://haynec.github.io/openscvx/hyperparameter_tuning) for more information.
208
+
209
+ Attributes:
210
+ n (int): The number of discretization nodes. Defaults to `None`.
211
+ k_max (int): The maximum number of SCP iterations. Defaults to 200.
212
+ w_tr (float): The trust region weight. Defaults to 1.0.
213
+ lam_vc (float): The penalty weight for virtual control. Defaults to 1.0.
214
+ ep_tr (float): The trust region convergence tolerance. Defaults to 1e-4.
215
+ ep_vb (float): The boundary constraint convergence tolerance. Defaults to 1e-4.
216
+ ep_vc (float): The virtual constraint convergence tolerance. Defaults to 1e-8.
217
+ lam_cost (float): The weight for original cost. Defaults to 0.0.
218
+ lam_vb (float): The weight for virtual buffer. This is only used if there are nonconvex nodal constraints present. Defaults to 0.0.
219
+ uniform_time_grid (bool): Whether to use a uniform time grid. TODO haynec add a link to the time dilation page. Defaults to `False`.
220
+ cost_drop (int): The number of iterations to allow for cost stagnation before termination. Defaults to -1 (disabled).
221
+ cost_relax (float): The relaxation factor for cost reduction. Defaults to 1.0.
222
+ w_tr_adapt (float): The adaptation factor for the trust region weight. Defaults to 1.0.
223
+ w_tr_max (float): The maximum allowable trust region weight. Defaults to `None`.
224
+ w_tr_max_scaling_factor (float): The scaling factor for the maximum trust region weight. Defaults to `None`.
225
+ """
226
+
227
+ def __post_init__(self):
228
+ keys_to_scale = ["w_tr", "lam_vc", "lam_cost", "lam_vb"]
229
+ scale = max(getattr(self, key) for key in keys_to_scale)
230
+ for key in keys_to_scale:
231
+ setattr(self, key, getattr(self, key) / scale)
232
+
233
+ if self.w_tr_max_scaling_factor is not None and self.w_tr_max is None:
234
+ self.w_tr_max = self.w_tr_max_scaling_factor * self.w_tr
235
+
236
+
237
+ @dataclass
238
+ class Config:
239
+ sim: SimConfig
240
+ scp: ScpConfig
241
+ cvx: ConvexSolverConfig
242
+ dis: DiscretizationConfig
243
+ prp: PropagationConfig
244
+ dev: DevConfig
245
+
246
+ def __post_init__(self):
247
+ pass
@@ -1,6 +1,7 @@
1
1
  from dataclasses import dataclass
2
2
  from typing import Callable, Sequence, Tuple, Optional
3
3
  import functools
4
+ import types
4
5
 
5
6
  from jax.lax import cond
6
7
  import jax.numpy as jnp
@@ -10,8 +11,18 @@ import jax.numpy as jnp
10
11
  class CTCSConstraint:
11
12
  func: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
12
13
  penalty: Callable[[jnp.ndarray], jnp.ndarray]
13
- nodes: Optional[Sequence[Tuple[int, int]]] = None
14
+ nodes: Optional[Tuple[int, int]] = None
14
15
  idx: Optional[int] = None
16
+ grad_f_x: Optional[Callable] = None
17
+ grad_f_u: Optional[Callable] = None
18
+
19
+ def __post_init__(self):
20
+ if self.grad_f_x is not None:
21
+ _grad_f_x = self.grad_f_x
22
+ self.grad_f_x = lambda x, u, nodes: _grad_f_x(x, u)
23
+ if self.grad_f_u is not None:
24
+ _grad_f_u = self.grad_f_u
25
+ self.grad_f_u = lambda x, u, nodes: _grad_f_u(x, u)
15
26
 
16
27
  def __call__(self, x, u, node):
17
28
  return cond(
@@ -28,6 +39,8 @@ def ctcs(
28
39
  penalty: str = "squared_relu",
29
40
  nodes: Optional[Sequence[Tuple[int, int]]] = None,
30
41
  idx: Optional[int] = None,
42
+ grad_f_x: Optional[Callable] = None,
43
+ grad_f_u: Optional[Callable] = None,
31
44
  ):
32
45
  """Decorator to mark a function as a 'ctcs' constraint.
33
46
 
@@ -44,14 +57,25 @@ def ctcs(
44
57
  def pen(x):
45
58
  r = jnp.maximum(0, x)
46
59
  return jnp.where(r < delta, 0.5 * r**2, r - 0.5 * delta)
47
-
60
+ elif penalty == "smooth_relu":
61
+ c = 1e-8
62
+ pen = lambda x: (jnp.maximum(0, x) ** 2 + c**2) ** 0.5 - c
63
+ elif isinstance(penalty, types.LambdaType):
64
+ pen = penalty
48
65
  else:
49
66
  raise ValueError(f"Unknown penalty {penalty}")
50
67
 
51
68
  def decorator(f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
52
69
  # wrap so name, doc, signature stay on f
53
70
  wrapped = functools.wraps(f)(f)
54
- return CTCSConstraint(func=wrapped, penalty=pen, nodes=nodes, idx=idx)
71
+ return CTCSConstraint(
72
+ func=wrapped,
73
+ penalty=pen,
74
+ nodes=nodes,
75
+ idx=idx,
76
+ grad_f_x=grad_f_x,
77
+ grad_f_u=grad_f_u,
78
+ )
55
79
 
56
80
  # if called as @ctcs or @ctcs(...), _func will be None and we return decorator
57
81
  if _func is None:
@@ -11,20 +11,22 @@ class NodalConstraint:
11
11
  nodes: Optional[List[int]] = None
12
12
  convex: bool = False
13
13
  vectorized: bool = False
14
+ grad_g_x: Optional[Callable] = None
15
+ grad_g_u: Optional[Callable] = None
14
16
 
15
17
  def __post_init__(self):
16
18
  if not self.convex:
17
- # TODO: (haynec) switch to AOT instead of JIT
18
- if self.vectorized:
19
- # single-node but still using JAX
20
- self.g = jit(self.func)
21
- self.grad_g_x = jit(jacfwd(self.func, argnums=0))
22
- self.grad_g_u = jit(jacfwd(self.func, argnums=1))
23
- else:
24
- self.g = vmap(jit(self.func), in_axes=(0, 0))
25
- self.grad_g_x = jit(vmap(jacfwd(self.func, argnums=0), in_axes=(0, 0)))
26
- self.grad_g_u = jit(vmap(jacfwd(self.func, argnums=1), in_axes=(0, 0)))
27
- # if convex=True and inter_nodal=False, assume an external solver (e.g. CVX) will handle it
19
+ # single-node but still using JAX
20
+ self.g = self.func
21
+ if self.grad_g_x is None:
22
+ self.grad_g_x = jacfwd(self.func, argnums=0)
23
+ if self.grad_g_u is None:
24
+ self.grad_g_u = jacfwd(self.func, argnums=1)
25
+ if not self.vectorized:
26
+ self.g = vmap(self.g, in_axes=(0, 0))
27
+ self.grad_g_x = vmap(self.grad_g_x, in_axes=(0, 0))
28
+ self.grad_g_u = vmap(self.grad_g_u, in_axes=(0, 0))
29
+ # if convex=True assume an external solver (e.g. CVX) will handle it
28
30
 
29
31
  def __call__(self, x: jnp.ndarray, u: jnp.ndarray):
30
32
  return self.func(x, u)
@@ -36,6 +38,8 @@ def nodal(
36
38
  nodes: Optional[List[int]] = None,
37
39
  convex: bool = False,
38
40
  vectorized: bool = False,
41
+ grad_g_x: Optional[Callable] = None,
42
+ grad_g_u: Optional[Callable] = None,
39
43
  ):
40
44
  def decorator(f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
41
45
  return NodalConstraint(
@@ -43,6 +47,8 @@ def nodal(
43
47
  nodes=nodes,
44
48
  convex=convex,
45
49
  vectorized=vectorized,
50
+ grad_g_x=grad_g_x,
51
+ grad_g_u=grad_g_u,
46
52
  )
47
53
 
48
54
  return decorator if _func is None else decorator(_func)
@@ -0,0 +1,67 @@
1
+ from collections import defaultdict
2
+ from typing import List, Optional, Tuple, Callable
3
+ from dataclasses import dataclass
4
+
5
+ import jax.numpy as jnp
6
+
7
+ from openscvx.constraints.ctcs import CTCSConstraint
8
+
9
+
10
+ @dataclass
11
+ class CTCSViolation:
12
+ g: Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]
13
+ g_grad_x: Optional[Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]] = None
14
+ g_grad_u: Optional[Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]] = None
15
+
16
+
17
+ def get_g_grad_x(constraints_ctcs: List[CTCSConstraint]) -> Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]:
18
+ def g_grad_x(x: jnp.ndarray, u: jnp.ndarray, node: int) -> jnp.ndarray:
19
+ grads = [
20
+ c.grad_f_x(x, u, node) for c in constraints_ctcs if c.grad_f_x is not None
21
+ ]
22
+ return sum(grads) if grads else None
23
+
24
+ return g_grad_x
25
+
26
+
27
+ def get_g_grad_u(constraints_ctcs: List[CTCSConstraint]) -> Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]:
28
+ def g_grad_u(x: jnp.ndarray, u: jnp.ndarray, node: int) -> jnp.ndarray:
29
+ grads = [
30
+ c.grad_f_u(x, u, node) for c in constraints_ctcs if c.grad_f_u is not None
31
+ ]
32
+ return sum(grads) if grads else None
33
+
34
+ return g_grad_u
35
+
36
+
37
+ def get_g_func(constraints_ctcs: List[CTCSConstraint]) -> Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]:
38
+ def g_func(x: jnp.array, u: jnp.array, node: int) -> jnp.array:
39
+ return sum(c(x, u, node) for c in constraints_ctcs)
40
+
41
+ return g_func
42
+
43
+
44
+ def get_g_funcs(constraints_ctcs: List[CTCSConstraint]) -> List[CTCSViolation]:
45
+ # Bucket by idx
46
+ groups: dict[int, List[CTCSConstraint]] = defaultdict(list)
47
+ for c in constraints_ctcs:
48
+ if c.idx is None:
49
+ raise ValueError(f"CTCS constraint {c} has no .idx assigned")
50
+ groups[c.idx].append(c)
51
+
52
+ # For each bucket, build one CTCSViolation
53
+ violations: List[CTCSViolation] = []
54
+ for idx, bucket in sorted(groups.items(), key=lambda kv: kv[0]):
55
+ g = get_g_func(bucket)
56
+ g_grad_x = get_g_grad_u(bucket) if all(c.grad_f_x for c in bucket) else None
57
+ g_grad_u = get_g_grad_x(bucket) if all(c.grad_f_u for c in bucket) else None
58
+
59
+ violations.append(
60
+ CTCSViolation(
61
+ g=g,
62
+ g_grad_x=g_grad_x,
63
+ g_grad_u=g_grad_u,
64
+ )
65
+ )
66
+
67
+ return violations