openscvx 0.1.3__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

openscvx/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.3'
21
- __version_tuple__ = version_tuple = (0, 1, 3)
20
+ __version__ = version = '0.2.1'
21
+ __version_tuple__ = version_tuple = (0, 2, 1)
@@ -1,4 +1,5 @@
1
1
  from typing import Callable, List, Tuple
2
+ import inspect
2
3
 
3
4
  import jax
4
5
  import jax.numpy as jnp
@@ -26,18 +27,32 @@ def build_augmented_dynamics(
26
27
 
27
28
 
28
29
  def get_augmented_dynamics(
29
- dynamics: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
30
+ dynamics: Callable[..., jnp.ndarray],
30
31
  violations: List[CTCSViolation],
31
32
  idx_x_true: slice,
32
33
  idx_u_true: slice,
33
- ) -> Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]:
34
- def dynamics_augmented(x: jnp.array, u: jnp.array, node: int) -> jnp.array:
35
- x_dot = dynamics(x[idx_x_true], u[idx_u_true])
34
+ ) -> Callable[..., jnp.ndarray]:
35
+
36
+ def dynamics_augmented(x: jnp.ndarray, u: jnp.ndarray, node: int, *params) -> jnp.ndarray:
37
+ x_true = x[idx_x_true]
38
+ u_true = u[idx_u_true]
39
+
40
+ # Determine the arguments of dynamics function
41
+ func_signature = inspect.signature(dynamics)
42
+ expected_args = set(func_signature.parameters.keys())
43
+ # Filter params to only those expected by the dynamics function
44
+ filtered_params = {
45
+ f"{name}_": value for name, value in params if f"{name}_" in expected_args
46
+ }
47
+
48
+ if "node" in expected_args:
49
+ filtered_params["node"] = node
50
+
51
+ x_dot = dynamics(x_true, u_true, **filtered_params)
36
52
 
37
- # Iterate through the g_func dictionary and stack the output each function
38
- # to x_dot
39
53
  for v in violations:
40
- x_dot = jnp.hstack([x_dot, v.g(x[idx_x_true], u[idx_u_true], node)])
54
+ g_val = v.g(x_true, u_true, node, *params)
55
+ x_dot = jnp.hstack([x_dot, g_val])
41
56
 
42
57
  return x_dot
43
58
 
openscvx/config.py CHANGED
@@ -1,8 +1,9 @@
1
1
  import numpy as np
2
2
  from dataclasses import dataclass, field
3
- from typing import Dict, List
3
+ from typing import Dict, List, Optional, Callable
4
4
 
5
- from openscvx.constraints.boundary import BoundaryConstraint
5
+ from openscvx.backend.state import State
6
+ from openscvx.backend.control import Control
6
7
 
7
8
 
8
9
  def get_affine_scaling_matrices(n, minimum, maximum):
@@ -13,222 +14,334 @@ def get_affine_scaling_matrices(n, minimum, maximum):
13
14
 
14
15
  @dataclass
15
16
  class DiscretizationConfig:
16
- dis_type: str = "FOH"
17
- custom_integrator: bool = True
18
- solver: str = "Tsit5"
19
- args: Dict = field(default_factory=dict)
20
- atol: float = 1e-3
21
- rtol: float = 1e-6
22
-
23
- """
24
- Configuration class for discretization settings.
25
-
26
- This class defines the parameters required for discretizing system dynamics.
27
-
28
- Main arguments:
29
- These are the arguments most commonly used day-to-day.
30
-
31
- Args:
32
- dis_type (str): The type of discretization to use (e.g., "FOH" for First-Order Hold). Defaults to "FOH".
33
- custom_integrator (bool): This enables our custom fixed-step RK45 algorthim. This tends to be faster then Diffrax but unless your going for speed, its reccomended to stick with Diffrax for robustness and other solver options. Defaults to False.
34
- solver (str): Not used if custom_integrator is enabled. Any choice of solver in Diffrax is valid, please refer here, [How to Choose a Solver](https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/). Defaults to "Tsit5".
35
-
36
- Other arguments:
37
- These arguments are less frequently used, and for most purposes you shouldn't need to understand these.
38
-
39
- Args:
40
- args (Dict): Additional arguments to pass to the solver which can be found [here](https://docs.kidger.site/diffrax/api/diffeqsolve/). Defaults to an empty dictionary.
41
- atol (float): Absolute tolerance for the solver. Defaults to 1e-3.
42
- rtol (float): Relative tolerance for the solver. Defaults to 1e-6.
43
- """
44
17
 
