openscvx 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +0 -0
- openscvx/_version.py +21 -0
- openscvx/augmentation.py +44 -0
- openscvx/config.py +247 -0
- openscvx/discretization.py +169 -0
- openscvx/dynamics.py +24 -0
- openscvx/integrators.py +139 -0
- openscvx/io.py +81 -0
- openscvx/ocp.py +160 -0
- openscvx/plotting.py +632 -0
- openscvx/post_processing.py +36 -0
- openscvx/propagation.py +135 -0
- openscvx/ptr.py +149 -0
- openscvx/trajoptproblem.py +336 -0
- openscvx/utils.py +80 -0
- {openscvx-0.1.0.dist-info → openscvx-0.1.1.dist-info}/METADATA +2 -2
- openscvx-0.1.1.dist-info/RECORD +25 -0
- openscvx-0.1.1.dist-info/top_level.txt +1 -0
- openscvx-0.1.0.dist-info/RECORD +0 -10
- openscvx-0.1.0.dist-info/top_level.txt +0 -1
- {constraints → openscvx/constraints}/__init__.py +0 -0
- {constraints → openscvx/constraints}/boundary.py +0 -0
- {constraints → openscvx/constraints}/ctcs.py +0 -0
- {constraints → openscvx/constraints}/nodal.py +0 -0
- {constraints → openscvx/constraints}/violation.py +0 -0
- {openscvx-0.1.0.dist-info → openscvx-0.1.1.dist-info}/WHEEL +0 -0
- {openscvx-0.1.0.dist-info → openscvx-0.1.1.dist-info}/licenses/LICENSE +0 -0
openscvx/plotting.py
ADDED
|
@@ -0,0 +1,632 @@
|
|
|
1
|
+
from plotly.subplots import make_subplots
|
|
2
|
+
import random
|
|
3
|
+
import plotly.graph_objects as go
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pickle
|
|
6
|
+
|
|
7
|
+
from openscvx.utils import qdcm, get_kp_pose
|
|
8
|
+
from openscvx.config import Config
|
|
9
|
+
|
|
10
|
+
def full_subject_traj_time(results, params):
|
|
11
|
+
x_full = results["x_full"]
|
|
12
|
+
x_nodes = results["x"]
|
|
13
|
+
t_nodes = x_nodes[:,params.sim.idx_t]
|
|
14
|
+
t_full = results['t_full']
|
|
15
|
+
subs_traj = []
|
|
16
|
+
subs_traj_node = []
|
|
17
|
+
subs_traj_sen = []
|
|
18
|
+
subs_traj_sen_node = []
|
|
19
|
+
|
|
20
|
+
# if hasattr(params.dyn, 'get_kp_pose'):
|
|
21
|
+
if "moving_subject" in results and "init_poses" in results:
|
|
22
|
+
init_poses = results["init_poses"]
|
|
23
|
+
subs_traj.append(get_kp_pose(t_full, init_poses))
|
|
24
|
+
subs_traj_node.append(get_kp_pose(t_nodes, init_poses))
|
|
25
|
+
elif "init_poses" in results:
|
|
26
|
+
for pose in results["init_poses"]:
|
|
27
|
+
# repeat the pose for all time steps
|
|
28
|
+
pose_full = np.repeat(pose[:,np.newaxis], x_full.shape[0], axis=1).T
|
|
29
|
+
subs_traj.append(pose_full)
|
|
30
|
+
|
|
31
|
+
pose_node = np.repeat(pose[:,np.newaxis], x_nodes.shape[0], axis=1).T
|
|
32
|
+
subs_traj_node.append(pose_node)
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError("No valid method to get keypoint poses.")
|
|
35
|
+
|
|
36
|
+
if "R_sb" in results:
|
|
37
|
+
R_sb = results["R_sb"]
|
|
38
|
+
for sub_traj in subs_traj:
|
|
39
|
+
sub_traj_sen = []
|
|
40
|
+
for i in range(x_full.shape[0]):
|
|
41
|
+
sub_pose = sub_traj[i]
|
|
42
|
+
sub_traj_sen.append(R_sb @ qdcm(x_full[i, 6:10]).T @ (sub_pose - x_full[i, 0:3]))
|
|
43
|
+
subs_traj_sen.append(np.array(sub_traj_sen).squeeze())
|
|
44
|
+
|
|
45
|
+
for sub_traj_node in subs_traj_node:
|
|
46
|
+
sub_traj_sen_node = []
|
|
47
|
+
for i in range(x_nodes.shape[0]):
|
|
48
|
+
sub_pose = sub_traj_node[i]
|
|
49
|
+
sub_traj_sen_node.append(R_sb @ qdcm(x_nodes[i, 6:10]).T @ (sub_pose - x_nodes[i, 0:3]).T)
|
|
50
|
+
subs_traj_sen_node.append(np.array(sub_traj_sen_node).squeeze())
|
|
51
|
+
return subs_traj, subs_traj_sen, subs_traj_node, subs_traj_sen_node
|
|
52
|
+
else:
|
|
53
|
+
raise ValueError("`R_sb` not found in results dictionary. Cannot compute sensor frame.")
|
|
54
|
+
|
|
55
|
+
def save_gate_parameters(gates, params: Config):
|
|
56
|
+
gate_centers = []
|
|
57
|
+
gate_vertices = []
|
|
58
|
+
for gate in gates:
|
|
59
|
+
gate_centers.append(gate.center)
|
|
60
|
+
gate_vertices.append(gate.vertices)
|
|
61
|
+
gate_params = dict(
|
|
62
|
+
gate_centers = gate_centers,
|
|
63
|
+
gate_vertices = gate_vertices
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Use pickle to save the gate parameters
|
|
67
|
+
with open('results/gate_params.pickle', 'wb') as f:
|
|
68
|
+
pickle.dump(gate_params, f)
|
|
69
|
+
|
|
70
|
+
def frame_args(duration):
|
|
71
|
+
return {
|
|
72
|
+
"frame": {"duration": duration},
|
|
73
|
+
"mode": "immediate",
|
|
74
|
+
"fromcurrent": True,
|
|
75
|
+
"transition": {"duration": duration, "easing": "linear"},
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
def plot_constraint_violation(result, params: Config):
|
|
79
|
+
fig = make_subplots(rows=2, cols=3, subplot_titles=(r'$\text{Obstacle Violation}$', r'$\text{Sub VP Violation}$', r'$\text{Sub Min Violation}$', r'$\text{Sub Max Violation}$', r'$\text{Sub Direc Violation}$', r'$\text{State Bound Violation}$', r'$\text{Total Violation}$'))
|
|
80
|
+
fig.update_layout(template='plotly_dark', title=r'$\text{Constraint Violation}$')
|
|
81
|
+
|
|
82
|
+
if "obs_vio" in result:
|
|
83
|
+
obs_vio = result["obs_vio"]
|
|
84
|
+
for i in range(obs_vio.shape[0]):
|
|
85
|
+
color = f'rgb({random.randint(10,255)}, {random.randint(10,255)}, {random.randint(10,255)})'
|
|
86
|
+
fig.add_trace(go.Scatter(y=obs_vio[i], mode='lines', showlegend=False, line=dict(color=color, width = 2)), row=1, col=1)
|
|
87
|
+
i = 0
|
|
88
|
+
else:
|
|
89
|
+
print("'obs_vio' not found in result dictionary.")
|
|
90
|
+
|
|
91
|
+
# Make names of each state in the state vector
|
|
92
|
+
state_names = ['x', 'y', 'z', 'vx', 'vy', 'vz', 'q0', 'q1', 'q2', 'q3', 'wx', 'wy', 'wz', 'ctcs']
|
|
93
|
+
|
|
94
|
+
if "sub_vp_vio" in result and "sub_min_vio" in result and "sub_max_vio" in result:
|
|
95
|
+
sub_vp_vio = result["sub_vp_vio"]
|
|
96
|
+
sub_min_vio = result["sub_min_vio"]
|
|
97
|
+
sub_max_vio = result["sub_max_vio"]
|
|
98
|
+
for i in range(sub_vp_vio.shape[0]):
|
|
99
|
+
color = f'rgb({random.randint(10,255)}, {random.randint(10,255)}, {random.randint(10,255)})'
|
|
100
|
+
fig.add_trace(go.Scatter(y=sub_vp_vio[i], mode='lines', showlegend=True, name = 'LoS ' + str(i) + ' Error', line=dict(color=color, width = 2)), row=1, col=2)
|
|
101
|
+
if params.vp.tracking:
|
|
102
|
+
fig.add_trace(go.Scatter(y=sub_min_vio[i], mode='lines', showlegend=False, line=dict(color=color, width = 2)), row=1, col=3)
|
|
103
|
+
fig.add_trace(go.Scatter(y=sub_max_vio[i], mode='lines', showlegend=False, line=dict(color=color, width = 2)), row=2, col=1)
|
|
104
|
+
else:
|
|
105
|
+
fig.add_trace(go.Scatter(y=[], mode='lines', showlegend=False, line=dict(color=color, width = 2)), row=1, col=3)
|
|
106
|
+
fig.add_trace(go.Scatter(y=[], mode='lines', showlegend=False, line=dict(color=color, width = 2)), row=2, col=1)
|
|
107
|
+
i = 0
|
|
108
|
+
else:
|
|
109
|
+
print("'sub_vp_vio', 'sub_min_vio', or 'sub_max_vio' not found in result dictionary.")
|
|
110
|
+
|
|
111
|
+
if "sub_direc_vio" in result:
|
|
112
|
+
sub_direc_vio = result["sub_direc_vio"]
|
|
113
|
+
# fig.add_trace(go.Scatter(y=sub_direc_vio, mode='lines', showlegend=False, line=dict(color='red', width = 2)), row=2, col=2)
|
|
114
|
+
else:
|
|
115
|
+
print("'sub_direc_vio' not found in result dictionary.")
|
|
116
|
+
|
|
117
|
+
if "state_bound_vio" in result:
|
|
118
|
+
state_bound_vio = result["state_bound_vio"]
|
|
119
|
+
for i in range(state_bound_vio.shape[0]):
|
|
120
|
+
color = f'rgb({random.randint(10,255)}, {random.randint(10,255)}, {random.randint(10,255)})'
|
|
121
|
+
fig.add_trace(go.Scatter(y=state_bound_vio[:,i], mode='lines', showlegend=True, name = state_names[i] + ' Error', line=dict(color=color, width = 2)), row=2, col=3)
|
|
122
|
+
else:
|
|
123
|
+
print("'state_bound_vio' not found in result dictionary.")
|
|
124
|
+
|
|
125
|
+
fig.show()
|
|
126
|
+
|
|
127
|
+
def plot_initial_guess(result, params: Config):
|
|
128
|
+
x_positions = result["x"][0:3]
|
|
129
|
+
x_attitude = result["x"][6:10]
|
|
130
|
+
subs_positions = result["sub_positions"]
|
|
131
|
+
|
|
132
|
+
fig = go.Figure(go.Scatter3d(x=[], y=[], z=[], mode='lines+markers', line=dict(color='gray', width = 2)))
|
|
133
|
+
|
|
134
|
+
# Plot the position of the drone
|
|
135
|
+
fig.add_trace(go.Scatter3d(x=x_positions[0], y=x_positions[1], z=x_positions[2], mode='lines+markers', line=dict(color='green', width = 5)))
|
|
136
|
+
|
|
137
|
+
# Plot the attitude of the drone
|
|
138
|
+
# Draw drone attitudes as axes
|
|
139
|
+
step = 1
|
|
140
|
+
indices = np.array(list(range(x_positions.shape[1])))
|
|
141
|
+
for i in range(0, len(indices), step):
|
|
142
|
+
att = x_attitude[:, indices[i]]
|
|
143
|
+
|
|
144
|
+
# Convert quaternion to rotation matrix
|
|
145
|
+
rotation_matrix = qdcm(att)
|
|
146
|
+
|
|
147
|
+
# Extract axes from rotation matrix
|
|
148
|
+
axes = 2 * np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
|
|
149
|
+
rotated_axes = np.dot(rotation_matrix, axes).T
|
|
150
|
+
|
|
151
|
+
colors = ['#FF0000', '#00FF00', '#0000FF']
|
|
152
|
+
|
|
153
|
+
for k in range(3):
|
|
154
|
+
axis = rotated_axes[k]
|
|
155
|
+
color = colors[k]
|
|
156
|
+
|
|
157
|
+
fig.add_trace(go.Scatter3d(
|
|
158
|
+
x=[x_positions[0, indices[i]], x_positions[0, indices[i]] + axis[0]],
|
|
159
|
+
y=[x_positions[1, indices[i]], x_positions[1, indices[i]] + axis[1]],
|
|
160
|
+
z=[x_positions[2, indices[i]], x_positions[2, indices[i]] + axis[2]],
|
|
161
|
+
mode='lines+text',
|
|
162
|
+
line=dict(color=color, width=4),
|
|
163
|
+
showlegend=False
|
|
164
|
+
))
|
|
165
|
+
|
|
166
|
+
fig.update_layout(template='plotly_dark')
|
|
167
|
+
fig.update_layout(scene=dict(aspectmode='manual', aspectratio=dict(x=10, y=10, z=10)))
|
|
168
|
+
fig.update_layout(scene=dict(xaxis=dict(range=[-200, 200]), yaxis=dict(range=[-200, 200]), zaxis=dict(range=[-200, 200])))
|
|
169
|
+
|
|
170
|
+
# Plot the keypoint
|
|
171
|
+
for sub_positions in subs_positions:
|
|
172
|
+
fig.add_trace(go.Scatter3d(x=sub_positions[:,0], y=sub_positions[:,1], z=sub_positions[:,2], mode='lines+markers', line=dict(color='red', width = 5), name='Subject'))
|
|
173
|
+
fig.show()
|
|
174
|
+
|
|
175
|
+
def plot_scp_animation(result: dict,
|
|
176
|
+
params = None,
|
|
177
|
+
path=""):
|
|
178
|
+
tof = result["t_final"]
|
|
179
|
+
title = f'SCP Simulation: {tof} seconds'
|
|
180
|
+
drone_positions = result["x_full"][:, :3]
|
|
181
|
+
drone_attitudes = result["x_full"][:, 6:10]
|
|
182
|
+
drone_forces = result["u_full"][:, :3]
|
|
183
|
+
scp_interp_trajs = scp_traj_interp(result["x_history"], params)
|
|
184
|
+
scp_ctcs_trajs = result["x_history"]
|
|
185
|
+
scp_multi_shoot = result["discretization"]
|
|
186
|
+
# obstacles = result_ctcs["obstacles"]
|
|
187
|
+
# gates = result_ctcs["gates"]
|
|
188
|
+
if "moving_subject" in result or "init_poses" in result:
|
|
189
|
+
subs_positions, _, _, _ = full_subject_traj_time(result, params)
|
|
190
|
+
fig = go.Figure(go.Scatter3d(x=[], y=[], z=[], mode='lines+markers', line=dict(color='gray', width = 2), name='SCP Iterations'))
|
|
191
|
+
for j in range(200):
|
|
192
|
+
fig.add_trace(go.Scatter3d(x=[], y=[], z=[], mode='lines+markers', line=dict(color='gray', width = 2)))
|
|
193
|
+
|
|
194
|
+
# fig.update_layout(height=1000)
|
|
195
|
+
|
|
196
|
+
fig.add_trace(go.Scatter3d(x=drone_positions[:,0], y=drone_positions[:,1], z=drone_positions[:,2], mode='lines', line=dict(color='green', width = 5), name='Nonlinear Propagation'))
|
|
197
|
+
|
|
198
|
+
fig.update_layout(template='plotly_dark', title=title)
|
|
199
|
+
|
|
200
|
+
fig.update_layout(scene=dict(aspectmode='manual', aspectratio=dict(x=10, y=10, z=10)))
|
|
201
|
+
fig.update_layout(scene=dict(xaxis=dict(range=[-200, 200]), yaxis=dict(range=[-200, 200]), zaxis=dict(range=[-200, 200])))
|
|
202
|
+
|
|
203
|
+
# Extract the number of states and controls from the parameters
|
|
204
|
+
n_x = params.sim.n_states
|
|
205
|
+
n_u = params.sim.n_controls
|
|
206
|
+
|
|
207
|
+
# Define indices for slicing the augmented state vector
|
|
208
|
+
i0 = 0
|
|
209
|
+
i1 = n_x
|
|
210
|
+
i2 = i1 + n_x * n_x
|
|
211
|
+
i3 = i2 + n_x * n_u
|
|
212
|
+
i4 = i3 + n_x * n_u
|
|
213
|
+
i5 = i4 + n_x
|
|
214
|
+
|
|
215
|
+
# Plot the attitudes of the SCP Trajs
|
|
216
|
+
frames = []
|
|
217
|
+
traj_iter = 0
|
|
218
|
+
|
|
219
|
+
for scp_traj in scp_ctcs_trajs:
|
|
220
|
+
drone_positions = scp_traj[:,0:3]
|
|
221
|
+
drone_attitudes = scp_traj[:,6:10]
|
|
222
|
+
frame = go.Frame(name=str(traj_iter))
|
|
223
|
+
data = []
|
|
224
|
+
# Plot the multiple shooting trajectories
|
|
225
|
+
pos_traj = []
|
|
226
|
+
if traj_iter < len(scp_multi_shoot):
|
|
227
|
+
for i_multi in range(scp_multi_shoot[traj_iter].shape[1]):
|
|
228
|
+
pos_traj.append(scp_multi_shoot[traj_iter][:,i_multi].reshape(-1, i5)[:,0:3])
|
|
229
|
+
pos_traj = np.array(pos_traj)
|
|
230
|
+
|
|
231
|
+
for j in range(pos_traj.shape[1]):
|
|
232
|
+
if j == 0:
|
|
233
|
+
data.append(go.Scatter3d(x=pos_traj[:,j, 0], y=pos_traj[:,j, 1], z=pos_traj[:,j, 2], mode='lines', legendgroup='Multishot Trajectory', name='Multishot Trajectory ' + str(traj_iter), showlegend=True, line=dict(color='blue', width = 5)))
|
|
234
|
+
else:
|
|
235
|
+
data.append(go.Scatter3d(x=pos_traj[:,j, 0], y=pos_traj[:,j, 1], z=pos_traj[:,j, 2], mode='lines', legendgroup='Multishot Trajectory', showlegend=False, line=dict(color='blue', width = 5)))
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
for i in range(drone_attitudes.shape[0]):
|
|
239
|
+
att = drone_attitudes[i]
|
|
240
|
+
|
|
241
|
+
# Convert quaternion to rotation matrix
|
|
242
|
+
rotation_matrix = qdcm(att)
|
|
243
|
+
|
|
244
|
+
# Extract axes from rotation matrix
|
|
245
|
+
axes = 2 * np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
|
|
246
|
+
rotated_axes = np.dot(rotation_matrix, axes.T).T
|
|
247
|
+
|
|
248
|
+
colors = ['#FF0000', '#00FF00', '#0000FF']
|
|
249
|
+
|
|
250
|
+
for k in range(3):
|
|
251
|
+
axis = rotated_axes[k]
|
|
252
|
+
color = colors[k]
|
|
253
|
+
|
|
254
|
+
data.append(go.Scatter3d(
|
|
255
|
+
x=[scp_traj[i, 0], scp_traj[i, 0] + axis[0]],
|
|
256
|
+
y=[scp_traj[i, 1], scp_traj[i, 1] + axis[1]],
|
|
257
|
+
z=[scp_traj[i, 2], scp_traj[i, 2] + axis[2]],
|
|
258
|
+
mode='lines+text',
|
|
259
|
+
line=dict(color=color, width=4),
|
|
260
|
+
showlegend=False
|
|
261
|
+
))
|
|
262
|
+
traj_iter += 1
|
|
263
|
+
frame.data = data
|
|
264
|
+
frames.append(frame)
|
|
265
|
+
fig.frames = frames
|
|
266
|
+
|
|
267
|
+
i = 1
|
|
268
|
+
if "obstacles_centers" in result:
|
|
269
|
+
for center, axes, radius in zip(result['obstacles_centers'], result['obstacles_axes'], result['obstacles_radii']):
|
|
270
|
+
n = 20
|
|
271
|
+
# Generate points on the unit sphere
|
|
272
|
+
u = np.linspace(0, 2 * np.pi, n)
|
|
273
|
+
v = np.linspace(0, np.pi, n)
|
|
274
|
+
|
|
275
|
+
x = np.outer(np.cos(u), np.sin(v))
|
|
276
|
+
y = np.outer(np.sin(u), np.sin(v))
|
|
277
|
+
z = np.outer(np.ones(np.size(u)), np.cos(v))
|
|
278
|
+
|
|
279
|
+
# Scale points by radii
|
|
280
|
+
x = 1/radius[0] * x
|
|
281
|
+
y = 1/radius[1] * y
|
|
282
|
+
z = 1/radius[2] * z
|
|
283
|
+
|
|
284
|
+
# Rotate and translate points
|
|
285
|
+
points = np.array([x.flatten(), y.flatten(), z.flatten()])
|
|
286
|
+
points = axes @ points
|
|
287
|
+
points = points.T + center
|
|
288
|
+
|
|
289
|
+
fig.add_trace(go.Surface(x=points[:, 0].reshape(n,n), y=points[:, 1].reshape(n,n), z=points[:, 2].reshape(n,n), opacity = 0.5, showscale=False))
|
|
290
|
+
|
|
291
|
+
if "vertices" in result:
|
|
292
|
+
for vertices in result["vertices"]:
|
|
293
|
+
# Plot a line through the vertices of the gate
|
|
294
|
+
fig.add_trace(go.Scatter3d(x=[vertices[0][0], vertices[1][0], vertices[2][0], vertices[3][0], vertices[0][0]], y=[vertices[0][1], vertices[1][1], vertices[2][1], vertices[3][1], vertices[0][1]], z=[vertices[0][2], vertices[1][2], vertices[2][2], vertices[3][2], vertices[0][2]], mode='lines', showlegend=False, line=dict(color='blue', width=10)))
|
|
295
|
+
|
|
296
|
+
# Add the subject positions
|
|
297
|
+
if "n_subs" in result and result["n_subs"] != 0:
|
|
298
|
+
if "moving_subject" in result:
|
|
299
|
+
if result["moving_subject"]:
|
|
300
|
+
for sub_positions in subs_positions:
|
|
301
|
+
fig.add_trace(go.Scatter3d(x=sub_positions[:,0], y=sub_positions[:,1], z=sub_positions[:,2], mode='lines', line=dict(color='red', width = 5), showlegend=False))
|
|
302
|
+
else:
|
|
303
|
+
# Plot the subject positions as points
|
|
304
|
+
for sub_positions in subs_positions:
|
|
305
|
+
fig.add_trace(go.Scatter3d(x=sub_positions[:,0], y=sub_positions[:,1], z=sub_positions[:,2], mode='markers', marker=dict(size=10, color='red'), showlegend=False))
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
fig.add_trace(go.Surface(x=[-200, 200, 200, -200], y=[-200, -200, 200, 200], z=[[0, 0], [0, 0], [0, 0], [0, 0]], opacity=0.3, showscale=False, colorscale='Greys', showlegend = True, name='Ground Plane'))
|
|
309
|
+
|
|
310
|
+
fig.update_layout(scene=dict(aspectmode='manual', aspectratio=dict(x=10, y=10, z=10)))
|
|
311
|
+
fig.update_layout(scene=dict(xaxis=dict(range=[-200, 200]), yaxis=dict(range=[-200, 200]), zaxis=dict(range=[-200, 200])))
|
|
312
|
+
|
|
313
|
+
sliders = [
|
|
314
|
+
{
|
|
315
|
+
"pad": {"b": 10, "t": 60},
|
|
316
|
+
"len": 0.8,
|
|
317
|
+
"x": 0.15,
|
|
318
|
+
"y": 0.32,
|
|
319
|
+
"steps": [
|
|
320
|
+
{
|
|
321
|
+
"args": [[f.name], frame_args(0)],
|
|
322
|
+
"label": f.name,
|
|
323
|
+
"method": "animate",
|
|
324
|
+
} for f in fig.frames
|
|
325
|
+
]
|
|
326
|
+
}
|
|
327
|
+
]
|
|
328
|
+
|
|
329
|
+
fig.update_layout(updatemenus = [{"buttons":[
|
|
330
|
+
{
|
|
331
|
+
"args": [None, frame_args(50)],
|
|
332
|
+
"label": "Play",
|
|
333
|
+
"method": "animate",
|
|
334
|
+
},
|
|
335
|
+
{
|
|
336
|
+
"args": [[None], frame_args(0)],
|
|
337
|
+
"label": "Pause",
|
|
338
|
+
"method": "animate",
|
|
339
|
+
}],
|
|
340
|
+
|
|
341
|
+
"direction": "left",
|
|
342
|
+
"pad": {"r": 10, "t": 70},
|
|
343
|
+
"type": "buttons",
|
|
344
|
+
"x": 0.15,
|
|
345
|
+
"y": 0.32,
|
|
346
|
+
}
|
|
347
|
+
],
|
|
348
|
+
sliders=sliders
|
|
349
|
+
)
|
|
350
|
+
fig.update_layout(sliders=sliders)
|
|
351
|
+
|
|
352
|
+
fig.update_layout(scene=dict(aspectmode='manual', aspectratio=dict(x=10, y=10, z=10)))
|
|
353
|
+
fig.update_layout(scene=dict(xaxis=dict(range=[-200, 200]), yaxis=dict(range=[-200, 200]), zaxis=dict(range=[-200, 200])))
|
|
354
|
+
|
|
355
|
+
# Overlay the title onto the plot
|
|
356
|
+
fig.update_layout(title_y=0.95, title_x=0.5)
|
|
357
|
+
|
|
358
|
+
# Overlay the sliders and buttons onto the plot
|
|
359
|
+
fig.update_layout(updatemenus = [{"buttons":[
|
|
360
|
+
{
|
|
361
|
+
"args": [None, frame_args(50)],
|
|
362
|
+
"label": "Play",
|
|
363
|
+
"method": "animate",
|
|
364
|
+
},
|
|
365
|
+
{
|
|
366
|
+
"args": [[None], frame_args(0)],
|
|
367
|
+
"label": "Pause",
|
|
368
|
+
"method": "animate",
|
|
369
|
+
}],
|
|
370
|
+
|
|
371
|
+
"direction": "left",
|
|
372
|
+
"pad": {"r": 10, "t": 70},
|
|
373
|
+
"type": "buttons",
|
|
374
|
+
"x": 0.15,
|
|
375
|
+
"y": 0.32,
|
|
376
|
+
}
|
|
377
|
+
],
|
|
378
|
+
sliders=sliders
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
# Show the legend overlayed on the plot
|
|
384
|
+
fig.update_layout(legend=dict(yanchor="top", y=0.9, xanchor="left", x=0.75))
|
|
385
|
+
|
|
386
|
+
# fig.update_layout(height=450, width = 800)
|
|
387
|
+
|
|
388
|
+
# Remove the black border around the fig
|
|
389
|
+
fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
|
|
390
|
+
|
|
391
|
+
# Rmeove the background from the legend
|
|
392
|
+
fig.update_layout(legend=dict(bgcolor='rgba(0,0,0,0)'))
|
|
393
|
+
|
|
394
|
+
fig.update_xaxes(
|
|
395
|
+
dtick=1.0,
|
|
396
|
+
showline=False
|
|
397
|
+
)
|
|
398
|
+
fig.update_yaxes(
|
|
399
|
+
scaleanchor="x",
|
|
400
|
+
scaleratio=1,
|
|
401
|
+
showline=False,
|
|
402
|
+
dtick=1.0
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
# Rotate the camera view to the left
|
|
406
|
+
if not "moving_subject" in result:
|
|
407
|
+
fig.update_layout(scene_camera=dict(up=dict(x=0, y=0, z=90), center=dict(x=1, y=0.3, z=1), eye=dict(x=-1, y=2, z=1)))
|
|
408
|
+
|
|
409
|
+
fig.show()
|
|
410
|
+
|
|
411
|
+
def scp_traj_interp(scp_trajs, params: Config):
|
|
412
|
+
scp_prop_trajs = []
|
|
413
|
+
for traj in scp_trajs:
|
|
414
|
+
states = []
|
|
415
|
+
for k in range(params.scp.n):
|
|
416
|
+
traj_temp = np.repeat(np.expand_dims(traj[k], axis = 1), params.prp.inter_sample - 1, axis = 1)
|
|
417
|
+
for i in range(1, params.prp.inter_sample - 1):
|
|
418
|
+
states.append(traj_temp[:,i])
|
|
419
|
+
scp_prop_trajs.append(np.array(states))
|
|
420
|
+
return scp_prop_trajs
|
|
421
|
+
|
|
422
|
+
def plot_state(result, params: Config):
|
|
423
|
+
scp_trajs = scp_traj_interp(result["x_history"], params)
|
|
424
|
+
x_full = result["x_full"]
|
|
425
|
+
|
|
426
|
+
fig = make_subplots(rows=2, cols=7, subplot_titles=('X Position', 'Y Position', 'Z Position', 'X Velocity', 'Y Velocity', 'Z Velocity', 'CTCS Augmentation', 'Q1', 'Q2', 'Q3', 'Q4', 'X Angular Rate', 'Y Angular Rate', 'Z Angular Rate'))
|
|
427
|
+
fig.update_layout(title_text="State Trajectories", template='plotly_dark')
|
|
428
|
+
|
|
429
|
+
# Plot the position
|
|
430
|
+
x_min = params.sim.min_state[0]
|
|
431
|
+
x_max = params.sim.max_state[0]
|
|
432
|
+
for traj in scp_trajs:
|
|
433
|
+
fig.add_trace(go.Scatter(y=traj[:,0], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=1)
|
|
434
|
+
fig.add_trace(go.Scatter(y=x_full[:,0], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=1)
|
|
435
|
+
fig.add_hline(y=x_min, line=dict(color='red', width=2), row = 1, col = 1)
|
|
436
|
+
fig.add_hline(y=x_max, line=dict(color='red', width=2), row = 1, col = 1)
|
|
437
|
+
fig.update_yaxes(range=[x_min, x_max], row=1, col=1)
|
|
438
|
+
|
|
439
|
+
y_min = params.sim.min_state[1]
|
|
440
|
+
y_max = params.sim.max_state[1]
|
|
441
|
+
for traj in scp_trajs:
|
|
442
|
+
fig.add_trace(go.Scatter(y=traj[:,1], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=2)
|
|
443
|
+
fig.add_trace(go.Scatter(y=x_full[:,1], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=2)
|
|
444
|
+
fig.add_hline(y=y_min, line=dict(color='red', width=2), row = 1, col = 2)
|
|
445
|
+
fig.add_hline(y=y_max, line=dict(color='red', width=2), row = 1, col = 2)
|
|
446
|
+
|
|
447
|
+
z_min = params.sim.min_state[2]
|
|
448
|
+
z_max = params.sim.max_state[2]
|
|
449
|
+
for traj in scp_trajs:
|
|
450
|
+
fig.add_trace(go.Scatter(y=traj[:,2], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=3)
|
|
451
|
+
fig.add_trace(go.Scatter(y=x_full[:,2], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=3)
|
|
452
|
+
fig.add_hline(y=z_min, line=dict(color='red', width=2), row = 1, col = 3)
|
|
453
|
+
fig.add_hline(y=z_max, line=dict(color='red', width=2), row = 1, col = 3)
|
|
454
|
+
|
|
455
|
+
# Plot the velocity
|
|
456
|
+
vx_min = params.sim.min_state[3]
|
|
457
|
+
vx_max = params.sim.max_state[3]
|
|
458
|
+
for traj in scp_trajs:
|
|
459
|
+
fig.add_trace(go.Scatter(y=traj[:,3], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=4)
|
|
460
|
+
fig.add_trace(go.Scatter(y=x_full[:,3], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=4)
|
|
461
|
+
fig.add_hline(y=vx_min, line=dict(color='red', width=2), row = 1, col = 4)
|
|
462
|
+
fig.add_hline(y=vx_max, line=dict(color='red', width=2), row = 1, col = 4)
|
|
463
|
+
|
|
464
|
+
vy_min = params.sim.min_state[4]
|
|
465
|
+
vy_max = params.sim.max_state[4]
|
|
466
|
+
for traj in scp_trajs:
|
|
467
|
+
fig.add_trace(go.Scatter(y=traj[:,4], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=5)
|
|
468
|
+
fig.add_trace(go.Scatter(y=x_full[:,4], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=5)
|
|
469
|
+
fig.add_hline(y=vy_min, line=dict(color='red', width=2), row = 1, col = 5)
|
|
470
|
+
fig.add_hline(y=vy_max, line=dict(color='red', width=2), row = 1, col = 5)
|
|
471
|
+
|
|
472
|
+
vz_min = params.sim.min_state[5]
|
|
473
|
+
vz_max = params.sim.max_state[5]
|
|
474
|
+
for traj in scp_trajs:
|
|
475
|
+
fig.add_trace(go.Scatter(y=traj[:,5], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=6)
|
|
476
|
+
fig.add_trace(go.Scatter(y=x_full[:,5], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=6)
|
|
477
|
+
fig.add_hline(y=vz_min, line=dict(color='red', width=2), row = 1, col = 6)
|
|
478
|
+
fig.add_hline(y=vz_max, line=dict(color='red', width=2), row = 1, col = 6)
|
|
479
|
+
|
|
480
|
+
# # Plot the norm of the quaternion
|
|
481
|
+
# for traj in scp_trajs:
|
|
482
|
+
# fig.add_trace(go.Scatter(y=la.norm(traj[1:,6:10], axis = 1), mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=7)
|
|
483
|
+
# fig.add_trace(go.Scatter(y=la.norm(x_full[1:,6:10], axis = 1), mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=7)
|
|
484
|
+
for traj in scp_trajs:
|
|
485
|
+
fig.add_trace(go.Scatter(y=traj[:,-1], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=7)
|
|
486
|
+
fig.add_trace(go.Scatter(y=x_full[:,-1], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=7)
|
|
487
|
+
# fig.add_hline(y=vz_min, line=dict(color='red', width=2), row = 1, col = 6)
|
|
488
|
+
# fig.add_hline(y=vz_max, line=dict(color='red', width=2), row = 1, col = 6)
|
|
489
|
+
|
|
490
|
+
# Plot the attitude
|
|
491
|
+
q1_min = params.sim.min_state[6]
|
|
492
|
+
q1_max = params.sim.max_state[6]
|
|
493
|
+
for traj in scp_trajs:
|
|
494
|
+
fig.add_trace(go.Scatter(y=traj[:,6], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=1)
|
|
495
|
+
fig.add_trace(go.Scatter(y=x_full[:,6], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=1)
|
|
496
|
+
fig.add_hline(y=q1_min, line=dict(color='red', width=2), row = 2, col = 1)
|
|
497
|
+
fig.add_hline(y=q1_max, line=dict(color='red', width=2), row = 2, col = 1)
|
|
498
|
+
|
|
499
|
+
q2_min = params.sim.min_state[7]
|
|
500
|
+
q2_max = params.sim.max_state[7]
|
|
501
|
+
for traj in scp_trajs:
|
|
502
|
+
fig.add_trace(go.Scatter(y=traj[:,7], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=2)
|
|
503
|
+
fig.add_trace(go.Scatter(y=x_full[:,7], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=2)
|
|
504
|
+
fig.add_hline(y=q2_min, line=dict(color='red', width=2), row = 2, col = 2)
|
|
505
|
+
fig.add_hline(y=q2_max, line=dict(color='red', width=2), row = 2, col = 2)
|
|
506
|
+
|
|
507
|
+
q3_min = params.sim.min_state[8]
|
|
508
|
+
q3_max = params.sim.max_state[8]
|
|
509
|
+
for traj in scp_trajs:
|
|
510
|
+
fig.add_trace(go.Scatter(y=traj[:,8], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=3)
|
|
511
|
+
fig.add_trace(go.Scatter(y=x_full[:,8], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=3)
|
|
512
|
+
fig.add_hline(y=q3_min, line=dict(color='red', width=2), row = 2, col = 3)
|
|
513
|
+
fig.add_hline(y=q3_max, line=dict(color='red', width=2), row = 2, col = 3)
|
|
514
|
+
|
|
515
|
+
q4_min = params.sim.min_state[9]
|
|
516
|
+
q4_max = params.sim.max_state[9]
|
|
517
|
+
for traj in scp_trajs:
|
|
518
|
+
fig.add_trace(go.Scatter(y=traj[:,9], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=4)
|
|
519
|
+
fig.add_trace(go.Scatter(y=x_full[:,9], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=4)
|
|
520
|
+
fig.add_hline(y=q4_min, line=dict(color='red', width=2), row = 2, col = 4)
|
|
521
|
+
fig.add_hline(y=q4_max, line=dict(color='red', width=2), row = 2, col = 4)
|
|
522
|
+
|
|
523
|
+
# Plot the angular rate
|
|
524
|
+
wx_min = params.sim.min_state[10]
|
|
525
|
+
wx_max = params.sim.max_state[10]
|
|
526
|
+
for traj in scp_trajs:
|
|
527
|
+
fig.add_trace(go.Scatter(y=traj[:,10], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=5)
|
|
528
|
+
fig.add_trace(go.Scatter(y=x_full[:,10], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=5)
|
|
529
|
+
fig.add_hline(y=wx_min, line=dict(color='red', width=2), row = 2, col = 5)
|
|
530
|
+
fig.add_hline(y=wx_max, line=dict(color='red', width=2), row = 2, col = 5)
|
|
531
|
+
|
|
532
|
+
wy_min = params.sim.min_state[11]
|
|
533
|
+
wy_max = params.sim.max_state[11]
|
|
534
|
+
for traj in scp_trajs:
|
|
535
|
+
fig.add_trace(go.Scatter(y=traj[:,11], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=6)
|
|
536
|
+
fig.add_trace(go.Scatter(y=x_full[:,11], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=6)
|
|
537
|
+
fig.add_hline(y=wy_min, line=dict(color='red', width=2), row = 2, col = 6)
|
|
538
|
+
fig.add_hline(y=wy_max, line=dict(color='red', width=2), row = 2, col = 6)
|
|
539
|
+
|
|
540
|
+
wz_min = params.sim.min_state[12]
|
|
541
|
+
wz_max = params.sim.max_state[12]
|
|
542
|
+
for traj in scp_trajs:
|
|
543
|
+
fig.add_trace(go.Scatter(y=traj[:,12], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=7)
|
|
544
|
+
fig.add_trace(go.Scatter(y=x_full[:,12], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=7)
|
|
545
|
+
fig.add_hline(y=wz_min, line=dict(color='red', width=2), row = 2, col = 7)
|
|
546
|
+
fig.add_hline(y=wz_max, line=dict(color='red', width=2), row = 2, col = 7)
|
|
547
|
+
fig.show()
|
|
548
|
+
|
|
549
|
+
def plot_control(result, params: Config):
|
|
550
|
+
scp_controls = result["u_history"]
|
|
551
|
+
u = result["u"]
|
|
552
|
+
|
|
553
|
+
fx_min = params.sim.min_control[0]
|
|
554
|
+
fx_max = params.sim.max_control[0]
|
|
555
|
+
|
|
556
|
+
fy_min = params.sim.min_control[1]
|
|
557
|
+
fy_max = params.sim.max_control[1]
|
|
558
|
+
|
|
559
|
+
fz_min = params.sim.min_control[2]
|
|
560
|
+
fz_max = params.sim.max_control[2]
|
|
561
|
+
|
|
562
|
+
tau_x_min = params.sim.max_control[3]
|
|
563
|
+
tau_x_max = params.sim.min_control[3]
|
|
564
|
+
|
|
565
|
+
tau_y_min = params.sim.max_control[4]
|
|
566
|
+
tau_y_max = params.sim.min_control[4]
|
|
567
|
+
|
|
568
|
+
tau_z_min = params.sim.max_control[5]
|
|
569
|
+
tau_z_max = params.sim.min_control[5]
|
|
570
|
+
|
|
571
|
+
fig = make_subplots(rows=2, cols=3, subplot_titles=('X Force', 'Y Force', 'Z Force', 'X Torque', 'Y Torque', 'Z Torque'))
|
|
572
|
+
fig.update_layout(title_text="Control Trajectories", template='plotly_dark')
|
|
573
|
+
|
|
574
|
+
for traj in scp_controls:
|
|
575
|
+
fig.add_trace(go.Scatter(y=traj[0], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=1)
|
|
576
|
+
fig.add_trace(go.Scatter(y=u[0], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=1)
|
|
577
|
+
fig.add_hline(y=fx_min, line=dict(color='red', width=2), row = 1, col = 1)
|
|
578
|
+
fig.add_hline(y=fx_max, line=dict(color='red', width=2), row = 1, col = 1)
|
|
579
|
+
|
|
580
|
+
for traj in scp_controls:
|
|
581
|
+
fig.add_trace(go.Scatter(y=traj[1], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=2)
|
|
582
|
+
fig.add_trace(go.Scatter(y=u[1], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=2)
|
|
583
|
+
fig.add_hline(y=fy_min, line=dict(color='red', width=2), row = 1, col = 2)
|
|
584
|
+
fig.add_hline(y=fy_max, line=dict(color='red', width=2), row = 1, col = 2)
|
|
585
|
+
|
|
586
|
+
for traj in scp_controls:
|
|
587
|
+
fig.add_trace(go.Scatter(y=traj[2], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=1, col=3)
|
|
588
|
+
fig.add_trace(go.Scatter(y=u[2], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=3)
|
|
589
|
+
fig.add_hline(y=fz_min, line=dict(color='red', width=2), row = 1, col = 3)
|
|
590
|
+
fig.add_hline(y=fz_max, line=dict(color='red', width=2), row = 1, col = 3)
|
|
591
|
+
|
|
592
|
+
for traj in scp_controls:
|
|
593
|
+
fig.add_trace(go.Scatter(y=traj[3], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=1)
|
|
594
|
+
fig.add_trace(go.Scatter(y=u[3], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=1)
|
|
595
|
+
fig.add_hline(y=tau_x_min, line=dict(color='red', width=2), row = 2, col = 1)
|
|
596
|
+
fig.add_hline(y=tau_x_max, line=dict(color='red', width=2), row = 2, col = 1)
|
|
597
|
+
|
|
598
|
+
for traj in scp_controls:
|
|
599
|
+
fig.add_trace(go.Scatter(y=traj[4], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=2)
|
|
600
|
+
fig.add_trace(go.Scatter(y=u[4], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=2)
|
|
601
|
+
fig.add_hline(y=tau_y_min, line=dict(color='red', width=2), row = 2, col = 2)
|
|
602
|
+
fig.add_hline(y=tau_y_max, line=dict(color='red', width=2), row = 2, col = 2)
|
|
603
|
+
|
|
604
|
+
for traj in scp_controls:
|
|
605
|
+
fig.add_trace(go.Scatter(y=traj[5], mode='lines', showlegend=False, line=dict(color='gray', width = 0.5)), row=2, col=3)
|
|
606
|
+
fig.add_trace(go.Scatter(y=u[5], mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=3)
|
|
607
|
+
fig.add_hline(y=tau_z_min, line=dict(color='red', width=2), row = 2, col = 3)
|
|
608
|
+
fig.add_hline(y=tau_z_max, line=dict(color='red', width=2), row = 2, col = 3)
|
|
609
|
+
|
|
610
|
+
fig.show()
|
|
611
|
+
|
|
612
|
+
def plot_losses(result, params: Config):
|
|
613
|
+
# Plot J_tr, J_vb, J_vc, J_vc_ctcs
|
|
614
|
+
J_tr = result["J_tr_history"]
|
|
615
|
+
J_vb = result["J_vb_history"]
|
|
616
|
+
J_vc = result["J_vc_history"]
|
|
617
|
+
J_vc_ctcs = result["J_vc_ctcs_vec"]
|
|
618
|
+
|
|
619
|
+
fig = make_subplots(rows=2, cols=2, subplot_titles=('J_tr', 'J_vb', 'J_vc', 'J_vc_ctcs'))
|
|
620
|
+
fig.update_layout(title_text="Losses", template='plotly_dark')
|
|
621
|
+
|
|
622
|
+
fig.add_trace(go.Scatter(y=J_tr, mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=1)
|
|
623
|
+
fig.add_trace(go.Scatter(y=J_vb, mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=1, col=2)
|
|
624
|
+
fig.add_trace(go.Scatter(y=J_vc, mode='lines', showlegend=False, line=dict(color='green', width = 2)), row=2, col=1)
|
|
625
|
+
|
|
626
|
+
# Set y-axis to log scale for each subplot
|
|
627
|
+
fig.update_yaxes(type='log', row=1, col=1)
|
|
628
|
+
fig.update_yaxes(type='log', row=1, col=2)
|
|
629
|
+
fig.update_yaxes(type='log', row=2, col=1)
|
|
630
|
+
fig.update_yaxes(type='log', row=2, col=2)
|
|
631
|
+
|
|
632
|
+
fig.show()
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from openscvx.propagation import s_to_t, t_to_tau, simulate_nonlinear_time
|
|
4
|
+
from openscvx.config import Config
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def propagate_trajectory_results(params: Config, result: dict, propagation_solver: callable) -> dict:
|
|
8
|
+
x = result["x"]
|
|
9
|
+
u = result["u"]
|
|
10
|
+
|
|
11
|
+
t = np.array(s_to_t(u, params))
|
|
12
|
+
|
|
13
|
+
t_full = np.arange(0, t[-1], params.prp.dt)
|
|
14
|
+
|
|
15
|
+
tau_vals, u_full = t_to_tau(u, t_full, u, t, params)
|
|
16
|
+
|
|
17
|
+
x_full = simulate_nonlinear_time(x[0], u, tau_vals, t, params, propagation_solver)
|
|
18
|
+
|
|
19
|
+
print("Total CTCS Constraint Violation:", x_full[-1, params.sim.idx_y])
|
|
20
|
+
i = 0
|
|
21
|
+
cost = np.zeros_like(x[-1, i])
|
|
22
|
+
for type in params.sim.initial_state.type:
|
|
23
|
+
if type == "Minimize":
|
|
24
|
+
cost += x[0, i]
|
|
25
|
+
i += 1
|
|
26
|
+
i = 0
|
|
27
|
+
for type in params.sim.final_state.type:
|
|
28
|
+
if type == "Minimize":
|
|
29
|
+
cost += x[-1, i]
|
|
30
|
+
i += 1
|
|
31
|
+
print("Cost: ", cost)
|
|
32
|
+
|
|
33
|
+
more_result = dict(t_full=t_full, x_full=x_full, u_full=u_full)
|
|
34
|
+
|
|
35
|
+
result.update(more_result)
|
|
36
|
+
return result
|