openscvx 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +0 -0
- openscvx/_version.py +21 -0
- openscvx/augmentation/__init__.py +0 -0
- openscvx/augmentation/ctcs.py +44 -0
- openscvx/augmentation/dynamics_augmentation.py +122 -0
- openscvx/config.py +247 -0
- {constraints → openscvx/constraints}/ctcs.py +27 -3
- {constraints → openscvx/constraints}/nodal.py +17 -11
- openscvx/constraints/violation.py +67 -0
- openscvx/discretization.py +170 -0
- openscvx/dynamics.py +41 -0
- openscvx/integrators.py +139 -0
- openscvx/io.py +81 -0
- openscvx/ocp.py +160 -0
- openscvx/plotting.py +632 -0
- openscvx/post_processing.py +36 -0
- openscvx/propagation.py +135 -0
- openscvx/ptr.py +149 -0
- openscvx/trajoptproblem.py +337 -0
- openscvx/utils.py +80 -0
- {openscvx-0.1.0.dist-info → openscvx-0.1.2.dist-info}/METADATA +61 -18
- openscvx-0.1.2.dist-info/RECORD +27 -0
- openscvx-0.1.2.dist-info/top_level.txt +1 -0
- constraints/violation.py +0 -26
- openscvx-0.1.0.dist-info/RECORD +0 -10
- openscvx-0.1.0.dist-info/top_level.txt +0 -1
- {constraints → openscvx/constraints}/__init__.py +0 -0
- {constraints → openscvx/constraints}/boundary.py +0 -0
- {openscvx-0.1.0.dist-info → openscvx-0.1.2.dist-info}/WHEEL +0 -0
- {openscvx-0.1.0.dist-info → openscvx-0.1.2.dist-info}/licenses/LICENSE +0 -0
openscvx/__init__.py
ADDED
|
File without changes
|
openscvx/_version.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
|
|
5
|
+
|
|
6
|
+
TYPE_CHECKING = False
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from typing import Tuple
|
|
9
|
+
from typing import Union
|
|
10
|
+
|
|
11
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
12
|
+
else:
|
|
13
|
+
VERSION_TUPLE = object
|
|
14
|
+
|
|
15
|
+
version: str
|
|
16
|
+
__version__: str
|
|
17
|
+
__version_tuple__: VERSION_TUPLE
|
|
18
|
+
version_tuple: VERSION_TUPLE
|
|
19
|
+
|
|
20
|
+
__version__ = version = '0.1.2'
|
|
21
|
+
__version_tuple__ = version_tuple = (0, 1, 2)
|
|
File without changes
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from openscvx.constraints.ctcs import CTCSConstraint
|
|
4
|
+
|
|
5
|
+
def sort_ctcs_constraints(constraints_ctcs: List[CTCSConstraint], N: int):
|
|
6
|
+
idx_to_nodes: dict[int, tuple] = {}
|
|
7
|
+
next_idx = 0
|
|
8
|
+
for c in constraints_ctcs:
|
|
9
|
+
# normalize None to full horizon
|
|
10
|
+
c.nodes = c.nodes or (0, N)
|
|
11
|
+
key = c.nodes
|
|
12
|
+
|
|
13
|
+
if c.idx is not None:
|
|
14
|
+
# user supplied an identifier: ensure it always points to the same interval
|
|
15
|
+
if c.idx in idx_to_nodes:
|
|
16
|
+
if idx_to_nodes[c.idx] != key:
|
|
17
|
+
raise ValueError(
|
|
18
|
+
f"idx={c.idx} was first used with interval={idx_to_nodes[c.idx]}, "
|
|
19
|
+
f"but now you gave it interval={key}"
|
|
20
|
+
)
|
|
21
|
+
else:
|
|
22
|
+
idx_to_nodes[c.idx] = key
|
|
23
|
+
|
|
24
|
+
else:
|
|
25
|
+
# no identifier: see if this interval already has one
|
|
26
|
+
for existing_id, nodes in idx_to_nodes.items():
|
|
27
|
+
if nodes == key:
|
|
28
|
+
c.idx = existing_id
|
|
29
|
+
break
|
|
30
|
+
else:
|
|
31
|
+
# brand-new interval: pick the next free auto-id
|
|
32
|
+
while next_idx in idx_to_nodes:
|
|
33
|
+
next_idx += 1
|
|
34
|
+
c.idx = next_idx
|
|
35
|
+
idx_to_nodes[next_idx] = key
|
|
36
|
+
next_idx += 1
|
|
37
|
+
|
|
38
|
+
# Extract your intervals in ascending‐idx order
|
|
39
|
+
ordered_ids = sorted(idx_to_nodes.keys())
|
|
40
|
+
node_intervals = [ idx_to_nodes[i] for i in ordered_ids ]
|
|
41
|
+
id_to_position = { ident: pos for pos, ident in enumerate(ordered_ids) }
|
|
42
|
+
num_augmented_states = len(ordered_ids)
|
|
43
|
+
|
|
44
|
+
return constraints_ctcs, node_intervals, num_augmented_states,
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from typing import Callable, List, Tuple
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
|
|
6
|
+
from openscvx.constraints.violation import CTCSViolation
|
|
7
|
+
from openscvx.dynamics import Dynamics
|
|
8
|
+
|
|
9
|
+
def build_augmented_dynamics(
|
|
10
|
+
dynamics_non_augmented: Dynamics,
|
|
11
|
+
violations: List[CTCSViolation],
|
|
12
|
+
idx_x_true: slice,
|
|
13
|
+
idx_u_true: slice,
|
|
14
|
+
) -> Dynamics:
|
|
15
|
+
dynamics_augmented = Dynamics(
|
|
16
|
+
f=get_augmented_dynamics(
|
|
17
|
+
dynamics_non_augmented.f, violations, idx_x_true, idx_u_true
|
|
18
|
+
),
|
|
19
|
+
)
|
|
20
|
+
A, B = get_jacobians(
|
|
21
|
+
dynamics_augmented.f, dynamics_non_augmented, violations, idx_x_true, idx_u_true
|
|
22
|
+
)
|
|
23
|
+
dynamics_augmented.A = A
|
|
24
|
+
dynamics_augmented.B = B
|
|
25
|
+
return dynamics_augmented
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_augmented_dynamics(
|
|
29
|
+
dynamics: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
|
|
30
|
+
violations: List[CTCSViolation],
|
|
31
|
+
idx_x_true: slice,
|
|
32
|
+
idx_u_true: slice,
|
|
33
|
+
) -> Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]:
|
|
34
|
+
def dynamics_augmented(x: jnp.array, u: jnp.array, node: int) -> jnp.array:
|
|
35
|
+
x_dot = dynamics(x[idx_x_true], u[idx_u_true])
|
|
36
|
+
|
|
37
|
+
# Iterate through the g_func dictionary and stack the output each function
|
|
38
|
+
# to x_dot
|
|
39
|
+
for v in violations:
|
|
40
|
+
x_dot = jnp.hstack([x_dot, v.g(x[idx_x_true], u[idx_u_true], node)])
|
|
41
|
+
|
|
42
|
+
return x_dot
|
|
43
|
+
|
|
44
|
+
return dynamics_augmented
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_jacobians(
|
|
48
|
+
dyn_augmented: Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray],
|
|
49
|
+
dynamics_non_augmented: Dynamics,
|
|
50
|
+
violations: List[CTCSViolation],
|
|
51
|
+
idx_x_true: slice,
|
|
52
|
+
idx_u_true: slice,
|
|
53
|
+
) -> Tuple[
|
|
54
|
+
Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray],
|
|
55
|
+
Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray],
|
|
56
|
+
]:
|
|
57
|
+
# 1) Early return if absolutely no custom grads anywhere
|
|
58
|
+
no_dyn_grads = dynamics_non_augmented.A is None and dynamics_non_augmented.B is None
|
|
59
|
+
no_vio_grads = all(v.g_grad_x is None and v.g_grad_u is None for v in violations)
|
|
60
|
+
|
|
61
|
+
if no_dyn_grads and no_vio_grads:
|
|
62
|
+
return (
|
|
63
|
+
jax.jacfwd(dyn_augmented, argnums=0),
|
|
64
|
+
jax.jacfwd(dyn_augmented, argnums=1),
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# 2) Build the *true‐state* Jacobians of f(x_true,u_true)
|
|
68
|
+
f_fn = dynamics_non_augmented.f
|
|
69
|
+
if dynamics_non_augmented.A is None:
|
|
70
|
+
A_f = lambda x_true, u_true: jax.jacfwd(f_fn, argnums=0)(x_true, u_true)
|
|
71
|
+
else:
|
|
72
|
+
A_f = dynamics_non_augmented.A
|
|
73
|
+
|
|
74
|
+
if dynamics_non_augmented.B is None:
|
|
75
|
+
B_f = lambda x_true, u_true: jax.jacfwd(f_fn, argnums=1)(x_true, u_true)
|
|
76
|
+
else:
|
|
77
|
+
B_f = dynamics_non_augmented.B
|
|
78
|
+
|
|
79
|
+
# 3) Build per-violation gradients
|
|
80
|
+
def make_violation_grad_x(i: int) -> Callable:
|
|
81
|
+
viol = violations[i]
|
|
82
|
+
# use user‐provided if present, otherwise autodiff viol.g in argnum=0
|
|
83
|
+
return viol.g_grad_x or jax.jacfwd(viol.g, argnums=0)
|
|
84
|
+
|
|
85
|
+
def make_violation_grad_u(i: int) -> Callable:
|
|
86
|
+
viol = violations[i]
|
|
87
|
+
# use user‐provided if present, otherwise autodiff viol.g in argnum=0
|
|
88
|
+
return viol.g_grad_u or jax.jacfwd(viol.g, argnums=1)
|
|
89
|
+
|
|
90
|
+
# 4) Assemble A_aug, B_aug
|
|
91
|
+
def A(x_aug, u_aug, node):
|
|
92
|
+
# dynamics block + zero‐pad
|
|
93
|
+
Af = A_f(x_aug[idx_x_true], u_aug[idx_u_true]) # (n_f, n_x_true)
|
|
94
|
+
nv = sum(
|
|
95
|
+
v.g(x_aug[idx_x_true], u_aug[idx_u_true], node).shape[0] for v in violations
|
|
96
|
+
) # total # rows of violations
|
|
97
|
+
zero_pad = jnp.zeros((Af.shape[0], nv)) # (n_f, n_v)
|
|
98
|
+
top = jnp.hstack([Af, zero_pad]) # (n_f, n_x_true + n_v)
|
|
99
|
+
|
|
100
|
+
# violation blocks
|
|
101
|
+
rows = [top]
|
|
102
|
+
for i in range(len(violations)):
|
|
103
|
+
dx_i = make_violation_grad_x(i)(
|
|
104
|
+
x_aug[idx_x_true], u_aug[idx_u_true], node
|
|
105
|
+
) # (n_gi, n_x_true)
|
|
106
|
+
pad_i = jnp.zeros((dx_i.shape[0], nv)) # (n_gi, n_v)
|
|
107
|
+
rows.append(jnp.hstack([dx_i, pad_i])) # (n_gi, n_x_true + n_v)
|
|
108
|
+
|
|
109
|
+
return jnp.vstack(rows)
|
|
110
|
+
|
|
111
|
+
def B(x_aug, u_aug, node):
|
|
112
|
+
Bf = B_f(x_aug[idx_x_true], u_aug[idx_u_true]) # (n_f, n_u_true)
|
|
113
|
+
rows = [Bf]
|
|
114
|
+
for i in range(len(violations)):
|
|
115
|
+
du_i = make_violation_grad_u(i)(
|
|
116
|
+
x_aug[idx_x_true], u_aug[idx_u_true], node
|
|
117
|
+
) # (n_gi, n_u_true)
|
|
118
|
+
rows.append(du_i)
|
|
119
|
+
|
|
120
|
+
return jnp.vstack(rows)
|
|
121
|
+
|
|
122
|
+
return A, B
|
openscvx/config.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Dict, List
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_affine_scaling_matrices(n, minimum, maximum):
|
|
7
|
+
S = np.diag(np.maximum(np.ones(n), abs(minimum - maximum) / 2))
|
|
8
|
+
c = (maximum + minimum) / 2
|
|
9
|
+
return S, c
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class DiscretizationConfig:
|
|
14
|
+
dis_type: str = "FOH"
|
|
15
|
+
custom_integrator: bool = True
|
|
16
|
+
solver: str = "Tsit5"
|
|
17
|
+
args: Dict = field(default_factory=dict)
|
|
18
|
+
atol: float = 1e-3
|
|
19
|
+
rtol: float = 1e-6
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
Configuration class for discretization settings.
|
|
23
|
+
|
|
24
|
+
This class defines the parameters required for discretizing system dynamics.
|
|
25
|
+
|
|
26
|
+
Main arguments:
|
|
27
|
+
These are the arguments most commonly used day-to-day.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
dis_type (str): The type of discretization to use (e.g., "FOH" for First-Order Hold). Defaults to "FOH".
|
|
31
|
+
custom_integrator (bool): This enables our custom fixed-step RK45 algorthim. This tends to be faster then Diffrax but unless your going for speed, its reccomended to stick with Diffrax for robustness and other solver options. Defaults to False.
|
|
32
|
+
solver (str): Not used if custom_integrator is enabled. Any choice of solver in Diffrax is valid, please refer here, [How to Choose a Solver](https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/). Defaults to "Tsit5".
|
|
33
|
+
|
|
34
|
+
Other arguments:
|
|
35
|
+
These arguments are less frequently used, and for most purposes you shouldn't need to understand these.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
args (Dict): Additional arguments to pass to the solver which can be found [here](https://docs.kidger.site/diffrax/api/diffeqsolve/). Defaults to an empty dictionary.
|
|
39
|
+
atol (float): Absolute tolerance for the solver. Defaults to 1e-3.
|
|
40
|
+
rtol (float): Relative tolerance for the solver. Defaults to 1e-6.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class DevConfig:
|
|
46
|
+
profiling: bool = False
|
|
47
|
+
debug: bool = False
|
|
48
|
+
printing: bool = True
|
|
49
|
+
|
|
50
|
+
"""
|
|
51
|
+
Configuration class for development settings.
|
|
52
|
+
|
|
53
|
+
This class defines the parameters used for development and debugging purposes.
|
|
54
|
+
|
|
55
|
+
Main arguments:
|
|
56
|
+
These are the arguments most commonly used day-to-day.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
profiling (bool): Whether to enable profiling for performance analysis. Defaults to False.
|
|
60
|
+
debug (bool): Disables all precompilation so you can place breakpoints and inspect values. Defaults to False.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class ConvexSolverConfig:
|
|
66
|
+
solver: str = "QOCO"
|
|
67
|
+
solver_args: dict = field(default_factory=lambda: {"abstol": 1e-6, "reltol": 1e-9})
|
|
68
|
+
cvxpygen: bool = False
|
|
69
|
+
|
|
70
|
+
"""
|
|
71
|
+
Configuration class for convex solver settings.
|
|
72
|
+
|
|
73
|
+
This class defines the parameters required for configuring a convex solver.
|
|
74
|
+
|
|
75
|
+
These are the arguments most commonly used day-to-day. Generally I have found [QOCO](https://qoco-org.github.io/qoco/index.html) to be the most performant of the CVXPY solvers for these types of problems (I do have a bias as the author is from my group) and can handle up to SOCP's.
|
|
76
|
+
[CLARABEL](https://clarabel.org/stable/) is also a great option with feasibility checking and can handle a few more problem types.
|
|
77
|
+
[CVXPYGen](https://github.com/cvxgrp/cvxpygen) is also great if your problem isn't too large and allows. I have found qocogen to be the most performant of the CVXPYGen solvers.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
solver (str): The name of the CVXPY solver to use. A list of options can be found [here](https://www.cvxpy.org/tutorial/solvers/index.html). Defaults to "QOCO".
|
|
81
|
+
solver_args (dict): Ensure you are using the correct arguments for your solver as they are not all common. Additional arguments to configure the solver, such as tolerances.
|
|
82
|
+
Defaults to {"abstol": 1e-6, "reltol": 1e-9}.
|
|
83
|
+
cvxpygen (bool): Whether to enable CVXPY code generation for the solver. Defaults to False.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class PropagationConfig:
|
|
89
|
+
inter_sample: int = 30
|
|
90
|
+
dt: float = 0.1
|
|
91
|
+
solver: str = "Dopri8"
|
|
92
|
+
args: Dict = field(default_factory=dict)
|
|
93
|
+
atol: float = 1e-3
|
|
94
|
+
rtol: float = 1e-6
|
|
95
|
+
|
|
96
|
+
"""
|
|
97
|
+
Configuration class for propagation settings.
|
|
98
|
+
|
|
99
|
+
This class defines the parameters required for propagating the nonlinear system dynamics using the optimal control sequence.
|
|
100
|
+
|
|
101
|
+
Main arguments:
|
|
102
|
+
These are the arguments most commonly used day-to-day.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
dt (float): The time step for propagation. Defaults to 0.1.
|
|
106
|
+
inter_sample (int): How dense the propagation within multishot discretization should be.
|
|
107
|
+
|
|
108
|
+
Other arguments:
|
|
109
|
+
The solver should likley not to be changed as it is a high accuracy 8th order runga kutta method.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
solver (str): The numerical solver to use for propagation (e.g., "Dopri8"). Defaults to "Dopri8".
|
|
113
|
+
args (Dict): Additional arguments to pass to the solver. Defaults to an empty dictionary.
|
|
114
|
+
atol (float): Absolute tolerance for the solver. Defaults to 1e-3.
|
|
115
|
+
rtol (float): Relative tolerance for the solver. Defaults to 1e-6.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@dataclass
|
|
120
|
+
class SimConfig:
|
|
121
|
+
x_bar: np.ndarray
|
|
122
|
+
u_bar: np.ndarray
|
|
123
|
+
initial_state: np.ndarray
|
|
124
|
+
final_state: np.ndarray
|
|
125
|
+
max_state: np.ndarray
|
|
126
|
+
min_state: np.ndarray
|
|
127
|
+
max_control: np.ndarray
|
|
128
|
+
min_control: np.ndarray
|
|
129
|
+
total_time: float
|
|
130
|
+
idx_x_true: slice
|
|
131
|
+
idx_u_true: slice
|
|
132
|
+
idx_t: slice
|
|
133
|
+
idx_y: slice
|
|
134
|
+
idx_s: slice
|
|
135
|
+
ctcs_node_intervals: list = None
|
|
136
|
+
constraints_ctcs: List[callable] = field(
|
|
137
|
+
default_factory=list
|
|
138
|
+
) # TODO (norrisg): clean this up, consider moving to dedicated `constraints` dataclass
|
|
139
|
+
constraints_nodal: List[callable] = field(default_factory=list)
|
|
140
|
+
n_states: int = None
|
|
141
|
+
n_controls: int = None
|
|
142
|
+
S_x: np.ndarray = None
|
|
143
|
+
inv_S_x: np.ndarray = None
|
|
144
|
+
c_x: np.ndarray = None
|
|
145
|
+
S_u: np.ndarray = None
|
|
146
|
+
inv_S_u: np.ndarray = None
|
|
147
|
+
c_u: np.ndarray = None
|
|
148
|
+
|
|
149
|
+
def __post_init__(self):
|
|
150
|
+
self.n_states = len(self.max_state)
|
|
151
|
+
self.n_controls = len(self.max_control)
|
|
152
|
+
|
|
153
|
+
assert (
|
|
154
|
+
len(self.initial_state.value) == self.n_states - (self.idx_y.stop - self.idx_y.start)
|
|
155
|
+
), f"Initial state must have {self.n_states - (self.idx_y.stop - self.idx_y.start)} elements"
|
|
156
|
+
assert (
|
|
157
|
+
len(self.final_state.value) == self.n_states - (self.idx_y.stop - self.idx_y.start)
|
|
158
|
+
), f"Final state must have {self.n_states - (self.idx_y.stop - self.idx_y.start)} elements"
|
|
159
|
+
assert (
|
|
160
|
+
self.max_state.shape[0] == self.n_states
|
|
161
|
+
), f"Max state must have {self.n_states} elements"
|
|
162
|
+
assert (
|
|
163
|
+
self.min_state.shape[0] == self.n_states
|
|
164
|
+
), f"Min state must have {self.n_states} elements"
|
|
165
|
+
assert (
|
|
166
|
+
self.max_control.shape[0] == self.n_controls
|
|
167
|
+
), f"Max control must have {self.n_controls} elements"
|
|
168
|
+
assert (
|
|
169
|
+
self.min_control.shape[0] == self.n_controls
|
|
170
|
+
), f"Min control must have {self.n_controls} elements"
|
|
171
|
+
|
|
172
|
+
if self.S_x is None or self.c_x is None:
|
|
173
|
+
self.S_x, self.c_x = get_affine_scaling_matrices(
|
|
174
|
+
self.n_states, self.min_state, self.max_state
|
|
175
|
+
)
|
|
176
|
+
# Use the fact that S_x is diagonal to compute the inverse
|
|
177
|
+
self.inv_S_x = np.diag(1 / np.diag(self.S_x))
|
|
178
|
+
if self.S_u is None or self.c_u is None:
|
|
179
|
+
self.S_u, self.c_u = get_affine_scaling_matrices(
|
|
180
|
+
self.n_controls, self.min_control, self.max_control
|
|
181
|
+
)
|
|
182
|
+
self.inv_S_u = np.diag(1 / np.diag(self.S_u))
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@dataclass
|
|
186
|
+
class ScpConfig:
|
|
187
|
+
n: int = None
|
|
188
|
+
k_max: int = 200
|
|
189
|
+
w_tr: float = 1e0
|
|
190
|
+
lam_vc: float = 1e0
|
|
191
|
+
ep_tr: float = 1e-4
|
|
192
|
+
ep_vb: float = 1e-4
|
|
193
|
+
ep_vc: float = 1e-8
|
|
194
|
+
lam_cost: float = 0.0
|
|
195
|
+
lam_vb: float = 0.0
|
|
196
|
+
uniform_time_grid: bool = False
|
|
197
|
+
cost_drop: int = -1
|
|
198
|
+
cost_relax: float = 1.0
|
|
199
|
+
w_tr_adapt: float = 1.0
|
|
200
|
+
w_tr_max: float = None
|
|
201
|
+
w_tr_max_scaling_factor: float = None
|
|
202
|
+
|
|
203
|
+
"""
|
|
204
|
+
Configuration class for Sequential Convex Programming (SCP).
|
|
205
|
+
|
|
206
|
+
This class defines the parameters used to configure the SCP solver. You will very likely need to modify
|
|
207
|
+
the weights for your problem. Please refer to my guide [here](https://haynec.github.io/openscvx/hyperparameter_tuning) for more information.
|
|
208
|
+
|
|
209
|
+
Attributes:
|
|
210
|
+
n (int): The number of discretization nodes. Defaults to `None`.
|
|
211
|
+
k_max (int): The maximum number of SCP iterations. Defaults to 200.
|
|
212
|
+
w_tr (float): The trust region weight. Defaults to 1.0.
|
|
213
|
+
lam_vc (float): The penalty weight for virtual control. Defaults to 1.0.
|
|
214
|
+
ep_tr (float): The trust region convergence tolerance. Defaults to 1e-4.
|
|
215
|
+
ep_vb (float): The boundary constraint convergence tolerance. Defaults to 1e-4.
|
|
216
|
+
ep_vc (float): The virtual constraint convergence tolerance. Defaults to 1e-8.
|
|
217
|
+
lam_cost (float): The weight for original cost. Defaults to 0.0.
|
|
218
|
+
lam_vb (float): The weight for virtual buffer. This is only used if there are nonconvex nodal constraints present. Defaults to 0.0.
|
|
219
|
+
uniform_time_grid (bool): Whether to use a uniform time grid. TODO haynec add a link to the time dilation page. Defaults to `False`.
|
|
220
|
+
cost_drop (int): The number of iterations to allow for cost stagnation before termination. Defaults to -1 (disabled).
|
|
221
|
+
cost_relax (float): The relaxation factor for cost reduction. Defaults to 1.0.
|
|
222
|
+
w_tr_adapt (float): The adaptation factor for the trust region weight. Defaults to 1.0.
|
|
223
|
+
w_tr_max (float): The maximum allowable trust region weight. Defaults to `None`.
|
|
224
|
+
w_tr_max_scaling_factor (float): The scaling factor for the maximum trust region weight. Defaults to `None`.
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
def __post_init__(self):
|
|
228
|
+
keys_to_scale = ["w_tr", "lam_vc", "lam_cost", "lam_vb"]
|
|
229
|
+
scale = max(getattr(self, key) for key in keys_to_scale)
|
|
230
|
+
for key in keys_to_scale:
|
|
231
|
+
setattr(self, key, getattr(self, key) / scale)
|
|
232
|
+
|
|
233
|
+
if self.w_tr_max_scaling_factor is not None and self.w_tr_max is None:
|
|
234
|
+
self.w_tr_max = self.w_tr_max_scaling_factor * self.w_tr
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
@dataclass
|
|
238
|
+
class Config:
|
|
239
|
+
sim: SimConfig
|
|
240
|
+
scp: ScpConfig
|
|
241
|
+
cvx: ConvexSolverConfig
|
|
242
|
+
dis: DiscretizationConfig
|
|
243
|
+
prp: PropagationConfig
|
|
244
|
+
dev: DevConfig
|
|
245
|
+
|
|
246
|
+
def __post_init__(self):
|
|
247
|
+
pass
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
from typing import Callable, Sequence, Tuple, Optional
|
|
3
3
|
import functools
|
|
4
|
+
import types
|
|
4
5
|
|
|
5
6
|
from jax.lax import cond
|
|
6
7
|
import jax.numpy as jnp
|
|
@@ -10,8 +11,18 @@ import jax.numpy as jnp
|
|
|
10
11
|
class CTCSConstraint:
|
|
11
12
|
func: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
|
|
12
13
|
penalty: Callable[[jnp.ndarray], jnp.ndarray]
|
|
13
|
-
nodes: Optional[
|
|
14
|
+
nodes: Optional[Tuple[int, int]] = None
|
|
14
15
|
idx: Optional[int] = None
|
|
16
|
+
grad_f_x: Optional[Callable] = None
|
|
17
|
+
grad_f_u: Optional[Callable] = None
|
|
18
|
+
|
|
19
|
+
def __post_init__(self):
|
|
20
|
+
if self.grad_f_x is not None:
|
|
21
|
+
_grad_f_x = self.grad_f_x
|
|
22
|
+
self.grad_f_x = lambda x, u, nodes: _grad_f_x(x, u)
|
|
23
|
+
if self.grad_f_u is not None:
|
|
24
|
+
_grad_f_u = self.grad_f_u
|
|
25
|
+
self.grad_f_u = lambda x, u, nodes: _grad_f_u(x, u)
|
|
15
26
|
|
|
16
27
|
def __call__(self, x, u, node):
|
|
17
28
|
return cond(
|
|
@@ -28,6 +39,8 @@ def ctcs(
|
|
|
28
39
|
penalty: str = "squared_relu",
|
|
29
40
|
nodes: Optional[Sequence[Tuple[int, int]]] = None,
|
|
30
41
|
idx: Optional[int] = None,
|
|
42
|
+
grad_f_x: Optional[Callable] = None,
|
|
43
|
+
grad_f_u: Optional[Callable] = None,
|
|
31
44
|
):
|
|
32
45
|
"""Decorator to mark a function as a 'ctcs' constraint.
|
|
33
46
|
|
|
@@ -44,14 +57,25 @@ def ctcs(
|
|
|
44
57
|
def pen(x):
|
|
45
58
|
r = jnp.maximum(0, x)
|
|
46
59
|
return jnp.where(r < delta, 0.5 * r**2, r - 0.5 * delta)
|
|
47
|
-
|
|
60
|
+
elif penalty == "smooth_relu":
|
|
61
|
+
c = 1e-8
|
|
62
|
+
pen = lambda x: (jnp.maximum(0, x) ** 2 + c**2) ** 0.5 - c
|
|
63
|
+
elif isinstance(penalty, types.LambdaType):
|
|
64
|
+
pen = penalty
|
|
48
65
|
else:
|
|
49
66
|
raise ValueError(f"Unknown penalty {penalty}")
|
|
50
67
|
|
|
51
68
|
def decorator(f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
|
|
52
69
|
# wrap so name, doc, signature stay on f
|
|
53
70
|
wrapped = functools.wraps(f)(f)
|
|
54
|
-
return CTCSConstraint(
|
|
71
|
+
return CTCSConstraint(
|
|
72
|
+
func=wrapped,
|
|
73
|
+
penalty=pen,
|
|
74
|
+
nodes=nodes,
|
|
75
|
+
idx=idx,
|
|
76
|
+
grad_f_x=grad_f_x,
|
|
77
|
+
grad_f_u=grad_f_u,
|
|
78
|
+
)
|
|
55
79
|
|
|
56
80
|
# if called as @ctcs or @ctcs(...), _func will be None and we return decorator
|
|
57
81
|
if _func is None:
|
|
@@ -11,20 +11,22 @@ class NodalConstraint:
|
|
|
11
11
|
nodes: Optional[List[int]] = None
|
|
12
12
|
convex: bool = False
|
|
13
13
|
vectorized: bool = False
|
|
14
|
+
grad_g_x: Optional[Callable] = None
|
|
15
|
+
grad_g_u: Optional[Callable] = None
|
|
14
16
|
|
|
15
17
|
def __post_init__(self):
|
|
16
18
|
if not self.convex:
|
|
17
|
-
#
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
self.
|
|
21
|
-
|
|
22
|
-
self.grad_g_u =
|
|
23
|
-
|
|
24
|
-
self.g = vmap(
|
|
25
|
-
self.grad_g_x =
|
|
26
|
-
self.grad_g_u =
|
|
27
|
-
# if convex=True
|
|
19
|
+
# single-node but still using JAX
|
|
20
|
+
self.g = self.func
|
|
21
|
+
if self.grad_g_x is None:
|
|
22
|
+
self.grad_g_x = jacfwd(self.func, argnums=0)
|
|
23
|
+
if self.grad_g_u is None:
|
|
24
|
+
self.grad_g_u = jacfwd(self.func, argnums=1)
|
|
25
|
+
if not self.vectorized:
|
|
26
|
+
self.g = vmap(self.g, in_axes=(0, 0))
|
|
27
|
+
self.grad_g_x = vmap(self.grad_g_x, in_axes=(0, 0))
|
|
28
|
+
self.grad_g_u = vmap(self.grad_g_u, in_axes=(0, 0))
|
|
29
|
+
# if convex=True assume an external solver (e.g. CVX) will handle it
|
|
28
30
|
|
|
29
31
|
def __call__(self, x: jnp.ndarray, u: jnp.ndarray):
|
|
30
32
|
return self.func(x, u)
|
|
@@ -36,6 +38,8 @@ def nodal(
|
|
|
36
38
|
nodes: Optional[List[int]] = None,
|
|
37
39
|
convex: bool = False,
|
|
38
40
|
vectorized: bool = False,
|
|
41
|
+
grad_g_x: Optional[Callable] = None,
|
|
42
|
+
grad_g_u: Optional[Callable] = None,
|
|
39
43
|
):
|
|
40
44
|
def decorator(f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
|
|
41
45
|
return NodalConstraint(
|
|
@@ -43,6 +47,8 @@ def nodal(
|
|
|
43
47
|
nodes=nodes,
|
|
44
48
|
convex=convex,
|
|
45
49
|
vectorized=vectorized,
|
|
50
|
+
grad_g_x=grad_g_x,
|
|
51
|
+
grad_g_u=grad_g_u,
|
|
46
52
|
)
|
|
47
53
|
|
|
48
54
|
return decorator if _func is None else decorator(_func)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import List, Optional, Tuple, Callable
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
|
|
7
|
+
from openscvx.constraints.ctcs import CTCSConstraint
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class CTCSViolation:
|
|
12
|
+
g: Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]
|
|
13
|
+
g_grad_x: Optional[Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]] = None
|
|
14
|
+
g_grad_u: Optional[Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]] = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_g_grad_x(constraints_ctcs: List[CTCSConstraint]) -> Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]:
|
|
18
|
+
def g_grad_x(x: jnp.ndarray, u: jnp.ndarray, node: int) -> jnp.ndarray:
|
|
19
|
+
grads = [
|
|
20
|
+
c.grad_f_x(x, u, node) for c in constraints_ctcs if c.grad_f_x is not None
|
|
21
|
+
]
|
|
22
|
+
return sum(grads) if grads else None
|
|
23
|
+
|
|
24
|
+
return g_grad_x
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_g_grad_u(constraints_ctcs: List[CTCSConstraint]) -> Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]:
|
|
28
|
+
def g_grad_u(x: jnp.ndarray, u: jnp.ndarray, node: int) -> jnp.ndarray:
|
|
29
|
+
grads = [
|
|
30
|
+
c.grad_f_u(x, u, node) for c in constraints_ctcs if c.grad_f_u is not None
|
|
31
|
+
]
|
|
32
|
+
return sum(grads) if grads else None
|
|
33
|
+
|
|
34
|
+
return g_grad_u
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_g_func(constraints_ctcs: List[CTCSConstraint]) -> Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]:
|
|
38
|
+
def g_func(x: jnp.array, u: jnp.array, node: int) -> jnp.array:
|
|
39
|
+
return sum(c(x, u, node) for c in constraints_ctcs)
|
|
40
|
+
|
|
41
|
+
return g_func
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_g_funcs(constraints_ctcs: List[CTCSConstraint]) -> List[CTCSViolation]:
|
|
45
|
+
# Bucket by idx
|
|
46
|
+
groups: dict[int, List[CTCSConstraint]] = defaultdict(list)
|
|
47
|
+
for c in constraints_ctcs:
|
|
48
|
+
if c.idx is None:
|
|
49
|
+
raise ValueError(f"CTCS constraint {c} has no .idx assigned")
|
|
50
|
+
groups[c.idx].append(c)
|
|
51
|
+
|
|
52
|
+
# For each bucket, build one CTCSViolation
|
|
53
|
+
violations: List[CTCSViolation] = []
|
|
54
|
+
for idx, bucket in sorted(groups.items(), key=lambda kv: kv[0]):
|
|
55
|
+
g = get_g_func(bucket)
|
|
56
|
+
g_grad_x = get_g_grad_u(bucket) if all(c.grad_f_x for c in bucket) else None
|
|
57
|
+
g_grad_u = get_g_grad_x(bucket) if all(c.grad_f_u for c in bucket) else None
|
|
58
|
+
|
|
59
|
+
violations.append(
|
|
60
|
+
CTCSViolation(
|
|
61
|
+
g=g,
|
|
62
|
+
g_grad_x=g_grad_x,
|
|
63
|
+
g_grad_u=g_grad_u,
|
|
64
|
+
)
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return violations
|