openscvx 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +0 -0
- openscvx/_version.py +21 -0
- openscvx/augmentation/__init__.py +0 -0
- openscvx/augmentation/ctcs.py +44 -0
- openscvx/augmentation/dynamics_augmentation.py +122 -0
- openscvx/config.py +247 -0
- {constraints → openscvx/constraints}/ctcs.py +27 -3
- {constraints → openscvx/constraints}/nodal.py +17 -11
- openscvx/constraints/violation.py +67 -0
- openscvx/discretization.py +170 -0
- openscvx/dynamics.py +41 -0
- openscvx/integrators.py +139 -0
- openscvx/io.py +81 -0
- openscvx/ocp.py +160 -0
- openscvx/plotting.py +632 -0
- openscvx/post_processing.py +36 -0
- openscvx/propagation.py +135 -0
- openscvx/ptr.py +149 -0
- openscvx/trajoptproblem.py +337 -0
- openscvx/utils.py +80 -0
- {openscvx-0.1.0.dist-info → openscvx-0.1.2.dist-info}/METADATA +61 -18
- openscvx-0.1.2.dist-info/RECORD +27 -0
- openscvx-0.1.2.dist-info/top_level.txt +1 -0
- constraints/violation.py +0 -26
- openscvx-0.1.0.dist-info/RECORD +0 -10
- openscvx-0.1.0.dist-info/top_level.txt +0 -1
- {constraints → openscvx/constraints}/__init__.py +0 -0
- {constraints → openscvx/constraints}/boundary.py +0 -0
- {openscvx-0.1.0.dist-info → openscvx-0.1.2.dist-info}/WHEEL +0 -0
- {openscvx-0.1.0.dist-info → openscvx-0.1.2.dist-info}/licenses/LICENSE +0 -0
openscvx/propagation.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from openscvx.config import Config
|
|
4
|
+
from openscvx.integrators import solve_ivp_diffrax_prop
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def prop_aug_dy(
|
|
8
|
+
tau: float,
|
|
9
|
+
x: np.ndarray,
|
|
10
|
+
u_current: np.ndarray,
|
|
11
|
+
u_next: np.ndarray,
|
|
12
|
+
tau_init: float,
|
|
13
|
+
node: int,
|
|
14
|
+
idx_s: int,
|
|
15
|
+
state_dot: callable,
|
|
16
|
+
dis_type: str,
|
|
17
|
+
N: int,
|
|
18
|
+
) -> np.ndarray:
|
|
19
|
+
x = x[None, :]
|
|
20
|
+
|
|
21
|
+
if dis_type == "ZOH":
|
|
22
|
+
beta = 0.0
|
|
23
|
+
elif dis_type == "FOH":
|
|
24
|
+
beta = (tau - tau_init) * N
|
|
25
|
+
u = u_current + beta * (u_next - u_current)
|
|
26
|
+
|
|
27
|
+
return u[:, idx_s] * state_dot(x, u[:, :-1], node).squeeze()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_propagation_solver(state_dot, params):
|
|
31
|
+
def propagation_solver(V0, tau_grid, u_cur, u_next, tau_init, node, idx_s):
|
|
32
|
+
return solve_ivp_diffrax_prop(
|
|
33
|
+
f=prop_aug_dy,
|
|
34
|
+
tau_final=tau_grid[1],
|
|
35
|
+
y_0=V0,
|
|
36
|
+
args=(
|
|
37
|
+
u_cur,
|
|
38
|
+
u_next,
|
|
39
|
+
tau_init,
|
|
40
|
+
node,
|
|
41
|
+
idx_s,
|
|
42
|
+
state_dot,
|
|
43
|
+
params.dis.dis_type,
|
|
44
|
+
params.scp.n,
|
|
45
|
+
),
|
|
46
|
+
tau_0=tau_grid[0],
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
return propagation_solver
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def s_to_t(u, params: Config):
|
|
53
|
+
t = [0]
|
|
54
|
+
tau = np.linspace(0, 1, params.scp.n)
|
|
55
|
+
for k in range(1, params.scp.n):
|
|
56
|
+
s_kp = u[k - 1, -1]
|
|
57
|
+
s_k = u[k, -1]
|
|
58
|
+
if params.dis.dis_type == "ZOH":
|
|
59
|
+
t.append(t[k - 1] + (tau[k] - tau[k - 1]) * (s_kp))
|
|
60
|
+
else:
|
|
61
|
+
t.append(t[k - 1] + 0.5 * (s_k + s_kp) * (tau[k] - tau[k - 1]))
|
|
62
|
+
return t
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def t_to_tau(u, t, u_nodal, t_nodal, params: Config):
|
|
66
|
+
u_lam = lambda new_t: np.array(
|
|
67
|
+
[np.interp(new_t, t_nodal, u[:, i]) for i in range(u.shape[1])]
|
|
68
|
+
).T
|
|
69
|
+
u = np.array([u_lam(t_i) for t_i in t])
|
|
70
|
+
|
|
71
|
+
tau = np.zeros(len(t))
|
|
72
|
+
tau_nodal = np.linspace(0, 1, params.scp.n)
|
|
73
|
+
for k in range(1, len(t)):
|
|
74
|
+
k_nodal = np.where(t_nodal < t[k])[0][-1]
|
|
75
|
+
s_kp = u_nodal[k_nodal, -1]
|
|
76
|
+
tp = t_nodal[k_nodal]
|
|
77
|
+
tau_p = tau_nodal[k_nodal]
|
|
78
|
+
|
|
79
|
+
s_k = u[k, -1]
|
|
80
|
+
if params.dis.dis_type == "ZOH":
|
|
81
|
+
tau[k] = tau_p + (t[k] - tp) / s_kp
|
|
82
|
+
else:
|
|
83
|
+
tau[k] = tau_p + 2 * (t[k] - tp) / (s_k + s_kp)
|
|
84
|
+
return tau, u
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def simulate_nonlinear_time(x_0, u, tau_vals, t, params, propagation_solver):
|
|
88
|
+
states = np.empty(
|
|
89
|
+
(x_0.shape[0], 0)
|
|
90
|
+
) # Initialize states as a 2D array with shape (n, 0)
|
|
91
|
+
|
|
92
|
+
tau = np.linspace(0, 1, params.scp.n)
|
|
93
|
+
|
|
94
|
+
u_lam = lambda new_t: np.array(
|
|
95
|
+
[np.interp(new_t, t, u[:, i]) for i in range(u.shape[1])]
|
|
96
|
+
).T
|
|
97
|
+
|
|
98
|
+
# Bin the tau_vals into with respect to the uniform tau grid, tau
|
|
99
|
+
tau_inds = np.digitize(tau_vals, tau) - 1
|
|
100
|
+
# Force the last indice to be in the same bin as the previous ones
|
|
101
|
+
tau_inds = np.where(tau_inds == params.scp.n - 1, params.scp.n - 2, tau_inds)
|
|
102
|
+
|
|
103
|
+
prev_count = 0
|
|
104
|
+
|
|
105
|
+
for k in range(params.scp.n - 1):
|
|
106
|
+
controls_current = np.squeeze(u_lam(t[k]))[None, :]
|
|
107
|
+
controls_next = np.squeeze(u_lam(t[k + 1]))[None, :]
|
|
108
|
+
|
|
109
|
+
# Create a mask
|
|
110
|
+
mask = (tau_inds >= k) & (tau_inds < k + 1)
|
|
111
|
+
|
|
112
|
+
count = np.sum(mask)
|
|
113
|
+
|
|
114
|
+
# Use count to grab the first count number of elements
|
|
115
|
+
tau_cur = tau_vals[prev_count : prev_count + count]
|
|
116
|
+
|
|
117
|
+
sol = propagation_solver(
|
|
118
|
+
x_0,
|
|
119
|
+
(tau[k], tau[k + 1]),
|
|
120
|
+
controls_current,
|
|
121
|
+
controls_next,
|
|
122
|
+
np.array([[tau[k]]]),
|
|
123
|
+
np.array([[k]]),
|
|
124
|
+
params.sim.idx_s.stop,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
x = sol.ys
|
|
128
|
+
for tau_i in tau_cur:
|
|
129
|
+
new_state = sol.evaluate(tau_i).reshape(-1, 1) # Ensure new_state is 2D
|
|
130
|
+
states = np.concatenate([states, new_state], axis=1)
|
|
131
|
+
|
|
132
|
+
x_0 = x[-1]
|
|
133
|
+
prev_count += count
|
|
134
|
+
|
|
135
|
+
return states.T
|
openscvx/ptr.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import numpy.linalg as la
|
|
3
|
+
import cvxpy as cp
|
|
4
|
+
import pickle
|
|
5
|
+
import time
|
|
6
|
+
|
|
7
|
+
from openscvx.config import Config
|
|
8
|
+
|
|
9
|
+
import warnings
|
|
10
|
+
warnings.filterwarnings("ignore")
|
|
11
|
+
|
|
12
|
+
def PTR_init(ocp: cp.Problem, discretization_solver: callable, params: Config):
|
|
13
|
+
if params.cvx.cvxpygen:
|
|
14
|
+
from solver.cpg_solver import cpg_solve
|
|
15
|
+
with open('solver/problem.pickle', 'rb') as f:
|
|
16
|
+
prob = pickle.load(f)
|
|
17
|
+
else:
|
|
18
|
+
cpg_solve = None
|
|
19
|
+
|
|
20
|
+
# Solve a dumb problem to intilize DPP and JAX jacobians
|
|
21
|
+
_ = PTR_subproblem(cpg_solve, params.sim.x_bar, params.sim.u_bar, discretization_solver, ocp, params)
|
|
22
|
+
|
|
23
|
+
return cpg_solve
|
|
24
|
+
|
|
25
|
+
def PTR_main(params: Config, prob: cp.Problem, aug_dy: callable, cpg_solve, emitter_function) -> dict:
|
|
26
|
+
J_vb = 1E2
|
|
27
|
+
J_vc = 1E2
|
|
28
|
+
J_tr = 1E2
|
|
29
|
+
|
|
30
|
+
x_bar = params.sim.x_bar
|
|
31
|
+
u_bar = params.sim.u_bar
|
|
32
|
+
|
|
33
|
+
scp_trajs = [x_bar]
|
|
34
|
+
scp_controls = [u_bar]
|
|
35
|
+
V_multi_shoot_traj = []
|
|
36
|
+
|
|
37
|
+
k = 1
|
|
38
|
+
|
|
39
|
+
while k <= params.scp.k_max and ((J_tr >= params.scp.ep_tr) or (J_vb >= params.scp.ep_vb) or (J_vc >= params.scp.ep_vc)):
|
|
40
|
+
x, u, t, J_total, J_vb_vec, J_vc_vec, J_tr_vec, prob_stat, V_multi_shoot, subprop_time, dis_time = PTR_subproblem(cpg_solve, x_bar, u_bar, aug_dy, prob, params)
|
|
41
|
+
|
|
42
|
+
V_multi_shoot_traj.append(V_multi_shoot)
|
|
43
|
+
|
|
44
|
+
x_bar = x
|
|
45
|
+
u_bar = u
|
|
46
|
+
|
|
47
|
+
J_tr = np.sum(np.array(J_tr_vec))
|
|
48
|
+
J_vb = np.sum(np.array(J_vb_vec))
|
|
49
|
+
J_vc = np.sum(np.array(J_vc_vec))
|
|
50
|
+
scp_trajs.append(x)
|
|
51
|
+
scp_controls.append(u)
|
|
52
|
+
|
|
53
|
+
params.scp.w_tr = min(params.scp.w_tr * params.scp.w_tr_adapt, params.scp.w_tr_max)
|
|
54
|
+
if k > params.scp.cost_drop:
|
|
55
|
+
params.scp.lam_cost = params.scp.lam_cost * params.scp.cost_relax
|
|
56
|
+
|
|
57
|
+
emitter_function(
|
|
58
|
+
{
|
|
59
|
+
"iter": k,
|
|
60
|
+
"dis_time": dis_time * 1000.0,
|
|
61
|
+
"subprop_time": subprop_time * 1000.0,
|
|
62
|
+
"J_total": J_total,
|
|
63
|
+
"J_tr": J_tr,
|
|
64
|
+
"J_vb": J_vb,
|
|
65
|
+
"J_vc": J_vc,
|
|
66
|
+
"cost": t[-1],
|
|
67
|
+
"prob_stat": prob_stat,
|
|
68
|
+
}
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
k += 1
|
|
72
|
+
|
|
73
|
+
result = dict(
|
|
74
|
+
converged = k <= params.scp.k_max,
|
|
75
|
+
t_final = x[:,params.sim.idx_t][-1],
|
|
76
|
+
u = u,
|
|
77
|
+
x = x,
|
|
78
|
+
x_history = scp_trajs,
|
|
79
|
+
u_history = scp_controls,
|
|
80
|
+
discretization_history = V_multi_shoot_traj,
|
|
81
|
+
J_tr_history = J_tr_vec,
|
|
82
|
+
J_vb_history = J_vb_vec,
|
|
83
|
+
J_vc_history = J_vc_vec,
|
|
84
|
+
)
|
|
85
|
+
return result
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def PTR_subproblem(cpg_solve, x_bar, u_bar, aug_dy, prob, params: Config):
|
|
89
|
+
prob.param_dict['x_bar'].value = x_bar
|
|
90
|
+
prob.param_dict['u_bar'].value = u_bar
|
|
91
|
+
|
|
92
|
+
t0 = time.time()
|
|
93
|
+
A_bar, B_bar, C_bar, z_bar, V_multi_shoot = aug_dy(x_bar, u_bar.astype(float))
|
|
94
|
+
|
|
95
|
+
prob.param_dict['A_d'].value = A_bar.__array__()
|
|
96
|
+
prob.param_dict['B_d'].value = B_bar.__array__()
|
|
97
|
+
prob.param_dict['C_d'].value = C_bar.__array__()
|
|
98
|
+
prob.param_dict['z_d'].value = z_bar.__array__()
|
|
99
|
+
dis_time = time.time() - t0
|
|
100
|
+
|
|
101
|
+
if params.sim.constraints_nodal:
|
|
102
|
+
for g_id, constraint in enumerate(params.sim.constraints_nodal):
|
|
103
|
+
if not constraint.convex:
|
|
104
|
+
prob.param_dict['g_' + str(g_id)].value = np.asarray(constraint.g(x_bar, u_bar))
|
|
105
|
+
prob.param_dict['grad_g_x_' + str(g_id)].value = np.asarray(constraint.grad_g_x(x_bar, u_bar))
|
|
106
|
+
prob.param_dict['grad_g_u_' + str(g_id)].value = np.asarray(constraint.grad_g_u(x_bar, u_bar))
|
|
107
|
+
|
|
108
|
+
prob.param_dict['w_tr'].value = params.scp.w_tr
|
|
109
|
+
prob.param_dict['lam_cost'].value = params.scp.lam_cost
|
|
110
|
+
|
|
111
|
+
if params.cvx.cvxpygen:
|
|
112
|
+
t0 = time.time()
|
|
113
|
+
prob.register_solve('CPG', cpg_solve)
|
|
114
|
+
prob.solve(method = 'CPG', **params.cvx.solver_args)
|
|
115
|
+
subprop_time = time.time() - t0
|
|
116
|
+
else:
|
|
117
|
+
t0 = time.time()
|
|
118
|
+
prob.solve(solver = params.cvx.solver, enforce_dpp = True, **params.cvx.solver_args)
|
|
119
|
+
subprop_time = time.time() - t0
|
|
120
|
+
|
|
121
|
+
x = (params.sim.S_x @ prob.var_dict['x'].value.T + np.expand_dims(params.sim.c_x, axis = 1)).T
|
|
122
|
+
u = (params.sim.S_u @ prob.var_dict['u'].value.T + np.expand_dims(params.sim.c_u, axis = 1)).T
|
|
123
|
+
|
|
124
|
+
i = 0
|
|
125
|
+
costs = [0]
|
|
126
|
+
for type in params.sim.final_state.type:
|
|
127
|
+
if type == 'Minimize':
|
|
128
|
+
costs += x[:,i]
|
|
129
|
+
if type == 'Maximize':
|
|
130
|
+
costs -= x[:,i]
|
|
131
|
+
i += 1
|
|
132
|
+
|
|
133
|
+
# Create the block diagonal matrix using jax.numpy.block
|
|
134
|
+
inv_block_diag = np.block([
|
|
135
|
+
[params.sim.inv_S_x, np.zeros((params.sim.inv_S_x.shape[0], params.sim.inv_S_u.shape[1]))],
|
|
136
|
+
[np.zeros((params.sim.inv_S_u.shape[0], params.sim.inv_S_x.shape[1])), params.sim.inv_S_u]
|
|
137
|
+
])
|
|
138
|
+
|
|
139
|
+
# Calculate J_tr_vec using the JAX-compatible block diagonal matrix
|
|
140
|
+
J_tr_vec = la.norm(inv_block_diag @ np.hstack((x - x_bar, u - u_bar)).T, axis=0)**2
|
|
141
|
+
J_vc_vec = np.sum(np.abs(prob.var_dict['nu'].value), axis = 1)
|
|
142
|
+
|
|
143
|
+
id_ncvx = 0
|
|
144
|
+
J_vb_vec = 0
|
|
145
|
+
for constraint in params.sim.constraints_nodal:
|
|
146
|
+
if constraint.convex == False:
|
|
147
|
+
J_vb_vec += np.maximum(0, prob.var_dict['nu_vb_' + str(id_ncvx)].value)
|
|
148
|
+
id_ncvx += 1
|
|
149
|
+
return x, u, costs, prob.value, J_vb_vec, J_vc_vec, J_tr_vec, prob.status, V_multi_shoot, subprop_time, dis_time
|
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from typing import List, Union
|
|
3
|
+
import queue
|
|
4
|
+
import threading
|
|
5
|
+
import time
|
|
6
|
+
|
|
7
|
+
import cvxpy as cp
|
|
8
|
+
import jax
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from openscvx.config import (
|
|
12
|
+
ScpConfig,
|
|
13
|
+
SimConfig,
|
|
14
|
+
ConvexSolverConfig,
|
|
15
|
+
DiscretizationConfig,
|
|
16
|
+
PropagationConfig,
|
|
17
|
+
DevConfig,
|
|
18
|
+
Config,
|
|
19
|
+
)
|
|
20
|
+
from openscvx.dynamics import Dynamics
|
|
21
|
+
from openscvx.augmentation.dynamics_augmentation import build_augmented_dynamics
|
|
22
|
+
from openscvx.augmentation.ctcs import sort_ctcs_constraints
|
|
23
|
+
from openscvx.constraints.violation import get_g_funcs, CTCSViolation
|
|
24
|
+
from openscvx.discretization import get_discretization_solver
|
|
25
|
+
from openscvx.propagation import get_propagation_solver
|
|
26
|
+
from openscvx.constraints.boundary import BoundaryConstraint
|
|
27
|
+
from openscvx.constraints.ctcs import CTCSConstraint
|
|
28
|
+
from openscvx.constraints.nodal import NodalConstraint
|
|
29
|
+
from openscvx.ptr import PTR_init, PTR_main
|
|
30
|
+
from openscvx.post_processing import propagate_trajectory_results
|
|
31
|
+
from openscvx.ocp import OptimalControlProblem
|
|
32
|
+
from openscvx import io
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# TODO: (norrisg) Decide whether to have constraints`, `cost`, alongside `dynamics`, ` etc.
|
|
36
|
+
class TrajOptProblem:
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
dynamics: Dynamics,
|
|
40
|
+
constraints: List[Union[CTCSConstraint, NodalConstraint]],
|
|
41
|
+
idx_time: int,
|
|
42
|
+
N: int,
|
|
43
|
+
time_init: float,
|
|
44
|
+
x_guess: jnp.ndarray,
|
|
45
|
+
u_guess: jnp.ndarray,
|
|
46
|
+
initial_state: BoundaryConstraint,
|
|
47
|
+
final_state: BoundaryConstraint,
|
|
48
|
+
x_max: jnp.ndarray,
|
|
49
|
+
x_min: jnp.ndarray,
|
|
50
|
+
u_max: jnp.ndarray,
|
|
51
|
+
u_min: jnp.ndarray,
|
|
52
|
+
scp: ScpConfig = None,
|
|
53
|
+
dis: DiscretizationConfig = None,
|
|
54
|
+
prp: PropagationConfig = None,
|
|
55
|
+
sim: SimConfig = None,
|
|
56
|
+
dev: DevConfig = None,
|
|
57
|
+
cvx: ConvexSolverConfig = None,
|
|
58
|
+
licq_min=0.0,
|
|
59
|
+
licq_max=1e-4,
|
|
60
|
+
time_dilation_factor_min=0.3,
|
|
61
|
+
time_dilation_factor_max=3.0,
|
|
62
|
+
):
|
|
63
|
+
|
|
64
|
+
# TODO (norrisg) move this into some augmentation function, if we want to make this be executed after the init (i.e. within problem.initialize) need to rethink how problem is defined
|
|
65
|
+
constraints_ctcs = []
|
|
66
|
+
constraints_nodal = []
|
|
67
|
+
for constraint in constraints:
|
|
68
|
+
if isinstance(constraint, CTCSConstraint):
|
|
69
|
+
constraints_ctcs.append(
|
|
70
|
+
constraint
|
|
71
|
+
)
|
|
72
|
+
elif isinstance(constraint, NodalConstraint):
|
|
73
|
+
constraints_nodal.append(
|
|
74
|
+
constraint
|
|
75
|
+
)
|
|
76
|
+
else:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
f"Unknown constraint type: {type(constraint)}, All constraints must be decorated with @ctcs or @nodal"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
constraints_ctcs, node_intervals, num_augmented_states = sort_ctcs_constraints(constraints_ctcs, N)
|
|
82
|
+
|
|
83
|
+
# Index tracking
|
|
84
|
+
idx_x_true = slice(0, len(x_max))
|
|
85
|
+
idx_u_true = slice(0, len(u_max))
|
|
86
|
+
idx_constraint_violation = slice(
|
|
87
|
+
idx_x_true.stop, idx_x_true.stop + num_augmented_states
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
idx_time_dilation = slice(idx_u_true.stop, idx_u_true.stop + 1)
|
|
91
|
+
|
|
92
|
+
# check that idx_time is in the correct range
|
|
93
|
+
assert idx_time >= 0 and idx_time < len(
|
|
94
|
+
x_max
|
|
95
|
+
), "idx_time must be in the range of the state vector and non-negative"
|
|
96
|
+
idx_time = slice(idx_time, idx_time + 1)
|
|
97
|
+
|
|
98
|
+
x_min_augmented = np.hstack([x_min, np.repeat(licq_min, num_augmented_states)])
|
|
99
|
+
x_max_augmented = np.hstack([x_max, np.repeat(licq_max, num_augmented_states)])
|
|
100
|
+
|
|
101
|
+
u_min_augmented = np.hstack([u_min, time_dilation_factor_min * time_init])
|
|
102
|
+
u_max_augmented = np.hstack([u_max, time_dilation_factor_max * time_init])
|
|
103
|
+
|
|
104
|
+
x_bar_augmented = np.hstack([x_guess, np.full((x_guess.shape[0], num_augmented_states), 0)])
|
|
105
|
+
u_bar_augmented = np.hstack(
|
|
106
|
+
[u_guess, np.full((u_guess.shape[0], 1), time_init)]
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
if dis is None:
|
|
110
|
+
dis = DiscretizationConfig()
|
|
111
|
+
|
|
112
|
+
if sim is None:
|
|
113
|
+
sim = SimConfig(
|
|
114
|
+
x_bar=x_bar_augmented,
|
|
115
|
+
u_bar=u_bar_augmented,
|
|
116
|
+
initial_state=initial_state,
|
|
117
|
+
final_state=final_state,
|
|
118
|
+
max_state=x_max_augmented,
|
|
119
|
+
min_state=x_min_augmented,
|
|
120
|
+
max_control=u_max_augmented,
|
|
121
|
+
min_control=u_min_augmented,
|
|
122
|
+
total_time=time_init,
|
|
123
|
+
n_states=len(x_max),
|
|
124
|
+
idx_x_true=idx_x_true,
|
|
125
|
+
idx_u_true=idx_u_true,
|
|
126
|
+
idx_t=idx_time,
|
|
127
|
+
idx_y=idx_constraint_violation,
|
|
128
|
+
idx_s=idx_time_dilation,
|
|
129
|
+
ctcs_node_intervals=node_intervals,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
if scp is None:
|
|
133
|
+
scp = ScpConfig(
|
|
134
|
+
n=N,
|
|
135
|
+
k_max=200,
|
|
136
|
+
w_tr=1e1, # Weight on the Trust Reigon
|
|
137
|
+
lam_cost=1e1, # Weight on the Nonlinear Cost
|
|
138
|
+
lam_vc=1e2, # Weight on the Virtual Control Objective
|
|
139
|
+
lam_vb=0e0, # Weight on the Virtual Buffer Objective (only for penalized nodal constraints)
|
|
140
|
+
ep_tr=1e-4, # Trust Region Tolerance
|
|
141
|
+
ep_vb=1e-4, # Virtual Control Tolerance
|
|
142
|
+
ep_vc=1e-8, # Virtual Control Tolerance for CTCS
|
|
143
|
+
cost_drop=4, # SCP iteration to relax minimal final time objective
|
|
144
|
+
cost_relax=0.5, # Minimal Time Relaxation Factor
|
|
145
|
+
w_tr_adapt=1.2, # Trust Region Adaptation Factor
|
|
146
|
+
w_tr_max_scaling_factor=1e2, # Maximum Trust Region Weight
|
|
147
|
+
)
|
|
148
|
+
else:
|
|
149
|
+
assert (
|
|
150
|
+
self.scp.n == N
|
|
151
|
+
), "Number of segments must be the same as in the config"
|
|
152
|
+
|
|
153
|
+
if dev is None:
|
|
154
|
+
dev = DevConfig()
|
|
155
|
+
if cvx is None:
|
|
156
|
+
cvx = ConvexSolverConfig()
|
|
157
|
+
if prp is None:
|
|
158
|
+
prp = PropagationConfig()
|
|
159
|
+
|
|
160
|
+
sim.constraints_ctcs = constraints_ctcs
|
|
161
|
+
sim.constraints_nodal = constraints_nodal
|
|
162
|
+
|
|
163
|
+
ctcs_violation_funcs = get_g_funcs(constraints_ctcs)
|
|
164
|
+
self.dynamics_augmented = build_augmented_dynamics(dynamics, ctcs_violation_funcs, idx_x_true, idx_u_true)
|
|
165
|
+
|
|
166
|
+
self.params = Config(
|
|
167
|
+
sim=sim,
|
|
168
|
+
scp=scp,
|
|
169
|
+
dis=dis,
|
|
170
|
+
dev=dev,
|
|
171
|
+
cvx=cvx,
|
|
172
|
+
prp=prp,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
self.optimal_control_problem: cp.Problem = None
|
|
176
|
+
self.discretization_solver: callable = None
|
|
177
|
+
self.cpg_solve = None
|
|
178
|
+
|
|
179
|
+
# set up emitter & thread only if printing is enabled
|
|
180
|
+
if self.params.dev.printing:
|
|
181
|
+
self.print_queue = queue.Queue()
|
|
182
|
+
self.emitter_function = lambda data: self.print_queue.put(data)
|
|
183
|
+
self.print_thread = threading.Thread(
|
|
184
|
+
target=io.intermediate,
|
|
185
|
+
args=(self.print_queue, self.params),
|
|
186
|
+
daemon=True,
|
|
187
|
+
)
|
|
188
|
+
self.print_thread.start()
|
|
189
|
+
else:
|
|
190
|
+
# no-op emitter; nothing ever gets queued or printed
|
|
191
|
+
self.emitter_function = lambda data: None
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
self.timing_init = None
|
|
195
|
+
self.timing_solve = None
|
|
196
|
+
self.timing_post = None
|
|
197
|
+
|
|
198
|
+
def initialize(self):
|
|
199
|
+
io.intro()
|
|
200
|
+
|
|
201
|
+
# Enable the profiler
|
|
202
|
+
if self.params.dev.profiling:
|
|
203
|
+
import cProfile
|
|
204
|
+
|
|
205
|
+
pr = cProfile.Profile()
|
|
206
|
+
pr.enable()
|
|
207
|
+
|
|
208
|
+
t_0_while = time.time()
|
|
209
|
+
# Ensure parameter sizes and normalization are correct
|
|
210
|
+
self.params.scp.__post_init__()
|
|
211
|
+
self.params.sim.__post_init__()
|
|
212
|
+
|
|
213
|
+
# Compile dynamics and jacobians
|
|
214
|
+
self.dynamics_augmented.f = jax.vmap(self.dynamics_augmented.f)
|
|
215
|
+
self.dynamics_augmented.A = jax.jit(jax.vmap(self.dynamics_augmented.A, in_axes=(0, 0, 0)))
|
|
216
|
+
self.dynamics_augmented.B = jax.jit(jax.vmap(self.dynamics_augmented.B, in_axes=(0, 0, 0)))
|
|
217
|
+
|
|
218
|
+
for constraint in self.params.sim.constraints_nodal:
|
|
219
|
+
if not constraint.convex:
|
|
220
|
+
# TODO: (haynec) switch to AOT instead of JIT
|
|
221
|
+
constraint.g = jax.jit(constraint.g)
|
|
222
|
+
constraint.grad_g_x = jax.jit(constraint.grad_g_x)
|
|
223
|
+
constraint.grad_g_u = jax.jit(constraint.grad_g_u)
|
|
224
|
+
|
|
225
|
+
# Generate solvers and optimal control problem
|
|
226
|
+
self.discretization_solver = get_discretization_solver(self.dynamics_augmented, self.params)
|
|
227
|
+
self.propagation_solver = get_propagation_solver(self.dynamics_augmented.f, self.params)
|
|
228
|
+
self.optimal_control_problem = OptimalControlProblem(self.params)
|
|
229
|
+
|
|
230
|
+
# Initialize the PTR loop
|
|
231
|
+
self.cpg_solve = PTR_init(
|
|
232
|
+
self.optimal_control_problem,
|
|
233
|
+
self.discretization_solver,
|
|
234
|
+
self.params,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Compile the solvers
|
|
238
|
+
if not self.params.dev.debug:
|
|
239
|
+
self.discretization_solver = (
|
|
240
|
+
jax.jit(self.discretization_solver)
|
|
241
|
+
.lower(
|
|
242
|
+
np.ones((self.params.scp.n, self.params.sim.n_states)),
|
|
243
|
+
np.ones((self.params.scp.n, self.params.sim.n_controls)),
|
|
244
|
+
)
|
|
245
|
+
.compile()
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
self.propagation_solver = (
|
|
249
|
+
jax.jit(self.propagation_solver)
|
|
250
|
+
.lower(
|
|
251
|
+
np.ones((self.params.sim.n_states)),
|
|
252
|
+
(0.0, 0.0),
|
|
253
|
+
np.ones((1, self.params.sim.n_controls)),
|
|
254
|
+
np.ones((1, self.params.sim.n_controls)),
|
|
255
|
+
np.ones((1, 1)),
|
|
256
|
+
np.ones((1, 1)).astype("int"),
|
|
257
|
+
0,
|
|
258
|
+
)
|
|
259
|
+
.compile()
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
t_f_while = time.time()
|
|
263
|
+
self.timing_init = t_f_while - t_0_while
|
|
264
|
+
print("Total Initialization Time: ", self.timing_init)
|
|
265
|
+
|
|
266
|
+
if self.params.dev.profiling:
|
|
267
|
+
pr.disable()
|
|
268
|
+
# Save results so it can be viusualized with snakeviz
|
|
269
|
+
pr.dump_stats("profiling_initialize.prof")
|
|
270
|
+
|
|
271
|
+
def solve(self):
|
|
272
|
+
# Ensure parameter sizes and normalization are correct
|
|
273
|
+
self.params.scp.__post_init__()
|
|
274
|
+
self.params.sim.__post_init__()
|
|
275
|
+
|
|
276
|
+
if self.optimal_control_problem is None or self.discretization_solver is None:
|
|
277
|
+
raise ValueError(
|
|
278
|
+
"Problem has not been initialized. Call initialize() before solve()"
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Enable the profiler
|
|
282
|
+
if self.params.dev.profiling:
|
|
283
|
+
import cProfile
|
|
284
|
+
|
|
285
|
+
pr = cProfile.Profile()
|
|
286
|
+
pr.enable()
|
|
287
|
+
|
|
288
|
+
t_0_while = time.time()
|
|
289
|
+
# Print top header for solver results
|
|
290
|
+
io.header()
|
|
291
|
+
|
|
292
|
+
result = PTR_main(
|
|
293
|
+
self.params,
|
|
294
|
+
self.optimal_control_problem,
|
|
295
|
+
self.discretization_solver,
|
|
296
|
+
self.cpg_solve,
|
|
297
|
+
self.emitter_function,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
t_f_while = time.time()
|
|
301
|
+
self.timing_solve = t_f_while - t_0_while
|
|
302
|
+
|
|
303
|
+
while self.print_queue.qsize() > 0:
|
|
304
|
+
time.sleep(0.1)
|
|
305
|
+
|
|
306
|
+
# Print bottom footer for solver results as well as total computation time
|
|
307
|
+
io.footer(self.timing_solve)
|
|
308
|
+
|
|
309
|
+
# Disable the profiler
|
|
310
|
+
if self.params.dev.profiling:
|
|
311
|
+
pr.disable()
|
|
312
|
+
# Save results so it can be viusualized with snakeviz
|
|
313
|
+
pr.dump_stats("profiling_solve.prof")
|
|
314
|
+
|
|
315
|
+
return result
|
|
316
|
+
|
|
317
|
+
def post_process(self, result):
|
|
318
|
+
# Enable the profiler
|
|
319
|
+
if self.params.dev.profiling:
|
|
320
|
+
import cProfile
|
|
321
|
+
|
|
322
|
+
pr = cProfile.Profile()
|
|
323
|
+
pr.enable()
|
|
324
|
+
|
|
325
|
+
t_0_post = time.time()
|
|
326
|
+
result = propagate_trajectory_results(self.params, result, self.propagation_solver)
|
|
327
|
+
t_f_post = time.time()
|
|
328
|
+
|
|
329
|
+
self.timing_post = t_f_post - t_0_post
|
|
330
|
+
print("Total Post Processing Time: ", self.timing_post)
|
|
331
|
+
|
|
332
|
+
# Disable the profiler
|
|
333
|
+
if self.params.dev.profiling:
|
|
334
|
+
pr.disable()
|
|
335
|
+
# Save results so it can be viusualized with snakeviz
|
|
336
|
+
pr.dump_stats("profiling_postprocess.prof")
|
|
337
|
+
return result
|
openscvx/utils.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def qdcm(q: jnp.ndarray) -> jnp.ndarray:
|
|
7
|
+
# Convert a quaternion to a direction cosine matrix
|
|
8
|
+
q_norm = (q[0] ** 2 + q[1] ** 2 + q[2] ** 2 + q[3] ** 2) ** 0.5
|
|
9
|
+
w, x, y, z = q / q_norm
|
|
10
|
+
return jnp.array(
|
|
11
|
+
[
|
|
12
|
+
[1 - 2 * (y**2 + z**2), 2 * (x * y - z * w), 2 * (x * z + y * w)],
|
|
13
|
+
[2 * (x * y + z * w), 1 - 2 * (x**2 + z**2), 2 * (y * z - x * w)],
|
|
14
|
+
[2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x**2 + y**2)],
|
|
15
|
+
]
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def SSMP(w: jnp.ndarray):
|
|
20
|
+
# Convert an angular rate to a 4 x 4 skew symetric matrix
|
|
21
|
+
x, y, z = w
|
|
22
|
+
return jnp.array([[0, -x, -y, -z], [x, 0, z, -y], [y, -z, 0, x], [z, y, -x, 0]])
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def SSM(w: jnp.ndarray):
|
|
26
|
+
# Convert an angular rate to a 3 x 3 skew symetric matrix
|
|
27
|
+
x, y, z = w
|
|
28
|
+
return jnp.array([[0, -z, y], [z, 0, -x], [-y, x, 0]])
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def generate_orthogonal_unit_vectors(vectors=None):
|
|
32
|
+
"""
|
|
33
|
+
Generates 3 orthogonal unit vectors to model the axis of the ellipsoid via QR decomposition
|
|
34
|
+
|
|
35
|
+
Parameters:
|
|
36
|
+
vectors (np.ndarray): Optional, axes of the ellipsoid to be orthonormalized.
|
|
37
|
+
If none specified generates randomly.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
np.ndarray: A 3x3 matrix where each column is a unit vector.
|
|
41
|
+
"""
|
|
42
|
+
if vectors is None:
|
|
43
|
+
# Create a random key
|
|
44
|
+
key = jax.random.PRNGKey(0)
|
|
45
|
+
|
|
46
|
+
# Generate a 3x3 array of random numbers uniformly distributed between 0 and 1
|
|
47
|
+
vectors = jax.random.uniform(key, (3, 3))
|
|
48
|
+
Q, _ = jnp.linalg.qr(vectors)
|
|
49
|
+
return Q
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
rot = np.array(
|
|
53
|
+
[
|
|
54
|
+
[np.cos(np.pi / 2), np.sin(np.pi / 2), 0],
|
|
55
|
+
[-np.sin(np.pi / 2), np.cos(np.pi / 2), 0],
|
|
56
|
+
[0, 0, 1],
|
|
57
|
+
]
|
|
58
|
+
)
|
|
59
|
+
def gen_vertices(center, radii):
|
|
60
|
+
"""
|
|
61
|
+
Obtains the vertices of the gate.
|
|
62
|
+
"""
|
|
63
|
+
vertices = []
|
|
64
|
+
vertices.append(center + rot @ [radii[0], 0, radii[2]])
|
|
65
|
+
vertices.append(center + rot @ [-radii[0], 0, radii[2]])
|
|
66
|
+
vertices.append(center + rot @ [-radii[0], 0, -radii[2]])
|
|
67
|
+
vertices.append(center + rot @ [radii[0], 0, -radii[2]])
|
|
68
|
+
return vertices
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# TODO (haynec): make this less hardcoded
|
|
72
|
+
def get_kp_pose(t, init_pose):
|
|
73
|
+
loop_time = 40.0
|
|
74
|
+
loop_radius = 20.0
|
|
75
|
+
|
|
76
|
+
t_angle = t / loop_time * (2 * jnp.pi)
|
|
77
|
+
x = loop_radius * jnp.sin(t_angle)
|
|
78
|
+
y = x * jnp.cos(t_angle)
|
|
79
|
+
z = 0.5 * x * jnp.sin(t_angle)
|
|
80
|
+
return jnp.array([x, y, z]).T + init_pose
|