openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,756 @@
|
|
|
1
|
+
import plotly.graph_objects as go
|
|
2
|
+
from plotly.subplots import make_subplots
|
|
3
|
+
|
|
4
|
+
from openscvx.algorithms import OptimizationResults
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _get_var(result: OptimizationResults, var_name: str, var_list: list):
|
|
8
|
+
"""Get a variable object by name from the metadata list."""
|
|
9
|
+
for var in var_list:
|
|
10
|
+
if var.name == var_name:
|
|
11
|
+
return var
|
|
12
|
+
raise ValueError(f"Variable '{var_name}' not found")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _get_var_dim(result: OptimizationResults, var_name: str, var_list: list) -> int:
|
|
16
|
+
"""Get dimensionality of a variable from the metadata."""
|
|
17
|
+
var = _get_var(result, var_name, var_list)
|
|
18
|
+
s = var._slice
|
|
19
|
+
if isinstance(s, slice):
|
|
20
|
+
return (s.stop or 1) - (s.start or 0)
|
|
21
|
+
return 1
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _add_component_traces(
|
|
25
|
+
fig: go.Figure,
|
|
26
|
+
result: OptimizationResults,
|
|
27
|
+
var_name: str,
|
|
28
|
+
component_idx: int,
|
|
29
|
+
row: int,
|
|
30
|
+
col: int,
|
|
31
|
+
show_legend: bool,
|
|
32
|
+
min_val: float | None = None,
|
|
33
|
+
max_val: float | None = None,
|
|
34
|
+
):
|
|
35
|
+
"""Add traces for a single component of a variable to a subplot.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
fig: Plotly figure to add traces to
|
|
39
|
+
result: Optimization results
|
|
40
|
+
var_name: Name of the variable
|
|
41
|
+
component_idx: Index of the component to plot
|
|
42
|
+
row: Subplot row
|
|
43
|
+
col: Subplot column
|
|
44
|
+
show_legend: Whether to show legend entries
|
|
45
|
+
min_val: Optional minimum bound to show as horizontal line
|
|
46
|
+
max_val: Optional maximum bound to show as horizontal line
|
|
47
|
+
"""
|
|
48
|
+
import numpy as np
|
|
49
|
+
|
|
50
|
+
t_nodes = result.nodes["time"].flatten()
|
|
51
|
+
has_trajectory = bool(result.trajectory) and var_name in result.trajectory
|
|
52
|
+
t_full = result.trajectory["time"].flatten() if has_trajectory else None
|
|
53
|
+
|
|
54
|
+
# Plot propagated trajectory if available
|
|
55
|
+
if has_trajectory:
|
|
56
|
+
data = result.trajectory[var_name]
|
|
57
|
+
y = data if data.ndim == 1 else data[:, component_idx]
|
|
58
|
+
fig.add_trace(
|
|
59
|
+
go.Scatter(
|
|
60
|
+
x=t_full,
|
|
61
|
+
y=y,
|
|
62
|
+
mode="lines",
|
|
63
|
+
name="Trajectory",
|
|
64
|
+
showlegend=show_legend,
|
|
65
|
+
legendgroup="trajectory",
|
|
66
|
+
line={"color": "green", "width": 2},
|
|
67
|
+
),
|
|
68
|
+
row=row,
|
|
69
|
+
col=col,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Plot optimization nodes
|
|
73
|
+
if var_name in result.nodes:
|
|
74
|
+
data = result.nodes[var_name]
|
|
75
|
+
y = data if data.ndim == 1 else data[:, component_idx]
|
|
76
|
+
fig.add_trace(
|
|
77
|
+
go.Scatter(
|
|
78
|
+
x=t_nodes,
|
|
79
|
+
y=y,
|
|
80
|
+
mode="markers",
|
|
81
|
+
name="Nodes",
|
|
82
|
+
showlegend=show_legend,
|
|
83
|
+
legendgroup="nodes",
|
|
84
|
+
marker={"color": "cyan", "size": 6, "symbol": "circle"},
|
|
85
|
+
),
|
|
86
|
+
row=row,
|
|
87
|
+
col=col,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Add horizontal bound lines if provided
|
|
91
|
+
# Only add if finite (skip -inf/inf bounds)
|
|
92
|
+
if min_val is not None and np.isfinite(min_val):
|
|
93
|
+
fig.add_hline(
|
|
94
|
+
y=min_val,
|
|
95
|
+
line={"color": "red", "width": 1.5, "dash": "dash"},
|
|
96
|
+
row=row,
|
|
97
|
+
col=col,
|
|
98
|
+
)
|
|
99
|
+
if max_val is not None and np.isfinite(max_val):
|
|
100
|
+
fig.add_hline(
|
|
101
|
+
y=max_val,
|
|
102
|
+
line={"color": "red", "width": 1.5, "dash": "dash"},
|
|
103
|
+
row=row,
|
|
104
|
+
col=col,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# =============================================================================
|
|
109
|
+
# State Plotting
|
|
110
|
+
# =============================================================================
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def plot_state_component(
|
|
114
|
+
result: OptimizationResults,
|
|
115
|
+
state_name: str,
|
|
116
|
+
component: int = 0,
|
|
117
|
+
) -> go.Figure:
|
|
118
|
+
"""Plot a single component of a state variable vs time.
|
|
119
|
+
|
|
120
|
+
This is the low-level function for plotting one scalar value over time.
|
|
121
|
+
For plotting all components of a state, use plot_states().
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
result: Optimization results containing state trajectories
|
|
125
|
+
state_name: Name of the state variable
|
|
126
|
+
component: Component index (0-indexed). For scalar states, use 0.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Plotly figure with single plot
|
|
130
|
+
|
|
131
|
+
Example:
|
|
132
|
+
>>> plot_state_component(result, "position", 2) # Plot z-component
|
|
133
|
+
"""
|
|
134
|
+
available = {s.name for s in result._states}
|
|
135
|
+
if state_name not in available:
|
|
136
|
+
raise ValueError(f"State '{state_name}' not found. Available: {sorted(available)}")
|
|
137
|
+
|
|
138
|
+
dim = _get_var_dim(result, state_name, result._states)
|
|
139
|
+
if component < 0 or component >= dim:
|
|
140
|
+
raise ValueError(f"Component {component} out of range for '{state_name}' (dim={dim})")
|
|
141
|
+
|
|
142
|
+
t_nodes = result.nodes["time"].flatten()
|
|
143
|
+
has_trajectory = bool(result.trajectory) and state_name in result.trajectory
|
|
144
|
+
t_full = result.trajectory["time"].flatten() if has_trajectory else None
|
|
145
|
+
|
|
146
|
+
label = f"{state_name}_{component}" if dim > 1 else state_name
|
|
147
|
+
|
|
148
|
+
fig = go.Figure()
|
|
149
|
+
fig.update_layout(title_text=label, template="plotly_dark")
|
|
150
|
+
|
|
151
|
+
if has_trajectory:
|
|
152
|
+
data = result.trajectory[state_name]
|
|
153
|
+
y = data if data.ndim == 1 else data[:, component]
|
|
154
|
+
fig.add_trace(
|
|
155
|
+
go.Scatter(
|
|
156
|
+
x=t_full,
|
|
157
|
+
y=y,
|
|
158
|
+
mode="lines",
|
|
159
|
+
name="Trajectory",
|
|
160
|
+
line={"color": "green", "width": 2},
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if state_name in result.nodes:
|
|
165
|
+
data = result.nodes[state_name]
|
|
166
|
+
y = data if data.ndim == 1 else data[:, component]
|
|
167
|
+
fig.add_trace(
|
|
168
|
+
go.Scatter(
|
|
169
|
+
x=t_nodes,
|
|
170
|
+
y=y,
|
|
171
|
+
mode="markers",
|
|
172
|
+
name="Nodes",
|
|
173
|
+
marker={"color": "cyan", "size": 6},
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
fig.update_xaxes(title_text="Time (s)")
|
|
178
|
+
fig.update_yaxes(title_text=label)
|
|
179
|
+
return fig
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def plot_states(
|
|
183
|
+
result: OptimizationResults,
|
|
184
|
+
state_names: list[str] | None = None,
|
|
185
|
+
include_private: bool = False,
|
|
186
|
+
cols: int = 4,
|
|
187
|
+
) -> go.Figure:
|
|
188
|
+
"""Plot state variables in a subplot grid.
|
|
189
|
+
|
|
190
|
+
Each component of each state gets its own subplot with individual y-axis
|
|
191
|
+
scaling. This is the primary function for visualizing state trajectories.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
result: Optimization results containing state trajectories
|
|
195
|
+
state_names: List of state names to plot. If None, plots all states.
|
|
196
|
+
include_private: Whether to include private states (names starting with '_')
|
|
197
|
+
cols: Maximum number of columns in subplot grid
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Plotly figure with subplot grid
|
|
201
|
+
|
|
202
|
+
Examples:
|
|
203
|
+
>>> plot_states(result, ["position"]) # 3 subplots for x, y, z
|
|
204
|
+
>>> plot_states(result, ["position", "velocity"]) # 6 subplots
|
|
205
|
+
>>> plot_states(result) # All states
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
states = result._states
|
|
209
|
+
if not include_private:
|
|
210
|
+
states = [s for s in states if not s.name.startswith("_")]
|
|
211
|
+
|
|
212
|
+
if state_names is not None:
|
|
213
|
+
available = {s.name for s in states}
|
|
214
|
+
missing = set(state_names) - available
|
|
215
|
+
if missing:
|
|
216
|
+
raise ValueError(f"States not found in result: {missing}")
|
|
217
|
+
# Preserve order from state_names
|
|
218
|
+
state_order = {name: i for i, name in enumerate(state_names)}
|
|
219
|
+
states = sorted(
|
|
220
|
+
[s for s in states if s.name in state_names],
|
|
221
|
+
key=lambda s: state_order[s.name],
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Build list of (display_name, var_name, component_idx)
|
|
225
|
+
components = []
|
|
226
|
+
for s in states:
|
|
227
|
+
dim = _get_var_dim(result, s.name, result._states)
|
|
228
|
+
if dim == 1:
|
|
229
|
+
components.append((s.name, s.name, 0))
|
|
230
|
+
else:
|
|
231
|
+
for i in range(dim):
|
|
232
|
+
components.append((f"{s.name}_{i}", s.name, i))
|
|
233
|
+
|
|
234
|
+
if not components:
|
|
235
|
+
raise ValueError("No state components to plot")
|
|
236
|
+
|
|
237
|
+
n_cols = min(cols, len(components))
|
|
238
|
+
n_rows = (len(components) + n_cols - 1) // n_cols
|
|
239
|
+
|
|
240
|
+
fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=[c[0] for c in components])
|
|
241
|
+
fig.update_layout(title_text="State Trajectories", template="plotly_dark")
|
|
242
|
+
|
|
243
|
+
for idx, (_, var_name, comp_idx) in enumerate(components):
|
|
244
|
+
row = (idx // n_cols) + 1
|
|
245
|
+
col = (idx % n_cols) + 1
|
|
246
|
+
|
|
247
|
+
# Get bounds for this component
|
|
248
|
+
var = _get_var(result, var_name, result._states)
|
|
249
|
+
min_val = var.min[comp_idx] if var.min is not None else None
|
|
250
|
+
max_val = var.max[comp_idx] if var.max is not None else None
|
|
251
|
+
|
|
252
|
+
_add_component_traces(
|
|
253
|
+
fig,
|
|
254
|
+
result,
|
|
255
|
+
var_name,
|
|
256
|
+
comp_idx,
|
|
257
|
+
row,
|
|
258
|
+
col,
|
|
259
|
+
show_legend=(idx == 0),
|
|
260
|
+
min_val=min_val,
|
|
261
|
+
max_val=max_val,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Add x-axis labels to bottom row
|
|
265
|
+
for col_idx in range(1, n_cols + 1):
|
|
266
|
+
fig.update_xaxes(title_text="Time (s)", row=n_rows, col=col_idx)
|
|
267
|
+
|
|
268
|
+
return fig
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
# =============================================================================
|
|
272
|
+
# Control Plotting
|
|
273
|
+
# =============================================================================
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def plot_control_component(
|
|
277
|
+
result: OptimizationResults,
|
|
278
|
+
control_name: str,
|
|
279
|
+
component: int = 0,
|
|
280
|
+
) -> go.Figure:
|
|
281
|
+
"""Plot a single component of a control variable vs time.
|
|
282
|
+
|
|
283
|
+
This is the low-level function for plotting one scalar control over time.
|
|
284
|
+
For plotting all components of a control, use plot_controls().
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
result: Optimization results containing control trajectories
|
|
288
|
+
control_name: Name of the control variable
|
|
289
|
+
component: Component index (0-indexed). For scalar controls, use 0.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
Plotly figure with single plot
|
|
293
|
+
|
|
294
|
+
Example:
|
|
295
|
+
>>> plot_control_component(result, "thrust", 0) # Plot thrust_x
|
|
296
|
+
"""
|
|
297
|
+
available = {c.name for c in result._controls}
|
|
298
|
+
if control_name not in available:
|
|
299
|
+
raise ValueError(f"Control '{control_name}' not found. Available: {sorted(available)}")
|
|
300
|
+
|
|
301
|
+
dim = _get_var_dim(result, control_name, result._controls)
|
|
302
|
+
if component < 0 or component >= dim:
|
|
303
|
+
raise ValueError(f"Component {component} out of range for '{control_name}' (dim={dim})")
|
|
304
|
+
|
|
305
|
+
t_nodes = result.nodes["time"].flatten()
|
|
306
|
+
has_trajectory = bool(result.trajectory) and control_name in result.trajectory
|
|
307
|
+
t_full = result.trajectory["time"].flatten() if has_trajectory else None
|
|
308
|
+
|
|
309
|
+
label = f"{control_name}_{component}" if dim > 1 else control_name
|
|
310
|
+
|
|
311
|
+
fig = go.Figure()
|
|
312
|
+
fig.update_layout(title_text=label, template="plotly_dark")
|
|
313
|
+
|
|
314
|
+
if has_trajectory:
|
|
315
|
+
data = result.trajectory[control_name]
|
|
316
|
+
y = data if data.ndim == 1 else data[:, component]
|
|
317
|
+
fig.add_trace(
|
|
318
|
+
go.Scatter(
|
|
319
|
+
x=t_full,
|
|
320
|
+
y=y,
|
|
321
|
+
mode="lines",
|
|
322
|
+
name="Trajectory",
|
|
323
|
+
line={"color": "green", "width": 2},
|
|
324
|
+
)
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
if control_name in result.nodes:
|
|
328
|
+
data = result.nodes[control_name]
|
|
329
|
+
y = data if data.ndim == 1 else data[:, component]
|
|
330
|
+
fig.add_trace(
|
|
331
|
+
go.Scatter(
|
|
332
|
+
x=t_nodes,
|
|
333
|
+
y=y,
|
|
334
|
+
mode="markers",
|
|
335
|
+
name="Nodes",
|
|
336
|
+
marker={"color": "cyan", "size": 6},
|
|
337
|
+
)
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
fig.update_xaxes(title_text="Time (s)")
|
|
341
|
+
fig.update_yaxes(title_text=label)
|
|
342
|
+
return fig
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def plot_controls(
|
|
346
|
+
result: OptimizationResults,
|
|
347
|
+
control_names: list[str] | None = None,
|
|
348
|
+
include_private: bool = False,
|
|
349
|
+
cols: int = 3,
|
|
350
|
+
) -> go.Figure:
|
|
351
|
+
"""Plot control variables in a subplot grid.
|
|
352
|
+
|
|
353
|
+
Each component of each control gets its own subplot with individual y-axis
|
|
354
|
+
scaling. This is the primary function for visualizing control trajectories.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
result: Optimization results containing control trajectories
|
|
358
|
+
control_names: List of control names to plot. If None, plots all controls.
|
|
359
|
+
include_private: Whether to include private controls (names starting with '_')
|
|
360
|
+
cols: Maximum number of columns in subplot grid
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
Plotly figure with subplot grid
|
|
364
|
+
|
|
365
|
+
Examples:
|
|
366
|
+
>>> plot_controls(result, ["thrust"]) # 3 subplots for x, y, z
|
|
367
|
+
>>> plot_controls(result) # All controls
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
controls = result._controls
|
|
371
|
+
if not include_private:
|
|
372
|
+
controls = [c for c in controls if not c.name.startswith("_")]
|
|
373
|
+
|
|
374
|
+
if control_names is not None:
|
|
375
|
+
available = {c.name for c in controls}
|
|
376
|
+
missing = set(control_names) - available
|
|
377
|
+
if missing:
|
|
378
|
+
raise ValueError(f"Controls not found in result: {missing}")
|
|
379
|
+
# Preserve order from control_names
|
|
380
|
+
control_order = {name: i for i, name in enumerate(control_names)}
|
|
381
|
+
controls = sorted(
|
|
382
|
+
[c for c in controls if c.name in control_names],
|
|
383
|
+
key=lambda c: control_order[c.name],
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Build list of (display_name, var_name, component_idx)
|
|
387
|
+
components = []
|
|
388
|
+
for c in controls:
|
|
389
|
+
dim = _get_var_dim(result, c.name, result._controls)
|
|
390
|
+
if dim == 1:
|
|
391
|
+
components.append((c.name, c.name, 0))
|
|
392
|
+
else:
|
|
393
|
+
for i in range(dim):
|
|
394
|
+
components.append((f"{c.name}_{i}", c.name, i))
|
|
395
|
+
|
|
396
|
+
if not components:
|
|
397
|
+
raise ValueError("No control components to plot")
|
|
398
|
+
|
|
399
|
+
n_cols = min(cols, len(components))
|
|
400
|
+
n_rows = (len(components) + n_cols - 1) // n_cols
|
|
401
|
+
|
|
402
|
+
fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=[c[0] for c in components])
|
|
403
|
+
fig.update_layout(title_text="Control Trajectories", template="plotly_dark")
|
|
404
|
+
|
|
405
|
+
for idx, (_, var_name, comp_idx) in enumerate(components):
|
|
406
|
+
row = (idx // n_cols) + 1
|
|
407
|
+
col = (idx % n_cols) + 1
|
|
408
|
+
|
|
409
|
+
# Get bounds for this component
|
|
410
|
+
var = _get_var(result, var_name, result._controls)
|
|
411
|
+
min_val = var.min[comp_idx] if var.min is not None else None
|
|
412
|
+
max_val = var.max[comp_idx] if var.max is not None else None
|
|
413
|
+
|
|
414
|
+
_add_component_traces(
|
|
415
|
+
fig,
|
|
416
|
+
result,
|
|
417
|
+
var_name,
|
|
418
|
+
comp_idx,
|
|
419
|
+
row,
|
|
420
|
+
col,
|
|
421
|
+
show_legend=(idx == 0),
|
|
422
|
+
min_val=min_val,
|
|
423
|
+
max_val=max_val,
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
# Add x-axis labels to bottom row
|
|
427
|
+
for col_idx in range(1, n_cols + 1):
|
|
428
|
+
fig.update_xaxes(title_text="Time (s)", row=n_rows, col=col_idx)
|
|
429
|
+
|
|
430
|
+
return fig
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def plot_trust_region_heatmap(result: OptimizationResults):
|
|
434
|
+
"""Plot heatmap of the final trust-region deltas (TR_history[-1])."""
|
|
435
|
+
if not result.TR_history:
|
|
436
|
+
raise ValueError("Result has no TR_history to plot")
|
|
437
|
+
|
|
438
|
+
tr_mat = result.TR_history[-1]
|
|
439
|
+
|
|
440
|
+
# Build variable names list
|
|
441
|
+
var_names = []
|
|
442
|
+
for var_list in [result._states, result._controls]:
|
|
443
|
+
for var in var_list:
|
|
444
|
+
dim = _get_var_dim(result, var.name, var_list)
|
|
445
|
+
if dim == 1:
|
|
446
|
+
var_names.append(var.name)
|
|
447
|
+
else:
|
|
448
|
+
var_names.extend(f"{var.name}_{i}" for i in range(dim))
|
|
449
|
+
|
|
450
|
+
# TR matrix is (n_states+n_controls, n_nodes): rows = variables, cols = nodes
|
|
451
|
+
if tr_mat.shape[0] == len(var_names):
|
|
452
|
+
z = tr_mat
|
|
453
|
+
elif tr_mat.shape[1] == len(var_names):
|
|
454
|
+
z = tr_mat.T
|
|
455
|
+
else:
|
|
456
|
+
raise ValueError("TR matrix dimensions do not align with state/control components")
|
|
457
|
+
|
|
458
|
+
x_len = z.shape[1]
|
|
459
|
+
t_nodes = result.nodes["time"].flatten()
|
|
460
|
+
x_labels = t_nodes if len(t_nodes) == x_len else list(range(x_len))
|
|
461
|
+
|
|
462
|
+
fig = go.Figure(data=go.Heatmap(z=z, x=x_labels, y=var_names, colorscale="Viridis"))
|
|
463
|
+
fig.update_layout(
|
|
464
|
+
title="Trust Region Delta Magnitudes (last iteration)", template="plotly_dark"
|
|
465
|
+
)
|
|
466
|
+
fig.update_xaxes(title_text="Node / Time", side="bottom")
|
|
467
|
+
fig.update_yaxes(title_text="State / Control component", side="left")
|
|
468
|
+
return fig
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def plot_projections_2d(
|
|
472
|
+
result: OptimizationResults,
|
|
473
|
+
var_name: str = "position",
|
|
474
|
+
velocity_var_name: str | None = None,
|
|
475
|
+
cmap: str = "viridis",
|
|
476
|
+
):
|
|
477
|
+
"""Plot XY, XZ, YZ projections of a 3D variable.
|
|
478
|
+
|
|
479
|
+
Useful for visualizing 3D trajectories in 2D plane views.
|
|
480
|
+
|
|
481
|
+
Args:
|
|
482
|
+
result: Optimization results containing trajectories
|
|
483
|
+
var_name: Name of the 3D variable to plot (default: "position")
|
|
484
|
+
velocity_var_name: Optional name of velocity variable for coloring by speed.
|
|
485
|
+
If provided, trajectory points are colored by velocity magnitude.
|
|
486
|
+
cmap: Matplotlib colormap name for velocity coloring (default: "viridis")
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
Plotly figure with three subplots (XY, XZ, YZ planes)
|
|
490
|
+
"""
|
|
491
|
+
import numpy as np
|
|
492
|
+
|
|
493
|
+
has_trajectory = bool(result.trajectory) and var_name in result.trajectory
|
|
494
|
+
has_nodes = var_name in result.nodes
|
|
495
|
+
|
|
496
|
+
if not has_trajectory and not has_nodes:
|
|
497
|
+
available_traj = set(result.trajectory.keys()) if result.trajectory else set()
|
|
498
|
+
available_nodes = set(result.nodes.keys())
|
|
499
|
+
raise ValueError(
|
|
500
|
+
f"Variable '{var_name}' not found. "
|
|
501
|
+
f"Available in trajectory: {sorted(available_traj)}, nodes: {sorted(available_nodes)}"
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
fig = make_subplots(
|
|
505
|
+
rows=2,
|
|
506
|
+
cols=2,
|
|
507
|
+
subplot_titles=("XY Plane", "XZ Plane", "YZ Plane"),
|
|
508
|
+
specs=[[{}, {}], [{}, None]],
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
# Subplot positions: (x_idx, y_idx, row, col)
|
|
512
|
+
subplots = [(0, 1, 1, 1), (0, 2, 1, 2), (1, 2, 2, 1)]
|
|
513
|
+
|
|
514
|
+
# Compute velocity norms if velocity variable is provided
|
|
515
|
+
traj_vel_norm = None
|
|
516
|
+
node_vel_norm = None
|
|
517
|
+
if velocity_var_name is not None:
|
|
518
|
+
if has_trajectory and velocity_var_name in result.trajectory:
|
|
519
|
+
traj_vel_norm = np.linalg.norm(result.trajectory[velocity_var_name], axis=1)
|
|
520
|
+
if has_nodes and velocity_var_name in result.nodes:
|
|
521
|
+
node_vel_norm = np.linalg.norm(result.nodes[velocity_var_name], axis=1)
|
|
522
|
+
|
|
523
|
+
# Colorbar config (only shown once)
|
|
524
|
+
colorbar_cfg = {"title": "‖velocity‖", "x": 1.02, "y": 0.5, "len": 0.9}
|
|
525
|
+
|
|
526
|
+
# Plot trajectory if available
|
|
527
|
+
if has_trajectory:
|
|
528
|
+
data = result.trajectory[var_name]
|
|
529
|
+
for i, (xi, yi, row, col) in enumerate(subplots):
|
|
530
|
+
if traj_vel_norm is not None:
|
|
531
|
+
marker = {
|
|
532
|
+
"size": 4,
|
|
533
|
+
"color": traj_vel_norm,
|
|
534
|
+
"colorscale": cmap,
|
|
535
|
+
"showscale": (i == 0),
|
|
536
|
+
"colorbar": colorbar_cfg if i == 0 else None,
|
|
537
|
+
}
|
|
538
|
+
fig.add_trace(
|
|
539
|
+
go.Scatter(
|
|
540
|
+
x=data[:, xi],
|
|
541
|
+
y=data[:, yi],
|
|
542
|
+
mode="markers",
|
|
543
|
+
marker=marker,
|
|
544
|
+
name="Trajectory",
|
|
545
|
+
legendgroup="trajectory",
|
|
546
|
+
showlegend=(i == 0),
|
|
547
|
+
),
|
|
548
|
+
row=row,
|
|
549
|
+
col=col,
|
|
550
|
+
)
|
|
551
|
+
else:
|
|
552
|
+
fig.add_trace(
|
|
553
|
+
go.Scatter(
|
|
554
|
+
x=data[:, xi],
|
|
555
|
+
y=data[:, yi],
|
|
556
|
+
mode="lines",
|
|
557
|
+
line={"color": "green", "width": 2},
|
|
558
|
+
name="Trajectory",
|
|
559
|
+
legendgroup="trajectory",
|
|
560
|
+
showlegend=(i == 0),
|
|
561
|
+
),
|
|
562
|
+
row=row,
|
|
563
|
+
col=col,
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
# Plot nodes if available
|
|
567
|
+
if has_nodes:
|
|
568
|
+
data = result.nodes[var_name]
|
|
569
|
+
# Only show colorbar on nodes if trajectory doesn't have one
|
|
570
|
+
show_node_colorbar = (traj_vel_norm is None) and (node_vel_norm is not None)
|
|
571
|
+
for i, (xi, yi, row, col) in enumerate(subplots):
|
|
572
|
+
if node_vel_norm is not None:
|
|
573
|
+
marker = {
|
|
574
|
+
"size": 8,
|
|
575
|
+
"color": node_vel_norm,
|
|
576
|
+
"colorscale": cmap,
|
|
577
|
+
"showscale": show_node_colorbar and (i == 0),
|
|
578
|
+
"colorbar": colorbar_cfg if (show_node_colorbar and i == 0) else None,
|
|
579
|
+
"line": {"color": "white", "width": 1},
|
|
580
|
+
}
|
|
581
|
+
else:
|
|
582
|
+
marker = {"color": "cyan", "size": 6}
|
|
583
|
+
fig.add_trace(
|
|
584
|
+
go.Scatter(
|
|
585
|
+
x=data[:, xi],
|
|
586
|
+
y=data[:, yi],
|
|
587
|
+
mode="markers",
|
|
588
|
+
marker=marker,
|
|
589
|
+
name="Nodes",
|
|
590
|
+
legendgroup="nodes",
|
|
591
|
+
showlegend=(i == 0),
|
|
592
|
+
),
|
|
593
|
+
row=row,
|
|
594
|
+
col=col,
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
# Set axis titles
|
|
598
|
+
fig.update_xaxes(title_text="X", row=1, col=1)
|
|
599
|
+
fig.update_yaxes(title_text="Y", row=1, col=1)
|
|
600
|
+
fig.update_xaxes(title_text="X", row=1, col=2)
|
|
601
|
+
fig.update_yaxes(title_text="Z", row=1, col=2)
|
|
602
|
+
fig.update_xaxes(title_text="Y", row=2, col=1)
|
|
603
|
+
fig.update_yaxes(title_text="Z", row=2, col=1)
|
|
604
|
+
|
|
605
|
+
# Set equal aspect ratio for each subplot
|
|
606
|
+
layout_opts = {
|
|
607
|
+
"title": f"{var_name} - XY, XZ, YZ Projections",
|
|
608
|
+
"template": "plotly_dark",
|
|
609
|
+
"xaxis": {"scaleanchor": "y"},
|
|
610
|
+
"xaxis2": {"scaleanchor": "y2"},
|
|
611
|
+
"xaxis3": {"scaleanchor": "y3"},
|
|
612
|
+
}
|
|
613
|
+
# Move legend to bottom-right when using colorbar to avoid overlap
|
|
614
|
+
if velocity_var_name is not None:
|
|
615
|
+
layout_opts["legend"] = {
|
|
616
|
+
"orientation": "h",
|
|
617
|
+
"yanchor": "bottom",
|
|
618
|
+
"y": -0.15,
|
|
619
|
+
"xanchor": "center",
|
|
620
|
+
"x": 0.5,
|
|
621
|
+
}
|
|
622
|
+
fig.update_layout(**layout_opts)
|
|
623
|
+
|
|
624
|
+
return fig
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def plot_vector_norm(
|
|
628
|
+
result: OptimizationResults,
|
|
629
|
+
var_name: str,
|
|
630
|
+
bounds: tuple[float, float] | None = None,
|
|
631
|
+
):
|
|
632
|
+
"""Plot the 2-norm of a vector variable over time.
|
|
633
|
+
|
|
634
|
+
Useful for visualizing thrust magnitude, velocity magnitude, etc.
|
|
635
|
+
|
|
636
|
+
Args:
|
|
637
|
+
result: Optimization results containing trajectories
|
|
638
|
+
var_name: Name of the vector variable (state or control)
|
|
639
|
+
bounds: Optional (min, max) bounds to show as horizontal dashed lines
|
|
640
|
+
|
|
641
|
+
Returns:
|
|
642
|
+
Plotly figure
|
|
643
|
+
"""
|
|
644
|
+
import numpy as np
|
|
645
|
+
|
|
646
|
+
has_trajectory = bool(result.trajectory) and var_name in result.trajectory
|
|
647
|
+
has_nodes = var_name in result.nodes
|
|
648
|
+
|
|
649
|
+
if not has_trajectory and not has_nodes:
|
|
650
|
+
available_traj = set(result.trajectory.keys()) if result.trajectory else set()
|
|
651
|
+
available_nodes = set(result.nodes.keys())
|
|
652
|
+
raise ValueError(
|
|
653
|
+
f"Variable '{var_name}' not found. "
|
|
654
|
+
f"Available in trajectory: {sorted(available_traj)}, nodes: {sorted(available_nodes)}"
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
fig = go.Figure()
|
|
658
|
+
|
|
659
|
+
# Plot trajectory norm if available
|
|
660
|
+
if has_trajectory:
|
|
661
|
+
t_full = result.trajectory["time"].flatten()
|
|
662
|
+
data = result.trajectory[var_name]
|
|
663
|
+
norm = np.linalg.norm(data, axis=1)
|
|
664
|
+
fig.add_trace(
|
|
665
|
+
go.Scatter(
|
|
666
|
+
x=t_full,
|
|
667
|
+
y=norm,
|
|
668
|
+
mode="lines",
|
|
669
|
+
line={"color": "green", "width": 2},
|
|
670
|
+
name="Trajectory",
|
|
671
|
+
legendgroup="trajectory",
|
|
672
|
+
)
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
# Plot node norms if available
|
|
676
|
+
if has_nodes:
|
|
677
|
+
t_nodes = result.nodes["time"].flatten()
|
|
678
|
+
data = result.nodes[var_name]
|
|
679
|
+
norm = np.linalg.norm(data, axis=1)
|
|
680
|
+
fig.add_trace(
|
|
681
|
+
go.Scatter(
|
|
682
|
+
x=t_nodes,
|
|
683
|
+
y=norm,
|
|
684
|
+
mode="markers",
|
|
685
|
+
marker={"color": "cyan", "size": 6},
|
|
686
|
+
name="Nodes",
|
|
687
|
+
legendgroup="nodes",
|
|
688
|
+
)
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
# Add bounds if provided
|
|
692
|
+
if bounds is not None:
|
|
693
|
+
min_bound, max_bound = bounds
|
|
694
|
+
fig.add_hline(
|
|
695
|
+
y=min_bound,
|
|
696
|
+
line={"color": "red", "width": 2, "dash": "dash"},
|
|
697
|
+
annotation_text="Min",
|
|
698
|
+
annotation_position="right",
|
|
699
|
+
)
|
|
700
|
+
fig.add_hline(
|
|
701
|
+
y=max_bound,
|
|
702
|
+
line={"color": "red", "width": 2, "dash": "dash"},
|
|
703
|
+
annotation_text="Max",
|
|
704
|
+
annotation_position="right",
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
fig.update_layout(
|
|
708
|
+
title=f"‖{var_name}‖₂",
|
|
709
|
+
xaxis_title="Time (s)",
|
|
710
|
+
yaxis_title="Norm",
|
|
711
|
+
template="plotly_dark",
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
return fig
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
def plot_virtual_control_heatmap(result: OptimizationResults):
|
|
718
|
+
"""Plot heatmap of the final virtual control magnitudes (VC_history[-1])."""
|
|
719
|
+
if not result.VC_history:
|
|
720
|
+
raise ValueError("Result has no VC_history to plot")
|
|
721
|
+
|
|
722
|
+
vc_mat = result.VC_history[-1]
|
|
723
|
+
|
|
724
|
+
# Build state names list
|
|
725
|
+
state_names = []
|
|
726
|
+
for var in result._states:
|
|
727
|
+
dim = _get_var_dim(result, var.name, result._states)
|
|
728
|
+
if dim == 1:
|
|
729
|
+
state_names.append(var.name)
|
|
730
|
+
else:
|
|
731
|
+
state_names.extend(f"{var.name}_{i}" for i in range(dim))
|
|
732
|
+
|
|
733
|
+
# Align so rows = states, cols = nodes
|
|
734
|
+
if vc_mat.shape[1] == len(state_names):
|
|
735
|
+
z = vc_mat.T
|
|
736
|
+
elif vc_mat.shape[0] == len(state_names):
|
|
737
|
+
z = vc_mat
|
|
738
|
+
else:
|
|
739
|
+
raise ValueError("VC matrix shape does not align with state components")
|
|
740
|
+
|
|
741
|
+
x_len = z.shape[1]
|
|
742
|
+
t_nodes = result.nodes["time"].flatten()
|
|
743
|
+
|
|
744
|
+
# Virtual control uses N-1 intervals
|
|
745
|
+
if len(t_nodes) == x_len + 1:
|
|
746
|
+
x_labels = t_nodes[:-1]
|
|
747
|
+
elif len(t_nodes) == x_len:
|
|
748
|
+
x_labels = t_nodes
|
|
749
|
+
else:
|
|
750
|
+
x_labels = list(range(x_len))
|
|
751
|
+
|
|
752
|
+
fig = go.Figure(data=go.Heatmap(z=z, x=x_labels, y=state_names, colorscale="Magma"))
|
|
753
|
+
fig.update_layout(title="Virtual Control Magnitudes (last iteration)", template="plotly_dark")
|
|
754
|
+
fig.update_xaxes(title_text="Node Interval (N-1)")
|
|
755
|
+
fig.update_yaxes(title_text="State component")
|
|
756
|
+
return fig
|