tritonparse 0.3.2.dev20251210071601__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritonparse might be problematic. Click here for more details.

Files changed (62) hide show
  1. tritonparse/__init__.py +0 -0
  2. tritonparse/__main__.py +7 -0
  3. tritonparse/cli.py +110 -0
  4. tritonparse/common.py +409 -0
  5. tritonparse/context_manager.py +64 -0
  6. tritonparse/event_diff.py +122 -0
  7. tritonparse/extract_source_mappings.py +49 -0
  8. tritonparse/info/__init__.py +30 -0
  9. tritonparse/info/cli.py +121 -0
  10. tritonparse/info/kernel_query.py +209 -0
  11. tritonparse/info/parse_helper.py +70 -0
  12. tritonparse/ir_analysis.py +427 -0
  13. tritonparse/ir_parser.py +365 -0
  14. tritonparse/mapper.py +102 -0
  15. tritonparse/reproducer/__init__.py +0 -0
  16. tritonparse/reproducer/ast_analyzer.py +636 -0
  17. tritonparse/reproducer/cli.py +72 -0
  18. tritonparse/reproducer/consolidated_result.py +52 -0
  19. tritonparse/reproducer/function_extractor.py +228 -0
  20. tritonparse/reproducer/import_info.py +25 -0
  21. tritonparse/reproducer/import_parser.py +178 -0
  22. tritonparse/reproducer/import_resolver.py +151 -0
  23. tritonparse/reproducer/ingestion/ndjson.py +237 -0
  24. tritonparse/reproducer/multi_file_analyzer.py +824 -0
  25. tritonparse/reproducer/orchestrator.py +110 -0
  26. tritonparse/reproducer/placeholder_replacer.py +335 -0
  27. tritonparse/reproducer/templates/__init__.py +0 -0
  28. tritonparse/reproducer/templates/example.py +38 -0
  29. tritonparse/reproducer/templates/loader.py +59 -0
  30. tritonparse/reproducer/templates/tritonbench.py +106 -0
  31. tritonparse/reproducer/templates/utils.py +48 -0
  32. tritonparse/reproducer/tests/__init__.py +0 -0
  33. tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
  34. tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
  35. tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
  36. tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
  37. tritonparse/reproducer/tests/test_import_parser.py +164 -0
  38. tritonparse/reproducer/tests/test_import_resolver.py +88 -0
  39. tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
  40. tritonparse/reproducer/types.py +20 -0
  41. tritonparse/reproducer/utils.py +580 -0
  42. tritonparse/shared_vars.py +12 -0
  43. tritonparse/source_type.py +56 -0
  44. tritonparse/sourcemap_utils.py +96 -0
  45. tritonparse/structured_logging.py +1634 -0
  46. tritonparse/tools/__init__.py +0 -0
  47. tritonparse/tools/decompress_bin_ndjson.py +120 -0
  48. tritonparse/tools/disasm.py +81 -0
  49. tritonparse/tools/extract_irs.py +244 -0
  50. tritonparse/tools/format_fix.py +151 -0
  51. tritonparse/tools/load_tensor.py +76 -0
  52. tritonparse/tools/prettify_ndjson.py +334 -0
  53. tritonparse/tools/readme.md +37 -0
  54. tritonparse/tp_logger.py +9 -0
  55. tritonparse/trace_processor.py +367 -0
  56. tritonparse/utils.py +155 -0
  57. tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
  58. tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
  59. tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
  60. tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
  61. tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
  62. tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