18
+ def __init__(
19
+ self,
20
+ dis_type: str = "FOH",
21
+ custom_integrator: bool = False,
22
+ solver: str = "Tsit5",
23
+ args: Dict = None,
24
+ atol: float = 1e-3,
25
+ rtol: float = 1e-6,
26
+ ):
27
+ """
28
+ Configuration class for discretization settings.
29
+
30
+ This class defines the parameters required for discretizing system dynamics.
31
+
32
+ Main arguments:
33
+ These are the arguments most commonly used day-to-day.
34
+
35
+ Args:
36
+ dis_type (str): The type of discretization to use (e.g., "FOH" for First-Order Hold). Defaults to "FOH".
37
+ custom_integrator (bool): This enables our custom fixed-step RK45 algorithm. This tends to be faster than Diffrax but unless you're going for speed, it's recommended to stick with Diffrax for robustness and other solver options. Defaults to False.
38
+ solver (str): Not used if custom_integrator is enabled. Any choice of solver in Diffrax is valid, please refer here, [How to Choose a Solver](https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/). Defaults to "Tsit5".
39
+
40
+ Other arguments:
41
+ These arguments are less frequently used, and for most purposes you shouldn't need to understand these.
42
+
43
+ Args:
44
+ args (Dict): Additional arguments to pass to the solver which can be found [here](https://docs.kidger.site/diffrax/api/diffeqsolve/). Defaults to an empty dictionary.
45
+ atol (float): Absolute tolerance for the solver. Defaults to 1e-3.
46
+ rtol (float): Relative tolerance for the solver. Defaults to 1e-6.
47
+ """
48
+ self.dis_type = dis_type
49
+ self.custom_integrator = custom_integrator
50
+ self.solver = solver
51
+ self.args = args if args is not None else {}
52
+ self.atol = atol
53
+ self.rtol = rtol
45
54
 
46
55
  @dataclass
47
56
  class DevConfig:
48
- profiling: bool = False
49
- debug: bool = False
50
- printing: bool = True
51
57
 
52
- """
53
- Configuration class for development settings.
58
+ def __init__(self, profiling: bool = False, debug: bool = False, printing: bool = True):
59
+ """
60
+ Configuration class for development settings.
54
61
 
55
- This class defines the parameters used for development and debugging purposes.
62
+ This class defines the parameters used for development and debugging purposes.
56
63
 
57
- Main arguments:
58
- These are the arguments most commonly used day-to-day.
59
-
60
- Args:
61
- profiling (bool): Whether to enable profiling for performance analysis. Defaults to False.
62
- debug (bool): Disables all precompilation so you can place breakpoints and inspect values. Defaults to False.
63
- """
64
+ Main arguments:
65
+ These are the arguments most commonly used day-to-day.
64
66
 
67
+ Args:
68
+ profiling (bool): Whether to enable profiling for performance analysis. Defaults to False.
69
+ debug (bool): Disables all precompilation so you can place breakpoints and inspect values. Defaults to False.
70
+ printing (bool): Whether to enable printing during development. Defaults to True.
71
+ """
72
+ self.profiling = profiling
73
+ self.debug = debug
74
+ self.printing = printing
65
75
 
66
76
  @dataclass
67
77
  class ConvexSolverConfig:
