haoline 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. haoline/.streamlit/config.toml +10 -0
  2. haoline/__init__.py +248 -0
  3. haoline/analyzer.py +935 -0
  4. haoline/cli.py +2712 -0
  5. haoline/compare.py +811 -0
  6. haoline/compare_visualizations.py +1564 -0
  7. haoline/edge_analysis.py +525 -0
  8. haoline/eval/__init__.py +131 -0
  9. haoline/eval/adapters.py +844 -0
  10. haoline/eval/cli.py +390 -0
  11. haoline/eval/comparison.py +542 -0
  12. haoline/eval/deployment.py +633 -0
  13. haoline/eval/schemas.py +833 -0
  14. haoline/examples/__init__.py +15 -0
  15. haoline/examples/basic_inspection.py +74 -0
  16. haoline/examples/compare_models.py +117 -0
  17. haoline/examples/hardware_estimation.py +78 -0
  18. haoline/format_adapters.py +1001 -0
  19. haoline/formats/__init__.py +123 -0
  20. haoline/formats/coreml.py +250 -0
  21. haoline/formats/gguf.py +483 -0
  22. haoline/formats/openvino.py +255 -0
  23. haoline/formats/safetensors.py +273 -0
  24. haoline/formats/tflite.py +369 -0
  25. haoline/hardware.py +2307 -0
  26. haoline/hierarchical_graph.py +462 -0
  27. haoline/html_export.py +1573 -0
  28. haoline/layer_summary.py +769 -0
  29. haoline/llm_summarizer.py +465 -0
  30. haoline/op_icons.py +618 -0
  31. haoline/operational_profiling.py +1492 -0
  32. haoline/patterns.py +1116 -0
  33. haoline/pdf_generator.py +265 -0
  34. haoline/privacy.py +250 -0
  35. haoline/pydantic_models.py +241 -0
  36. haoline/report.py +1923 -0
  37. haoline/report_sections.py +539 -0
  38. haoline/risks.py +521 -0
  39. haoline/schema.py +523 -0
  40. haoline/streamlit_app.py +2024 -0
  41. haoline/tests/__init__.py +4 -0
  42. haoline/tests/conftest.py +123 -0
  43. haoline/tests/test_analyzer.py +868 -0
  44. haoline/tests/test_compare_visualizations.py +293 -0
  45. haoline/tests/test_edge_analysis.py +243 -0
  46. haoline/tests/test_eval.py +604 -0
  47. haoline/tests/test_format_adapters.py +460 -0
  48. haoline/tests/test_hardware.py +237 -0
  49. haoline/tests/test_hardware_recommender.py +90 -0
  50. haoline/tests/test_hierarchical_graph.py +326 -0
  51. haoline/tests/test_html_export.py +180 -0
  52. haoline/tests/test_layer_summary.py +428 -0
  53. haoline/tests/test_llm_patterns.py +540 -0
  54. haoline/tests/test_llm_summarizer.py +339 -0
  55. haoline/tests/test_patterns.py +774 -0
  56. haoline/tests/test_pytorch.py +327 -0
  57. haoline/tests/test_report.py +383 -0
  58. haoline/tests/test_risks.py +398 -0
  59. haoline/tests/test_schema.py +417 -0
  60. haoline/tests/test_tensorflow.py +380 -0
  61. haoline/tests/test_visualizations.py +316 -0
  62. haoline/universal_ir.py +856 -0
  63. haoline/visualizations.py +1086 -0
  64. haoline/visualize_yolo.py +44 -0
  65. haoline/web.py +110 -0
  66. haoline-0.3.0.dist-info/METADATA +471 -0
  67. haoline-0.3.0.dist-info/RECORD +70 -0
  68. haoline-0.3.0.dist-info/WHEEL +4 -0
  69. haoline-0.3.0.dist-info/entry_points.txt +5 -0
  70. haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,90 @@
