haoline 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. haoline/.streamlit/config.toml +10 -0
  2. haoline/__init__.py +248 -0
  3. haoline/analyzer.py +935 -0
  4. haoline/cli.py +2712 -0
  5. haoline/compare.py +811 -0
  6. haoline/compare_visualizations.py +1564 -0
  7. haoline/edge_analysis.py +525 -0
  8. haoline/eval/__init__.py +131 -0
  9. haoline/eval/adapters.py +844 -0
  10. haoline/eval/cli.py +390 -0
  11. haoline/eval/comparison.py +542 -0
  12. haoline/eval/deployment.py +633 -0
  13. haoline/eval/schemas.py +833 -0
  14. haoline/examples/__init__.py +15 -0
  15. haoline/examples/basic_inspection.py +74 -0
  16. haoline/examples/compare_models.py +117 -0
  17. haoline/examples/hardware_estimation.py +78 -0
  18. haoline/format_adapters.py +1001 -0
  19. haoline/formats/__init__.py +123 -0
  20. haoline/formats/coreml.py +250 -0
  21. haoline/formats/gguf.py +483 -0
  22. haoline/formats/openvino.py +255 -0
  23. haoline/formats/safetensors.py +273 -0
  24. haoline/formats/tflite.py +369 -0
  25. haoline/hardware.py +2307 -0
  26. haoline/hierarchical_graph.py +462 -0
  27. haoline/html_export.py +1573 -0
  28. haoline/layer_summary.py +769 -0
  29. haoline/llm_summarizer.py +465 -0
  30. haoline/op_icons.py +618 -0
  31. haoline/operational_profiling.py +1492 -0
  32. haoline/patterns.py +1116 -0
  33. haoline/pdf_generator.py +265 -0
  34. haoline/privacy.py +250 -0
  35. haoline/pydantic_models.py +241 -0
  36. haoline/report.py +1923 -0
  37. haoline/report_sections.py +539 -0
  38. haoline/risks.py +521 -0
  39. haoline/schema.py +523 -0
  40. haoline/streamlit_app.py +2024 -0
  41. haoline/tests/__init__.py +4 -0
  42. haoline/tests/conftest.py +123 -0
  43. haoline/tests/test_analyzer.py +868 -0
  44. haoline/tests/test_compare_visualizations.py +293 -0
  45. haoline/tests/test_edge_analysis.py +243 -0
  46. haoline/tests/test_eval.py +604 -0
  47. haoline/tests/test_format_adapters.py +460 -0
  48. haoline/tests/test_hardware.py +237 -0
  49. haoline/tests/test_hardware_recommender.py +90 -0
  50. haoline/tests/test_hierarchical_graph.py +326 -0
  51. haoline/tests/test_html_export.py +180 -0
  52. haoline/tests/test_layer_summary.py +428 -0
  53. haoline/tests/test_llm_patterns.py +540 -0
  54. haoline/tests/test_llm_summarizer.py +339 -0
  55. haoline/tests/test_patterns.py +774 -0
  56. haoline/tests/test_pytorch.py +327 -0
  57. haoline/tests/test_report.py +383 -0
  58. haoline/tests/test_risks.py +398 -0
  59. haoline/tests/test_schema.py +417 -0
  60. haoline/tests/test_tensorflow.py +380 -0
  61. haoline/tests/test_visualizations.py +316 -0
  62. haoline/universal_ir.py +856 -0
  63. haoline/visualizations.py +1086 -0
  64. haoline/visualize_yolo.py +44 -0
  65. haoline/web.py +110 -0
  66. haoline-0.3.0.dist-info/METADATA +471 -0
  67. haoline-0.3.0.dist-info/RECORD +70 -0
  68. haoline-0.3.0.dist-info/WHEEL +4 -0
  69. haoline-0.3.0.dist-info/entry_points.txt +5 -0
  70. haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,1086 @@