@@ -0,0 +1,48 @@
1
+ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
+
3
+ import inspect
4
+ import logging
5
+ from collections.abc import Callable
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def get_function_source(
11
+ func: Callable[[], None], with_invocation: bool = True
12
+ ) -> list[str]:
13
+ """
14
+ Extract function source code and optionally include invocation.
15
+
16
+ Args:
17
+ func: Function to extract source code from
18
+ with_invocation: Whether to include function invocation code
19
+
20
+ Returns:
21
+ List containing source code and optional invocation statement
22
+ """
23
+ source = inspect.getsource(func).rstrip()
24
+ result = [source]
25
+
26
+ if with_invocation:
27
+ result.append("")
28
+ result.append(f"{func.__name__}()")
29
+
30
+ return result
31
+
32
+
33
+ def _disable_triton_autotune() -> None:
34
+ """
35
+ Monkey patch the triton.autotune decorator to skip autotuning entirely.
36
+ """
37
+ logger.info("Disabling triton autotune")
38
+
39
+ def dummy_autotune(configs, key=None, **kwargs):
40
+ def decorator(func):
41
+ return func # Just pass through, let @triton.jit handle compilation
42
+
43
+ return decorator
44
+
45
+ import triton
46
+
47
+ triton.autotune = dummy_autotune
48
+ logger.info("Disabled triton autotune")
File without changes
@@ -0,0 +1,5 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # pyre-strict
4
+
5
+ """Test artifacts for multi-file call graph analyzer integration tests."""
@@ -0,0 +1,65 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # pyre-strict
4
+
5
+ if __name__ == "__main__":
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ # Add tritonparse_root to sys.path for module resolution
10
+ tritonparse_root = Path(__file__).resolve().parents[4]
11
+ if str(tritonparse_root) not in sys.path:
12
+ sys.path.insert(0, str(tritonparse_root))
13
+
14
+ import torch
15
+ import triton
16
+ import triton.language as tl
17
+ from tritonparse.reproducer.tests.artifacts.triton_preprocess import scale_kernel
18
+
19
+
20
+ @triton.jit
21
+ def main_kernel(
22
+ input_ptr,
23
+ output_ptr,
24
+ n_elements,
25
+ scale: tl.constexpr,
26
+ BLOCK_SIZE: tl.constexpr,
27
+ ):
28
+ pid = tl.program_id(axis=0)
29
+ block_start = pid * BLOCK_SIZE
30
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
31
+ mask = offsets < n_elements
32
+
33
+ x = tl.load(input_ptr + offsets, mask=mask, other=0.0)
34
+ scaled = scale_kernel(x, scale)
35
+ result = scaled * 2.0
36
+ tl.store(output_ptr + offsets, result, mask=mask)
37
+
38
+
39
+ def launch_main_kernel() -> None:
40
+ """Launch and test the main_kernel on GPU."""
41
+ if not torch.cuda.is_available():
42
+ print("CUDA not available - showing call graph only")
43
+ print(" main_kernel -> scale_kernel -> add_values")
44
+ return
45
+
46
+ size = 1024
47
+ scale_factor = 3.0
48
+ BLOCK_SIZE = 256
49
+
50
+ x = torch.randn(size, device="cuda", dtype=torch.float32)
51
+ output = torch.zeros_like(x)
52
+
53
+ main_kernel[(256,)](
54
+ input_ptr=x,
55
+ output_ptr=output,
56
+ n_elements=size,
57
+ scale=scale_factor,
58
+ BLOCK_SIZE=BLOCK_SIZE,
59
+ )
60
+
61
+ print("✅ Kernel executed successfully")
62
+
63
+
64
+ if __name__ == "__main__":
65
+ launch_main_kernel()
@@ -0,0 +1,16 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # pyre-strict
4
+
5
+ import triton
6
+ import triton.language as tl
7
+ from tritonparse.reproducer.tests.artifacts.triton_utils import add_values
8
+
9
+
10
+ @triton.jit
11
+ def scale_kernel(
12
+ x: tl.tensor,
13
+ scale: tl.constexpr,
14
+ ) -> tl.tensor:
15
+ result = x * scale
16
+ return add_values(result, 0.0)
@@ -0,0 +1,14 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # pyre-strict
4
+
5
+ import triton
6
+ import triton.language as tl
7
+
8
+
9
+ @triton.jit
10
+ def add_values(
11
+ a: tl.tensor,
12
+ b: tl.tensor,
13
+ ) -> tl.tensor:
14
+ return a + b
@@ -0,0 +1,164 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # pyre-strict
4
+
5
+ import ast
6
+ import unittest
7
+ from pathlib import Path
8
+
9
+ from tritonparse.reproducer.import_parser import ImportParser
10
+ from tritonparse.reproducer.import_resolver import ImportResolver
11
+
12
+
13
+ class TestImportParser(unittest.TestCase):
14
+ """Test ImportParser functionality."""
15
+
16
+ def setUp(self) -> None:
17
+ """Set up test fixtures."""
18
+ test_dir = Path(__file__).resolve().parent
19
+ reproducer_dir = test_dir.parent
20
+ self.project_root = str(reproducer_dir.parent)
21
+ self.resolver = ImportResolver(self.project_root)
22
+ self.parser = ImportParser(self.resolver)
23
+
24
+ def test_parse_simple_import(self) -> None:
25
+ """Test parsing simple 'import X' statement."""
26
+ code = """
27
+ import os
28
+ """
29
+ tree = ast.parse(code, filename="test.py")
30
+ imports = self.parser.parse_imports(tree, "test.py")
31
+
32
+ self.assertEqual(len(imports), 1)
33
+ self.assertEqual(imports[0].import_type, "import")
34
+ self.assertEqual(imports[0].module, "os")
35
+ self.assertEqual(imports[0].names, ["os"])
36
+ self.assertTrue(imports[0].is_external)
37
+
38
+ def test_parse_import_with_alias(self) -> None:
39
+ """Test parsing 'import X as Y' statement."""
40
+ code = """
41
+ import numpy as np
42
+ """
43
+ tree = ast.parse(code, filename="test.py")
44
+ imports = self.parser.parse_imports(tree, "test.py")
45
+
46
+ self.assertEqual(len(imports), 1)
47
+ self.assertEqual(imports[0].module, "numpy")
48
+ self.assertIn("np", imports[0].aliases)
49
+ self.assertEqual(imports[0].aliases["np"], "numpy")
50
+
51
+ def test_parse_from_import(self) -> None:
52
+ """Test parsing 'from X import Y' statement."""
53
+ code = """
54
+ from typing import List, Dict
55
+ """
56
+ tree = ast.parse(code, filename="test.py")
57
+ imports = self.parser.parse_imports(tree, "test.py")
58
+
59
+ self.assertEqual(len(imports), 2)
60
+ self.assertEqual(imports[0].import_type, "from_import")
61
+ self.assertEqual(imports[0].module, "typing")
62
+ self.assertIn("List", imports[0].names)
63
+ self.assertIn("Dict", imports[1].names)
64
+
65
+ def test_parse_from_import_with_alias(self) -> None:
66
+ """Test parsing 'from X import Y as Z' statement."""
67
+ code = """
68
+ from collections import OrderedDict as OD
69
+ """
70
+ tree = ast.parse(code, filename="test.py")
71
+ imports = self.parser.parse_imports(tree, "test.py")
72
+
73
+ self.assertEqual(len(imports), 1)
74
+ self.assertEqual(imports[0].names, ["OrderedDict"])
75
+ self.assertIn("OD", imports[0].aliases)
76
+ self.assertEqual(imports[0].aliases["OD"], "OrderedDict")
77
+
78
+ def test_parse_relative_import_level_1(self) -> None:
79
+ """Test parsing relative import 'from . import X'.
80
+
81
+ When level=1 (from . import X), we're importing from the current package.
82
+ With package="pytorch.tritonparse.reproducer", level=1 means current package,
83
+ so the full module is "pytorch.tritonparse.reproducer".
84
+ """
85
+ code = """
86
+ from . import utils
87
+ """
88
+ tree = ast.parse(code, filename="test.py")
89
+ package = "tritonparse.reproducer"
90
+ imports = self.parser.parse_imports(tree, "test.py", package)
91
+
92
+ self.assertEqual(len(imports), 1)
93
+ self.assertEqual(imports[0].level, 1)
94
+ self.assertEqual(imports[0].module, "tritonparse.reproducer")
95
+ self.assertEqual(imports[0].names, ["utils"])
96
+
97
+ def test_parse_relative_import_level_2(self) -> None:
98
+ """Test parsing relative import 'from .. import X'.
99
+
100
+ When level=2 (from ..utils import X), we're importing from the parent package.
101
+ With package="pytorch.tritonparse.reproducer.submodule", level=2 removes 1 component,
102
+ giving us "pytorch.tritonparse.reproducer", then we append "utils" to get
103
+ "pytorch.tritonparse.reproducer.utils".
104
+ """
105
+ code = """
106
+ from ..utils import helper
107
+ """
108
+ tree = ast.parse(code, filename="test.py")
109
+ package = "tritonparse.reproducer.submodule"
110
+ imports = self.parser.parse_imports(tree, "test.py", package)
111
+
112
+ self.assertEqual(len(imports), 1)
113
+ self.assertEqual(imports[0].level, 2)
114
+ self.assertEqual(imports[0].module, "tritonparse.reproducer.utils")
115
+ self.assertEqual(imports[0].names, ["helper"])
116
+
117
+ def test_parse_multiple_imports(self) -> None:
118
+ """Test parsing multiple import statements."""
119
+ code = """
120
+ import os
121
+ import sys
122
+ from typing import List
123
+ from collections import defaultdict
124
+ """
125
+ tree = ast.parse(code, filename="test.py")
126
+ imports = self.parser.parse_imports(tree, "test.py")
127
+
128
+ self.assertEqual(len(imports), 4)
129
+ self.assertEqual(imports[0].module, "os")
130
+ self.assertEqual(imports[1].module, "sys")
131
+ self.assertEqual(imports[2].module, "typing")
132
+ self.assertEqual(imports[3].module, "collections")
133
+
134
+ def test_parse_project_internal_import(self) -> None:
135
+ """Test parsing imports from within the project."""
136
+ code = """
137
+ from tritonparse.reproducer import ast_analyzer
138
+ """
139
+ tree = ast.parse(code, filename="test.py")
140
+ imports = self.parser.parse_imports(tree, "test.py")
141
+
142
+ self.assertEqual(len(imports), 1)
143
+ self.assertFalse(imports[0].is_external)
144
+ self.assertIsNotNone(imports[0].resolved_path)
145
+ if imports[0].resolved_path:
146
+ self.assertTrue(imports[0].resolved_path.startswith(self.project_root))
147
+
148
+ def test_parse_lineno_tracking(self) -> None:
149
+ """Test that line numbers are correctly tracked."""
150
+ code = """
151
+ import os
152
+
153
+ from typing import List
154
+
155
+
156
+ import sys
157
+ """
158
+ tree = ast.parse(code, filename="test.py")
159
+ imports = self.parser.parse_imports(tree, "test.py")
160
+
161
+ self.assertEqual(len(imports), 3)
162
+ self.assertEqual(imports[0].lineno, 2) # import os
163
+ self.assertEqual(imports[1].lineno, 4) # from typing import List
164
+ self.assertEqual(imports[2].lineno, 7) # import sys
@@ -0,0 +1,88 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ """
4
+ Unit tests for ImportResolver.
5
+
6
+ Tests the ImportResolver implementation to ensure it correctly:
7
+ 1. Resolves internal imports
8
+ 2. Detects external modules
9
+ 3. Handles non-existent modules gracefully
10
+ """
11
+
12
+ import unittest
13
+ from pathlib import Path
14
+
15
+ from tritonparse.reproducer.import_resolver import ImportResolver
16
+
17
+
18
+ class ImportResolverTest(unittest.TestCase):
19
+ """Unit tests for ImportResolver class."""
20
+
21
+ def setUp(self) -> None:
22
+ """Set up test fixtures."""
23
+ # Use the tritonparse directory as project root for testing
24
+ test_dir = Path(__file__).resolve().parent
25
+ reproducer_dir = test_dir.parent
26
+ self.project_root = str(reproducer_dir.parent)
27
+ self.resolver = ImportResolver(project_root=self.project_root)
28
+
29
+ def test_resolve_tritonparse_module(self) -> None:
30
+ """Test resolving a module within tritonparse."""
31
+ # Setup: module that exists in tritonparse
32
+ module_name = "reproducer.ast_analyzer"
33
+
34
+ # Execute: resolve the module
35
+ path, is_external = self.resolver.resolve_import(module_name)
36
+
37
+ # Assert: should resolve to tritonparse path
38
+ self.assertIsNotNone(path)
39
+ assert path is not None # For type checker
40
+ self.assertTrue(path.startswith(self.project_root))
41
+ self.assertTrue(Path(path).exists())
42
+ self.assertFalse(is_external)
43
+
44
+ def test_external_module_torch(self) -> None:
45
+ """Test that torch is correctly identified as external."""
46
+ # Setup: torch is a known external module
47
+ module_name = "torch"
48
+
49
+ # Execute: resolve the module
50
+ path, is_external = self.resolver.resolve_import(module_name)
51
+
52
+ # Assert: should be external
53
+ self.assertIsNone(path)
54
+ self.assertTrue(is_external)
55
+
56
+ def test_nonexistent_module(self) -> None:
57
+ """Test handling of non-existent module."""
58
+ # Setup: a module that doesn't exist
59
+ module_name = "this_module_does_not_exist_anywhere"
60
+
61
+ # Execute: resolve the module
62
+ path, is_external = self.resolver.resolve_import(module_name)
63
+
64
+ # Assert: should handle gracefully as external
65
+ self.assertIsNone(path)
66
+ self.assertTrue(is_external)
67
+
68
+ def test_is_external_module_torch(self) -> None:
69
+ """Test is_external_module() for torch."""
70
+ # Setup: torch module name
71
+ module_name = "torch"
72
+
73
+ # Execute: check if external
74
+ result = self.resolver.is_external_module(module_name)
75
+
76
+ # Assert: should be true
77
+ self.assertTrue(result)
78
+
79
+ def test_is_external_module_internal(self) -> None:
80
+ """Test is_external_module() for internal module."""
81
+ # Setup: internal module name
82
+ module_name = "reproducer.ast_analyzer"
83
+
84
+ # Execute: check if external
85
+ result = self.resolver.is_external_module(module_name)
86
+
87
+ # Assert: should be false
88
+ self.assertFalse(result)
@@ -0,0 +1,118 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # pyre-strict
4
+
5
+ """
6
+ Tests for MultiFileCallGraphAnalyzer.
7
+
8
+ This test suite validates the multi-file call graph analysis functionality,
9
+ including cross-file dependency tracking, import resolution, and result consolidation.
10
+ """
11
+
12
+ import unittest
13
+ from pathlib import Path
14
+
15
+ from tritonparse.reproducer.multi_file_analyzer import MultiFileCallGraphAnalyzer
16
+
17
+
18
+ class TestMultiFileCallGraphAnalyzer(unittest.TestCase):
19
+ """Test the MultiFileCallGraphAnalyzer with real Triton test files."""
20
+
21
+ def setUp(self) -> None:
22
+ """Set up test fixtures."""
23
+ # Get code root (4 levels up from this test file)
24
+ test_file = Path(__file__).resolve()
25
+ self.code_root = str(test_file.parents[4])
26
+
27
+ # Test artifacts directory
28
+ self.artifacts_dir = test_file.parent / "artifacts"
29
+ self.entry_file = str(self.artifacts_dir / "triton_fused_kernel.py")
30
+ self.preprocess_file = str(self.artifacts_dir / "triton_preprocess.py")
31
+ self.utils_file = str(self.artifacts_dir / "triton_utils.py")
32
+
33
+ def test_single_file_analysis(self) -> None:
34
+ """Test analysis of a single file without following imports."""
35
+ # Setup: Analyze only triton_utils.py (which has no imports to internal files)
36
+ analyzer = MultiFileCallGraphAnalyzer(
37
+ entry_file=self.utils_file,
38
+ entry_function="add_values",
39
+ )
40
+
41
+ # Execute: Run analysis
42
+ result = analyzer.analyze()
43
+
44
+ # Assert: Only one file analyzed (triton_utils.py)
45
+ self.assertEqual(result.stats.total_files_analyzed, 1)
46
+ self.assertIn(self.utils_file, result.analyzed_files)
47
+
48
+ # Assert: add_values is the backend, should not be in dependent functions
49
+ # (dependent functions excludes the backend itself)
50
+ self.assertNotIn(
51
+ "tritonparse.reproducer.tests.artifacts.triton_utils.add_values",
52
+ result.functions,
53
+ )
54
+
55
+ # Assert: All imports should be external (triton, tl)
56
+ self.assertTrue(all(imp.is_external for imp in result.imports))
57
+
58
+ def test_multi_file_traversal(self) -> None:
59
+ """Test that analyzer follows imports across multiple files."""
60
+ # Setup: Start from main_kernel in triton_fused_kernel.py
61
+ analyzer = MultiFileCallGraphAnalyzer(
62
+ entry_file=self.entry_file,
63
+ entry_function="main_kernel",
64
+ )
65
+
66
+ # Execute: Run analysis
67
+ result = analyzer.analyze()
68
+
69
+ # Assert: Should analyze all 3 files in the call chain
70
+ # main_kernel -> scale_kernel -> add_values
71
+ self.assertEqual(result.stats.total_files_analyzed, 3)
72
+
73
+ # Assert: All 3 kernel files should be included
74
+ self.assertIn(self.entry_file, result.analyzed_files)
75
+ self.assertIn(self.preprocess_file, result.analyzed_files)
76
+ self.assertIn(self.utils_file, result.analyzed_files)
77
+
78
+ # Assert: Verify scale_kernel and add_values are in result.functions
79
+ function_names = set(result.functions.keys())
80
+ self.assertEqual(len(result.functions.items()), 2)
81
+ self.assertTrue(
82
+ any("scale_kernel" in name for name in function_names),
83
+ f"scale_kernel should be in functions. Found: {function_names}",
84
+ )
85
+ self.assertTrue(
86
+ any("add_values" in name for name in function_names),
87
+ f"add_values should be in functions. Found: {function_names}",
88
+ )
89
+
90
+ # Assert: Verify source code was properly extracted with exact lengths
91
+ # Find the full qualified names for these functions
92
+ scale_kernel_name = next(
93
+ name for name in function_names if "scale_kernel" in name
94
+ )
95
+ add_values_name = next(name for name in function_names if "add_values" in name)
96
+
97
+ # Verify source code has exact expected lengths
98
+ scale_kernel_source = result.functions[scale_kernel_name]
99
+ add_values_source = result.functions[add_values_name]
100
+
101
+ self.assertGreaterEqual(
102
+ len(scale_kernel_source),
103
+ 50,
104
+ f"scale_kernel source code should be exactly 50 chars. Got: {len(scale_kernel_source)} chars",
105
+ )
106
+ self.assertGreaterEqual(
107
+ len(add_values_source),
108
+ 50,
109
+ f"add_values source code should be exactly 50 chars. Got: {len(add_values_source)} chars",
110
+ )
111
+
112
+ # Verify short names are correctly extracted
113
+ self.assertEqual(result.function_short_names[scale_kernel_name], "scale_kernel")
114
+ self.assertEqual(result.function_short_names[add_values_name], "add_values")
115
+
116
+
117
+ if __name__ == "__main__":
118
+ unittest.main()
@@ -0,0 +1,20 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from enum import Enum
4
+
5
+
6
+ class KernelImportMode(str, Enum):
7
+ """
8
+ Kernel import strategy for reproducer generation.
9
+
10
+ Inherits from str to allow direct string comparison and use in argparse.
11
+
12
+ Attributes:
13
+ DEFAULT: Import kernel from original file (current behavior).
14
+ COPY: Embed kernel source code directly in reproducer.
15
+ OVERRIDE_TTIR: Use TTIR from compilation event with monkeypatch.
16
+ """
17
+
18
+ DEFAULT = "default"
19
+ COPY = "copy"
20
+ OVERRIDE_TTIR = "override-ttir"