1
+ # Copyright (c) 2025 HaoLine Contributors
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ import unittest
5
+
6
+ from ..hardware import (
7
+ NVIDIA_JETSON_NANO,
8
+ NVIDIA_RTX_3060_12GB,
9
+ BatchSizeSweeper,
10
+ HardwareEstimator,
11
+ SystemRequirementsRecommender,
12
+ )
13
+
14
+
15
+ class TestSystemRequirements(unittest.TestCase):
16
+ def setUp(self):
17
+ self.estimator = HardwareEstimator()
18
+ self.recommender = SystemRequirementsRecommender(self.estimator)
19
+
20
+ def test_recommend_tiny_model(self):
21
+ # Tiny model (10MB weights) should fit on Nano
22
+ reqs = self.recommender.recommend(
23
+ model_params=2_500_000, # ~10MB FP32
24
+ model_flops=1_000_000_000, # 1 GFLOP
25
+ peak_activation_bytes=10_000_000, # 10MB
26
+ target_batch_size=1,
27
+ )
28
+
29
+ self.assertEqual(reqs.minimum_gpu.name, NVIDIA_JETSON_NANO.name)
30
+ self.assertLess(reqs.minimum_vram_gb, 4.0)
31
+
32
+ def test_recommend_large_model(self):
33
+ # Large model (10GB weights) needs bigger GPU
34
+ reqs = self.recommender.recommend(
35
+ model_params=5_000_000_000, # ~10GB FP16
36
+ model_flops=10_000_000_000_000, # 10 TFLOPs
37
+ peak_activation_bytes=1_000_000_000, # 1GB
38
+ target_batch_size=1,
39
+ precision="fp16",
40
+ )
41
+
42
+ # Should definitely NOT be Nano
43
+ self.assertNotEqual(reqs.minimum_gpu.name, NVIDIA_JETSON_NANO.name)
44
+ # Should be at least a 12GB card
45
+ self.assertGreaterEqual(reqs.minimum_gpu.vram_bytes, 10 * 1024**3)
46
+
47
+
48
+ class TestBatchSizeSweep(unittest.TestCase):
49
+ def setUp(self):
50
+ self.estimator = HardwareEstimator()
51
+ self.sweeper = BatchSizeSweeper(self.estimator)
52
+ self.gpu = NVIDIA_RTX_3060_12GB
53
+
54
+ def test_sweep_basic(self):
55
+ # Medium model
56
+ sweep = self.sweeper.sweep(
57
+ model_params=100_000_000, # 400MB
58
+ model_flops=10_000_000_000, # 10 GFLOPs
59
+ peak_activation_bytes=50_000_000, # 50MB
60
+ hardware=self.gpu,
61
+ max_batch_size=8,
62
+ )
63
+
64
+ self.assertEqual(len(sweep.batch_sizes), 4) # 1, 2, 4, 8
65
+ self.assertEqual(sweep.batch_sizes[-1], 8)
66
+ self.assertTrue(all(lat > 0 for lat in sweep.latencies))
67
+
68
+ # Throughput should generally increase (or saturate)
69
+ self.assertGreater(sweep.throughputs[-1], sweep.throughputs[0])
70
+
71
+ def test_sweep_oom(self):
72
+ # Huge activation model that OOMs quickly
73
+ sweep = self.sweeper.sweep(
74
+ model_params=100_000_000,
75
+ model_flops=10_000_000_000,
76
+ peak_activation_bytes=4 * 1024**3, # 4GB activations per sample
77
+ hardware=self.gpu, # 12GB VRAM
78
+ max_batch_size=8,
79
+ )
80
+
81
+ # Should fit batch 1 (4GB < 12GB)
82
+ # batch 2 (8GB < 12GB)
83
+ # batch 4 (16GB > 12GB) -> OOM
84
+ # So we expect 2 results
85
+ self.assertLess(len(sweep.batch_sizes), 4)
86
+ self.assertTrue(len(sweep.batch_sizes) >= 1)
87
+
88
+
89
+ if __name__ == "__main__":
90
+ unittest.main()
@@ -0,0 +1,326 @@
1
+ # Copyright (c) 2025 HaoLine Contributors
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Tests for hierarchical graph view."""
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ import tempfile
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import onnx
14
+ import pytest
15
+ from onnx import TensorProto, helper
16
+
17
+ from ..analyzer import ONNXGraphLoader
18
+ from ..hierarchical_graph import (
19
+ HierarchicalGraph,
20
+ HierarchicalGraphBuilder,
21
+ HierarchicalNode,
22
+ generate_summary,
23
+ )
24
+ from ..patterns import PatternAnalyzer
25
+
26
+
27
+ def create_test_model() -> onnx.ModelProto:
28
+ """Create a simple model for hierarchical testing."""
29
+ X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 64, 56, 56])
30
+ Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 64, 56, 56])
31
+
32
+ weight1 = helper.make_tensor(
33
+ "w1",
34
+ TensorProto.FLOAT,
35
+ [64, 64, 3, 3],
36
+ np.random.randn(64, 64, 3, 3).astype(np.float32).flatten().tolist(),
37
+ )
38
+ weight2 = helper.make_tensor(
39
+ "w2",
40
+ TensorProto.FLOAT,
41
+ [64, 64, 3, 3],
42
+ np.random.randn(64, 64, 3, 3).astype(np.float32).flatten().tolist(),
43
+ )
44
+
45
+ nodes = [
46
+ helper.make_node("Conv", ["X", "w1"], ["c1"], kernel_shape=[3, 3], pads=[1, 1, 1, 1]),
47
+ helper.make_node("Relu", ["c1"], ["r1"]),
48
+ helper.make_node("Conv", ["r1", "w2"], ["c2"], kernel_shape=[3, 3], pads=[1, 1, 1, 1]),
49
+ helper.make_node("Add", ["X", "c2"], ["add_out"]),
50
+ helper.make_node("Relu", ["add_out"], ["Y"]),
51
+ ]
52
+
53
+ graph = helper.make_graph(nodes, "test_model", [X], [Y], [weight1, weight2])
54
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])
55
+ return model
56
+
57
+
58
+ class TestHierarchicalNode:
59
+ """Tests for HierarchicalNode."""
60
+
61
+ def test_basic_creation(self):
62
+ """Test basic node creation."""
63
+ node = HierarchicalNode(
64
+ id="test_1",
65
+ name="TestNode",
66
+ node_type="op",
67
+ op_type="Conv",
68
+ )
69
+ assert node.id == "test_1"
70
+ assert node.name == "TestNode"
71
+ assert node.is_leaf()
72
+ assert node.is_collapsed
73
+
74
+ def test_display_name_with_repeat(self):
75
+ """Test xN notation for repeated blocks."""
76
+ node = HierarchicalNode(
77
+ id="block_1",
78
+ name="TransformerLayer",
79
+ node_type="block",
80
+ repeat_count=12,
81
+ is_repeated=True,
82
+ )
83
+ assert node.get_display_name() == "TransformerLayer x12"
84
+
85
+ def test_collapse_expand(self):
86
+ """Test collapse/expand state management."""
87
+ parent = HierarchicalNode(
88
+ id="parent",
89
+ name="Parent",
90
+ node_type="block",
91
+ children=[
92
+ HierarchicalNode(id="child1", name="Child1", node_type="op"),
93
+ HierarchicalNode(id="child2", name="Child2", node_type="op"),
94
+ ],
95
+ )
96
+
97
+ assert parent.is_collapsed
98
+ parent.expand()
99
+ assert not parent.is_collapsed
100
+ parent.toggle()
101
+ assert parent.is_collapsed
102
+
103
+ def test_aggregate_stats(self):
104
+ """Test stats aggregation from children."""
105
+ child1 = HierarchicalNode(
106
+ id="c1",
107
+ name="C1",
108
+ node_type="op",
109
+ total_flops=1000,
110
+ total_params=100,
111
+ total_memory_bytes=400,
112
+ )
113
+ child2 = HierarchicalNode(
114
+ id="c2",
115
+ name="C2",
116
+ node_type="op",
117
+ total_flops=2000,
118
+ total_params=200,
119
+ total_memory_bytes=800,
120
+ )
121
+ parent = HierarchicalNode(
122
+ id="parent",
123
+ name="Parent",
124
+ node_type="block",
125
+ children=[child1, child2],
126
+ )
127
+
128
+ parent.aggregate_stats()
129
+
130
+ assert parent.total_flops == 3000
131
+ assert parent.total_params == 300
132
+ assert parent.total_memory_bytes == 1200
133
+ assert parent.node_count == 2
134
+
135
+ def test_aggregate_with_repeat(self):
136
+ """Test stats aggregation with repeat multiplier."""
137
+ child = HierarchicalNode(
138
+ id="c1",
139
+ name="C1",
140
+ node_type="op",
141
+ total_flops=1000,
142
+ total_params=100,
143
+ total_memory_bytes=400,
144
+ )
145
+ parent = HierarchicalNode(
146
+ id="parent",
147
+ name="Parent",
148
+ node_type="block",
149
+ children=[child],
150
+ repeat_count=12,
151
+ )
152
+
153
+ parent.aggregate_stats()
154
+
155
+ assert parent.total_flops == 12000 # Multiplied
156
+ assert parent.total_params == 100 # Shared, not multiplied
157
+ assert parent.total_memory_bytes == 4800 # Multiplied
158
+
159
+ def test_to_dict(self):
160
+ """Test JSON export."""
161
+ node = HierarchicalNode(
162
+ id="test",
163
+ name="TestNode",
164
+ node_type="op",
165
+ op_type="Conv",
166
+ )
167
+ d = node.to_dict()
168
+
169
+ assert d["id"] == "test"
170
+ assert d["name"] == "TestNode"
171
+ assert d["op_type"] == "Conv"
172
+ assert "children" not in d # No children for leaf
173
+
174
+
175
+ class TestHierarchicalGraph:
176
+ """Tests for HierarchicalGraph."""
177
+
178
+ def test_graph_creation(self):
179
+ """Test creating a hierarchical graph."""
180
+ root = HierarchicalNode(id="root", name="Model", node_type="model")
181
+ graph = HierarchicalGraph(root=root, nodes_by_id={"root": root})
182
+
183
+ assert graph.root.name == "Model"
184
+ assert graph.get_node("root") == root
185
+
186
+ def test_visible_nodes_collapsed(self):
187
+ """Test visible nodes when parent is collapsed."""
188
+ child = HierarchicalNode(id="child", name="Child", node_type="op")
189
+ root = HierarchicalNode(
190
+ id="root",
191
+ name="Root",
192
+ node_type="model",
193
+ children=[child],
194
+ is_collapsed=True,
195
+ )
196
+ graph = HierarchicalGraph(
197
+ root=root,
198
+ nodes_by_id={"root": root, "child": child},
199
+ )
200
+
201
+ visible = graph.get_visible_nodes()
202
+ assert len(visible) == 1
203
+ assert visible[0].id == "root"
204
+
205
+ def test_visible_nodes_expanded(self):
206
+ """Test visible nodes when parent is expanded."""
207
+ child = HierarchicalNode(id="child", name="Child", node_type="op")
208
+ root = HierarchicalNode(
209
+ id="root",
210
+ name="Root",
211
+ node_type="model",
212
+ children=[child],
213
+ is_collapsed=False,
214
+ )
215
+ graph = HierarchicalGraph(
216
+ root=root,
217
+ nodes_by_id={"root": root, "child": child},
218
+ )
219
+
220
+ visible = graph.get_visible_nodes()
221
+ assert len(visible) == 2
222
+
223
+ def test_expand_to_depth(self):
224
+ """Test expanding to a specific depth."""
225
+ grandchild = HierarchicalNode(id="gc", name="GC", node_type="op", depth=2)
226
+ child = HierarchicalNode(
227
+ id="child",
228
+ name="Child",
229
+ node_type="block",
230
+ children=[grandchild],
231
+ depth=1,
232
+ )
233
+ root = HierarchicalNode(
234
+ id="root",
235
+ name="Root",
236
+ node_type="model",
237
+ children=[child],
238
+ depth=0,
239
+ )
240
+ graph = HierarchicalGraph(
241
+ root=root,
242
+ nodes_by_id={"root": root, "child": child, "gc": grandchild},
243
+ )
244
+
245
+ graph.expand_to_depth(1)
246
+
247
+ assert not root.is_collapsed
248
+ assert not child.is_collapsed
249
+ assert grandchild.is_collapsed
250
+
251
+ def test_to_json(self):
252
+ """Test JSON export."""
253
+ root = HierarchicalNode(id="root", name="Model", node_type="model")
254
+ graph = HierarchicalGraph(root=root, total_nodes=1, depth=0)
255
+
256
+ json_str = graph.to_json()
257
+ data = json.loads(json_str)
258
+
259
+ assert data["root"]["name"] == "Model"
260
+ assert data["total_nodes"] == 1
261
+
262
+
263
+ class TestHierarchicalGraphBuilder:
264
+ """Tests for building hierarchical graphs from ONNX."""
265
+
266
+ def test_build_from_onnx(self):
267
+ """Test building hierarchy from ONNX model."""
268
+ model = create_test_model()
269
+
270
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
271
+ onnx.save(model, f.name)
272
+ model_path = Path(f.name)
273
+
274
+ try:
275
+ loader = ONNXGraphLoader()
276
+ _, graph_info = loader.load(model_path)
277
+
278
+ pattern_analyzer = PatternAnalyzer()
279
+ blocks = pattern_analyzer.group_into_blocks(graph_info)
280
+
281
+ builder = HierarchicalGraphBuilder()
282
+ graph = builder.build(graph_info, blocks, "TestModel")
283
+
284
+ assert graph.root.name == "TestModel"
285
+ assert graph.total_nodes > 0
286
+ assert len(graph.root.children) > 0
287
+ finally:
288
+ model_path.unlink()
289
+
290
+
291
+ class TestSummaryGeneration:
292
+ """Tests for summary generation."""
293
+
294
+ def test_generate_summary(self):
295
+ """Test multi-level summary generation."""
296
+ child1 = HierarchicalNode(
297
+ id="c1",
298
+ name="ConvBlock",
299
+ node_type="block",
300
+ attributes={"block_type": "ConvRelu"},
301
+ )
302
+ child2 = HierarchicalNode(
303
+ id="c2",
304
+ name="Layer",
305
+ node_type="layer",
306
+ is_repeated=True,
307
+ repeat_count=12,
308
+ )
309
+ root = HierarchicalNode(
310
+ id="root",
311
+ name="Model",
312
+ node_type="model",
313
+ children=[child1, child2],
314
+ is_collapsed=False,
315
+ )
316
+ graph = HierarchicalGraph(root=root, total_nodes=3, depth=1)
317
+
318
+ summary = generate_summary(graph)
319
+
320
+ assert "Model" in summary
321
+ assert "ConvRelu" in summary
322
+ assert "x12" in summary
323
+
324
+
325
+ if __name__ == "__main__":
326
+ pytest.main([__file__, "-v"])
@@ -0,0 +1,180 @@
1
+ # Copyright (c) 2025 HaoLine Contributors
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Tests for HTML export."""
5
+
6
+ from __future__ import annotations
7
+
8
+ import tempfile
9
+ from pathlib import Path
10
+
11
+ import pytest
12
+
13
+ from ..edge_analysis import EdgeAnalysisResult
14
+ from ..hierarchical_graph import HierarchicalGraph, HierarchicalNode
15
+ from ..html_export import HTMLExporter, generate_html
16
+
17
+
18
+ def create_test_graph() -> HierarchicalGraph:
19
+ """Create a test hierarchical graph."""
20
+ child1 = HierarchicalNode(
21
+ id="op_1",
22
+ name="Conv1",
23
+ node_type="op",
24
+ op_type="Conv",
25
+ total_flops=1000000,
26
+ total_memory_bytes=1024 * 1024,
27
+ )
28
+ child2 = HierarchicalNode(
29
+ id="op_2",
30
+ name="Relu1",
31
+ node_type="op",
32
+ op_type="Relu",
33
+ )
34
+ block = HierarchicalNode(
35
+ id="block_1",
36
+ name="ConvBlock",
37
+ node_type="block",
38
+ children=[child1, child2],
39
+ is_collapsed=False,
40
+ attributes={"block_type": "ConvRelu"},
41
+ )
42
+ root = HierarchicalNode(
43
+ id="root",
44
+ name="TestModel",
45
+ node_type="model",
46
+ children=[block],
47
+ is_collapsed=False,
48
+ )
49
+
50
+ return HierarchicalGraph(
51
+ root=root,
52
+ nodes_by_id={"root": root, "block_1": block, "op_1": child1, "op_2": child2},
53
+ total_nodes=4,
54
+ depth=2,
55
+ )
56
+
57
+
58
+ def create_test_edge_result() -> EdgeAnalysisResult:
59
+ """Create a test edge analysis result."""
60
+ return EdgeAnalysisResult(
61
+ edges=[],
62
+ total_activation_bytes=10 * 1024 * 1024,
63
+ peak_activation_bytes=5 * 1024 * 1024,
64
+ peak_activation_node="Conv1",
65
+ bottleneck_edges=["conv_out"],
66
+ attention_edges=[],
67
+ skip_connection_edges=[],
68
+ memory_profile=[("Conv1", 5 * 1024 * 1024)],
69
+ )
70
+
71
+
72
+ class TestHTMLGeneration:
73
+ """Tests for HTML generation."""
74
+
75
+ def test_generate_html_basic(self):
76
+ """Test basic HTML generation."""
77
+ graph = create_test_graph()
78
+ html = generate_html(graph, title="Test Model")
79
+
80
+ # Should contain key elements
81
+ assert "<!DOCTYPE html>" in html
82
+ assert "<title>Test Model - Neural Architecture</title>" in html
83
+ assert "d3.v7.min.js" in html
84
+ assert "graphData" in html
85
+
86
+ def test_generate_html_with_edges(self):
87
+ """Test HTML generation with edge data."""
88
+ graph = create_test_graph()
89
+ edge_result = create_test_edge_result()
90
+ html = generate_html(graph, edge_result, title="Test Model")
91
+
92
+ assert "edgeData" in html
93
+ assert "peak_activation_bytes" in html
94
+
95
+ def test_generate_html_to_file(self):
96
+ """Test HTML generation to file."""
97
+ graph = create_test_graph()
98
+
99
+ with tempfile.TemporaryDirectory() as tmpdir:
100
+ output_path = Path(tmpdir) / "test_model.html"
101
+ html = generate_html(graph, title="Test", output_path=output_path)
102
+
103
+ assert output_path.exists()
104
+ content = output_path.read_text(encoding="utf-8")
105
+ assert content == html
106
+
107
+ def test_html_contains_controls(self):
108
+ """Test that HTML contains interactive controls."""
109
+ graph = create_test_graph()
110
+ html = generate_html(graph)
111
+
112
+ assert "expandAll()" in html
113
+ assert "collapseAll()" in html
114
+ assert "resetZoom()" in html
115
+ assert "fitToScreen()" in html
116
+
117
+ def test_html_contains_legend(self):
118
+ """Test that HTML contains op type legend."""
119
+ graph = create_test_graph()
120
+ html = generate_html(graph)
121
+
122
+ # New design uses "Op Types" instead of "Legend"
123
+ assert "Op Types" in html
124
+ assert "Convolution" in html
125
+ assert "Attention" in html
126
+ assert "legend-item" in html # CSS class for legend items
127
+
128
+ def test_html_contains_stats_panel(self):
129
+ """Test that HTML contains stats panel."""
130
+ graph = create_test_graph()
131
+ html = generate_html(graph)
132
+
133
+ assert "node-count" in html
134
+ assert "edge-count" in html
135
+ assert "peak-memory" in html
136
+
137
+
138
+ class TestHTMLExporter:
139
+ """Tests for HTMLExporter class."""
140
+
141
+ def test_export_basic(self):
142
+ """Test basic export."""
143
+ graph = create_test_graph()
144
+ exporter = HTMLExporter()
145
+
146
+ with tempfile.TemporaryDirectory() as tmpdir:
147
+ output_path = Path(tmpdir) / "model.html"
148
+ result = exporter.export(graph, output_path=output_path)
149
+
150
+ assert result == output_path
151
+ assert output_path.exists()
152
+
153
+ def test_export_with_edges(self):
154
+ """Test export with edge data."""
155
+ graph = create_test_graph()
156
+ edge_result = create_test_edge_result()
157
+ exporter = HTMLExporter()
158
+
159
+ with tempfile.TemporaryDirectory() as tmpdir:
160
+ output_path = Path(tmpdir) / "model.html"
161
+ exporter.export(graph, edge_result, output_path=output_path)
162
+
163
+ content = output_path.read_text(encoding="utf-8")
164
+ assert "peak_activation_bytes" in content
165
+
166
+ def test_export_custom_title(self):
167
+ """Test export with custom title."""
168
+ graph = create_test_graph()
169
+ exporter = HTMLExporter()
170
+
171
+ with tempfile.TemporaryDirectory() as tmpdir:
172
+ output_path = Path(tmpdir) / "model.html"
173
+ exporter.export(graph, output_path=output_path, title="My Custom Model")
174
+
175
+ content = output_path.read_text(encoding="utf-8")
176
+ assert "My Custom Model" in content
177
+
178
+
179
+ if __name__ == "__main__":
180
+ pytest.main([__file__, "-v"])