openscvx 0.1.2__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/_version.py +2 -2
- openscvx/augmentation/dynamics_augmentation.py +22 -7
- openscvx/config.py +310 -192
- openscvx/constraints/__init__.py +0 -3
- openscvx/constraints/ctcs.py +188 -33
- openscvx/constraints/nodal.py +150 -11
- openscvx/constraints/violation.py +12 -2
- openscvx/discretization.py +115 -37
- openscvx/dynamics.py +150 -11
- openscvx/integrators.py +135 -16
- openscvx/io.py +129 -17
- openscvx/ocp.py +86 -67
- openscvx/plotting.py +72 -215
- openscvx/post_processing.py +57 -16
- openscvx/propagation.py +155 -55
- openscvx/ptr.py +96 -57
- openscvx/results.py +153 -0
- openscvx/trajoptproblem.py +359 -114
- openscvx/utils.py +50 -0
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/METADATA +129 -41
- openscvx-0.2.1.dev0.dist-info/RECORD +27 -0
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/WHEEL +1 -1
- openscvx/constraints/boundary.py +0 -49
- openscvx-0.1.2.dist-info/RECORD +0 -27
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/top_level.txt +0 -0
openscvx/plotting.py
CHANGED
|
@@ -6,12 +6,13 @@ import pickle
|
|
|
6
6
|
|
|
7
7
|
from openscvx.utils import qdcm, get_kp_pose
|
|
8
8
|
from openscvx.config import Config
|
|
9
|
+
from openscvx.results import OptimizationResults
|
|
9
10
|
|
|
10
|
-
def full_subject_traj_time(results, params):
|
|
11
|
-
x_full = results
|
|
12
|
-
x_nodes = results
|
|
11
|
+
def full_subject_traj_time(results: OptimizationResults, params: Config):
|
|
12
|
+
x_full = results.x_full
|
|
13
|
+
x_nodes = results.x
|
|
13
14
|
t_nodes = x_nodes[:,params.sim.idx_t]
|
|
14
|
-
t_full = results
|
|
15
|
+
t_full = results.t_full
|
|
15
16
|
subs_traj = []
|
|
16
17
|
subs_traj_node = []
|
|
17
18
|
subs_traj_sen = []
|
|
@@ -19,11 +20,12 @@ def full_subject_traj_time(results, params):
|
|
|
19
20
|
|
|
20
21
|
# if hasattr(params.dyn, 'get_kp_pose'):
|
|
21
22
|
if "moving_subject" in results and "init_poses" in results:
|
|
22
|
-
init_poses = results["init_poses"]
|
|
23
|
+
init_poses = results.plotting_data["init_poses"]
|
|
23
24
|
subs_traj.append(get_kp_pose(t_full, init_poses))
|
|
24
25
|
subs_traj_node.append(get_kp_pose(t_nodes, init_poses))
|
|
25
26
|
elif "init_poses" in results:
|
|
26
|
-
|
|
27
|
+
init_poses = results.plotting_data["init_poses"]
|
|
28
|
+
for pose in init_poses:
|
|
27
29
|
# repeat the pose for all time steps
|
|
28
30
|
pose_full = np.repeat(pose[:,np.newaxis], x_full.shape[0], axis=1).T
|
|
29
31
|
subs_traj.append(pose_full)
|
|
@@ -34,7 +36,7 @@ def full_subject_traj_time(results, params):
|
|
|
34
36
|
raise ValueError("No valid method to get keypoint poses.")
|
|
35
37
|
|
|
36
38
|
if "R_sb" in results:
|
|
37
|
-
R_sb = results["R_sb"]
|
|
39
|
+
R_sb = results.plotting_data["R_sb"]
|
|
38
40
|
for sub_traj in subs_traj:
|
|
39
41
|
sub_traj_sen = []
|
|
40
42
|
for i in range(x_full.shape[0]):
|
|
@@ -50,7 +52,7 @@ def full_subject_traj_time(results, params):
|
|
|
50
52
|
subs_traj_sen_node.append(np.array(sub_traj_sen_node).squeeze())
|
|
51
53
|
return subs_traj, subs_traj_sen, subs_traj_node, subs_traj_sen_node
|
|
52
54
|
else:
|
|
53
|
-
raise ValueError("`R_sb` not found in results
|
|
55
|
+
raise ValueError("`R_sb` not found in results. Cannot compute sensor frame.")
|
|
54
56
|
|
|
55
57
|
def save_gate_parameters(gates, params: Config):
|
|
56
58
|
gate_centers = []
|
|
@@ -75,26 +77,26 @@ def frame_args(duration):
|
|
|
75
77
|
"transition": {"duration": duration, "easing": "linear"},
|
|
76
78
|
}
|
|
77
79
|
|
|
78
|
-
def plot_constraint_violation(result, params: Config):
|
|
80
|
+
def plot_constraint_violation(result: OptimizationResults, params: Config):
|
|
79
81
|
fig = make_subplots(rows=2, cols=3, subplot_titles=(r'$\text{Obstacle Violation}$', r'$\text{Sub VP Violation}$', r'$\text{Sub Min Violation}$', r'$\text{Sub Max Violation}$', r'$\text{Sub Direc Violation}$', r'$\text{State Bound Violation}$', r'$\text{Total Violation}$'))
|
|
80
82
|
fig.update_layout(template='plotly_dark', title=r'$\text{Constraint Violation}$')
|
|
81
83
|
|
|
82
84
|
if "obs_vio" in result:
|
|
83
|
-
obs_vio = result["obs_vio"]
|
|
85
|
+
obs_vio = result.plotting_data["obs_vio"]
|
|
84
86
|
for i in range(obs_vio.shape[0]):
|
|
85
87
|
color = f'rgb({random.randint(10,255)}, {random.randint(10,255)}, {random.randint(10,255)})'
|
|
86
88
|
fig.add_trace(go.Scatter(y=obs_vio[i], mode='lines', showlegend=False, line=dict(color=color, width = 2)), row=1, col=1)
|
|
87
89
|
i = 0
|
|
88
90
|
else:
|
|
89
|
-
print("'obs_vio' not found in result
|
|
91
|
+
print("'obs_vio' not found in result.")
|
|
90
92
|
|
|
91
93
|
# Make names of each state in the state vector
|
|
92
94
|
state_names = ['x', 'y', 'z', 'vx', 'vy', 'vz', 'q0', 'q1', 'q2', 'q3', 'wx', 'wy', 'wz', 'ctcs']
|
|
93
95
|
|
|
94
96
|
if "sub_vp_vio" in result and "sub_min_vio" in result and "sub_max_vio" in result:
|
|
95
|
-
sub_vp_vio = result["sub_vp_vio"]
|
|
96
|
-
sub_min_vio = result["sub_min_vio"]
|
|
97
|
-
sub_max_vio = result["sub_max_vio"]
|
|
97
|
+
sub_vp_vio = result.plotting_data["sub_vp_vio"]
|
|
98
|
+
sub_min_vio = result.plotting_data["sub_min_vio"]
|
|
99
|
+
sub_max_vio = result.plotting_data["sub_max_vio"]
|
|
98
100
|
for i in range(sub_vp_vio.shape[0]):
|
|
99
101
|
color = f'rgb({random.randint(10,255)}, {random.randint(10,255)}, {random.randint(10,255)})'
|
|
100
102
|
fig.add_trace(go.Scatter(y=sub_vp_vio[i], mode='lines', showlegend=True, name = 'LoS ' + str(i) + ' Error', line=dict(color=color, width = 2)), row=1, col=2)
|
|
@@ -106,28 +108,28 @@ def plot_constraint_violation(result, params: Config):
|
|
|
106
108
|
fig.add_trace(go.Scatter(y=[], mode='lines', showlegend=False, line=dict(color=color, width = 2)), row=2, col=1)
|
|
107
109
|
i = 0
|
|
108
110
|
else:
|
|
109
|
-
print("'sub_vp_vio', 'sub_min_vio', or 'sub_max_vio' not found in result
|
|
111
|
+
print("'sub_vp_vio', 'sub_min_vio', or 'sub_max_vio' not found in result.")
|
|
110
112
|
|
|
111
113
|
if "sub_direc_vio" in result:
|
|
112
|
-
sub_direc_vio = result["sub_direc_vio"]
|
|
114
|
+
sub_direc_vio = result.plotting_data["sub_direc_vio"]
|
|
113
115
|
# fig.add_trace(go.Scatter(y=sub_direc_vio, mode='lines', showlegend=False, line=dict(color='red', width = 2)), row=2, col=2)
|
|
114
116
|
else:
|
|
115
|
-
print("'sub_direc_vio' not found in result
|
|
117
|
+
print("'sub_direc_vio' not found in result.")
|
|
116
118
|
|
|
117
119
|
if "state_bound_vio" in result:
|
|
118
|
-
state_bound_vio = result["state_bound_vio"]
|
|
120
|
+
state_bound_vio = result.plotting_data["state_bound_vio"]
|
|
119
121
|
for i in range(state_bound_vio.shape[0]):
|
|
120
122
|
color = f'rgb({random.randint(10,255)}, {random.randint(10,255)}, {random.randint(10,255)})'
|
|
121
123
|
fig.add_trace(go.Scatter(y=state_bound_vio[:,i], mode='lines', showlegend=True, name = state_names[i] + ' Error', line=dict(color=color, width = 2)), row=2, col=3)
|
|
122
124
|
else:
|
|
123
|
-
print("'state_bound_vio' not found in result
|
|
125
|
+
print("'state_bound_vio' not found in result.")
|
|
124
126
|
|
|
125
127
|
fig.show()
|
|
126
128
|
|
|
127
|
-
def plot_initial_guess(result, params: Config):
|
|
128
|
-
x_positions = result
|
|
129
|
-
x_attitude = result
|
|
130
|
-
subs_positions = result["sub_positions"]
|
|
129
|
+
def plot_initial_guess(result: OptimizationResults, params: Config):
|
|
130
|
+
x_positions = result.x.guess[:, 0:3].T
|
|
131
|
+
x_attitude = result.x.guess[:, 6:10].T
|
|
132
|
+
subs_positions = result.plotting_data["sub_positions"]
|
|
131
133
|
|
|
132
134
|
fig = go.Figure(go.Scatter3d(x=[], y=[], z=[], mode='lines+markers', line=dict(color='gray', width = 2)))
|
|
133
135
|
|
|
@@ -172,17 +174,17 @@ def plot_initial_guess(result, params: Config):
|
|
|
172
174
|
fig.add_trace(go.Scatter3d(x=sub_positions[:,0], y=sub_positions[:,1], z=sub_positions[:,2], mode='lines+markers', line=dict(color='red', width = 5), name='Subject'))
|
|
173
175
|
fig.show()
|
|
174
176
|
|
|
175
|
-
def plot_scp_animation(result:
|
|
177
|
+
def plot_scp_animation(result: OptimizationResults,
|
|
176
178
|
params = None,
|
|
177
179
|
path=""):
|
|
178
|
-
tof = result
|
|
180
|
+
tof = result.t_final
|
|
179
181
|
title = f'SCP Simulation: {tof} seconds'
|
|
180
|
-
drone_positions = result
|
|
181
|
-
drone_attitudes = result
|
|
182
|
-
drone_forces = result
|
|
183
|
-
scp_interp_trajs = scp_traj_interp(result
|
|
184
|
-
scp_ctcs_trajs = result
|
|
185
|
-
scp_multi_shoot = result
|
|
182
|
+
drone_positions = result.x_full[:, :3]
|
|
183
|
+
drone_attitudes = result.x_full[:, 6:10]
|
|
184
|
+
drone_forces = result.u_full[:, :3]
|
|
185
|
+
scp_interp_trajs = scp_traj_interp(result.x_history, params)
|
|
186
|
+
scp_ctcs_trajs = result.x_history
|
|
187
|
+
scp_multi_shoot = result.discretization_history
|
|
186
188
|
# obstacles = result_ctcs["obstacles"]
|
|
187
189
|
# gates = result_ctcs["gates"]
|
|
188
190
|
if "moving_subject" in result or "init_poses" in result:
|
|
@@ -289,14 +291,14 @@ def plot_scp_animation(result: dict,
|
|
|
289
291
|
fig.add_trace(go.Surface(x=points[:, 0].reshape(n,n), y=points[:, 1].reshape(n,n), z=points[:, 2].reshape(n,n), opacity = 0.5, showscale=False))
|
|
290
292
|
|
|
291
293
|
if "vertices" in result:
|
|
292
|
-
for vertices in result["vertices"]:
|
|
294
|
+
for vertices in result.plotting_data["vertices"]:
|
|
293
295
|
# Plot a line through the vertices of the gate
|
|
294
296
|
fig.add_trace(go.Scatter3d(x=[vertices[0][0], vertices[1][0], vertices[2][0], vertices[3][0], vertices[0][0]], y=[vertices[0][1], vertices[1][1], vertices[2][1], vertices[3][1], vertices[0][1]], z=[vertices[0][2], vertices[1][2], vertices[2][2], vertices[3][2], vertices[0][2]], mode='lines', showlegend=False, line=dict(color='blue', width=10)))
|
|
295
297
|
|
|
296
298
|
# Add the subject positions
|
|
297
|
-
if "n_subs" in result and result["n_subs"] != 0:
|
|
299
|
+
if "n_subs" in result and result.plotting_data["n_subs"] != 0:
|
|
298
300
|
if "moving_subject" in result:
|
|
299
|
-
if result["moving_subject"]:
|
|
301
|
+
if result.plotting_data["moving_subject"]:
|
|
300
302
|
for sub_positions in subs_positions:
|
|
301
303
|
fig.add_trace(go.Scatter3d(x=sub_positions[:,0], y=sub_positions[:,1], z=sub_positions[:,2], mode='lines', line=dict(color='red', width = 5), showlegend=False))
|
|
302
304
|
else:
|
|
@@ -419,202 +421,57 @@ def scp_traj_interp(scp_trajs, params: Config):
|
|
|
419
421
|
scp_prop_trajs.append(np.array(states))
|
|
420
422
|
return scp_prop_trajs
|
|
421
423
|
|
|
422
|
-
def plot_state(result, params: Config):
|
|
423
|
-
|
|
424
|
-
|
|
424
|
+
def plot_state(result: OptimizationResults, params: Config):
|
|
425
|
+
x_full = result.x_full
|
|
426
|
+
t_full = result.t_full
|
|
427
|
+
dis_history = result.discretization_history
|
|
428
|
+
|
|
429
|
+
n_x = params.sim.n_states
|
|
425
430
|
|
|
426
431
|
fig = make_subplots(rows=2, cols=7, subplot_titles=('X Position', 'Y Position', 'Z Position', 'X Velocity', 'Y Velocity', 'Z Velocity', 'CTCS Augmentation', 'Q1', 'Q2', 'Q3', 'Q4', 'X Angular Rate', 'Y Angular Rate', 'Z Angular Rate'))
|
|
427
432
|
fig.update_layout(title_text="State Trajectories", template='plotly_dark')
|
|
428
433
|
|
|
429
|
-
# Plot the
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
y_min = params.sim.min_state[1]
|
|
440
|
-
y_max = params.sim.max_state[1]
|
|
441
|
-
for traj in scp_trajs:
|
|
442
|
-
fig.add_trace(go.Scatter(y=traj[:,1], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=2)
|
|
443
|
-
fig.add_trace(go.Scatter(y=x_full[:,1], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=2)
|
|
444
|
-
fig.add_hline(y=y_min, line=dict(color='red', width=2), row = 1, col = 2)
|
|
445
|
-
fig.add_hline(y=y_max, line=dict(color='red', width=2), row = 1, col = 2)
|
|
446
|
-
|
|
447
|
-
z_min = params.sim.min_state[2]
|
|
448
|
-
z_max = params.sim.max_state[2]
|
|
449
|
-
for traj in scp_trajs:
|
|
450
|
-
fig.add_trace(go.Scatter(y=traj[:,2], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=3)
|
|
451
|
-
fig.add_trace(go.Scatter(y=x_full[:,2], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=3)
|
|
452
|
-
fig.add_hline(y=z_min, line=dict(color='red', width=2), row = 1, col = 3)
|
|
453
|
-
fig.add_hline(y=z_max, line=dict(color='red', width=2), row = 1, col = 3)
|
|
454
|
-
|
|
455
|
-
# Plot the velocity
|
|
456
|
-
vx_min = params.sim.min_state[3]
|
|
457
|
-
vx_max = params.sim.max_state[3]
|
|
458
|
-
for traj in scp_trajs:
|
|
459
|
-
fig.add_trace(go.Scatter(y=traj[:,3], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=4)
|
|
460
|
-
fig.add_trace(go.Scatter(y=x_full[:,3], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=4)
|
|
461
|
-
fig.add_hline(y=vx_min, line=dict(color='red', width=2), row = 1, col = 4)
|
|
462
|
-
fig.add_hline(y=vx_max, line=dict(color='red', width=2), row = 1, col = 4)
|
|
463
|
-
|
|
464
|
-
vy_min = params.sim.min_state[4]
|
|
465
|
-
vy_max = params.sim.max_state[4]
|
|
466
|
-
for traj in scp_trajs:
|
|
467
|
-
fig.add_trace(go.Scatter(y=traj[:,4], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=5)
|
|
468
|
-
fig.add_trace(go.Scatter(y=x_full[:,4], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=5)
|
|
469
|
-
fig.add_hline(y=vy_min, line=dict(color='red', width=2), row = 1, col = 5)
|
|
470
|
-
fig.add_hline(y=vy_max, line=dict(color='red', width=2), row = 1, col = 5)
|
|
471
|
-
|
|
472
|
-
vz_min = params.sim.min_state[5]
|
|
473
|
-
vz_max = params.sim.max_state[5]
|
|
474
|
-
for traj in scp_trajs:
|
|
475
|
-
fig.add_trace(go.Scatter(y=traj[:,5], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=6)
|
|
476
|
-
fig.add_trace(go.Scatter(y=x_full[:,5], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=6)
|
|
477
|
-
fig.add_hline(y=vz_min, line=dict(color='red', width=2), row = 1, col = 6)
|
|
478
|
-
fig.add_hline(y=vz_max, line=dict(color='red', width=2), row = 1, col = 6)
|
|
479
|
-
|
|
480
|
-
# # Plot the norm of the quaternion
|
|
481
|
-
# for traj in scp_trajs:
|
|
482
|
-
# fig.add_trace(go.Scatter(y=la.norm(traj[1:,6:10], axis = 1), mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=7)
|
|
483
|
-
# fig.add_trace(go.Scatter(y=la.norm(x_full[1:,6:10], axis = 1), mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=7)
|
|
484
|
-
for traj in scp_trajs:
|
|
485
|
-
fig.add_trace(go.Scatter(y=traj[:,-1], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=7)
|
|
486
|
-
fig.add_trace(go.Scatter(y=x_full[:,-1], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=7)
|
|
487
|
-
# fig.add_hline(y=vz_min, line=dict(color='red', width=2), row = 1, col = 6)
|
|
488
|
-
# fig.add_hline(y=vz_max, line=dict(color='red', width=2), row = 1, col = 6)
|
|
489
|
-
|
|
490
|
-
# Plot the attitude
|
|
491
|
-
q1_min = params.sim.min_state[6]
|
|
492
|
-
q1_max = params.sim.max_state[6]
|
|
493
|
-
for traj in scp_trajs:
|
|
494
|
-
fig.add_trace(go.Scatter(y=traj[:,6], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=1)
|
|
495
|
-
fig.add_trace(go.Scatter(y=x_full[:,6], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=1)
|
|
496
|
-
fig.add_hline(y=q1_min, line=dict(color='red', width=2), row = 2, col = 1)
|
|
497
|
-
fig.add_hline(y=q1_max, line=dict(color='red', width=2), row = 2, col = 1)
|
|
498
|
-
|
|
499
|
-
q2_min = params.sim.min_state[7]
|
|
500
|
-
q2_max = params.sim.max_state[7]
|
|
501
|
-
for traj in scp_trajs:
|
|
502
|
-
fig.add_trace(go.Scatter(y=traj[:,7], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=2)
|
|
503
|
-
fig.add_trace(go.Scatter(y=x_full[:,7], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=2)
|
|
504
|
-
fig.add_hline(y=q2_min, line=dict(color='red', width=2), row = 2, col = 2)
|
|
505
|
-
fig.add_hline(y=q2_max, line=dict(color='red', width=2), row = 2, col = 2)
|
|
506
|
-
|
|
507
|
-
q3_min = params.sim.min_state[8]
|
|
508
|
-
q3_max = params.sim.max_state[8]
|
|
509
|
-
for traj in scp_trajs:
|
|
510
|
-
fig.add_trace(go.Scatter(y=traj[:,8], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=3)
|
|
511
|
-
fig.add_trace(go.Scatter(y=x_full[:,8], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=3)
|
|
512
|
-
fig.add_hline(y=q3_min, line=dict(color='red', width=2), row = 2, col = 3)
|
|
513
|
-
fig.add_hline(y=q3_max, line=dict(color='red', width=2), row = 2, col = 3)
|
|
514
|
-
|
|
515
|
-
q4_min = params.sim.min_state[9]
|
|
516
|
-
q4_max = params.sim.max_state[9]
|
|
517
|
-
for traj in scp_trajs:
|
|
518
|
-
fig.add_trace(go.Scatter(y=traj[:,9], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=4)
|
|
519
|
-
fig.add_trace(go.Scatter(y=x_full[:,9], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=4)
|
|
520
|
-
fig.add_hline(y=q4_min, line=dict(color='red', width=2), row = 2, col = 4)
|
|
521
|
-
fig.add_hline(y=q4_max, line=dict(color='red', width=2), row = 2, col = 4)
|
|
522
|
-
|
|
523
|
-
# Plot the angular rate
|
|
524
|
-
wx_min = params.sim.min_state[10]
|
|
525
|
-
wx_max = params.sim.max_state[10]
|
|
526
|
-
for traj in scp_trajs:
|
|
527
|
-
fig.add_trace(go.Scatter(y=traj[:,10], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=5)
|
|
528
|
-
fig.add_trace(go.Scatter(y=x_full[:,10], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=5)
|
|
529
|
-
fig.add_hline(y=wx_min, line=dict(color='red', width=2), row = 2, col = 5)
|
|
530
|
-
fig.add_hline(y=wx_max, line=dict(color='red', width=2), row = 2, col = 5)
|
|
434
|
+
# Plot the State
|
|
435
|
+
# for traj in dis_history:
|
|
436
|
+
for i in range(n_x):
|
|
437
|
+
x_min = params.sim.x.min[i]
|
|
438
|
+
x_max = params.sim.x.max[i]
|
|
439
|
+
# fig.add_trace(go.Scatter(y=traj[i], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=(i // 7) + 1, col=(i % 7) + 1)
|
|
440
|
+
fig.add_trace(go.Scatter(x=t_full, y=x_full[:,i], mode='lines', showlegend=True, line=dict(color='green', width = 2)), row=(i // 7) + 1, col=(i % 7) + 1)
|
|
441
|
+
fig.add_trace(go.Scatter(x=params.sim.x.guess[:,7], y=params.sim.x.guess[:,i], mode='lines', showlegend=True, line=dict(color='blue', width = 0.5)), row=(i // 7) + 1, col=(i % 7) + 1)
|
|
442
|
+
fig.add_hline(y=x_min, line=dict(color='red', width=2), row = (i // 7) + 1, col = (i % 7) + 1)
|
|
443
|
+
fig.add_hline(y=x_max, line=dict(color='red', width=2), row = (i // 7) + 1, col = (i % 7) + 1)
|
|
531
444
|
|
|
532
|
-
|
|
533
|
-
wy_max = params.sim.max_state[11]
|
|
534
|
-
for traj in scp_trajs:
|
|
535
|
-
fig.add_trace(go.Scatter(y=traj[:,11], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=6)
|
|
536
|
-
fig.add_trace(go.Scatter(y=x_full[:,11], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=6)
|
|
537
|
-
fig.add_hline(y=wy_min, line=dict(color='red', width=2), row = 2, col = 6)
|
|
538
|
-
fig.add_hline(y=wy_max, line=dict(color='red', width=2), row = 2, col = 6)
|
|
445
|
+
return fig
|
|
539
446
|
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
fig.add_trace(go.Scatter(y=traj[:,12], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=7)
|
|
544
|
-
fig.add_trace(go.Scatter(y=x_full[:,12], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=7)
|
|
545
|
-
fig.add_hline(y=wz_min, line=dict(color='red', width=2), row = 2, col = 7)
|
|
546
|
-
fig.add_hline(y=wz_max, line=dict(color='red', width=2), row = 2, col = 7)
|
|
547
|
-
fig.show()
|
|
447
|
+
def plot_control(result: OptimizationResults, params: Config):
|
|
448
|
+
u_full = result.u_full
|
|
449
|
+
t_full = result.t_full
|
|
548
450
|
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
u = result["u"]
|
|
451
|
+
u = params.sim.u
|
|
452
|
+
x = params.sim.x
|
|
552
453
|
|
|
553
|
-
|
|
554
|
-
fx_max = params.sim.max_control[0]
|
|
555
|
-
|
|
556
|
-
fy_min = params.sim.min_control[1]
|
|
557
|
-
fy_max = params.sim.max_control[1]
|
|
558
|
-
|
|
559
|
-
fz_min = params.sim.min_control[2]
|
|
560
|
-
fz_max = params.sim.max_control[2]
|
|
561
|
-
|
|
562
|
-
tau_x_min = params.sim.max_control[3]
|
|
563
|
-
tau_x_max = params.sim.min_control[3]
|
|
564
|
-
|
|
565
|
-
tau_y_min = params.sim.max_control[4]
|
|
566
|
-
tau_y_max = params.sim.min_control[4]
|
|
567
|
-
|
|
568
|
-
tau_z_min = params.sim.max_control[5]
|
|
569
|
-
tau_z_max = params.sim.min_control[5]
|
|
454
|
+
n_u = params.sim.n_controls
|
|
570
455
|
|
|
571
456
|
fig = make_subplots(rows=2, cols=3, subplot_titles=('X Force', 'Y Force', 'Z Force', 'X Torque', 'Y Torque', 'Z Torque'))
|
|
572
457
|
fig.update_layout(title_text="Control Trajectories", template='plotly_dark')
|
|
573
458
|
|
|
574
|
-
for
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
fig.add_trace(go.Scatter(y=traj[1], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=2)
|
|
582
|
-
fig.add_trace(go.Scatter(y=u[1], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=2)
|
|
583
|
-
fig.add_hline(y=fy_min, line=dict(color='red', width=2), row = 1, col = 2)
|
|
584
|
-
fig.add_hline(y=fy_max, line=dict(color='red', width=2), row = 1, col = 2)
|
|
585
|
-
|
|
586
|
-
for traj in scp_controls:
|
|
587
|
-
fig.add_trace(go.Scatter(y=traj[2], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=3)
|
|
588
|
-
fig.add_trace(go.Scatter(y=u[2], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=3)
|
|
589
|
-
fig.add_hline(y=fz_min, line=dict(color='red', width=2), row = 1, col = 3)
|
|
590
|
-
fig.add_hline(y=fz_max, line=dict(color='red', width=2), row = 1, col = 3)
|
|
591
|
-
|
|
592
|
-
for traj in scp_controls:
|
|
593
|
-
fig.add_trace(go.Scatter(y=traj[3], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=1)
|
|
594
|
-
fig.add_trace(go.Scatter(y=u[3], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=1)
|
|
595
|
-
fig.add_hline(y=tau_x_min, line=dict(color='red', width=2), row = 2, col = 1)
|
|
596
|
-
fig.add_hline(y=tau_x_max, line=dict(color='red', width=2), row = 2, col = 1)
|
|
597
|
-
|
|
598
|
-
for traj in scp_controls:
|
|
599
|
-
fig.add_trace(go.Scatter(y=traj[4], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=2)
|
|
600
|
-
fig.add_trace(go.Scatter(y=u[4], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=2)
|
|
601
|
-
fig.add_hline(y=tau_y_min, line=dict(color='red', width=2), row = 2, col = 2)
|
|
602
|
-
fig.add_hline(y=tau_y_max, line=dict(color='red', width=2), row = 2, col = 2)
|
|
603
|
-
|
|
604
|
-
for traj in scp_controls:
|
|
605
|
-
fig.add_trace(go.Scatter(y=traj[5], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=3)
|
|
606
|
-
fig.add_trace(go.Scatter(y=u[5], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=3)
|
|
607
|
-
fig.add_hline(y=tau_z_min, line=dict(color='red', width=2), row = 2, col = 3)
|
|
608
|
-
fig.add_hline(y=tau_z_max, line=dict(color='red', width=2), row = 2, col = 3)
|
|
459
|
+
for i in range(n_u):
|
|
460
|
+
u_min = u.min[i]
|
|
461
|
+
u_max = u.max[i]
|
|
462
|
+
fig.add_trace(go.Scatter(x=t_full, y=u_full[:,i], mode='lines', showlegend=True, line=dict(color='green', width = 2)), row=(i // 3) + 1, col=(i % 3) + 1)
|
|
463
|
+
fig.add_trace(go.Scatter(x=x.guess[:,7], y=u.guess[:,i], mode='lines', showlegend=True, line=dict(color='blue', width = 0.5)), row=(i // 3) + 1, col=(i % 3) + 1)
|
|
464
|
+
fig.add_hline(y=u_min, line=dict(color='red', width=2), row = (i // 3) + 1, col = (i % 3) + 1)
|
|
465
|
+
fig.add_hline(y=u_max, line=dict(color='red', width=2), row = (i // 3) + 1, col = (i % 3) + 1)
|
|
609
466
|
|
|
610
|
-
fig
|
|
467
|
+
return fig
|
|
611
468
|
|
|
612
|
-
def plot_losses(result, params: Config):
|
|
469
|
+
def plot_losses(result: OptimizationResults, params: Config):
|
|
613
470
|
# Plot J_tr, J_vb, J_vc, J_vc_ctcs
|
|
614
|
-
J_tr = result
|
|
615
|
-
J_vb = result
|
|
616
|
-
J_vc = result
|
|
617
|
-
J_vc_ctcs = result["J_vc_ctcs_vec"]
|
|
471
|
+
J_tr = result.J_tr_history
|
|
472
|
+
J_vb = result.J_vb_history
|
|
473
|
+
J_vc = result.J_vc_history
|
|
474
|
+
J_vc_ctcs = result.plotting_data["J_vc_ctcs_vec"]
|
|
618
475
|
|
|
619
476
|
fig = make_subplots(rows=2, cols=2, subplot_titles=('J_tr', 'J_vb', 'J_vc', 'J_vc_ctcs'))
|
|
620
477
|
fig.update_layout(title_text="Losses", template='plotly_dark')
|
openscvx/post_processing.py
CHANGED
|
@@ -1,36 +1,77 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
+
import jax.numpy as jnp
|
|
2
3
|
|
|
3
4
|
from openscvx.propagation import s_to_t, t_to_tau, simulate_nonlinear_time
|
|
4
5
|
from openscvx.config import Config
|
|
6
|
+
from openscvx.results import OptimizationResults
|
|
5
7
|
|
|
6
8
|
|
|
7
|
-
def propagate_trajectory_results(params: Config, result:
|
|
8
|
-
|
|
9
|
-
|
|
9
|
+
def propagate_trajectory_results(params: dict, settings: Config, result: OptimizationResults, propagation_solver: callable) -> OptimizationResults:
|
|
10
|
+
"""Propagate the optimal trajectory and compute additional results.
|
|
11
|
+
|
|
12
|
+
This function takes the optimal control solution and propagates it through the
|
|
13
|
+
nonlinear dynamics to compute the actual state trajectory and other metrics.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
params (dict): System parameters.
|
|
17
|
+
settings (Config): Configuration settings.
|
|
18
|
+
result (OptimizationResults): Optimization results object.
|
|
19
|
+
propagation_solver (callable): Function for propagating the system state.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
OptimizationResults: Updated results object containing:
|
|
23
|
+
- t_full: Full time vector
|
|
24
|
+
- x_full: Full state trajectory
|
|
25
|
+
- u_full: Full control trajectory
|
|
26
|
+
- cost: Computed cost
|
|
27
|
+
- ctcs_violation: CTCS constraint violation
|
|
28
|
+
"""
|
|
29
|
+
x = result.x
|
|
30
|
+
u = result.u
|
|
10
31
|
|
|
11
|
-
t = np.array(s_to_t(u,
|
|
32
|
+
t = np.array(s_to_t(x, u, settings)).squeeze()
|
|
12
33
|
|
|
13
|
-
t_full = np.arange(0, t[-1],
|
|
34
|
+
t_full = np.arange(t[0], t[-1], settings.prp.dt)
|
|
14
35
|
|
|
15
|
-
tau_vals, u_full = t_to_tau(u, t_full,
|
|
36
|
+
tau_vals, u_full = t_to_tau(u, t_full, t, settings)
|
|
16
37
|
|
|
17
|
-
|
|
38
|
+
# Match free values from initial state to the initial value from the result
|
|
39
|
+
mask = jnp.array([t == "Free" for t in x.initial_type], dtype=bool)
|
|
40
|
+
settings.sim.x_prop.initial = jnp.where(mask, x.guess[0,:], settings.sim.x_prop.initial)
|
|
18
41
|
|
|
19
|
-
|
|
42
|
+
x_full = simulate_nonlinear_time(params, x, u, tau_vals, t, settings, propagation_solver)
|
|
43
|
+
|
|
44
|
+
# Calculate cost
|
|
20
45
|
i = 0
|
|
21
|
-
cost = np.zeros_like(x[-1,
|
|
22
|
-
for type in
|
|
46
|
+
cost = np.zeros_like(x.guess[-1,i])
|
|
47
|
+
for type in x.initial_type:
|
|
23
48
|
if type == "Minimize":
|
|
24
|
-
cost += x[0, i]
|
|
49
|
+
cost += x.guess[0, i]
|
|
25
50
|
i += 1
|
|
26
51
|
i = 0
|
|
27
|
-
for type in
|
|
52
|
+
for type in x.final_type:
|
|
28
53
|
if type == "Minimize":
|
|
29
|
-
cost += x[-1, i]
|
|
54
|
+
cost += x.guess[-1, i]
|
|
55
|
+
i += 1
|
|
56
|
+
i=0
|
|
57
|
+
for type in x.initial_type:
|
|
58
|
+
if type == "Maximize":
|
|
59
|
+
cost -= x.guess[0, i]
|
|
30
60
|
i += 1
|
|
31
|
-
|
|
61
|
+
i = 0
|
|
62
|
+
for type in x.final_type:
|
|
63
|
+
if type == "Maximize":
|
|
64
|
+
cost -= x.guess[-1, i]
|
|
65
|
+
i += 1
|
|
66
|
+
|
|
67
|
+
# Calculate CTCS constraint violation
|
|
68
|
+
ctcs_violation = x_full[-1, settings.sim.idx_y_prop]
|
|
32
69
|
|
|
33
|
-
|
|
70
|
+
# Update the results object with post-processing data
|
|
71
|
+
result.t_full = t_full
|
|
72
|
+
result.x_full = x_full
|
|
73
|
+
result.u_full = u_full
|
|
74
|
+
result.cost = cost
|
|
75
|
+
result.ctcs_violation = ctcs_violation
|
|
34
76
|
|
|
35
|
-
result.update(more_result)
|
|
36
77
|
return result
|