68
- solver: str = "QOCO"
69
- solver_args: dict = field(default_factory=lambda: {"abstol": 1e-6, "reltol": 1e-9})
70
- cvxpygen: bool = False
71
-
72
- """
73
- Configuration class for convex solver settings.
74
-
75
- This class defines the parameters required for configuring a convex solver.
76
-
77
- These are the arguments most commonly used day-to-day. Generally I have found [QOCO](https://qoco-org.github.io/qoco/index.html) to be the most performant of the CVXPY solvers for these types of problems (I do have a bias as the author is from my group) and can handle up to SOCP's.
78
- [CLARABEL](https://clarabel.org/stable/) is also a great option with feasibility checking and can handle a few more problem types.
79
- [CVXPYGen](https://github.com/cvxgrp/cvxpygen) is also great if your problem isn't too large and allows. I have found qocogen to be the most performant of the CVXPYGen solvers.
80
-
81
- Args:
82
- solver (str): The name of the CVXPY solver to use. A list of options can be found [here](https://www.cvxpy.org/tutorial/solvers/index.html). Defaults to "QOCO".
83
- solver_args (dict): Ensure you are using the correct arguments for your solver as they are not all common. Additional arguments to configure the solver, such as tolerances.
84
- Defaults to {"abstol": 1e-6, "reltol": 1e-9}.
85
- cvxpygen (bool): Whether to enable CVXPY code generation for the solver. Defaults to False.
86
- """
78
+ def __init__(
79
+ self,
80
+ solver: str = "QOCO",
81
+ solver_args: dict = {"abstol": 1e-6, "reltol": 1e-9, "enforce_dpp": True},
82
+ cvxpygen: bool = False,
83
+ cvxpygen_override: bool = False,
84
+ ):
85
+ """
86
+ Configuration class for convex solver settings.
87
+
88
+ This class defines the parameters required for configuring a convex solver.
89
+
90
+ These are the arguments most commonly used day-to-day. Generally I have found [QOCO](https://qoco-org.github.io/qoco/index.html)
91
+ to be the most performant of the CVXPY solvers for these types of problems (I do have a bias as the author is from my group)
92
+ and can handle up to SOCP's. [CLARABEL](https://clarabel.org/stable/) is also a great option with feasibility checking and
93
+ can handle a few more problem types. [CVXPYGen](https://github.com/cvxgrp/cvxpygen) is also great if your problem isn't too large.
94
+ I have found qocogen to be the most performant of the CVXPYGen solvers.
95
+
96
+ Args:
97
+ solver (str): The name of the CVXPY solver to use. A list of options can be found
98
+ [here](https://www.cvxpy.org/tutorial/solvers/index.html). Defaults to "QOCO".
99
+ solver_args (dict, optional): Ensure you are using the correct arguments for your solver as they are not all common.
100
+ Additional arguments to configure the solver, such as tolerances.
101
+ Defaults to {"abstol": 1e-6, "reltol": 1e-9}.
102
+ cvxpygen (bool): Whether to enable CVXPY code generation for the solver. Defaults to False.
103
+ """
104
+ self.solver = solver
105
+ self.solver_args = solver_args if solver_args is not None else {"abstol": 1e-6, "reltol": 1e-9}
106
+ self.cvxpygen = cvxpygen
107
+ self.cvxpygen_override = cvxpygen_override
87
108
 
88
109
 
89
110
  @dataclass
90
111
  class PropagationConfig:
91
- inter_sample: int = 30
92
- dt: float = 0.1
93
- solver: str = "Dopri8"
94
- args: Dict = field(default_factory=dict)
95
- atol: float = 1e-3
96
- rtol: float = 1e-6
97
-
98
- """
99
- Configuration class for propagation settings.
100
-
101
- This class defines the parameters required for propagating the nonlinear system dynamics using the optimal control sequence.
102
-
103
- Main arguments:
104
- These are the arguments most commonly used day-to-day.
105
-
106
- Args:
107
- dt (float): The time step for propagation. Defaults to 0.1.
108
- inter_sample (int): How dense the propagation within multishot discretization should be.
109
-
110
- Other arguments:
111
- The solver should likley not to be changed as it is a high accuracy 8th order runga kutta method.
112
-
113
- Args:
114
- solver (str): The numerical solver to use for propagation (e.g., "Dopri8"). Defaults to "Dopri8".
115
- args (Dict): Additional arguments to pass to the solver. Defaults to an empty dictionary.
116
- atol (float): Absolute tolerance for the solver. Defaults to 1e-3.
117
- rtol (float): Relative tolerance for the solver. Defaults to 1e-6.
118
- """
119
-
120
-
121
- @dataclass
112
+ def __init__(
113
+ self,
114
+ inter_sample: int = 30,
115
+ dt: float = 0.01,
116
+ solver: str = "Dopri8",
117
+ max_tau_len: int = 1000,
118
+ args: Optional[Dict] = None,
119
+ atol: float = 1e-3,
120
+ rtol: float = 1e-6,
121
+ ):
122
+ """
123
+ Configuration class for propagation settings.
124
+
125
+ This class defines the parameters required for propagating the nonlinear system dynamics
126
+ using the optimal control sequence.
127
+
128
+ Main arguments:
129
+ These are the arguments most commonly used day-to-day.
130
+
131
+ Other arguments:
132
+ The solver should likely not be changed as it is a high accuracy 8th-order Runge-Kutta method.
133
+
134
+ Args:
135
+ inter_sample (int): How dense the propagation within multishot discretization should be. Defaults to 30.
136
+ dt (float): The time step for propagation. Defaults to 0.1.
137
+ solver (str): The numerical solver to use for propagation (e.g., "Dopri8"). Defaults to "Dopri8".
138
+ max_tau_len (int): The maximum length of the time vector for propagation. Defaults to 1000.
139
+ args (Dict, optional): Additional arguments to pass to the solver. Defaults to an empty dictionary.
140
+ atol (float): Absolute tolerance for the solver. Defaults to 1e-3.
141
+ rtol (float): Relative tolerance for the solver. Defaults to 1e-6.
142
+ """
143
+
144
+ self.inter_sample = inter_sample
145
+ self.dt = dt
146
+ self.solver = solver
147
+ self.max_tau_len = max_tau_len
148
+ self.args = args if args is not None else {}
149
+ self.atol = atol
150
+ self.rtol = rtol
151
+
152
+ @dataclass(init=False)
122
153
  class SimConfig:
