haoline 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. haoline/.streamlit/config.toml +10 -0
  2. haoline/__init__.py +248 -0
  3. haoline/analyzer.py +935 -0
  4. haoline/cli.py +2712 -0
  5. haoline/compare.py +811 -0
  6. haoline/compare_visualizations.py +1564 -0
  7. haoline/edge_analysis.py +525 -0
  8. haoline/eval/__init__.py +131 -0
  9. haoline/eval/adapters.py +844 -0
  10. haoline/eval/cli.py +390 -0
  11. haoline/eval/comparison.py +542 -0
  12. haoline/eval/deployment.py +633 -0
  13. haoline/eval/schemas.py +833 -0
  14. haoline/examples/__init__.py +15 -0
  15. haoline/examples/basic_inspection.py +74 -0
  16. haoline/examples/compare_models.py +117 -0
  17. haoline/examples/hardware_estimation.py +78 -0
  18. haoline/format_adapters.py +1001 -0
  19. haoline/formats/__init__.py +123 -0
  20. haoline/formats/coreml.py +250 -0
  21. haoline/formats/gguf.py +483 -0
  22. haoline/formats/openvino.py +255 -0
  23. haoline/formats/safetensors.py +273 -0
  24. haoline/formats/tflite.py +369 -0
  25. haoline/hardware.py +2307 -0
  26. haoline/hierarchical_graph.py +462 -0
  27. haoline/html_export.py +1573 -0
  28. haoline/layer_summary.py +769 -0
  29. haoline/llm_summarizer.py +465 -0
  30. haoline/op_icons.py +618 -0
  31. haoline/operational_profiling.py +1492 -0
  32. haoline/patterns.py +1116 -0
  33. haoline/pdf_generator.py +265 -0
  34. haoline/privacy.py +250 -0
  35. haoline/pydantic_models.py +241 -0
  36. haoline/report.py +1923 -0
  37. haoline/report_sections.py +539 -0
  38. haoline/risks.py +521 -0
  39. haoline/schema.py +523 -0
  40. haoline/streamlit_app.py +2024 -0
  41. haoline/tests/__init__.py +4 -0
  42. haoline/tests/conftest.py +123 -0
  43. haoline/tests/test_analyzer.py +868 -0
  44. haoline/tests/test_compare_visualizations.py +293 -0
  45. haoline/tests/test_edge_analysis.py +243 -0
  46. haoline/tests/test_eval.py +604 -0
  47. haoline/tests/test_format_adapters.py +460 -0
  48. haoline/tests/test_hardware.py +237 -0
  49. haoline/tests/test_hardware_recommender.py +90 -0
  50. haoline/tests/test_hierarchical_graph.py +326 -0
  51. haoline/tests/test_html_export.py +180 -0
  52. haoline/tests/test_layer_summary.py +428 -0
  53. haoline/tests/test_llm_patterns.py +540 -0
  54. haoline/tests/test_llm_summarizer.py +339 -0
  55. haoline/tests/test_patterns.py +774 -0
  56. haoline/tests/test_pytorch.py +327 -0
  57. haoline/tests/test_report.py +383 -0
  58. haoline/tests/test_risks.py +398 -0
  59. haoline/tests/test_schema.py +417 -0
  60. haoline/tests/test_tensorflow.py +380 -0
  61. haoline/tests/test_visualizations.py +316 -0
  62. haoline/universal_ir.py +856 -0
  63. haoline/visualizations.py +1086 -0
  64. haoline/visualize_yolo.py +44 -0
  65. haoline/web.py +110 -0
  66. haoline-0.3.0.dist-info/METADATA +471 -0
  67. haoline-0.3.0.dist-info/RECORD +70 -0
  68. haoline-0.3.0.dist-info/WHEEL +4 -0
  69. haoline-0.3.0.dist-info/entry_points.txt +5 -0
  70. haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,380 @@
