tritonparse 0.3.2.dev20251210071601__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/__init__.py +0 -0
- tritonparse/__main__.py +7 -0
- tritonparse/cli.py +110 -0
- tritonparse/common.py +409 -0
- tritonparse/context_manager.py +64 -0
- tritonparse/event_diff.py +122 -0
- tritonparse/extract_source_mappings.py +49 -0
- tritonparse/info/__init__.py +30 -0
- tritonparse/info/cli.py +121 -0
- tritonparse/info/kernel_query.py +209 -0
- tritonparse/info/parse_helper.py +70 -0
- tritonparse/ir_analysis.py +427 -0
- tritonparse/ir_parser.py +365 -0
- tritonparse/mapper.py +102 -0
- tritonparse/reproducer/__init__.py +0 -0
- tritonparse/reproducer/ast_analyzer.py +636 -0
- tritonparse/reproducer/cli.py +72 -0
- tritonparse/reproducer/consolidated_result.py +52 -0
- tritonparse/reproducer/function_extractor.py +228 -0
- tritonparse/reproducer/import_info.py +25 -0
- tritonparse/reproducer/import_parser.py +178 -0
- tritonparse/reproducer/import_resolver.py +151 -0
- tritonparse/reproducer/ingestion/ndjson.py +237 -0
- tritonparse/reproducer/multi_file_analyzer.py +824 -0
- tritonparse/reproducer/orchestrator.py +110 -0
- tritonparse/reproducer/placeholder_replacer.py +335 -0
- tritonparse/reproducer/templates/__init__.py +0 -0
- tritonparse/reproducer/templates/example.py +38 -0
- tritonparse/reproducer/templates/loader.py +59 -0
- tritonparse/reproducer/templates/tritonbench.py +106 -0
- tritonparse/reproducer/templates/utils.py +48 -0
- tritonparse/reproducer/tests/__init__.py +0 -0
- tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
- tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
- tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
- tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
- tritonparse/reproducer/tests/test_import_parser.py +164 -0
- tritonparse/reproducer/tests/test_import_resolver.py +88 -0
- tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
- tritonparse/reproducer/types.py +20 -0
- tritonparse/reproducer/utils.py +580 -0
- tritonparse/shared_vars.py +12 -0
- tritonparse/source_type.py +56 -0
- tritonparse/sourcemap_utils.py +96 -0
- tritonparse/structured_logging.py +1634 -0
- tritonparse/tools/__init__.py +0 -0
- tritonparse/tools/decompress_bin_ndjson.py +120 -0
- tritonparse/tools/disasm.py +81 -0
- tritonparse/tools/extract_irs.py +244 -0
- tritonparse/tools/format_fix.py +151 -0
- tritonparse/tools/load_tensor.py +76 -0
- tritonparse/tools/prettify_ndjson.py +334 -0
- tritonparse/tools/readme.md +37 -0
- tritonparse/tp_logger.py +9 -0
- tritonparse/trace_processor.py +367 -0
- tritonparse/utils.py +155 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import logging
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_function_source(
|
|
11
|
+
func: Callable[[], None], with_invocation: bool = True
|
|
12
|
+
) -> list[str]:
|
|
13
|
+
"""
|
|
14
|
+
Extract function source code and optionally include invocation.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
func: Function to extract source code from
|
|
18
|
+
with_invocation: Whether to include function invocation code
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
List containing source code and optional invocation statement
|
|
22
|
+
"""
|
|
23
|
+
source = inspect.getsource(func).rstrip()
|
|
24
|
+
result = [source]
|
|
25
|
+
|
|
26
|
+
if with_invocation:
|
|
27
|
+
result.append("")
|
|
28
|
+
result.append(f"{func.__name__}()")
|
|
29
|
+
|
|
30
|
+
return result
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _disable_triton_autotune() -> None:
|
|
34
|
+
"""
|
|
35
|
+
Monkey patch the triton.autotune decorator to skip autotuning entirely.
|
|
36
|
+
"""
|
|
37
|
+
logger.info("Disabling triton autotune")
|
|
38
|
+
|
|
39
|
+
def dummy_autotune(configs, key=None, **kwargs):
|
|
40
|
+
def decorator(func):
|
|
41
|
+
return func # Just pass through, let @triton.jit handle compilation
|
|
42
|
+
|
|
43
|
+
return decorator
|
|
44
|
+
|
|
45
|
+
import triton
|
|
46
|
+
|
|
47
|
+
triton.autotune = dummy_autotune
|
|
48
|
+
logger.info("Disabled triton autotune")
|
|
File without changes
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
# pyre-strict
|
|
4
|
+
|
|
5
|
+
if __name__ == "__main__":
|
|
6
|
+
import sys
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
# Add tritonparse_root to sys.path for module resolution
|
|
10
|
+
tritonparse_root = Path(__file__).resolve().parents[4]
|
|
11
|
+
if str(tritonparse_root) not in sys.path:
|
|
12
|
+
sys.path.insert(0, str(tritonparse_root))
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import triton
|
|
16
|
+
import triton.language as tl
|
|
17
|
+
from tritonparse.reproducer.tests.artifacts.triton_preprocess import scale_kernel
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@triton.jit
|
|
21
|
+
def main_kernel(
|
|
22
|
+
input_ptr,
|
|
23
|
+
output_ptr,
|
|
24
|
+
n_elements,
|
|
25
|
+
scale: tl.constexpr,
|
|
26
|
+
BLOCK_SIZE: tl.constexpr,
|
|
27
|
+
):
|
|
28
|
+
pid = tl.program_id(axis=0)
|
|
29
|
+
block_start = pid * BLOCK_SIZE
|
|
30
|
+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
31
|
+
mask = offsets < n_elements
|
|
32
|
+
|
|
33
|
+
x = tl.load(input_ptr + offsets, mask=mask, other=0.0)
|
|
34
|
+
scaled = scale_kernel(x, scale)
|
|
35
|
+
result = scaled * 2.0
|
|
36
|
+
tl.store(output_ptr + offsets, result, mask=mask)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def launch_main_kernel() -> None:
|
|
40
|
+
"""Launch and test the main_kernel on GPU."""
|
|
41
|
+
if not torch.cuda.is_available():
|
|
42
|
+
print("CUDA not available - showing call graph only")
|
|
43
|
+
print(" main_kernel -> scale_kernel -> add_values")
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
size = 1024
|
|
47
|
+
scale_factor = 3.0
|
|
48
|
+
BLOCK_SIZE = 256
|
|
49
|
+
|
|
50
|
+
x = torch.randn(size, device="cuda", dtype=torch.float32)
|
|
51
|
+
output = torch.zeros_like(x)
|
|
52
|
+
|
|
53
|
+
main_kernel[(256,)](
|
|
54
|
+
input_ptr=x,
|
|
55
|
+
output_ptr=output,
|
|
56
|
+
n_elements=size,
|
|
57
|
+
scale=scale_factor,
|
|
58
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
print("✅ Kernel executed successfully")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
if __name__ == "__main__":
|
|
65
|
+
launch_main_kernel()
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
# pyre-strict
|
|
4
|
+
|
|
5
|
+
import triton
|
|
6
|
+
import triton.language as tl
|
|
7
|
+
from tritonparse.reproducer.tests.artifacts.triton_utils import add_values
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@triton.jit
|
|
11
|
+
def scale_kernel(
|
|
12
|
+
x: tl.tensor,
|
|
13
|
+
scale: tl.constexpr,
|
|
14
|
+
) -> tl.tensor:
|
|
15
|
+
result = x * scale
|
|
16
|
+
return add_values(result, 0.0)
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
# pyre-strict
|
|
4
|
+
|
|
5
|
+
import ast
|
|
6
|
+
import unittest
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
from tritonparse.reproducer.import_parser import ImportParser
|
|
10
|
+
from tritonparse.reproducer.import_resolver import ImportResolver
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TestImportParser(unittest.TestCase):
|
|
14
|
+
"""Test ImportParser functionality."""
|
|
15
|
+
|
|
16
|
+
def setUp(self) -> None:
|
|
17
|
+
"""Set up test fixtures."""
|
|
18
|
+
test_dir = Path(__file__).resolve().parent
|
|
19
|
+
reproducer_dir = test_dir.parent
|
|
20
|
+
self.project_root = str(reproducer_dir.parent)
|
|
21
|
+
self.resolver = ImportResolver(self.project_root)
|
|
22
|
+
self.parser = ImportParser(self.resolver)
|
|
23
|
+
|
|
24
|
+
def test_parse_simple_import(self) -> None:
|
|
25
|
+
"""Test parsing simple 'import X' statement."""
|
|
26
|
+
code = """
|
|
27
|
+
import os
|
|
28
|
+
"""
|
|
29
|
+
tree = ast.parse(code, filename="test.py")
|
|
30
|
+
imports = self.parser.parse_imports(tree, "test.py")
|
|
31
|
+
|
|
32
|
+
self.assertEqual(len(imports), 1)
|
|
33
|
+
self.assertEqual(imports[0].import_type, "import")
|
|
34
|
+
self.assertEqual(imports[0].module, "os")
|
|
35
|
+
self.assertEqual(imports[0].names, ["os"])
|
|
36
|
+
self.assertTrue(imports[0].is_external)
|
|
37
|
+
|
|
38
|
+
def test_parse_import_with_alias(self) -> None:
|
|
39
|
+
"""Test parsing 'import X as Y' statement."""
|
|
40
|
+
code = """
|
|
41
|
+
import numpy as np
|
|
42
|
+
"""
|
|
43
|
+
tree = ast.parse(code, filename="test.py")
|
|
44
|
+
imports = self.parser.parse_imports(tree, "test.py")
|
|
45
|
+
|
|
46
|
+
self.assertEqual(len(imports), 1)
|
|
47
|
+
self.assertEqual(imports[0].module, "numpy")
|
|
48
|
+
self.assertIn("np", imports[0].aliases)
|
|
49
|
+
self.assertEqual(imports[0].aliases["np"], "numpy")
|
|
50
|
+
|
|
51
|
+
def test_parse_from_import(self) -> None:
|
|
52
|
+
"""Test parsing 'from X import Y' statement."""
|
|
53
|
+
code = """
|
|
54
|
+
from typing import List, Dict
|
|
55
|
+
"""
|
|
56
|
+
tree = ast.parse(code, filename="test.py")
|
|
57
|
+
imports = self.parser.parse_imports(tree, "test.py")
|
|
58
|
+
|
|
59
|
+
self.assertEqual(len(imports), 2)
|
|
60
|
+
self.assertEqual(imports[0].import_type, "from_import")
|
|
61
|
+
self.assertEqual(imports[0].module, "typing")
|
|
62
|
+
self.assertIn("List", imports[0].names)
|
|
63
|
+
self.assertIn("Dict", imports[1].names)
|
|
64
|
+
|
|
65
|
+
def test_parse_from_import_with_alias(self) -> None:
|
|
66
|
+
"""Test parsing 'from X import Y as Z' statement."""
|
|
67
|
+
code = """
|
|
68
|
+
from collections import OrderedDict as OD
|
|
69
|
+
"""
|
|
70
|
+
tree = ast.parse(code, filename="test.py")
|
|
71
|
+
imports = self.parser.parse_imports(tree, "test.py")
|
|
72
|
+
|
|
73
|
+
self.assertEqual(len(imports), 1)
|
|
74
|
+
self.assertEqual(imports[0].names, ["OrderedDict"])
|
|
75
|
+
self.assertIn("OD", imports[0].aliases)
|
|
76
|
+
self.assertEqual(imports[0].aliases["OD"], "OrderedDict")
|
|
77
|
+
|
|
78
|
+
def test_parse_relative_import_level_1(self) -> None:
|
|
79
|
+
"""Test parsing relative import 'from . import X'.
|
|
80
|
+
|
|
81
|
+
When level=1 (from . import X), we're importing from the current package.
|
|
82
|
+
With package="pytorch.tritonparse.reproducer", level=1 means current package,
|
|
83
|
+
so the full module is "pytorch.tritonparse.reproducer".
|
|
84
|
+
"""
|
|
85
|
+
code = """
|
|
86
|
+
from . import utils
|
|
87
|
+
"""
|
|
88
|
+
tree = ast.parse(code, filename="test.py")
|
|
89
|
+
package = "tritonparse.reproducer"
|
|
90
|
+
imports = self.parser.parse_imports(tree, "test.py", package)
|
|
91
|
+
|
|
92
|
+
self.assertEqual(len(imports), 1)
|
|
93
|
+
self.assertEqual(imports[0].level, 1)
|
|
94
|
+
self.assertEqual(imports[0].module, "tritonparse.reproducer")
|
|
95
|
+
self.assertEqual(imports[0].names, ["utils"])
|
|
96
|
+
|
|
97
|
+
def test_parse_relative_import_level_2(self) -> None:
|
|
98
|
+
"""Test parsing relative import 'from .. import X'.
|
|
99
|
+
|
|
100
|
+
When level=2 (from ..utils import X), we're importing from the parent package.
|
|
101
|
+
With package="pytorch.tritonparse.reproducer.submodule", level=2 removes 1 component,
|
|
102
|
+
giving us "pytorch.tritonparse.reproducer", then we append "utils" to get
|
|
103
|
+
"pytorch.tritonparse.reproducer.utils".
|
|
104
|
+
"""
|
|
105
|
+
code = """
|
|
106
|
+
from ..utils import helper
|
|
107
|
+
"""
|
|
108
|
+
tree = ast.parse(code, filename="test.py")
|
|
109
|
+
package = "tritonparse.reproducer.submodule"
|
|
110
|
+
imports = self.parser.parse_imports(tree, "test.py", package)
|
|
111
|
+
|
|
112
|
+
self.assertEqual(len(imports), 1)
|
|
113
|
+
self.assertEqual(imports[0].level, 2)
|
|
114
|
+
self.assertEqual(imports[0].module, "tritonparse.reproducer.utils")
|
|
115
|
+
self.assertEqual(imports[0].names, ["helper"])
|
|
116
|
+
|
|
117
|
+
def test_parse_multiple_imports(self) -> None:
|
|
118
|
+
"""Test parsing multiple import statements."""
|
|
119
|
+
code = """
|
|
120
|
+
import os
|
|
121
|
+
import sys
|
|
122
|
+
from typing import List
|
|
123
|
+
from collections import defaultdict
|
|
124
|
+
"""
|
|
125
|
+
tree = ast.parse(code, filename="test.py")
|
|
126
|
+
imports = self.parser.parse_imports(tree, "test.py")
|
|
127
|
+
|
|
128
|
+
self.assertEqual(len(imports), 4)
|
|
129
|
+
self.assertEqual(imports[0].module, "os")
|
|
130
|
+
self.assertEqual(imports[1].module, "sys")
|
|
131
|
+
self.assertEqual(imports[2].module, "typing")
|
|
132
|
+
self.assertEqual(imports[3].module, "collections")
|
|
133
|
+
|
|
134
|
+
def test_parse_project_internal_import(self) -> None:
|
|
135
|
+
"""Test parsing imports from within the project."""
|
|
136
|
+
code = """
|
|
137
|
+
from tritonparse.reproducer import ast_analyzer
|
|
138
|
+
"""
|
|
139
|
+
tree = ast.parse(code, filename="test.py")
|
|
140
|
+
imports = self.parser.parse_imports(tree, "test.py")
|
|
141
|
+
|
|
142
|
+
self.assertEqual(len(imports), 1)
|
|
143
|
+
self.assertFalse(imports[0].is_external)
|
|
144
|
+
self.assertIsNotNone(imports[0].resolved_path)
|
|
145
|
+
if imports[0].resolved_path:
|
|
146
|
+
self.assertTrue(imports[0].resolved_path.startswith(self.project_root))
|
|
147
|
+
|
|
148
|
+
def test_parse_lineno_tracking(self) -> None:
|
|
149
|
+
"""Test that line numbers are correctly tracked."""
|
|
150
|
+
code = """
|
|
151
|
+
import os
|
|
152
|
+
|
|
153
|
+
from typing import List
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
import sys
|
|
157
|
+
"""
|
|
158
|
+
tree = ast.parse(code, filename="test.py")
|
|
159
|
+
imports = self.parser.parse_imports(tree, "test.py")
|
|
160
|
+
|
|
161
|
+
self.assertEqual(len(imports), 3)
|
|
162
|
+
self.assertEqual(imports[0].lineno, 2) # import os
|
|
163
|
+
self.assertEqual(imports[1].lineno, 4) # from typing import List
|
|
164
|
+
self.assertEqual(imports[2].lineno, 7) # import sys
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Unit tests for ImportResolver.
|
|
5
|
+
|
|
6
|
+
Tests the ImportResolver implementation to ensure it correctly:
|
|
7
|
+
1. Resolves internal imports
|
|
8
|
+
2. Detects external modules
|
|
9
|
+
3. Handles non-existent modules gracefully
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import unittest
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
from tritonparse.reproducer.import_resolver import ImportResolver
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ImportResolverTest(unittest.TestCase):
|
|
19
|
+
"""Unit tests for ImportResolver class."""
|
|
20
|
+
|
|
21
|
+
def setUp(self) -> None:
|
|
22
|
+
"""Set up test fixtures."""
|
|
23
|
+
# Use the tritonparse directory as project root for testing
|
|
24
|
+
test_dir = Path(__file__).resolve().parent
|
|
25
|
+
reproducer_dir = test_dir.parent
|
|
26
|
+
self.project_root = str(reproducer_dir.parent)
|
|
27
|
+
self.resolver = ImportResolver(project_root=self.project_root)
|
|
28
|
+
|
|
29
|
+
def test_resolve_tritonparse_module(self) -> None:
|
|
30
|
+
"""Test resolving a module within tritonparse."""
|
|
31
|
+
# Setup: module that exists in tritonparse
|
|
32
|
+
module_name = "reproducer.ast_analyzer"
|
|
33
|
+
|
|
34
|
+
# Execute: resolve the module
|
|
35
|
+
path, is_external = self.resolver.resolve_import(module_name)
|
|
36
|
+
|
|
37
|
+
# Assert: should resolve to tritonparse path
|
|
38
|
+
self.assertIsNotNone(path)
|
|
39
|
+
assert path is not None # For type checker
|
|
40
|
+
self.assertTrue(path.startswith(self.project_root))
|
|
41
|
+
self.assertTrue(Path(path).exists())
|
|
42
|
+
self.assertFalse(is_external)
|
|
43
|
+
|
|
44
|
+
def test_external_module_torch(self) -> None:
|
|
45
|
+
"""Test that torch is correctly identified as external."""
|
|
46
|
+
# Setup: torch is a known external module
|
|
47
|
+
module_name = "torch"
|
|
48
|
+
|
|
49
|
+
# Execute: resolve the module
|
|
50
|
+
path, is_external = self.resolver.resolve_import(module_name)
|
|
51
|
+
|
|
52
|
+
# Assert: should be external
|
|
53
|
+
self.assertIsNone(path)
|
|
54
|
+
self.assertTrue(is_external)
|
|
55
|
+
|
|
56
|
+
def test_nonexistent_module(self) -> None:
|
|
57
|
+
"""Test handling of non-existent module."""
|
|
58
|
+
# Setup: a module that doesn't exist
|
|
59
|
+
module_name = "this_module_does_not_exist_anywhere"
|
|
60
|
+
|
|
61
|
+
# Execute: resolve the module
|
|
62
|
+
path, is_external = self.resolver.resolve_import(module_name)
|
|
63
|
+
|
|
64
|
+
# Assert: should handle gracefully as external
|
|
65
|
+
self.assertIsNone(path)
|
|
66
|
+
self.assertTrue(is_external)
|
|
67
|
+
|
|
68
|
+
def test_is_external_module_torch(self) -> None:
|
|
69
|
+
"""Test is_external_module() for torch."""
|
|
70
|
+
# Setup: torch module name
|
|
71
|
+
module_name = "torch"
|
|
72
|
+
|
|
73
|
+
# Execute: check if external
|
|
74
|
+
result = self.resolver.is_external_module(module_name)
|
|
75
|
+
|
|
76
|
+
# Assert: should be true
|
|
77
|
+
self.assertTrue(result)
|
|
78
|
+
|
|
79
|
+
def test_is_external_module_internal(self) -> None:
|
|
80
|
+
"""Test is_external_module() for internal module."""
|
|
81
|
+
# Setup: internal module name
|
|
82
|
+
module_name = "reproducer.ast_analyzer"
|
|
83
|
+
|
|
84
|
+
# Execute: check if external
|
|
85
|
+
result = self.resolver.is_external_module(module_name)
|
|
86
|
+
|
|
87
|
+
# Assert: should be false
|
|
88
|
+
self.assertFalse(result)
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
# pyre-strict
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Tests for MultiFileCallGraphAnalyzer.
|
|
7
|
+
|
|
8
|
+
This test suite validates the multi-file call graph analysis functionality,
|
|
9
|
+
including cross-file dependency tracking, import resolution, and result consolidation.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import unittest
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
from tritonparse.reproducer.multi_file_analyzer import MultiFileCallGraphAnalyzer
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TestMultiFileCallGraphAnalyzer(unittest.TestCase):
|
|
19
|
+
"""Test the MultiFileCallGraphAnalyzer with real Triton test files."""
|
|
20
|
+
|
|
21
|
+
def setUp(self) -> None:
|
|
22
|
+
"""Set up test fixtures."""
|
|
23
|
+
# Get code root (4 levels up from this test file)
|
|
24
|
+
test_file = Path(__file__).resolve()
|
|
25
|
+
self.code_root = str(test_file.parents[4])
|
|
26
|
+
|
|
27
|
+
# Test artifacts directory
|
|
28
|
+
self.artifacts_dir = test_file.parent / "artifacts"
|
|
29
|
+
self.entry_file = str(self.artifacts_dir / "triton_fused_kernel.py")
|
|
30
|
+
self.preprocess_file = str(self.artifacts_dir / "triton_preprocess.py")
|
|
31
|
+
self.utils_file = str(self.artifacts_dir / "triton_utils.py")
|
|
32
|
+
|
|
33
|
+
def test_single_file_analysis(self) -> None:
|
|
34
|
+
"""Test analysis of a single file without following imports."""
|
|
35
|
+
# Setup: Analyze only triton_utils.py (which has no imports to internal files)
|
|
36
|
+
analyzer = MultiFileCallGraphAnalyzer(
|
|
37
|
+
entry_file=self.utils_file,
|
|
38
|
+
entry_function="add_values",
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Execute: Run analysis
|
|
42
|
+
result = analyzer.analyze()
|
|
43
|
+
|
|
44
|
+
# Assert: Only one file analyzed (triton_utils.py)
|
|
45
|
+
self.assertEqual(result.stats.total_files_analyzed, 1)
|
|
46
|
+
self.assertIn(self.utils_file, result.analyzed_files)
|
|
47
|
+
|
|
48
|
+
# Assert: add_values is the backend, should not be in dependent functions
|
|
49
|
+
# (dependent functions excludes the backend itself)
|
|
50
|
+
self.assertNotIn(
|
|
51
|
+
"tritonparse.reproducer.tests.artifacts.triton_utils.add_values",
|
|
52
|
+
result.functions,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Assert: All imports should be external (triton, tl)
|
|
56
|
+
self.assertTrue(all(imp.is_external for imp in result.imports))
|
|
57
|
+
|
|
58
|
+
def test_multi_file_traversal(self) -> None:
|
|
59
|
+
"""Test that analyzer follows imports across multiple files."""
|
|
60
|
+
# Setup: Start from main_kernel in triton_fused_kernel.py
|
|
61
|
+
analyzer = MultiFileCallGraphAnalyzer(
|
|
62
|
+
entry_file=self.entry_file,
|
|
63
|
+
entry_function="main_kernel",
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Execute: Run analysis
|
|
67
|
+
result = analyzer.analyze()
|
|
68
|
+
|
|
69
|
+
# Assert: Should analyze all 3 files in the call chain
|
|
70
|
+
# main_kernel -> scale_kernel -> add_values
|
|
71
|
+
self.assertEqual(result.stats.total_files_analyzed, 3)
|
|
72
|
+
|
|
73
|
+
# Assert: All 3 kernel files should be included
|
|
74
|
+
self.assertIn(self.entry_file, result.analyzed_files)
|
|
75
|
+
self.assertIn(self.preprocess_file, result.analyzed_files)
|
|
76
|
+
self.assertIn(self.utils_file, result.analyzed_files)
|
|
77
|
+
|
|
78
|
+
# Assert: Verify scale_kernel and add_values are in result.functions
|
|
79
|
+
function_names = set(result.functions.keys())
|
|
80
|
+
self.assertEqual(len(result.functions.items()), 2)
|
|
81
|
+
self.assertTrue(
|
|
82
|
+
any("scale_kernel" in name for name in function_names),
|
|
83
|
+
f"scale_kernel should be in functions. Found: {function_names}",
|
|
84
|
+
)
|
|
85
|
+
self.assertTrue(
|
|
86
|
+
any("add_values" in name for name in function_names),
|
|
87
|
+
f"add_values should be in functions. Found: {function_names}",
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Assert: Verify source code was properly extracted with exact lengths
|
|
91
|
+
# Find the full qualified names for these functions
|
|
92
|
+
scale_kernel_name = next(
|
|
93
|
+
name for name in function_names if "scale_kernel" in name
|
|
94
|
+
)
|
|
95
|
+
add_values_name = next(name for name in function_names if "add_values" in name)
|
|
96
|
+
|
|
97
|
+
# Verify source code has exact expected lengths
|
|
98
|
+
scale_kernel_source = result.functions[scale_kernel_name]
|
|
99
|
+
add_values_source = result.functions[add_values_name]
|
|
100
|
+
|
|
101
|
+
self.assertGreaterEqual(
|
|
102
|
+
len(scale_kernel_source),
|
|
103
|
+
50,
|
|
104
|
+
f"scale_kernel source code should be exactly 50 chars. Got: {len(scale_kernel_source)} chars",
|
|
105
|
+
)
|
|
106
|
+
self.assertGreaterEqual(
|
|
107
|
+
len(add_values_source),
|
|
108
|
+
50,
|
|
109
|
+
f"add_values source code should be exactly 50 chars. Got: {len(add_values_source)} chars",
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Verify short names are correctly extracted
|
|
113
|
+
self.assertEqual(result.function_short_names[scale_kernel_name], "scale_kernel")
|
|
114
|
+
self.assertEqual(result.function_short_names[add_values_name], "add_values")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
if __name__ == "__main__":
|
|
118
|
+
unittest.main()
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class KernelImportMode(str, Enum):
|
|
7
|
+
"""
|
|
8
|
+
Kernel import strategy for reproducer generation.
|
|
9
|
+
|
|
10
|
+
Inherits from str to allow direct string comparison and use in argparse.
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
DEFAULT: Import kernel from original file (current behavior).
|
|
14
|
+
COPY: Embed kernel source code directly in reproducer.
|
|
15
|
+
OVERRIDE_TTIR: Use TTIR from compilation event with monkeypatch.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
DEFAULT = "default"
|
|
19
|
+
COPY = "copy"
|
|
20
|
+
OVERRIDE_TTIR = "override-ttir"
|