haoline 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haoline/.streamlit/config.toml +10 -0
- haoline/__init__.py +248 -0
- haoline/analyzer.py +935 -0
- haoline/cli.py +2712 -0
- haoline/compare.py +811 -0
- haoline/compare_visualizations.py +1564 -0
- haoline/edge_analysis.py +525 -0
- haoline/eval/__init__.py +131 -0
- haoline/eval/adapters.py +844 -0
- haoline/eval/cli.py +390 -0
- haoline/eval/comparison.py +542 -0
- haoline/eval/deployment.py +633 -0
- haoline/eval/schemas.py +833 -0
- haoline/examples/__init__.py +15 -0
- haoline/examples/basic_inspection.py +74 -0
- haoline/examples/compare_models.py +117 -0
- haoline/examples/hardware_estimation.py +78 -0
- haoline/format_adapters.py +1001 -0
- haoline/formats/__init__.py +123 -0
- haoline/formats/coreml.py +250 -0
- haoline/formats/gguf.py +483 -0
- haoline/formats/openvino.py +255 -0
- haoline/formats/safetensors.py +273 -0
- haoline/formats/tflite.py +369 -0
- haoline/hardware.py +2307 -0
- haoline/hierarchical_graph.py +462 -0
- haoline/html_export.py +1573 -0
- haoline/layer_summary.py +769 -0
- haoline/llm_summarizer.py +465 -0
- haoline/op_icons.py +618 -0
- haoline/operational_profiling.py +1492 -0
- haoline/patterns.py +1116 -0
- haoline/pdf_generator.py +265 -0
- haoline/privacy.py +250 -0
- haoline/pydantic_models.py +241 -0
- haoline/report.py +1923 -0
- haoline/report_sections.py +539 -0
- haoline/risks.py +521 -0
- haoline/schema.py +523 -0
- haoline/streamlit_app.py +2024 -0
- haoline/tests/__init__.py +4 -0
- haoline/tests/conftest.py +123 -0
- haoline/tests/test_analyzer.py +868 -0
- haoline/tests/test_compare_visualizations.py +293 -0
- haoline/tests/test_edge_analysis.py +243 -0
- haoline/tests/test_eval.py +604 -0
- haoline/tests/test_format_adapters.py +460 -0
- haoline/tests/test_hardware.py +237 -0
- haoline/tests/test_hardware_recommender.py +90 -0
- haoline/tests/test_hierarchical_graph.py +326 -0
- haoline/tests/test_html_export.py +180 -0
- haoline/tests/test_layer_summary.py +428 -0
- haoline/tests/test_llm_patterns.py +540 -0
- haoline/tests/test_llm_summarizer.py +339 -0
- haoline/tests/test_patterns.py +774 -0
- haoline/tests/test_pytorch.py +327 -0
- haoline/tests/test_report.py +383 -0
- haoline/tests/test_risks.py +398 -0
- haoline/tests/test_schema.py +417 -0
- haoline/tests/test_tensorflow.py +380 -0
- haoline/tests/test_visualizations.py +316 -0
- haoline/universal_ir.py +856 -0
- haoline/visualizations.py +1086 -0
- haoline/visualize_yolo.py +44 -0
- haoline/web.py +110 -0
- haoline-0.3.0.dist-info/METADATA +471 -0
- haoline-0.3.0.dist-info/RECORD +70 -0
- haoline-0.3.0.dist-info/WHEEL +4 -0
- haoline-0.3.0.dist-info/entry_points.txt +5 -0
- haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,380 @@
|
|
|
1
|
+
# Copyright (c) 2025 HaoLine Contributors
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Unit tests for TensorFlow/Keras/JAX to ONNX conversion functionality.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import importlib.util
|
|
11
|
+
import logging
|
|
12
|
+
from unittest.mock import patch
|
|
13
|
+
|
|
14
|
+
import pytest
|
|
15
|
+
|
|
16
|
+
from ..cli import (
|
|
17
|
+
_convert_frozen_graph_to_onnx,
|
|
18
|
+
_convert_jax_to_onnx,
|
|
19
|
+
_convert_keras_to_onnx,
|
|
20
|
+
_convert_tensorflow_to_onnx,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
# Check if TensorFlow and tf2onnx are available
|
|
24
|
+
_TF_AVAILABLE = (
|
|
25
|
+
importlib.util.find_spec("tensorflow") is not None
|
|
26
|
+
and importlib.util.find_spec("tf2onnx") is not None
|
|
27
|
+
)
|
|
28
|
+
_JAX_AVAILABLE = importlib.util.find_spec("jax") is not None
|
|
29
|
+
|
|
30
|
+
# Import TensorFlow only when needed for tests
|
|
31
|
+
if _TF_AVAILABLE:
|
|
32
|
+
import tensorflow as tf
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.fixture
|
|
36
|
+
def logger():
|
|
37
|
+
"""Create a logger for tests."""
|
|
38
|
+
return logging.getLogger("test_tensorflow")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class TestTensorFlowConversionErrors:
|
|
42
|
+
"""Test error handling for TensorFlow conversion (no TF required)."""
|
|
43
|
+
|
|
44
|
+
def test_nonexistent_savedmodel(self, tmp_path, logger):
|
|
45
|
+
"""Conversion should fail for non-existent SavedModel."""
|
|
46
|
+
fake_path = tmp_path / "nonexistent_model"
|
|
47
|
+
|
|
48
|
+
onnx_path, _ = _convert_tensorflow_to_onnx(
|
|
49
|
+
fake_path,
|
|
50
|
+
output_path=None,
|
|
51
|
+
opset_version=17,
|
|
52
|
+
logger=logger,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
assert onnx_path is None
|
|
56
|
+
|
|
57
|
+
def test_tf2onnx_not_installed(self, tmp_path, logger):
|
|
58
|
+
"""Conversion should fail gracefully when tf2onnx not installed."""
|
|
59
|
+
# Create a fake directory to avoid "not found" error
|
|
60
|
+
fake_model = tmp_path / "fake_saved_model"
|
|
61
|
+
fake_model.mkdir()
|
|
62
|
+
|
|
63
|
+
# Mock tf2onnx import to fail
|
|
64
|
+
with patch.dict("sys.modules", {"tf2onnx": None}):
|
|
65
|
+
# Force re-import check by clearing cached imports
|
|
66
|
+
|
|
67
|
+
# Save original function
|
|
68
|
+
|
|
69
|
+
# The function checks for import at runtime, so we need to
|
|
70
|
+
# test the actual error path
|
|
71
|
+
_onnx_path, _ = _convert_tensorflow_to_onnx(
|
|
72
|
+
fake_model,
|
|
73
|
+
output_path=None,
|
|
74
|
+
opset_version=17,
|
|
75
|
+
logger=logger,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# If tf2onnx is not installed, this should fail
|
|
79
|
+
# (result depends on actual installation state)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class TestKerasConversionErrors:
|
|
83
|
+
"""Test error handling for Keras conversion (no TF required)."""
|
|
84
|
+
|
|
85
|
+
def test_nonexistent_keras_file(self, tmp_path, logger):
|
|
86
|
+
"""Conversion should fail for non-existent Keras file."""
|
|
87
|
+
fake_path = tmp_path / "nonexistent.h5"
|
|
88
|
+
|
|
89
|
+
onnx_path, _ = _convert_keras_to_onnx(
|
|
90
|
+
fake_path,
|
|
91
|
+
output_path=None,
|
|
92
|
+
opset_version=17,
|
|
93
|
+
logger=logger,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
assert onnx_path is None
|
|
97
|
+
|
|
98
|
+
def test_unexpected_extension_warning(self, tmp_path, logger, caplog):
|
|
99
|
+
"""Unexpected file extension should log warning but proceed."""
|
|
100
|
+
# Create a fake file with wrong extension
|
|
101
|
+
fake_file = tmp_path / "model.wrongext"
|
|
102
|
+
fake_file.write_text("fake content")
|
|
103
|
+
|
|
104
|
+
with caplog.at_level(logging.WARNING):
|
|
105
|
+
_onnx_path, _ = _convert_keras_to_onnx(
|
|
106
|
+
fake_file,
|
|
107
|
+
output_path=None,
|
|
108
|
+
opset_version=17,
|
|
109
|
+
logger=logger,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Should have logged a warning about unexpected extension
|
|
113
|
+
# (will still fail because it's not a real model)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class TestFrozenGraphConversionErrors:
|
|
117
|
+
"""Test error handling for frozen graph conversion."""
|
|
118
|
+
|
|
119
|
+
def test_missing_inputs_outputs(self, tmp_path, logger):
|
|
120
|
+
"""Conversion should fail when inputs/outputs not specified."""
|
|
121
|
+
fake_pb = tmp_path / "model.pb"
|
|
122
|
+
fake_pb.write_bytes(b"fake")
|
|
123
|
+
|
|
124
|
+
# Missing inputs
|
|
125
|
+
onnx_path, _ = _convert_frozen_graph_to_onnx(
|
|
126
|
+
fake_pb,
|
|
127
|
+
inputs=None,
|
|
128
|
+
outputs="output:0",
|
|
129
|
+
output_path=None,
|
|
130
|
+
opset_version=17,
|
|
131
|
+
logger=logger,
|
|
132
|
+
)
|
|
133
|
+
assert onnx_path is None
|
|
134
|
+
|
|
135
|
+
# Missing outputs
|
|
136
|
+
onnx_path, _ = _convert_frozen_graph_to_onnx(
|
|
137
|
+
fake_pb,
|
|
138
|
+
inputs="input:0",
|
|
139
|
+
outputs=None,
|
|
140
|
+
output_path=None,
|
|
141
|
+
opset_version=17,
|
|
142
|
+
logger=logger,
|
|
143
|
+
)
|
|
144
|
+
assert onnx_path is None
|
|
145
|
+
|
|
146
|
+
def test_nonexistent_pb_file(self, tmp_path, logger):
|
|
147
|
+
"""Conversion should fail for non-existent .pb file."""
|
|
148
|
+
fake_path = tmp_path / "nonexistent.pb"
|
|
149
|
+
|
|
150
|
+
onnx_path, _ = _convert_frozen_graph_to_onnx(
|
|
151
|
+
fake_path,
|
|
152
|
+
inputs="input:0",
|
|
153
|
+
outputs="output:0",
|
|
154
|
+
output_path=None,
|
|
155
|
+
opset_version=17,
|
|
156
|
+
logger=logger,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
assert onnx_path is None
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class TestJAXConversionErrors:
|
|
163
|
+
"""Test error handling for JAX conversion."""
|
|
164
|
+
|
|
165
|
+
def test_missing_apply_fn(self, tmp_path, logger):
|
|
166
|
+
"""Conversion should fail when --jax-apply-fn not provided."""
|
|
167
|
+
fake_params = tmp_path / "params.pkl"
|
|
168
|
+
fake_params.write_bytes(b"fake")
|
|
169
|
+
|
|
170
|
+
onnx_path, _ = _convert_jax_to_onnx(
|
|
171
|
+
fake_params,
|
|
172
|
+
apply_fn_path=None, # Missing
|
|
173
|
+
input_shape_str="1,3,224,224",
|
|
174
|
+
output_path=None,
|
|
175
|
+
opset_version=17,
|
|
176
|
+
logger=logger,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
assert onnx_path is None
|
|
180
|
+
|
|
181
|
+
def test_missing_input_shape(self, tmp_path, logger):
|
|
182
|
+
"""Conversion should fail when --input-shape not provided."""
|
|
183
|
+
fake_params = tmp_path / "params.pkl"
|
|
184
|
+
fake_params.write_bytes(b"fake")
|
|
185
|
+
|
|
186
|
+
onnx_path, _ = _convert_jax_to_onnx(
|
|
187
|
+
fake_params,
|
|
188
|
+
apply_fn_path="module:function",
|
|
189
|
+
input_shape_str=None, # Missing
|
|
190
|
+
output_path=None,
|
|
191
|
+
opset_version=17,
|
|
192
|
+
logger=logger,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
assert onnx_path is None
|
|
196
|
+
|
|
197
|
+
def test_invalid_apply_fn_format(self, tmp_path, logger):
|
|
198
|
+
"""Conversion should fail for invalid apply_fn format."""
|
|
199
|
+
fake_params = tmp_path / "params.pkl"
|
|
200
|
+
fake_params.write_bytes(b"fake")
|
|
201
|
+
|
|
202
|
+
onnx_path, _ = _convert_jax_to_onnx(
|
|
203
|
+
fake_params,
|
|
204
|
+
apply_fn_path="no_colon_separator", # Invalid format
|
|
205
|
+
input_shape_str="1,3,224,224",
|
|
206
|
+
output_path=None,
|
|
207
|
+
opset_version=17,
|
|
208
|
+
logger=logger,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
assert onnx_path is None
|
|
212
|
+
|
|
213
|
+
def test_invalid_input_shape_format(self, tmp_path, logger):
|
|
214
|
+
"""Conversion should fail for invalid input shape format."""
|
|
215
|
+
fake_params = tmp_path / "params.pkl"
|
|
216
|
+
fake_params.write_bytes(b"fake")
|
|
217
|
+
|
|
218
|
+
onnx_path, _ = _convert_jax_to_onnx(
|
|
219
|
+
fake_params,
|
|
220
|
+
apply_fn_path="module:function",
|
|
221
|
+
input_shape_str="not,valid,shape", # Invalid
|
|
222
|
+
output_path=None,
|
|
223
|
+
opset_version=17,
|
|
224
|
+
logger=logger,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
assert onnx_path is None
|
|
228
|
+
|
|
229
|
+
def test_nonexistent_params_file(self, tmp_path, logger):
|
|
230
|
+
"""Conversion should fail for non-existent params file."""
|
|
231
|
+
fake_path = tmp_path / "nonexistent.pkl"
|
|
232
|
+
|
|
233
|
+
onnx_path, _ = _convert_jax_to_onnx(
|
|
234
|
+
fake_path,
|
|
235
|
+
apply_fn_path="module:function",
|
|
236
|
+
input_shape_str="1,3,224,224",
|
|
237
|
+
output_path=None,
|
|
238
|
+
opset_version=17,
|
|
239
|
+
logger=logger,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
assert onnx_path is None
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
@pytest.mark.skipif(not _TF_AVAILABLE, reason="TensorFlow not installed")
|
|
246
|
+
class TestTensorFlowConversion:
|
|
247
|
+
"""Tests for TensorFlow to ONNX conversion (requires TensorFlow)."""
|
|
248
|
+
|
|
249
|
+
def test_savedmodel_conversion(self, tmp_path, logger):
|
|
250
|
+
"""SavedModel should convert successfully."""
|
|
251
|
+
# Create a simple TF SavedModel
|
|
252
|
+
model = tf.keras.Sequential(
|
|
253
|
+
[
|
|
254
|
+
tf.keras.layers.Dense(10, input_shape=(5,), activation="relu"),
|
|
255
|
+
tf.keras.layers.Dense(2),
|
|
256
|
+
]
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
saved_model_path = tmp_path / "saved_model"
|
|
260
|
+
model.save(str(saved_model_path), save_format="tf")
|
|
261
|
+
|
|
262
|
+
# Convert to ONNX
|
|
263
|
+
onnx_path, _temp_file = _convert_tensorflow_to_onnx(
|
|
264
|
+
saved_model_path,
|
|
265
|
+
output_path=tmp_path / "output.onnx",
|
|
266
|
+
opset_version=17,
|
|
267
|
+
logger=logger,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
assert onnx_path is not None
|
|
271
|
+
assert onnx_path.exists()
|
|
272
|
+
assert onnx_path.suffix == ".onnx"
|
|
273
|
+
|
|
274
|
+
def test_savedmodel_temp_file(self, tmp_path, logger):
|
|
275
|
+
"""SavedModel conversion to temp file should work."""
|
|
276
|
+
model = tf.keras.Sequential(
|
|
277
|
+
[
|
|
278
|
+
tf.keras.layers.Dense(10, input_shape=(5,)),
|
|
279
|
+
]
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
saved_model_path = tmp_path / "saved_model"
|
|
283
|
+
model.save(str(saved_model_path), save_format="tf")
|
|
284
|
+
|
|
285
|
+
onnx_path, _temp_file = _convert_tensorflow_to_onnx(
|
|
286
|
+
saved_model_path,
|
|
287
|
+
output_path=None, # Use temp file
|
|
288
|
+
opset_version=17,
|
|
289
|
+
logger=logger,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
assert onnx_path is not None
|
|
293
|
+
assert onnx_path.exists()
|
|
294
|
+
# Cleanup
|
|
295
|
+
onnx_path.unlink()
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
@pytest.mark.skipif(not _TF_AVAILABLE, reason="TensorFlow not installed")
|
|
299
|
+
class TestKerasConversion:
|
|
300
|
+
"""Tests for Keras to ONNX conversion (requires TensorFlow)."""
|
|
301
|
+
|
|
302
|
+
def test_h5_conversion(self, tmp_path, logger):
|
|
303
|
+
"""Keras .h5 model should convert successfully."""
|
|
304
|
+
model = tf.keras.Sequential(
|
|
305
|
+
[
|
|
306
|
+
tf.keras.layers.Dense(10, input_shape=(5,), activation="relu"),
|
|
307
|
+
tf.keras.layers.Dense(2),
|
|
308
|
+
]
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
h5_path = tmp_path / "model.h5"
|
|
312
|
+
model.save(str(h5_path), save_format="h5")
|
|
313
|
+
|
|
314
|
+
onnx_path, _temp_file = _convert_keras_to_onnx(
|
|
315
|
+
h5_path,
|
|
316
|
+
output_path=tmp_path / "output.onnx",
|
|
317
|
+
opset_version=17,
|
|
318
|
+
logger=logger,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
assert onnx_path is not None
|
|
322
|
+
assert onnx_path.exists()
|
|
323
|
+
assert onnx_path.suffix == ".onnx"
|
|
324
|
+
|
|
325
|
+
def test_keras_format_conversion(self, tmp_path, logger):
|
|
326
|
+
"""Keras .keras format should convert successfully."""
|
|
327
|
+
model = tf.keras.Sequential(
|
|
328
|
+
[
|
|
329
|
+
tf.keras.layers.Dense(10, input_shape=(5,)),
|
|
330
|
+
]
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
keras_path = tmp_path / "model.keras"
|
|
334
|
+
model.save(str(keras_path))
|
|
335
|
+
|
|
336
|
+
onnx_path, _temp_file = _convert_keras_to_onnx(
|
|
337
|
+
keras_path,
|
|
338
|
+
output_path=tmp_path / "output.onnx",
|
|
339
|
+
opset_version=17,
|
|
340
|
+
logger=logger,
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
assert onnx_path is not None
|
|
344
|
+
assert onnx_path.exists()
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
@pytest.mark.skipif(not _JAX_AVAILABLE or not _TF_AVAILABLE, reason="JAX or TF not installed")
|
|
348
|
+
class TestJAXConversion:
|
|
349
|
+
"""Tests for JAX to ONNX conversion (requires JAX and TensorFlow)."""
|
|
350
|
+
|
|
351
|
+
def test_unsupported_params_format(self, tmp_path, logger):
|
|
352
|
+
"""Unsupported params format should fail with clear error."""
|
|
353
|
+
fake_params = tmp_path / "params.xyz"
|
|
354
|
+
fake_params.write_bytes(b"fake")
|
|
355
|
+
|
|
356
|
+
onnx_path, _ = _convert_jax_to_onnx(
|
|
357
|
+
fake_params,
|
|
358
|
+
apply_fn_path="module:function",
|
|
359
|
+
input_shape_str="1,3,224,224",
|
|
360
|
+
output_path=None,
|
|
361
|
+
opset_version=17,
|
|
362
|
+
logger=logger,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
assert onnx_path is None
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
class TestCLIValidation:
|
|
369
|
+
"""Test CLI argument validation."""
|
|
370
|
+
|
|
371
|
+
def test_multiple_conversion_flags_error(self):
|
|
372
|
+
"""Using multiple conversion flags should error."""
|
|
373
|
+
|
|
374
|
+
# This would be tested via the main function, but we can verify
|
|
375
|
+
# the argument structure allows these to be set
|
|
376
|
+
# The actual validation happens in run_inspect()
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
if __name__ == "__main__":
|
|
380
|
+
pytest.main([__file__, "-v"])
|
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
# Copyright (c) 2025 HaoLine Contributors
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Unit tests for the visualization module.
|
|
6
|
+
|
|
7
|
+
Tests chart generation, theming, and graceful fallback behavior.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import tempfile
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
import pytest
|
|
16
|
+
|
|
17
|
+
from ..analyzer import FlopCounts, MemoryEstimates, ParamCounts
|
|
18
|
+
from ..report import GraphSummary, InspectionReport, ModelMetadata
|
|
19
|
+
from ..visualizations import (
|
|
20
|
+
THEME,
|
|
21
|
+
ChartTheme,
|
|
22
|
+
VisualizationGenerator,
|
|
23
|
+
_format_count,
|
|
24
|
+
generate_visualizations,
|
|
25
|
+
is_available,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TestVisualizationAvailability:
|
|
30
|
+
"""Tests for matplotlib availability detection."""
|
|
31
|
+
|
|
32
|
+
def test_is_available_returns_bool(self):
|
|
33
|
+
"""is_available() should return a boolean."""
|
|
34
|
+
result = is_available()
|
|
35
|
+
assert isinstance(result, bool)
|
|
36
|
+
|
|
37
|
+
def test_theme_has_required_colors(self):
|
|
38
|
+
"""Theme should have all required color properties."""
|
|
39
|
+
assert hasattr(THEME, "background")
|
|
40
|
+
assert hasattr(THEME, "text")
|
|
41
|
+
assert hasattr(THEME, "accent_primary")
|
|
42
|
+
assert hasattr(THEME, "palette")
|
|
43
|
+
assert len(THEME.palette) >= 5 # Need enough colors for charts
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TestChartTheme:
|
|
47
|
+
"""Tests for the ChartTheme dataclass."""
|
|
48
|
+
|
|
49
|
+
def test_default_theme_values(self):
|
|
50
|
+
"""Default theme should have sensible values."""
|
|
51
|
+
theme = ChartTheme()
|
|
52
|
+
assert theme.background.startswith("#")
|
|
53
|
+
assert theme.text.startswith("#")
|
|
54
|
+
assert theme.figure_dpi >= 72
|
|
55
|
+
assert theme.figure_width > 0
|
|
56
|
+
assert theme.figure_height > 0
|
|
57
|
+
|
|
58
|
+
def test_custom_theme(self):
|
|
59
|
+
"""Should be able to create custom themes."""
|
|
60
|
+
theme = ChartTheme(
|
|
61
|
+
background="#000000",
|
|
62
|
+
text="#ffffff",
|
|
63
|
+
accent_primary="#ff0000",
|
|
64
|
+
)
|
|
65
|
+
assert theme.background == "#000000"
|
|
66
|
+
assert theme.text == "#ffffff"
|
|
67
|
+
assert theme.accent_primary == "#ff0000"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class TestFormatCount:
|
|
71
|
+
"""Tests for the _format_count helper function."""
|
|
72
|
+
|
|
73
|
+
def test_format_small_numbers(self):
|
|
74
|
+
"""Small numbers should be formatted as-is."""
|
|
75
|
+
assert _format_count(0) == "0"
|
|
76
|
+
assert _format_count(1) == "1"
|
|
77
|
+
assert _format_count(999) == "999"
|
|
78
|
+
|
|
79
|
+
def test_format_thousands(self):
|
|
80
|
+
"""Thousands should use K suffix."""
|
|
81
|
+
assert _format_count(1000) == "1.0K"
|
|
82
|
+
assert _format_count(1500) == "1.5K"
|
|
83
|
+
assert _format_count(999999) == "1000.0K" # Just under 1M
|
|
84
|
+
|
|
85
|
+
def test_format_millions(self):
|
|
86
|
+
"""Millions should use M suffix."""
|
|
87
|
+
assert _format_count(1000000) == "1.0M"
|
|
88
|
+
assert _format_count(1500000) == "1.5M"
|
|
89
|
+
assert _format_count(25000000) == "25.0M"
|
|
90
|
+
|
|
91
|
+
def test_format_billions(self):
|
|
92
|
+
"""Billions should use B suffix."""
|
|
93
|
+
assert _format_count(1000000000) == "1.0B"
|
|
94
|
+
assert _format_count(7500000000) == "7.5B"
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class TestVisualizationGenerator:
|
|
98
|
+
"""Tests for the VisualizationGenerator class."""
|
|
99
|
+
|
|
100
|
+
def test_generator_initialization(self):
|
|
101
|
+
"""Generator should initialize without errors."""
|
|
102
|
+
gen = VisualizationGenerator()
|
|
103
|
+
assert gen is not None
|
|
104
|
+
assert gen.logger is not None
|
|
105
|
+
|
|
106
|
+
def test_generator_with_custom_logger(self):
|
|
107
|
+
"""Generator should accept custom logger."""
|
|
108
|
+
import logging
|
|
109
|
+
|
|
110
|
+
logger = logging.getLogger("test")
|
|
111
|
+
gen = VisualizationGenerator(logger=logger)
|
|
112
|
+
assert gen.logger is logger
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@pytest.mark.skipif(not is_available(), reason="matplotlib not installed")
|
|
116
|
+
class TestChartGeneration:
|
|
117
|
+
"""Tests for actual chart generation (requires matplotlib)."""
|
|
118
|
+
|
|
119
|
+
def test_operator_histogram_generation(self):
|
|
120
|
+
"""Should generate operator histogram PNG."""
|
|
121
|
+
gen = VisualizationGenerator()
|
|
122
|
+
op_counts = {
|
|
123
|
+
"Conv": 50,
|
|
124
|
+
"Relu": 48,
|
|
125
|
+
"Add": 25,
|
|
126
|
+
"MatMul": 12,
|
|
127
|
+
"Softmax": 6,
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
131
|
+
output_path = Path(tmpdir) / "op_histogram.png"
|
|
132
|
+
result = gen.operator_histogram(op_counts, output_path)
|
|
133
|
+
|
|
134
|
+
assert result is not None
|
|
135
|
+
assert result.exists()
|
|
136
|
+
assert result.stat().st_size > 0 # File has content
|
|
137
|
+
|
|
138
|
+
def test_param_distribution_generation(self):
|
|
139
|
+
"""Should generate parameter distribution chart."""
|
|
140
|
+
gen = VisualizationGenerator()
|
|
141
|
+
params_by_op = {
|
|
142
|
+
"Conv": 5000000,
|
|
143
|
+
"Gemm": 2000000,
|
|
144
|
+
"MatMul": 1500000,
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
148
|
+
output_path = Path(tmpdir) / "param_dist.png"
|
|
149
|
+
result = gen.param_distribution(params_by_op, output_path)
|
|
150
|
+
|
|
151
|
+
assert result is not None
|
|
152
|
+
assert result.exists()
|
|
153
|
+
|
|
154
|
+
def test_flops_distribution_generation(self):
|
|
155
|
+
"""Should generate FLOPs distribution chart."""
|
|
156
|
+
gen = VisualizationGenerator()
|
|
157
|
+
flops_by_op = {
|
|
158
|
+
"Conv": 500000000,
|
|
159
|
+
"MatMul": 200000000,
|
|
160
|
+
"Gemm": 150000000,
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
164
|
+
output_path = Path(tmpdir) / "flops_dist.png"
|
|
165
|
+
result = gen.flops_distribution(flops_by_op, output_path)
|
|
166
|
+
|
|
167
|
+
assert result is not None
|
|
168
|
+
assert result.exists()
|
|
169
|
+
|
|
170
|
+
def test_empty_data_returns_none(self):
|
|
171
|
+
"""Charts should return None for empty data."""
|
|
172
|
+
gen = VisualizationGenerator()
|
|
173
|
+
|
|
174
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
175
|
+
# Empty dict
|
|
176
|
+
result = gen.operator_histogram({}, Path(tmpdir) / "empty.png")
|
|
177
|
+
assert result is None
|
|
178
|
+
|
|
179
|
+
# All zeros
|
|
180
|
+
result = gen.param_distribution({"Conv": 0, "Relu": 0}, Path(tmpdir) / "zeros.png")
|
|
181
|
+
assert result is None
|
|
182
|
+
|
|
183
|
+
def test_generate_all_creates_multiple_charts(self):
|
|
184
|
+
"""generate_all should create multiple chart files."""
|
|
185
|
+
# Create a mock report
|
|
186
|
+
metadata = ModelMetadata(
|
|
187
|
+
path="test.onnx",
|
|
188
|
+
ir_version=8,
|
|
189
|
+
producer_name="test",
|
|
190
|
+
producer_version="1.0",
|
|
191
|
+
domain="",
|
|
192
|
+
model_version=1,
|
|
193
|
+
doc_string="",
|
|
194
|
+
opsets={"ai.onnx": 17},
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
graph_summary = GraphSummary(
|
|
198
|
+
num_nodes=100,
|
|
199
|
+
num_inputs=1,
|
|
200
|
+
num_outputs=1,
|
|
201
|
+
num_initializers=50,
|
|
202
|
+
input_shapes={"input": [1, 3, 224, 224]},
|
|
203
|
+
output_shapes={"output": [1, 1000]},
|
|
204
|
+
op_type_counts={"Conv": 50, "Relu": 48, "Add": 25},
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
param_counts = ParamCounts(
|
|
208
|
+
total=25000000,
|
|
209
|
+
trainable=25000000,
|
|
210
|
+
by_op_type={"Conv": 20000000, "Gemm": 5000000},
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
flop_counts = FlopCounts(
|
|
214
|
+
total=4000000000,
|
|
215
|
+
by_op_type={"Conv": 3500000000, "Gemm": 500000000},
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
memory_estimates = MemoryEstimates(
|
|
219
|
+
model_size_bytes=100000000,
|
|
220
|
+
peak_activation_bytes=50000000,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
report = InspectionReport(
|
|
224
|
+
metadata=metadata,
|
|
225
|
+
graph_summary=graph_summary,
|
|
226
|
+
param_counts=param_counts,
|
|
227
|
+
flop_counts=flop_counts,
|
|
228
|
+
memory_estimates=memory_estimates,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
gen = VisualizationGenerator()
|
|
232
|
+
|
|
233
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
234
|
+
paths = gen.generate_all(report, Path(tmpdir))
|
|
235
|
+
|
|
236
|
+
assert len(paths) >= 3 # Should generate at least 3 charts
|
|
237
|
+
assert "op_histogram" in paths
|
|
238
|
+
assert "param_distribution" in paths
|
|
239
|
+
assert "complexity_summary" in paths
|
|
240
|
+
|
|
241
|
+
# All files should exist
|
|
242
|
+
for name, path in paths.items():
|
|
243
|
+
assert path.exists(), f"{name} file should exist"
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
@pytest.mark.skipif(not is_available(), reason="matplotlib not installed")
|
|
247
|
+
class TestConvenienceFunction:
|
|
248
|
+
"""Tests for the generate_visualizations convenience function."""
|
|
249
|
+
|
|
250
|
+
def test_generate_visualizations_function(self):
|
|
251
|
+
"""Convenience function should work correctly."""
|
|
252
|
+
metadata = ModelMetadata(
|
|
253
|
+
path="test.onnx",
|
|
254
|
+
ir_version=8,
|
|
255
|
+
producer_name="test",
|
|
256
|
+
producer_version="1.0",
|
|
257
|
+
domain="",
|
|
258
|
+
model_version=1,
|
|
259
|
+
doc_string="",
|
|
260
|
+
opsets={"ai.onnx": 17},
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
graph_summary = GraphSummary(
|
|
264
|
+
num_nodes=10,
|
|
265
|
+
num_inputs=1,
|
|
266
|
+
num_outputs=1,
|
|
267
|
+
num_initializers=5,
|
|
268
|
+
input_shapes={},
|
|
269
|
+
output_shapes={},
|
|
270
|
+
op_type_counts={"Conv": 5, "Relu": 5},
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
param_counts = ParamCounts(total=1000, trainable=1000, by_op_type={"Conv": 1000})
|
|
274
|
+
flop_counts = FlopCounts(total=10000, by_op_type={"Conv": 10000})
|
|
275
|
+
memory_estimates = MemoryEstimates(model_size_bytes=4000, peak_activation_bytes=2000)
|
|
276
|
+
|
|
277
|
+
report = InspectionReport(
|
|
278
|
+
metadata=metadata,
|
|
279
|
+
graph_summary=graph_summary,
|
|
280
|
+
param_counts=param_counts,
|
|
281
|
+
flop_counts=flop_counts,
|
|
282
|
+
memory_estimates=memory_estimates,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
286
|
+
paths = generate_visualizations(report, tmpdir)
|
|
287
|
+
assert isinstance(paths, dict)
|
|
288
|
+
assert len(paths) > 0
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class TestGracefulDegradation:
|
|
292
|
+
"""Tests for graceful degradation when matplotlib is unavailable."""
|
|
293
|
+
|
|
294
|
+
def test_generator_handles_missing_matplotlib(self):
|
|
295
|
+
"""Generator should not crash if matplotlib is unavailable."""
|
|
296
|
+
# This test runs regardless of matplotlib availability
|
|
297
|
+
gen = VisualizationGenerator()
|
|
298
|
+
|
|
299
|
+
# Even with matplotlib, empty data should return empty dict
|
|
300
|
+
metadata = ModelMetadata(
|
|
301
|
+
path="test.onnx",
|
|
302
|
+
ir_version=8,
|
|
303
|
+
producer_name="test",
|
|
304
|
+
producer_version="1.0",
|
|
305
|
+
domain="",
|
|
306
|
+
model_version=1,
|
|
307
|
+
doc_string="",
|
|
308
|
+
opsets={},
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
# Report with no metrics - should return empty dict
|
|
312
|
+
report = InspectionReport(metadata=metadata)
|
|
313
|
+
|
|
314
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
315
|
+
paths = gen.generate_all(report, Path(tmpdir))
|
|
316
|
+
assert isinstance(paths, dict)
|