haoline 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. haoline/.streamlit/config.toml +10 -0
  2. haoline/__init__.py +248 -0
  3. haoline/analyzer.py +935 -0
  4. haoline/cli.py +2712 -0
  5. haoline/compare.py +811 -0
  6. haoline/compare_visualizations.py +1564 -0
  7. haoline/edge_analysis.py +525 -0
  8. haoline/eval/__init__.py +131 -0
  9. haoline/eval/adapters.py +844 -0
  10. haoline/eval/cli.py +390 -0
  11. haoline/eval/comparison.py +542 -0
  12. haoline/eval/deployment.py +633 -0
  13. haoline/eval/schemas.py +833 -0
  14. haoline/examples/__init__.py +15 -0
  15. haoline/examples/basic_inspection.py +74 -0
  16. haoline/examples/compare_models.py +117 -0
  17. haoline/examples/hardware_estimation.py +78 -0
  18. haoline/format_adapters.py +1001 -0
  19. haoline/formats/__init__.py +123 -0
  20. haoline/formats/coreml.py +250 -0
  21. haoline/formats/gguf.py +483 -0
  22. haoline/formats/openvino.py +255 -0
  23. haoline/formats/safetensors.py +273 -0
  24. haoline/formats/tflite.py +369 -0
  25. haoline/hardware.py +2307 -0
  26. haoline/hierarchical_graph.py +462 -0
  27. haoline/html_export.py +1573 -0
  28. haoline/layer_summary.py +769 -0
  29. haoline/llm_summarizer.py +465 -0
  30. haoline/op_icons.py +618 -0
  31. haoline/operational_profiling.py +1492 -0
  32. haoline/patterns.py +1116 -0
  33. haoline/pdf_generator.py +265 -0
  34. haoline/privacy.py +250 -0
  35. haoline/pydantic_models.py +241 -0
  36. haoline/report.py +1923 -0
  37. haoline/report_sections.py +539 -0
  38. haoline/risks.py +521 -0
  39. haoline/schema.py +523 -0
  40. haoline/streamlit_app.py +2024 -0
  41. haoline/tests/__init__.py +4 -0
  42. haoline/tests/conftest.py +123 -0
  43. haoline/tests/test_analyzer.py +868 -0
  44. haoline/tests/test_compare_visualizations.py +293 -0
  45. haoline/tests/test_edge_analysis.py +243 -0
  46. haoline/tests/test_eval.py +604 -0
  47. haoline/tests/test_format_adapters.py +460 -0
  48. haoline/tests/test_hardware.py +237 -0
  49. haoline/tests/test_hardware_recommender.py +90 -0
  50. haoline/tests/test_hierarchical_graph.py +326 -0
  51. haoline/tests/test_html_export.py +180 -0
  52. haoline/tests/test_layer_summary.py +428 -0
  53. haoline/tests/test_llm_patterns.py +540 -0
  54. haoline/tests/test_llm_summarizer.py +339 -0
  55. haoline/tests/test_patterns.py +774 -0
  56. haoline/tests/test_pytorch.py +327 -0
  57. haoline/tests/test_report.py +383 -0
  58. haoline/tests/test_risks.py +398 -0
  59. haoline/tests/test_schema.py +417 -0
  60. haoline/tests/test_tensorflow.py +380 -0
  61. haoline/tests/test_visualizations.py +316 -0
  62. haoline/universal_ir.py +856 -0
  63. haoline/visualizations.py +1086 -0
  64. haoline/visualize_yolo.py +44 -0
  65. haoline/web.py +110 -0
  66. haoline-0.3.0.dist-info/METADATA +471 -0
  67. haoline-0.3.0.dist-info/RECORD +70 -0
  68. haoline-0.3.0.dist-info/WHEEL +4 -0
  69. haoline-0.3.0.dist-info/entry_points.txt +5 -0
  70. haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,293 @@
