openscvx 0.1.2__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

openscvx/plotting.py CHANGED
@@ -6,12 +6,13 @@ import pickle
6
6
 
7
7
  from openscvx.utils import qdcm, get_kp_pose
8
8
  from openscvx.config import Config
9
+ from openscvx.results import OptimizationResults
9
10
 
10
- def full_subject_traj_time(results, params):
11
- x_full = results["x_full"]
12
- x_nodes = results["x"]
11
+ def full_subject_traj_time(results: OptimizationResults, params: Config):
12
+ x_full = results.x_full
13
+ x_nodes = results.x
13
14
  t_nodes = x_nodes[:,params.sim.idx_t]
14
- t_full = results['t_full']
15
+ t_full = results.t_full
15
16
  subs_traj = []
16
17
  subs_traj_node = []
17
18
  subs_traj_sen = []
@@ -19,11 +20,12 @@ def full_subject_traj_time(results, params):
19
20
 
20
21
  # if hasattr(params.dyn, 'get_kp_pose'):
21
22
  if "moving_subject" in results and "init_poses" in results:
22
- init_poses = results["init_poses"]
23
+ init_poses = results.plotting_data["init_poses"]
23
24
  subs_traj.append(get_kp_pose(t_full, init_poses))
24
25
  subs_traj_node.append(get_kp_pose(t_nodes, init_poses))
25
26
  elif "init_poses" in results:
26
- for pose in results["init_poses"]:
27
+ init_poses = results.plotting_data["init_poses"]
28
+ for pose in init_poses:
27
29
  # repeat the pose for all time steps
28
30
  pose_full = np.repeat(pose[:,np.newaxis], x_full.shape[0], axis=1).T
29
31
  subs_traj.append(pose_full)
@@ -34,7 +36,7 @@ def full_subject_traj_time(results, params):
34
36
  raise ValueError("No valid method to get keypoint poses.")
35
37
 
36
38
  if "R_sb" in results:
37
- R_sb = results["R_sb"]
39
+ R_sb = results.plotting_data["R_sb"]
38
40
  for sub_traj in subs_traj:
39
41
  sub_traj_sen = []
40
42
  for i in range(x_full.shape[0]):
@@ -50,7 +52,7 @@ def full_subject_traj_time(results, params):
50
52
  subs_traj_sen_node.append(np.array(sub_traj_sen_node).squeeze())
51
53
  return subs_traj, subs_traj_sen, subs_traj_node, subs_traj_sen_node
52
54
  else:
53
- raise ValueError("`R_sb` not found in results dictionary. Cannot compute sensor frame.")
55
+ raise ValueError("`R_sb` not found in results. Cannot compute sensor frame.")
54
56
 
55
57
  def save_gate_parameters(gates, params: Config):
56
58
  gate_centers = []
@@ -75,26 +77,26 @@ def frame_args(duration):
75
77
  "transition": {"duration": duration, "easing": "linear"},
76
78
  }
77
79
 
78
- def plot_constraint_violation(result, params: Config):
80
+ def plot_constraint_violation(result: OptimizationResults, params: Config):
79
81
  fig = make_subplots(rows=2, cols=3, subplot_titles=(r'$\text{Obstacle Violation}$', r'$\text{Sub VP Violation}$', r'$\text{Sub Min Violation}$', r'$\text{Sub Max Violation}$', r'$\text{Sub Direc Violation}$', r'$\text{State Bound Violation}$', r'$\text{Total Violation}$'))
80
82
  fig.update_layout(template='plotly_dark', title=r'$\text{Constraint Violation}$')
81
83
 
82
84
  if "obs_vio" in result:
83
- obs_vio = result["obs_vio"]
85
+ obs_vio = result.plotting_data["obs_vio"]
84
86
  for i in range(obs_vio.shape[0]):
85
87
  color = f'rgb({random.randint(10,255)}, {random.randint(10,255)}, {random.randint(10,255)})'
86
88
  fig.add_trace(go.Scatter(y=obs_vio[i], mode='lines', showlegend=False, line=dict(color=color, width = 2)), row=1, col=1)
87
89
  i = 0
88
90
  else:
89
- print("'obs_vio' not found in result dictionary.")
91
+ print("'obs_vio' not found in result.")
90
92
 
91
93
  # Make names of each state in the state vector
92
94
  state_names = ['x', 'y', 'z', 'vx', 'vy', 'vz', 'q0', 'q1', 'q2', 'q3', 'wx', 'wy', 'wz', 'ctcs']
