openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,756 @@
1
+ import plotly.graph_objects as go
2
+ from plotly.subplots import make_subplots
3
+
4
+ from openscvx.algorithms import OptimizationResults
5
+
6
+
7
+ def _get_var(result: OptimizationResults, var_name: str, var_list: list):
8
+ """Get a variable object by name from the metadata list."""
9
+ for var in var_list:
10
+ if var.name == var_name:
11
+ return var
12
+ raise ValueError(f"Variable '{var_name}' not found")
13
+
14
+
15
+ def _get_var_dim(result: OptimizationResults, var_name: str, var_list: list) -> int:
16
+ """Get dimensionality of a variable from the metadata."""
17
+ var = _get_var(result, var_name, var_list)
18
+ s = var._slice
19
+ if isinstance(s, slice):
20
+ return (s.stop or 1) - (s.start or 0)
21
+ return 1
22
+
23
+
24
+ def _add_component_traces(
25
+ fig: go.Figure,
26
+ result: OptimizationResults,
27
+ var_name: str,
28
+ component_idx: int,
29
+ row: int,
30
+ col: int,
31
+ show_legend: bool,
32
+ min_val: float | None = None,
33
+ max_val: float | None = None,
34
+ ):
35
+ """Add traces for a single component of a variable to a subplot.
36
+
37
+ Args:
38
+ fig: Plotly figure to add traces to
39
+ result: Optimization results
40
+ var_name: Name of the variable
41
+ component_idx: Index of the component to plot
42
+ row: Subplot row
43
+ col: Subplot column
44
+ show_legend: Whether to show legend entries
45
+ min_val: Optional minimum bound to show as horizontal line
46
+ max_val: Optional maximum bound to show as horizontal line
47
+ """
48
+ import numpy as np
49
+
50
+ t_nodes = result.nodes["time"].flatten()
51
+ has_trajectory = bool(result.trajectory) and var_name in result.trajectory
52
+ t_full = result.trajectory["time"].flatten() if has_trajectory else None
53
+
54
+ # Plot propagated trajectory if available
55
+ if has_trajectory:
56
+ data = result.trajectory[var_name]
57
+ y = data if data.ndim == 1 else data[:, component_idx]
58
+ fig.add_trace(
59
+ go.Scatter(
60
+ x=t_full,
61
+ y=y,
62
+ mode="lines",
63
+ name="Trajectory",
64
+ showlegend=show_legend,
65
+ legendgroup="trajectory",
66
+ line={"color": "green", "width": 2},
67
+ ),
68
+ row=row,
69
+ col=col,
70
+ )
71
+
72
+ # Plot optimization nodes
73
+ if var_name in result.nodes:
74
+ data = result.nodes[var_name]
75
+ y = data if data.ndim == 1 else data[:, component_idx]
76
+ fig.add_trace(
77
+ go.Scatter(
78
+ x=t_nodes,
79
+ y=y,
80
+ mode="markers",
81
+ name="Nodes",
82
+ showlegend=show_legend,
83
+ legendgroup="nodes",
84
+ marker={"color": "cyan", "size": 6, "symbol": "circle"},
85
+ ),
86
+ row=row,
87
+ col=col,
88
+ )
89
+
90
+ # Add horizontal bound lines if provided
91
+ # Only add if finite (skip -inf/inf bounds)
92
+ if min_val is not None and np.isfinite(min_val):
93
+ fig.add_hline(
94
+ y=min_val,
95
+ line={"color": "red", "width": 1.5, "dash": "dash"},
96
+ row=row,
97
+ col=col,
98
+ )
99
+ if max_val is not None and np.isfinite(max_val):
100
+ fig.add_hline(
101
+ y=max_val,
102
+ line={"color": "red", "width": 1.5, "dash": "dash"},
103
+ row=row,
104
+ col=col,
105
+ )
106
+
107
+
108
+ # =============================================================================
109
+ # State Plotting
110
+ # =============================================================================
111
+
112
+
113
+ def plot_state_component(
114
+ result: OptimizationResults,
115
+ state_name: str,
116
+ component: int = 0,
117
+ ) -> go.Figure:
118
+ """Plot a single component of a state variable vs time.
119
+
120
+ This is the low-level function for plotting one scalar value over time.
121
+ For plotting all components of a state, use plot_states().
122
+
123
+ Args:
124
+ result: Optimization results containing state trajectories
125
+ state_name: Name of the state variable
126
+ component: Component index (0-indexed). For scalar states, use 0.
127
+
128
+ Returns:
129
+ Plotly figure with single plot
130
+
131
+ Example:
132
+ >>> plot_state_component(result, "position", 2) # Plot z-component
133
+ """
134
+ available = {s.name for s in result._states}
135
+ if state_name not in available:
136
+ raise ValueError(f"State '{state_name}' not found. Available: {sorted(available)}")
137
+
138
+ dim = _get_var_dim(result, state_name, result._states)
139
+ if component < 0 or component >= dim:
140
+ raise ValueError(f"Component {component} out of range for '{state_name}' (dim={dim})")
141
+
142
+ t_nodes = result.nodes["time"].flatten()
143
+ has_trajectory = bool(result.trajectory) and state_name in result.trajectory
144
+ t_full = result.trajectory["time"].flatten() if has_trajectory else None
145
+
146
+ label = f"{state_name}_{component}" if dim > 1 else state_name
147
+
148
+ fig = go.Figure()
149
+ fig.update_layout(title_text=label, template="plotly_dark")
150
+
151
+ if has_trajectory:
152
+ data = result.trajectory[state_name]
153
+ y = data if data.ndim == 1 else data[:, component]
154
+ fig.add_trace(
155
+ go.Scatter(
156
+ x=t_full,
157
+ y=y,
158
+ mode="lines",
159
+ name="Trajectory",
160
+ line={"color": "green", "width": 2},
161
+ )
162
+ )
163
+
164
+ if state_name in result.nodes:
165
+ data = result.nodes[state_name]
166
+ y = data if data.ndim == 1 else data[:, component]
167
+ fig.add_trace(
168
+ go.Scatter(
169
+ x=t_nodes,
170
+ y=y,
171
+ mode="markers",
172
+ name="Nodes",
173
+ marker={"color": "cyan", "size": 6},
174
+ )
175
+ )
176
+
177
+ fig.update_xaxes(title_text="Time (s)")
178
+ fig.update_yaxes(title_text=label)
179
+ return fig
180
+
181
+
182
+ def plot_states(
183
+ result: OptimizationResults,
184
+ state_names: list[str] | None = None,
185
+ include_private: bool = False,
186
+ cols: int = 4,
187
+ ) -> go.Figure:
188
+ """Plot state variables in a subplot grid.
189
+
190
+ Each component of each state gets its own subplot with individual y-axis
191
+ scaling. This is the primary function for visualizing state trajectories.
192
+
193
+ Args:
194
+ result: Optimization results containing state trajectories
195
+ state_names: List of state names to plot. If None, plots all states.
196
+ include_private: Whether to include private states (names starting with '_')
197
+ cols: Maximum number of columns in subplot grid
198
+
199
+ Returns:
200
+ Plotly figure with subplot grid
201
+
202
+ Examples:
203
+ >>> plot_states(result, ["position"]) # 3 subplots for x, y, z
204
+ >>> plot_states(result, ["position", "velocity"]) # 6 subplots
205
+ >>> plot_states(result) # All states
206
+ """
207
+
208
+ states = result._states
209
+ if not include_private:
210
+ states = [s for s in states if not s.name.startswith("_")]
211
+
212
+ if state_names is not None:
213
+ available = {s.name for s in states}
214
+ missing = set(state_names) - available
215
+ if missing:
216
+ raise ValueError(f"States not found in result: {missing}")
217
+ # Preserve order from state_names
218
+ state_order = {name: i for i, name in enumerate(state_names)}
219
+ states = sorted(
220
+ [s for s in states if s.name in state_names],
221
+ key=lambda s: state_order[s.name],
222
+ )
223
+
224
+ # Build list of (display_name, var_name, component_idx)
225
+ components = []
226
+ for s in states:
227
+ dim = _get_var_dim(result, s.name, result._states)
228
+ if dim == 1:
229
+ components.append((s.name, s.name, 0))
230
+ else:
231
+ for i in range(dim):
232
+ components.append((f"{s.name}_{i}", s.name, i))
233
+
234
+ if not components:
235
+ raise ValueError("No state components to plot")
236
+
237
+ n_cols = min(cols, len(components))
238
+ n_rows = (len(components) + n_cols - 1) // n_cols
239
+
240
+ fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=[c[0] for c in components])
241
+ fig.update_layout(title_text="State Trajectories", template="plotly_dark")
242
+
243
+ for idx, (_, var_name, comp_idx) in enumerate(components):
244
+ row = (idx // n_cols) + 1
245
+ col = (idx % n_cols) + 1
246
+
247
+ # Get bounds for this component
248
+ var = _get_var(result, var_name, result._states)
249
+ min_val = var.min[comp_idx] if var.min is not None else None
250
+ max_val = var.max[comp_idx] if var.max is not None else None
251
+
252
+ _add_component_traces(
253
+ fig,
254
+ result,
255
+ var_name,
256
+ comp_idx,
257
+ row,
258
+ col,
259
+ show_legend=(idx == 0),
260
+ min_val=min_val,
261
+ max_val=max_val,
262
+ )
263
+
264
+ # Add x-axis labels to bottom row
265
+ for col_idx in range(1, n_cols + 1):
266
+ fig.update_xaxes(title_text="Time (s)", row=n_rows, col=col_idx)
267
+
268
+ return fig
269
+
270
+
271
+ # =============================================================================
272
+ # Control Plotting
273
+ # =============================================================================
274
+
275
+
276
+ def plot_control_component(
277
+ result: OptimizationResults,
278
+ control_name: str,
279
+ component: int = 0,
280
+ ) -> go.Figure:
281
+ """Plot a single component of a control variable vs time.
282
+
283
+ This is the low-level function for plotting one scalar control over time.
284
+ For plotting all components of a control, use plot_controls().
285
+
286
+ Args:
287
+ result: Optimization results containing control trajectories
288
+ control_name: Name of the control variable
289
+ component: Component index (0-indexed). For scalar controls, use 0.
290
+
291
+ Returns:
292
+ Plotly figure with single plot
293
+
294
+ Example:
295
+ >>> plot_control_component(result, "thrust", 0) # Plot thrust_x
296
+ """
297
+ available = {c.name for c in result._controls}
298
+ if control_name not in available:
299
+ raise ValueError(f"Control '{control_name}' not found. Available: {sorted(available)}")
300
+
301
+ dim = _get_var_dim(result, control_name, result._controls)
302
+ if component < 0 or component >= dim:
303
+ raise ValueError(f"Component {component} out of range for '{control_name}' (dim={dim})")
304
+
305
+ t_nodes = result.nodes["time"].flatten()
306
+ has_trajectory = bool(result.trajectory) and control_name in result.trajectory
307
+ t_full = result.trajectory["time"].flatten() if has_trajectory else None
308
+
309
+ label = f"{control_name}_{component}" if dim > 1 else control_name
310
+
311
+ fig = go.Figure()
312
+ fig.update_layout(title_text=label, template="plotly_dark")
313
+
314
+ if has_trajectory:
315
+ data = result.trajectory[control_name]
316
+ y = data if data.ndim == 1 else data[:, component]
317
+ fig.add_trace(
318
+ go.Scatter(
319
+ x=t_full,
320
+ y=y,
321
+ mode="lines",
322
+ name="Trajectory",
323
+ line={"color": "green", "width": 2},
324
+ )
325
+ )
326
+
327
+ if control_name in result.nodes:
328
+ data = result.nodes[control_name]
329
+ y = data if data.ndim == 1 else data[:, component]
330
+ fig.add_trace(
331
+ go.Scatter(
332
+ x=t_nodes,
333
+ y=y,
334
+ mode="markers",
335
+ name="Nodes",
336
+ marker={"color": "cyan", "size": 6},
337
+ )
338
+ )
339
+
340
+ fig.update_xaxes(title_text="Time (s)")
341
+ fig.update_yaxes(title_text=label)
342
+ return fig
343
+
344
+
345
+ def plot_controls(
346
+ result: OptimizationResults,
347
+ control_names: list[str] | None = None,
348
+ include_private: bool = False,
349
+ cols: int = 3,
350
+ ) -> go.Figure:
351
+ """Plot control variables in a subplot grid.
352
+
353
+ Each component of each control gets its own subplot with individual y-axis
354
+ scaling. This is the primary function for visualizing control trajectories.
355
+
356
+ Args:
357
+ result: Optimization results containing control trajectories
358
+ control_names: List of control names to plot. If None, plots all controls.
359
+ include_private: Whether to include private controls (names starting with '_')
360
+ cols: Maximum number of columns in subplot grid
361
+
362
+ Returns:
363
+ Plotly figure with subplot grid
364
+
365
+ Examples:
366
+ >>> plot_controls(result, ["thrust"]) # 3 subplots for x, y, z
367
+ >>> plot_controls(result) # All controls
368
+ """
369
+
370
+ controls = result._controls
371
+ if not include_private:
372
+ controls = [c for c in controls if not c.name.startswith("_")]
373
+
374
+ if control_names is not None:
375
+ available = {c.name for c in controls}
376
+ missing = set(control_names) - available
377
+ if missing:
378
+ raise ValueError(f"Controls not found in result: {missing}")
379
+ # Preserve order from control_names
380
+ control_order = {name: i for i, name in enumerate(control_names)}
381
+ controls = sorted(
382
+ [c for c in controls if c.name in control_names],
383
+ key=lambda c: control_order[c.name],
384
+ )
385
+
386
+ # Build list of (display_name, var_name, component_idx)
387
+ components = []
388
+ for c in controls:
389
+ dim = _get_var_dim(result, c.name, result._controls)
390
+ if dim == 1:
391
+ components.append((c.name, c.name, 0))
392
+ else:
393
+ for i in range(dim):
394
+ components.append((f"{c.name}_{i}", c.name, i))
395
+
396
+ if not components:
397
+ raise ValueError("No control components to plot")
398
+
399
+ n_cols = min(cols, len(components))
400
+ n_rows = (len(components) + n_cols - 1) // n_cols
401
+
402
+ fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=[c[0] for c in components])
403
+ fig.update_layout(title_text="Control Trajectories", template="plotly_dark")
404
+
405
+ for idx, (_, var_name, comp_idx) in enumerate(components):
406
+ row = (idx // n_cols) + 1
407
+ col = (idx % n_cols) + 1
408
+
409
+ # Get bounds for this component
410
+ var = _get_var(result, var_name, result._controls)
411
+ min_val = var.min[comp_idx] if var.min is not None else None
412
+ max_val = var.max[comp_idx] if var.max is not None else None
413
+
414
+ _add_component_traces(
415
+ fig,
416
+ result,
417
+ var_name,
418
+ comp_idx,
419
+ row,
420
+ col,
421
+ show_legend=(idx == 0),
422
+ min_val=min_val,
423
+ max_val=max_val,
424
+ )
425
+
426
+ # Add x-axis labels to bottom row
427
+ for col_idx in range(1, n_cols + 1):
428
+ fig.update_xaxes(title_text="Time (s)", row=n_rows, col=col_idx)
429
+
430
+ return fig
431
+
432
+
433
+ def plot_trust_region_heatmap(result: OptimizationResults):
434
+ """Plot heatmap of the final trust-region deltas (TR_history[-1])."""
435
+ if not result.TR_history:
436
+ raise ValueError("Result has no TR_history to plot")
437
+
438
+ tr_mat = result.TR_history[-1]
439
+
440
+ # Build variable names list
441
+ var_names = []
442
+ for var_list in [result._states, result._controls]:
443
+ for var in var_list:
444
+ dim = _get_var_dim(result, var.name, var_list)
445
+ if dim == 1:
446
+ var_names.append(var.name)
447
+ else:
448
+ var_names.extend(f"{var.name}_{i}" for i in range(dim))
449
+
450
+ # TR matrix is (n_states+n_controls, n_nodes): rows = variables, cols = nodes
451
+ if tr_mat.shape[0] == len(var_names):
452
+ z = tr_mat
453
+ elif tr_mat.shape[1] == len(var_names):
454
+ z = tr_mat.T
455
+ else:
456
+ raise ValueError("TR matrix dimensions do not align with state/control components")
457
+
458
+ x_len = z.shape[1]
459
+ t_nodes = result.nodes["time"].flatten()
460
+ x_labels = t_nodes if len(t_nodes) == x_len else list(range(x_len))
461
+
462
+ fig = go.Figure(data=go.Heatmap(z=z, x=x_labels, y=var_names, colorscale="Viridis"))
463
+ fig.update_layout(
464
+ title="Trust Region Delta Magnitudes (last iteration)", template="plotly_dark"
465
+ )
466
+ fig.update_xaxes(title_text="Node / Time", side="bottom")
467
+ fig.update_yaxes(title_text="State / Control component", side="left")
468
+ return fig
469
+
470
+
471
+ def plot_projections_2d(
472
+ result: OptimizationResults,
473
+ var_name: str = "position",
474
+ velocity_var_name: str | None = None,
475
+ cmap: str = "viridis",
476
+ ):
477
+ """Plot XY, XZ, YZ projections of a 3D variable.
478
+
479
+ Useful for visualizing 3D trajectories in 2D plane views.
480
+
481
+ Args:
482
+ result: Optimization results containing trajectories
483
+ var_name: Name of the 3D variable to plot (default: "position")
484
+ velocity_var_name: Optional name of velocity variable for coloring by speed.
485
+ If provided, trajectory points are colored by velocity magnitude.
486
+ cmap: Matplotlib colormap name for velocity coloring (default: "viridis")
487
+
488
+ Returns:
489
+ Plotly figure with three subplots (XY, XZ, YZ planes)
490
+ """
491
+ import numpy as np
492
+
493
+ has_trajectory = bool(result.trajectory) and var_name in result.trajectory
494
+ has_nodes = var_name in result.nodes
495
+
496
+ if not has_trajectory and not has_nodes:
497
+ available_traj = set(result.trajectory.keys()) if result.trajectory else set()
498
+ available_nodes = set(result.nodes.keys())
499
+ raise ValueError(
500
+ f"Variable '{var_name}' not found. "
501
+ f"Available in trajectory: {sorted(available_traj)}, nodes: {sorted(available_nodes)}"
502
+ )
503
+
504
+ fig = make_subplots(
505
+ rows=2,
506
+ cols=2,
507
+ subplot_titles=("XY Plane", "XZ Plane", "YZ Plane"),
508
+ specs=[[{}, {}], [{}, None]],
509
+ )
510
+
511
+ # Subplot positions: (x_idx, y_idx, row, col)
512
+ subplots = [(0, 1, 1, 1), (0, 2, 1, 2), (1, 2, 2, 1)]
513
+
514
+ # Compute velocity norms if velocity variable is provided
515
+ traj_vel_norm = None
516
+ node_vel_norm = None
517
+ if velocity_var_name is not None:
518
+ if has_trajectory and velocity_var_name in result.trajectory:
519
+ traj_vel_norm = np.linalg.norm(result.trajectory[velocity_var_name], axis=1)
520
+ if has_nodes and velocity_var_name in result.nodes:
521
+ node_vel_norm = np.linalg.norm(result.nodes[velocity_var_name], axis=1)
522
+
523
+ # Colorbar config (only shown once)
524
+ colorbar_cfg = {"title": "‖velocity‖", "x": 1.02, "y": 0.5, "len": 0.9}
525
+
526
+ # Plot trajectory if available
527
+ if has_trajectory:
528
+ data = result.trajectory[var_name]
529
+ for i, (xi, yi, row, col) in enumerate(subplots):
530
+ if traj_vel_norm is not None:
531
+ marker = {
532
+ "size": 4,
533
+ "color": traj_vel_norm,
534
+ "colorscale": cmap,
535
+ "showscale": (i == 0),
536
+ "colorbar": colorbar_cfg if i == 0 else None,
537
+ }
538
+ fig.add_trace(
539
+ go.Scatter(
540
+ x=data[:, xi],
541
+ y=data[:, yi],
542
+ mode="markers",
543
+ marker=marker,
544
+ name="Trajectory",
545
+ legendgroup="trajectory",
546
+ showlegend=(i == 0),
547
+ ),
548
+ row=row,
549
+ col=col,
550
+ )
551
+ else:
552
+ fig.add_trace(
553
+ go.Scatter(
554
+ x=data[:, xi],
555
+ y=data[:, yi],
556
+ mode="lines",
557
+ line={"color": "green", "width": 2},
558
+ name="Trajectory",
559
+ legendgroup="trajectory",
560
+ showlegend=(i == 0),
561
+ ),
562
+ row=row,
563
+ col=col,
564
+ )
565
+
566
+ # Plot nodes if available
567
+ if has_nodes:
568
+ data = result.nodes[var_name]
569
+ # Only show colorbar on nodes if trajectory doesn't have one
570
+ show_node_colorbar = (traj_vel_norm is None) and (node_vel_norm is not None)
571
+ for i, (xi, yi, row, col) in enumerate(subplots):
572
+ if node_vel_norm is not None:
573
+ marker = {
574
+ "size": 8,
575
+ "color": node_vel_norm,
576
+ "colorscale": cmap,
577
+ "showscale": show_node_colorbar and (i == 0),
578
+ "colorbar": colorbar_cfg if (show_node_colorbar and i == 0) else None,
579
+ "line": {"color": "white", "width": 1},
580
+ }
581
+ else:
582
+ marker = {"color": "cyan", "size": 6}
583
+ fig.add_trace(
584
+ go.Scatter(
585
+ x=data[:, xi],
586
+ y=data[:, yi],
587
+ mode="markers",
588
+ marker=marker,
589
+ name="Nodes",
590
+ legendgroup="nodes",
591
+ showlegend=(i == 0),
592
+ ),
593
+ row=row,
594
+ col=col,
595
+ )
596
+
597
+ # Set axis titles
598
+ fig.update_xaxes(title_text="X", row=1, col=1)
599
+ fig.update_yaxes(title_text="Y", row=1, col=1)
600
+ fig.update_xaxes(title_text="X", row=1, col=2)
601
+ fig.update_yaxes(title_text="Z", row=1, col=2)
602
+ fig.update_xaxes(title_text="Y", row=2, col=1)
603
+ fig.update_yaxes(title_text="Z", row=2, col=1)
604
+
605
+ # Set equal aspect ratio for each subplot
606
+ layout_opts = {
607
+ "title": f"{var_name} - XY, XZ, YZ Projections",
608
+ "template": "plotly_dark",
609
+ "xaxis": {"scaleanchor": "y"},
610
+ "xaxis2": {"scaleanchor": "y2"},
611
+ "xaxis3": {"scaleanchor": "y3"},
612
+ }
613
+ # Move legend to bottom-right when using colorbar to avoid overlap
614
+ if velocity_var_name is not None:
615
+ layout_opts["legend"] = {
616
+ "orientation": "h",
617
+ "yanchor": "bottom",
618
+ "y": -0.15,
619
+ "xanchor": "center",
620
+ "x": 0.5,
621
+ }
622
+ fig.update_layout(**layout_opts)
623
+
624
+ return fig
625
+
626
+
627
+ def plot_vector_norm(
628
+ result: OptimizationResults,
629
+ var_name: str,
630
+ bounds: tuple[float, float] | None = None,
631
+ ):
632
+ """Plot the 2-norm of a vector variable over time.
633
+
634
+ Useful for visualizing thrust magnitude, velocity magnitude, etc.
635
+
636
+ Args:
637
+ result: Optimization results containing trajectories
638
+ var_name: Name of the vector variable (state or control)
639
+ bounds: Optional (min, max) bounds to show as horizontal dashed lines
640
+
641
+ Returns:
642
+ Plotly figure
643
+ """
644
+ import numpy as np
645
+
646
+ has_trajectory = bool(result.trajectory) and var_name in result.trajectory
647
+ has_nodes = var_name in result.nodes
648
+
649
+ if not has_trajectory and not has_nodes:
650
+ available_traj = set(result.trajectory.keys()) if result.trajectory else set()
651
+ available_nodes = set(result.nodes.keys())
652
+ raise ValueError(
653
+ f"Variable '{var_name}' not found. "
654
+ f"Available in trajectory: {sorted(available_traj)}, nodes: {sorted(available_nodes)}"
655
+ )
656
+
657
+ fig = go.Figure()
658
+
659
+ # Plot trajectory norm if available
660
+ if has_trajectory:
661
+ t_full = result.trajectory["time"].flatten()
662
+ data = result.trajectory[var_name]
663
+ norm = np.linalg.norm(data, axis=1)
664
+ fig.add_trace(
665
+ go.Scatter(
666
+ x=t_full,
667
+ y=norm,
668
+ mode="lines",
669
+ line={"color": "green", "width": 2},
670
+ name="Trajectory",
671
+ legendgroup="trajectory",
672
+ )
673
+ )
674
+
675
+ # Plot node norms if available
676
+ if has_nodes:
677
+ t_nodes = result.nodes["time"].flatten()
678
+ data = result.nodes[var_name]
679
+ norm = np.linalg.norm(data, axis=1)
680
+ fig.add_trace(
681
+ go.Scatter(
682
+ x=t_nodes,
683
+ y=norm,
684
+ mode="markers",
685
+ marker={"color": "cyan", "size": 6},
686
+ name="Nodes",
687
+ legendgroup="nodes",
688
+ )
689
+ )
690
+
691
+ # Add bounds if provided
692
+ if bounds is not None:
693
+ min_bound, max_bound = bounds
694
+ fig.add_hline(
695
+ y=min_bound,
696
+ line={"color": "red", "width": 2, "dash": "dash"},
697
+ annotation_text="Min",
698
+ annotation_position="right",
699
+ )
700
+ fig.add_hline(
701
+ y=max_bound,
702
+ line={"color": "red", "width": 2, "dash": "dash"},
703
+ annotation_text="Max",
704
+ annotation_position="right",
705
+ )
706
+
707
+ fig.update_layout(
708
+ title=f"‖{var_name}‖₂",
709
+ xaxis_title="Time (s)",
710
+ yaxis_title="Norm",
711
+ template="plotly_dark",
712
+ )
713
+
714
+ return fig
715
+
716
+
717
+ def plot_virtual_control_heatmap(result: OptimizationResults):
718
+ """Plot heatmap of the final virtual control magnitudes (VC_history[-1])."""
719
+ if not result.VC_history:
720
+ raise ValueError("Result has no VC_history to plot")
721
+
722
+ vc_mat = result.VC_history[-1]
723
+
724
+ # Build state names list
725
+ state_names = []
726
+ for var in result._states:
727
+ dim = _get_var_dim(result, var.name, result._states)
728
+ if dim == 1:
729
+ state_names.append(var.name)
730
+ else:
731
+ state_names.extend(f"{var.name}_{i}" for i in range(dim))
732
+
733
+ # Align so rows = states, cols = nodes
734
+ if vc_mat.shape[1] == len(state_names):
735
+ z = vc_mat.T
736
+ elif vc_mat.shape[0] == len(state_names):
737
+ z = vc_mat
738
+ else:
739
+ raise ValueError("VC matrix shape does not align with state components")
740
+
741
+ x_len = z.shape[1]
742
+ t_nodes = result.nodes["time"].flatten()
743
+
744
+ # Virtual control uses N-1 intervals
745
+ if len(t_nodes) == x_len + 1:
746
+ x_labels = t_nodes[:-1]
747
+ elif len(t_nodes) == x_len:
748
+ x_labels = t_nodes
749
+ else:
750
+ x_labels = list(range(x_len))
751
+
752
+ fig = go.Figure(data=go.Heatmap(z=z, x=x_labels, y=state_names, colorscale="Magma"))
753
+ fig.update_layout(title="Virtual Control Magnitudes (last iteration)", template="plotly_dark")
754
+ fig.update_xaxes(title_text="Node Interval (N-1)")
755
+ fig.update_yaxes(title_text="State component")
756
+ return fig