openscvx 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

@@ -0,0 +1,170 @@
1
+ import jax.numpy as jnp
2
+ import numpy as np
3
+
4
+ from openscvx.dynamics import Dynamics
5
+ from openscvx.integrators import solve_ivp_rk45, solve_ivp_diffrax
6
+
7
+
8
+ def dVdt(
9
+ tau: float,
10
+ V: jnp.ndarray,
11
+ u_cur: np.ndarray,
12
+ u_next: np.ndarray,
13
+ state_dot: callable,
14
+ A: callable,
15
+ B: callable,
16
+ n_x: int,
17
+ n_u: int,
18
+ N: int,
19
+ dis_type: str,
20
+ ) -> jnp.ndarray:
21
+ # Define the nodes
22
+ nodes = jnp.arange(0, N-1)
23
+
24
+ # Define indices for slicing the augmented state vector
25
+ i0 = 0
26
+ i1 = n_x
27
+ i2 = i1 + n_x * n_x
28
+ i3 = i2 + n_x * n_u
29
+ i4 = i3 + n_x * n_u
30
+ i5 = i4 + n_x
31
+
32
+ # Unflatten V
33
+ V = V.reshape(-1, i5)
34
+
35
+ # Compute the interpolation factor based on the discretization type
36
+ if dis_type == "ZOH":
37
+ beta = 0.0
38
+ elif dis_type == "FOH":
39
+ beta = (tau) * N
40
+ alpha = 1 - beta
41
+
42
+ # Interpolate the control input
43
+ u = u_cur + beta * (u_next - u_cur)
44
+ s = u[:, -1]
45
+
46
+ # Initialize the augmented Jacobians
47
+ dfdx = jnp.zeros((V.shape[0], n_x, n_x))
48
+ dfdu = jnp.zeros((V.shape[0], n_x, n_u))
49
+
50
+ # Ensure x_seq and u have the same batch size
51
+ x = V[:, :n_x]
52
+ u = u[: x.shape[0]]
53
+
54
+ # Compute the nonlinear propagation term
55
+ f = state_dot(x, u[:, :-1], nodes)
56
+ F = s[:, None] * f
57
+
58
+ # Evaluate the State Jacobian
59
+ dfdx = A(x, u[:, :-1], nodes)
60
+ sdfdx = s[:, None, None] * dfdx
61
+
62
+ # Evaluate the Control Jacobian
63
+ dfdu_veh = B(x, u[:, :-1], nodes)
64
+ dfdu = dfdu.at[:, :, :-1].set(s[:, None, None] * dfdu_veh)
65
+ dfdu = dfdu.at[:, :, -1].set(f)
66
+
67
+ # Compute the defect
68
+ z = F - jnp.einsum("ijk,ik->ij", sdfdx, x) - jnp.einsum("ijk,ik->ij", dfdu, u)
69
+
70
+ # Stack up the results into the augmented state vector
71
+ # fmt: off
72
+ dVdt = jnp.zeros_like(V)
73
+ dVdt = dVdt.at[:, i0:i1].set(F)
74
+ dVdt = dVdt.at[:, i1:i2].set(jnp.matmul(sdfdx, V[:, i1:i2].reshape(-1, n_x, n_x)).reshape(-1, n_x * n_x))
75
+ dVdt = dVdt.at[:, i2:i3].set((jnp.matmul(sdfdx, V[:, i2:i3].reshape(-1, n_x, n_u)) + dfdu * alpha).reshape(-1, n_x * n_u))
76
+ dVdt = dVdt.at[:, i3:i4].set((jnp.matmul(sdfdx, V[:, i3:i4].reshape(-1, n_x, n_u)) + dfdu * beta).reshape(-1, n_x * n_u))
77
+ dVdt = dVdt.at[:, i4:i5].set((jnp.matmul(sdfdx, V[:, i4:i5].reshape(-1, n_x)[..., None]).squeeze(-1) + z).reshape(-1, n_x))
78
+ # fmt: on
79
+ return dVdt.flatten()
80
+
81
+
82
+ def calculate_discretization(
83
+ x,
84
+ u,
85
+ state_dot: callable,
86
+ A: callable,
87
+ B: callable,
88
+ n_x: int,
89
+ n_u: int,
90
+ N: int,
91
+ custom_integrator: bool,
92
+ debug: bool,
93
+ solver: str,
94
+ rtol,
95
+ atol,
96
+ dis_type: str,
97
+ ):
98
+
99
+ # Define indices for slicing the augmented state vector
100
+ i0 = 0
101
+ i1 = n_x
102
+ i2 = i1 + n_x * n_x
103
+ i3 = i2 + n_x * n_u
104
+ i4 = i3 + n_x * n_u
105
+ i5 = i4 + n_x
106
+
107
+ # initial augmented state
108
+ V0 = jnp.zeros((N - 1, i5))
109
+ V0 = V0.at[:, :n_x].set(x[:-1].astype(float))
110
+ V0 = V0.at[:, n_x : n_x + n_x * n_x].set(
111
+ jnp.eye(n_x).reshape(1, -1).repeat(N - 1, axis=0)
112
+ )
113
+
114
+ # choose integrator
115
+ if custom_integrator:
116
+ # fmt: off
117
+ sol = solve_ivp_rk45(
118
+ lambda t,y,*a: dVdt(t, y, *a),
119
+ 1.0/(N-1),
120
+ V0.reshape(-1),
121
+ args=(u[:-1].astype(float), u[1:].astype(float),
122
+ state_dot, A, B, n_x, n_u, N, dis_type),
123
+ is_not_compiled=debug,
124
+ )
125
+ # fmt: on
126
+ else:
127
+ # fmt: off
128
+ sol = solve_ivp_diffrax(
129
+ lambda t,y,*a: dVdt(t, y, *a),
130
+ 1.0/(N-1),
131
+ V0.reshape(-1),
132
+ args=(u[:-1].astype(float), u[1:].astype(float),
133
+ state_dot, A, B, n_x, n_u, N, dis_type),
134
+ solver_name=solver,
135
+ rtol=rtol,
136
+ atol=atol,
137
+ extra_kwargs=None,
138
+ )
139
+ # fmt: on
140
+
141
+ Vend = sol[-1].T.reshape(-1, i5)
142
+ Vmulti = sol.T
143
+
144
+ # fmt: off
145
+ A_bar = Vend[:, i1:i2].reshape(N-1, n_x, n_x).transpose(1,2,0).reshape(n_x*n_x, -1, order='F').T
146
+ B_bar = Vend[:, i2:i3].reshape(N-1, n_x, n_u).transpose(1,2,0).reshape(n_x*n_u, -1, order='F').T
147
+ C_bar = Vend[:, i3:i4].reshape(N-1, n_x, n_u).transpose(1,2,0).reshape(n_x*n_u, -1, order='F').T
148
+ z_bar = Vend[:, i4:i5]
149
+ # fmt: on
150
+
151
+ return A_bar, B_bar, C_bar, z_bar, Vmulti
152
+
153
+
154
+ def get_discretization_solver(dyn: Dynamics, params):
155
+ return lambda x, u: calculate_discretization(
156
+ x=x,
157
+ u=u,
158
+ state_dot=dyn.f,
159
+ A=dyn.A,
160
+ B=dyn.B,
161
+ n_x=params.sim.n_states,
162
+ n_u=params.sim.n_controls,
163
+ N=params.scp.n,
164
+ custom_integrator=params.dis.custom_integrator,
165
+ debug=params.dev.debug,
166
+ solver=params.dis.solver,
167
+ rtol=params.dis.rtol,
168
+ atol=params.dis.atol,
169
+ dis_type=params.dis.dis_type,
170
+ )
openscvx/dynamics.py ADDED
@@ -0,0 +1,41 @@
1
+ from dataclasses import dataclass
2
+ from typing import Callable, Optional
3
+ import functools
4
+
5
+ import jax.numpy as jnp
6
+
7
+
8
+ @dataclass
9
+ class Dynamics:
10
+ f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
11
+ A: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
12
+ B: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
13
+
14
+ def dynamics(
15
+ _func=None,
16
+ *,
17
+ A: Optional[Callable] = None,
18
+ B: Optional[Callable] = None,):
19
+ """Decorator to mark a function as defining the system dynamics.
20
+
21
+ Use as:
22
+ @dynamics(A=my_grad_f_x, B=my_grad_f_u)')
23
+ def my_dynamics(x,u): ...
24
+ """
25
+
26
+ def decorator(f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
27
+ # wrap so name, doc, signature stay on f
28
+ wrapped = functools.wraps(f)(f)
29
+ return Dynamics(
30
+ f=wrapped,
31
+ A=A,
32
+ B=B,
33
+ )
34
+
35
+ # if called as @dynamics or @dynamics(...), _func will be None and we return decorator
36
+ if _func is None:
37
+ return decorator
38
+ # if called as dynamics(func), we immediately decorate
39
+ else:
40
+ return decorator(_func)
41
+
@@ -0,0 +1,139 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import diffrax as dfx
4
+
5
+ SOLVER_MAP = {
6
+ "Tsit5": dfx.Tsit5,
7
+ "Euler": dfx.Euler,
8
+ "Heun": dfx.Heun,
9
+ "Midpoint": dfx.Midpoint,
10
+ "Ralston": dfx.Ralston,
11
+ "Dopri5": dfx.Dopri5,
12
+ "Dopri8": dfx.Dopri8,
13
+ "Bosh3": dfx.Bosh3,
14
+ "ReversibleHeun": dfx.ReversibleHeun,
15
+ "ImplicitEuler": dfx.ImplicitEuler,
16
+ "KenCarp3": dfx.KenCarp3,
17
+ "KenCarp4": dfx.KenCarp4,
18
+ "KenCarp5": dfx.KenCarp5,
19
+ }
20
+
21
+ # fmt: off
22
+ def rk45_step(f, t, y, h, *args):
23
+ k1 = f(t, y, *args)
24
+ k2 = f(t + h/4, y + h*k1/4, *args)
25
+ k3 = f(t + 3*h/8, y + 3*h*k1/32 + 9*h*k2/32, *args)
26
+ k4 = f(t + 12*h/13, y + 1932*h*k1/2197 - 7200*h*k2/2197 + 7296*h*k3/2197, *args)
27
+ k5 = f(t + h, y + 439*h*k1/216 - 8*h*k2 + 3680*h*k3/513 - 845*h*k4/4104, *args)
28
+ y_next = y + h * (25*k1/216 + 1408*k3/2565 + 2197*k4/4104 - k5/5)
29
+ return y_next
30
+ # fmt: on
31
+
32
+
33
+ def solve_ivp_rk45(
34
+ f,
35
+ tau_final: float,
36
+ y_0,
37
+ args,
38
+ tau_0: float = 0.0,
39
+ num_substeps: int = 50,
40
+ is_not_compiled: bool = False,
41
+ ):
42
+ substeps = jnp.linspace(tau_0, tau_final, num_substeps)
43
+
44
+ h = (tau_final - tau_0) / (len(substeps) - 1)
45
+ solution = jnp.zeros((len(substeps), len(y_0)))
46
+ solution = solution.at[0].set(y_0)
47
+
48
+ if is_not_compiled:
49
+ for i in range(1, len(substeps)):
50
+ t = tau_0 + i * h
51
+ solution = solution.at[i].set(rk45_step(f, t, solution[i - 1], h, *args))
52
+ else:
53
+
54
+ def body_fun(i, val):
55
+ t, y, V_result = val
56
+ y_next = rk45_step(f, t, y, h, *args)
57
+ V_result = V_result.at[i].set(y_next)
58
+ return (t + h, y_next, V_result)
59
+
60
+ _, _, solution = jax.lax.fori_loop(
61
+ 1, len(substeps), body_fun, (tau_0, y_0, solution)
62
+ )
63
+
64
+ return solution
65
+
66
+
67
+ def solve_ivp_diffrax(
68
+ f,
69
+ tau_final,
70
+ y_0,
71
+ args,
72
+ tau_0: float = 0.0,
73
+ num_substeps: int = 50,
74
+ solver_name="Dopri8",
75
+ rtol: float = 1e-3,
76
+ atol: float = 1e-6,
77
+ extra_kwargs=None,
78
+ ):
79
+ substeps = jnp.linspace(tau_0, tau_final, num_substeps)
80
+
81
+ solver_class = SOLVER_MAP.get(solver_name)
82
+ if solver_class is None:
83
+ raise ValueError(f"Unknown solver: {solver_name}")
84
+ solver = solver_class()
85
+
86
+ term = dfx.ODETerm(lambda t, y, args: f(t, y, *args))
87
+ stepsize_controller = dfx.PIDController(rtol=rtol, atol=atol)
88
+ solution = dfx.diffeqsolve(
89
+ term,
90
+ solver=solver,
91
+ t0=tau_0,
92
+ t1=tau_final,
93
+ dt0=(tau_final - tau_0) / (len(substeps) - 1),
94
+ y0=y_0,
95
+ args=args,
96
+ stepsize_controller=stepsize_controller,
97
+ saveat=dfx.SaveAt(ts=substeps),
98
+ **(extra_kwargs or {}),
99
+ )
100
+
101
+ return solution.ys
102
+
103
+
104
+ # TODO: (norrisg) this function is basically identical to `solve_ivp_diffrax`, could combine, but requires returning solution and getting `.ys` wherever the `solve_ivp_diffrax` is called
105
+ def solve_ivp_diffrax_prop(
106
+ f,
107
+ tau_final,
108
+ y_0,
109
+ args,
110
+ tau_0: float = 0.0,
111
+ num_substeps: int = 50,
112
+ solver_name="Dopri8",
113
+ rtol: float = 1e-3,
114
+ atol: float = 1e-6,
115
+ extra_kwargs=None,
116
+ ):
117
+ substeps = jnp.linspace(tau_0, tau_final, num_substeps)
118
+
119
+ solver_class = SOLVER_MAP.get(solver_name)
120
+ if solver_class is None:
121
+ raise ValueError(f"Unknown solver: {solver_name}")
122
+ solver = solver_class()
123
+
124
+ term = dfx.ODETerm(lambda t, y, args: f(t, y, *args))
125
+ stepsize_controller = dfx.PIDController(rtol=rtol, atol=atol)
126
+ solution = dfx.diffeqsolve(
127
+ term,
128
+ solver=solver,
129
+ t0=tau_0,
130
+ t1=tau_final,
131
+ dt0=(tau_final - tau_0) / (len(substeps) - 1),
132
+ y0=y_0,
133
+ args=args,
134
+ stepsize_controller=stepsize_controller,
135
+ saveat=dfx.SaveAt(dense=True, ts=substeps),
136
+ **(extra_kwargs or {}),
137
+ )
138
+
139
+ return solution
openscvx/io.py ADDED
@@ -0,0 +1,81 @@
1
+ import sys
2
+ import warnings
3
+ warnings.filterwarnings("ignore")
4
+ import queue
5
+ import time
6
+
7
+ from termcolor import colored
8
+
9
+ # Define colors for printing
10
+ col_main = "blue"
11
+ col_pos = "green"
12
+ col_neg = "red"
13
+
14
+ def intro():
15
+ # Silence syntax warnings
16
+ warnings.filterwarnings("ignore")
17
+ # fmt: off
18
+ ascii_art = '''
19
+
20
+ ____ _____ _____
21
+ / __ \ / ____|/ ____|
22
+ | | | |_ __ ___ _ __ | (___ | | __ ____ __
23
+ | | | | '_ \ / _ \ '_ \ \___ \| | \ \ / /\ \/ /
24
+ | |__| | |_) | __/ | | |____) | |___\ V / > <
25
+ \____/| .__/ \___|_| |_|_____/ \_____\_/ /_/\_\
26
+ | |
27
+ |_|
28
+ ---------------------------------------------------------------------------------------------------------
29
+ Author: Chris Hayner and Griffin Norris
30
+ Autonomous Controls Laboratory
31
+ University of Washington
32
+ ---------------------------------------------------------------------------------------------------------
33
+ '''
34
+ # fmt: on
35
+ print(ascii_art)
36
+
37
+ def header():
38
+ print("{:^4} | {:^7} | {:^7} | {:^7} | {:^7} | {:^7} | {:^7} | {:^7} | {:^14}".format(
39
+ "Iter", "Dis Time (ms)", "Solve Time (ms)", "J_total", "J_tr", "J_vb", "J_vc", "Cost", "Solver Status"))
40
+ print(colored("---------------------------------------------------------------------------------------------------------"))
41
+
42
+ def intermediate(print_queue, params):
43
+ hz = 30.0
44
+ while True:
45
+ t_start = time.time()
46
+ try:
47
+ data = print_queue.get(timeout=1.0/hz)
48
+ # remove bottom labels and line
49
+ if not data["iter"] == 1:
50
+ sys.stdout.write('\x1b[1A\x1b[2K\x1b[1A\x1b[2K')
51
+ if data["prob_stat"][3] == 'f':
52
+ # Only show the first element of the string
53
+ data["prob_stat"] = data["prob_stat"][0]
54
+
55
+ iter_colored = colored("{:4d}".format(data["iter"]))
56
+ J_tot_colored = colored("{:.1e}".format(data["J_total"]))
57
+ J_tr_colored = colored("{:.1e}".format(data["J_tr"]), col_pos if data["J_tr"] <= params.scp.ep_tr else col_neg)
58
+ J_vb_colored = colored("{:.1e}".format(data["J_vb"]), col_pos if data["J_vb"] <= params.scp.ep_vb else col_neg)
59
+ J_vc_colored = colored("{:.1e}".format(data["J_vc"]), col_pos if data["J_vc"] <= params.scp.ep_vc else col_neg)
60
+ cost_colored = colored("{:.1e}".format(data["cost"]))
61
+ prob_stat_colored = colored(data["prob_stat"], col_pos if data["prob_stat"] == 'optimal' else col_neg)
62
+
63
+ print("{:^4} | {:^6.2f} | {:^6.2F} | {:^7} | {:^7} | {:^7} | {:^7} | {:^7} | {:^14}".format(
64
+ iter_colored, data["dis_time"], data["subprop_time"], J_tot_colored, J_tr_colored, J_vb_colored, J_vc_colored, cost_colored, prob_stat_colored))
65
+
66
+ print(colored("---------------------------------------------------------------------------------------------------------"))
67
+ print("{:^4} | {:^7} | {:^7} | {:^7} | {:^7} | {:^7} | {:^7} | {:^7} | {:^14}".format(
68
+ "Iter", "Dis Time (ms)", "Solve Time (ms)", "J_total", "J_tr", "J_vb", "J_vc", "Cost", "Solver Status"))
69
+ except queue.Empty:
70
+ pass
71
+ time.sleep(max(0.0, 1.0/hz - (time.time() - t_start)))
72
+
73
+ def footer(computation_time):
74
+ print(colored("---------------------------------------------------------------------------------------------------------"))
75
+ # Define ANSI color codes
76
+ BOLD = "\033[1m"
77
+ RESET = "\033[0m"
78
+
79
+ # Print with bold text
80
+ print("------------------------------------------------ " + BOLD + "RESULTS" + RESET + " ------------------------------------------------")
81
+ print("Total Computation Time: ", computation_time)
openscvx/ocp.py ADDED
@@ -0,0 +1,160 @@
1
+ import os
2
+ import numpy.linalg as la
3
+ from numpy import block
4
+ import numpy as np
5
+ import cvxpy as cp
6
+ from cvxpygen import cpg
7
+ from openscvx.config import Config
8
+ from cvxpygen import cpg
9
+
10
+
11
+ def OptimalControlProblem(params: Config):
12
+ ########################
13
+ # VARIABLES & PARAMETERS
14
+ ########################
15
+
16
+ # Parameters
17
+ w_tr = cp.Parameter(nonneg = True, name='w_tr')
18
+ lam_cost = cp.Parameter(nonneg=True, name='lam_cost')
19
+
20
+ # State
21
+ x = cp.Variable((params.scp.n, params.sim.n_states), name='x') # Current State
22
+ dx = cp.Variable((params.scp.n, params.sim.n_states), name='dx') # State Error
23
+ x_bar = cp.Parameter((params.scp.n, params.sim.n_states), name='x_bar') # Previous SCP State
24
+
25
+ # Affine Scaling for State
26
+ S_x = params.sim.S_x
27
+ inv_S_x = params.sim.inv_S_x
28
+ c_x = params.sim.c_x
29
+
30
+ # Control
31
+ u = cp.Variable((params.scp.n, params.sim.n_controls), name='u') # Current Control
32
+ du = cp.Variable((params.scp.n, params.sim.n_controls), name='du') # Control Error
33
+ u_bar = cp.Parameter((params.scp.n, params.sim.n_controls), name='u_bar') # Previous SCP Control
34
+
35
+ # Affine Scaling for Control
36
+ S_u = params.sim.S_u
37
+ inv_S_u = params.sim.inv_S_u
38
+ c_u = params.sim.c_u
39
+
40
+ # Discretized Augmented Dynamics Constraints
41
+ A_d = cp.Parameter((params.scp.n - 1, (params.sim.n_states)*(params.sim.n_states)), name='A_d')
42
+ B_d = cp.Parameter((params.scp.n - 1, params.sim.n_states*params.sim.n_controls), name='B_d')
43
+ C_d = cp.Parameter((params.scp.n - 1, params.sim.n_states*params.sim.n_controls), name='C_d')
44
+ z_d = cp.Parameter((params.scp.n - 1, params.sim.n_states), name='z_d')
45
+ nu = cp.Variable((params.scp.n - 1, params.sim.n_states), name='nu') # Virtual Control
46
+
47
+ # Linearized Nonconvex Nodal Constraints
48
+ if params.sim.constraints_nodal:
49
+ g = []
50
+ grad_g_x = []
51
+ grad_g_u = []
52
+ nu_vb = []
53
+ for idx_ncvx, constraint in enumerate(params.sim.constraints_nodal):
54
+ if not constraint.convex:
55
+ g.append(cp.Parameter(params.scp.n, name = 'g_' + str(idx_ncvx)))
56
+ grad_g_x.append(cp.Parameter((params.scp.n, params.sim.n_states), name='grad_g_x_' + str(idx_ncvx)))
57
+ grad_g_u.append(cp.Parameter((params.scp.n, params.sim.n_controls), name='grad_g_u_' + str(idx_ncvx)))
58
+ nu_vb.append(cp.Variable(params.scp.n, name='nu_vb_' + str(idx_ncvx))) # Virtual Control for VB
59
+
60
+ # Applying the affine scaling to state and control
61
+ x_nonscaled = []
62
+ u_nonscaled = []
63
+ for k in range(params.scp.n):
64
+ x_nonscaled.append(S_x @ x[k] + c_x)
65
+ u_nonscaled.append(S_u @ u[k] + c_u)
66
+
67
+ constr = []
68
+ cost = lam_cost * 0
69
+
70
+ #############
71
+ # CONSTRAINTS
72
+ #############
73
+ idx_ncvx = 0
74
+ if params.sim.constraints_nodal:
75
+ for constraint in params.sim.constraints_nodal:
76
+ if constraint.nodes is None:
77
+ nodes = range(params.scp.n)
78
+ else:
79
+ nodes = constraint.nodes
80
+
81
+ if constraint.convex:
82
+ constr += [constraint(x_nonscaled[node], u_nonscaled[node]) for node in nodes]
83
+
84
+ elif not constraint.convex:
85
+ constr += [((g[idx_ncvx][node] + grad_g_x[idx_ncvx][node] @ dx[node] + grad_g_u[idx_ncvx][node] @ du[node])) == nu_vb[idx_ncvx][node] for node in nodes]
86
+ idx_ncvx += 1
87
+
88
+ for i in range(params.sim.idx_x_true.start, params.sim.idx_x_true.stop):
89
+ if params.sim.initial_state.type[i] == 'Fix':
90
+ constr += [x_nonscaled[0][i] == params.sim.initial_state.value[i]] # Initial Boundary Conditions
91
+ if params.sim.final_state.type[i] == 'Fix':
92
+ constr += [x_nonscaled[-1][i] == params.sim.final_state.value[i]] # Final Boundary Conditions
93
+ if params.sim.initial_state.type[i] == 'Minimize':
94
+ cost += lam_cost * x_nonscaled[0][i]
95
+ if params.sim.final_state.type[i] == 'Minimize':
96
+ cost += lam_cost * x_nonscaled[-1][i]
97
+ if params.sim.initial_state.type[i] == 'Maximize':
98
+ cost += lam_cost * x_nonscaled[0][i]
99
+ if params.sim.final_state.type[i] == 'Maximize':
100
+ cost += lam_cost * x_nonscaled[-1][i]
101
+
102
+ if params.scp.uniform_time_grid:
103
+ constr += [x_nonscaled[i][params.sim.idx_t] - x_nonscaled[i-1][params.sim.idx_t] == x_nonscaled[i-1][params.sim.idx_t] - x_nonscaled[i-2][params.sim.idx_t] for i in range(2, params.scp.n)] # Uniform Time Step
104
+
105
+ constr += [0 == la.inv(S_x) @ (x_nonscaled[i] - x_bar[i] - dx[i]) for i in range(params.scp.n)] # State Error
106
+ constr += [0 == la.inv(S_u) @ (u_nonscaled[i] - u_bar[i] - du[i]) for i in range(params.scp.n)] # Control Error
107
+
108
+ constr += [x_nonscaled[i] == \
109
+ cp.reshape(A_d[i-1], (params.sim.n_states, params.sim.n_states)) @ x_nonscaled[i-1] \
110
+ + cp.reshape(B_d[i-1], (params.sim.n_states, params.sim.n_controls)) @ u_nonscaled[i-1] \
111
+ + cp.reshape(C_d[i-1], (params.sim.n_states, params.sim.n_controls)) @ u_nonscaled[i] \
112
+ + z_d[i-1] \
113
+ + nu[i-1] for i in range(1, params.scp.n)] # Dynamics Constraint
114
+
115
+ constr += [u_nonscaled[i] <= params.sim.max_control for i in range(params.scp.n)]
116
+ constr += [u_nonscaled[i] >= params.sim.min_control for i in range(params.scp.n)] # Control Constraints
117
+
118
+ constr += [x_nonscaled[i][params.sim.idx_x_true] <= params.sim.max_state[params.sim.idx_x_true] for i in range(params.scp.n)]
119
+ constr += [x_nonscaled[i][params.sim.idx_x_true] >= params.sim.min_state[params.sim.idx_x_true] for i in range(params.scp.n)] # State Constraints (Also implemented in CTCS but included for numerical stability)
120
+
121
+ ########
122
+ # COSTS
123
+ ########
124
+
125
+ inv = block([[inv_S_x, np.zeros((S_x.shape[0], S_u.shape[1]))], [np.zeros((S_u.shape[0], S_x.shape[1])), inv_S_u]])
126
+ cost += sum(w_tr * cp.sum_squares(inv @ cp.hstack((dx[i], du[i]))) for i in range(params.scp.n)) # Trust Region Cost
127
+ cost += sum(params.scp.lam_vc * cp.sum(cp.abs(nu[i-1])) for i in range(1, params.scp.n)) # Virtual Control Slack
128
+
129
+ idx_ncvx = 0
130
+ if params.sim.constraints_nodal:
131
+ for constraint in params.sim.constraints_nodal:
132
+ if not constraint.convex:
133
+ cost += params.scp.lam_vb * cp.sum(cp.pos(nu_vb[idx_ncvx]))
134
+ idx_ncvx += 1
135
+
136
+ for idx, nodes in zip(np.arange(params.sim.idx_y.start, params.sim.idx_y.stop), params.sim.ctcs_node_intervals):
137
+ if nodes[0] == 0:
138
+ start_idx = 1
139
+ else:
140
+ start_idx = nodes[0]
141
+ constr += [cp.abs(x_nonscaled[i][idx] - x_nonscaled[i-1][idx]) <= params.sim.max_state[idx] for i in range(start_idx, nodes[1])]
142
+ constr += [x_nonscaled[0][idx] == 0]
143
+
144
+
145
+ #########
146
+ # PROBLEM
147
+ #########
148
+ prob = cp.Problem(cp.Minimize(cost), constr)
149
+ if params.cvx.cvxpygen:
150
+ # Check to see if solver directory exists
151
+ if not os.path.exists('solver'):
152
+ cpg.generate_code(prob, solver = params.cvx.solver, code_dir='solver', wrapper = True)
153
+ else:
154
+ # Prompt the use to indicate if they wish to overwrite the solver directory or use the existing compiled solver
155
+ overwrite = input("Solver directory already exists. Overwrite? (y/n): ")
156
+ if overwrite.lower() == 'y':
157
+ cpg.generate_code(prob, solver = params.cvx.solver, code_dir='solver', wrapper = True)
158
+ else:
159
+ pass
160
+ return prob