1
+ # Copyright (c) 2025 HaoLine Contributors
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """
5
+ Unit tests for TensorFlow/Keras/JAX to ONNX conversion functionality.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import importlib.util
11
+ import logging
12
+ from unittest.mock import patch
13
+
14
+ import pytest
15
+
16
+ from ..cli import (
17
+ _convert_frozen_graph_to_onnx,
18
+ _convert_jax_to_onnx,
19
+ _convert_keras_to_onnx,
20
+ _convert_tensorflow_to_onnx,
21
+ )
22
+
23
+ # Check if TensorFlow and tf2onnx are available
24
+ _TF_AVAILABLE = (
25
+ importlib.util.find_spec("tensorflow") is not None
26
+ and importlib.util.find_spec("tf2onnx") is not None
27
+ )
28
+ _JAX_AVAILABLE = importlib.util.find_spec("jax") is not None
29
+
30
+ # Import TensorFlow only when needed for tests
31
+ if _TF_AVAILABLE:
32
+ import tensorflow as tf
33
+
34
+
35
+ @pytest.fixture
36
+ def logger():
37
+ """Create a logger for tests."""
38
+ return logging.getLogger("test_tensorflow")
39
+
40
+
41
+ class TestTensorFlowConversionErrors:
42
+ """Test error handling for TensorFlow conversion (no TF required)."""
43
+
44
+ def test_nonexistent_savedmodel(self, tmp_path, logger):
45
+ """Conversion should fail for non-existent SavedModel."""
46
+ fake_path = tmp_path / "nonexistent_model"
47
+
48
+ onnx_path, _ = _convert_tensorflow_to_onnx(
49
+ fake_path,
50
+ output_path=None,
51
+ opset_version=17,
52
+ logger=logger,
53
+ )
54
+
55
+ assert onnx_path is None
56
+
57
+ def test_tf2onnx_not_installed(self, tmp_path, logger):
58
+ """Conversion should fail gracefully when tf2onnx not installed."""
59
+ # Create a fake directory to avoid "not found" error
60
+ fake_model = tmp_path / "fake_saved_model"
61
+ fake_model.mkdir()
62
+
63
+ # Mock tf2onnx import to fail
64
+ with patch.dict("sys.modules", {"tf2onnx": None}):
65
+ # Force re-import check by clearing cached imports
66
+
67
+ # Save original function
68
+
69
+ # The function checks for import at runtime, so we need to
70
+ # test the actual error path
71
+ _onnx_path, _ = _convert_tensorflow_to_onnx(
72
+ fake_model,
73
+ output_path=None,
74
+ opset_version=17,
75
+ logger=logger,
76
+ )
77
+
78
+ # If tf2onnx is not installed, this should fail
79
+ # (result depends on actual installation state)
80
+
81
+
82
+ class TestKerasConversionErrors:
83
+ """Test error handling for Keras conversion (no TF required)."""
84
+
85
+ def test_nonexistent_keras_file(self, tmp_path, logger):
86
+ """Conversion should fail for non-existent Keras file."""
87
+ fake_path = tmp_path / "nonexistent.h5"
88
+
89
+ onnx_path, _ = _convert_keras_to_onnx(
90
+ fake_path,
91
+ output_path=None,
92
+ opset_version=17,
93
+ logger=logger,
94
+ )
95
+
96
+ assert onnx_path is None
97
+
98
+ def test_unexpected_extension_warning(self, tmp_path, logger, caplog):
99
+ """Unexpected file extension should log warning but proceed."""
100
+ # Create a fake file with wrong extension
101
+ fake_file = tmp_path / "model.wrongext"
102
+ fake_file.write_text("fake content")
103
+
104
+ with caplog.at_level(logging.WARNING):
105
+ _onnx_path, _ = _convert_keras_to_onnx(
106
+ fake_file,
107
+ output_path=None,
108
+ opset_version=17,
109
+ logger=logger,
110
+ )
111
+
112
+ # Should have logged a warning about unexpected extension
113
+ # (will still fail because it's not a real model)
114
+
115
+
116
+ class TestFrozenGraphConversionErrors:
117
+ """Test error handling for frozen graph conversion."""
118
+
119
+ def test_missing_inputs_outputs(self, tmp_path, logger):
120
+ """Conversion should fail when inputs/outputs not specified."""
121
+ fake_pb = tmp_path / "model.pb"
122
+ fake_pb.write_bytes(b"fake")
123
+
124
+ # Missing inputs
125
+ onnx_path, _ = _convert_frozen_graph_to_onnx(
126
+ fake_pb,
127
+ inputs=None,
128
+ outputs="output:0",
129
+ output_path=None,
130
+ opset_version=17,
131
+ logger=logger,
132
+ )
133
+ assert onnx_path is None
134
+
135
+ # Missing outputs
136
+ onnx_path, _ = _convert_frozen_graph_to_onnx(
137
+ fake_pb,
138
+ inputs="input:0",
139
+ outputs=None,
140
+ output_path=None,
141
+ opset_version=17,
142
+ logger=logger,
143
+ )
144
+ assert onnx_path is None
145
+
146
+ def test_nonexistent_pb_file(self, tmp_path, logger):
147
+ """Conversion should fail for non-existent .pb file."""
148
+ fake_path = tmp_path / "nonexistent.pb"
149
+
150
+ onnx_path, _ = _convert_frozen_graph_to_onnx(
151
+ fake_path,
152
+ inputs="input:0",
153
+ outputs="output:0",
154
+ output_path=None,
155
+ opset_version=17,
156
+ logger=logger,
157
+ )
158
+
159
+ assert onnx_path is None
160
+
161
+
162
+ class TestJAXConversionErrors:
163
+ """Test error handling for JAX conversion."""
164
+
165
+ def test_missing_apply_fn(self, tmp_path, logger):
166
+ """Conversion should fail when --jax-apply-fn not provided."""
167
+ fake_params = tmp_path / "params.pkl"
168
+ fake_params.write_bytes(b"fake")
169
+
170
+ onnx_path, _ = _convert_jax_to_onnx(
171
+ fake_params,
172
+ apply_fn_path=None, # Missing
173
+ input_shape_str="1,3,224,224",
174
+ output_path=None,
175
+ opset_version=17,
176
+ logger=logger,
177
+ )
178
+
179
+ assert onnx_path is None
180
+
181
+ def test_missing_input_shape(self, tmp_path, logger):
182
+ """Conversion should fail when --input-shape not provided."""
183
+ fake_params = tmp_path / "params.pkl"
184
+ fake_params.write_bytes(b"fake")
185
+
186
+ onnx_path, _ = _convert_jax_to_onnx(
187
+ fake_params,
188
+ apply_fn_path="module:function",
189
+ input_shape_str=None, # Missing
190
+ output_path=None,
191
+ opset_version=17,
192
+ logger=logger,
193
+ )
194
+
195
+ assert onnx_path is None
196
+
197
+ def test_invalid_apply_fn_format(self, tmp_path, logger):
198
+ """Conversion should fail for invalid apply_fn format."""
199
+ fake_params = tmp_path / "params.pkl"
200
+ fake_params.write_bytes(b"fake")
201
+
202
+ onnx_path, _ = _convert_jax_to_onnx(
203
+ fake_params,
204
+ apply_fn_path="no_colon_separator", # Invalid format
205
+ input_shape_str="1,3,224,224",
206
+ output_path=None,
207
+ opset_version=17,
208
+ logger=logger,
209
+ )
210
+
211
+ assert onnx_path is None
212
+
213
+ def test_invalid_input_shape_format(self, tmp_path, logger):
214
+ """Conversion should fail for invalid input shape format."""
215
+ fake_params = tmp_path / "params.pkl"
216
+ fake_params.write_bytes(b"fake")
217
+
218
+ onnx_path, _ = _convert_jax_to_onnx(
219
+ fake_params,
220
+ apply_fn_path="module:function",
221
+ input_shape_str="not,valid,shape", # Invalid
222
+ output_path=None,
223
+ opset_version=17,
224
+ logger=logger,
225
+ )
226
+
227
+ assert onnx_path is None
228
+
229
+ def test_nonexistent_params_file(self, tmp_path, logger):
230
+ """Conversion should fail for non-existent params file."""
231
+ fake_path = tmp_path / "nonexistent.pkl"
232
+
233
+ onnx_path, _ = _convert_jax_to_onnx(
234
+ fake_path,
235
+ apply_fn_path="module:function",
236
+ input_shape_str="1,3,224,224",
237
+ output_path=None,
238
+ opset_version=17,
239
+ logger=logger,
240
+ )
241
+
242
+ assert onnx_path is None
243
+
244
+
245
+ @pytest.mark.skipif(not _TF_AVAILABLE, reason="TensorFlow not installed")
246
+ class TestTensorFlowConversion:
247
+ """Tests for TensorFlow to ONNX conversion (requires TensorFlow)."""
248
+
249
+ def test_savedmodel_conversion(self, tmp_path, logger):
250
+ """SavedModel should convert successfully."""
251
+ # Create a simple TF SavedModel
252
+ model = tf.keras.Sequential(
253
+ [
254
+ tf.keras.layers.Dense(10, input_shape=(5,), activation="relu"),
255
+ tf.keras.layers.Dense(2),
256
+ ]
257
+ )
258
+
259
+ saved_model_path = tmp_path / "saved_model"
260
+ model.save(str(saved_model_path), save_format="tf")
261
+
262
+ # Convert to ONNX
263
+ onnx_path, _temp_file = _convert_tensorflow_to_onnx(
264
+ saved_model_path,
265
+ output_path=tmp_path / "output.onnx",
266
+ opset_version=17,
267
+ logger=logger,
268
+ )
269
+
270
+ assert onnx_path is not None
271
+ assert onnx_path.exists()
272
+ assert onnx_path.suffix == ".onnx"
273
+
274
+ def test_savedmodel_temp_file(self, tmp_path, logger):
275
+ """SavedModel conversion to temp file should work."""
276
+ model = tf.keras.Sequential(
277
+ [
278
+ tf.keras.layers.Dense(10, input_shape=(5,)),
279
+ ]
280
+ )
281
+
282
+ saved_model_path = tmp_path / "saved_model"
283
+ model.save(str(saved_model_path), save_format="tf")
284
+
285
+ onnx_path, _temp_file = _convert_tensorflow_to_onnx(
286
+ saved_model_path,
287
+ output_path=None, # Use temp file
288
+ opset_version=17,
289
+ logger=logger,
290
+ )
291
+
292
+ assert onnx_path is not None
293
+ assert onnx_path.exists()
294
+ # Cleanup
295
+ onnx_path.unlink()
296
+
297
+
298
+ @pytest.mark.skipif(not _TF_AVAILABLE, reason="TensorFlow not installed")
299
+ class TestKerasConversion:
300
+ """Tests for Keras to ONNX conversion (requires TensorFlow)."""
301
+
302
+ def test_h5_conversion(self, tmp_path, logger):
303
+ """Keras .h5 model should convert successfully."""
304
+ model = tf.keras.Sequential(
305
+ [
306
+ tf.keras.layers.Dense(10, input_shape=(5,), activation="relu"),
307
+ tf.keras.layers.Dense(2),
308
+ ]
309
+ )
310
+
311
+ h5_path = tmp_path / "model.h5"
312
+ model.save(str(h5_path), save_format="h5")
313
+
314
+ onnx_path, _temp_file = _convert_keras_to_onnx(
315
+ h5_path,
316
+ output_path=tmp_path / "output.onnx",
317
+ opset_version=17,
318
+ logger=logger,
319
+ )
320
+
321
+ assert onnx_path is not None
322
+ assert onnx_path.exists()
323
+ assert onnx_path.suffix == ".onnx"
324
+
325
+ def test_keras_format_conversion(self, tmp_path, logger):
326
+ """Keras .keras format should convert successfully."""
327
+ model = tf.keras.Sequential(
328
+ [
329
+ tf.keras.layers.Dense(10, input_shape=(5,)),
330
+ ]
331
+ )
332
+
333
+ keras_path = tmp_path / "model.keras"
334
+ model.save(str(keras_path))
335
+
336
+ onnx_path, _temp_file = _convert_keras_to_onnx(
337
+ keras_path,
338
+ output_path=tmp_path / "output.onnx",
339
+ opset_version=17,
340
+ logger=logger,
341
+ )
342
+
343
+ assert onnx_path is not None
344
+ assert onnx_path.exists()
345
+
346
+
347
+ @pytest.mark.skipif(not _JAX_AVAILABLE or not _TF_AVAILABLE, reason="JAX or TF not installed")
348
+ class TestJAXConversion:
349
+ """Tests for JAX to ONNX conversion (requires JAX and TensorFlow)."""
350
+
351
+ def test_unsupported_params_format(self, tmp_path, logger):
352
+ """Unsupported params format should fail with clear error."""
353
+ fake_params = tmp_path / "params.xyz"
354
+ fake_params.write_bytes(b"fake")
355
+
356
+ onnx_path, _ = _convert_jax_to_onnx(
357
+ fake_params,
358
+ apply_fn_path="module:function",
359
+ input_shape_str="1,3,224,224",
360
+ output_path=None,
361
+ opset_version=17,
362
+ logger=logger,
363
+ )
364
+
365
+ assert onnx_path is None
366
+
367
+
368
+ class TestCLIValidation:
369
+ """Test CLI argument validation."""
370
+
371
+ def test_multiple_conversion_flags_error(self):
372
+ """Using multiple conversion flags should error."""
373
+
374
+ # This would be tested via the main function, but we can verify
375
+ # the argument structure allows these to be set
376
+ # The actual validation happens in run_inspect()
377
+
378
+
379
+ if __name__ == "__main__":
380
+ pytest.main([__file__, "-v"])
@@ -0,0 +1,316 @@
1
+ # Copyright (c) 2025 HaoLine Contributors
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """
5
+ Unit tests for the visualization module.
6
+
7
+ Tests chart generation, theming, and graceful fallback behavior.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import tempfile
13
+ from pathlib import Path
14
+
15
+ import pytest
16
+
17
+ from ..analyzer import FlopCounts, MemoryEstimates, ParamCounts
18
+ from ..report import GraphSummary, InspectionReport, ModelMetadata
19
+ from ..visualizations import (
20
+ THEME,
21
+ ChartTheme,
22
+ VisualizationGenerator,
23
+ _format_count,
24
+ generate_visualizations,
25
+ is_available,
26
+ )
27
+
28
+
29
+ class TestVisualizationAvailability:
30
+ """Tests for matplotlib availability detection."""
31
+
32
+ def test_is_available_returns_bool(self):
33
+ """is_available() should return a boolean."""
34
+ result = is_available()
35
+ assert isinstance(result, bool)
36
+
37
+ def test_theme_has_required_colors(self):
38
+ """Theme should have all required color properties."""
39
+ assert hasattr(THEME, "background")
40
+ assert hasattr(THEME, "text")
41
+ assert hasattr(THEME, "accent_primary")
42
+ assert hasattr(THEME, "palette")
43
+ assert len(THEME.palette) >= 5 # Need enough colors for charts
44
+
45
+
46
+ class TestChartTheme:
47
+ """Tests for the ChartTheme dataclass."""
48
+
49
+ def test_default_theme_values(self):
50
+ """Default theme should have sensible values."""
51
+ theme = ChartTheme()
52
+ assert theme.background.startswith("#")
53
+ assert theme.text.startswith("#")
54
+ assert theme.figure_dpi >= 72
55
+ assert theme.figure_width > 0
56
+ assert theme.figure_height > 0
57
+
58
+ def test_custom_theme(self):
59
+ """Should be able to create custom themes."""
60
+ theme = ChartTheme(
61
+ background="#000000",
62
+ text="#ffffff",
63
+ accent_primary="#ff0000",
64
+ )
65
+ assert theme.background == "#000000"
66
+ assert theme.text == "#ffffff"
67
+ assert theme.accent_primary == "#ff0000"
68
+
69
+
70
+ class TestFormatCount:
71
+ """Tests for the _format_count helper function."""
72
+
73
+ def test_format_small_numbers(self):
74
+ """Small numbers should be formatted as-is."""
75
+ assert _format_count(0) == "0"
76
+ assert _format_count(1) == "1"
77
+ assert _format_count(999) == "999"
78
+
79
+ def test_format_thousands(self):
80
+ """Thousands should use K suffix."""
81
+ assert _format_count(1000) == "1.0K"
82
+ assert _format_count(1500) == "1.5K"
83
+ assert _format_count(999999) == "1000.0K" # Just under 1M
84
+
85
+ def test_format_millions(self):
86
+ """Millions should use M suffix."""
87
+ assert _format_count(1000000) == "1.0M"
88
+ assert _format_count(1500000) == "1.5M"
89
+ assert _format_count(25000000) == "25.0M"
90
+
91
+ def test_format_billions(self):
92
+ """Billions should use B suffix."""
93
+ assert _format_count(1000000000) == "1.0B"
94
+ assert _format_count(7500000000) == "7.5B"
95
+
96
+
97
+ class TestVisualizationGenerator:
98
+ """Tests for the VisualizationGenerator class."""
99
+
100
+ def test_generator_initialization(self):
101
+ """Generator should initialize without errors."""
102
+ gen = VisualizationGenerator()
103
+ assert gen is not None
104
+ assert gen.logger is not None
105
+
106
+ def test_generator_with_custom_logger(self):
107
+ """Generator should accept custom logger."""
108
+ import logging
109
+
110
+ logger = logging.getLogger("test")
111
+ gen = VisualizationGenerator(logger=logger)
112
+ assert gen.logger is logger
113
+
114
+
115
+ @pytest.mark.skipif(not is_available(), reason="matplotlib not installed")
116
+ class TestChartGeneration:
117
+ """Tests for actual chart generation (requires matplotlib)."""
118
+
119
+ def test_operator_histogram_generation(self):
120
+ """Should generate operator histogram PNG."""
121
+ gen = VisualizationGenerator()
122
+ op_counts = {
123
+ "Conv": 50,
124
+ "Relu": 48,
125
+ "Add": 25,
126
+ "MatMul": 12,
127
+ "Softmax": 6,
128
+ }
129
+
130
+ with tempfile.TemporaryDirectory() as tmpdir:
131
+ output_path = Path(tmpdir) / "op_histogram.png"
132
+ result = gen.operator_histogram(op_counts, output_path)
133
+
134
+ assert result is not None
135
+ assert result.exists()
136
+ assert result.stat().st_size > 0 # File has content
137
+
138
+ def test_param_distribution_generation(self):
139
+ """Should generate parameter distribution chart."""
140
+ gen = VisualizationGenerator()
141
+ params_by_op = {
142
+ "Conv": 5000000,
143
+ "Gemm": 2000000,
144
+ "MatMul": 1500000,
145
+ }
146
+
147
+ with tempfile.TemporaryDirectory() as tmpdir:
148
+ output_path = Path(tmpdir) / "param_dist.png"
149
+ result = gen.param_distribution(params_by_op, output_path)
150
+
151
+ assert result is not None
152
+ assert result.exists()
153
+
154
+ def test_flops_distribution_generation(self):
155
+ """Should generate FLOPs distribution chart."""
156
+ gen = VisualizationGenerator()
157
+ flops_by_op = {
158
+ "Conv": 500000000,
159
+ "MatMul": 200000000,
160
+ "Gemm": 150000000,
161
+ }
162
+
163
+ with tempfile.TemporaryDirectory() as tmpdir:
164
+ output_path = Path(tmpdir) / "flops_dist.png"
165
+ result = gen.flops_distribution(flops_by_op, output_path)
166
+
167
+ assert result is not None
168
+ assert result.exists()
169
+
170
+ def test_empty_data_returns_none(self):
171
+ """Charts should return None for empty data."""
172
+ gen = VisualizationGenerator()
173
+
174
+ with tempfile.TemporaryDirectory() as tmpdir:
175
+ # Empty dict
176
+ result = gen.operator_histogram({}, Path(tmpdir) / "empty.png")
177
+ assert result is None
178
+
179
+ # All zeros
180
+ result = gen.param_distribution({"Conv": 0, "Relu": 0}, Path(tmpdir) / "zeros.png")
181
+ assert result is None
182
+
183
+ def test_generate_all_creates_multiple_charts(self):
184
+ """generate_all should create multiple chart files."""
185
+ # Create a mock report
186
+ metadata = ModelMetadata(
187
+ path="test.onnx",
188
+ ir_version=8,
189
+ producer_name="test",
190
+ producer_version="1.0",
191
+ domain="",
192
+ model_version=1,
193
+ doc_string="",
194
+ opsets={"ai.onnx": 17},
195
+ )
196
+
197
+ graph_summary = GraphSummary(
198
+ num_nodes=100,
199
+ num_inputs=1,
200
+ num_outputs=1,
201
+ num_initializers=50,
202
+ input_shapes={"input": [1, 3, 224, 224]},
203
+ output_shapes={"output": [1, 1000]},
204
+ op_type_counts={"Conv": 50, "Relu": 48, "Add": 25},
205
+ )
206
+
207
+ param_counts = ParamCounts(
208
+ total=25000000,
209
+ trainable=25000000,
210
+ by_op_type={"Conv": 20000000, "Gemm": 5000000},
211
+ )
212
+
213
+ flop_counts = FlopCounts(
214
+ total=4000000000,
215
+ by_op_type={"Conv": 3500000000, "Gemm": 500000000},
216
+ )
217
+
218
+ memory_estimates = MemoryEstimates(
219
+ model_size_bytes=100000000,
220
+ peak_activation_bytes=50000000,
221
+ )
222
+
223
+ report = InspectionReport(
224
+ metadata=metadata,
225
+ graph_summary=graph_summary,
226
+ param_counts=param_counts,
227
+ flop_counts=flop_counts,
228
+ memory_estimates=memory_estimates,
229
+ )
230
+
231
+ gen = VisualizationGenerator()
232
+
233
+ with tempfile.TemporaryDirectory() as tmpdir:
234
+ paths = gen.generate_all(report, Path(tmpdir))
235
+
236
+ assert len(paths) >= 3 # Should generate at least 3 charts
237
+ assert "op_histogram" in paths
238
+ assert "param_distribution" in paths
239
+ assert "complexity_summary" in paths
240
+
241
+ # All files should exist
242
+ for name, path in paths.items():
243
+ assert path.exists(), f"{name} file should exist"
244
+
245
+
246
+ @pytest.mark.skipif(not is_available(), reason="matplotlib not installed")
247
+ class TestConvenienceFunction:
248
+ """Tests for the generate_visualizations convenience function."""
249
+
250
+ def test_generate_visualizations_function(self):
251
+ """Convenience function should work correctly."""
252
+ metadata = ModelMetadata(
253
+ path="test.onnx",
254
+ ir_version=8,
255
+ producer_name="test",
256
+ producer_version="1.0",
257
+ domain="",
258
+ model_version=1,
259
+ doc_string="",
260
+ opsets={"ai.onnx": 17},
261
+ )
262
+
263
+ graph_summary = GraphSummary(
264
+ num_nodes=10,
265
+ num_inputs=1,
266
+ num_outputs=1,
267
+ num_initializers=5,
268
+ input_shapes={},
269
+ output_shapes={},
270
+ op_type_counts={"Conv": 5, "Relu": 5},
271
+ )
272
+
273
+ param_counts = ParamCounts(total=1000, trainable=1000, by_op_type={"Conv": 1000})
274
+ flop_counts = FlopCounts(total=10000, by_op_type={"Conv": 10000})
275
+ memory_estimates = MemoryEstimates(model_size_bytes=4000, peak_activation_bytes=2000)
276
+
277
+ report = InspectionReport(
278
+ metadata=metadata,
279
+ graph_summary=graph_summary,
280
+ param_counts=param_counts,
281
+ flop_counts=flop_counts,
282
+ memory_estimates=memory_estimates,
283
+ )
284
+
285
+ with tempfile.TemporaryDirectory() as tmpdir:
286
+ paths = generate_visualizations(report, tmpdir)
287
+ assert isinstance(paths, dict)
288
+ assert len(paths) > 0
289
+
290
+
291
+ class TestGracefulDegradation:
292
+ """Tests for graceful degradation when matplotlib is unavailable."""
293
+
294
+ def test_generator_handles_missing_matplotlib(self):
295
+ """Generator should not crash if matplotlib is unavailable."""
296
+ # This test runs regardless of matplotlib availability
297
+ gen = VisualizationGenerator()
298
+
299
+ # Even with matplotlib, empty data should return empty dict
300
+ metadata = ModelMetadata(
301
+ path="test.onnx",
302
+ ir_version=8,
303
+ producer_name="test",
304
+ producer_version="1.0",
305
+ domain="",
306
+ model_version=1,
307
+ doc_string="",
308
+ opsets={},
309
+ )
310
+
311
+ # Report with no metrics - should return empty dict
312
+ report = InspectionReport(metadata=metadata)
313
+
314
+ with tempfile.TemporaryDirectory() as tmpdir:
315
+ paths = gen.generate_all(report, Path(tmpdir))
316
+ assert isinstance(paths, dict)