93
95
 
94
96
  if "sub_vp_vio" in result and "sub_min_vio" in result and "sub_max_vio" in result:
95
- sub_vp_vio = result["sub_vp_vio"]
96
- sub_min_vio = result["sub_min_vio"]
97
- sub_max_vio = result["sub_max_vio"]
97
+ sub_vp_vio = result.plotting_data["sub_vp_vio"]
98
+ sub_min_vio = result.plotting_data["sub_min_vio"]
99
+ sub_max_vio = result.plotting_data["sub_max_vio"]
98
100
  for i in range(sub_vp_vio.shape[0]):
99
101
  color = f'rgb({random.randint(10,255)}, {random.randint(10,255)}, {random.randint(10,255)})'
100
102
  fig.add_trace(go.Scatter(y=sub_vp_vio[i], mode='lines', showlegend=True, name = 'LoS ' + str(i) + ' Error', line=dict(color=color, width = 2)), row=1, col=2)
@@ -106,28 +108,28 @@ def plot_constraint_violation(result, params: Config):
106
108
  fig.add_trace(go.Scatter(y=[], mode='lines', showlegend=False, line=dict(color=color, width = 2)), row=2, col=1)
107
109
  i = 0
108
110
  else:
109
- print("'sub_vp_vio', 'sub_min_vio', or 'sub_max_vio' not found in result dictionary.")
111
+ print("'sub_vp_vio', 'sub_min_vio', or 'sub_max_vio' not found in result.")
110
112
 
111
113
  if "sub_direc_vio" in result:
112
- sub_direc_vio = result["sub_direc_vio"]
114
+ sub_direc_vio = result.plotting_data["sub_direc_vio"]
113
115
  # fig.add_trace(go.Scatter(y=sub_direc_vio, mode='lines', showlegend=False, line=dict(color='red', width = 2)), row=2, col=2)
114
116
  else:
115
- print("'sub_direc_vio' not found in result dictionary.")
117
+ print("'sub_direc_vio' not found in result.")
116
118
 
117
119
  if "state_bound_vio" in result:
118
- state_bound_vio = result["state_bound_vio"]
120
+ state_bound_vio = result.plotting_data["state_bound_vio"]
119
121
  for i in range(state_bound_vio.shape[0]):
120
122
  color = f'rgb({random.randint(10,255)}, {random.randint(10,255)}, {random.randint(10,255)})'
121
123
  fig.add_trace(go.Scatter(y=state_bound_vio[:,i], mode='lines', showlegend=True, name = state_names[i] + ' Error', line=dict(color=color, width = 2)), row=2, col=3)
122
124
  else:
123
- print("'state_bound_vio' not found in result dictionary.")
125
+ print("'state_bound_vio' not found in result.")
124
126
 
125
127
  fig.show()
126
128
 
127
- def plot_initial_guess(result, params: Config):
128
- x_positions = result["x"][0:3]
129
- x_attitude = result["x"][6:10]
130
- subs_positions = result["sub_positions"]
129
+ def plot_initial_guess(result: OptimizationResults, params: Config):
130
+ x_positions = result.x.guess[:, 0:3].T
131
+ x_attitude = result.x.guess[:, 6:10].T
132
+ subs_positions = result.plotting_data["sub_positions"]
131
133
 
132
134
  fig = go.Figure(go.Scatter3d(x=[], y=[], z=[], mode='lines+markers', line=dict(color='gray', width = 2)))
133
135
 
@@ -172,17 +174,17 @@ def plot_initial_guess(result, params: Config):
172
174
  fig.add_trace(go.Scatter3d(x=sub_positions[:,0], y=sub_positions[:,1], z=sub_positions[:,2], mode='lines+markers', line=dict(color='red', width = 5), name='Subject'))
173
175
  fig.show()
174
176
 
