haoline 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. haoline/.streamlit/config.toml +10 -0
  2. haoline/__init__.py +248 -0
  3. haoline/analyzer.py +935 -0
  4. haoline/cli.py +2712 -0
  5. haoline/compare.py +811 -0
  6. haoline/compare_visualizations.py +1564 -0
  7. haoline/edge_analysis.py +525 -0
  8. haoline/eval/__init__.py +131 -0
  9. haoline/eval/adapters.py +844 -0
  10. haoline/eval/cli.py +390 -0
  11. haoline/eval/comparison.py +542 -0
  12. haoline/eval/deployment.py +633 -0
  13. haoline/eval/schemas.py +833 -0
  14. haoline/examples/__init__.py +15 -0
  15. haoline/examples/basic_inspection.py +74 -0
  16. haoline/examples/compare_models.py +117 -0
  17. haoline/examples/hardware_estimation.py +78 -0
  18. haoline/format_adapters.py +1001 -0
  19. haoline/formats/__init__.py +123 -0
  20. haoline/formats/coreml.py +250 -0
  21. haoline/formats/gguf.py +483 -0
  22. haoline/formats/openvino.py +255 -0
  23. haoline/formats/safetensors.py +273 -0
  24. haoline/formats/tflite.py +369 -0
  25. haoline/hardware.py +2307 -0
  26. haoline/hierarchical_graph.py +462 -0
  27. haoline/html_export.py +1573 -0
  28. haoline/layer_summary.py +769 -0
  29. haoline/llm_summarizer.py +465 -0
  30. haoline/op_icons.py +618 -0
  31. haoline/operational_profiling.py +1492 -0
  32. haoline/patterns.py +1116 -0
  33. haoline/pdf_generator.py +265 -0
  34. haoline/privacy.py +250 -0
  35. haoline/pydantic_models.py +241 -0
  36. haoline/report.py +1923 -0
  37. haoline/report_sections.py +539 -0
  38. haoline/risks.py +521 -0
  39. haoline/schema.py +523 -0
  40. haoline/streamlit_app.py +2024 -0
  41. haoline/tests/__init__.py +4 -0
  42. haoline/tests/conftest.py +123 -0
  43. haoline/tests/test_analyzer.py +868 -0
  44. haoline/tests/test_compare_visualizations.py +293 -0
  45. haoline/tests/test_edge_analysis.py +243 -0
  46. haoline/tests/test_eval.py +604 -0
  47. haoline/tests/test_format_adapters.py +460 -0
  48. haoline/tests/test_hardware.py +237 -0
  49. haoline/tests/test_hardware_recommender.py +90 -0
  50. haoline/tests/test_hierarchical_graph.py +326 -0
  51. haoline/tests/test_html_export.py +180 -0
  52. haoline/tests/test_layer_summary.py +428 -0
  53. haoline/tests/test_llm_patterns.py +540 -0
  54. haoline/tests/test_llm_summarizer.py +339 -0
  55. haoline/tests/test_patterns.py +774 -0
  56. haoline/tests/test_pytorch.py +327 -0
  57. haoline/tests/test_report.py +383 -0
  58. haoline/tests/test_risks.py +398 -0
  59. haoline/tests/test_schema.py +417 -0
  60. haoline/tests/test_tensorflow.py +380 -0
  61. haoline/tests/test_visualizations.py +316 -0
  62. haoline/universal_ir.py +856 -0
  63. haoline/visualizations.py +1086 -0
  64. haoline/visualize_yolo.py +44 -0
  65. haoline/web.py +110 -0
  66. haoline-0.3.0.dist-info/METADATA +471 -0
  67. haoline-0.3.0.dist-info/RECORD +70 -0
  68. haoline-0.3.0.dist-info/WHEEL +4 -0
  69. haoline-0.3.0.dist-info/entry_points.txt +5 -0
  70. haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,460 @@