123
- x_bar: np.ndarray
124
- u_bar: np.ndarray
125
- initial_state: BoundaryConstraint
126
- initial_state_prop: BoundaryConstraint
127
- final_state: np.ndarray
128
- max_state: np.ndarray
129
- min_state: np.ndarray
130
- max_control: np.ndarray
131
- min_control: np.ndarray
132
- total_time: float
133
- idx_x_true: slice
134
- idx_x_true_prop: slice
135
- idx_u_true: slice
136
- idx_t: slice
137
- idx_y: slice
138
- idx_y_prop: slice
139
- idx_s: slice
140
- ctcs_node_intervals: list = None
141
- constraints_ctcs: List[callable] = field(
142
- default_factory=list
143
- ) # TODO (norrisg): clean this up, consider moving to dedicated `constraints` dataclass
144
- constraints_nodal: List[callable] = field(default_factory=list)
145
- n_states: int = None
146
- n_states_prop: int = None
147
- n_controls: int = None
148
- S_x: np.ndarray = None
149
- inv_S_x: np.ndarray = None
150
- c_x: np.ndarray = None
151
- S_u: np.ndarray = None
152
- inv_S_u: np.ndarray = None
153
- c_u: np.ndarray = None
154
+ # No class-level field declarations
155
+
156
+ def __init__(
157
+ self,
158
+ x: State,
159
+ x_prop: State,
160
+ u: Control,
161
+ total_time: float,
162
+ idx_x_true: slice,
163
+ idx_x_true_prop: slice,
164
+ idx_u_true: slice,
165
+ idx_t: slice,
166
+ idx_y: slice,
167
+ idx_y_prop: slice,
168
+ idx_s: slice,
169
+ save_compiled: bool = True,
170
+ ctcs_node_intervals: Optional[list] = None,
171
+ constraints_ctcs: Optional[List[Callable]] = None,
172
+ constraints_nodal: Optional[List[Callable]] = None,
173
+ n_states: Optional[int] = None,
174
+ n_states_prop: Optional[int] = None,
175
+ n_controls: Optional[int] = None,
176
+ scaling_x_overrides: Optional[list] = None,
177
+ scaling_u_overrides: Optional[list] = None,
178
+ ):
179
+ """
180
+ Configuration class for simulation settings.
181
+
182
+ This class defines the parameters required for simulating a trajectory optimization problem.
183
+
184
+ Main arguments:
185
+ These are the arguments most commonly used day-to-day.
186
+
187
+ Args:
188
+ x (State): State object, must have .min and .max attributes for bounds.
189
+ x_prop (State): Propagation state object, must have .min and .max attributes for bounds.
190
+ u (Control): Control object, must have .min and .max attributes for bounds.
191
+ total_time (float): The total simulation time.
192
+ idx_x_true (slice): Slice for true state indices.
193
+ idx_x_true_prop (slice): Slice for true propagation state indices.
194
+ idx_u_true (slice): Slice for true control indices.
195
+ idx_t (slice): Slice for time index.
196
+ idx_y (slice): Slice for constraint violation indices.
197
+ idx_y_prop (slice): Slice for propagation constraint violation indices.
198
+ idx_s (slice): Slice for time dilation index.
199
+ save_compiled (bool): If True, save and reuse compiled solver functions. Defaults to True.
200
+ ctcs_node_intervals (list, optional): Node intervals for CTCS constraints.
201
+ constraints_ctcs (list, optional): List of CTCS constraints.
202
+ constraints_nodal (list, optional): List of nodal constraints.
203
+ n_states (int, optional): The number of state variables. Defaults to `None` (inferred from x.max).
204
+ n_states_prop (int, optional): The number of propagation state variables. Defaults to `None` (inferred from x_prop.max).
205
+ n_controls (int, optional): The number of control variables. Defaults to `None` (inferred from u.max).
206
+ scaling_x_overrides (list, optional): List of (upper_bound, lower_bound, idx) for custom state scaling. Each can be scalar or array, idx can be int, list, or slice.
207
+ scaling_u_overrides (list, optional): List of (upper_bound, lower_bound, idx) for custom control scaling. Each can be scalar or array, idx can be int, list, or slice.
208
+
209
+ Note:
210
+ You can specify custom scaling for specific states/controls using scaling_x_overrides and scaling_u_overrides. Any indices not covered by overrides will use the default min/max bounds.
211
+ """
212
+ # Assign all arguments to self
213
+ self.x = x
214
+ self.x_prop = x_prop
215
+ self.u = u
216
+ self.total_time = total_time
217
+ self.idx_x_true = idx_x_true
218
+ self.idx_x_true_prop = idx_x_true_prop
219
+ self.idx_u_true = idx_u_true
220
+ self.idx_t = idx_t
221
+ self.idx_y = idx_y
222
+ self.idx_y_prop = idx_y_prop
223
+ self.idx_s = idx_s
224
+ self.save_compiled = save_compiled
225
+ self.ctcs_node_intervals = ctcs_node_intervals
226
+ self.constraints_ctcs = constraints_ctcs if constraints_ctcs is not None else []
227
+ self.constraints_nodal = constraints_nodal if constraints_nodal is not None else []
228
+ self.n_states = n_states
229
+ self.n_states_prop = n_states_prop
230
+ self.n_controls = n_controls
231
+ self.scaling_x_overrides = scaling_x_overrides
232
+ self.scaling_u_overrides = scaling_u_overrides
233
+
234
+ # Then call post init logic
235
+ self.__post_init__()
154
236
 