175
- def plot_scp_animation(result: dict,
177
+ def plot_scp_animation(result: OptimizationResults,
176
178
  params = None,
177
179
  path=""):
178
- tof = result["t_final"]
180
+ tof = result.t_final
179
181
  title = f'SCP Simulation: {tof} seconds'
180
- drone_positions = result["x_full"][:, :3]
181
- drone_attitudes = result["x_full"][:, 6:10]
182
- drone_forces = result["u_full"][:, :3]
183
- scp_interp_trajs = scp_traj_interp(result["x_history"], params)
184
- scp_ctcs_trajs = result["x_history"]
185
- scp_multi_shoot = result["discretization"]
182
+ drone_positions = result.x_full[:, :3]
183
+ drone_attitudes = result.x_full[:, 6:10]
184
+ drone_forces = result.u_full[:, :3]
185
+ scp_interp_trajs = scp_traj_interp(result.x_history, params)
186
+ scp_ctcs_trajs = result.x_history
187
+ scp_multi_shoot = result.discretization_history
186
188
  # obstacles = result_ctcs["obstacles"]
187
189
  # gates = result_ctcs["gates"]
188
190
  if "moving_subject" in result or "init_poses" in result:
@@ -289,14 +291,14 @@ def plot_scp_animation(result: dict,
289
291
  fig.add_trace(go.Surface(x=points[:, 0].reshape(n,n), y=points[:, 1].reshape(n,n), z=points[:, 2].reshape(n,n), opacity = 0.5, showscale=False))
290
292
 
291
293
  if "vertices" in result:
292
- for vertices in result["vertices"]:
294
+ for vertices in result.plotting_data["vertices"]:
293
295
  # Plot a line through the vertices of the gate
294
296
  fig.add_trace(go.Scatter3d(x=[vertices[0][0], vertices[1][0], vertices[2][0], vertices[3][0], vertices[0][0]], y=[vertices[0][1], vertices[1][1], vertices[2][1], vertices[3][1], vertices[0][1]], z=[vertices[0][2], vertices[1][2], vertices[2][2], vertices[3][2], vertices[0][2]], mode='lines', showlegend=False, line=dict(color='blue', width=10)))
295
297
 
296
298
  # Add the subject positions
297
- if "n_subs" in result and result["n_subs"] != 0:
299
+ if "n_subs" in result and result.plotting_data["n_subs"] != 0:
298
300
  if "moving_subject" in result:
299
- if result["moving_subject"]:
301
+ if result.plotting_data["moving_subject"]:
300
302
  for sub_positions in subs_positions:
301
303
  fig.add_trace(go.Scatter3d(x=sub_positions[:,0], y=sub_positions[:,1], z=sub_positions[:,2], mode='lines', line=dict(color='red', width = 5), showlegend=False))
302
304
  else:
@@ -419,202 +421,57 @@ def scp_traj_interp(scp_trajs, params: Config):
419
421
  scp_prop_trajs.append(np.array(states))
420
422
  return scp_prop_trajs
421
423
 
422
- def plot_state(result, params: Config):
423
- scp_trajs = scp_traj_interp(result["x_history"], params)
424
- x_full = result["x_full"]
424
+ def plot_state(result: OptimizationResults, params: Config):
425
+ x_full = result.x_full
426
+ t_full = result.t_full
427
+ dis_history = result.discretization_history
428
+
429
+ n_x = params.sim.n_states
425
430
 
426
431
  fig = make_subplots(rows=2, cols=7, subplot_titles=('X Position', 'Y Position', 'Z Position', 'X Velocity', 'Y Velocity', 'Z Velocity', 'CTCS Augmentation', 'Q1', 'Q2', 'Q3', 'Q4', 'X Angular Rate', 'Y Angular Rate', 'Z Angular Rate'))
427
432
  fig.update_layout(title_text="State Trajectories", template='plotly_dark')
428
433
 
429
- # Plot the position
430
- x_min = params.sim.min_state[0]
431
- x_max = params.sim.max_state[0]
432
- for traj in scp_trajs:
433
- fig.add_trace(go.Scatter(y=traj[:,0], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=1)
434
- fig.add_trace(go.Scatter(y=x_full[:,0], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=1)
435
- fig.add_hline(y=x_min, line=dict(color='red', width=2), row = 1, col = 1)
436
- fig.add_hline(y=x_max, line=dict(color='red', width=2), row = 1, col = 1)
437
- fig.update_yaxes(range=[x_min, x_max], row=1, col=1)
438
-
439
- y_min = params.sim.min_state[1]
440
- y_max = params.sim.max_state[1]
441
- for traj in scp_trajs:
442
- fig.add_trace(go.Scatter(y=traj[:,1], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=2)
443
- fig.add_trace(go.Scatter(y=x_full[:,1], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=2)
444
- fig.add_hline(y=y_min, line=dict(color='red', width=2), row = 1, col = 2)
445
- fig.add_hline(y=y_max, line=dict(color='red', width=2), row = 1, col = 2)
446
-
447
- z_min = params.sim.min_state[2]
448
- z_max = params.sim.max_state[2]
449
- for traj in scp_trajs:
450
- fig.add_trace(go.Scatter(y=traj[:,2], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=3)
451
- fig.add_trace(go.Scatter(y=x_full[:,2], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=3)
452
- fig.add_hline(y=z_min, line=dict(color='red', width=2), row = 1, col = 3)
453
- fig.add_hline(y=z_max, line=dict(color='red', width=2), row = 1, col = 3)
454
-
455
- # Plot the velocity
456
- vx_min = params.sim.min_state[3]
457
- vx_max = params.sim.max_state[3]
458
- for traj in scp_trajs:
459
- fig.add_trace(go.Scatter(y=traj[:,3], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=4)
460
- fig.add_trace(go.Scatter(y=x_full[:,3], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=4)
461
- fig.add_hline(y=vx_min, line=dict(color='red', width=2), row = 1, col = 4)
462
- fig.add_hline(y=vx_max, line=dict(color='red', width=2), row = 1, col = 4)
463
-
464
- vy_min = params.sim.min_state[4]
465
- vy_max = params.sim.max_state[4]
466
- for traj in scp_trajs:
467
- fig.add_trace(go.Scatter(y=traj[:,4], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=5)
468
- fig.add_trace(go.Scatter(y=x_full[:,4], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=5)
469
- fig.add_hline(y=vy_min, line=dict(color='red', width=2), row = 1, col = 5)
470
- fig.add_hline(y=vy_max, line=dict(color='red', width=2), row = 1, col = 5)
471
-
472
- vz_min = params.sim.min_state[5]
473
- vz_max = params.sim.max_state[5]
474
- for traj in scp_trajs:
475
- fig.add_trace(go.Scatter(y=traj[:,5], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=6)
476
- fig.add_trace(go.Scatter(y=x_full[:,5], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=6)
477
- fig.add_hline(y=vz_min, line=dict(color='red', width=2), row = 1, col = 6)
478
- fig.add_hline(y=vz_max, line=dict(color='red', width=2), row = 1, col = 6)
479
-
480
- # # Plot the norm of the quaternion
481
- # for traj in scp_trajs:
482
- # fig.add_trace(go.Scatter(y=la.norm(traj[1:,6:10], axis = 1), mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=7)
483
- # fig.add_trace(go.Scatter(y=la.norm(x_full[1:,6:10], axis = 1), mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=7)
484
- for traj in scp_trajs:
485
- fig.add_trace(go.Scatter(y=traj[:,-1], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=7)
486
- fig.add_trace(go.Scatter(y=x_full[:,-1], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=7)
487
- # fig.add_hline(y=vz_min, line=dict(color='red', width=2), row = 1, col = 6)
488
- # fig.add_hline(y=vz_max, line=dict(color='red', width=2), row = 1, col = 6)
489
-
490
- # Plot the attitude
491
- q1_min = params.sim.min_state[6]
492
- q1_max = params.sim.max_state[6]
493
- for traj in scp_trajs:
494
- fig.add_trace(go.Scatter(y=traj[:,6], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=1)
495
- fig.add_trace(go.Scatter(y=x_full[:,6], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=1)
496
- fig.add_hline(y=q1_min, line=dict(color='red', width=2), row = 2, col = 1)
497
- fig.add_hline(y=q1_max, line=dict(color='red', width=2), row = 2, col = 1)
498
-
499
- q2_min = params.sim.min_state[7]
500
- q2_max = params.sim.max_state[7]
501
- for traj in scp_trajs:
502
- fig.add_trace(go.Scatter(y=traj[:,7], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=2)
503
- fig.add_trace(go.Scatter(y=x_full[:,7], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=2)
504
- fig.add_hline(y=q2_min, line=dict(color='red', width=2), row = 2, col = 2)
505
- fig.add_hline(y=q2_max, line=dict(color='red', width=2), row = 2, col = 2)
506
-
507
- q3_min = params.sim.min_state[8]
508
- q3_max = params.sim.max_state[8]
509
- for traj in scp_trajs:
510
- fig.add_trace(go.Scatter(y=traj[:,8], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=3)
511
- fig.add_trace(go.Scatter(y=x_full[:,8], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=3)
512
- fig.add_hline(y=q3_min, line=dict(color='red', width=2), row = 2, col = 3)
513
- fig.add_hline(y=q3_max, line=dict(color='red', width=2), row = 2, col = 3)
514
-
515
- q4_min = params.sim.min_state[9]
516
- q4_max = params.sim.max_state[9]
517
- for traj in scp_trajs:
518
- fig.add_trace(go.Scatter(y=traj[:,9], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=4)
519
- fig.add_trace(go.Scatter(y=x_full[:,9], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=4)
520
- fig.add_hline(y=q4_min, line=dict(color='red', width=2), row = 2, col = 4)
521
- fig.add_hline(y=q4_max, line=dict(color='red', width=2), row = 2, col = 4)
522
-
523
- # Plot the angular rate
524
- wx_min = params.sim.min_state[10]
525
- wx_max = params.sim.max_state[10]
526
- for traj in scp_trajs:
527
- fig.add_trace(go.Scatter(y=traj[:,10], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=5)
528
- fig.add_trace(go.Scatter(y=x_full[:,10], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=5)
529
- fig.add_hline(y=wx_min, line=dict(color='red', width=2), row = 2, col = 5)
530
- fig.add_hline(y=wx_max, line=dict(color='red', width=2), row = 2, col = 5)
434
+ # Plot the State
435
+ # for traj in dis_history:
436
+ for i in range(n_x):
437
+ x_min = params.sim.x.min[i]
438
+ x_max = params.sim.x.max[i]
439
+ # fig.add_trace(go.Scatter(y=traj[i], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=(i // 7) + 1, col=(i % 7) + 1)
440
+ fig.add_trace(go.Scatter(x=t_full, y=x_full[:,i], mode='lines', showlegend=True, line=dict(color='green', width = 2)), row=(i // 7) + 1, col=(i % 7) + 1)
441
+ fig.add_trace(go.Scatter(x=params.sim.x.guess[:,7], y=params.sim.x.guess[:,i], mode='lines', showlegend=True, line=dict(color='blue', width = 0.5)), row=(i // 7) + 1, col=(i % 7) + 1)
442
+ fig.add_hline(y=x_min, line=dict(color='red', width=2), row = (i // 7) + 1, col = (i % 7) + 1)
443
+ fig.add_hline(y=x_max, line=dict(color='red', width=2), row = (i // 7) + 1, col = (i % 7) + 1)
531
444
 
532
- wy_min = params.sim.min_state[11]
533
- wy_max = params.sim.max_state[11]
534
- for traj in scp_trajs:
535
- fig.add_trace(go.Scatter(y=traj[:,11], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=6)
536
- fig.add_trace(go.Scatter(y=x_full[:,11], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=6)
537
- fig.add_hline(y=wy_min, line=dict(color='red', width=2), row = 2, col = 6)
538
- fig.add_hline(y=wy_max, line=dict(color='red', width=2), row = 2, col = 6)
445
+ return fig
539
446
 
540
- wz_min = params.sim.min_state[12]
541
- wz_max = params.sim.max_state[12]
542
- for traj in scp_trajs:
543
- fig.add_trace(go.Scatter(y=traj[:,12], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=7)
544
- fig.add_trace(go.Scatter(y=x_full[:,12], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=7)
545
- fig.add_hline(y=wz_min, line=dict(color='red', width=2), row = 2, col = 7)
546
- fig.add_hline(y=wz_max, line=dict(color='red', width=2), row = 2, col = 7)
547
- fig.show()
447
+ def plot_control(result: OptimizationResults, params: Config):
448
+ u_full = result.u_full
449
+ t_full = result.t_full
548
450
 
549
- def plot_control(result, params: Config):
550
- scp_controls = result["u_history"]
551
- u = result["u"]
451
+ u = params.sim.u
452
+ x = params.sim.x
552
453
 
553
- fx_min = params.sim.min_control[0]
554
- fx_max = params.sim.max_control[0]
555
-
556
- fy_min = params.sim.min_control[1]
557
- fy_max = params.sim.max_control[1]
558
-
559
- fz_min = params.sim.min_control[2]
560
- fz_max = params.sim.max_control[2]
561
-
562
- tau_x_min = params.sim.max_control[3]
563
- tau_x_max = params.sim.min_control[3]
564
-
565
- tau_y_min = params.sim.max_control[4]
566
- tau_y_max = params.sim.min_control[4]
567
-
568
- tau_z_min = params.sim.max_control[5]
569
- tau_z_max = params.sim.min_control[5]
454
+ n_u = params.sim.n_controls
570
455
 
571
456
  fig = make_subplots(rows=2, cols=3, subplot_titles=('X Force', 'Y Force', 'Z Force', 'X Torque', 'Y Torque', 'Z Torque'))
572
457
  fig.update_layout(title_text="Control Trajectories", template='plotly_dark')
573
458
 
574
- for traj in scp_controls:
575
- fig.add_trace(go.Scatter(y=traj[0], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=1)
576
- fig.add_trace(go.Scatter(y=u[0], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=1)
577
- fig.add_hline(y=fx_min, line=dict(color='red', width=2), row = 1, col = 1)
578
- fig.add_hline(y=fx_max, line=dict(color='red', width=2), row = 1, col = 1)
579
-
580
- for traj in scp_controls:
581
- fig.add_trace(go.Scatter(y=traj[1], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=2)
582
- fig.add_trace(go.Scatter(y=u[1], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=2)
583
- fig.add_hline(y=fy_min, line=dict(color='red', width=2), row = 1, col = 2)
584
- fig.add_hline(y=fy_max, line=dict(color='red', width=2), row = 1, col = 2)
585
-
586
- for traj in scp_controls:
587
- fig.add_trace(go.Scatter(y=traj[2], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=3)
588
- fig.add_trace(go.Scatter(y=u[2], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=3)
589
- fig.add_hline(y=fz_min, line=dict(color='red', width=2), row = 1, col = 3)
590
- fig.add_hline(y=fz_max, line=dict(color='red', width=2), row = 1, col = 3)
591
-
592
- for traj in scp_controls:
593
- fig.add_trace(go.Scatter(y=traj[3], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=1)
594
- fig.add_trace(go.Scatter(y=u[3], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=1)
595
- fig.add_hline(y=tau_x_min, line=dict(color='red', width=2), row = 2, col = 1)
596
- fig.add_hline(y=tau_x_max, line=dict(color='red', width=2), row = 2, col = 1)
597
-
598
- for traj in scp_controls:
599
- fig.add_trace(go.Scatter(y=traj[4], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=2)
600
- fig.add_trace(go.Scatter(y=u[4], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=2)
601
- fig.add_hline(y=tau_y_min, line=dict(color='red', width=2), row = 2, col = 2)
602
- fig.add_hline(y=tau_y_max, line=dict(color='red', width=2), row = 2, col = 2)
603
-
604
- for traj in scp_controls:
605
- fig.add_trace(go.Scatter(y=traj[5], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=3)
606
- fig.add_trace(go.Scatter(y=u[5], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=3)
607
- fig.add_hline(y=tau_z_min, line=dict(color='red', width=2), row = 2, col = 3)
608
- fig.add_hline(y=tau_z_max, line=dict(color='red', width=2), row = 2, col = 3)
459
+ for i in range(n_u):
460
+ u_min = u.min[i]
461
+ u_max = u.max[i]
462
+ fig.add_trace(go.Scatter(x=t_full, y=u_full[:,i], mode='lines', showlegend=True, line=dict(color='green', width = 2)), row=(i // 3) + 1, col=(i % 3) + 1)
463
+ fig.add_trace(go.Scatter(x=x.guess[:,7], y=u.guess[:,i], mode='lines', showlegend=True, line=dict(color='blue', width = 0.5)), row=(i // 3) + 1, col=(i % 3) + 1)
464
+ fig.add_hline(y=u_min, line=dict(color='red', width=2), row = (i // 3) + 1, col = (i % 3) + 1)
465
+ fig.add_hline(y=u_max, line=dict(color='red', width=2), row = (i // 3) + 1, col = (i % 3) + 1)
609
466
 
610
- fig.show()
467
+ return fig
611
468
 
612
- def plot_losses(result, params: Config):
469
+ def plot_losses(result: OptimizationResults, params: Config):
613
470
  # Plot J_tr, J_vb, J_vc, J_vc_ctcs
614
- J_tr = result["J_tr_history"]
615
- J_vb = result["J_vb_history"]
616
- J_vc = result["J_vc_history"]
617
- J_vc_ctcs = result["J_vc_ctcs_vec"]
471
+ J_tr = result.J_tr_history
472
+ J_vb = result.J_vb_history
473
+ J_vc = result.J_vc_history
474
+ J_vc_ctcs = result.plotting_data["J_vc_ctcs_vec"]
618
475
 
619
476
  fig = make_subplots(rows=2, cols=2, subplot_titles=('J_tr', 'J_vb', 'J_vc', 'J_vc_ctcs'))
620
477
  fig.update_layout(title_text="Losses", template='plotly_dark')
@@ -1,36 +1,77 @@
1
1
  import numpy as np
2
+ import jax.numpy as jnp
2
3
 
3
4
  from openscvx.propagation import s_to_t, t_to_tau, simulate_nonlinear_time
4
5
  from openscvx.config import Config
6
+ from openscvx.results import OptimizationResults
5
7
 
6
8
 
7
- def propagate_trajectory_results(params: Config, result: dict, propagation_solver: callable) -> dict:
8
- x = result["x"]
9
- u = result["u"]
9
+ def propagate_trajectory_results(params: dict, settings: Config, result: OptimizationResults, propagation_solver: callable) -> OptimizationResults:
10
+ """Propagate the optimal trajectory and compute additional results.
11
+
12
+ This function takes the optimal control solution and propagates it through the
13
+ nonlinear dynamics to compute the actual state trajectory and other metrics.
14
+
15
+ Args:
16
+ params (dict): System parameters.
17
+ settings (Config): Configuration settings.
18
+ result (OptimizationResults): Optimization results object.
19
+ propagation_solver (callable): Function for propagating the system state.
20
+
21
+ Returns:
22
+ OptimizationResults: Updated results object containing:
23
+ - t_full: Full time vector
24
+ - x_full: Full state trajectory
25
+ - u_full: Full control trajectory
26
+ - cost: Computed cost
27
+ - ctcs_violation: CTCS constraint violation
28
+ """
29
+ x = result.x
30
+ u = result.u
10
31
 
11
- t = np.array(s_to_t(u, params))
32
+ t = np.array(s_to_t(x, u, settings)).squeeze()
12
33
 
13
- t_full = np.arange(0, t[-1], params.prp.dt)
34
+ t_full = np.arange(t[0], t[-1], settings.prp.dt)
14
35
 
15
- tau_vals, u_full = t_to_tau(u, t_full, u, t, params)
36
+ tau_vals, u_full = t_to_tau(u, t_full, t, settings)
16
37
 
17
- x_full = simulate_nonlinear_time(x[0], u, tau_vals, t, params, propagation_solver)
38
+ # Match free values from initial state to the initial value from the result
39
+ mask = jnp.array([t == "Free" for t in x.initial_type], dtype=bool)
40
+ settings.sim.x_prop.initial = jnp.where(mask, x.guess[0,:], settings.sim.x_prop.initial)
18
41
 
19
- print("Total CTCS Constraint Violation:", x_full[-1, params.sim.idx_y])
42
+ x_full = simulate_nonlinear_time(params, x, u, tau_vals, t, settings, propagation_solver)
43
+
44
+ # Calculate cost
20
45
  i = 0
21
- cost = np.zeros_like(x[-1, i])
22
- for type in params.sim.initial_state.type:
46
+ cost = np.zeros_like(x.guess[-1,i])
47
+ for type in x.initial_type:
23
48
  if type == "Minimize":
24
- cost += x[0, i]
49
+ cost += x.guess[0, i]
25
50
  i += 1
26
51
  i = 0
27
- for type in params.sim.final_state.type:
52
+ for type in x.final_type:
28
53
  if type == "Minimize":
29
- cost += x[-1, i]
54
+ cost += x.guess[-1, i]
55
+ i += 1
56
+ i=0
57
+ for type in x.initial_type:
58
+ if type == "Maximize":
59
+ cost -= x.guess[0, i]
30
60
  i += 1
31
- print("Cost: ", cost)
61
+ i = 0
62
+ for type in x.final_type:
63
+ if type == "Maximize":
64
+ cost -= x.guess[-1, i]
65
+ i += 1
66
+
67
+ # Calculate CTCS constraint violation
68
+ ctcs_violation = x_full[-1, settings.sim.idx_y_prop]
32
69
 
33
- more_result = dict(t_full=t_full, x_full=x_full, u_full=u_full)
70
+ # Update the results object with post-processing data
71
+ result.t_full = t_full
72
+ result.x_full = x_full
73
+ result.u_full = u_full
74
+ result.cost = cost
75
+ result.ctcs_violation = ctcs_violation
34
76
 
35
- result.update(more_result)
36
77
  return result