1
+ # Copyright (c) 2025 HaoLine Contributors
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """
5
+ Visualization module for HaoLine.
6
+
7
+ Generates matplotlib-based charts for model architecture analysis:
8
+ - Operator type histogram
9
+ - Layer depth profile (cumulative params/FLOPs)
10
+ - Parameter distribution by layer type
11
+ - Shape evolution through the network
12
+
13
+ All charts use a consistent dark theme suitable for technical documentation.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import logging
19
+ from dataclasses import dataclass
20
+ from pathlib import Path
21
+ from typing import TYPE_CHECKING
22
+
23
+ if TYPE_CHECKING:
24
+ from .analyzer import FlopCounts, GraphInfo, ParamCounts
25
+ from .operational_profiling import BatchSizeSweep, ProfilingResult, ResolutionSweep
26
+ from .report import InspectionReport
27
+
28
+ # Attempt to import matplotlib with Agg backend (non-interactive)
29
+ _MATPLOTLIB_AVAILABLE = False
30
+ try:
31
+ import matplotlib
32
+
33
+ matplotlib.use("Agg") # Must be before importing pyplot
34
+ import matplotlib.pyplot as plt
35
+ from matplotlib.figure import Figure
36
+
37
+ _MATPLOTLIB_AVAILABLE = True
38
+ except ImportError:
39
+ plt = None # type: ignore
40
+ Figure = None # type: ignore
41
+
42
+
43
+ def is_available() -> bool:
44
+ """Check if visualization module is available (matplotlib installed)."""
45
+ return _MATPLOTLIB_AVAILABLE
46
+
47
+
48
+ @dataclass
49
+ class ChartTheme:
50
+ """Consistent theme for all charts."""
51
+
52
+ # Colors - dark theme with vibrant accents
53
+ background: str = "#1a1a2e"
54
+ plot_background: str = "#16213e" # Slightly lighter for plot areas
55
+ text: str = "#eaeaea"
56
+ grid: str = "#2d2d44"
57
+ accent: str = "#00d9ff" # Primary accent (alias for accent_primary)
58
+ accent_primary: str = "#00d9ff" # Cyan
59
+ accent_secondary: str = "#ff6b6b" # Coral
60
+ accent_tertiary: str = "#4ecdc4" # Teal
61
+ accent_quaternary: str = "#ffe66d" # Yellow
62
+
63
+ # Color palette for multi-series charts
64
+ palette: tuple[str, ...] = (
65
+ "#00d9ff", # Cyan
66
+ "#ff6b6b", # Coral
67
+ "#4ecdc4", # Teal
68
+ "#ffe66d", # Yellow
69
+ "#a29bfe", # Lavender
70
+ "#fd79a8", # Pink
71
+ "#55efc4", # Mint
72
+ "#ffeaa7", # Pale yellow
73
+ "#74b9ff", # Light blue
74
+ "#ff7675", # Salmon
75
+ )
76
+
77
+ # Alias for palette (some code uses 'colors')
78
+ @property
79
+ def colors(self) -> tuple[str, ...]:
80
+ return self.palette
81
+
82
+ # Typography
83
+ font_family: str = "sans-serif"
84
+ title_size: int = 14
85
+ label_size: int = 11
86
+ tick_size: int = 9
87
+
88
+ # Figure
89
+ figure_dpi: int = 150
90
+ figure_width: float = 10.0
91
+ figure_height: float = 6.0
92
+
93
+
94
+ # Default theme instance
95
+ THEME = ChartTheme()
96
+
97
+
98
+ def _apply_theme(fig: Figure, ax, title: str) -> None:
99
+ """Apply consistent theme to a matplotlib figure and axes."""
100
+ if not _MATPLOTLIB_AVAILABLE:
101
+ return
102
+
103
+ # Figure background
104
+ fig.patch.set_facecolor(THEME.background)
105
+
106
+ # Axes styling
107
+ ax.set_facecolor(THEME.background)
108
+ ax.set_title(title, color=THEME.text, fontsize=THEME.title_size, fontweight="bold", pad=15)
109
+
110
+ # Spines
111
+ for spine in ax.spines.values():
112
+ spine.set_color(THEME.grid)
113
+ spine.set_linewidth(0.5)
114
+
115
+ # Ticks and labels
116
+ ax.tick_params(colors=THEME.text, labelsize=THEME.tick_size)
117
+ ax.xaxis.label.set_color(THEME.text)
118
+ ax.yaxis.label.set_color(THEME.text)
119
+ ax.xaxis.label.set_fontsize(THEME.label_size)
120
+ ax.yaxis.label.set_fontsize(THEME.label_size)
121
+
122
+ # Grid
123
+ ax.grid(True, linestyle="--", alpha=0.3, color=THEME.grid)
124
+ ax.set_axisbelow(True)
125
+
126
+
127
+ def _format_count(n: int) -> str:
128
+ """Format large numbers with K/M/B suffixes."""
129
+ if n >= 1e9:
130
+ return f"{n / 1e9:.1f}B"
131
+ if n >= 1e6:
132
+ return f"{n / 1e6:.1f}M"
133
+ if n >= 1e3:
134
+ return f"{n / 1e3:.1f}K"
135
+ return str(n)
136
+
137
+
138
+ class VisualizationGenerator:
139
+ """
140
+ Generate visualization assets for ONNX model reports.
141
+
142
+ Usage:
143
+ viz = VisualizationGenerator()
144
+ paths = viz.generate_all(report, output_dir=Path("assets"))
145
+ """
146
+
147
+ def __init__(self, logger: logging.Logger | None = None):
148
+ self.logger = logger or logging.getLogger("haoline.viz")
149
+
150
+ def generate_all(
151
+ self,
152
+ report: InspectionReport,
153
+ output_dir: Path,
154
+ ) -> dict[str, Path]:
155
+ """
156
+ Generate all visualization assets for a report.
157
+
158
+ Args:
159
+ report: The inspection report containing analysis results.
160
+ output_dir: Directory to save PNG files.
161
+
162
+ Returns:
163
+ Dict mapping chart name to file path, e.g.:
164
+ {"op_histogram": Path("assets/op_histogram.png"), ...}
165
+ """
166
+ if not _MATPLOTLIB_AVAILABLE:
167
+ self.logger.warning("matplotlib not available, skipping visualizations")
168
+ return {}
169
+
170
+ output_dir = Path(output_dir)
171
+ output_dir.mkdir(parents=True, exist_ok=True)
172
+
173
+ paths: dict[str, Path] = {}
174
+
175
+ # Generate each chart type
176
+ try:
177
+ if report.graph_summary and report.graph_summary.op_type_counts:
178
+ path = self.operator_histogram(
179
+ report.graph_summary.op_type_counts,
180
+ output_dir / "op_histogram.png",
181
+ )
182
+ if path:
183
+ paths["op_histogram"] = path
184
+ except Exception as e:
185
+ self.logger.warning(f"Failed to generate operator histogram: {e}")
186
+
187
+ try:
188
+ if report.param_counts and report.param_counts.by_op_type:
189
+ path = self.param_distribution(
190
+ report.param_counts.by_op_type,
191
+ output_dir / "param_distribution.png",
192
+ )
193
+ if path:
194
+ paths["param_distribution"] = path
195
+ except Exception as e:
196
+ self.logger.warning(f"Failed to generate param distribution: {e}")
197
+
198
+ try:
199
+ if report.flop_counts and report.flop_counts.by_op_type:
200
+ path = self.flops_distribution(
201
+ report.flop_counts.by_op_type,
202
+ output_dir / "flops_distribution.png",
203
+ )
204
+ if path:
205
+ paths["flops_distribution"] = path
206
+ except Exception as e:
207
+ self.logger.warning(f"Failed to generate FLOPs distribution: {e}")
208
+
209
+ try:
210
+ if report.param_counts and report.flop_counts:
211
+ path = self.complexity_summary(
212
+ report,
213
+ output_dir / "complexity_summary.png",
214
+ )
215
+ if path:
216
+ paths["complexity_summary"] = path
217
+ except Exception as e:
218
+ self.logger.warning(f"Failed to generate complexity summary: {e}")
219
+
220
+ # Resolution sweep chart (Story 6.8)
221
+ try:
222
+ if hasattr(report, "resolution_sweep") and report.resolution_sweep:
223
+ path = self.resolution_scaling_chart(
224
+ report.resolution_sweep,
225
+ output_dir / "resolution_scaling.png",
226
+ )
227
+ if path:
228
+ paths["resolution_scaling"] = path
229
+ except Exception as e:
230
+ self.logger.warning(f"Failed to generate resolution scaling chart: {e}")
231
+
232
+ # Batch size sweep chart
233
+ try:
234
+ if hasattr(report, "batch_size_sweep") and report.batch_size_sweep:
235
+ path = self.batch_scaling_chart(
236
+ report.batch_size_sweep,
237
+ output_dir / "batch_scaling.png",
238
+ )
239
+ if path:
240
+ paths["batch_scaling"] = path
241
+ except Exception as e:
242
+ self.logger.warning(f"Failed to generate batch scaling chart: {e}")
243
+
244
+ self.logger.info(f"Generated {len(paths)} visualization assets in {output_dir}")
245
+ return paths
246
+
247
+ def resolution_scaling_chart(
248
+ self,
249
+ sweep: ResolutionSweep,
250
+ output_path: Path,
251
+ ) -> Path | None:
252
+ """
253
+ Generate resolution scaling chart.
254
+
255
+ Shows how latency, throughput, and VRAM scale with resolution.
256
+ """
257
+ if not _MATPLOTLIB_AVAILABLE:
258
+ return None
259
+
260
+ # Import at runtime for isinstance check (avoid circular import)
261
+ from .operational_profiling import ResolutionSweep
262
+
263
+ if not isinstance(sweep, ResolutionSweep):
264
+ return None
265
+
266
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
267
+ fig.patch.set_facecolor(THEME.background)
268
+
269
+ resolutions = sweep.resolutions
270
+ x = range(len(resolutions))
271
+
272
+ # Latency chart
273
+ ax1 = axes[0]
274
+ ax1.set_facecolor(THEME.plot_background)
275
+ latencies = [lat if lat != float("inf") else 0 for lat in sweep.latencies]
276
+ ax1.bar(x, latencies, color=THEME.palette[0], alpha=0.8)
277
+ ax1.set_xlabel("Resolution", color=THEME.text, fontsize=10)
278
+ ax1.set_ylabel("Latency (ms)", color=THEME.text, fontsize=10)
279
+ ax1.set_title("Latency vs Resolution", color=THEME.text, fontsize=12, fontweight="bold")
280
+ ax1.set_xticks(list(x))
281
+ ax1.set_xticklabels(resolutions, rotation=45, ha="right", fontsize=8)
282
+ ax1.tick_params(colors=THEME.text)
283
+ for spine in ax1.spines.values():
284
+ spine.set_color(THEME.grid)
285
+ ax1.grid(True, alpha=0.3, color=THEME.grid)
286
+
287
+ # Mark OOM points
288
+ for i, lat in enumerate(sweep.latencies):
289
+ if lat == float("inf"):
290
+ ax1.annotate(
291
+ "OOM",
292
+ (i, 0),
293
+ ha="center",
294
+ va="bottom",
295
+ color=THEME.accent,
296
+ fontsize=8,
297
+ )
298
+
299
+ # Throughput chart
300
+ ax2 = axes[1]
301
+ ax2.set_facecolor(THEME.plot_background)
302
+ ax2.bar(x, sweep.throughputs, color=THEME.palette[1], alpha=0.8)
303
+ ax2.set_xlabel("Resolution", color=THEME.text, fontsize=10)
304
+ ax2.set_ylabel("Throughput (inf/s)", color=THEME.text, fontsize=10)
305
+ ax2.set_title("Throughput vs Resolution", color=THEME.text, fontsize=12, fontweight="bold")
306
+ ax2.set_xticks(list(x))
307
+ ax2.set_xticklabels(resolutions, rotation=45, ha="right", fontsize=8)
308
+ ax2.tick_params(colors=THEME.text)
309
+ for spine in ax2.spines.values():
310
+ spine.set_color(THEME.grid)
311
+ ax2.grid(True, alpha=0.3, color=THEME.grid)
312
+
313
+ # VRAM chart
314
+ ax3 = axes[2]
315
+ ax3.set_facecolor(THEME.plot_background)
316
+ ax3.bar(x, sweep.vram_usage_gb, color=THEME.palette[2], alpha=0.8)
317
+ ax3.set_xlabel("Resolution", color=THEME.text, fontsize=10)
318
+ ax3.set_ylabel("VRAM (GB)", color=THEME.text, fontsize=10)
319
+ ax3.set_title("VRAM vs Resolution", color=THEME.text, fontsize=12, fontweight="bold")
320
+ ax3.set_xticks(list(x))
321
+ ax3.set_xticklabels(resolutions, rotation=45, ha="right", fontsize=8)
322
+ ax3.tick_params(colors=THEME.text)
323
+ for spine in ax3.spines.values():
324
+ spine.set_color(THEME.grid)
325
+ ax3.grid(True, alpha=0.3, color=THEME.grid)
326
+
327
+ plt.tight_layout()
328
+ fig.savefig(output_path, dpi=150, facecolor=THEME.background)
329
+ plt.close(fig)
330
+
331
+ return output_path
332
+
333
+ def batch_scaling_chart(
334
+ self,
335
+ sweep: BatchSizeSweep,
336
+ output_path: Path,
337
+ ) -> Path | None:
338
+ """
339
+ Generate batch size scaling chart.
340
+
341
+ Shows how latency, throughput, and VRAM scale with batch size.
342
+ """
343
+ if not _MATPLOTLIB_AVAILABLE:
344
+ return None
345
+
346
+ # Import at runtime for isinstance check (avoid circular import)
347
+ from .operational_profiling import BatchSizeSweep
348
+
349
+ if not isinstance(sweep, BatchSizeSweep):
350
+ return None
351
+
352
+ fig, axes = plt.subplots(1, 3, figsize=(18, 6))
353
+ fig.patch.set_facecolor(THEME.background)
354
+ plt.subplots_adjust(wspace=0.3) # Add spacing between subplots
355
+
356
+ batch_sizes = sweep.batch_sizes
357
+
358
+ # Latency chart
359
+ ax1 = axes[0]
360
+ ax1.set_facecolor(THEME.plot_background)
361
+ latencies = [lat if lat != float("inf") else 0 for lat in sweep.latencies]
362
+ ax1.plot(
363
+ batch_sizes,
364
+ latencies,
365
+ color=THEME.palette[0],
366
+ marker="o",
367
+ linewidth=2,
368
+ markersize=8,
369
+ )
370
+ ax1.set_xlabel("Batch Size", color=THEME.text, fontsize=12)
371
+ ax1.set_ylabel("Latency (ms)", color=THEME.text, fontsize=12)
372
+ ax1.set_title("Latency vs Batch Size", color=THEME.text, fontsize=14, fontweight="bold")
373
+ ax1.tick_params(colors=THEME.text, labelsize=10)
374
+ for spine in ax1.spines.values():
375
+ spine.set_color(THEME.grid)
376
+ ax1.grid(True, alpha=0.3, color=THEME.grid)
377
+
378
+ # Throughput chart
379
+ ax2 = axes[1]
380
+ ax2.set_facecolor(THEME.plot_background)
381
+ ax2.plot(
382
+ batch_sizes,
383
+ sweep.throughputs,
384
+ color=THEME.palette[1],
385
+ marker="o",
386
+ linewidth=2,
387
+ markersize=8,
388
+ )
389
+ ax2.axvline(
390
+ x=sweep.optimal_batch_size,
391
+ color=THEME.accent_tertiary,
392
+ linestyle="--",
393
+ linewidth=2,
394
+ alpha=0.8,
395
+ label=f"Optimal: {sweep.optimal_batch_size}",
396
+ )
397
+ ax2.set_xlabel("Batch Size", color=THEME.text, fontsize=12)
398
+ ax2.set_ylabel("Throughput (inf/s)", color=THEME.text, fontsize=12)
399
+ ax2.set_title("Throughput vs Batch Size", color=THEME.text, fontsize=14, fontweight="bold")
400
+ ax2.tick_params(colors=THEME.text, labelsize=10)
401
+ ax2.legend(
402
+ facecolor=THEME.plot_background,
403
+ edgecolor=THEME.grid,
404
+ labelcolor=THEME.text,
405
+ fontsize=11,
406
+ )
407
+ for spine in ax2.spines.values():
408
+ spine.set_color(THEME.grid)
409
+ ax2.grid(True, alpha=0.3, color=THEME.grid)
410
+
411
+ # VRAM chart
412
+ ax3 = axes[2]
413
+ ax3.set_facecolor(THEME.plot_background)
414
+ ax3.plot(
415
+ batch_sizes,
416
+ sweep.vram_usage_gb,
417
+ color=THEME.palette[2],
418
+ marker="o",
419
+ linewidth=2,
420
+ markersize=8,
421
+ )
422
+ ax3.set_xlabel("Batch Size", color=THEME.text, fontsize=12)
423
+ ax3.set_ylabel("VRAM (GB)", color=THEME.text, fontsize=12)
424
+ ax3.set_title("VRAM vs Batch Size", color=THEME.text, fontsize=14, fontweight="bold")
425
+ ax3.tick_params(colors=THEME.text, labelsize=10)
426
+ for spine in ax3.spines.values():
427
+ spine.set_color(THEME.grid)
428
+ ax3.grid(True, alpha=0.3, color=THEME.grid)
429
+
430
+ plt.tight_layout()
431
+ fig.savefig(output_path, dpi=150, facecolor=THEME.background)
432
+ plt.close(fig)
433
+
434
+ return output_path
435
+
436
+ def operator_histogram(
437
+ self,
438
+ op_counts: dict[str, int],
439
+ output_path: Path,
440
+ max_ops: int = 15,
441
+ ) -> Path | None:
442
+ """
443
+ Generate operator type histogram.
444
+
445
+ Shows distribution of operator types in the model, sorted by frequency.
446
+ """
447
+ if not _MATPLOTLIB_AVAILABLE or not op_counts:
448
+ return None
449
+
450
+ # Sort by count and take top N
451
+ sorted_ops = sorted(op_counts.items(), key=lambda x: -x[1])
452
+ if len(sorted_ops) > max_ops:
453
+ top_ops = sorted_ops[:max_ops]
454
+ other_count = sum(count for _, count in sorted_ops[max_ops:])
455
+ if other_count > 0:
456
+ top_ops.append(("Other", other_count))
457
+ else:
458
+ top_ops = sorted_ops
459
+
460
+ labels = [op for op, _ in top_ops]
461
+ counts = [count for _, count in top_ops]
462
+
463
+ # Create figure
464
+ fig, ax = plt.subplots(
465
+ figsize=(THEME.figure_width, THEME.figure_height), dpi=THEME.figure_dpi
466
+ )
467
+
468
+ # Horizontal bar chart
469
+ y_pos = range(len(labels))
470
+ bars = ax.barh(
471
+ y_pos,
472
+ counts,
473
+ color=THEME.accent_primary,
474
+ edgecolor=THEME.background,
475
+ height=0.7,
476
+ )
477
+
478
+ # Labels
479
+ ax.set_yticks(y_pos)
480
+ ax.set_yticklabels(labels)
481
+ ax.set_xlabel("Node Count")
482
+ ax.invert_yaxis() # Largest at top
483
+
484
+ # Add value labels on bars
485
+ for bar, count in zip(bars, counts, strict=False):
486
+ ax.text(
487
+ bar.get_width() + max(counts) * 0.01,
488
+ bar.get_y() + bar.get_height() / 2,
489
+ str(count),
490
+ va="center",
491
+ color=THEME.text,
492
+ fontsize=THEME.tick_size,
493
+ )
494
+
495
+ _apply_theme(fig, ax, "Operator Type Distribution")
496
+
497
+ # Adjust layout
498
+ plt.tight_layout()
499
+ fig.savefig(
500
+ output_path,
501
+ facecolor=THEME.background,
502
+ edgecolor="none",
503
+ bbox_inches="tight",
504
+ )
505
+ plt.close(fig)
506
+
507
+ return output_path
508
+
509
+ def param_distribution(
510
+ self,
511
+ params_by_op: dict[str, float],
512
+ output_path: Path,
513
+ max_slices: int = 6,
514
+ min_pct_for_label: float = 3.0,
515
+ ) -> Path | None:
516
+ """
517
+ Generate parameter distribution pie chart.
518
+
519
+ Shows how parameters are distributed across operator types.
520
+ Uses a legend to avoid label overlap on small slices.
521
+ """
522
+ if not _MATPLOTLIB_AVAILABLE or not params_by_op:
523
+ return None
524
+
525
+ # Filter out zero-param ops and sort
526
+ nonzero_ops = {k: v for k, v in params_by_op.items() if v > 0}
527
+ if not nonzero_ops:
528
+ return None
529
+
530
+ total = sum(nonzero_ops.values())
531
+ sorted_ops = sorted(nonzero_ops.items(), key=lambda x: -x[1])
532
+
533
+ # Group small slices into "Other" more aggressively
534
+ top_ops: list[tuple[str, float]] = []
535
+ other_count = 0.0
536
+ for op, count in sorted_ops:
537
+ pct = (count / total) * 100
538
+ if len(top_ops) < max_slices and pct >= min_pct_for_label:
539
+ top_ops.append((op, count))
540
+ else:
541
+ other_count += count
542
+
543
+ if other_count > 0:
544
+ top_ops.append(("Other", other_count))
545
+
546
+ # If we only have "Other", show top ops anyway
547
+ if len(top_ops) == 1 and top_ops[0][0] == "Other":
548
+ top_ops = sorted_ops[:max_slices]
549
+ if len(sorted_ops) > max_slices:
550
+ other_count = sum(count for _, count in sorted_ops[max_slices:])
551
+ if other_count > 0:
552
+ top_ops.append(("Other", other_count))
553
+
554
+ sizes = [count for _, count in top_ops]
555
+ colors = THEME.palette[: len(sizes)]
556
+
557
+ # Create figure with space for legend
558
+ fig, ax = plt.subplots(figsize=(10, 7), dpi=THEME.figure_dpi)
559
+
560
+ # Create pie with no labels (use legend instead)
561
+ wedges, _texts, autotexts = ax.pie(
562
+ sizes,
563
+ labels=None, # No inline labels
564
+ colors=colors,
565
+ autopct=lambda pct: f"{pct:.1f}%" if pct >= 5 else "",
566
+ startangle=90,
567
+ pctdistance=0.75,
568
+ textprops={"color": THEME.text, "fontsize": 11, "fontweight": "bold"},
569
+ )
570
+
571
+ # Style the percentage text
572
+ for autotext in autotexts:
573
+ autotext.set_color("#ffffff")
574
+ autotext.set_fontweight("bold")
575
+
576
+ # Create legend with op names and param counts
577
+ legend_labels = [f"{op} ({_format_count(int(count))})" for op, count in top_ops]
578
+ ax.legend(
579
+ wedges,
580
+ legend_labels,
581
+ title="Operator Type",
582
+ loc="center left",
583
+ bbox_to_anchor=(1.0, 0.5),
584
+ fontsize=10,
585
+ title_fontsize=11,
586
+ frameon=True,
587
+ facecolor=THEME.plot_background,
588
+ edgecolor=THEME.grid,
589
+ labelcolor=THEME.text,
590
+ )
591
+
592
+ ax.set_title(
593
+ "Parameter Distribution by Operator Type",
594
+ color=THEME.text,
595
+ fontsize=THEME.title_size,
596
+ fontweight="bold",
597
+ pad=20,
598
+ )
599
+
600
+ fig.patch.set_facecolor(THEME.background)
601
+
602
+ plt.tight_layout()
603
+ fig.savefig(
604
+ output_path,
605
+ facecolor=THEME.background,
606
+ edgecolor="none",
607
+ bbox_inches="tight",
608
+ )
609
+ plt.close(fig)
610
+
611
+ return output_path
612
+
613
+ def flops_distribution(
614
+ self,
615
+ flops_by_op: dict[str, int],
616
+ output_path: Path,
617
+ max_ops: int = 10,
618
+ ) -> Path | None:
619
+ """
620
+ Generate FLOPs distribution bar chart.
621
+
622
+ Shows computational cost distribution across operator types.
623
+ """
624
+ if not _MATPLOTLIB_AVAILABLE or not flops_by_op:
625
+ return None
626
+
627
+ # Filter out zero-FLOP ops and sort
628
+ nonzero_ops = {k: v for k, v in flops_by_op.items() if v > 0}
629
+ if not nonzero_ops:
630
+ return None
631
+
632
+ sorted_ops = sorted(nonzero_ops.items(), key=lambda x: -x[1])[:max_ops]
633
+
634
+ labels = [op for op, _ in sorted_ops]
635
+ values = [flops for _, flops in sorted_ops]
636
+
637
+ # Create figure
638
+ fig, ax = plt.subplots(
639
+ figsize=(THEME.figure_width, THEME.figure_height), dpi=THEME.figure_dpi
640
+ )
641
+
642
+ x_pos = range(len(labels))
643
+ bars = ax.bar(
644
+ x_pos,
645
+ values,
646
+ color=THEME.accent_secondary,
647
+ edgecolor=THEME.background,
648
+ width=0.7,
649
+ )
650
+
651
+ ax.set_xticks(x_pos)
652
+ ax.set_xticklabels(labels, rotation=45, ha="right")
653
+ ax.set_ylabel("FLOPs")
654
+
655
+ # Add value labels on bars
656
+ for bar, value in zip(bars, values, strict=False):
657
+ ax.text(
658
+ bar.get_x() + bar.get_width() / 2,
659
+ bar.get_height() + max(values) * 0.02,
660
+ _format_count(value),
661
+ ha="center",
662
+ va="bottom",
663
+ color=THEME.text,
664
+ fontsize=THEME.tick_size,
665
+ )
666
+
667
+ _apply_theme(fig, ax, "FLOPs Distribution by Operator Type")
668
+
669
+ plt.tight_layout()
670
+ fig.savefig(
671
+ output_path,
672
+ facecolor=THEME.background,
673
+ edgecolor="none",
674
+ bbox_inches="tight",
675
+ )
676
+ plt.close(fig)
677
+
678
+ return output_path
679
+
680
+ def complexity_summary(
681
+ self,
682
+ report: InspectionReport,
683
+ output_path: Path,
684
+ ) -> Path | None:
685
+ """
686
+ Generate complexity summary dashboard.
687
+
688
+ Multi-panel figure showing key metrics at a glance.
689
+ """
690
+ if not _MATPLOTLIB_AVAILABLE:
691
+ return None
692
+ if not report.param_counts or not report.flop_counts or not report.memory_estimates:
693
+ return None
694
+
695
+ fig, axes = plt.subplots(1, 3, figsize=(THEME.figure_width, 4), dpi=THEME.figure_dpi)
696
+ fig.patch.set_facecolor(THEME.background)
697
+
698
+ metrics = [
699
+ ("Parameters", report.param_counts.total, THEME.accent_primary),
700
+ ("FLOPs", report.flop_counts.total, THEME.accent_secondary),
701
+ (
702
+ "Memory (bytes)",
703
+ report.memory_estimates.model_size_bytes,
704
+ THEME.accent_tertiary,
705
+ ),
706
+ ]
707
+
708
+ for ax, (label, value, color) in zip(axes, metrics, strict=False):
709
+ ax.set_facecolor(THEME.background)
710
+
711
+ # Large centered value
712
+ ax.text(
713
+ 0.5,
714
+ 0.6,
715
+ _format_count(value),
716
+ ha="center",
717
+ va="center",
718
+ fontsize=28,
719
+ fontweight="bold",
720
+ color=color,
721
+ transform=ax.transAxes,
722
+ )
723
+
724
+ # Label below
725
+ ax.text(
726
+ 0.5,
727
+ 0.25,
728
+ label,
729
+ ha="center",
730
+ va="center",
731
+ fontsize=THEME.label_size,
732
+ color=THEME.text,
733
+ transform=ax.transAxes,
734
+ )
735
+
736
+ # Remove all axes elements
737
+ ax.set_xticks([])
738
+ ax.set_yticks([])
739
+ for spine in ax.spines.values():
740
+ spine.set_visible(False)
741
+
742
+ # Main title
743
+ fig.suptitle(
744
+ "Model Complexity Summary",
745
+ color=THEME.text,
746
+ fontsize=THEME.title_size,
747
+ fontweight="bold",
748
+ y=0.95,
749
+ )
750
+
751
+ plt.tight_layout(rect=[0, 0, 1, 0.9])
752
+ fig.savefig(
753
+ output_path,
754
+ facecolor=THEME.background,
755
+ edgecolor="none",
756
+ bbox_inches="tight",
757
+ )
758
+ plt.close(fig)
759
+
760
+ return output_path
761
+
762
+ def layer_depth_profile(
763
+ self,
764
+ graph_info: GraphInfo,
765
+ param_counts: ParamCounts,
766
+ flop_counts: FlopCounts,
767
+ output_path: Path,
768
+ ) -> Path | None:
769
+ """
770
+ Generate layer depth profile showing cumulative params/FLOPs.
771
+
772
+ Shows how complexity accumulates through the network depth.
773
+ """
774
+ if not _MATPLOTLIB_AVAILABLE:
775
+ return None
776
+
777
+ # Get ordered nodes with their metrics
778
+ nodes_with_metrics = []
779
+ for node in graph_info.nodes:
780
+ params = param_counts.by_node.get(node.name, 0)
781
+ flops = flop_counts.by_node.get(node.name, 0)
782
+ if params > 0 or flops > 0:
783
+ nodes_with_metrics.append((node.name, params, flops))
784
+
785
+ if not nodes_with_metrics:
786
+ return None
787
+
788
+ # Compute cumulative values
789
+ cum_params: list[float] = []
790
+ cum_flops: list[int] = []
791
+ running_params: float = 0.0
792
+ running_flops: int = 0
793
+ for _, params, flops in nodes_with_metrics:
794
+ running_params += params
795
+ running_flops += flops
796
+ cum_params.append(running_params)
797
+ cum_flops.append(running_flops)
798
+
799
+ x = range(len(nodes_with_metrics))
800
+
801
+ # Create figure with two y-axes
802
+ fig, ax1 = plt.subplots(
803
+ figsize=(THEME.figure_width, THEME.figure_height), dpi=THEME.figure_dpi
804
+ )
805
+ ax2 = ax1.twinx()
806
+
807
+ # Plot cumulative params
808
+ ax1.fill_between(x, cum_params, alpha=0.3, color=THEME.accent_primary)
809
+ ax1.plot(x, cum_params, color=THEME.accent_primary, linewidth=2, label="Parameters")
810
+ ax1.set_ylabel("Cumulative Parameters", color=THEME.accent_primary)
811
+ ax1.tick_params(axis="y", labelcolor=THEME.accent_primary)
812
+
813
+ # Plot cumulative FLOPs
814
+ ax2.fill_between(x, cum_flops, alpha=0.3, color=THEME.accent_secondary)
815
+ ax2.plot(x, cum_flops, color=THEME.accent_secondary, linewidth=2, label="FLOPs")
816
+ ax2.set_ylabel("Cumulative FLOPs", color=THEME.accent_secondary)
817
+ ax2.tick_params(axis="y", labelcolor=THEME.accent_secondary)
818
+
819
+ ax1.set_xlabel("Layer Index")
820
+
821
+ # Apply theme to primary axis
822
+ fig.patch.set_facecolor(THEME.background)
823
+ ax1.set_facecolor(THEME.background)
824
+ ax1.set_title(
825
+ "Layer Depth Profile",
826
+ color=THEME.text,
827
+ fontsize=THEME.title_size,
828
+ fontweight="bold",
829
+ pad=15,
830
+ )
831
+
832
+ for spine in ax1.spines.values():
833
+ spine.set_color(THEME.grid)
834
+ for spine in ax2.spines.values():
835
+ spine.set_color(THEME.grid)
836
+
837
+ ax1.tick_params(colors=THEME.text, labelsize=THEME.tick_size)
838
+ ax1.xaxis.label.set_color(THEME.text)
839
+ ax1.xaxis.label.set_fontsize(THEME.label_size)
840
+
841
+ ax1.grid(True, linestyle="--", alpha=0.3, color=THEME.grid)
842
+
843
+ # Legend
844
+ lines = [
845
+ plt.Line2D([0], [0], color=THEME.accent_primary, linewidth=2),
846
+ plt.Line2D([0], [0], color=THEME.accent_secondary, linewidth=2),
847
+ ]
848
+ ax1.legend(
849
+ lines,
850
+ ["Parameters", "FLOPs"],
851
+ loc="upper left",
852
+ facecolor=THEME.background,
853
+ labelcolor=THEME.text,
854
+ )
855
+
856
+ plt.tight_layout()
857
+ fig.savefig(
858
+ output_path,
859
+ facecolor=THEME.background,
860
+ edgecolor="none",
861
+ bbox_inches="tight",
862
+ )
863
+ plt.close(fig)
864
+
865
+ return output_path
866
+
867
+ def layer_timing_chart(
868
+ self,
869
+ profiling_result: ProfilingResult,
870
+ output_path: Path,
871
+ top_n: int = 15,
872
+ ) -> Path | None:
873
+ """
874
+ Generate per-layer timing breakdown chart (Story 9.3.4).
875
+
876
+ Shows execution time for the slowest N layers as a horizontal bar chart.
877
+
878
+ Args:
879
+ profiling_result: ProfilingResult from OperationalProfiler.profile_model()
880
+ output_path: Path to save the chart
881
+ top_n: Number of slowest layers to show
882
+
883
+ Returns:
884
+ Path to generated chart or None if unavailable
885
+ """
886
+ if not _MATPLOTLIB_AVAILABLE:
887
+ return None
888
+
889
+ # Import at runtime
890
+ from .operational_profiling import ProfilingResult
891
+
892
+ if not isinstance(profiling_result, ProfilingResult):
893
+ return None
894
+
895
+ slowest = profiling_result.get_slowest_layers(top_n)
896
+ if not slowest:
897
+ return None
898
+
899
+ # Prepare data (reversed for horizontal bar chart - slowest at top)
900
+ layers = [lp.name[:30] for lp in reversed(slowest)] # Truncate long names
901
+ times = [lp.duration_ms for lp in reversed(slowest)]
902
+ op_types = [lp.op_type for lp in reversed(slowest)]
903
+
904
+ # Create figure
905
+ fig, ax = plt.subplots(figsize=(12, max(6, len(slowest) * 0.4)), dpi=THEME.figure_dpi)
906
+
907
+ # Color by op type
908
+ unique_ops = list(set(op_types))
909
+ colors = [THEME.palette[unique_ops.index(op) % len(THEME.palette)] for op in op_types]
910
+
911
+ y_pos = range(len(layers))
912
+ bars = ax.barh(y_pos, times, color=colors, edgecolor=THEME.background, height=0.7)
913
+
914
+ ax.set_yticks(list(y_pos))
915
+ ax.set_yticklabels(layers, fontsize=9)
916
+ ax.set_xlabel("Time (ms)", color=THEME.text, fontsize=10)
917
+
918
+ # Add value labels
919
+ for bar, time_val, op_type in zip(bars, times, op_types, strict=False):
920
+ ax.text(
921
+ bar.get_width() + max(times) * 0.01,
922
+ bar.get_y() + bar.get_height() / 2,
923
+ f"{time_val:.2f}ms ({op_type})",
924
+ va="center",
925
+ ha="left",
926
+ color=THEME.text,
927
+ fontsize=8,
928
+ )
929
+
930
+ _apply_theme(
931
+ fig,
932
+ ax,
933
+ f"Top {len(slowest)} Slowest Layers (Total: {profiling_result.total_time_ms:.2f}ms)",
934
+ )
935
+
936
+ # Add legend for op types
937
+ from matplotlib.patches import Patch
938
+
939
+ legend_elements = [
940
+ Patch(facecolor=THEME.palette[i % len(THEME.palette)], label=op)
941
+ for i, op in enumerate(unique_ops)
942
+ ]
943
+ ax.legend(
944
+ handles=legend_elements,
945
+ loc="lower right",
946
+ fontsize=8,
947
+ facecolor=THEME.plot_background,
948
+ edgecolor=THEME.grid,
949
+ labelcolor=THEME.text,
950
+ )
951
+
952
+ plt.tight_layout()
953
+ fig.savefig(
954
+ output_path,
955
+ facecolor=THEME.background,
956
+ edgecolor="none",
957
+ bbox_inches="tight",
958
+ )
959
+ plt.close(fig)
960
+
961
+ return output_path
962
+
963
+ def op_time_distribution_chart(
964
+ self,
965
+ profiling_result: ProfilingResult,
966
+ output_path: Path,
967
+ ) -> Path | None:
968
+ """
969
+ Generate operator type time distribution pie chart.
970
+
971
+ Shows percentage of execution time by operator type.
972
+
973
+ Args:
974
+ profiling_result: ProfilingResult from profiling
975
+ output_path: Path to save the chart
976
+
977
+ Returns:
978
+ Path to generated chart or None if unavailable
979
+ """
980
+ if not _MATPLOTLIB_AVAILABLE:
981
+ return None
982
+
983
+ from .operational_profiling import ProfilingResult
984
+
985
+ if not isinstance(profiling_result, ProfilingResult):
986
+ return None
987
+
988
+ time_by_op = profiling_result.get_time_by_op_type()
989
+ if not time_by_op:
990
+ return None
991
+
992
+ # Limit to top 8 ops, aggregate rest into "Other"
993
+ sorted_ops = sorted(time_by_op.items(), key=lambda x: -x[1])
994
+ labels = []
995
+ values = []
996
+ other_time = 0.0
997
+
998
+ for i, (op, time_val) in enumerate(sorted_ops):
999
+ if i < 8:
1000
+ labels.append(op)
1001
+ values.append(time_val)
1002
+ else:
1003
+ other_time += time_val
1004
+
1005
+ if other_time > 0:
1006
+ labels.append("Other")
1007
+ values.append(other_time)
1008
+
1009
+ # Create figure
1010
+ fig, ax = plt.subplots(figsize=(10, 8), dpi=THEME.figure_dpi)
1011
+ ax.set_facecolor(THEME.plot_background)
1012
+
1013
+ colors = THEME.palette[: len(values)]
1014
+ wedges, _texts, autotexts = ax.pie(
1015
+ values,
1016
+ labels=None,
1017
+ autopct=lambda pct: f"{pct:.1f}%" if pct > 3 else "",
1018
+ startangle=90,
1019
+ colors=colors,
1020
+ pctdistance=0.75,
1021
+ wedgeprops={"edgecolor": THEME.background, "linewidth": 1},
1022
+ )
1023
+
1024
+ for autotext in autotexts:
1025
+ autotext.set_color(THEME.text)
1026
+ autotext.set_fontsize(9)
1027
+
1028
+ # Legend
1029
+ legend_labels = [
1030
+ f"{op} ({time_val:.2f}ms)" for op, time_val in zip(labels, values, strict=False)
1031
+ ]
1032
+ ax.legend(
1033
+ wedges,
1034
+ legend_labels,
1035
+ title="Operator Type",
1036
+ loc="center left",
1037
+ bbox_to_anchor=(1.0, 0.5),
1038
+ fontsize=9,
1039
+ title_fontsize=10,
1040
+ frameon=True,
1041
+ facecolor=THEME.plot_background,
1042
+ edgecolor=THEME.grid,
1043
+ labelcolor=THEME.text,
1044
+ )
1045
+
1046
+ total_time = sum(values)
1047
+ ax.set_title(
1048
+ f"Execution Time Distribution (Total: {total_time:.2f}ms)",
1049
+ color=THEME.text,
1050
+ fontsize=THEME.title_size,
1051
+ fontweight="bold",
1052
+ pad=20,
1053
+ )
1054
+
1055
+ fig.patch.set_facecolor(THEME.background)
1056
+
1057
+ plt.tight_layout()
1058
+ fig.savefig(
1059
+ output_path,
1060
+ facecolor=THEME.background,
1061
+ edgecolor="none",
1062
+ bbox_inches="tight",
1063
+ )
1064
+ plt.close(fig)
1065
+
1066
+ return output_path
1067
+
1068
+
1069
+ def generate_visualizations(
1070
+ report: InspectionReport,
1071
+ output_dir: Path | str,
1072
+ logger: logging.Logger | None = None,
1073
+ ) -> dict[str, Path]:
1074
+ """
1075
+ Convenience function to generate all visualizations for a report.
1076
+
1077
+ Args:
1078
+ report: The inspection report.
1079
+ output_dir: Directory to save PNG files.
1080
+ logger: Optional logger.
1081
+
1082
+ Returns:
1083
+ Dict mapping chart name to file path.
1084
+ """
1085
+ generator = VisualizationGenerator(logger=logger)
1086
+ return generator.generate_all(report, Path(output_dir))