155
237
  def __post_init__(self):
156
- self.n_states = len(self.max_state)
157
- self.n_controls = len(self.max_control)
158
-
159
- assert (
160
- len(self.initial_state.value) == self.n_states - (self.idx_y.stop - self.idx_y.start)
161
- ), f"Initial state must have {self.n_states - (self.idx_y.stop - self.idx_y.start)} elements"
162
- assert (
163
- len(self.final_state.value) == self.n_states - (self.idx_y.stop - self.idx_y.start)
164
- ), f"Final state must have {self.n_states - (self.idx_y.stop - self.idx_y.start)} elements"
165
- assert (
166
- self.max_state.shape[0] == self.n_states
167
- ), f"Max state must have {self.n_states} elements"
168
- assert (
169
- self.min_state.shape[0] == self.n_states
170
- ), f"Min state must have {self.n_states} elements"
171
- assert (
172
- self.max_control.shape[0] == self.n_controls
173
- ), f"Max control must have {self.n_controls} elements"
174
- assert (
175
- self.min_control.shape[0] == self.n_controls
176
- ), f"Min control must have {self.n_controls} elements"
177
-
178
- if self.S_x is None or self.c_x is None:
179
- self.S_x, self.c_x = get_affine_scaling_matrices(
180
- self.n_states, self.min_state, self.max_state
181
- )
182
- # Use the fact that S_x is diagonal to compute the inverse
183
- self.inv_S_x = np.diag(1 / np.diag(self.S_x))
184
- if self.S_u is None or self.c_u is None:
185
- self.S_u, self.c_u = get_affine_scaling_matrices(
186
- self.n_controls, self.min_control, self.max_control
187
- )
188
- self.inv_S_u = np.diag(1 / np.diag(self.S_u))
238
+ self.n_states = len(self.x.max)
239
+ self.n_controls = len(self.u.max)
240
+
241
+ # Helper to apply overrides
242
+ def apply_overrides(size, overrides, min_arr, max_arr):
243
+ upper = np.array(max_arr, dtype=float)
244
+ lower = np.array(min_arr, dtype=float)
245
+ if overrides is not None:
246
+ for ub, lb, idx in overrides:
247
+ if isinstance(idx, int):
248
+ idxs = [idx]
249
+ elif isinstance(idx, slice):
250
+ idxs = list(range(*idx.indices(size)))
251
+ else:
252
+ idxs = list(idx)
253
+ if np.isscalar(ub):
254
+ ub_vals = [ub] * len(idxs)
255
+ else:
256
+ ub_vals = ub
257
+ if np.isscalar(lb):
258
+ lb_vals = [lb] * len(idxs)
259
+ else:
260
+ lb_vals = lb
261
+ for i, uval, lval in zip(idxs, ub_vals, lb_vals):
262
+ upper[i] = uval
263
+ lower[i] = lval
264
+ return upper, lower
265
+
266
+ # State scaling
267
+ min_x = np.array(self.x.min)
268
+ max_x = np.array(self.x.max)
269
+ upper_x, lower_x = apply_overrides(self.n_states, self.scaling_x_overrides, min_x, max_x)
270
+ S_x, c_x = get_affine_scaling_matrices(self.n_states, lower_x, upper_x)
271
+ self.S_x = S_x
272
+ self.c_x = c_x
273
+ self.inv_S_x = np.diag(1 / np.diag(self.S_x))
274
+
275
+ # Control scaling
276
+ min_u = np.array(self.u.min)
277
+ max_u = np.array(self.u.max)
278
+ upper_u, lower_u = apply_overrides(self.n_controls, self.scaling_u_overrides, min_u, max_u)
279
+ S_u, c_u = get_affine_scaling_matrices(self.n_controls, lower_u, upper_u)
280
+ self.S_u = S_u
281
+ self.c_u = c_u
282
+ self.inv_S_u = np.diag(1 / np.diag(self.S_u))
283
+
189
284
 