1
+ # Copyright (c) 2025 HaoLine Contributors
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Tests for format adapter system."""
5
+
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import pytest
10
+
11
+ from haoline.format_adapters import (
12
+ ConversionLevel,
13
+ FormatAdapter,
14
+ OnnxAdapter,
15
+ PyTorchAdapter,
16
+ can_convert,
17
+ get_adapter,
18
+ get_conversion_level,
19
+ list_adapters,
20
+ list_conversion_paths,
21
+ map_onnx_op_to_universal,
22
+ register_adapter,
23
+ )
24
+ from haoline.universal_ir import (
25
+ DataType,
26
+ GraphMetadata,
27
+ SourceFormat,
28
+ TensorOrigin,
29
+ UniversalGraph,
30
+ UniversalNode,
31
+ UniversalTensor,
32
+ )
33
+
34
+
35
+ class TestAdapterRegistry:
36
+ """Tests for adapter registration and lookup."""
37
+
38
+ def test_list_adapters_includes_onnx(self) -> None:
39
+ """ONNX adapter should be registered by default."""
40
+ adapters = list_adapters()
41
+ names = [a["name"] for a in adapters]
42
+ assert "onnx" in names
43
+
44
+ def test_list_adapters_includes_pytorch(self) -> None:
45
+ """PyTorch adapter should be registered by default."""
46
+ adapters = list_adapters()
47
+ names = [a["name"] for a in adapters]
48
+ assert "pytorch" in names
49
+
50
+ def test_get_adapter_onnx(self) -> None:
51
+ """get_adapter should return OnnxAdapter for .onnx files."""
52
+ adapter = get_adapter("model.onnx")
53
+ assert isinstance(adapter, OnnxAdapter)
54
+
55
+ def test_get_adapter_pytorch(self) -> None:
56
+ """get_adapter should return PyTorchAdapter for .pt files."""
57
+ adapter = get_adapter("model.pt")
58
+ assert isinstance(adapter, PyTorchAdapter)
59
+
60
+ def test_get_adapter_unknown_extension_raises(self) -> None:
61
+ """get_adapter should raise ValueError for unknown extensions."""
62
+ with pytest.raises(ValueError, match="No adapter registered"):
63
+ get_adapter("model.unknown")
64
+
65
+ def test_register_custom_adapter(self) -> None:
66
+ """Custom adapters can be registered."""
67
+
68
+ @register_adapter
69
+ class TestAdapter(FormatAdapter):
70
+ name = "test_format"
71
+ extensions = [".testfmt"]
72
+ source_format = SourceFormat.UNKNOWN
73
+
74
+ def can_read(self, path: Path) -> bool:
75
+ return path.suffix.lower() == ".testfmt"
76
+
77
+ def read(self, path: Path) -> UniversalGraph:
78
+ return UniversalGraph()
79
+
80
+ # Should be able to get the adapter
81
+ adapter = get_adapter("model.testfmt")
82
+ assert adapter.name == "test_format"
83
+
84
+
85
+ class TestOnnxAdapter:
86
+ """Tests for ONNX adapter."""
87
+
88
+ def test_can_read_onnx_file(self) -> None:
89
+ """OnnxAdapter.can_read should return True for .onnx files."""
90
+ adapter = OnnxAdapter()
91
+ assert adapter.can_read(Path("model.onnx")) is True
92
+ assert adapter.can_read(Path("model.pt")) is False
93
+
94
+ def test_can_write(self) -> None:
95
+ """OnnxAdapter should support writing."""
96
+ adapter = OnnxAdapter()
97
+ assert adapter.can_write() is True
98
+
99
+ def test_read_nonexistent_file_raises(self) -> None:
100
+ """Reading non-existent file should raise FileNotFoundError."""
101
+ adapter = OnnxAdapter()
102
+ with pytest.raises(FileNotFoundError):
103
+ adapter.read(Path("nonexistent_model.onnx"))
104
+
105
+ @pytest.mark.skipif(
106
+ not Path("src/haoline/tests/fixtures").exists(),
107
+ reason="No test fixtures available",
108
+ )
109
+ def test_read_simple_onnx_model(self) -> None:
110
+ """Test reading a simple ONNX model."""
111
+ # This test would use a fixture model
112
+ pass
113
+
114
+
115
+ class TestPyTorchAdapter:
116
+ """Tests for PyTorch adapter."""
117
+
118
+ def test_can_read_pytorch_files(self) -> None:
119
+ """PyTorchAdapter.can_read should return True for .pt/.pth files."""
120
+ adapter = PyTorchAdapter()
121
+ assert adapter.can_read(Path("model.pt")) is True
122
+ assert adapter.can_read(Path("model.pth")) is True
123
+ assert adapter.can_read(Path("model.onnx")) is False
124
+
125
+ def test_cannot_write(self) -> None:
126
+ """PyTorchAdapter should not support direct writing."""
127
+ adapter = PyTorchAdapter()
128
+ assert adapter.can_write() is False
129
+
130
+
131
+ class TestOpMapping:
132
+ """Tests for ONNX to Universal op mapping."""
133
+
134
+ def test_conv_mapping(self) -> None:
135
+ """ONNX Conv should map to Conv2D."""
136
+ assert map_onnx_op_to_universal("Conv") == "Conv2D"
137
+
138
+ def test_gemm_mapping(self) -> None:
139
+ """ONNX Gemm should map to MatMul."""
140
+ assert map_onnx_op_to_universal("Gemm") == "MatMul"
141
+
142
+ def test_relu_mapping(self) -> None:
143
+ """ONNX Relu should map to Relu."""
144
+ assert map_onnx_op_to_universal("Relu") == "Relu"
145
+
146
+ def test_unknown_op_passthrough(self) -> None:
147
+ """Unknown ops should pass through unchanged."""
148
+ assert map_onnx_op_to_universal("CustomOp") == "CustomOp"
149
+
150
+ def test_pooling_mapping(self) -> None:
151
+ """Pooling ops should map correctly."""
152
+ assert map_onnx_op_to_universal("MaxPool") == "MaxPool2D"
153
+ assert map_onnx_op_to_universal("AveragePool") == "AvgPool2D"
154
+ assert map_onnx_op_to_universal("GlobalAveragePool") == "GlobalAvgPool"
155
+
156
+
157
+ class TestUniversalGraphCreation:
158
+ """Tests for creating UniversalGraph manually."""
159
+
160
+ def test_create_empty_graph(self) -> None:
161
+ """Should be able to create an empty graph."""
162
+ graph = UniversalGraph()
163
+ assert graph.num_nodes == 0
164
+ assert graph.num_tensors == 0
165
+
166
+ def test_create_graph_with_metadata(self) -> None:
167
+ """Should be able to create graph with metadata."""
168
+ metadata = GraphMetadata(
169
+ name="test_model",
170
+ source_format=SourceFormat.ONNX,
171
+ opset_version=17,
172
+ )
173
+ graph = UniversalGraph(metadata=metadata)
174
+ assert graph.metadata.name == "test_model"
175
+ assert graph.metadata.source_format == SourceFormat.ONNX
176
+
177
+ def test_create_node(self) -> None:
178
+ """Should be able to create a node."""
179
+ node = UniversalNode(
180
+ id="conv1",
181
+ op_type="Conv2D",
182
+ inputs=["input", "weight"],
183
+ outputs=["output"],
184
+ attributes={"kernel_shape": [3, 3]},
185
+ )
186
+ assert node.id == "conv1"
187
+ assert node.op_type == "Conv2D"
188
+ assert node.is_compute_op is True
189
+
190
+ def test_create_tensor(self) -> None:
191
+ """Should be able to create a tensor."""
192
+ tensor = UniversalTensor(
193
+ name="weight",
194
+ shape=[64, 3, 3, 3],
195
+ dtype=DataType.FLOAT32,
196
+ origin=TensorOrigin.WEIGHT,
197
+ )
198
+ assert tensor.name == "weight"
199
+ assert tensor.num_elements == 64 * 3 * 3 * 3
200
+ assert tensor.size_bytes == 64 * 3 * 3 * 3 * 4
201
+
202
+ def test_graph_with_nodes_and_tensors(self) -> None:
203
+ """Should be able to create a complete graph."""
204
+ weight_data = np.random.randn(64, 3, 3, 3).astype(np.float32)
205
+
206
+ tensors = {
207
+ "input": UniversalTensor(
208
+ name="input",
209
+ shape=[1, 3, 224, 224],
210
+ dtype=DataType.FLOAT32,
211
+ origin=TensorOrigin.INPUT,
212
+ ),
213
+ "conv1.weight": UniversalTensor(
214
+ name="conv1.weight",
215
+ shape=[64, 3, 3, 3],
216
+ dtype=DataType.FLOAT32,
217
+ origin=TensorOrigin.WEIGHT,
218
+ data=weight_data,
219
+ ),
220
+ "output": UniversalTensor(
221
+ name="output",
222
+ shape=[1, 64, 222, 222],
223
+ dtype=DataType.FLOAT32,
224
+ origin=TensorOrigin.OUTPUT,
225
+ ),
226
+ }
227
+
228
+ nodes = [
229
+ UniversalNode(
230
+ id="conv1",
231
+ op_type="Conv2D",
232
+ inputs=["input", "conv1.weight"],
233
+ outputs=["output"],
234
+ attributes={"kernel_shape": [3, 3]},
235
+ )
236
+ ]
237
+
238
+ graph = UniversalGraph(
239
+ nodes=nodes,
240
+ tensors=tensors,
241
+ metadata=GraphMetadata(name="simple_conv"),
242
+ )
243
+
244
+ assert graph.num_nodes == 1
245
+ assert graph.num_tensors == 3
246
+ assert graph.total_parameters == 64 * 3 * 3 * 3
247
+ assert len(graph.weight_tensors) == 1
248
+ assert len(graph.input_tensors) == 1
249
+ assert len(graph.output_tensors) == 1
250
+
251
+
252
+ class TestGraphComparison:
253
+ """Tests for graph structural comparison."""
254
+
255
+ def test_empty_graphs_equal(self) -> None:
256
+ """Two empty graphs should be structurally equal."""
257
+ g1 = UniversalGraph()
258
+ g2 = UniversalGraph()
259
+ assert g1.is_structurally_equal(g2) is True
260
+
261
+ def test_different_node_count_not_equal(self) -> None:
262
+ """Graphs with different node counts are not equal."""
263
+ g1 = UniversalGraph(nodes=[UniversalNode(id="n1", op_type="Conv2D")])
264
+ g2 = UniversalGraph()
265
+ assert g1.is_structurally_equal(g2) is False
266
+
267
+ def test_same_structure_equal(self) -> None:
268
+ """Graphs with same structure should be equal."""
269
+ g1 = UniversalGraph(
270
+ nodes=[
271
+ UniversalNode(id="conv1", op_type="Conv2D", inputs=["a"], outputs=["b"]),
272
+ UniversalNode(id="relu1", op_type="Relu", inputs=["b"], outputs=["c"]),
273
+ ]
274
+ )
275
+ g2 = UniversalGraph(
276
+ nodes=[
277
+ UniversalNode(id="c1", op_type="Conv2D", inputs=["x"], outputs=["y"]),
278
+ UniversalNode(id="r1", op_type="Relu", inputs=["y"], outputs=["z"]),
279
+ ]
280
+ )
281
+ assert g1.is_structurally_equal(g2) is True
282
+
283
+ def test_different_op_types_not_equal(self) -> None:
284
+ """Graphs with different op types are not equal."""
285
+ g1 = UniversalGraph(nodes=[UniversalNode(id="n1", op_type="Conv2D")])
286
+ g2 = UniversalGraph(nodes=[UniversalNode(id="n1", op_type="MatMul")])
287
+ assert g1.is_structurally_equal(g2) is False
288
+
289
+
290
+ class TestGraphDiff:
291
+ """Tests for graph diff functionality."""
292
+
293
+ def test_diff_empty_graphs(self) -> None:
294
+ """Diff of two empty graphs."""
295
+ g1 = UniversalGraph()
296
+ g2 = UniversalGraph()
297
+ diff = g1.diff(g2)
298
+ assert diff["structurally_equal"] is True
299
+ assert diff["node_count_diff"] == (0, 0)
300
+
301
+ def test_diff_shows_node_count(self) -> None:
302
+ """Diff should show node count difference."""
303
+ g1 = UniversalGraph(nodes=[UniversalNode(id="n1", op_type="Conv2D")])
304
+ g2 = UniversalGraph()
305
+ diff = g1.diff(g2)
306
+ assert diff["structurally_equal"] is False
307
+ assert diff["node_count_diff"] == (1, 0)
308
+
309
+ def test_diff_shows_op_type_diff(self) -> None:
310
+ """Diff should show op type count differences."""
311
+ g1 = UniversalGraph(
312
+ nodes=[
313
+ UniversalNode(id="c1", op_type="Conv2D"),
314
+ UniversalNode(id="c2", op_type="Conv2D"),
315
+ ]
316
+ )
317
+ g2 = UniversalGraph(
318
+ nodes=[
319
+ UniversalNode(id="c1", op_type="Conv2D"),
320
+ ]
321
+ )
322
+ diff = g1.diff(g2)
323
+ assert "Conv2D" in diff["op_type_diff"]
324
+ assert diff["op_type_diff"]["Conv2D"] == (2, 1)
325
+
326
+
327
+ class TestGraphSerialization:
328
+ """Tests for graph JSON serialization."""
329
+
330
+ def test_to_dict_empty_graph(self) -> None:
331
+ """Empty graph should serialize to dict."""
332
+ graph = UniversalGraph()
333
+ data = graph.to_dict()
334
+ assert "metadata" in data
335
+ assert "nodes" in data
336
+ assert "tensors" in data
337
+ assert "summary" in data
338
+
339
+ def test_to_dict_with_tensors(self) -> None:
340
+ """Graph with tensors should serialize correctly."""
341
+ graph = UniversalGraph(
342
+ tensors={
343
+ "w": UniversalTensor(
344
+ name="w",
345
+ shape=[3, 3],
346
+ dtype=DataType.FLOAT32,
347
+ origin=TensorOrigin.WEIGHT,
348
+ data=np.ones((3, 3), dtype=np.float32),
349
+ )
350
+ }
351
+ )
352
+ data = graph.to_dict(include_weights=False)
353
+ # Weight data should be stripped
354
+ assert data["tensors"]["w"]["data"] is None
355
+
356
+ def test_round_trip_json(self, tmp_path: Path) -> None:
357
+ """Graph should survive JSON round-trip."""
358
+ original = UniversalGraph(
359
+ nodes=[UniversalNode(id="n1", op_type="Conv2D", inputs=["a"], outputs=["b"])],
360
+ tensors={
361
+ "a": UniversalTensor(name="a", shape=[1, 3, 224, 224], origin=TensorOrigin.INPUT)
362
+ },
363
+ metadata=GraphMetadata(name="test", source_format=SourceFormat.ONNX),
364
+ )
365
+
366
+ json_path = tmp_path / "graph.json"
367
+ original.to_json(json_path)
368
+
369
+ loaded = UniversalGraph.from_json(json_path)
370
+ assert loaded.num_nodes == 1
371
+ assert loaded.metadata.name == "test"
372
+ assert loaded.is_structurally_equal(original)
373
+
374
+
375
+ class TestDataType:
376
+ """Tests for DataType enum."""
377
+
378
+ def test_bytes_per_element(self) -> None:
379
+ """DataType should report correct bytes per element."""
380
+ assert DataType.FLOAT32.bytes_per_element == 4
381
+ assert DataType.FLOAT16.bytes_per_element == 2
382
+ assert DataType.INT8.bytes_per_element == 1
383
+ assert DataType.FLOAT64.bytes_per_element == 8
384
+
385
+ def test_from_numpy_dtype(self) -> None:
386
+ """DataType should convert from numpy dtype."""
387
+ assert DataType.from_numpy_dtype(np.dtype(np.float32)) == DataType.FLOAT32
388
+ assert DataType.from_numpy_dtype(np.dtype(np.float16)) == DataType.FLOAT16
389
+ assert DataType.from_numpy_dtype(np.dtype(np.int8)) == DataType.INT8
390
+
391
+
392
+ class TestConversionMatrix:
393
+ """Tests for conversion matrix functionality."""
394
+
395
+ def test_conversion_level_enum(self) -> None:
396
+ """ConversionLevel enum should have expected values."""
397
+ assert ConversionLevel.FULL.value == "full"
398
+ assert ConversionLevel.PARTIAL.value == "partial"
399
+ assert ConversionLevel.LOSSY.value == "lossy"
400
+ assert ConversionLevel.NONE.value == "none"
401
+
402
+ def test_identity_conversion(self) -> None:
403
+ """Converting to same format should be FULL."""
404
+ assert get_conversion_level(SourceFormat.ONNX, SourceFormat.ONNX) == ConversionLevel.FULL
405
+ assert get_conversion_level("pytorch", "pytorch") == ConversionLevel.FULL
406
+
407
+ def test_pytorch_to_onnx(self) -> None:
408
+ """PyTorch to ONNX should be FULL."""
409
+ level = get_conversion_level(SourceFormat.PYTORCH, SourceFormat.ONNX)
410
+ assert level == ConversionLevel.FULL
411
+
412
+ def test_onnx_to_tensorrt(self) -> None:
413
+ """ONNX to TensorRT should be PARTIAL."""
414
+ level = get_conversion_level(SourceFormat.ONNX, SourceFormat.TENSORRT)
415
+ assert level == ConversionLevel.PARTIAL
416
+
417
+ def test_tensorrt_to_onnx(self) -> None:
418
+ """TensorRT to ONNX should be NONE (no export)."""
419
+ level = get_conversion_level(SourceFormat.TENSORRT, SourceFormat.ONNX)
420
+ assert level == ConversionLevel.NONE
421
+
422
+ def test_unknown_conversion(self) -> None:
423
+ """Unknown conversions should return NONE."""
424
+ level = get_conversion_level(SourceFormat.GGUF, SourceFormat.TENSORRT)
425
+ assert level == ConversionLevel.NONE
426
+
427
+ def test_string_format_input(self) -> None:
428
+ """Should accept string format names."""
429
+ level = get_conversion_level("onnx", "openvino")
430
+ assert level == ConversionLevel.FULL
431
+
432
+ def test_can_convert_true(self) -> None:
433
+ """can_convert should return True for valid conversions."""
434
+ assert can_convert("pytorch", "onnx") is True
435
+ assert can_convert("onnx", "tflite") is True
436
+
437
+ def test_can_convert_false(self) -> None:
438
+ """can_convert should return False for impossible conversions."""
439
+ assert can_convert("safetensors", "onnx") is False
440
+ assert can_convert("tensorrt", "onnx") is False
441
+
442
+ def test_list_conversion_paths(self) -> None:
443
+ """list_conversion_paths should return available conversions."""
444
+ paths = list_conversion_paths()
445
+ assert len(paths) > 0
446
+ # Each path should have source, target, level
447
+ for path in paths:
448
+ assert "source" in path
449
+ assert "target" in path
450
+ assert "level" in path
451
+
452
+ def test_list_conversion_paths_filtered_source(self) -> None:
453
+ """list_conversion_paths should filter by source."""
454
+ paths = list_conversion_paths(source="onnx")
455
+ assert all(p["source"] == "onnx" for p in paths)
456
+
457
+ def test_list_conversion_paths_filtered_target(self) -> None:
458
+ """list_conversion_paths should filter by target."""
459
+ paths = list_conversion_paths(target="onnx")
460
+ assert all(p["target"] == "onnx" for p in paths)
@@ -0,0 +1,237 @@
1
+ # Copyright (c) 2025 HaoLine Contributors
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """
5
+ Unit tests for the hardware module (profiles, detection, estimation).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ import pytest
14
+
15
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
16
+ from ..hardware import (
17
+ HARDWARE_PROFILES,
18
+ NVIDIA_A100_80GB,
19
+ NVIDIA_JETSON_NANO,
20
+ NVIDIA_RTX_4090,
21
+ HardwareEstimates,
22
+ HardwareEstimator,
23
+ get_profile,
24
+ list_available_profiles,
25
+ )
26
+
27
+
28
+ class TestHardwareProfiles:
29
+ """Tests for hardware profile definitions."""
30
+
31
+ def test_profile_registry_not_empty(self):
32
+ """Verify hardware profile registry has entries."""
33
+ assert len(HARDWARE_PROFILES) > 0
34
+
35
+ def test_get_profile_by_name(self):
36
+ """Test retrieving profiles by name."""
37
+ profile = get_profile("a100")
38
+ assert profile is not None
39
+ assert "A100" in profile.name
40
+
41
+ profile = get_profile("rtx4090")
42
+ assert profile is not None
43
+ assert "4090" in profile.name
44
+
45
+ def test_get_profile_case_insensitive(self):
46
+ """Profile lookup should be case-insensitive."""
47
+ assert get_profile("A100") is not None
48
+ assert get_profile("a100") is not None
49
+ assert get_profile("RTX4090") is not None
50
+
51
+ def test_get_unknown_profile_returns_none(self):
52
+ """Unknown profile names should return None."""
53
+ assert get_profile("nonexistent_gpu") is None
54
+
55
+ def test_list_available_profiles(self):
56
+ """Test listing available profiles."""
57
+ profiles = list_available_profiles()
58
+ assert len(profiles) > 0
59
+ assert any("4090" in p for p in profiles)
60
+ assert any("A100" in p for p in profiles)
61
+
62
+ def test_profile_has_required_fields(self):
63
+ """All profiles should have required fields."""
64
+ for name, profile in HARDWARE_PROFILES.items():
65
+ assert profile.name, f"{name} missing name"
66
+ assert profile.vendor, f"{name} missing vendor"
67
+ assert profile.device_type, f"{name} missing device_type"
68
+ assert profile.vram_bytes > 0, f"{name} missing vram"
69
+ assert profile.memory_bandwidth_bytes_per_s > 0, f"{name} missing bandwidth"
70
+
71
+ def test_jetson_profiles_exist(self):
72
+ """Verify Jetson edge profiles are available."""
73
+ assert get_profile("jetson-nano") is not None
74
+ assert get_profile("jetson-orin-nano-4gb") is not None
75
+ assert get_profile("jetson-agx-orin") is not None
76
+
77
+
78
+ class TestHardwareProfileDataclass:
79
+ """Tests for HardwareProfile dataclass."""
80
+
81
+ def test_to_dict(self):
82
+ """Test serialization to dictionary."""
83
+ profile = NVIDIA_A100_80GB
84
+ data = profile.to_dict()
85
+
86
+ assert data["name"] == "NVIDIA A100 80GB SXM"
87
+ assert data["vendor"] == "nvidia"
88
+ assert data["device_type"] == "gpu"
89
+ assert data["vram_gb"] == 80.0
90
+ assert data["peak_fp32_tflops"] == 19.5
91
+
92
+ def test_compute_capability(self):
93
+ """Verify compute capability is set for NVIDIA GPUs."""
94
+ assert NVIDIA_A100_80GB.compute_capability == "8.0"
95
+ assert NVIDIA_RTX_4090.compute_capability == "8.9"
96
+
97
+
98
+ class TestHardwareEstimator:
99
+ """Tests for HardwareEstimator class."""
100
+
101
+ def test_estimate_fits_in_vram(self):
102
+ """Test model that fits in VRAM."""
103
+ estimator = HardwareEstimator()
104
+
105
+ # Small model: 1M params, 1B FLOPs
106
+ estimates = estimator.estimate(
107
+ model_params=1_000_000,
108
+ model_flops=1_000_000_000,
109
+ peak_activation_bytes=10_000_000, # 10 MB
110
+ hardware=NVIDIA_A100_80GB,
111
+ batch_size=1,
112
+ precision="fp32",
113
+ )
114
+
115
+ assert estimates.fits_in_vram
116
+ assert estimates.vram_required_bytes > 0
117
+ assert estimates.theoretical_latency_ms > 0
118
+ assert estimates.bottleneck in ("compute", "memory_bandwidth")
119
+
120
+ def test_estimate_does_not_fit_in_vram(self):
121
+ """Test model that doesn't fit in VRAM."""
122
+ estimator = HardwareEstimator()
123
+
124
+ # Huge model: 100B params (400GB at fp32)
125
+ estimates = estimator.estimate(
126
+ model_params=100_000_000_000,
127
+ model_flops=1_000_000_000_000,
128
+ peak_activation_bytes=10_000_000_000,
129
+ hardware=NVIDIA_A100_80GB,
130
+ batch_size=1,
131
+ precision="fp32",
132
+ )
133
+
134
+ assert not estimates.fits_in_vram
135
+ assert estimates.bottleneck == "vram"
136
+
137
+ def test_precision_affects_vram(self):
138
+ """Test that precision affects VRAM requirements."""
139
+ estimator = HardwareEstimator()
140
+
141
+ params = 10_000_000_000 # 10B params
142
+
143
+ fp32_est = estimator.estimate(
144
+ model_params=params,
145
+ model_flops=1_000_000_000,
146
+ peak_activation_bytes=100_000_000,
147
+ hardware=NVIDIA_A100_80GB,
148
+ precision="fp32",
149
+ )
150
+
151
+ fp16_est = estimator.estimate(
152
+ model_params=params,
153
+ model_flops=1_000_000_000,
154
+ peak_activation_bytes=100_000_000,
155
+ hardware=NVIDIA_A100_80GB,
156
+ precision="fp16",
157
+ )
158
+
159
+ # FP16 should require less VRAM
160
+ assert fp16_est.vram_required_bytes < fp32_est.vram_required_bytes
161
+
162
+ def test_batch_size_affects_vram(self):
163
+ """Test that batch size affects VRAM requirements."""
164
+ estimator = HardwareEstimator()
165
+
166
+ batch1 = estimator.estimate(
167
+ model_params=1_000_000,
168
+ model_flops=1_000_000_000,
169
+ peak_activation_bytes=100_000_000,
170
+ hardware=NVIDIA_A100_80GB,
171
+ batch_size=1,
172
+ )
173
+
174
+ batch8 = estimator.estimate(
175
+ model_params=1_000_000,
176
+ model_flops=1_000_000_000,
177
+ peak_activation_bytes=100_000_000,
178
+ hardware=NVIDIA_A100_80GB,
179
+ batch_size=8,
180
+ )
181
+
182
+ # Larger batch should require more VRAM
183
+ assert batch8.vram_required_bytes > batch1.vram_required_bytes
184
+
185
+ def test_jetson_nano_constraints(self):
186
+ """Test estimation on resource-constrained Jetson Nano."""
187
+ estimator = HardwareEstimator()
188
+
189
+ # Medium model that might not fit on Jetson Nano (4GB)
190
+ estimates = estimator.estimate(
191
+ model_params=100_000_000, # 100M params = 400MB at fp32
192
+ model_flops=1_000_000_000,
193
+ peak_activation_bytes=500_000_000, # 500MB activations
194
+ hardware=NVIDIA_JETSON_NANO,
195
+ batch_size=1,
196
+ precision="fp32",
197
+ )
198
+
199
+ # Should still fit at fp32 but be tight
200
+ assert estimates.fits_in_vram
201
+ # Jetson Nano is likely memory bandwidth limited
202
+ assert estimates.bottleneck in ("memory_bandwidth", "compute")
203
+
204
+
205
+ class TestHardwareEstimatesDataclass:
206
+ """Tests for HardwareEstimates dataclass."""
207
+
208
+ def test_to_dict(self):
209
+ """Test serialization to dictionary."""
210
+ estimates = HardwareEstimates(
211
+ device="NVIDIA A100 80GB",
212
+ precision="fp16",
213
+ batch_size=8,
214
+ vram_required_bytes=1024 * 1024 * 1024, # 1 GB
215
+ fits_in_vram=True,
216
+ theoretical_latency_ms=5.5,
217
+ compute_utilization_estimate=0.75,
218
+ gpu_saturation=0.000001, # Tiny model on big GPU
219
+ bottleneck="compute",
220
+ model_flops=1_000_000_000,
221
+ hardware_peak_tflops=312.0,
222
+ )
223
+
224
+ data = estimates.to_dict()
225
+
226
+ assert data["device"] == "NVIDIA A100 80GB"
227
+ assert data["precision"] == "fp16"
228
+ assert data["batch_size"] == 8
229
+ assert data["vram_required_gb"] == 1.0
230
+ assert data["fits_in_vram"] is True
231
+ assert data["theoretical_latency_ms"] == 5.5
232
+ assert data["bottleneck"] == "compute"
233
+ assert data["gpu_saturation"] == 0.000001
234
+
235
+
236
+ if __name__ == "__main__":
237
+ pytest.main([__file__, "-v"])