1
+ #!/usr/bin/env python
2
+ # Copyright (c) 2025 HaoLine Contributors
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """Unit tests for compare_visualizations module."""
6
+
7
+ import tempfile
8
+ from pathlib import Path
9
+ from unittest import TestCase
10
+
11
+ from ..compare_visualizations import (
12
+ CalibrationRecommendation,
13
+ TradeoffPoint,
14
+ analyze_tradeoffs,
15
+ build_enhanced_markdown,
16
+ compute_tradeoff_points,
17
+ generate_calibration_recommendations,
18
+ generate_memory_savings_chart,
19
+ generate_tradeoff_chart,
20
+ is_available,
21
+ )
22
+
23
+
24
+ def create_sample_compare_json() -> dict:
25
+ """Create a sample comparison JSON for testing."""
26
+ return {
27
+ "model_family_id": "resnet50_test",
28
+ "baseline_precision": "fp32",
29
+ "architecture_compatible": True,
30
+ "compatibility_warnings": [],
31
+ "variants": [
32
+ {
33
+ "precision": "fp32",
34
+ "quantization_scheme": "none",
35
+ "model_path": "resnet_fp32.onnx",
36
+ "size_bytes": 102400000,
37
+ "total_params": 25000000,
38
+ "total_flops": 4100000000,
39
+ "memory_bytes": 102400000,
40
+ "metrics": {
41
+ "f1_macro": 0.931,
42
+ "latency_ms_p50": 14.5,
43
+ "throughput_qps": 680,
44
+ },
45
+ "hardware_estimates": None,
46
+ "deltas_vs_baseline": None,
47
+ },
48
+ {
49
+ "precision": "fp16",
50
+ "quantization_scheme": "fp16",
51
+ "model_path": "resnet_fp16.onnx",
52
+ "size_bytes": 51200000,
53
+ "total_params": 25000000,
54
+ "total_flops": 4100000000,
55
+ "memory_bytes": 51200000,
56
+ "metrics": {
57
+ "f1_macro": 0.929,
58
+ "latency_ms_p50": 9.1,
59
+ "throughput_qps": 1080,
60
+ },
61
+ "hardware_estimates": None,
62
+ "deltas_vs_baseline": {
63
+ "size_bytes": -51200000,
64
+ "f1_macro": -0.002,
65
+ "latency_ms_p50": -5.4,
66
+ },
67
+ },
68
+ {
69
+ "precision": "int8",
70
+ "quantization_scheme": "int8",
71
+ "model_path": "resnet_int8.onnx",
72
+ "size_bytes": 25600000,
73
+ "total_params": 25000000,
74
+ "total_flops": 4100000000,
75
+ "memory_bytes": 25600000,
76
+ "metrics": {
77
+ "f1_macro": 0.915,
78
+ "latency_ms_p50": 5.2,
79
+ "throughput_qps": 1850,
80
+ },
81
+ "hardware_estimates": None,
82
+ "deltas_vs_baseline": {
83
+ "size_bytes": -76800000,
84
+ "f1_macro": -0.016,
85
+ "latency_ms_p50": -9.3,
86
+ },
87
+ },
88
+ ],
89
+ }
90
+
91
+
92
+ class TestIsAvailable(TestCase):
93
+ """Test visualization availability check."""
94
+
95
+ def test_is_available_returns_bool(self):
96
+ """is_available should return a boolean."""
97
+ result = is_available()
98
+ self.assertIsInstance(result, bool)
99
+
100
+
101
+ class TestComputeTradeoffPoints(TestCase):
102
+ """Tests for compute_tradeoff_points function."""
103
+
104
+ def test_basic_computation(self):
105
+ """Should compute tradeoff points for all variants."""
106
+ compare_json = create_sample_compare_json()
107
+ points = compute_tradeoff_points(compare_json)
108
+
109
+ self.assertEqual(len(points), 3)
110
+ self.assertIsInstance(points[0], TradeoffPoint)
111
+
112
+ def test_baseline_speedup_is_one(self):
113
+ """Baseline variant should have speedup of 1.0."""
114
+ compare_json = create_sample_compare_json()
115
+ points = compute_tradeoff_points(compare_json)
116
+
117
+ # Find fp32 (baseline)
118
+ fp32_point = next(p for p in points if p.precision == "fp32")
119
+ self.assertAlmostEqual(fp32_point.speedup, 1.0, places=2)
120
+ self.assertAlmostEqual(fp32_point.accuracy_delta, 0.0, places=5)
121
+
122
+ def test_faster_variant_has_higher_speedup(self):
123
+ """Faster variant should have speedup > 1.0."""
124
+ compare_json = create_sample_compare_json()
125
+ points = compute_tradeoff_points(compare_json)
126
+
127
+ int8_point = next(p for p in points if p.precision == "int8")
128
+ # int8 is faster: 14.5 / 5.2 ≈ 2.79x speedup
129
+ self.assertGreater(int8_point.speedup, 2.0)
130
+
131
+ def test_accuracy_delta_computed(self):
132
+ """Should compute accuracy delta relative to baseline."""
133
+ compare_json = create_sample_compare_json()
134
+ points = compute_tradeoff_points(compare_json)
135
+
136
+ fp16_point = next(p for p in points if p.precision == "fp16")
137
+ # fp16: 0.929 - 0.931 = -0.002
138
+ self.assertAlmostEqual(fp16_point.accuracy_delta, -0.002, places=5)
139
+
140
+ def test_empty_variants(self):
141
+ """Should handle empty variants list."""
142
+ compare_json = {"variants": []}
143
+ points = compute_tradeoff_points(compare_json)
144
+ self.assertEqual(len(points), 0)
145
+
146
+
147
+ class TestAnalyzeTradeoffs(TestCase):
148
+ """Tests for analyze_tradeoffs function."""
149
+
150
+ def test_returns_recommendations(self):
151
+ """Should return analysis with recommendations."""
152
+ compare_json = create_sample_compare_json()
153
+ analysis = analyze_tradeoffs(compare_json)
154
+
155
+ self.assertIn("recommendations", analysis)
156
+ self.assertIsInstance(analysis["recommendations"], list)
157
+
158
+ def test_identifies_best_variants(self):
159
+ """Should identify best variants for different criteria."""
160
+ compare_json = create_sample_compare_json()
161
+ analysis = analyze_tradeoffs(compare_json)
162
+
163
+ self.assertIn("best_speed", analysis)
164
+ self.assertIn("smallest", analysis)
165
+
166
+ def test_tradeoff_points_included(self):
167
+ """Should include tradeoff points in analysis."""
168
+ compare_json = create_sample_compare_json()
169
+ analysis = analyze_tradeoffs(compare_json)
170
+
171
+ self.assertIn("tradeoff_points", analysis)
172
+ self.assertEqual(len(analysis["tradeoff_points"]), 3)
173
+
174
+
175
+ class TestCalibrationRecommendations(TestCase):
176
+ """Tests for generate_calibration_recommendations function."""
177
+
178
+ def test_returns_recommendations(self):
179
+ """Should return list of CalibrationRecommendation objects."""
180
+ compare_json = create_sample_compare_json()
181
+ recs = generate_calibration_recommendations(compare_json)
182
+
183
+ self.assertIsInstance(recs, list)
184
+ for rec in recs:
185
+ self.assertIsInstance(rec, CalibrationRecommendation)
186
+
187
+ def test_int8_accuracy_warning(self):
188
+ """Should warn about INT8 accuracy drop > 2%."""
189
+ # Modify to have significant INT8 accuracy drop
190
+ compare_json = create_sample_compare_json()
191
+ # INT8 has 1.6% drop, which is below threshold, so let's make it worse
192
+ compare_json["variants"][2]["metrics"]["f1_macro"] = 0.90 # 3.1% drop
193
+
194
+ recs = generate_calibration_recommendations(compare_json)
195
+
196
+ # Should have warning about calibration
197
+ warnings = [r for r in recs if r.severity == "warning"]
198
+ self.assertGreater(len(warnings), 0)
199
+
200
+
201
+ class TestBuildEnhancedMarkdown(TestCase):
202
+ """Tests for build_enhanced_markdown function."""
203
+
204
+ def test_basic_markdown_generation(self):
205
+ """Should generate valid Markdown."""
206
+ compare_json = create_sample_compare_json()
207
+ md = build_enhanced_markdown(compare_json, include_charts=False)
208
+
209
+ self.assertIn("# Quantization Impact Report", md)
210
+ self.assertIn("resnet50_test", md)
211
+ self.assertIn("FP32", md)
212
+ self.assertIn("FP16", md)
213
+ self.assertIn("INT8", md)
214
+
215
+ def test_includes_tradeoff_analysis(self):
216
+ """Should include trade-off analysis section."""
217
+ compare_json = create_sample_compare_json()
218
+ md = build_enhanced_markdown(compare_json, include_charts=False)
219
+
220
+ self.assertIn("## Trade-off Analysis", md)
221
+ self.assertIn("### Recommendations", md)
222
+
223
+ def test_includes_variant_table(self):
224
+ """Should include variant comparison table."""
225
+ compare_json = create_sample_compare_json()
226
+ md = build_enhanced_markdown(compare_json, include_charts=False)
227
+
228
+ self.assertIn("## Variant Comparison", md)
229
+ self.assertIn("| Precision |", md)
230
+
231
+ def test_compatibility_warnings_shown(self):
232
+ """Should show compatibility warnings if present."""
233
+ compare_json = create_sample_compare_json()
234
+ compare_json["architecture_compatible"] = False
235
+ compare_json["compatibility_warnings"] = ["Test warning"]
236
+
237
+ md = build_enhanced_markdown(compare_json, include_charts=False)
238
+ self.assertIn("⚠️ Compatibility Warnings", md)
239
+ self.assertIn("Test warning", md)
240
+
241
+
242
+ class TestChartGeneration(TestCase):
243
+ """Tests for chart generation functions."""
244
+
245
+ def test_tradeoff_chart_no_matplotlib(self):
246
+ """Should return None if matplotlib not available."""
247
+ # This test will pass if matplotlib IS available, as it will
248
+ # generate the chart. We just verify it doesn't crash.
249
+ compare_json = create_sample_compare_json()
250
+ points = compute_tradeoff_points(compare_json)
251
+
252
+ result = generate_tradeoff_chart(points)
253
+ # Result is either bytes (matplotlib available) or None
254
+ if is_available():
255
+ self.assertIsInstance(result, bytes)
256
+ self.assertGreater(len(result), 0)
257
+ else:
258
+ self.assertIsNone(result)
259
+
260
+ def test_tradeoff_chart_to_file(self):
261
+ """Should save chart to file if path provided."""
262
+ if not is_available():
263
+ self.skipTest("matplotlib not available")
264
+
265
+ compare_json = create_sample_compare_json()
266
+ points = compute_tradeoff_points(compare_json)
267
+
268
+ with tempfile.TemporaryDirectory() as tmpdir:
269
+ output_path = Path(tmpdir) / "tradeoff.png"
270
+ result = generate_tradeoff_chart(points, output_path)
271
+
272
+ self.assertIsNotNone(result)
273
+ self.assertTrue(output_path.exists())
274
+ self.assertGreater(output_path.stat().st_size, 0)
275
+
276
+ def test_memory_savings_chart(self):
277
+ """Should generate memory savings chart."""
278
+ if not is_available():
279
+ self.skipTest("matplotlib not available")
280
+
281
+ compare_json = create_sample_compare_json()
282
+
283
+ with tempfile.TemporaryDirectory() as tmpdir:
284
+ output_path = Path(tmpdir) / "memory.png"
285
+ result = generate_memory_savings_chart(compare_json, output_path)
286
+
287
+ self.assertIsNotNone(result)
288
+ self.assertTrue(output_path.exists())
289
+
290
+ def test_empty_points_returns_none(self):
291
+ """Should return None for empty points list."""
292
+ result = generate_tradeoff_chart([])
293
+ self.assertIsNone(result)
@@ -0,0 +1,243 @@
1
+ # Copyright (c) 2025 HaoLine Contributors
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Tests for edge-centric analysis."""
5
+
6
+ from __future__ import annotations
7
+
8
+ import tempfile
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import onnx
13
+ import pytest
14
+ from onnx import TensorProto, helper
15
+
16
+ from ..analyzer import ONNXGraphLoader
17
+ from ..edge_analysis import (
18
+ PRECISION_EDGE_COLORS,
19
+ EdgeAnalyzer,
20
+ EdgeInfo,
21
+ compute_edge_thickness,
22
+ format_tensor_shape,
23
+ format_tensor_size,
24
+ generate_edge_tooltip,
25
+ get_edge_color,
26
+ )
27
+
28
+
29
+ def create_simple_model() -> onnx.ModelProto:
30
+ """Create a simple model for edge testing."""
31
+ X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 224, 224])
32
+ Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 64, 112, 112])
33
+
34
+ weight = helper.make_tensor(
35
+ "conv_weight",
36
+ TensorProto.FLOAT,
37
+ [64, 3, 7, 7],
38
+ np.random.randn(64, 3, 7, 7).astype(np.float32).flatten().tolist(),
39
+ )
40
+
41
+ conv = helper.make_node(
42
+ "Conv",
43
+ ["X", "conv_weight"],
44
+ ["conv_out"],
45
+ kernel_shape=[7, 7],
46
+ strides=[2, 2],
47
+ pads=[3, 3, 3, 3],
48
+ )
49
+ relu = helper.make_node("Relu", ["conv_out"], ["Y"])
50
+
51
+ graph = helper.make_graph([conv, relu], "simple", [X], [Y], [weight])
52
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])
53
+ return model
54
+
55
+
56
+ def create_residual_model() -> onnx.ModelProto:
57
+ """Create a model with skip connection."""
58
+ X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 64, 56, 56])
59
+ Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 64, 56, 56])
60
+
61
+ weight1 = helper.make_tensor(
62
+ "w1",
63
+ TensorProto.FLOAT,
64
+ [64, 64, 3, 3],
65
+ np.random.randn(64, 64, 3, 3).astype(np.float32).flatten().tolist(),
66
+ )
67
+ weight2 = helper.make_tensor(
68
+ "w2",
69
+ TensorProto.FLOAT,
70
+ [64, 64, 3, 3],
71
+ np.random.randn(64, 64, 3, 3).astype(np.float32).flatten().tolist(),
72
+ )
73
+
74
+ # Conv -> BN -> ReLU -> Conv -> BN -> Add(skip) -> ReLU
75
+ nodes = [
76
+ helper.make_node("Conv", ["X", "w1"], ["c1"], kernel_shape=[3, 3], pads=[1, 1, 1, 1]),
77
+ helper.make_node("Relu", ["c1"], ["r1"]),
78
+ helper.make_node("Conv", ["r1", "w2"], ["c2"], kernel_shape=[3, 3], pads=[1, 1, 1, 1]),
79
+ helper.make_node("Add", ["X", "c2"], ["add_out"]), # Skip connection
80
+ helper.make_node("Relu", ["add_out"], ["Y"]),
81
+ ]
82
+
83
+ graph = helper.make_graph(nodes, "residual", [X], [Y], [weight1, weight2])
84
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])
85
+ return model
86
+
87
+
88
+ class TestEdgeAnalysis:
89
+ """Tests for EdgeAnalyzer."""
90
+
91
+ def test_extract_edges(self):
92
+ """Test edge extraction from model."""
93
+ model = create_simple_model()
94
+
95
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
96
+ onnx.save(model, f.name)
97
+ model_path = Path(f.name)
98
+
99
+ try:
100
+ loader = ONNXGraphLoader()
101
+ _, graph_info = loader.load(model_path)
102
+
103
+ analyzer = EdgeAnalyzer()
104
+ result = analyzer.analyze(graph_info)
105
+
106
+ # Should have edges for: X, conv_weight, conv_out, Y
107
+ assert len(result.edges) >= 3
108
+
109
+ # Find the weight edge
110
+ weight_edges = [e for e in result.edges if e.is_weight]
111
+ assert len(weight_edges) >= 1
112
+ finally:
113
+ model_path.unlink()
114
+
115
+ def test_skip_connection_detection(self):
116
+ """Test skip connection detection."""
117
+ model = create_residual_model()
118
+
119
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
120
+ onnx.save(model, f.name)
121
+ model_path = Path(f.name)
122
+
123
+ try:
124
+ loader = ONNXGraphLoader()
125
+ _, graph_info = loader.load(model_path)
126
+
127
+ analyzer = EdgeAnalyzer()
128
+ result = analyzer.analyze(graph_info)
129
+
130
+ # Should detect skip connection for X -> Add
131
+ [e for e in result.edges if e.is_skip_connection]
132
+ # May or may not detect depending on topological distance
133
+ # At minimum the analysis should run without error
134
+ assert result is not None
135
+ finally:
136
+ model_path.unlink()
137
+
138
+ def test_memory_profile(self):
139
+ """Test memory profile calculation."""
140
+ model = create_simple_model()
141
+
142
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
143
+ onnx.save(model, f.name)
144
+ model_path = Path(f.name)
145
+
146
+ try:
147
+ loader = ONNXGraphLoader()
148
+ _, graph_info = loader.load(model_path)
149
+
150
+ analyzer = EdgeAnalyzer()
151
+ result = analyzer.analyze(graph_info)
152
+
153
+ # Memory profile should have entries
154
+ assert len(result.memory_profile) > 0
155
+
156
+ # Peak should be non-zero
157
+ assert result.peak_activation_bytes >= 0
158
+ finally:
159
+ model_path.unlink()
160
+
161
+
162
+ class TestEdgeVisualization:
163
+ """Tests for edge visualization helpers."""
164
+
165
+ def test_edge_thickness_scaling(self):
166
+ """Test edge thickness computation."""
167
+ # Very small tensor
168
+ assert compute_edge_thickness(100) == 1 # Below min threshold
169
+
170
+ # 1KB
171
+ kb = compute_edge_thickness(1024)
172
+ assert 1 <= kb <= 10
173
+
174
+ # 1MB - should be thicker
175
+ mb = compute_edge_thickness(1024 * 1024)
176
+ assert mb > kb
177
+
178
+ # 1GB - should be even thicker
179
+ gb = compute_edge_thickness(1024 * 1024 * 1024)
180
+ assert gb > mb
181
+ assert gb <= 10
182
+
183
+ def test_format_tensor_shape(self):
184
+ """Test tensor shape formatting."""
185
+ assert format_tensor_shape([]) == "[]"
186
+ assert format_tensor_shape([1, 3, 224, 224]) == "[1, 3, 224, 224]"
187
+ assert format_tensor_shape([1, "batch", 768]) == "[1, batch, 768]"
188
+
189
+ def test_format_tensor_size(self):
190
+ """Test tensor size formatting."""
191
+ assert format_tensor_size(100) == "100 B"
192
+ assert format_tensor_size(1024) == "1.0 KB"
193
+ assert format_tensor_size(1024 * 1024) == "1.0 MB"
194
+ assert format_tensor_size(1024 * 1024 * 1024) == "1.00 GB"
195
+
196
+ def test_edge_colors(self):
197
+ """Test edge color assignment."""
198
+ # Normal edge
199
+ edge = EdgeInfo(
200
+ tensor_name="test",
201
+ source_node="node1",
202
+ target_nodes=["node2"],
203
+ shape=[1, 64, 56, 56],
204
+ dtype="float32",
205
+ size_bytes=1024 * 1024,
206
+ is_weight=False,
207
+ precision="fp32",
208
+ )
209
+ assert get_edge_color(edge) == PRECISION_EDGE_COLORS["fp32"]
210
+
211
+ # Bottleneck edge
212
+ edge.is_bottleneck = True
213
+ assert get_edge_color(edge) == "#E74C3C" # Red
214
+
215
+ # Attention edge (without bottleneck)
216
+ edge.is_bottleneck = False
217
+ edge.is_attention_qk = True
218
+ assert get_edge_color(edge) == "#E67E22" # Orange
219
+
220
+ def test_edge_tooltip(self):
221
+ """Test tooltip generation."""
222
+ edge = EdgeInfo(
223
+ tensor_name="attention_scores",
224
+ source_node="matmul_qk",
225
+ target_nodes=["softmax"],
226
+ shape=[1, 12, 512, 512],
227
+ dtype="float32",
228
+ size_bytes=12 * 512 * 512 * 4,
229
+ is_weight=False,
230
+ precision="fp32",
231
+ is_attention_qk=True,
232
+ )
233
+
234
+ tooltip = generate_edge_tooltip(edge)
235
+
236
+ assert "attention_scores" in tooltip
237
+ assert "[1, 12, 512, 512]" in tooltip
238
+ assert "fp32" in tooltip
239
+ assert "O(seq" in tooltip # O(seq^2) warning
240
+
241
+
242
+ if __name__ == "__main__":
243
+ pytest.main([__file__, "-v"])