tritonparse 0.3.2.dev20251210071601__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/__init__.py +0 -0
- tritonparse/__main__.py +7 -0
- tritonparse/cli.py +110 -0
- tritonparse/common.py +409 -0
- tritonparse/context_manager.py +64 -0
- tritonparse/event_diff.py +122 -0
- tritonparse/extract_source_mappings.py +49 -0
- tritonparse/info/__init__.py +30 -0
- tritonparse/info/cli.py +121 -0
- tritonparse/info/kernel_query.py +209 -0
- tritonparse/info/parse_helper.py +70 -0
- tritonparse/ir_analysis.py +427 -0
- tritonparse/ir_parser.py +365 -0
- tritonparse/mapper.py +102 -0
- tritonparse/reproducer/__init__.py +0 -0
- tritonparse/reproducer/ast_analyzer.py +636 -0
- tritonparse/reproducer/cli.py +72 -0
- tritonparse/reproducer/consolidated_result.py +52 -0
- tritonparse/reproducer/function_extractor.py +228 -0
- tritonparse/reproducer/import_info.py +25 -0
- tritonparse/reproducer/import_parser.py +178 -0
- tritonparse/reproducer/import_resolver.py +151 -0
- tritonparse/reproducer/ingestion/ndjson.py +237 -0
- tritonparse/reproducer/multi_file_analyzer.py +824 -0
- tritonparse/reproducer/orchestrator.py +110 -0
- tritonparse/reproducer/placeholder_replacer.py +335 -0
- tritonparse/reproducer/templates/__init__.py +0 -0
- tritonparse/reproducer/templates/example.py +38 -0
- tritonparse/reproducer/templates/loader.py +59 -0
- tritonparse/reproducer/templates/tritonbench.py +106 -0
- tritonparse/reproducer/templates/utils.py +48 -0
- tritonparse/reproducer/tests/__init__.py +0 -0
- tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
- tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
- tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
- tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
- tritonparse/reproducer/tests/test_import_parser.py +164 -0
- tritonparse/reproducer/tests/test_import_resolver.py +88 -0
- tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
- tritonparse/reproducer/types.py +20 -0
- tritonparse/reproducer/utils.py +580 -0
- tritonparse/shared_vars.py +12 -0
- tritonparse/source_type.py +56 -0
- tritonparse/sourcemap_utils.py +96 -0
- tritonparse/structured_logging.py +1634 -0
- tritonparse/tools/__init__.py +0 -0
- tritonparse/tools/decompress_bin_ndjson.py +120 -0
- tritonparse/tools/disasm.py +81 -0
- tritonparse/tools/extract_irs.py +244 -0
- tritonparse/tools/format_fix.py +151 -0
- tritonparse/tools/load_tensor.py +76 -0
- tritonparse/tools/prettify_ndjson.py +334 -0
- tritonparse/tools/readme.md +37 -0
- tritonparse/tp_logger.py +9 -0
- tritonparse/trace_processor.py +367 -0
- tritonparse/utils.py +155 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from tritonparse.info.kernel_query import find_launch_index_by_kernel
|
|
7
|
+
from tritonparse.reproducer.ingestion.ndjson import build_context_bundle
|
|
8
|
+
from tritonparse.reproducer.placeholder_replacer import (
|
|
9
|
+
DefaultPlaceholderReplacer,
|
|
10
|
+
PlaceholderReplacer,
|
|
11
|
+
)
|
|
12
|
+
from tritonparse.reproducer.templates.loader import load_template_code
|
|
13
|
+
from tritonparse.reproducer.types import KernelImportMode
|
|
14
|
+
from tritonparse.reproducer.utils import determine_output_paths, format_python_code
|
|
15
|
+
from tritonparse.tools.prettify_ndjson import load_ndjson, save_prettified_json
|
|
16
|
+
from tritonparse.tp_logger import logger
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def reproduce(
|
|
20
|
+
input_path: str,
|
|
21
|
+
line_index: int,
|
|
22
|
+
out_dir: str,
|
|
23
|
+
template: str,
|
|
24
|
+
kernel_name: Optional[str] = None,
|
|
25
|
+
launch_id: int = 0,
|
|
26
|
+
replacer: Optional[PlaceholderReplacer] = None,
|
|
27
|
+
kernel_import: KernelImportMode = KernelImportMode.DEFAULT,
|
|
28
|
+
) -> dict[str, str]:
|
|
29
|
+
"""
|
|
30
|
+
Generate a reproducer script from NDJSON trace file.
|
|
31
|
+
|
|
32
|
+
Must provide either line_index OR (kernel_name + launch_id), not both.
|
|
33
|
+
If kernel_name is provided, the line_index parameter will be ignored and
|
|
34
|
+
recalculated from the kernel lookup.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
input_path: Path to ndjson file. Supports uncompressed (.ndjson),
|
|
38
|
+
gzip compressed (.ndjson.gz), and gzip member concatenation (.bin.ndjson) formats.
|
|
39
|
+
line_index: 0-based index in events list. Ignored if kernel_name is provided.
|
|
40
|
+
out_dir: Output directory for reproducer files.
|
|
41
|
+
template: Template name to use for the reproducer.
|
|
42
|
+
kernel_name: Exact kernel name to match (case-sensitive). If provided, line_index will be recalculated.
|
|
43
|
+
launch_id: 0-based launch index for the kernel (default: 0, first launch).
|
|
44
|
+
replacer: Optional custom PlaceholderReplacer instance. If None, uses DefaultPlaceholderReplacer.
|
|
45
|
+
kernel_import: Kernel import mode (DEFAULT or COPY).
|
|
46
|
+
"""
|
|
47
|
+
events = load_ndjson(Path(input_path))
|
|
48
|
+
logger.debug(f"Loaded {len(events)} events")
|
|
49
|
+
|
|
50
|
+
# If kernel_name is provided, lookup the actual line_index (overrides the parameter)
|
|
51
|
+
if kernel_name is not None:
|
|
52
|
+
logger.debug(
|
|
53
|
+
f"Looking up kernel '{kernel_name}' launch_id={launch_id} in {input_path}"
|
|
54
|
+
)
|
|
55
|
+
line_index = find_launch_index_by_kernel(events, kernel_name, launch_id)
|
|
56
|
+
logger.debug(
|
|
57
|
+
f"Found kernel '{kernel_name}' launch_id={launch_id} at line {line_index}"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
logger.debug(f"Building bundle from {input_path} at line {line_index}")
|
|
61
|
+
|
|
62
|
+
# Build context bundle from the specified launch event
|
|
63
|
+
context_bundle = build_context_bundle(events, line_index)
|
|
64
|
+
logger.debug(
|
|
65
|
+
f"Built context bundle for kernel: {context_bundle.kernel_info.function_name}"
|
|
66
|
+
)
|
|
67
|
+
out_py_path, temp_json_path = determine_output_paths(
|
|
68
|
+
out_dir, context_bundle.kernel_info.function_name, template
|
|
69
|
+
)
|
|
70
|
+
save_prettified_json(context_bundle.raw_launch_event, temp_json_path)
|
|
71
|
+
|
|
72
|
+
# Save compilation event JSON if using OVERRIDE_TTIR mode
|
|
73
|
+
comp_json_path = None
|
|
74
|
+
if kernel_import == KernelImportMode.OVERRIDE_TTIR:
|
|
75
|
+
comp_json_path = (
|
|
76
|
+
temp_json_path.parent / f"{temp_json_path.stem}_compilation.json"
|
|
77
|
+
)
|
|
78
|
+
save_prettified_json(context_bundle.raw_comp_event, comp_json_path)
|
|
79
|
+
|
|
80
|
+
logger.debug("Loading reproducer template.")
|
|
81
|
+
template_code = load_template_code(template)
|
|
82
|
+
|
|
83
|
+
# Use PlaceholderReplacer to replace all placeholders
|
|
84
|
+
# If no custom replacer provided, use the default one
|
|
85
|
+
if replacer is None:
|
|
86
|
+
replacer = DefaultPlaceholderReplacer()
|
|
87
|
+
final_code = replacer.replace(
|
|
88
|
+
template_code,
|
|
89
|
+
context_bundle,
|
|
90
|
+
temp_json_path=temp_json_path,
|
|
91
|
+
kernel_import=kernel_import,
|
|
92
|
+
comp_json_filename=comp_json_path.name if comp_json_path else None,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Format the generated code
|
|
96
|
+
final_code = format_python_code(final_code)
|
|
97
|
+
|
|
98
|
+
out_py_path.write_text(final_code, encoding="utf-8")
|
|
99
|
+
|
|
100
|
+
filepath = context_bundle.kernel_info.file_path
|
|
101
|
+
filepath = "/".join(filepath.split("/")[5:])
|
|
102
|
+
ret = {
|
|
103
|
+
"kernel_src_path": filepath,
|
|
104
|
+
"kernel": context_bundle.kernel_info.function_name,
|
|
105
|
+
"repro_script": str(out_py_path.resolve()),
|
|
106
|
+
"repro_context": str(temp_json_path.resolve()),
|
|
107
|
+
}
|
|
108
|
+
logger.info("REPRODUCER_OUTPUT\n%s", ret)
|
|
109
|
+
|
|
110
|
+
return ret
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from typing import Any, Dict, Optional, Protocol
|
|
5
|
+
|
|
6
|
+
from tritonparse.reproducer.function_extractor import extract_utility_functions
|
|
7
|
+
from tritonparse.reproducer.ingestion.ndjson import ContextBundle
|
|
8
|
+
from tritonparse.reproducer.templates.utils import (
|
|
9
|
+
_disable_triton_autotune,
|
|
10
|
+
get_function_source,
|
|
11
|
+
)
|
|
12
|
+
from tritonparse.reproducer.types import KernelImportMode
|
|
13
|
+
from tritonparse.reproducer.utils import (
|
|
14
|
+
_generate_import_statements,
|
|
15
|
+
_generate_invocation_snippet,
|
|
16
|
+
_parse_kernel_signature,
|
|
17
|
+
)
|
|
18
|
+
from tritonparse.tp_logger import logger
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class HandlerProtocol(Protocol):
|
|
22
|
+
def __call__(
|
|
23
|
+
self, code: str, context_bundle: ContextBundle, **kwargs: Any
|
|
24
|
+
) -> str: ...
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PlaceholderReplacer(ABC):
|
|
28
|
+
"""
|
|
29
|
+
Abstract base class for template placeholder replacement.
|
|
30
|
+
|
|
31
|
+
Subclasses should register replacement handlers in their __init__ method
|
|
32
|
+
by calling self.register(placeholder, handler_function).
|
|
33
|
+
|
|
34
|
+
Each handler function should have the signature:
|
|
35
|
+
handler(code: str, context_bundle: ContextBundle, **kwargs) -> str
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self):
|
|
39
|
+
# Dictionary mapping placeholder strings to handler functions
|
|
40
|
+
self.handlers: Dict[str, HandlerProtocol] = {}
|
|
41
|
+
|
|
42
|
+
def register(self, placeholder: str, handler: HandlerProtocol):
|
|
43
|
+
"""
|
|
44
|
+
Register a handler function for a specific placeholder.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
placeholder: The placeholder string to replace (e.g., "{{JSON_FILE_NAME_PLACEHOLDER}}")
|
|
48
|
+
handler: A callable that takes (code, context_bundle, **kwargs) and returns modified code
|
|
49
|
+
"""
|
|
50
|
+
self.handlers[placeholder] = handler
|
|
51
|
+
|
|
52
|
+
def replace(
|
|
53
|
+
self, template_code: str, context_bundle: ContextBundle, **kwargs: Any
|
|
54
|
+
) -> str:
|
|
55
|
+
"""
|
|
56
|
+
Replace all registered placeholders in the template code.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
template_code: The template code containing placeholders
|
|
60
|
+
context_bundle: Context information about the kernel
|
|
61
|
+
**kwargs: Additional keyword arguments passed to handler functions
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
The code with all placeholders replaced
|
|
65
|
+
"""
|
|
66
|
+
code = template_code
|
|
67
|
+
for handler in self.handlers.values():
|
|
68
|
+
code = handler(code, context_bundle, **kwargs)
|
|
69
|
+
return code
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class DefaultPlaceholderReplacer(PlaceholderReplacer):
|
|
73
|
+
"""
|
|
74
|
+
Default implementation of PlaceholderReplacer.
|
|
75
|
+
|
|
76
|
+
Handles the following placeholders:
|
|
77
|
+
- {{JSON_FILE_NAME_PLACEHOLDER}}: Replaced with the JSON file name
|
|
78
|
+
- # {{KERNEL_SYSPATH_PLACEHOLDER}}: Replaced with sys.path setup code
|
|
79
|
+
- # {{KERNEL_IMPORT_PLACEHOLDER}}: Replaced with kernel import statement
|
|
80
|
+
- # {{KERNEL_INVOCATION_PLACEHOLDER}}: Replaced with kernel invocation code
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
KERNEL_NAME_PLACEHOLDER = "{{KERNEL_NAME_PLACEHOLDER}}"
|
|
84
|
+
JSON_FILE_NAME_PLACEHOLDER = "{{JSON_FILE_NAME_PLACEHOLDER}}"
|
|
85
|
+
IR_OVERRIDE_SETUP_PLACEHOLDER = "# {{IR_OVERRIDE_SETUP_PLACEHOLDER}}"
|
|
86
|
+
KERNEL_SYSPATH_PLACEHOLDER = "# {{KERNEL_SYSPATH_PLACEHOLDER}}"
|
|
87
|
+
KERNEL_IMPORT_PLACEHOLDER = "# {{KERNEL_IMPORT_PLACEHOLDER}}"
|
|
88
|
+
UTILITY_FUNCTIONS_PLACEHOLDER = "# {{UTILITY_FUNCTIONS_PLACEHOLDER}}"
|
|
89
|
+
KERNEL_INVOCATION_PLACEHOLDER = "# {{KERNEL_INVOCATION_PLACEHOLDER}}"
|
|
90
|
+
|
|
91
|
+
def __init__(self):
|
|
92
|
+
super().__init__()
|
|
93
|
+
# Register all default handlers
|
|
94
|
+
self.register(self.JSON_FILE_NAME_PLACEHOLDER, self._replace_json_filename)
|
|
95
|
+
self.register(
|
|
96
|
+
self.IR_OVERRIDE_SETUP_PLACEHOLDER, self._replace_ir_override_setup
|
|
97
|
+
)
|
|
98
|
+
self.register(self.KERNEL_SYSPATH_PLACEHOLDER, self._replace_kernel_syspath)
|
|
99
|
+
self.register(self.KERNEL_IMPORT_PLACEHOLDER, self._replace_kernel_import)
|
|
100
|
+
self.register(
|
|
101
|
+
self.UTILITY_FUNCTIONS_PLACEHOLDER, self._replace_utility_functions
|
|
102
|
+
)
|
|
103
|
+
self.register(
|
|
104
|
+
self.KERNEL_INVOCATION_PLACEHOLDER, self._replace_kernel_invocation
|
|
105
|
+
)
|
|
106
|
+
self.register(self.KERNEL_NAME_PLACEHOLDER, self._replace_kernel_name)
|
|
107
|
+
|
|
108
|
+
def _replace_kernel_name(
|
|
109
|
+
self, code: str, context_bundle: ContextBundle, **kwargs
|
|
110
|
+
) -> str:
|
|
111
|
+
"""Replace the kernel name placeholder."""
|
|
112
|
+
kernel_name = context_bundle.kernel_info.function_name
|
|
113
|
+
if not kernel_name:
|
|
114
|
+
raise ValueError("Kernel function name is not available")
|
|
115
|
+
return code.replace(self.KERNEL_NAME_PLACEHOLDER, kernel_name)
|
|
116
|
+
|
|
117
|
+
def _replace_json_filename(
|
|
118
|
+
self, code: str, context_bundle: ContextBundle, **kwargs
|
|
119
|
+
) -> str:
|
|
120
|
+
"""Replace the JSON file name placeholder."""
|
|
121
|
+
temp_json_path = kwargs.get("temp_json_path")
|
|
122
|
+
if temp_json_path is None:
|
|
123
|
+
raise ValueError("temp_json_path is required for JSON filename replacement")
|
|
124
|
+
return code.replace(self.JSON_FILE_NAME_PLACEHOLDER, temp_json_path.name)
|
|
125
|
+
|
|
126
|
+
def _replace_ir_override_setup(
|
|
127
|
+
self, code: str, context_bundle: ContextBundle, **kwargs
|
|
128
|
+
) -> str:
|
|
129
|
+
"""Replace the IR override setup placeholder."""
|
|
130
|
+
kernel_import = kwargs.get("kernel_import", KernelImportMode.DEFAULT)
|
|
131
|
+
|
|
132
|
+
if kernel_import != KernelImportMode.OVERRIDE_TTIR:
|
|
133
|
+
return code.replace(self.IR_OVERRIDE_SETUP_PLACEHOLDER, "")
|
|
134
|
+
|
|
135
|
+
comp_json_filename = kwargs.get("comp_json_filename")
|
|
136
|
+
if not comp_json_filename:
|
|
137
|
+
raise ValueError("comp_json_filename is required for OVERRIDE_TTIR mode")
|
|
138
|
+
|
|
139
|
+
setup_code = f'''
|
|
140
|
+
def create_ttir_tempfile():
|
|
141
|
+
"""Extract TTIR from compilation event and create temporary file."""
|
|
142
|
+
script_dir = Path(__file__).resolve().parent
|
|
143
|
+
comp_json_file = script_dir / "{comp_json_filename}"
|
|
144
|
+
|
|
145
|
+
with open(comp_json_file, 'r') as f:
|
|
146
|
+
comp_data = json.load(f)
|
|
147
|
+
|
|
148
|
+
# Extract TTIR content
|
|
149
|
+
kernel_name = comp_data['payload']['metadata']['name']
|
|
150
|
+
ttir_key = f"{{kernel_name}}.ttir"
|
|
151
|
+
ttir_content = comp_data['payload']['file_content'][ttir_key]
|
|
152
|
+
|
|
153
|
+
# Create temporary file
|
|
154
|
+
temp_file = tempfile.NamedTemporaryFile(
|
|
155
|
+
mode='w',
|
|
156
|
+
suffix='.ttir',
|
|
157
|
+
delete=False,
|
|
158
|
+
prefix=f'{{kernel_name}}_'
|
|
159
|
+
)
|
|
160
|
+
temp_file.write(ttir_content)
|
|
161
|
+
temp_file.close()
|
|
162
|
+
return temp_file.name
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
# Monkeypatch triton.autotune to use our TTIR
|
|
166
|
+
_ttir_file = create_ttir_tempfile()
|
|
167
|
+
_original_autotune = None
|
|
168
|
+
|
|
169
|
+
def _patched_autotune(configs, key=None, **kwargs):
|
|
170
|
+
"""Patched autotune that uses our TTIR file."""
|
|
171
|
+
import triton
|
|
172
|
+
# Replace configs with our single config using ir_override
|
|
173
|
+
new_configs = [triton.Config(kwargs={{}}, ir_override=_ttir_file)]
|
|
174
|
+
# Call original autotune with our config
|
|
175
|
+
return _original_autotune(new_configs, key=[], **kwargs)
|
|
176
|
+
|
|
177
|
+
# Apply the monkeypatch before importing the kernel
|
|
178
|
+
import triton
|
|
179
|
+
_original_autotune = triton.autotune
|
|
180
|
+
triton.autotune = _patched_autotune
|
|
181
|
+
'''
|
|
182
|
+
|
|
183
|
+
return code.replace(self.IR_OVERRIDE_SETUP_PLACEHOLDER, setup_code)
|
|
184
|
+
|
|
185
|
+
def _replace_kernel_syspath(
|
|
186
|
+
self, code: str, context_bundle: ContextBundle, **kwargs
|
|
187
|
+
) -> str:
|
|
188
|
+
"""Replace the kernel sys.path placeholder."""
|
|
189
|
+
kernel_import = kwargs.get("kernel_import", KernelImportMode.DEFAULT)
|
|
190
|
+
|
|
191
|
+
if kernel_import == KernelImportMode.DEFAULT:
|
|
192
|
+
sys_stmt, _ = _generate_import_statements(context_bundle.kernel_info)
|
|
193
|
+
return code.replace(self.KERNEL_SYSPATH_PLACEHOLDER, sys_stmt)
|
|
194
|
+
elif kernel_import == KernelImportMode.COPY:
|
|
195
|
+
comment = (
|
|
196
|
+
"# Kernel sys.path setup skipped - kernel source code embedded below"
|
|
197
|
+
)
|
|
198
|
+
return code.replace(self.KERNEL_SYSPATH_PLACEHOLDER, comment)
|
|
199
|
+
elif kernel_import == KernelImportMode.OVERRIDE_TTIR:
|
|
200
|
+
comment = "# Kernel sys.path setup skipped - using IR override mode"
|
|
201
|
+
return code.replace(self.KERNEL_SYSPATH_PLACEHOLDER, comment)
|
|
202
|
+
else:
|
|
203
|
+
raise ValueError(f"Unknown kernel_import mode: {kernel_import}")
|
|
204
|
+
|
|
205
|
+
def _replace_kernel_import(
|
|
206
|
+
self, code: str, context_bundle: ContextBundle, **kwargs
|
|
207
|
+
) -> str:
|
|
208
|
+
"""Replace the kernel import placeholder."""
|
|
209
|
+
kernel_import = kwargs.get("kernel_import", KernelImportMode.DEFAULT)
|
|
210
|
+
|
|
211
|
+
if kernel_import == KernelImportMode.DEFAULT:
|
|
212
|
+
_, import_statement = _generate_import_statements(
|
|
213
|
+
context_bundle.kernel_info
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
final_stmt = "\n".join(
|
|
217
|
+
[import_statement, ""] + get_function_source(_disable_triton_autotune)
|
|
218
|
+
)
|
|
219
|
+
return code.replace(self.KERNEL_IMPORT_PLACEHOLDER, final_stmt)
|
|
220
|
+
elif kernel_import == KernelImportMode.COPY:
|
|
221
|
+
source_code = context_bundle.kernel_info.source_code
|
|
222
|
+
func_name = context_bundle.kernel_info.function_name
|
|
223
|
+
|
|
224
|
+
if not source_code or not source_code.strip():
|
|
225
|
+
raise ValueError("Kernel source code is empty, cannot use 'copy' mode")
|
|
226
|
+
if not func_name:
|
|
227
|
+
raise ValueError(
|
|
228
|
+
"Cannot determine kernel function name for 'copy' mode"
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if kernel_import == KernelImportMode.COPY:
|
|
232
|
+
dependent_source_map = get_dependent_source_map(
|
|
233
|
+
context_bundle.kernel_info.function_name,
|
|
234
|
+
context_bundle.kernel_info.file_path,
|
|
235
|
+
)
|
|
236
|
+
# Only add dependent functions if extraction was successful
|
|
237
|
+
if dependent_source_map:
|
|
238
|
+
# Add separator, import statements, and dependent functions
|
|
239
|
+
dependent_code = (
|
|
240
|
+
"\n\n# Dependent functions extracted from source file\n\n"
|
|
241
|
+
)
|
|
242
|
+
dependent_code += "\n\n".join(dependent_source_map.values())
|
|
243
|
+
source_code += "\n\n" + dependent_code
|
|
244
|
+
logger.debug("Appended dependent functions to kernel source code")
|
|
245
|
+
|
|
246
|
+
# Add common imports needed for most Triton kernels
|
|
247
|
+
import_lines = [
|
|
248
|
+
"import torch",
|
|
249
|
+
"import numpy as np",
|
|
250
|
+
"import triton",
|
|
251
|
+
"import triton.language as tl",
|
|
252
|
+
"from typing import List, Tuple",
|
|
253
|
+
"",
|
|
254
|
+
] + get_function_source(_disable_triton_autotune)
|
|
255
|
+
|
|
256
|
+
# Combine: imports + kernel source code + alias
|
|
257
|
+
embedded_code = "\n".join(import_lines)
|
|
258
|
+
embedded_code += "\n" + source_code
|
|
259
|
+
embedded_code += f"\n\n# Use kernel function directly\nimported_kernel_function = {func_name}"
|
|
260
|
+
|
|
261
|
+
return code.replace(self.KERNEL_IMPORT_PLACEHOLDER, embedded_code)
|
|
262
|
+
elif kernel_import == KernelImportMode.OVERRIDE_TTIR:
|
|
263
|
+
comment = "# Kernel import skipped - using IR override mode with TTIR"
|
|
264
|
+
return code.replace(self.KERNEL_IMPORT_PLACEHOLDER, comment)
|
|
265
|
+
else:
|
|
266
|
+
raise ValueError(f"Unknown kernel_import mode: {kernel_import}")
|
|
267
|
+
|
|
268
|
+
def _replace_utility_functions(
|
|
269
|
+
self, code: str, context_bundle: ContextBundle, **kwargs
|
|
270
|
+
) -> str:
|
|
271
|
+
"""Replace the utility functions placeholder with extracted functions."""
|
|
272
|
+
utility_code = extract_utility_functions()
|
|
273
|
+
return code.replace(self.UTILITY_FUNCTIONS_PLACEHOLDER, utility_code)
|
|
274
|
+
|
|
275
|
+
def _replace_kernel_invocation(
|
|
276
|
+
self, code: str, context_bundle: ContextBundle, **kwargs
|
|
277
|
+
) -> str:
|
|
278
|
+
"""Replace the kernel invocation placeholder."""
|
|
279
|
+
source_code = context_bundle.kernel_info.source_code
|
|
280
|
+
pos_args, kw_args = _parse_kernel_signature(source_code)
|
|
281
|
+
invocation_snippet = _generate_invocation_snippet(pos_args, kw_args)
|
|
282
|
+
return code.replace(self.KERNEL_INVOCATION_PLACEHOLDER, invocation_snippet)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def get_dependent_source_map(
|
|
286
|
+
function_name: str, file_path: str
|
|
287
|
+
) -> Optional[dict[str, str]]:
|
|
288
|
+
"""
|
|
289
|
+
Extract dependent functions and their required imports.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
A tuple of (functions_dict, import_statements_list) or None if extraction fails.
|
|
293
|
+
- functions_dict: Maps qualified function names to their source code
|
|
294
|
+
- import_statements_list: List of formatted import statements needed by dependent functions
|
|
295
|
+
"""
|
|
296
|
+
from pathlib import Path
|
|
297
|
+
|
|
298
|
+
from tritonparse.tp_logger import logger
|
|
299
|
+
|
|
300
|
+
source_path = Path(file_path)
|
|
301
|
+
if not source_path.exists():
|
|
302
|
+
return None
|
|
303
|
+
|
|
304
|
+
try:
|
|
305
|
+
# Use MultiFileCallGraphAnalyzer for multi-file analysis
|
|
306
|
+
from tritonparse.reproducer.multi_file_analyzer import (
|
|
307
|
+
MultiFileCallGraphAnalyzer,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
analyzer = MultiFileCallGraphAnalyzer(
|
|
311
|
+
entry_file=file_path,
|
|
312
|
+
entry_function=function_name,
|
|
313
|
+
)
|
|
314
|
+
result = analyzer.analyze()
|
|
315
|
+
|
|
316
|
+
logger.info(
|
|
317
|
+
f"Extracted {result.stats.total_functions_found} dependent functions "
|
|
318
|
+
f"from {result.stats.total_files_analyzed} files with "
|
|
319
|
+
f"{result.stats.total_imports} imports"
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# Print dependent functions' short names
|
|
323
|
+
logger.info("\nDependent functions (short names):")
|
|
324
|
+
for func_name in sorted(result.function_short_names.keys()):
|
|
325
|
+
short_name = result.function_short_names[func_name]
|
|
326
|
+
logger.info(
|
|
327
|
+
" - %s. %s", short_name, result.functions[func_name].splitlines()[0]
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
return result.functions
|
|
331
|
+
|
|
332
|
+
except Exception as e:
|
|
333
|
+
# If AST analysis fails, continue without dependent functions
|
|
334
|
+
logger.warning(f"Failed to extract dependent functions: {e}", exc_info=True)
|
|
335
|
+
return None
|
|
File without changes
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This file is automatically generated by TritonParse reproducer.
|
|
3
|
+
It contains a smallest testing example for a Triton kernel.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
# {{IR_OVERRIDE_SETUP_PLACEHOLDER}}
|
|
13
|
+
|
|
14
|
+
# {{KERNEL_SYSPATH_PLACEHOLDER}}
|
|
15
|
+
|
|
16
|
+
# {{KERNEL_IMPORT_PLACEHOLDER}}
|
|
17
|
+
|
|
18
|
+
# {{UTILITY_FUNCTIONS_PLACEHOLDER}}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def launch_kernel():
|
|
22
|
+
script_dir = Path(__file__).resolve().parent # noqa: F821
|
|
23
|
+
json_file = script_dir / "{{JSON_FILE_NAME_PLACEHOLDER}}"
|
|
24
|
+
grid, args_dict = create_args_from_json_file(str(json_file)) # noqa: F821
|
|
25
|
+
|
|
26
|
+
print("Generated kernel arguments dictionary:")
|
|
27
|
+
for name, arg in args_dict.items():
|
|
28
|
+
print(f" {name}: {arg}")
|
|
29
|
+
print(f"Grid: {grid}")
|
|
30
|
+
|
|
31
|
+
# {{KERNEL_INVOCATION_PLACEHOLDER}}
|
|
32
|
+
|
|
33
|
+
torch.cuda.synchronize()
|
|
34
|
+
print("Kernel execution finished.")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
if __name__ == "__main__":
|
|
38
|
+
launch_kernel()
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
from importlib.resources import files as pkg_files
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
BUILTIN_TEMPLATES_PACKAGE = "tritonparse.reproducer.templates"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _is_path_like(template_arg: str) -> bool:
|
|
12
|
+
return "/" in template_arg or "\\" in template_arg or template_arg.endswith(".py")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _read_file_text(path: Path) -> str:
|
|
16
|
+
p = path.expanduser().resolve()
|
|
17
|
+
if not p.exists() or not p.is_file():
|
|
18
|
+
raise FileNotFoundError(f"Template not found: {p}")
|
|
19
|
+
return p.read_text(encoding="utf-8")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _read_builtin_template_text(name: str) -> str:
|
|
23
|
+
resource = pkg_files(BUILTIN_TEMPLATES_PACKAGE).joinpath(f"{name}.py")
|
|
24
|
+
# resource may not exist if an invalid name is provided
|
|
25
|
+
try:
|
|
26
|
+
with resource.open("r", encoding="utf-8") as f:
|
|
27
|
+
return f.read()
|
|
28
|
+
except FileNotFoundError as exc:
|
|
29
|
+
available = ", ".join(list_builtin_templates())
|
|
30
|
+
raise FileNotFoundError(
|
|
31
|
+
f"Builtin template '{name}' not found. Available: {available}"
|
|
32
|
+
) from exc
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def list_builtin_templates() -> List[str]:
|
|
36
|
+
"""
|
|
37
|
+
Return the list of available builtin template names (without .py suffix).
|
|
38
|
+
"""
|
|
39
|
+
names: List[str] = []
|
|
40
|
+
for entry in pkg_files(BUILTIN_TEMPLATES_PACKAGE).iterdir():
|
|
41
|
+
try:
|
|
42
|
+
if entry.is_file():
|
|
43
|
+
filename = entry.name
|
|
44
|
+
if filename.endswith(".py") and not filename.startswith("__"):
|
|
45
|
+
names.append(filename[:-3])
|
|
46
|
+
except (OSError, FileNotFoundError):
|
|
47
|
+
# Defensive: in case entry access fails in some environments
|
|
48
|
+
continue
|
|
49
|
+
names.sort()
|
|
50
|
+
return names
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def load_template_code(template_arg: str) -> str:
|
|
54
|
+
"""
|
|
55
|
+
Load template code by name (builtin, without .py) or by filesystem path.
|
|
56
|
+
"""
|
|
57
|
+
if _is_path_like(template_arg):
|
|
58
|
+
return _read_file_text(Path(template_arg))
|
|
59
|
+
return _read_builtin_template_text(template_arg)
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Callable, Dict, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from tritonbench.utils.triton_op import (
|
|
9
|
+
BenchmarkOperator,
|
|
10
|
+
register_benchmark,
|
|
11
|
+
REGISTERED_X_VALS,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
imported_kernel_function: Optional[Callable[[Tuple[int], Dict[str, Any]], None]] = None
|
|
18
|
+
|
|
19
|
+
# {{IR_OVERRIDE_SETUP_PLACEHOLDER}}
|
|
20
|
+
|
|
21
|
+
# {{KERNEL_SYSPATH_PLACEHOLDER}}
|
|
22
|
+
|
|
23
|
+
# {{KERNEL_IMPORT_PLACEHOLDER}}
|
|
24
|
+
|
|
25
|
+
# {{UTILITY_FUNCTIONS_PLACEHOLDER}}
|
|
26
|
+
|
|
27
|
+
assert imported_kernel_function is not None, "imported_kernel_function is missing"
|
|
28
|
+
|
|
29
|
+
KERNEL_NAME = "{{KERNEL_NAME_PLACEHOLDER}}"
|
|
30
|
+
REPRO_CONTEXT_FILE_NAME = "{{JSON_FILE_NAME_PLACEHOLDER}}"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _get_launch_kernel_args() -> Tuple[Tuple[int], Dict[str, Any]]:
|
|
34
|
+
script_dir = Path(__file__).resolve().parent # noqa: F821
|
|
35
|
+
json_file = script_dir / REPRO_CONTEXT_FILE_NAME
|
|
36
|
+
|
|
37
|
+
grid, args_dict = create_args_from_json_file(json_file) # noqa: F821, F841
|
|
38
|
+
|
|
39
|
+
print("Recorded kernel arguments dictionary:")
|
|
40
|
+
for name, arg in args_dict.items():
|
|
41
|
+
if isinstance(arg, torch.Tensor):
|
|
42
|
+
print(
|
|
43
|
+
f" {name}: Tensor: {arg.shape} {arg.dtype} stride: {arg.stride()}, is_contiguous: {arg.is_contiguous()}"
|
|
44
|
+
)
|
|
45
|
+
else:
|
|
46
|
+
print(f" {name}: {arg}")
|
|
47
|
+
print(f"Grid: {grid}")
|
|
48
|
+
|
|
49
|
+
return tuple(grid), args_dict
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
grid, args_dict = _get_launch_kernel_args()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _launch_kernel(grid: tuple[int], args_dict: dict[str, Any]):
|
|
56
|
+
try:
|
|
57
|
+
assert grid is not None
|
|
58
|
+
assert args_dict is not None
|
|
59
|
+
|
|
60
|
+
# {{KERNEL_INVOCATION_PLACEHOLDER}}
|
|
61
|
+
|
|
62
|
+
except Exception as e:
|
|
63
|
+
print(f"Error: {e}")
|
|
64
|
+
print("Failed to launch kernel!")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# HACK: @register_x_val doesn't allow us to pass `operator_name`` as a parameter
|
|
68
|
+
tensor_args = {k: v for k, v in args_dict.items() if isinstance(v, torch.Tensor)}
|
|
69
|
+
x_vals_label = ", ".join(tensor_args.keys())
|
|
70
|
+
REGISTERED_X_VALS[KERNEL_NAME] = x_vals_label
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class Operator(BenchmarkOperator):
|
|
74
|
+
@register_benchmark(operator_name=KERNEL_NAME)
|
|
75
|
+
def run_kernel(self, grid, args_dict):
|
|
76
|
+
return lambda: _launch_kernel(grid, args_dict)
|
|
77
|
+
|
|
78
|
+
def get_input_iter(self):
|
|
79
|
+
yield {"grid": grid, "args_dict": args_dict}
|
|
80
|
+
|
|
81
|
+
def get_x_val(self, example_inputs):
|
|
82
|
+
tensors_shapes = [
|
|
83
|
+
tuple(v.shape)
|
|
84
|
+
for v in example_inputs["args_dict"].values()
|
|
85
|
+
if isinstance(v, torch.Tensor)
|
|
86
|
+
]
|
|
87
|
+
return tuple(tensors_shapes)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
if __name__ == "__main__":
|
|
91
|
+
print("do_benchmark...")
|
|
92
|
+
|
|
93
|
+
args = [
|
|
94
|
+
"--benchmark-name",
|
|
95
|
+
KERNEL_NAME,
|
|
96
|
+
]
|
|
97
|
+
|
|
98
|
+
from tritonbench.utils.parser import get_parser
|
|
99
|
+
|
|
100
|
+
parser = get_parser(args)
|
|
101
|
+
tb_args, extra_args = parser.parse_known_args(args)
|
|
102
|
+
bench = Operator(tb_args, extra_args)
|
|
103
|
+
bench.run()
|
|
104
|
+
|
|
105
|
+
print(bench.output)
|
|
106
|
+
print("Benchmark completed successfully!")
|