190
285
 
191
286
  @dataclass
192
287
  class ScpConfig:
193
- n: int = None
194
- k_max: int = 200
195
- w_tr: float = 1e0
196
- lam_vc: float = 1e0
197
- ep_tr: float = 1e-4
198
- ep_vb: float = 1e-4
199
- ep_vc: float = 1e-8
200
- lam_cost: float = 0.0
201
- lam_vb: float = 0.0
202
- uniform_time_grid: bool = False
203
- cost_drop: int = -1
204
- cost_relax: float = 1.0
205
- w_tr_adapt: float = 1.0
206
- w_tr_max: float = None
207
- w_tr_max_scaling_factor: float = None
208
-
209
- """
210
- Configuration class for Sequential Convex Programming (SCP).
211
-
212
- This class defines the parameters used to configure the SCP solver. You will very likely need to modify
213
- the weights for your problem. Please refer to my guide [here](https://haynec.github.io/openscvx/hyperparameter_tuning) for more information.
214
-
215
- Attributes:
216
- n (int): The number of discretization nodes. Defaults to `None`.
217
- k_max (int): The maximum number of SCP iterations. Defaults to 200.
218
- w_tr (float): The trust region weight. Defaults to 1.0.
219
- lam_vc (float): The penalty weight for virtual control. Defaults to 1.0.
220
- ep_tr (float): The trust region convergence tolerance. Defaults to 1e-4.
221
- ep_vb (float): The boundary constraint convergence tolerance. Defaults to 1e-4.
222
- ep_vc (float): The virtual constraint convergence tolerance. Defaults to 1e-8.
223
- lam_cost (float): The weight for original cost. Defaults to 0.0.
224
- lam_vb (float): The weight for virtual buffer. This is only used if there are nonconvex nodal constraints present. Defaults to 0.0.
225
- uniform_time_grid (bool): Whether to use a uniform time grid. TODO haynec add a link to the time dilation page. Defaults to `False`.
226
- cost_drop (int): The number of iterations to allow for cost stagnation before termination. Defaults to -1 (disabled).
227
- cost_relax (float): The relaxation factor for cost reduction. Defaults to 1.0.
228
- w_tr_adapt (float): The adaptation factor for the trust region weight. Defaults to 1.0.
229
- w_tr_max (float): The maximum allowable trust region weight. Defaults to `None`.
230
- w_tr_max_scaling_factor (float): The scaling factor for the maximum trust region weight. Defaults to `None`.
231
- """
288
+
289
+ def __init__(
290
+ self,
291
+ n: Optional[int] = None,
292
+ k_max: int = 200,
293
+ w_tr: float = 1.0,
294
+ lam_vc: float = 1.0,
295
+ ep_tr: float = 1e-4,
296
+ ep_vb: float = 1e-4,
297
+ ep_vc: float = 1e-8,
298
+ lam_cost: float = 0.0,
299
+ lam_vb: float = 0.0,
300
+ uniform_time_grid: bool = False,
301
+ cost_drop: int = -1,
302
+ cost_relax: float = 1.0,
303
+ w_tr_adapt: float = 1.0,
304
+ w_tr_max: Optional[float] = None,
305
+ w_tr_max_scaling_factor: Optional[float] = None,
306
+ ):
307
+ """
308
+ Configuration class for Sequential Convex Programming (SCP).
309
+
310
+ This class defines the parameters used to configure the SCP solver. You will very likely need to modify
311
+ the weights for your problem. Please refer to my guide [here](https://haynec.github.io/openscvx/hyperparameter_tuning) for more information.
312
+
313
+ Attributes:
314
+ n (int): The number of discretization nodes. Defaults to `None`.
315
+ k_max (int): The maximum number of SCP iterations. Defaults to 200.
316
+ w_tr (float): The trust region weight. Defaults to 1.0.
317
+ lam_vc (float): The penalty weight for virtual control. Defaults to 1.0.
318
+ ep_tr (float): The trust region convergence tolerance. Defaults to 1e-4.
319
+ ep_vb (float): The boundary constraint convergence tolerance. Defaults to 1e-4.
320
+ ep_vc (float): The virtual constraint convergence tolerance. Defaults to 1e-8.
321
+ lam_cost (float): The weight for original cost. Defaults to 0.0.
322
+ lam_vb (float): The weight for virtual buffer. This is only used if there are nonconvex nodal constraints present. Defaults to 0.0.
323
+ uniform_time_grid (bool): Whether to use a uniform time grid. Defaults to `False`.
324
+ cost_drop (int): The number of iterations to allow for cost stagnation before termination. Defaults to -1 (disabled).
325
+ cost_relax (float): The relaxation factor for cost reduction. Defaults to 1.0.
326
+ w_tr_adapt (float): The adaptation factor for the trust region weight. Defaults to 1.0.
327
+ w_tr_max (float): The maximum allowable trust region weight. Defaults to `None`.
328
+ w_tr_max_scaling_factor (float): The scaling factor for the maximum trust region weight. Defaults to `None`.
329
+ """
330
+ self.n = n
331
+ self.k_max = k_max
332
+ self.w_tr = w_tr
333
+ self.lam_vc = lam_vc
334
+ self.ep_tr = ep_tr
335
+ self.ep_vb = ep_vb
336
+ self.ep_vc = ep_vc
337
+ self.lam_cost = lam_cost
338
+ self.lam_vb = lam_vb
339
+ self.uniform_time_grid = uniform_time_grid
340
+ self.cost_drop = cost_drop
341
+ self.cost_relax = cost_relax
342
+ self.w_tr_adapt = w_tr_adapt
343
+ self.w_tr_max = w_tr_max
344
+ self.w_tr_max_scaling_factor = w_tr_max_scaling_factor
232
345
 
233
346
  def __post_init__(self):
234
347
  keys_to_scale = ["w_tr", "lam_vc", "lam_cost", "lam_vb"]
@@ -239,7 +352,6 @@ class ScpConfig:
239
352
  if self.w_tr_max_scaling_factor is not None and self.w_tr_max is None:
240
353
  self.w_tr_max = self.w_tr_max_scaling_factor * self.w_tr
241
354
 
242
-
243
355
  @dataclass
244
356
  class Config:
245
357
  sim: SimConfig
@@ -1,10 +1,7 @@
1
- from .boundary import boundary, BoundaryConstraint
2
1
  from .ctcs import ctcs, CTCSConstraint
3
2
  from .nodal import nodal, NodalConstraint
4
3
 
5
4
  __all__ = [
6
- "boundary",
7
- "BoundaryConstraint",
8
5
  "ctcs",
9
6
  "CTCSConstraint",
10
7
  "nodal",