haoline 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. haoline/.streamlit/config.toml +10 -0
  2. haoline/__init__.py +248 -0
  3. haoline/analyzer.py +935 -0
  4. haoline/cli.py +2712 -0
  5. haoline/compare.py +811 -0
  6. haoline/compare_visualizations.py +1564 -0
  7. haoline/edge_analysis.py +525 -0
  8. haoline/eval/__init__.py +131 -0
  9. haoline/eval/adapters.py +844 -0
  10. haoline/eval/cli.py +390 -0
  11. haoline/eval/comparison.py +542 -0
  12. haoline/eval/deployment.py +633 -0
  13. haoline/eval/schemas.py +833 -0
  14. haoline/examples/__init__.py +15 -0
  15. haoline/examples/basic_inspection.py +74 -0
  16. haoline/examples/compare_models.py +117 -0
  17. haoline/examples/hardware_estimation.py +78 -0
  18. haoline/format_adapters.py +1001 -0
  19. haoline/formats/__init__.py +123 -0
  20. haoline/formats/coreml.py +250 -0
  21. haoline/formats/gguf.py +483 -0
  22. haoline/formats/openvino.py +255 -0
  23. haoline/formats/safetensors.py +273 -0
  24. haoline/formats/tflite.py +369 -0
  25. haoline/hardware.py +2307 -0
  26. haoline/hierarchical_graph.py +462 -0
  27. haoline/html_export.py +1573 -0
  28. haoline/layer_summary.py +769 -0
  29. haoline/llm_summarizer.py +465 -0
  30. haoline/op_icons.py +618 -0
  31. haoline/operational_profiling.py +1492 -0
  32. haoline/patterns.py +1116 -0
  33. haoline/pdf_generator.py +265 -0
  34. haoline/privacy.py +250 -0
  35. haoline/pydantic_models.py +241 -0
  36. haoline/report.py +1923 -0
  37. haoline/report_sections.py +539 -0
  38. haoline/risks.py +521 -0
  39. haoline/schema.py +523 -0
  40. haoline/streamlit_app.py +2024 -0
  41. haoline/tests/__init__.py +4 -0
  42. haoline/tests/conftest.py +123 -0
  43. haoline/tests/test_analyzer.py +868 -0
  44. haoline/tests/test_compare_visualizations.py +293 -0
  45. haoline/tests/test_edge_analysis.py +243 -0
  46. haoline/tests/test_eval.py +604 -0
  47. haoline/tests/test_format_adapters.py +460 -0
  48. haoline/tests/test_hardware.py +237 -0
  49. haoline/tests/test_hardware_recommender.py +90 -0
  50. haoline/tests/test_hierarchical_graph.py +326 -0
  51. haoline/tests/test_html_export.py +180 -0
  52. haoline/tests/test_layer_summary.py +428 -0
  53. haoline/tests/test_llm_patterns.py +540 -0
  54. haoline/tests/test_llm_summarizer.py +339 -0
  55. haoline/tests/test_patterns.py +774 -0
  56. haoline/tests/test_pytorch.py +327 -0
  57. haoline/tests/test_report.py +383 -0
  58. haoline/tests/test_risks.py +398 -0
  59. haoline/tests/test_schema.py +417 -0
  60. haoline/tests/test_tensorflow.py +380 -0
  61. haoline/tests/test_visualizations.py +316 -0
  62. haoline/universal_ir.py +856 -0
  63. haoline/visualizations.py +1086 -0
  64. haoline/visualize_yolo.py +44 -0
  65. haoline/web.py +110 -0
  66. haoline-0.3.0.dist-info/METADATA +471 -0
  67. haoline-0.3.0.dist-info/RECORD +70 -0
  68. haoline-0.3.0.dist-info/WHEEL +4 -0
  69. haoline-0.3.0.dist-info/entry_points.txt +5 -0
  70. haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,868 @@
