openscvx 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

openscvx/plotting.py ADDED
@@ -0,0 +1,632 @@
1
+ from plotly.subplots import make_subplots
2
+ import random
3
+ import plotly.graph_objects as go
4
+ import numpy as np
5
+ import pickle
6
+
7
+ from openscvx.utils import qdcm, get_kp_pose
8
+ from openscvx.config import Config
9
+
10
+ def full_subject_traj_time(results, params):
11
+ x_full = results["x_full"]
12
+ x_nodes = results["x"]
13
+ t_nodes = x_nodes[:,params.sim.idx_t]
14
+ t_full = results['t_full']
15
+ subs_traj = []
16
+ subs_traj_node = []
17
+ subs_traj_sen = []
18
+ subs_traj_sen_node = []
19
+
20
+ # if hasattr(params.dyn, 'get_kp_pose'):
21
+ if "moving_subject" in results and "init_poses" in results:
22
+ init_poses = results["init_poses"]
23
+ subs_traj.append(get_kp_pose(t_full, init_poses))
24
+ subs_traj_node.append(get_kp_pose(t_nodes, init_poses))
25
+ elif "init_poses" in results:
26
+ for pose in results["init_poses"]:
27
+ # repeat the pose for all time steps
28
+ pose_full = np.repeat(pose[:,np.newaxis], x_full.shape[0], axis=1).T
29
+ subs_traj.append(pose_full)
30
+
31
+ pose_node = np.repeat(pose[:,np.newaxis], x_nodes.shape[0], axis=1).T
32
+ subs_traj_node.append(pose_node)
33
+ else:
34
+ raise ValueError("No valid method to get keypoint poses.")
35
+
36
+ if "R_sb" in results:
37
+ R_sb = results["R_sb"]
38
+ for sub_traj in subs_traj:
39
+ sub_traj_sen = []
40
+ for i in range(x_full.shape[0]):
41
+ sub_pose = sub_traj[i]
42
+ sub_traj_sen.append(R_sb @ qdcm(x_full[i, 6:10]).T @ (sub_pose - x_full[i, 0:3]))
43
+ subs_traj_sen.append(np.array(sub_traj_sen).squeeze())
44
+
45
+ for sub_traj_node in subs_traj_node:
46
+ sub_traj_sen_node = []
47
+ for i in range(x_nodes.shape[0]):
48
+ sub_pose = sub_traj_node[i]
49
+ sub_traj_sen_node.append(R_sb @ qdcm(x_nodes[i, 6:10]).T @ (sub_pose - x_nodes[i, 0:3]).T)
50
+ subs_traj_sen_node.append(np.array(sub_traj_sen_node).squeeze())
51
+ return subs_traj, subs_traj_sen, subs_traj_node, subs_traj_sen_node
52
+ else:
53
+ raise ValueError("`R_sb` not found in results dictionary. Cannot compute sensor frame.")
54
+
55
+ def save_gate_parameters(gates, params: Config):
56
+ gate_centers = []
57
+ gate_vertices = []
58
+ for gate in gates:
59
+ gate_centers.append(gate.center)
60
+ gate_vertices.append(gate.vertices)
61
+ gate_params = dict(
62
+ gate_centers = gate_centers,
63
+ gate_vertices = gate_vertices
64
+ )
65
+
66
+ # Use pickle to save the gate parameters
67
+ with open('results/gate_params.pickle', 'wb') as f:
68
+ pickle.dump(gate_params, f)
69
+
70
+ def frame_args(duration):
71
+ return {
72
+ "frame": {"duration": duration},
73
+ "mode": "immediate",
74
+ "fromcurrent": True,
75
+ "transition": {"duration": duration, "easing": "linear"},
76
+ }
77
+
78
+ def plot_constraint_violation(result, params: Config):
79
+ fig = make_subplots(rows=2, cols=3, subplot_titles=(r'$\text{Obstacle Violation}$', r'$\text{Sub VP Violation}$', r'$\text{Sub Min Violation}$', r'$\text{Sub Max Violation}$', r'$\text{Sub Direc Violation}$', r'$\text{State Bound Violation}$', r'$\text{Total Violation}$'))
80
+ fig.update_layout(template='plotly_dark', title=r'$\text{Constraint Violation}$')
81
+
82
+ if "obs_vio" in result:
83
+ obs_vio = result["obs_vio"]
84
+ for i in range(obs_vio.shape[0]):
85
+ color = f'rgb({random.randint(10,255)}, {random.randint(10,255)}, {random.randint(10,255)})'
86
+ fig.add_trace(go.Scatter(y=obs_vio[i], mode='lines', showlegend=False, line=dict(color=color, width = 2)), row=1, col=1)
87
+ i = 0
88
+ else:
89
+ print("'obs_vio' not found in result dictionary.")
90
+
91
+ # Make names of each state in the state vector
92
+ state_names = ['x', 'y', 'z', 'vx', 'vy', 'vz', 'q0', 'q1', 'q2', 'q3', 'wx', 'wy', 'wz', 'ctcs']
93
+
94
+ if "sub_vp_vio" in result and "sub_min_vio" in result and "sub_max_vio" in result:
95
+ sub_vp_vio = result["sub_vp_vio"]
96
+ sub_min_vio = result["sub_min_vio"]
97
+ sub_max_vio = result["sub_max_vio"]
98
+ for i in range(sub_vp_vio.shape[0]):
99
+ color = f'rgb({random.randint(10,255)}, {random.randint(10,255)}, {random.randint(10,255)})'
100
+ fig.add_trace(go.Scatter(y=sub_vp_vio[i], mode='lines', showlegend=True, name = 'LoS ' + str(i) + ' Error', line=dict(color=color, width = 2)), row=1, col=2)
101
+ if params.vp.tracking:
102
+ fig.add_trace(go.Scatter(y=sub_min_vio[i], mode='lines', showlegend=False, line=dict(color=color, width = 2)), row=1, col=3)
103
+ fig.add_trace(go.Scatter(y=sub_max_vio[i], mode='lines', showlegend=False, line=dict(color=color, width = 2)), row=2, col=1)
104
+ else:
105
+ fig.add_trace(go.Scatter(y=[], mode='lines', showlegend=False, line=dict(color=color, width = 2)), row=1, col=3)
106
+ fig.add_trace(go.Scatter(y=[], mode='lines', showlegend=False, line=dict(color=color, width = 2)), row=2, col=1)
107
+ i = 0
108
+ else:
109
+ print("'sub_vp_vio', 'sub_min_vio', or 'sub_max_vio' not found in result dictionary.")
110
+
111
+ if "sub_direc_vio" in result:
112
+ sub_direc_vio = result["sub_direc_vio"]
113
+ # fig.add_trace(go.Scatter(y=sub_direc_vio, mode='lines', showlegend=False, line=dict(color='red', width = 2)), row=2, col=2)
114
+ else:
115
+ print("'sub_direc_vio' not found in result dictionary.")
116
+
117
+ if "state_bound_vio" in result:
118
+ state_bound_vio = result["state_bound_vio"]
119
+ for i in range(state_bound_vio.shape[0]):
120
+ color = f'rgb({random.randint(10,255)}, {random.randint(10,255)}, {random.randint(10,255)})'
121
+ fig.add_trace(go.Scatter(y=state_bound_vio[:,i], mode='lines', showlegend=True, name = state_names[i] + ' Error', line=dict(color=color, width = 2)), row=2, col=3)
122
+ else:
123
+ print("'state_bound_vio' not found in result dictionary.")
124
+
125
+ fig.show()
126
+
127
+ def plot_initial_guess(result, params: Config):
128
+ x_positions = result["x"][0:3]
129
+ x_attitude = result["x"][6:10]
130
+ subs_positions = result["sub_positions"]
131
+
132
+ fig = go.Figure(go.Scatter3d(x=[], y=[], z=[], mode='lines+markers', line=dict(color='gray', width = 2)))
133
+
134
+ # Plot the position of the drone
135
+ fig.add_trace(go.Scatter3d(x=x_positions[0], y=x_positions[1], z=x_positions[2], mode='lines+markers', line=dict(color='green', width = 5)))
136
+
137
+ # Plot the attitude of the drone
138
+ # Draw drone attitudes as axes
139
+ step = 1
140
+ indices = np.array(list(range(x_positions.shape[1])))
141
+ for i in range(0, len(indices), step):
142
+ att = x_attitude[:, indices[i]]
143
+
144
+ # Convert quaternion to rotation matrix
145
+ rotation_matrix = qdcm(att)
146
+
147
+ # Extract axes from rotation matrix
148
+ axes = 2 * np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
149
+ rotated_axes = np.dot(rotation_matrix, axes).T
150
+
151
+ colors = ['#FF0000', '#00FF00', '#0000FF']
152
+
153
+ for k in range(3):
154
+ axis = rotated_axes[k]
155
+ color = colors[k]
156
+
157
+ fig.add_trace(go.Scatter3d(
158
+ x=[x_positions[0, indices[i]], x_positions[0, indices[i]] + axis[0]],
159
+ y=[x_positions[1, indices[i]], x_positions[1, indices[i]] + axis[1]],
160
+ z=[x_positions[2, indices[i]], x_positions[2, indices[i]] + axis[2]],
161
+ mode='lines+text',
162
+ line=dict(color=color, width=4),
163
+ showlegend=False
164
+ ))
165
+
166
+ fig.update_layout(template='plotly_dark')
167
+ fig.update_layout(scene=dict(aspectmode='manual', aspectratio=dict(x=10, y=10, z=10)))
168
+ fig.update_layout(scene=dict(xaxis=dict(range=[-200, 200]), yaxis=dict(range=[-200, 200]), zaxis=dict(range=[-200, 200])))
169
+
170
+ # Plot the keypoint
171
+ for sub_positions in subs_positions:
172
+ fig.add_trace(go.Scatter3d(x=sub_positions[:,0], y=sub_positions[:,1], z=sub_positions[:,2], mode='lines+markers', line=dict(color='red', width = 5), name='Subject'))
173
+ fig.show()
174
+
175
+ def plot_scp_animation(result: dict,
176
+ params = None,
177
+ path=""):
178
+ tof = result["t_final"]
179
+ title = f'SCP Simulation: {tof} seconds'
180
+ drone_positions = result["x_full"][:, :3]
181
+ drone_attitudes = result["x_full"][:, 6:10]
182
+ drone_forces = result["u_full"][:, :3]
183
+ scp_interp_trajs = scp_traj_interp(result["x_history"], params)
184
+ scp_ctcs_trajs = result["x_history"]
185
+ scp_multi_shoot = result["discretization"]
186
+ # obstacles = result_ctcs["obstacles"]
187
+ # gates = result_ctcs["gates"]
188
+ if "moving_subject" in result or "init_poses" in result:
189
+ subs_positions, _, _, _ = full_subject_traj_time(result, params)
190
+ fig = go.Figure(go.Scatter3d(x=[], y=[], z=[], mode='lines+markers', line=dict(color='gray', width = 2), name='SCP Iterations'))
191
+ for j in range(200):
192
+ fig.add_trace(go.Scatter3d(x=[], y=[], z=[], mode='lines+markers', line=dict(color='gray', width = 2)))
193
+
194
+ # fig.update_layout(height=1000)
195
+
196
+ fig.add_trace(go.Scatter3d(x=drone_positions[:,0], y=drone_positions[:,1], z=drone_positions[:,2], mode='lines', line=dict(color='green', width = 5), name='Nonlinear Propagation'))
197
+
198
+ fig.update_layout(template='plotly_dark', title=title)
199
+
200
+ fig.update_layout(scene=dict(aspectmode='manual', aspectratio=dict(x=10, y=10, z=10)))
201
+ fig.update_layout(scene=dict(xaxis=dict(range=[-200, 200]), yaxis=dict(range=[-200, 200]), zaxis=dict(range=[-200, 200])))
202
+
203
+ # Extract the number of states and controls from the parameters
204
+ n_x = params.sim.n_states
205
+ n_u = params.sim.n_controls
206
+
207
+ # Define indices for slicing the augmented state vector
208
+ i0 = 0
209
+ i1 = n_x
210
+ i2 = i1 + n_x * n_x
211
+ i3 = i2 + n_x * n_u
212
+ i4 = i3 + n_x * n_u
213
+ i5 = i4 + n_x
214
+
215
+ # Plot the attitudes of the SCP Trajs
216
+ frames = []
217
+ traj_iter = 0
218
+
219
+ for scp_traj in scp_ctcs_trajs:
220
+ drone_positions = scp_traj[:,0:3]
221
+ drone_attitudes = scp_traj[:,6:10]
222
+ frame = go.Frame(name=str(traj_iter))
223
+ data = []
224
+ # Plot the multiple shooting trajectories
225
+ pos_traj = []
226
+ if traj_iter < len(scp_multi_shoot):
227
+ for i_multi in range(scp_multi_shoot[traj_iter].shape[1]):
228
+ pos_traj.append(scp_multi_shoot[traj_iter][:,i_multi].reshape(-1, i5)[:,0:3])
229
+ pos_traj = np.array(pos_traj)
230
+
231
+ for j in range(pos_traj.shape[1]):
232
+ if j == 0:
233
+ data.append(go.Scatter3d(x=pos_traj[:,j, 0], y=pos_traj[:,j, 1], z=pos_traj[:,j, 2], mode='lines', legendgroup='Multishot Trajectory', name='Multishot Trajectory ' + str(traj_iter), showlegend=True, line=dict(color='blue', width = 5)))
234
+ else:
235
+ data.append(go.Scatter3d(x=pos_traj[:,j, 0], y=pos_traj[:,j, 1], z=pos_traj[:,j, 2], mode='lines', legendgroup='Multishot Trajectory', showlegend=False, line=dict(color='blue', width = 5)))
236
+
237
+
238
+ for i in range(drone_attitudes.shape[0]):
239
+ att = drone_attitudes[i]
240
+
241
+ # Convert quaternion to rotation matrix
242
+ rotation_matrix = qdcm(att)
243
+
244
+ # Extract axes from rotation matrix
245
+ axes = 2 * np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
246
+ rotated_axes = np.dot(rotation_matrix, axes.T).T
247
+
248
+ colors = ['#FF0000', '#00FF00', '#0000FF']
249
+
250
+ for k in range(3):
251
+ axis = rotated_axes[k]
252
+ color = colors[k]
253
+
254
+ data.append(go.Scatter3d(
255
+ x=[scp_traj[i, 0], scp_traj[i, 0] + axis[0]],
256
+ y=[scp_traj[i, 1], scp_traj[i, 1] + axis[1]],
257
+ z=[scp_traj[i, 2], scp_traj[i, 2] + axis[2]],
258
+ mode='lines+text',
259
+ line=dict(color=color, width=4),
260
+ showlegend=False
261
+ ))
262
+ traj_iter += 1
263
+ frame.data = data
264
+ frames.append(frame)
265
+ fig.frames = frames
266
+
267
+ i = 1
268
+ if "obstacles_centers" in result:
269
+ for center, axes, radius in zip(result['obstacles_centers'], result['obstacles_axes'], result['obstacles_radii']):
270
+ n = 20
271
+ # Generate points on the unit sphere
272
+ u = np.linspace(0, 2 * np.pi, n)
273
+ v = np.linspace(0, np.pi, n)
274
+
275
+ x = np.outer(np.cos(u), np.sin(v))
276
+ y = np.outer(np.sin(u), np.sin(v))
277
+ z = np.outer(np.ones(np.size(u)), np.cos(v))
278
+
279
+ # Scale points by radii
280
+ x = 1/radius[0] * x
281
+ y = 1/radius[1] * y
282
+ z = 1/radius[2] * z
283
+
284
+ # Rotate and translate points
285
+ points = np.array([x.flatten(), y.flatten(), z.flatten()])
286
+ points = axes @ points
287
+ points = points.T + center
288
+
289
+ fig.add_trace(go.Surface(x=points[:, 0].reshape(n,n), y=points[:, 1].reshape(n,n), z=points[:, 2].reshape(n,n), opacity = 0.5, showscale=False))
290
+
291
+ if "vertices" in result:
292
+ for vertices in result["vertices"]:
293
+ # Plot a line through the vertices of the gate
294
+ fig.add_trace(go.Scatter3d(x=[vertices[0][0], vertices[1][0], vertices[2][0], vertices[3][0], vertices[0][0]], y=[vertices[0][1], vertices[1][1], vertices[2][1], vertices[3][1], vertices[0][1]], z=[vertices[0][2], vertices[1][2], vertices[2][2], vertices[3][2], vertices[0][2]], mode='lines', showlegend=False, line=dict(color='blue', width=10)))
295
+
296
+ # Add the subject positions
297
+ if "n_subs" in result and result["n_subs"] != 0:
298
+ if "moving_subject" in result:
299
+ if result["moving_subject"]:
300
+ for sub_positions in subs_positions:
301
+ fig.add_trace(go.Scatter3d(x=sub_positions[:,0], y=sub_positions[:,1], z=sub_positions[:,2], mode='lines', line=dict(color='red', width = 5), showlegend=False))
302
+ else:
303
+ # Plot the subject positions as points
304
+ for sub_positions in subs_positions:
305
+ fig.add_trace(go.Scatter3d(x=sub_positions[:,0], y=sub_positions[:,1], z=sub_positions[:,2], mode='markers', marker=dict(size=10, color='red'), showlegend=False))
306
+
307
+
308
+ fig.add_trace(go.Surface(x=[-200, 200, 200, -200], y=[-200, -200, 200, 200], z=[[0, 0], [0, 0], [0, 0], [0, 0]], opacity=0.3, showscale=False, colorscale='Greys', showlegend = True, name='Ground Plane'))
309
+
310
+ fig.update_layout(scene=dict(aspectmode='manual', aspectratio=dict(x=10, y=10, z=10)))
311
+ fig.update_layout(scene=dict(xaxis=dict(range=[-200, 200]), yaxis=dict(range=[-200, 200]), zaxis=dict(range=[-200, 200])))
312
+
313
+ sliders = [
314
+ {
315
+ "pad": {"b": 10, "t": 60},
316
+ "len": 0.8,
317
+ "x": 0.15,
318
+ "y": 0.32,
319
+ "steps": [
320
+ {
321
+ "args": [[f.name], frame_args(0)],
322
+ "label": f.name,
323
+ "method": "animate",
324
+ } for f in fig.frames
325
+ ]
326
+ }
327
+ ]
328
+
329
+ fig.update_layout(updatemenus = [{"buttons":[
330
+ {
331
+ "args": [None, frame_args(50)],
332
+ "label": "Play",
333
+ "method": "animate",
334
+ },
335
+ {
336
+ "args": [[None], frame_args(0)],
337
+ "label": "Pause",
338
+ "method": "animate",
339
+ }],
340
+
341
+ "direction": "left",
342
+ "pad": {"r": 10, "t": 70},
343
+ "type": "buttons",
344
+ "x": 0.15,
345
+ "y": 0.32,
346
+ }
347
+ ],
348
+ sliders=sliders
349
+ )
350
+ fig.update_layout(sliders=sliders)
351
+
352
+ fig.update_layout(scene=dict(aspectmode='manual', aspectratio=dict(x=10, y=10, z=10)))
353
+ fig.update_layout(scene=dict(xaxis=dict(range=[-200, 200]), yaxis=dict(range=[-200, 200]), zaxis=dict(range=[-200, 200])))
354
+
355
+ # Overlay the title onto the plot
356
+ fig.update_layout(title_y=0.95, title_x=0.5)
357
+
358
+ # Overlay the sliders and buttons onto the plot
359
+ fig.update_layout(updatemenus = [{"buttons":[
360
+ {
361
+ "args": [None, frame_args(50)],
362
+ "label": "Play",
363
+ "method": "animate",
364
+ },
365
+ {
366
+ "args": [[None], frame_args(0)],
367
+ "label": "Pause",
368
+ "method": "animate",
369
+ }],
370
+
371
+ "direction": "left",
372
+ "pad": {"r": 10, "t": 70},
373
+ "type": "buttons",
374
+ "x": 0.15,
375
+ "y": 0.32,
376
+ }
377
+ ],
378
+ sliders=sliders
379
+ )
380
+
381
+
382
+
383
+ # Show the legend overlayed on the plot
384
+ fig.update_layout(legend=dict(yanchor="top", y=0.9, xanchor="left", x=0.75))
385
+
386
+ # fig.update_layout(height=450, width = 800)
387
+
388
+ # Remove the black border around the fig
389
+ fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
390
+
391
+ # Rmeove the background from the legend
392
+ fig.update_layout(legend=dict(bgcolor='rgba(0,0,0,0)'))
393
+
394
+ fig.update_xaxes(
395
+ dtick=1.0,
396
+ showline=False
397
+ )
398
+ fig.update_yaxes(
399
+ scaleanchor="x",
400
+ scaleratio=1,
401
+ showline=False,
402
+ dtick=1.0
403
+ )
404
+
405
+ # Rotate the camera view to the left
406
+ if not "moving_subject" in result:
407
+ fig.update_layout(scene_camera=dict(up=dict(x=0, y=0, z=90), center=dict(x=1, y=0.3, z=1), eye=dict(x=-1, y=2, z=1)))
408
+
409
+ fig.show()
410
+
411
+ def scp_traj_interp(scp_trajs, params: Config):
412
+ scp_prop_trajs = []
413
+ for traj in scp_trajs:
414
+ states = []
415
+ for k in range(params.scp.n):
416
+ traj_temp = np.repeat(np.expand_dims(traj[k], axis = 1), params.prp.inter_sample - 1, axis = 1)
417
+ for i in range(1, params.prp.inter_sample - 1):
418
+ states.append(traj_temp[:,i])
419
+ scp_prop_trajs.append(np.array(states))
420
+ return scp_prop_trajs
421
+
422
+ def plot_state(result, params: Config):
423
+ scp_trajs = scp_traj_interp(result["x_history"], params)
424
+ x_full = result["x_full"]
425
+
426
+ fig = make_subplots(rows=2, cols=7, subplot_titles=('X Position', 'Y Position', 'Z Position', 'X Velocity', 'Y Velocity', 'Z Velocity', 'CTCS Augmentation', 'Q1', 'Q2', 'Q3', 'Q4', 'X Angular Rate', 'Y Angular Rate', 'Z Angular Rate'))
427
+ fig.update_layout(title_text="State Trajectories", template='plotly_dark')
428
+
429
+ # Plot the position
430
+ x_min = params.sim.min_state[0]
431
+ x_max = params.sim.max_state[0]
432
+ for traj in scp_trajs:
433
+ fig.add_trace(go.Scatter(y=traj[:,0], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=1)
434
+ fig.add_trace(go.Scatter(y=x_full[:,0], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=1)
435
+ fig.add_hline(y=x_min, line=dict(color='red', width=2), row = 1, col = 1)
436
+ fig.add_hline(y=x_max, line=dict(color='red', width=2), row = 1, col = 1)
437
+ fig.update_yaxes(range=[x_min, x_max], row=1, col=1)
438
+
439
+ y_min = params.sim.min_state[1]
440
+ y_max = params.sim.max_state[1]
441
+ for traj in scp_trajs:
442
+ fig.add_trace(go.Scatter(y=traj[:,1], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=2)
443
+ fig.add_trace(go.Scatter(y=x_full[:,1], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=2)
444
+ fig.add_hline(y=y_min, line=dict(color='red', width=2), row = 1, col = 2)
445
+ fig.add_hline(y=y_max, line=dict(color='red', width=2), row = 1, col = 2)
446
+
447
+ z_min = params.sim.min_state[2]
448
+ z_max = params.sim.max_state[2]
449
+ for traj in scp_trajs:
450
+ fig.add_trace(go.Scatter(y=traj[:,2], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=3)
451
+ fig.add_trace(go.Scatter(y=x_full[:,2], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=3)
452
+ fig.add_hline(y=z_min, line=dict(color='red', width=2), row = 1, col = 3)
453
+ fig.add_hline(y=z_max, line=dict(color='red', width=2), row = 1, col = 3)
454
+
455
+ # Plot the velocity
456
+ vx_min = params.sim.min_state[3]
457
+ vx_max = params.sim.max_state[3]
458
+ for traj in scp_trajs:
459
+ fig.add_trace(go.Scatter(y=traj[:,3], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=4)
460
+ fig.add_trace(go.Scatter(y=x_full[:,3], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=4)
461
+ fig.add_hline(y=vx_min, line=dict(color='red', width=2), row = 1, col = 4)
462
+ fig.add_hline(y=vx_max, line=dict(color='red', width=2), row = 1, col = 4)
463
+
464
+ vy_min = params.sim.min_state[4]
465
+ vy_max = params.sim.max_state[4]
466
+ for traj in scp_trajs:
467
+ fig.add_trace(go.Scatter(y=traj[:,4], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=5)
468
+ fig.add_trace(go.Scatter(y=x_full[:,4], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=5)
469
+ fig.add_hline(y=vy_min, line=dict(color='red', width=2), row = 1, col = 5)
470
+ fig.add_hline(y=vy_max, line=dict(color='red', width=2), row = 1, col = 5)
471
+
472
+ vz_min = params.sim.min_state[5]
473
+ vz_max = params.sim.max_state[5]
474
+ for traj in scp_trajs:
475
+ fig.add_trace(go.Scatter(y=traj[:,5], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=6)
476
+ fig.add_trace(go.Scatter(y=x_full[:,5], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=6)
477
+ fig.add_hline(y=vz_min, line=dict(color='red', width=2), row = 1, col = 6)
478
+ fig.add_hline(y=vz_max, line=dict(color='red', width=2), row = 1, col = 6)
479
+
480
+ # # Plot the norm of the quaternion
481
+ # for traj in scp_trajs:
482
+ # fig.add_trace(go.Scatter(y=la.norm(traj[1:,6:10], axis = 1), mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=7)
483
+ # fig.add_trace(go.Scatter(y=la.norm(x_full[1:,6:10], axis = 1), mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=7)
484
+ for traj in scp_trajs:
485
+ fig.add_trace(go.Scatter(y=traj[:,-1], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=7)
486
+ fig.add_trace(go.Scatter(y=x_full[:,-1], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=7)
487
+ # fig.add_hline(y=vz_min, line=dict(color='red', width=2), row = 1, col = 6)
488
+ # fig.add_hline(y=vz_max, line=dict(color='red', width=2), row = 1, col = 6)
489
+
490
+ # Plot the attitude
491
+ q1_min = params.sim.min_state[6]
492
+ q1_max = params.sim.max_state[6]
493
+ for traj in scp_trajs:
494
+ fig.add_trace(go.Scatter(y=traj[:,6], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=1)
495
+ fig.add_trace(go.Scatter(y=x_full[:,6], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=1)
496
+ fig.add_hline(y=q1_min, line=dict(color='red', width=2), row = 2, col = 1)
497
+ fig.add_hline(y=q1_max, line=dict(color='red', width=2), row = 2, col = 1)
498
+
499
+ q2_min = params.sim.min_state[7]
500
+ q2_max = params.sim.max_state[7]
501
+ for traj in scp_trajs:
502
+ fig.add_trace(go.Scatter(y=traj[:,7], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=2)
503
+ fig.add_trace(go.Scatter(y=x_full[:,7], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=2)
504
+ fig.add_hline(y=q2_min, line=dict(color='red', width=2), row = 2, col = 2)
505
+ fig.add_hline(y=q2_max, line=dict(color='red', width=2), row = 2, col = 2)
506
+
507
+ q3_min = params.sim.min_state[8]
508
+ q3_max = params.sim.max_state[8]
509
+ for traj in scp_trajs:
510
+ fig.add_trace(go.Scatter(y=traj[:,8], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=3)
511
+ fig.add_trace(go.Scatter(y=x_full[:,8], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=3)
512
+ fig.add_hline(y=q3_min, line=dict(color='red', width=2), row = 2, col = 3)
513
+ fig.add_hline(y=q3_max, line=dict(color='red', width=2), row = 2, col = 3)
514
+
515
+ q4_min = params.sim.min_state[9]
516
+ q4_max = params.sim.max_state[9]
517
+ for traj in scp_trajs:
518
+ fig.add_trace(go.Scatter(y=traj[:,9], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=4)
519
+ fig.add_trace(go.Scatter(y=x_full[:,9], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=4)
520
+ fig.add_hline(y=q4_min, line=dict(color='red', width=2), row = 2, col = 4)
521
+ fig.add_hline(y=q4_max, line=dict(color='red', width=2), row = 2, col = 4)
522
+
523
+ # Plot the angular rate
524
+ wx_min = params.sim.min_state[10]
525
+ wx_max = params.sim.max_state[10]
526
+ for traj in scp_trajs:
527
+ fig.add_trace(go.Scatter(y=traj[:,10], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=5)
528
+ fig.add_trace(go.Scatter(y=x_full[:,10], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=5)
529
+ fig.add_hline(y=wx_min, line=dict(color='red', width=2), row = 2, col = 5)
530
+ fig.add_hline(y=wx_max, line=dict(color='red', width=2), row = 2, col = 5)
531
+
532
+ wy_min = params.sim.min_state[11]
533
+ wy_max = params.sim.max_state[11]
534
+ for traj in scp_trajs:
535
+ fig.add_trace(go.Scatter(y=traj[:,11], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=6)
536
+ fig.add_trace(go.Scatter(y=x_full[:,11], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=6)
537
+ fig.add_hline(y=wy_min, line=dict(color='red', width=2), row = 2, col = 6)
538
+ fig.add_hline(y=wy_max, line=dict(color='red', width=2), row = 2, col = 6)
539
+
540
+ wz_min = params.sim.min_state[12]
541
+ wz_max = params.sim.max_state[12]
542
+ for traj in scp_trajs:
543
+ fig.add_trace(go.Scatter(y=traj[:,12], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=7)
544
+ fig.add_trace(go.Scatter(y=x_full[:,12], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=7)
545
+ fig.add_hline(y=wz_min, line=dict(color='red', width=2), row = 2, col = 7)
546
+ fig.add_hline(y=wz_max, line=dict(color='red', width=2), row = 2, col = 7)
547
+ fig.show()
548
+
549
+ def plot_control(result, params: Config):
550
+ scp_controls = result["u_history"]
551
+ u = result["u"]
552
+
553
+ fx_min = params.sim.min_control[0]
554
+ fx_max = params.sim.max_control[0]
555
+
556
+ fy_min = params.sim.min_control[1]
557
+ fy_max = params.sim.max_control[1]
558
+
559
+ fz_min = params.sim.min_control[2]
560
+ fz_max = params.sim.max_control[2]
561
+
562
+ tau_x_min = params.sim.max_control[3]
563
+ tau_x_max = params.sim.min_control[3]
564
+
565
+ tau_y_min = params.sim.max_control[4]
566
+ tau_y_max = params.sim.min_control[4]
567
+
568
+ tau_z_min = params.sim.max_control[5]
569
+ tau_z_max = params.sim.min_control[5]
570
+
571
+ fig = make_subplots(rows=2, cols=3, subplot_titles=('X Force', 'Y Force', 'Z Force', 'X Torque', 'Y Torque', 'Z Torque'))
572
+ fig.update_layout(title_text="Control Trajectories", template='plotly_dark')
573
+
574
+ for traj in scp_controls:
575
+ fig.add_trace(go.Scatter(y=traj[0], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=1)
576
+ fig.add_trace(go.Scatter(y=u[0], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=1)
577
+ fig.add_hline(y=fx_min, line=dict(color='red', width=2), row = 1, col = 1)
578
+ fig.add_hline(y=fx_max, line=dict(color='red', width=2), row = 1, col = 1)
579
+
580
+ for traj in scp_controls:
581
+ fig.add_trace(go.Scatter(y=traj[1], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=2)
582
+ fig.add_trace(go.Scatter(y=u[1], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=2)
583
+ fig.add_hline(y=fy_min, line=dict(color='red', width=2), row = 1, col = 2)
584
+ fig.add_hline(y=fy_max, line=dict(color='red', width=2), row = 1, col = 2)
585
+
586
+ for traj in scp_controls:
587
+ fig.add_trace(go.Scatter(y=traj[2], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=3)
588
+ fig.add_trace(go.Scatter(y=u[2], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=3)
589
+ fig.add_hline(y=fz_min, line=dict(color='red', width=2), row = 1, col = 3)
590
+ fig.add_hline(y=fz_max, line=dict(color='red', width=2), row = 1, col = 3)
591
+
592
+ for traj in scp_controls:
593
+ fig.add_trace(go.Scatter(y=traj[3], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=1)
594
+ fig.add_trace(go.Scatter(y=u[3], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=1)
595
+ fig.add_hline(y=tau_x_min, line=dict(color='red', width=2), row = 2, col = 1)
596
+ fig.add_hline(y=tau_x_max, line=dict(color='red', width=2), row = 2, col = 1)
597
+
598
+ for traj in scp_controls:
599
+ fig.add_trace(go.Scatter(y=traj[4], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=2)
600
+ fig.add_trace(go.Scatter(y=u[4], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=2)
601
+ fig.add_hline(y=tau_y_min, line=dict(color='red', width=2), row = 2, col = 2)
602
+ fig.add_hline(y=tau_y_max, line=dict(color='red', width=2), row = 2, col = 2)
603
+
604
+ for traj in scp_controls:
605
+ fig.add_trace(go.Scatter(y=traj[5], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=3)
606
+ fig.add_trace(go.Scatter(y=u[5], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=3)
607
+ fig.add_hline(y=tau_z_min, line=dict(color='red', width=2), row = 2, col = 3)
608
+ fig.add_hline(y=tau_z_max, line=dict(color='red', width=2), row = 2, col = 3)
609
+
610
+ fig.show()
611
+
612
+ def plot_losses(result, params: Config):
613
+ # Plot J_tr, J_vb, J_vc, J_vc_ctcs
614
+ J_tr = result["J_tr_history"]
615
+ J_vb = result["J_vb_history"]
616
+ J_vc = result["J_vc_history"]
617
+ J_vc_ctcs = result["J_vc_ctcs_vec"]
618
+
619
+ fig = make_subplots(rows=2, cols=2, subplot_titles=('J_tr', 'J_vb', 'J_vc', 'J_vc_ctcs'))
620
+ fig.update_layout(title_text="Losses", template='plotly_dark')
621
+
622
+ fig.add_trace(go.Scatter(y=J_tr, mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=1)
623
+ fig.add_trace(go.Scatter(y=J_vb, mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=2)
624
+ fig.add_trace(go.Scatter(y=J_vc, mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=1)
625
+
626
+ # Set y-axis to log scale for each subplot
627
+ fig.update_yaxes(type='log', row=1, col=1)
628
+ fig.update_yaxes(type='log', row=1, col=2)
629
+ fig.update_yaxes(type='log', row=2, col=1)
630
+ fig.update_yaxes(type='log', row=2, col=2)
631
+
632
+ fig.show()
@@ -0,0 +1,36 @@
1
+ import numpy as np
2
+
3
+ from openscvx.propagation import s_to_t, t_to_tau, simulate_nonlinear_time
4
+ from openscvx.config import Config
5
+
6
+
7
+ def propagate_trajectory_results(params: Config, result: dict, propagation_solver: callable) -> dict:
8
+ x = result["x"]
9
+ u = result["u"]
10
+
11
+ t = np.array(s_to_t(u, params))
12
+
13
+ t_full = np.arange(0, t[-1], params.prp.dt)
14
+
15
+ tau_vals, u_full = t_to_tau(u, t_full, u, t, params)
16
+
17
+ x_full = simulate_nonlinear_time(x[0], u, tau_vals, t, params, propagation_solver)
18
+
19
+ print("Total CTCS Constraint Violation:", x_full[-1, params.sim.idx_y])
20
+ i = 0
21
+ cost = np.zeros_like(x[-1, i])
22
+ for type in params.sim.initial_state.type:
23
+ if type == "Minimize":
24
+ cost += x[0, i]
25
+ i += 1
26
+ i = 0
27
+ for type in params.sim.final_state.type:
28
+ if type == "Minimize":
29
+ cost += x[-1, i]
30
+ i += 1
31
+ print("Cost: ", cost)
32
+
33
+ more_result = dict(t_full=t_full, x_full=x_full, u_full=u_full)
34
+
35
+ result.update(more_result)
36
+ return result