haoline 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haoline/.streamlit/config.toml +10 -0
- haoline/__init__.py +248 -0
- haoline/analyzer.py +935 -0
- haoline/cli.py +2712 -0
- haoline/compare.py +811 -0
- haoline/compare_visualizations.py +1564 -0
- haoline/edge_analysis.py +525 -0
- haoline/eval/__init__.py +131 -0
- haoline/eval/adapters.py +844 -0
- haoline/eval/cli.py +390 -0
- haoline/eval/comparison.py +542 -0
- haoline/eval/deployment.py +633 -0
- haoline/eval/schemas.py +833 -0
- haoline/examples/__init__.py +15 -0
- haoline/examples/basic_inspection.py +74 -0
- haoline/examples/compare_models.py +117 -0
- haoline/examples/hardware_estimation.py +78 -0
- haoline/format_adapters.py +1001 -0
- haoline/formats/__init__.py +123 -0
- haoline/formats/coreml.py +250 -0
- haoline/formats/gguf.py +483 -0
- haoline/formats/openvino.py +255 -0
- haoline/formats/safetensors.py +273 -0
- haoline/formats/tflite.py +369 -0
- haoline/hardware.py +2307 -0
- haoline/hierarchical_graph.py +462 -0
- haoline/html_export.py +1573 -0
- haoline/layer_summary.py +769 -0
- haoline/llm_summarizer.py +465 -0
- haoline/op_icons.py +618 -0
- haoline/operational_profiling.py +1492 -0
- haoline/patterns.py +1116 -0
- haoline/pdf_generator.py +265 -0
- haoline/privacy.py +250 -0
- haoline/pydantic_models.py +241 -0
- haoline/report.py +1923 -0
- haoline/report_sections.py +539 -0
- haoline/risks.py +521 -0
- haoline/schema.py +523 -0
- haoline/streamlit_app.py +2024 -0
- haoline/tests/__init__.py +4 -0
- haoline/tests/conftest.py +123 -0
- haoline/tests/test_analyzer.py +868 -0
- haoline/tests/test_compare_visualizations.py +293 -0
- haoline/tests/test_edge_analysis.py +243 -0
- haoline/tests/test_eval.py +604 -0
- haoline/tests/test_format_adapters.py +460 -0
- haoline/tests/test_hardware.py +237 -0
- haoline/tests/test_hardware_recommender.py +90 -0
- haoline/tests/test_hierarchical_graph.py +326 -0
- haoline/tests/test_html_export.py +180 -0
- haoline/tests/test_layer_summary.py +428 -0
- haoline/tests/test_llm_patterns.py +540 -0
- haoline/tests/test_llm_summarizer.py +339 -0
- haoline/tests/test_patterns.py +774 -0
- haoline/tests/test_pytorch.py +327 -0
- haoline/tests/test_report.py +383 -0
- haoline/tests/test_risks.py +398 -0
- haoline/tests/test_schema.py +417 -0
- haoline/tests/test_tensorflow.py +380 -0
- haoline/tests/test_visualizations.py +316 -0
- haoline/universal_ir.py +856 -0
- haoline/visualizations.py +1086 -0
- haoline/visualize_yolo.py +44 -0
- haoline/web.py +110 -0
- haoline-0.3.0.dist-info/METADATA +471 -0
- haoline-0.3.0.dist-info/RECORD +70 -0
- haoline-0.3.0.dist-info/WHEEL +4 -0
- haoline-0.3.0.dist-info/entry_points.txt +5 -0
- haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,1086 @@
|
|
|
1
|
+
# Copyright (c) 2025 HaoLine Contributors
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Visualization module for HaoLine.
|
|
6
|
+
|
|
7
|
+
Generates matplotlib-based charts for model architecture analysis:
|
|
8
|
+
- Operator type histogram
|
|
9
|
+
- Layer depth profile (cumulative params/FLOPs)
|
|
10
|
+
- Parameter distribution by layer type
|
|
11
|
+
- Shape evolution through the network
|
|
12
|
+
|
|
13
|
+
All charts use a consistent dark theme suitable for technical documentation.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import logging
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import TYPE_CHECKING
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from .analyzer import FlopCounts, GraphInfo, ParamCounts
|
|
25
|
+
from .operational_profiling import BatchSizeSweep, ProfilingResult, ResolutionSweep
|
|
26
|
+
from .report import InspectionReport
|
|
27
|
+
|
|
28
|
+
# Attempt to import matplotlib with Agg backend (non-interactive)
|
|
29
|
+
_MATPLOTLIB_AVAILABLE = False
|
|
30
|
+
try:
|
|
31
|
+
import matplotlib
|
|
32
|
+
|
|
33
|
+
matplotlib.use("Agg") # Must be before importing pyplot
|
|
34
|
+
import matplotlib.pyplot as plt
|
|
35
|
+
from matplotlib.figure import Figure
|
|
36
|
+
|
|
37
|
+
_MATPLOTLIB_AVAILABLE = True
|
|
38
|
+
except ImportError:
|
|
39
|
+
plt = None # type: ignore
|
|
40
|
+
Figure = None # type: ignore
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def is_available() -> bool:
|
|
44
|
+
"""Check if visualization module is available (matplotlib installed)."""
|
|
45
|
+
return _MATPLOTLIB_AVAILABLE
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class ChartTheme:
|
|
50
|
+
"""Consistent theme for all charts."""
|
|
51
|
+
|
|
52
|
+
# Colors - dark theme with vibrant accents
|
|
53
|
+
background: str = "#1a1a2e"
|
|
54
|
+
plot_background: str = "#16213e" # Slightly lighter for plot areas
|
|
55
|
+
text: str = "#eaeaea"
|
|
56
|
+
grid: str = "#2d2d44"
|
|
57
|
+
accent: str = "#00d9ff" # Primary accent (alias for accent_primary)
|
|
58
|
+
accent_primary: str = "#00d9ff" # Cyan
|
|
59
|
+
accent_secondary: str = "#ff6b6b" # Coral
|
|
60
|
+
accent_tertiary: str = "#4ecdc4" # Teal
|
|
61
|
+
accent_quaternary: str = "#ffe66d" # Yellow
|
|
62
|
+
|
|
63
|
+
# Color palette for multi-series charts
|
|
64
|
+
palette: tuple[str, ...] = (
|
|
65
|
+
"#00d9ff", # Cyan
|
|
66
|
+
"#ff6b6b", # Coral
|
|
67
|
+
"#4ecdc4", # Teal
|
|
68
|
+
"#ffe66d", # Yellow
|
|
69
|
+
"#a29bfe", # Lavender
|
|
70
|
+
"#fd79a8", # Pink
|
|
71
|
+
"#55efc4", # Mint
|
|
72
|
+
"#ffeaa7", # Pale yellow
|
|
73
|
+
"#74b9ff", # Light blue
|
|
74
|
+
"#ff7675", # Salmon
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Alias for palette (some code uses 'colors')
|
|
78
|
+
@property
|
|
79
|
+
def colors(self) -> tuple[str, ...]:
|
|
80
|
+
return self.palette
|
|
81
|
+
|
|
82
|
+
# Typography
|
|
83
|
+
font_family: str = "sans-serif"
|
|
84
|
+
title_size: int = 14
|
|
85
|
+
label_size: int = 11
|
|
86
|
+
tick_size: int = 9
|
|
87
|
+
|
|
88
|
+
# Figure
|
|
89
|
+
figure_dpi: int = 150
|
|
90
|
+
figure_width: float = 10.0
|
|
91
|
+
figure_height: float = 6.0
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# Default theme instance
|
|
95
|
+
THEME = ChartTheme()
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _apply_theme(fig: Figure, ax, title: str) -> None:
|
|
99
|
+
"""Apply consistent theme to a matplotlib figure and axes."""
|
|
100
|
+
if not _MATPLOTLIB_AVAILABLE:
|
|
101
|
+
return
|
|
102
|
+
|
|
103
|
+
# Figure background
|
|
104
|
+
fig.patch.set_facecolor(THEME.background)
|
|
105
|
+
|
|
106
|
+
# Axes styling
|
|
107
|
+
ax.set_facecolor(THEME.background)
|
|
108
|
+
ax.set_title(title, color=THEME.text, fontsize=THEME.title_size, fontweight="bold", pad=15)
|
|
109
|
+
|
|
110
|
+
# Spines
|
|
111
|
+
for spine in ax.spines.values():
|
|
112
|
+
spine.set_color(THEME.grid)
|
|
113
|
+
spine.set_linewidth(0.5)
|
|
114
|
+
|
|
115
|
+
# Ticks and labels
|
|
116
|
+
ax.tick_params(colors=THEME.text, labelsize=THEME.tick_size)
|
|
117
|
+
ax.xaxis.label.set_color(THEME.text)
|
|
118
|
+
ax.yaxis.label.set_color(THEME.text)
|
|
119
|
+
ax.xaxis.label.set_fontsize(THEME.label_size)
|
|
120
|
+
ax.yaxis.label.set_fontsize(THEME.label_size)
|
|
121
|
+
|
|
122
|
+
# Grid
|
|
123
|
+
ax.grid(True, linestyle="--", alpha=0.3, color=THEME.grid)
|
|
124
|
+
ax.set_axisbelow(True)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _format_count(n: int) -> str:
|
|
128
|
+
"""Format large numbers with K/M/B suffixes."""
|
|
129
|
+
if n >= 1e9:
|
|
130
|
+
return f"{n / 1e9:.1f}B"
|
|
131
|
+
if n >= 1e6:
|
|
132
|
+
return f"{n / 1e6:.1f}M"
|
|
133
|
+
if n >= 1e3:
|
|
134
|
+
return f"{n / 1e3:.1f}K"
|
|
135
|
+
return str(n)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class VisualizationGenerator:
|
|
139
|
+
"""
|
|
140
|
+
Generate visualization assets for ONNX model reports.
|
|
141
|
+
|
|
142
|
+
Usage:
|
|
143
|
+
viz = VisualizationGenerator()
|
|
144
|
+
paths = viz.generate_all(report, output_dir=Path("assets"))
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
def __init__(self, logger: logging.Logger | None = None):
|
|
148
|
+
self.logger = logger or logging.getLogger("haoline.viz")
|
|
149
|
+
|
|
150
|
+
def generate_all(
|
|
151
|
+
self,
|
|
152
|
+
report: InspectionReport,
|
|
153
|
+
output_dir: Path,
|
|
154
|
+
) -> dict[str, Path]:
|
|
155
|
+
"""
|
|
156
|
+
Generate all visualization assets for a report.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
report: The inspection report containing analysis results.
|
|
160
|
+
output_dir: Directory to save PNG files.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
Dict mapping chart name to file path, e.g.:
|
|
164
|
+
{"op_histogram": Path("assets/op_histogram.png"), ...}
|
|
165
|
+
"""
|
|
166
|
+
if not _MATPLOTLIB_AVAILABLE:
|
|
167
|
+
self.logger.warning("matplotlib not available, skipping visualizations")
|
|
168
|
+
return {}
|
|
169
|
+
|
|
170
|
+
output_dir = Path(output_dir)
|
|
171
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
172
|
+
|
|
173
|
+
paths: dict[str, Path] = {}
|
|
174
|
+
|
|
175
|
+
# Generate each chart type
|
|
176
|
+
try:
|
|
177
|
+
if report.graph_summary and report.graph_summary.op_type_counts:
|
|
178
|
+
path = self.operator_histogram(
|
|
179
|
+
report.graph_summary.op_type_counts,
|
|
180
|
+
output_dir / "op_histogram.png",
|
|
181
|
+
)
|
|
182
|
+
if path:
|
|
183
|
+
paths["op_histogram"] = path
|
|
184
|
+
except Exception as e:
|
|
185
|
+
self.logger.warning(f"Failed to generate operator histogram: {e}")
|
|
186
|
+
|
|
187
|
+
try:
|
|
188
|
+
if report.param_counts and report.param_counts.by_op_type:
|
|
189
|
+
path = self.param_distribution(
|
|
190
|
+
report.param_counts.by_op_type,
|
|
191
|
+
output_dir / "param_distribution.png",
|
|
192
|
+
)
|
|
193
|
+
if path:
|
|
194
|
+
paths["param_distribution"] = path
|
|
195
|
+
except Exception as e:
|
|
196
|
+
self.logger.warning(f"Failed to generate param distribution: {e}")
|
|
197
|
+
|
|
198
|
+
try:
|
|
199
|
+
if report.flop_counts and report.flop_counts.by_op_type:
|
|
200
|
+
path = self.flops_distribution(
|
|
201
|
+
report.flop_counts.by_op_type,
|
|
202
|
+
output_dir / "flops_distribution.png",
|
|
203
|
+
)
|
|
204
|
+
if path:
|
|
205
|
+
paths["flops_distribution"] = path
|
|
206
|
+
except Exception as e:
|
|
207
|
+
self.logger.warning(f"Failed to generate FLOPs distribution: {e}")
|
|
208
|
+
|
|
209
|
+
try:
|
|
210
|
+
if report.param_counts and report.flop_counts:
|
|
211
|
+
path = self.complexity_summary(
|
|
212
|
+
report,
|
|
213
|
+
output_dir / "complexity_summary.png",
|
|
214
|
+
)
|
|
215
|
+
if path:
|
|
216
|
+
paths["complexity_summary"] = path
|
|
217
|
+
except Exception as e:
|
|
218
|
+
self.logger.warning(f"Failed to generate complexity summary: {e}")
|
|
219
|
+
|
|
220
|
+
# Resolution sweep chart (Story 6.8)
|
|
221
|
+
try:
|
|
222
|
+
if hasattr(report, "resolution_sweep") and report.resolution_sweep:
|
|
223
|
+
path = self.resolution_scaling_chart(
|
|
224
|
+
report.resolution_sweep,
|
|
225
|
+
output_dir / "resolution_scaling.png",
|
|
226
|
+
)
|
|
227
|
+
if path:
|
|
228
|
+
paths["resolution_scaling"] = path
|
|
229
|
+
except Exception as e:
|
|
230
|
+
self.logger.warning(f"Failed to generate resolution scaling chart: {e}")
|
|
231
|
+
|
|
232
|
+
# Batch size sweep chart
|
|
233
|
+
try:
|
|
234
|
+
if hasattr(report, "batch_size_sweep") and report.batch_size_sweep:
|
|
235
|
+
path = self.batch_scaling_chart(
|
|
236
|
+
report.batch_size_sweep,
|
|
237
|
+
output_dir / "batch_scaling.png",
|
|
238
|
+
)
|
|
239
|
+
if path:
|
|
240
|
+
paths["batch_scaling"] = path
|
|
241
|
+
except Exception as e:
|
|
242
|
+
self.logger.warning(f"Failed to generate batch scaling chart: {e}")
|
|
243
|
+
|
|
244
|
+
self.logger.info(f"Generated {len(paths)} visualization assets in {output_dir}")
|
|
245
|
+
return paths
|
|
246
|
+
|
|
247
|
+
def resolution_scaling_chart(
|
|
248
|
+
self,
|
|
249
|
+
sweep: ResolutionSweep,
|
|
250
|
+
output_path: Path,
|
|
251
|
+
) -> Path | None:
|
|
252
|
+
"""
|
|
253
|
+
Generate resolution scaling chart.
|
|
254
|
+
|
|
255
|
+
Shows how latency, throughput, and VRAM scale with resolution.
|
|
256
|
+
"""
|
|
257
|
+
if not _MATPLOTLIB_AVAILABLE:
|
|
258
|
+
return None
|
|
259
|
+
|
|
260
|
+
# Import at runtime for isinstance check (avoid circular import)
|
|
261
|
+
from .operational_profiling import ResolutionSweep
|
|
262
|
+
|
|
263
|
+
if not isinstance(sweep, ResolutionSweep):
|
|
264
|
+
return None
|
|
265
|
+
|
|
266
|
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
|
267
|
+
fig.patch.set_facecolor(THEME.background)
|
|
268
|
+
|
|
269
|
+
resolutions = sweep.resolutions
|
|
270
|
+
x = range(len(resolutions))
|
|
271
|
+
|
|
272
|
+
# Latency chart
|
|
273
|
+
ax1 = axes[0]
|
|
274
|
+
ax1.set_facecolor(THEME.plot_background)
|
|
275
|
+
latencies = [lat if lat != float("inf") else 0 for lat in sweep.latencies]
|
|
276
|
+
ax1.bar(x, latencies, color=THEME.palette[0], alpha=0.8)
|
|
277
|
+
ax1.set_xlabel("Resolution", color=THEME.text, fontsize=10)
|
|
278
|
+
ax1.set_ylabel("Latency (ms)", color=THEME.text, fontsize=10)
|
|
279
|
+
ax1.set_title("Latency vs Resolution", color=THEME.text, fontsize=12, fontweight="bold")
|
|
280
|
+
ax1.set_xticks(list(x))
|
|
281
|
+
ax1.set_xticklabels(resolutions, rotation=45, ha="right", fontsize=8)
|
|
282
|
+
ax1.tick_params(colors=THEME.text)
|
|
283
|
+
for spine in ax1.spines.values():
|
|
284
|
+
spine.set_color(THEME.grid)
|
|
285
|
+
ax1.grid(True, alpha=0.3, color=THEME.grid)
|
|
286
|
+
|
|
287
|
+
# Mark OOM points
|
|
288
|
+
for i, lat in enumerate(sweep.latencies):
|
|
289
|
+
if lat == float("inf"):
|
|
290
|
+
ax1.annotate(
|
|
291
|
+
"OOM",
|
|
292
|
+
(i, 0),
|
|
293
|
+
ha="center",
|
|
294
|
+
va="bottom",
|
|
295
|
+
color=THEME.accent,
|
|
296
|
+
fontsize=8,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
# Throughput chart
|
|
300
|
+
ax2 = axes[1]
|
|
301
|
+
ax2.set_facecolor(THEME.plot_background)
|
|
302
|
+
ax2.bar(x, sweep.throughputs, color=THEME.palette[1], alpha=0.8)
|
|
303
|
+
ax2.set_xlabel("Resolution", color=THEME.text, fontsize=10)
|
|
304
|
+
ax2.set_ylabel("Throughput (inf/s)", color=THEME.text, fontsize=10)
|
|
305
|
+
ax2.set_title("Throughput vs Resolution", color=THEME.text, fontsize=12, fontweight="bold")
|
|
306
|
+
ax2.set_xticks(list(x))
|
|
307
|
+
ax2.set_xticklabels(resolutions, rotation=45, ha="right", fontsize=8)
|
|
308
|
+
ax2.tick_params(colors=THEME.text)
|
|
309
|
+
for spine in ax2.spines.values():
|
|
310
|
+
spine.set_color(THEME.grid)
|
|
311
|
+
ax2.grid(True, alpha=0.3, color=THEME.grid)
|
|
312
|
+
|
|
313
|
+
# VRAM chart
|
|
314
|
+
ax3 = axes[2]
|
|
315
|
+
ax3.set_facecolor(THEME.plot_background)
|
|
316
|
+
ax3.bar(x, sweep.vram_usage_gb, color=THEME.palette[2], alpha=0.8)
|
|
317
|
+
ax3.set_xlabel("Resolution", color=THEME.text, fontsize=10)
|
|
318
|
+
ax3.set_ylabel("VRAM (GB)", color=THEME.text, fontsize=10)
|
|
319
|
+
ax3.set_title("VRAM vs Resolution", color=THEME.text, fontsize=12, fontweight="bold")
|
|
320
|
+
ax3.set_xticks(list(x))
|
|
321
|
+
ax3.set_xticklabels(resolutions, rotation=45, ha="right", fontsize=8)
|
|
322
|
+
ax3.tick_params(colors=THEME.text)
|
|
323
|
+
for spine in ax3.spines.values():
|
|
324
|
+
spine.set_color(THEME.grid)
|
|
325
|
+
ax3.grid(True, alpha=0.3, color=THEME.grid)
|
|
326
|
+
|
|
327
|
+
plt.tight_layout()
|
|
328
|
+
fig.savefig(output_path, dpi=150, facecolor=THEME.background)
|
|
329
|
+
plt.close(fig)
|
|
330
|
+
|
|
331
|
+
return output_path
|
|
332
|
+
|
|
333
|
+
def batch_scaling_chart(
|
|
334
|
+
self,
|
|
335
|
+
sweep: BatchSizeSweep,
|
|
336
|
+
output_path: Path,
|
|
337
|
+
) -> Path | None:
|
|
338
|
+
"""
|
|
339
|
+
Generate batch size scaling chart.
|
|
340
|
+
|
|
341
|
+
Shows how latency, throughput, and VRAM scale with batch size.
|
|
342
|
+
"""
|
|
343
|
+
if not _MATPLOTLIB_AVAILABLE:
|
|
344
|
+
return None
|
|
345
|
+
|
|
346
|
+
# Import at runtime for isinstance check (avoid circular import)
|
|
347
|
+
from .operational_profiling import BatchSizeSweep
|
|
348
|
+
|
|
349
|
+
if not isinstance(sweep, BatchSizeSweep):
|
|
350
|
+
return None
|
|
351
|
+
|
|
352
|
+
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
|
|
353
|
+
fig.patch.set_facecolor(THEME.background)
|
|
354
|
+
plt.subplots_adjust(wspace=0.3) # Add spacing between subplots
|
|
355
|
+
|
|
356
|
+
batch_sizes = sweep.batch_sizes
|
|
357
|
+
|
|
358
|
+
# Latency chart
|
|
359
|
+
ax1 = axes[0]
|
|
360
|
+
ax1.set_facecolor(THEME.plot_background)
|
|
361
|
+
latencies = [lat if lat != float("inf") else 0 for lat in sweep.latencies]
|
|
362
|
+
ax1.plot(
|
|
363
|
+
batch_sizes,
|
|
364
|
+
latencies,
|
|
365
|
+
color=THEME.palette[0],
|
|
366
|
+
marker="o",
|
|
367
|
+
linewidth=2,
|
|
368
|
+
markersize=8,
|
|
369
|
+
)
|
|
370
|
+
ax1.set_xlabel("Batch Size", color=THEME.text, fontsize=12)
|
|
371
|
+
ax1.set_ylabel("Latency (ms)", color=THEME.text, fontsize=12)
|
|
372
|
+
ax1.set_title("Latency vs Batch Size", color=THEME.text, fontsize=14, fontweight="bold")
|
|
373
|
+
ax1.tick_params(colors=THEME.text, labelsize=10)
|
|
374
|
+
for spine in ax1.spines.values():
|
|
375
|
+
spine.set_color(THEME.grid)
|
|
376
|
+
ax1.grid(True, alpha=0.3, color=THEME.grid)
|
|
377
|
+
|
|
378
|
+
# Throughput chart
|
|
379
|
+
ax2 = axes[1]
|
|
380
|
+
ax2.set_facecolor(THEME.plot_background)
|
|
381
|
+
ax2.plot(
|
|
382
|
+
batch_sizes,
|
|
383
|
+
sweep.throughputs,
|
|
384
|
+
color=THEME.palette[1],
|
|
385
|
+
marker="o",
|
|
386
|
+
linewidth=2,
|
|
387
|
+
markersize=8,
|
|
388
|
+
)
|
|
389
|
+
ax2.axvline(
|
|
390
|
+
x=sweep.optimal_batch_size,
|
|
391
|
+
color=THEME.accent_tertiary,
|
|
392
|
+
linestyle="--",
|
|
393
|
+
linewidth=2,
|
|
394
|
+
alpha=0.8,
|
|
395
|
+
label=f"Optimal: {sweep.optimal_batch_size}",
|
|
396
|
+
)
|
|
397
|
+
ax2.set_xlabel("Batch Size", color=THEME.text, fontsize=12)
|
|
398
|
+
ax2.set_ylabel("Throughput (inf/s)", color=THEME.text, fontsize=12)
|
|
399
|
+
ax2.set_title("Throughput vs Batch Size", color=THEME.text, fontsize=14, fontweight="bold")
|
|
400
|
+
ax2.tick_params(colors=THEME.text, labelsize=10)
|
|
401
|
+
ax2.legend(
|
|
402
|
+
facecolor=THEME.plot_background,
|
|
403
|
+
edgecolor=THEME.grid,
|
|
404
|
+
labelcolor=THEME.text,
|
|
405
|
+
fontsize=11,
|
|
406
|
+
)
|
|
407
|
+
for spine in ax2.spines.values():
|
|
408
|
+
spine.set_color(THEME.grid)
|
|
409
|
+
ax2.grid(True, alpha=0.3, color=THEME.grid)
|
|
410
|
+
|
|
411
|
+
# VRAM chart
|
|
412
|
+
ax3 = axes[2]
|
|
413
|
+
ax3.set_facecolor(THEME.plot_background)
|
|
414
|
+
ax3.plot(
|
|
415
|
+
batch_sizes,
|
|
416
|
+
sweep.vram_usage_gb,
|
|
417
|
+
color=THEME.palette[2],
|
|
418
|
+
marker="o",
|
|
419
|
+
linewidth=2,
|
|
420
|
+
markersize=8,
|
|
421
|
+
)
|
|
422
|
+
ax3.set_xlabel("Batch Size", color=THEME.text, fontsize=12)
|
|
423
|
+
ax3.set_ylabel("VRAM (GB)", color=THEME.text, fontsize=12)
|
|
424
|
+
ax3.set_title("VRAM vs Batch Size", color=THEME.text, fontsize=14, fontweight="bold")
|
|
425
|
+
ax3.tick_params(colors=THEME.text, labelsize=10)
|
|
426
|
+
for spine in ax3.spines.values():
|
|
427
|
+
spine.set_color(THEME.grid)
|
|
428
|
+
ax3.grid(True, alpha=0.3, color=THEME.grid)
|
|
429
|
+
|
|
430
|
+
plt.tight_layout()
|
|
431
|
+
fig.savefig(output_path, dpi=150, facecolor=THEME.background)
|
|
432
|
+
plt.close(fig)
|
|
433
|
+
|
|
434
|
+
return output_path
|
|
435
|
+
|
|
436
|
+
def operator_histogram(
|
|
437
|
+
self,
|
|
438
|
+
op_counts: dict[str, int],
|
|
439
|
+
output_path: Path,
|
|
440
|
+
max_ops: int = 15,
|
|
441
|
+
) -> Path | None:
|
|
442
|
+
"""
|
|
443
|
+
Generate operator type histogram.
|
|
444
|
+
|
|
445
|
+
Shows distribution of operator types in the model, sorted by frequency.
|
|
446
|
+
"""
|
|
447
|
+
if not _MATPLOTLIB_AVAILABLE or not op_counts:
|
|
448
|
+
return None
|
|
449
|
+
|
|
450
|
+
# Sort by count and take top N
|
|
451
|
+
sorted_ops = sorted(op_counts.items(), key=lambda x: -x[1])
|
|
452
|
+
if len(sorted_ops) > max_ops:
|
|
453
|
+
top_ops = sorted_ops[:max_ops]
|
|
454
|
+
other_count = sum(count for _, count in sorted_ops[max_ops:])
|
|
455
|
+
if other_count > 0:
|
|
456
|
+
top_ops.append(("Other", other_count))
|
|
457
|
+
else:
|
|
458
|
+
top_ops = sorted_ops
|
|
459
|
+
|
|
460
|
+
labels = [op for op, _ in top_ops]
|
|
461
|
+
counts = [count for _, count in top_ops]
|
|
462
|
+
|
|
463
|
+
# Create figure
|
|
464
|
+
fig, ax = plt.subplots(
|
|
465
|
+
figsize=(THEME.figure_width, THEME.figure_height), dpi=THEME.figure_dpi
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
# Horizontal bar chart
|
|
469
|
+
y_pos = range(len(labels))
|
|
470
|
+
bars = ax.barh(
|
|
471
|
+
y_pos,
|
|
472
|
+
counts,
|
|
473
|
+
color=THEME.accent_primary,
|
|
474
|
+
edgecolor=THEME.background,
|
|
475
|
+
height=0.7,
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
# Labels
|
|
479
|
+
ax.set_yticks(y_pos)
|
|
480
|
+
ax.set_yticklabels(labels)
|
|
481
|
+
ax.set_xlabel("Node Count")
|
|
482
|
+
ax.invert_yaxis() # Largest at top
|
|
483
|
+
|
|
484
|
+
# Add value labels on bars
|
|
485
|
+
for bar, count in zip(bars, counts, strict=False):
|
|
486
|
+
ax.text(
|
|
487
|
+
bar.get_width() + max(counts) * 0.01,
|
|
488
|
+
bar.get_y() + bar.get_height() / 2,
|
|
489
|
+
str(count),
|
|
490
|
+
va="center",
|
|
491
|
+
color=THEME.text,
|
|
492
|
+
fontsize=THEME.tick_size,
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
_apply_theme(fig, ax, "Operator Type Distribution")
|
|
496
|
+
|
|
497
|
+
# Adjust layout
|
|
498
|
+
plt.tight_layout()
|
|
499
|
+
fig.savefig(
|
|
500
|
+
output_path,
|
|
501
|
+
facecolor=THEME.background,
|
|
502
|
+
edgecolor="none",
|
|
503
|
+
bbox_inches="tight",
|
|
504
|
+
)
|
|
505
|
+
plt.close(fig)
|
|
506
|
+
|
|
507
|
+
return output_path
|
|
508
|
+
|
|
509
|
+
def param_distribution(
|
|
510
|
+
self,
|
|
511
|
+
params_by_op: dict[str, float],
|
|
512
|
+
output_path: Path,
|
|
513
|
+
max_slices: int = 6,
|
|
514
|
+
min_pct_for_label: float = 3.0,
|
|
515
|
+
) -> Path | None:
|
|
516
|
+
"""
|
|
517
|
+
Generate parameter distribution pie chart.
|
|
518
|
+
|
|
519
|
+
Shows how parameters are distributed across operator types.
|
|
520
|
+
Uses a legend to avoid label overlap on small slices.
|
|
521
|
+
"""
|
|
522
|
+
if not _MATPLOTLIB_AVAILABLE or not params_by_op:
|
|
523
|
+
return None
|
|
524
|
+
|
|
525
|
+
# Filter out zero-param ops and sort
|
|
526
|
+
nonzero_ops = {k: v for k, v in params_by_op.items() if v > 0}
|
|
527
|
+
if not nonzero_ops:
|
|
528
|
+
return None
|
|
529
|
+
|
|
530
|
+
total = sum(nonzero_ops.values())
|
|
531
|
+
sorted_ops = sorted(nonzero_ops.items(), key=lambda x: -x[1])
|
|
532
|
+
|
|
533
|
+
# Group small slices into "Other" more aggressively
|
|
534
|
+
top_ops: list[tuple[str, float]] = []
|
|
535
|
+
other_count = 0.0
|
|
536
|
+
for op, count in sorted_ops:
|
|
537
|
+
pct = (count / total) * 100
|
|
538
|
+
if len(top_ops) < max_slices and pct >= min_pct_for_label:
|
|
539
|
+
top_ops.append((op, count))
|
|
540
|
+
else:
|
|
541
|
+
other_count += count
|
|
542
|
+
|
|
543
|
+
if other_count > 0:
|
|
544
|
+
top_ops.append(("Other", other_count))
|
|
545
|
+
|
|
546
|
+
# If we only have "Other", show top ops anyway
|
|
547
|
+
if len(top_ops) == 1 and top_ops[0][0] == "Other":
|
|
548
|
+
top_ops = sorted_ops[:max_slices]
|
|
549
|
+
if len(sorted_ops) > max_slices:
|
|
550
|
+
other_count = sum(count for _, count in sorted_ops[max_slices:])
|
|
551
|
+
if other_count > 0:
|
|
552
|
+
top_ops.append(("Other", other_count))
|
|
553
|
+
|
|
554
|
+
sizes = [count for _, count in top_ops]
|
|
555
|
+
colors = THEME.palette[: len(sizes)]
|
|
556
|
+
|
|
557
|
+
# Create figure with space for legend
|
|
558
|
+
fig, ax = plt.subplots(figsize=(10, 7), dpi=THEME.figure_dpi)
|
|
559
|
+
|
|
560
|
+
# Create pie with no labels (use legend instead)
|
|
561
|
+
wedges, _texts, autotexts = ax.pie(
|
|
562
|
+
sizes,
|
|
563
|
+
labels=None, # No inline labels
|
|
564
|
+
colors=colors,
|
|
565
|
+
autopct=lambda pct: f"{pct:.1f}%" if pct >= 5 else "",
|
|
566
|
+
startangle=90,
|
|
567
|
+
pctdistance=0.75,
|
|
568
|
+
textprops={"color": THEME.text, "fontsize": 11, "fontweight": "bold"},
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
# Style the percentage text
|
|
572
|
+
for autotext in autotexts:
|
|
573
|
+
autotext.set_color("#ffffff")
|
|
574
|
+
autotext.set_fontweight("bold")
|
|
575
|
+
|
|
576
|
+
# Create legend with op names and param counts
|
|
577
|
+
legend_labels = [f"{op} ({_format_count(int(count))})" for op, count in top_ops]
|
|
578
|
+
ax.legend(
|
|
579
|
+
wedges,
|
|
580
|
+
legend_labels,
|
|
581
|
+
title="Operator Type",
|
|
582
|
+
loc="center left",
|
|
583
|
+
bbox_to_anchor=(1.0, 0.5),
|
|
584
|
+
fontsize=10,
|
|
585
|
+
title_fontsize=11,
|
|
586
|
+
frameon=True,
|
|
587
|
+
facecolor=THEME.plot_background,
|
|
588
|
+
edgecolor=THEME.grid,
|
|
589
|
+
labelcolor=THEME.text,
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
ax.set_title(
|
|
593
|
+
"Parameter Distribution by Operator Type",
|
|
594
|
+
color=THEME.text,
|
|
595
|
+
fontsize=THEME.title_size,
|
|
596
|
+
fontweight="bold",
|
|
597
|
+
pad=20,
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
fig.patch.set_facecolor(THEME.background)
|
|
601
|
+
|
|
602
|
+
plt.tight_layout()
|
|
603
|
+
fig.savefig(
|
|
604
|
+
output_path,
|
|
605
|
+
facecolor=THEME.background,
|
|
606
|
+
edgecolor="none",
|
|
607
|
+
bbox_inches="tight",
|
|
608
|
+
)
|
|
609
|
+
plt.close(fig)
|
|
610
|
+
|
|
611
|
+
return output_path
|
|
612
|
+
|
|
613
|
+
def flops_distribution(
|
|
614
|
+
self,
|
|
615
|
+
flops_by_op: dict[str, int],
|
|
616
|
+
output_path: Path,
|
|
617
|
+
max_ops: int = 10,
|
|
618
|
+
) -> Path | None:
|
|
619
|
+
"""
|
|
620
|
+
Generate FLOPs distribution bar chart.
|
|
621
|
+
|
|
622
|
+
Shows computational cost distribution across operator types.
|
|
623
|
+
"""
|
|
624
|
+
if not _MATPLOTLIB_AVAILABLE or not flops_by_op:
|
|
625
|
+
return None
|
|
626
|
+
|
|
627
|
+
# Filter out zero-FLOP ops and sort
|
|
628
|
+
nonzero_ops = {k: v for k, v in flops_by_op.items() if v > 0}
|
|
629
|
+
if not nonzero_ops:
|
|
630
|
+
return None
|
|
631
|
+
|
|
632
|
+
sorted_ops = sorted(nonzero_ops.items(), key=lambda x: -x[1])[:max_ops]
|
|
633
|
+
|
|
634
|
+
labels = [op for op, _ in sorted_ops]
|
|
635
|
+
values = [flops for _, flops in sorted_ops]
|
|
636
|
+
|
|
637
|
+
# Create figure
|
|
638
|
+
fig, ax = plt.subplots(
|
|
639
|
+
figsize=(THEME.figure_width, THEME.figure_height), dpi=THEME.figure_dpi
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
x_pos = range(len(labels))
|
|
643
|
+
bars = ax.bar(
|
|
644
|
+
x_pos,
|
|
645
|
+
values,
|
|
646
|
+
color=THEME.accent_secondary,
|
|
647
|
+
edgecolor=THEME.background,
|
|
648
|
+
width=0.7,
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
ax.set_xticks(x_pos)
|
|
652
|
+
ax.set_xticklabels(labels, rotation=45, ha="right")
|
|
653
|
+
ax.set_ylabel("FLOPs")
|
|
654
|
+
|
|
655
|
+
# Add value labels on bars
|
|
656
|
+
for bar, value in zip(bars, values, strict=False):
|
|
657
|
+
ax.text(
|
|
658
|
+
bar.get_x() + bar.get_width() / 2,
|
|
659
|
+
bar.get_height() + max(values) * 0.02,
|
|
660
|
+
_format_count(value),
|
|
661
|
+
ha="center",
|
|
662
|
+
va="bottom",
|
|
663
|
+
color=THEME.text,
|
|
664
|
+
fontsize=THEME.tick_size,
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
_apply_theme(fig, ax, "FLOPs Distribution by Operator Type")
|
|
668
|
+
|
|
669
|
+
plt.tight_layout()
|
|
670
|
+
fig.savefig(
|
|
671
|
+
output_path,
|
|
672
|
+
facecolor=THEME.background,
|
|
673
|
+
edgecolor="none",
|
|
674
|
+
bbox_inches="tight",
|
|
675
|
+
)
|
|
676
|
+
plt.close(fig)
|
|
677
|
+
|
|
678
|
+
return output_path
|
|
679
|
+
|
|
680
|
+
def complexity_summary(
|
|
681
|
+
self,
|
|
682
|
+
report: InspectionReport,
|
|
683
|
+
output_path: Path,
|
|
684
|
+
) -> Path | None:
|
|
685
|
+
"""
|
|
686
|
+
Generate complexity summary dashboard.
|
|
687
|
+
|
|
688
|
+
Multi-panel figure showing key metrics at a glance.
|
|
689
|
+
"""
|
|
690
|
+
if not _MATPLOTLIB_AVAILABLE:
|
|
691
|
+
return None
|
|
692
|
+
if not report.param_counts or not report.flop_counts or not report.memory_estimates:
|
|
693
|
+
return None
|
|
694
|
+
|
|
695
|
+
fig, axes = plt.subplots(1, 3, figsize=(THEME.figure_width, 4), dpi=THEME.figure_dpi)
|
|
696
|
+
fig.patch.set_facecolor(THEME.background)
|
|
697
|
+
|
|
698
|
+
metrics = [
|
|
699
|
+
("Parameters", report.param_counts.total, THEME.accent_primary),
|
|
700
|
+
("FLOPs", report.flop_counts.total, THEME.accent_secondary),
|
|
701
|
+
(
|
|
702
|
+
"Memory (bytes)",
|
|
703
|
+
report.memory_estimates.model_size_bytes,
|
|
704
|
+
THEME.accent_tertiary,
|
|
705
|
+
),
|
|
706
|
+
]
|
|
707
|
+
|
|
708
|
+
for ax, (label, value, color) in zip(axes, metrics, strict=False):
|
|
709
|
+
ax.set_facecolor(THEME.background)
|
|
710
|
+
|
|
711
|
+
# Large centered value
|
|
712
|
+
ax.text(
|
|
713
|
+
0.5,
|
|
714
|
+
0.6,
|
|
715
|
+
_format_count(value),
|
|
716
|
+
ha="center",
|
|
717
|
+
va="center",
|
|
718
|
+
fontsize=28,
|
|
719
|
+
fontweight="bold",
|
|
720
|
+
color=color,
|
|
721
|
+
transform=ax.transAxes,
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
# Label below
|
|
725
|
+
ax.text(
|
|
726
|
+
0.5,
|
|
727
|
+
0.25,
|
|
728
|
+
label,
|
|
729
|
+
ha="center",
|
|
730
|
+
va="center",
|
|
731
|
+
fontsize=THEME.label_size,
|
|
732
|
+
color=THEME.text,
|
|
733
|
+
transform=ax.transAxes,
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
# Remove all axes elements
|
|
737
|
+
ax.set_xticks([])
|
|
738
|
+
ax.set_yticks([])
|
|
739
|
+
for spine in ax.spines.values():
|
|
740
|
+
spine.set_visible(False)
|
|
741
|
+
|
|
742
|
+
# Main title
|
|
743
|
+
fig.suptitle(
|
|
744
|
+
"Model Complexity Summary",
|
|
745
|
+
color=THEME.text,
|
|
746
|
+
fontsize=THEME.title_size,
|
|
747
|
+
fontweight="bold",
|
|
748
|
+
y=0.95,
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
plt.tight_layout(rect=[0, 0, 1, 0.9])
|
|
752
|
+
fig.savefig(
|
|
753
|
+
output_path,
|
|
754
|
+
facecolor=THEME.background,
|
|
755
|
+
edgecolor="none",
|
|
756
|
+
bbox_inches="tight",
|
|
757
|
+
)
|
|
758
|
+
plt.close(fig)
|
|
759
|
+
|
|
760
|
+
return output_path
|
|
761
|
+
|
|
762
|
+
def layer_depth_profile(
|
|
763
|
+
self,
|
|
764
|
+
graph_info: GraphInfo,
|
|
765
|
+
param_counts: ParamCounts,
|
|
766
|
+
flop_counts: FlopCounts,
|
|
767
|
+
output_path: Path,
|
|
768
|
+
) -> Path | None:
|
|
769
|
+
"""
|
|
770
|
+
Generate layer depth profile showing cumulative params/FLOPs.
|
|
771
|
+
|
|
772
|
+
Shows how complexity accumulates through the network depth.
|
|
773
|
+
"""
|
|
774
|
+
if not _MATPLOTLIB_AVAILABLE:
|
|
775
|
+
return None
|
|
776
|
+
|
|
777
|
+
# Get ordered nodes with their metrics
|
|
778
|
+
nodes_with_metrics = []
|
|
779
|
+
for node in graph_info.nodes:
|
|
780
|
+
params = param_counts.by_node.get(node.name, 0)
|
|
781
|
+
flops = flop_counts.by_node.get(node.name, 0)
|
|
782
|
+
if params > 0 or flops > 0:
|
|
783
|
+
nodes_with_metrics.append((node.name, params, flops))
|
|
784
|
+
|
|
785
|
+
if not nodes_with_metrics:
|
|
786
|
+
return None
|
|
787
|
+
|
|
788
|
+
# Compute cumulative values
|
|
789
|
+
cum_params: list[float] = []
|
|
790
|
+
cum_flops: list[int] = []
|
|
791
|
+
running_params: float = 0.0
|
|
792
|
+
running_flops: int = 0
|
|
793
|
+
for _, params, flops in nodes_with_metrics:
|
|
794
|
+
running_params += params
|
|
795
|
+
running_flops += flops
|
|
796
|
+
cum_params.append(running_params)
|
|
797
|
+
cum_flops.append(running_flops)
|
|
798
|
+
|
|
799
|
+
x = range(len(nodes_with_metrics))
|
|
800
|
+
|
|
801
|
+
# Create figure with two y-axes
|
|
802
|
+
fig, ax1 = plt.subplots(
|
|
803
|
+
figsize=(THEME.figure_width, THEME.figure_height), dpi=THEME.figure_dpi
|
|
804
|
+
)
|
|
805
|
+
ax2 = ax1.twinx()
|
|
806
|
+
|
|
807
|
+
# Plot cumulative params
|
|
808
|
+
ax1.fill_between(x, cum_params, alpha=0.3, color=THEME.accent_primary)
|
|
809
|
+
ax1.plot(x, cum_params, color=THEME.accent_primary, linewidth=2, label="Parameters")
|
|
810
|
+
ax1.set_ylabel("Cumulative Parameters", color=THEME.accent_primary)
|
|
811
|
+
ax1.tick_params(axis="y", labelcolor=THEME.accent_primary)
|
|
812
|
+
|
|
813
|
+
# Plot cumulative FLOPs
|
|
814
|
+
ax2.fill_between(x, cum_flops, alpha=0.3, color=THEME.accent_secondary)
|
|
815
|
+
ax2.plot(x, cum_flops, color=THEME.accent_secondary, linewidth=2, label="FLOPs")
|
|
816
|
+
ax2.set_ylabel("Cumulative FLOPs", color=THEME.accent_secondary)
|
|
817
|
+
ax2.tick_params(axis="y", labelcolor=THEME.accent_secondary)
|
|
818
|
+
|
|
819
|
+
ax1.set_xlabel("Layer Index")
|
|
820
|
+
|
|
821
|
+
# Apply theme to primary axis
|
|
822
|
+
fig.patch.set_facecolor(THEME.background)
|
|
823
|
+
ax1.set_facecolor(THEME.background)
|
|
824
|
+
ax1.set_title(
|
|
825
|
+
"Layer Depth Profile",
|
|
826
|
+
color=THEME.text,
|
|
827
|
+
fontsize=THEME.title_size,
|
|
828
|
+
fontweight="bold",
|
|
829
|
+
pad=15,
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
for spine in ax1.spines.values():
|
|
833
|
+
spine.set_color(THEME.grid)
|
|
834
|
+
for spine in ax2.spines.values():
|
|
835
|
+
spine.set_color(THEME.grid)
|
|
836
|
+
|
|
837
|
+
ax1.tick_params(colors=THEME.text, labelsize=THEME.tick_size)
|
|
838
|
+
ax1.xaxis.label.set_color(THEME.text)
|
|
839
|
+
ax1.xaxis.label.set_fontsize(THEME.label_size)
|
|
840
|
+
|
|
841
|
+
ax1.grid(True, linestyle="--", alpha=0.3, color=THEME.grid)
|
|
842
|
+
|
|
843
|
+
# Legend
|
|
844
|
+
lines = [
|
|
845
|
+
plt.Line2D([0], [0], color=THEME.accent_primary, linewidth=2),
|
|
846
|
+
plt.Line2D([0], [0], color=THEME.accent_secondary, linewidth=2),
|
|
847
|
+
]
|
|
848
|
+
ax1.legend(
|
|
849
|
+
lines,
|
|
850
|
+
["Parameters", "FLOPs"],
|
|
851
|
+
loc="upper left",
|
|
852
|
+
facecolor=THEME.background,
|
|
853
|
+
labelcolor=THEME.text,
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
plt.tight_layout()
|
|
857
|
+
fig.savefig(
|
|
858
|
+
output_path,
|
|
859
|
+
facecolor=THEME.background,
|
|
860
|
+
edgecolor="none",
|
|
861
|
+
bbox_inches="tight",
|
|
862
|
+
)
|
|
863
|
+
plt.close(fig)
|
|
864
|
+
|
|
865
|
+
return output_path
|
|
866
|
+
|
|
867
|
+
def layer_timing_chart(
|
|
868
|
+
self,
|
|
869
|
+
profiling_result: ProfilingResult,
|
|
870
|
+
output_path: Path,
|
|
871
|
+
top_n: int = 15,
|
|
872
|
+
) -> Path | None:
|
|
873
|
+
"""
|
|
874
|
+
Generate per-layer timing breakdown chart (Story 9.3.4).
|
|
875
|
+
|
|
876
|
+
Shows execution time for the slowest N layers as a horizontal bar chart.
|
|
877
|
+
|
|
878
|
+
Args:
|
|
879
|
+
profiling_result: ProfilingResult from OperationalProfiler.profile_model()
|
|
880
|
+
output_path: Path to save the chart
|
|
881
|
+
top_n: Number of slowest layers to show
|
|
882
|
+
|
|
883
|
+
Returns:
|
|
884
|
+
Path to generated chart or None if unavailable
|
|
885
|
+
"""
|
|
886
|
+
if not _MATPLOTLIB_AVAILABLE:
|
|
887
|
+
return None
|
|
888
|
+
|
|
889
|
+
# Import at runtime
|
|
890
|
+
from .operational_profiling import ProfilingResult
|
|
891
|
+
|
|
892
|
+
if not isinstance(profiling_result, ProfilingResult):
|
|
893
|
+
return None
|
|
894
|
+
|
|
895
|
+
slowest = profiling_result.get_slowest_layers(top_n)
|
|
896
|
+
if not slowest:
|
|
897
|
+
return None
|
|
898
|
+
|
|
899
|
+
# Prepare data (reversed for horizontal bar chart - slowest at top)
|
|
900
|
+
layers = [lp.name[:30] for lp in reversed(slowest)] # Truncate long names
|
|
901
|
+
times = [lp.duration_ms for lp in reversed(slowest)]
|
|
902
|
+
op_types = [lp.op_type for lp in reversed(slowest)]
|
|
903
|
+
|
|
904
|
+
# Create figure
|
|
905
|
+
fig, ax = plt.subplots(figsize=(12, max(6, len(slowest) * 0.4)), dpi=THEME.figure_dpi)
|
|
906
|
+
|
|
907
|
+
# Color by op type
|
|
908
|
+
unique_ops = list(set(op_types))
|
|
909
|
+
colors = [THEME.palette[unique_ops.index(op) % len(THEME.palette)] for op in op_types]
|
|
910
|
+
|
|
911
|
+
y_pos = range(len(layers))
|
|
912
|
+
bars = ax.barh(y_pos, times, color=colors, edgecolor=THEME.background, height=0.7)
|
|
913
|
+
|
|
914
|
+
ax.set_yticks(list(y_pos))
|
|
915
|
+
ax.set_yticklabels(layers, fontsize=9)
|
|
916
|
+
ax.set_xlabel("Time (ms)", color=THEME.text, fontsize=10)
|
|
917
|
+
|
|
918
|
+
# Add value labels
|
|
919
|
+
for bar, time_val, op_type in zip(bars, times, op_types, strict=False):
|
|
920
|
+
ax.text(
|
|
921
|
+
bar.get_width() + max(times) * 0.01,
|
|
922
|
+
bar.get_y() + bar.get_height() / 2,
|
|
923
|
+
f"{time_val:.2f}ms ({op_type})",
|
|
924
|
+
va="center",
|
|
925
|
+
ha="left",
|
|
926
|
+
color=THEME.text,
|
|
927
|
+
fontsize=8,
|
|
928
|
+
)
|
|
929
|
+
|
|
930
|
+
_apply_theme(
|
|
931
|
+
fig,
|
|
932
|
+
ax,
|
|
933
|
+
f"Top {len(slowest)} Slowest Layers (Total: {profiling_result.total_time_ms:.2f}ms)",
|
|
934
|
+
)
|
|
935
|
+
|
|
936
|
+
# Add legend for op types
|
|
937
|
+
from matplotlib.patches import Patch
|
|
938
|
+
|
|
939
|
+
legend_elements = [
|
|
940
|
+
Patch(facecolor=THEME.palette[i % len(THEME.palette)], label=op)
|
|
941
|
+
for i, op in enumerate(unique_ops)
|
|
942
|
+
]
|
|
943
|
+
ax.legend(
|
|
944
|
+
handles=legend_elements,
|
|
945
|
+
loc="lower right",
|
|
946
|
+
fontsize=8,
|
|
947
|
+
facecolor=THEME.plot_background,
|
|
948
|
+
edgecolor=THEME.grid,
|
|
949
|
+
labelcolor=THEME.text,
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
plt.tight_layout()
|
|
953
|
+
fig.savefig(
|
|
954
|
+
output_path,
|
|
955
|
+
facecolor=THEME.background,
|
|
956
|
+
edgecolor="none",
|
|
957
|
+
bbox_inches="tight",
|
|
958
|
+
)
|
|
959
|
+
plt.close(fig)
|
|
960
|
+
|
|
961
|
+
return output_path
|
|
962
|
+
|
|
963
|
+
def op_time_distribution_chart(
|
|
964
|
+
self,
|
|
965
|
+
profiling_result: ProfilingResult,
|
|
966
|
+
output_path: Path,
|
|
967
|
+
) -> Path | None:
|
|
968
|
+
"""
|
|
969
|
+
Generate operator type time distribution pie chart.
|
|
970
|
+
|
|
971
|
+
Shows percentage of execution time by operator type.
|
|
972
|
+
|
|
973
|
+
Args:
|
|
974
|
+
profiling_result: ProfilingResult from profiling
|
|
975
|
+
output_path: Path to save the chart
|
|
976
|
+
|
|
977
|
+
Returns:
|
|
978
|
+
Path to generated chart or None if unavailable
|
|
979
|
+
"""
|
|
980
|
+
if not _MATPLOTLIB_AVAILABLE:
|
|
981
|
+
return None
|
|
982
|
+
|
|
983
|
+
from .operational_profiling import ProfilingResult
|
|
984
|
+
|
|
985
|
+
if not isinstance(profiling_result, ProfilingResult):
|
|
986
|
+
return None
|
|
987
|
+
|
|
988
|
+
time_by_op = profiling_result.get_time_by_op_type()
|
|
989
|
+
if not time_by_op:
|
|
990
|
+
return None
|
|
991
|
+
|
|
992
|
+
# Limit to top 8 ops, aggregate rest into "Other"
|
|
993
|
+
sorted_ops = sorted(time_by_op.items(), key=lambda x: -x[1])
|
|
994
|
+
labels = []
|
|
995
|
+
values = []
|
|
996
|
+
other_time = 0.0
|
|
997
|
+
|
|
998
|
+
for i, (op, time_val) in enumerate(sorted_ops):
|
|
999
|
+
if i < 8:
|
|
1000
|
+
labels.append(op)
|
|
1001
|
+
values.append(time_val)
|
|
1002
|
+
else:
|
|
1003
|
+
other_time += time_val
|
|
1004
|
+
|
|
1005
|
+
if other_time > 0:
|
|
1006
|
+
labels.append("Other")
|
|
1007
|
+
values.append(other_time)
|
|
1008
|
+
|
|
1009
|
+
# Create figure
|
|
1010
|
+
fig, ax = plt.subplots(figsize=(10, 8), dpi=THEME.figure_dpi)
|
|
1011
|
+
ax.set_facecolor(THEME.plot_background)
|
|
1012
|
+
|
|
1013
|
+
colors = THEME.palette[: len(values)]
|
|
1014
|
+
wedges, _texts, autotexts = ax.pie(
|
|
1015
|
+
values,
|
|
1016
|
+
labels=None,
|
|
1017
|
+
autopct=lambda pct: f"{pct:.1f}%" if pct > 3 else "",
|
|
1018
|
+
startangle=90,
|
|
1019
|
+
colors=colors,
|
|
1020
|
+
pctdistance=0.75,
|
|
1021
|
+
wedgeprops={"edgecolor": THEME.background, "linewidth": 1},
|
|
1022
|
+
)
|
|
1023
|
+
|
|
1024
|
+
for autotext in autotexts:
|
|
1025
|
+
autotext.set_color(THEME.text)
|
|
1026
|
+
autotext.set_fontsize(9)
|
|
1027
|
+
|
|
1028
|
+
# Legend
|
|
1029
|
+
legend_labels = [
|
|
1030
|
+
f"{op} ({time_val:.2f}ms)" for op, time_val in zip(labels, values, strict=False)
|
|
1031
|
+
]
|
|
1032
|
+
ax.legend(
|
|
1033
|
+
wedges,
|
|
1034
|
+
legend_labels,
|
|
1035
|
+
title="Operator Type",
|
|
1036
|
+
loc="center left",
|
|
1037
|
+
bbox_to_anchor=(1.0, 0.5),
|
|
1038
|
+
fontsize=9,
|
|
1039
|
+
title_fontsize=10,
|
|
1040
|
+
frameon=True,
|
|
1041
|
+
facecolor=THEME.plot_background,
|
|
1042
|
+
edgecolor=THEME.grid,
|
|
1043
|
+
labelcolor=THEME.text,
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1046
|
+
total_time = sum(values)
|
|
1047
|
+
ax.set_title(
|
|
1048
|
+
f"Execution Time Distribution (Total: {total_time:.2f}ms)",
|
|
1049
|
+
color=THEME.text,
|
|
1050
|
+
fontsize=THEME.title_size,
|
|
1051
|
+
fontweight="bold",
|
|
1052
|
+
pad=20,
|
|
1053
|
+
)
|
|
1054
|
+
|
|
1055
|
+
fig.patch.set_facecolor(THEME.background)
|
|
1056
|
+
|
|
1057
|
+
plt.tight_layout()
|
|
1058
|
+
fig.savefig(
|
|
1059
|
+
output_path,
|
|
1060
|
+
facecolor=THEME.background,
|
|
1061
|
+
edgecolor="none",
|
|
1062
|
+
bbox_inches="tight",
|
|
1063
|
+
)
|
|
1064
|
+
plt.close(fig)
|
|
1065
|
+
|
|
1066
|
+
return output_path
|
|
1067
|
+
|
|
1068
|
+
|
|
1069
|
+
def generate_visualizations(
|
|
1070
|
+
report: InspectionReport,
|
|
1071
|
+
output_dir: Path | str,
|
|
1072
|
+
logger: logging.Logger | None = None,
|
|
1073
|
+
) -> dict[str, Path]:
|
|
1074
|
+
"""
|
|
1075
|
+
Convenience function to generate all visualizations for a report.
|
|
1076
|
+
|
|
1077
|
+
Args:
|
|
1078
|
+
report: The inspection report.
|
|
1079
|
+
output_dir: Directory to save PNG files.
|
|
1080
|
+
logger: Optional logger.
|
|
1081
|
+
|
|
1082
|
+
Returns:
|
|
1083
|
+
Dict mapping chart name to file path.
|
|
1084
|
+
"""
|
|
1085
|
+
generator = VisualizationGenerator(logger=logger)
|
|
1086
|
+
return generator.generate_all(report, Path(output_dir))
|