1
+ # Copyright (c) 2025 HaoLine Contributors
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """
5
+ Unit tests for the analyzer module (parameter counting, FLOP estimation, memory estimates).
6
+
7
+ These tests use programmatically-created tiny ONNX models to ensure deterministic,
8
+ reproducible test results without external dependencies.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ # Import the modules under test
14
+ import sys
15
+ import tempfile
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import onnx
20
+ import pytest
21
+ from onnx import TensorProto, helper
22
+
23
+ # Add parent directories to path for imports
24
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
25
+ from ..analyzer import MetricsEngine, ONNXGraphLoader
26
+
27
+
28
+ def create_simple_conv_model() -> onnx.ModelProto:
29
+ """Create a minimal Conv model for testing."""
30
+ # Input: [1, 3, 8, 8]
31
+ X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 8, 8])
32
+
33
+ # Weight: [16, 3, 3, 3] = 16 * 3 * 3 * 3 = 432 params
34
+ W = helper.make_tensor(
35
+ "W",
36
+ TensorProto.FLOAT,
37
+ [16, 3, 3, 3],
38
+ np.random.randn(16, 3, 3, 3).astype(np.float32).flatten().tolist(),
39
+ )
40
+
41
+ # Bias: [16] = 16 params
42
+ B = helper.make_tensor(
43
+ "B",
44
+ TensorProto.FLOAT,
45
+ [16],
46
+ np.zeros(16, dtype=np.float32).tolist(),
47
+ )
48
+
49
+ # Output: [1, 16, 6, 6]
50
+ Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 16, 6, 6])
51
+
52
+ conv_node = helper.make_node(
53
+ "Conv",
54
+ inputs=["X", "W", "B"],
55
+ outputs=["Y"],
56
+ kernel_shape=[3, 3],
57
+ pads=[0, 0, 0, 0],
58
+ )
59
+
60
+ graph = helper.make_graph(
61
+ [conv_node],
62
+ "conv_test",
63
+ [X],
64
+ [Y],
65
+ [W, B],
66
+ )
67
+
68
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])
69
+ return model
70
+
71
+
72
+ def create_matmul_model() -> onnx.ModelProto:
73
+ """Create a minimal MatMul model for testing."""
74
+ # A: [2, 4, 8]
75
+ A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 4, 8])
76
+
77
+ # B weight: [8, 16] = 128 params
78
+ B = helper.make_tensor(
79
+ "B",
80
+ TensorProto.FLOAT,
81
+ [8, 16],
82
+ np.random.randn(8, 16).astype(np.float32).flatten().tolist(),
83
+ )
84
+
85
+ # Output: [2, 4, 16]
86
+ Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4, 16])
87
+
88
+ matmul_node = helper.make_node(
89
+ "MatMul",
90
+ inputs=["A", "B"],
91
+ outputs=["Y"],
92
+ )
93
+
94
+ graph = helper.make_graph(
95
+ [matmul_node],
96
+ "matmul_test",
97
+ [A],
98
+ [Y],
99
+ [B],
100
+ )
101
+
102
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])
103
+ return model
104
+
105
+
106
+ def create_gemm_model() -> onnx.ModelProto:
107
+ """Create a minimal Gemm model for testing."""
108
+ # A: [4, 8]
109
+ A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [4, 8])
110
+
111
+ # B weight: [8, 16] = 128 params
112
+ B = helper.make_tensor(
113
+ "B",
114
+ TensorProto.FLOAT,
115
+ [8, 16],
116
+ np.random.randn(8, 16).astype(np.float32).flatten().tolist(),
117
+ )
118
+
119
+ # C bias: [16] = 16 params
120
+ C = helper.make_tensor(
121
+ "C",
122
+ TensorProto.FLOAT,
123
+ [16],
124
+ np.zeros(16, dtype=np.float32).tolist(),
125
+ )
126
+
127
+ # Output: [4, 16]
128
+ Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 16])
129
+
130
+ gemm_node = helper.make_node(
131
+ "Gemm",
132
+ inputs=["A", "B", "C"],
133
+ outputs=["Y"],
134
+ )
135
+
136
+ graph = helper.make_graph(
137
+ [gemm_node],
138
+ "gemm_test",
139
+ [A],
140
+ [Y],
141
+ [B, C],
142
+ )
143
+
144
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])
145
+ return model
146
+
147
+
148
+ def create_relu_model() -> onnx.ModelProto:
149
+ """Create a minimal ReLU model (no parameters) for testing."""
150
+ X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 8, 8])
151
+ Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 8, 8])
152
+
153
+ relu_node = helper.make_node("Relu", inputs=["X"], outputs=["Y"])
154
+
155
+ graph = helper.make_graph([relu_node], "relu_test", [X], [Y])
156
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])
157
+ return model
158
+
159
+
160
+ def create_conv_bn_relu_model() -> onnx.ModelProto:
161
+ """Create a Conv-BatchNorm-ReLU sequence for pattern testing."""
162
+ X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 8, 8])
163
+
164
+ # Conv weights: [16, 3, 3, 3] = 432 params
165
+ W = helper.make_tensor(
166
+ "W",
167
+ TensorProto.FLOAT,
168
+ [16, 3, 3, 3],
169
+ np.random.randn(16, 3, 3, 3).astype(np.float32).flatten().tolist(),
170
+ )
171
+
172
+ # BN params: scale, bias, mean, var each [16] = 64 params total
173
+ scale = helper.make_tensor(
174
+ "scale", TensorProto.FLOAT, [16], np.ones(16, dtype=np.float32).tolist()
175
+ )
176
+ bias = helper.make_tensor(
177
+ "bias", TensorProto.FLOAT, [16], np.zeros(16, dtype=np.float32).tolist()
178
+ )
179
+ mean = helper.make_tensor(
180
+ "mean", TensorProto.FLOAT, [16], np.zeros(16, dtype=np.float32).tolist()
181
+ )
182
+ var = helper.make_tensor("var", TensorProto.FLOAT, [16], np.ones(16, dtype=np.float32).tolist())
183
+
184
+ Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 16, 6, 6])
185
+
186
+ conv_out = "conv_out"
187
+ bn_out = "bn_out"
188
+
189
+ conv_node = helper.make_node("Conv", ["X", "W"], [conv_out], kernel_shape=[3, 3])
190
+ bn_node = helper.make_node(
191
+ "BatchNormalization",
192
+ [conv_out, "scale", "bias", "mean", "var"],
193
+ [bn_out],
194
+ )
195
+ relu_node = helper.make_node("Relu", [bn_out], ["Y"])
196
+
197
+ graph = helper.make_graph(
198
+ [conv_node, bn_node, relu_node],
199
+ "conv_bn_relu_test",
200
+ [X],
201
+ [Y],
202
+ [W, scale, bias, mean, var],
203
+ )
204
+
205
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])
206
+ return model
207
+
208
+
209
+ class TestONNXGraphLoader:
210
+ """Tests for ONNXGraphLoader class."""
211
+
212
+ def test_load_conv_model(self):
213
+ """Test loading a simple Conv model."""
214
+ model = create_simple_conv_model()
215
+
216
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
217
+ onnx.save(model, f.name)
218
+ model_path = Path(f.name)
219
+
220
+ try:
221
+ loader = ONNXGraphLoader()
222
+ _loaded_model, graph_info = loader.load(model_path)
223
+
224
+ assert graph_info.num_nodes == 1
225
+ assert len(graph_info.inputs) == 1
226
+ assert len(graph_info.outputs) == 1
227
+ assert len(graph_info.initializers) == 2 # W and B
228
+ assert "Conv" in graph_info.op_type_counts
229
+ finally:
230
+ model_path.unlink()
231
+
232
+ def test_load_extracts_shapes(self):
233
+ """Test that shape information is extracted correctly."""
234
+ model = create_simple_conv_model()
235
+
236
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
237
+ onnx.save(model, f.name)
238
+ model_path = Path(f.name)
239
+
240
+ try:
241
+ loader = ONNXGraphLoader()
242
+ _, graph_info = loader.load(model_path)
243
+
244
+ assert "X" in graph_info.input_shapes
245
+ assert graph_info.input_shapes["X"] == [1, 3, 8, 8]
246
+ assert "Y" in graph_info.output_shapes
247
+ assert graph_info.output_shapes["Y"] == [1, 16, 6, 6]
248
+ finally:
249
+ model_path.unlink()
250
+
251
+
252
+ class TestMetricsEngine:
253
+ """Tests for MetricsEngine class."""
254
+
255
+ def test_count_parameters_conv(self):
256
+ """Test parameter counting for Conv model."""
257
+ model = create_simple_conv_model()
258
+
259
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
260
+ onnx.save(model, f.name)
261
+ model_path = Path(f.name)
262
+
263
+ try:
264
+ loader = ONNXGraphLoader()
265
+ _, graph_info = loader.load(model_path)
266
+
267
+ engine = MetricsEngine()
268
+ params = engine.count_parameters(graph_info)
269
+
270
+ # W: 16*3*3*3 = 432, B: 16 = 448 total
271
+ assert params.total == 448
272
+ assert "Conv" in params.by_op_type
273
+ finally:
274
+ model_path.unlink()
275
+
276
+ def test_count_parameters_matmul(self):
277
+ """Test parameter counting for MatMul model."""
278
+ model = create_matmul_model()
279
+
280
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
281
+ onnx.save(model, f.name)
282
+ model_path = Path(f.name)
283
+
284
+ try:
285
+ loader = ONNXGraphLoader()
286
+ _, graph_info = loader.load(model_path)
287
+
288
+ engine = MetricsEngine()
289
+ params = engine.count_parameters(graph_info)
290
+
291
+ # B: 8*16 = 128
292
+ assert params.total == 128
293
+ finally:
294
+ model_path.unlink()
295
+
296
+ def test_count_parameters_no_weights(self):
297
+ """Test parameter counting for model without weights."""
298
+ model = create_relu_model()
299
+
300
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
301
+ onnx.save(model, f.name)
302
+ model_path = Path(f.name)
303
+
304
+ try:
305
+ loader = ONNXGraphLoader()
306
+ _, graph_info = loader.load(model_path)
307
+
308
+ engine = MetricsEngine()
309
+ params = engine.count_parameters(graph_info)
310
+
311
+ assert params.total == 0
312
+ finally:
313
+ model_path.unlink()
314
+
315
+ def test_estimate_flops_conv(self):
316
+ """Test FLOP estimation for Conv model."""
317
+ model = create_simple_conv_model()
318
+
319
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
320
+ onnx.save(model, f.name)
321
+ model_path = Path(f.name)
322
+
323
+ try:
324
+ loader = ONNXGraphLoader()
325
+ _, graph_info = loader.load(model_path)
326
+
327
+ engine = MetricsEngine()
328
+ flops = engine.estimate_flops(graph_info)
329
+
330
+ # Conv FLOPs: 2 * K_h * K_w * C_in * C_out * H_out * W_out
331
+ # = 2 * 3 * 3 * 3 * 16 * 6 * 6 + bias = 31,104 + 576 = 31,680
332
+ expected_flops = 2 * 3 * 3 * 3 * 16 * 6 * 6 + 16 * 6 * 6
333
+ assert flops.total == expected_flops
334
+ finally:
335
+ model_path.unlink()
336
+
337
+ def test_estimate_flops_matmul(self):
338
+ """Test FLOP estimation for MatMul model."""
339
+ model = create_matmul_model()
340
+
341
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
342
+ onnx.save(model, f.name)
343
+ model_path = Path(f.name)
344
+
345
+ try:
346
+ loader = ONNXGraphLoader()
347
+ _, graph_info = loader.load(model_path)
348
+
349
+ engine = MetricsEngine()
350
+ flops = engine.estimate_flops(graph_info)
351
+
352
+ # MatMul FLOPs: 2 * batch * M * N * K
353
+ # = 2 * 2 * 4 * 16 * 8 = 2048
354
+ expected_flops = 2 * 2 * 4 * 16 * 8
355
+ assert flops.total == expected_flops
356
+ finally:
357
+ model_path.unlink()
358
+
359
+ def test_estimate_memory(self):
360
+ """Test memory estimation."""
361
+ model = create_simple_conv_model()
362
+
363
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
364
+ onnx.save(model, f.name)
365
+ model_path = Path(f.name)
366
+
367
+ try:
368
+ loader = ONNXGraphLoader()
369
+ _, graph_info = loader.load(model_path)
370
+
371
+ engine = MetricsEngine()
372
+ memory = engine.estimate_memory(graph_info)
373
+
374
+ # Model size: 448 params * 4 bytes = 1792 bytes
375
+ assert memory.model_size_bytes == 448 * 4
376
+ assert memory.peak_activation_bytes >= 0
377
+ finally:
378
+ model_path.unlink()
379
+
380
+
381
+ class TestMetricsEngineEdgeCases:
382
+ """Edge case tests for MetricsEngine."""
383
+
384
+ def test_gemm_with_bias(self):
385
+ """Test Gemm with bias adds extra FLOPs."""
386
+ model = create_gemm_model()
387
+
388
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
389
+ onnx.save(model, f.name)
390
+ model_path = Path(f.name)
391
+
392
+ try:
393
+ loader = ONNXGraphLoader()
394
+ _, graph_info = loader.load(model_path)
395
+
396
+ engine = MetricsEngine()
397
+ params = engine.count_parameters(graph_info)
398
+ flops = engine.estimate_flops(graph_info)
399
+
400
+ # B: 8*16=128, C: 16 = 144 total params
401
+ assert params.total == 144
402
+
403
+ # Gemm FLOPs: 2*M*N*K + M*N (bias) = 2*4*16*8 + 4*16 = 1024 + 64 = 1088
404
+ expected_flops = 2 * 4 * 16 * 8 + 4 * 16
405
+ assert flops.total == expected_flops
406
+ finally:
407
+ model_path.unlink()
408
+
409
+
410
+ def create_transformer_like_model() -> onnx.ModelProto:
411
+ """Create a minimal transformer-like model with Softmax (for KV cache testing)."""
412
+ # Input: [1, 128, 768] - batch, seq, hidden
413
+ X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 128, 768])
414
+
415
+ # QKV projection weights: [768, 768] = 589,824 params each
416
+ Wq = helper.make_tensor(
417
+ "Wq",
418
+ TensorProto.FLOAT,
419
+ [768, 768],
420
+ np.random.randn(768, 768).astype(np.float32).flatten().tolist(),
421
+ )
422
+ Wk = helper.make_tensor(
423
+ "Wk",
424
+ TensorProto.FLOAT,
425
+ [768, 768],
426
+ np.random.randn(768, 768).astype(np.float32).flatten().tolist(),
427
+ )
428
+ Wv = helper.make_tensor(
429
+ "Wv",
430
+ TensorProto.FLOAT,
431
+ [768, 768],
432
+ np.random.randn(768, 768).astype(np.float32).flatten().tolist(),
433
+ )
434
+
435
+ Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 128, 768])
436
+
437
+ # Create Q, K, V projections
438
+ q_node = helper.make_node("MatMul", ["X", "Wq"], ["Q"])
439
+ k_node = helper.make_node("MatMul", ["X", "Wk"], ["K"])
440
+ v_node = helper.make_node("MatMul", ["X", "Wv"], ["V"])
441
+
442
+ # Softmax for attention scores
443
+ softmax_node = helper.make_node("Softmax", ["Q"], ["attn_scores"], axis=-1)
444
+
445
+ # Output projection (simplified - just use V as output for testing)
446
+ add_node = helper.make_node("Add", ["attn_scores", "V"], ["Y"])
447
+
448
+ graph = helper.make_graph(
449
+ [q_node, k_node, v_node, softmax_node, add_node],
450
+ "transformer_test",
451
+ [X],
452
+ [Y],
453
+ [Wq, Wk, Wv],
454
+ )
455
+
456
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])
457
+ return model
458
+
459
+
460
+ class TestKVCacheEstimation:
461
+ """Tests for KV cache estimation in transformer models."""
462
+
463
+ def test_kv_cache_detected_for_transformer(self):
464
+ """Test that KV cache is estimated for transformer-like models."""
465
+ model = create_transformer_like_model()
466
+
467
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
468
+ onnx.save(model, f.name)
469
+ model_path = Path(f.name)
470
+
471
+ try:
472
+ loader = ONNXGraphLoader()
473
+ _, graph_info = loader.load(model_path)
474
+
475
+ engine = MetricsEngine()
476
+ memory = engine.estimate_memory(graph_info)
477
+
478
+ # Should detect KV cache for transformer model
479
+ assert memory.kv_cache_bytes_per_token > 0
480
+ assert memory.kv_cache_bytes_full_context > 0
481
+ assert "num_layers" in memory.kv_cache_config
482
+ assert "hidden_dim" in memory.kv_cache_config
483
+ finally:
484
+ model_path.unlink()
485
+
486
+ def test_kv_cache_not_detected_for_cnn(self):
487
+ """Test that KV cache is not estimated for CNN models."""
488
+ model = create_simple_conv_model()
489
+
490
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
491
+ onnx.save(model, f.name)
492
+ model_path = Path(f.name)
493
+
494
+ try:
495
+ loader = ONNXGraphLoader()
496
+ _, graph_info = loader.load(model_path)
497
+
498
+ engine = MetricsEngine()
499
+ memory = engine.estimate_memory(graph_info)
500
+
501
+ # Should NOT detect KV cache for CNN model
502
+ assert memory.kv_cache_bytes_per_token == 0
503
+ assert memory.kv_cache_bytes_full_context == 0
504
+ finally:
505
+ model_path.unlink()
506
+
507
+ def test_kv_cache_formula(self):
508
+ """Test KV cache per-token formula: 2 * num_layers * hidden_dim * bytes."""
509
+ model = create_transformer_like_model()
510
+
511
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
512
+ onnx.save(model, f.name)
513
+ model_path = Path(f.name)
514
+
515
+ try:
516
+ loader = ONNXGraphLoader()
517
+ _, graph_info = loader.load(model_path)
518
+
519
+ engine = MetricsEngine()
520
+ memory = engine.estimate_memory(graph_info)
521
+
522
+ config = memory.kv_cache_config
523
+ if config:
524
+ # Verify formula: 2 * layers * hidden * bytes_per_elem
525
+ expected_per_token = (
526
+ 2 * config["num_layers"] * config["hidden_dim"] * config["bytes_per_elem"]
527
+ )
528
+ assert memory.kv_cache_bytes_per_token == expected_per_token
529
+
530
+ # Full context = per_token * seq_len
531
+ expected_full = expected_per_token * config["seq_len"]
532
+ assert memory.kv_cache_bytes_full_context == expected_full
533
+ finally:
534
+ model_path.unlink()
535
+
536
+ def test_memory_estimates_to_dict_includes_kv_cache(self):
537
+ """Test that to_dict includes KV cache info when present."""
538
+ model = create_transformer_like_model()
539
+
540
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
541
+ onnx.save(model, f.name)
542
+ model_path = Path(f.name)
543
+
544
+ try:
545
+ loader = ONNXGraphLoader()
546
+ _, graph_info = loader.load(model_path)
547
+
548
+ engine = MetricsEngine()
549
+ memory = engine.estimate_memory(graph_info)
550
+
551
+ result = memory.to_dict()
552
+ assert "kv_cache_bytes_per_token" in result
553
+ assert "kv_cache_bytes_full_context" in result
554
+ assert "kv_cache_config" in result
555
+ finally:
556
+ model_path.unlink()
557
+
558
+
559
+ def create_shared_weights_model() -> onnx.ModelProto:
560
+ """Create a model with shared weights (same weight used by two nodes)."""
561
+ # Input: [1, 8]
562
+ X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 8])
563
+
564
+ # Shared weight: [8, 16] = 128 params
565
+ W_shared = helper.make_tensor(
566
+ "W_shared",
567
+ TensorProto.FLOAT,
568
+ [8, 16],
569
+ np.random.randn(8, 16).astype(np.float32).flatten().tolist(),
570
+ )
571
+
572
+ # Output: [1, 16]
573
+ Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 16])
574
+
575
+ # Two MatMul nodes using the same weight
576
+ matmul1 = helper.make_node("MatMul", ["X", "W_shared"], ["hidden"], name="MatMul1")
577
+ matmul2 = helper.make_node("MatMul", ["hidden", "W_shared"], ["Y"], name="MatMul2")
578
+
579
+ graph = helper.make_graph(
580
+ [matmul1, matmul2],
581
+ "shared_weights_test",
582
+ [X],
583
+ [Y],
584
+ [W_shared],
585
+ )
586
+
587
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])
588
+ return model
589
+
590
+
591
+ def create_int8_weights_model() -> onnx.ModelProto:
592
+ """Create a model with INT8 quantized weights."""
593
+ # Input: [1, 8]
594
+ X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 8])
595
+
596
+ # INT8 weight: [8, 16] = 128 params
597
+ W_int8 = helper.make_tensor(
598
+ "W_int8",
599
+ TensorProto.INT8,
600
+ [8, 16],
601
+ np.random.randint(-128, 127, (8, 16), dtype=np.int8).flatten().tolist(),
602
+ )
603
+
604
+ # Scale and zero point for dequantization
605
+ scale = helper.make_tensor("scale", TensorProto.FLOAT, [], [0.01])
606
+ zero_point = helper.make_tensor("zero_point", TensorProto.INT8, [], [0])
607
+
608
+ # Output: [1, 16]
609
+ Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 16])
610
+
611
+ # Dequantize the weight, then MatMul
612
+ dequant_node = helper.make_node(
613
+ "DequantizeLinear", ["W_int8", "scale", "zero_point"], ["W_float"]
614
+ )
615
+ matmul_node = helper.make_node("MatMul", ["X", "W_float"], ["Y"])
616
+
617
+ graph = helper.make_graph(
618
+ [dequant_node, matmul_node],
619
+ "int8_weights_test",
620
+ [X],
621
+ [Y],
622
+ [W_int8, scale, zero_point],
623
+ )
624
+
625
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])
626
+ return model
627
+
628
+
629
+ def create_mixed_precision_model() -> onnx.ModelProto:
630
+ """Create a model with mixed precision weights (fp32, fp16)."""
631
+ # Input: [1, 8]
632
+ X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 8])
633
+
634
+ # FP32 weight: [8, 16] = 128 params
635
+ W_fp32 = helper.make_tensor(
636
+ "W_fp32",
637
+ TensorProto.FLOAT,
638
+ [8, 16],
639
+ np.random.randn(8, 16).astype(np.float32).flatten().tolist(),
640
+ )
641
+
642
+ # FP16 weight: [16, 8] = 128 params
643
+ W_fp16 = helper.make_tensor(
644
+ "W_fp16",
645
+ TensorProto.FLOAT16,
646
+ [16, 8],
647
+ np.random.randn(16, 8).astype(np.float16).flatten().tolist(),
648
+ )
649
+
650
+ # Output: [1, 8]
651
+ Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 8])
652
+
653
+ # Two MatMul nodes with different precision weights
654
+ matmul1 = helper.make_node("MatMul", ["X", "W_fp32"], ["hidden"])
655
+ cast_node = helper.make_node("Cast", ["W_fp16"], ["W_fp16_casted"], to=TensorProto.FLOAT)
656
+ matmul2 = helper.make_node("MatMul", ["hidden", "W_fp16_casted"], ["Y"])
657
+
658
+ graph = helper.make_graph(
659
+ [matmul1, cast_node, matmul2],
660
+ "mixed_precision_test",
661
+ [X],
662
+ [Y],
663
+ [W_fp32, W_fp16],
664
+ )
665
+
666
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])
667
+ return model
668
+
669
+
670
+ class TestSharedWeights:
671
+ """Tests for shared weight handling (Task 2.2.4)."""
672
+
673
+ def test_shared_weights_detected(self):
674
+ """Test that shared weights are correctly detected."""
675
+ model = create_shared_weights_model()
676
+
677
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
678
+ onnx.save(model, f.name)
679
+ model_path = Path(f.name)
680
+
681
+ try:
682
+ loader = ONNXGraphLoader()
683
+ _, graph_info = loader.load(model_path)
684
+
685
+ engine = MetricsEngine()
686
+ params = engine.count_parameters(graph_info)
687
+
688
+ # Should detect 1 shared weight
689
+ assert params.num_shared_weights == 1
690
+ assert "W_shared" in params.shared_weights
691
+ assert len(params.shared_weights["W_shared"]) == 2
692
+ finally:
693
+ model_path.unlink()
694
+
695
+ def test_shared_weights_fractional_attribution(self):
696
+ """Test that shared weights use fractional attribution."""
697
+ model = create_shared_weights_model()
698
+
699
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
700
+ onnx.save(model, f.name)
701
+ model_path = Path(f.name)
702
+
703
+ try:
704
+ loader = ONNXGraphLoader()
705
+ _, graph_info = loader.load(model_path)
706
+
707
+ engine = MetricsEngine()
708
+ params = engine.count_parameters(graph_info)
709
+
710
+ # Total should still be 128 (8*16)
711
+ assert params.total == 128
712
+
713
+ # by_op_type should sum to 128 (fractional attribution)
714
+ op_type_sum = sum(params.by_op_type.values())
715
+ assert abs(op_type_sum - 128) < 0.01 # Allow floating point tolerance
716
+
717
+ # MatMul should have the full 128 (64 + 64 from two nodes)
718
+ assert "MatMul" in params.by_op_type
719
+ assert abs(params.by_op_type["MatMul"] - 128) < 0.01
720
+ finally:
721
+ model_path.unlink()
722
+
723
+ def test_no_shared_weights_normal_model(self):
724
+ """Test that normal models report 0 shared weights."""
725
+ model = create_simple_conv_model()
726
+
727
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
728
+ onnx.save(model, f.name)
729
+ model_path = Path(f.name)
730
+
731
+ try:
732
+ loader = ONNXGraphLoader()
733
+ _, graph_info = loader.load(model_path)
734
+
735
+ engine = MetricsEngine()
736
+ params = engine.count_parameters(graph_info)
737
+
738
+ # Should have no shared weights
739
+ assert params.num_shared_weights == 0
740
+ assert len(params.shared_weights) == 0
741
+ finally:
742
+ model_path.unlink()
743
+
744
+
745
+ class TestQuantizedParams:
746
+ """Tests for quantized parameter detection (Task 2.2.4)."""
747
+
748
+ def test_int8_weights_detected(self):
749
+ """Test that INT8 weights are detected."""
750
+ model = create_int8_weights_model()
751
+
752
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
753
+ onnx.save(model, f.name)
754
+ model_path = Path(f.name)
755
+
756
+ try:
757
+ loader = ONNXGraphLoader()
758
+ _, graph_info = loader.load(model_path)
759
+
760
+ engine = MetricsEngine()
761
+ params = engine.count_parameters(graph_info)
762
+
763
+ # Should detect quantization
764
+ assert params.is_quantized is True
765
+ assert "DequantizeLinear" in params.quantized_ops
766
+ finally:
767
+ model_path.unlink()
768
+
769
+ def test_precision_breakdown(self):
770
+ """Test that precision breakdown is computed correctly."""
771
+ model = create_int8_weights_model()
772
+
773
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
774
+ onnx.save(model, f.name)
775
+ model_path = Path(f.name)
776
+
777
+ try:
778
+ loader = ONNXGraphLoader()
779
+ _, graph_info = loader.load(model_path)
780
+
781
+ engine = MetricsEngine()
782
+ params = engine.count_parameters(graph_info)
783
+
784
+ # Should have precision breakdown
785
+ assert len(params.precision_breakdown) > 0
786
+ # INT8 weight: 8*16 = 128 params + zero_point (1) = 129
787
+ assert "int8" in params.precision_breakdown
788
+ assert params.precision_breakdown["int8"] == 129
789
+ finally:
790
+ model_path.unlink()
791
+
792
+ def test_mixed_precision_breakdown(self):
793
+ """Test precision breakdown for mixed precision model."""
794
+ model = create_mixed_precision_model()
795
+
796
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
797
+ onnx.save(model, f.name)
798
+ model_path = Path(f.name)
799
+
800
+ try:
801
+ loader = ONNXGraphLoader()
802
+ _, graph_info = loader.load(model_path)
803
+
804
+ engine = MetricsEngine()
805
+ params = engine.count_parameters(graph_info)
806
+
807
+ # Should have precision breakdown with both precisions
808
+ assert "fp32" in params.precision_breakdown
809
+ assert "fp16" in params.precision_breakdown
810
+ assert params.precision_breakdown["fp32"] == 128 # 8*16
811
+ assert params.precision_breakdown["fp16"] == 128 # 16*8
812
+
813
+ # Total should be 256
814
+ assert params.total == 256
815
+ finally:
816
+ model_path.unlink()
817
+
818
+ def test_non_quantized_model(self):
819
+ """Test that non-quantized models are not marked as quantized."""
820
+ model = create_simple_conv_model()
821
+
822
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
823
+ onnx.save(model, f.name)
824
+ model_path = Path(f.name)
825
+
826
+ try:
827
+ loader = ONNXGraphLoader()
828
+ _, graph_info = loader.load(model_path)
829
+
830
+ engine = MetricsEngine()
831
+ params = engine.count_parameters(graph_info)
832
+
833
+ # Should not be quantized
834
+ assert params.is_quantized is False
835
+ assert len(params.quantized_ops) == 0
836
+ finally:
837
+ model_path.unlink()
838
+
839
+ def test_param_counts_to_dict_includes_new_fields(self):
840
+ """Test that to_dict includes shared weights and quantization info."""
841
+ model = create_int8_weights_model()
842
+
843
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
844
+ onnx.save(model, f.name)
845
+ model_path = Path(f.name)
846
+
847
+ try:
848
+ loader = ONNXGraphLoader()
849
+ _, graph_info = loader.load(model_path)
850
+
851
+ engine = MetricsEngine()
852
+ params = engine.count_parameters(graph_info)
853
+
854
+ result = params.to_dict()
855
+
856
+ # Check new fields exist
857
+ assert "shared_weights" in result
858
+ assert "count" in result["shared_weights"]
859
+ assert "details" in result["shared_weights"]
860
+ assert "precision_breakdown" in result
861
+ assert "is_quantized" in result
862
+ assert "quantized_ops" in result
863
+ finally:
864
+ model_path.unlink()
865
+
866
+
867
+ if __name__ == "__main__":
868
+ pytest.main([__file__, "-v"])