tritonparse 0.3.2.dev20251210071601__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritonparse might be problematic. Click here for more details.

Files changed (62) hide show
  1. tritonparse/__init__.py +0 -0
  2. tritonparse/__main__.py +7 -0
  3. tritonparse/cli.py +110 -0
  4. tritonparse/common.py +409 -0
  5. tritonparse/context_manager.py +64 -0
  6. tritonparse/event_diff.py +122 -0
  7. tritonparse/extract_source_mappings.py +49 -0
  8. tritonparse/info/__init__.py +30 -0
  9. tritonparse/info/cli.py +121 -0
  10. tritonparse/info/kernel_query.py +209 -0
  11. tritonparse/info/parse_helper.py +70 -0
  12. tritonparse/ir_analysis.py +427 -0
  13. tritonparse/ir_parser.py +365 -0
  14. tritonparse/mapper.py +102 -0
  15. tritonparse/reproducer/__init__.py +0 -0
  16. tritonparse/reproducer/ast_analyzer.py +636 -0
  17. tritonparse/reproducer/cli.py +72 -0
  18. tritonparse/reproducer/consolidated_result.py +52 -0
  19. tritonparse/reproducer/function_extractor.py +228 -0
  20. tritonparse/reproducer/import_info.py +25 -0
  21. tritonparse/reproducer/import_parser.py +178 -0
  22. tritonparse/reproducer/import_resolver.py +151 -0
  23. tritonparse/reproducer/ingestion/ndjson.py +237 -0
  24. tritonparse/reproducer/multi_file_analyzer.py +824 -0
  25. tritonparse/reproducer/orchestrator.py +110 -0
  26. tritonparse/reproducer/placeholder_replacer.py +335 -0
  27. tritonparse/reproducer/templates/__init__.py +0 -0
  28. tritonparse/reproducer/templates/example.py +38 -0
  29. tritonparse/reproducer/templates/loader.py +59 -0
  30. tritonparse/reproducer/templates/tritonbench.py +106 -0
  31. tritonparse/reproducer/templates/utils.py +48 -0
  32. tritonparse/reproducer/tests/__init__.py +0 -0
  33. tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
  34. tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
  35. tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
  36. tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
  37. tritonparse/reproducer/tests/test_import_parser.py +164 -0
  38. tritonparse/reproducer/tests/test_import_resolver.py +88 -0
  39. tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
  40. tritonparse/reproducer/types.py +20 -0
  41. tritonparse/reproducer/utils.py +580 -0
  42. tritonparse/shared_vars.py +12 -0
  43. tritonparse/source_type.py +56 -0
  44. tritonparse/sourcemap_utils.py +96 -0
  45. tritonparse/structured_logging.py +1634 -0
  46. tritonparse/tools/__init__.py +0 -0
  47. tritonparse/tools/decompress_bin_ndjson.py +120 -0
  48. tritonparse/tools/disasm.py +81 -0
  49. tritonparse/tools/extract_irs.py +244 -0
  50. tritonparse/tools/format_fix.py +151 -0
  51. tritonparse/tools/load_tensor.py +76 -0
  52. tritonparse/tools/prettify_ndjson.py +334 -0
  53. tritonparse/tools/readme.md +37 -0
  54. tritonparse/tp_logger.py +9 -0
  55. tritonparse/trace_processor.py +367 -0
  56. tritonparse/utils.py +155 -0
  57. tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
  58. tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
  59. tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
  60. tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
  61. tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
  62. tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
@@ -0,0 +1,110 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ from tritonparse.info.kernel_query import find_launch_index_by_kernel
7
+ from tritonparse.reproducer.ingestion.ndjson import build_context_bundle
8
+ from tritonparse.reproducer.placeholder_replacer import (
9
+ DefaultPlaceholderReplacer,
10
+ PlaceholderReplacer,
11
+ )
12
+ from tritonparse.reproducer.templates.loader import load_template_code
13
+ from tritonparse.reproducer.types import KernelImportMode
14
+ from tritonparse.reproducer.utils import determine_output_paths, format_python_code
15
+ from tritonparse.tools.prettify_ndjson import load_ndjson, save_prettified_json
16
+ from tritonparse.tp_logger import logger
17
+
18
+
19
+ def reproduce(
20
+ input_path: str,
21
+ line_index: int,
22
+ out_dir: str,
23
+ template: str,
24
+ kernel_name: Optional[str] = None,
25
+ launch_id: int = 0,
26
+ replacer: Optional[PlaceholderReplacer] = None,
27
+ kernel_import: KernelImportMode = KernelImportMode.DEFAULT,
28
+ ) -> dict[str, str]:
29
+ """
30
+ Generate a reproducer script from NDJSON trace file.
31
+
32
+ Must provide either line_index OR (kernel_name + launch_id), not both.
33
+ If kernel_name is provided, the line_index parameter will be ignored and
34
+ recalculated from the kernel lookup.
35
+
36
+ Args:
37
+ input_path: Path to ndjson file. Supports uncompressed (.ndjson),
38
+ gzip compressed (.ndjson.gz), and gzip member concatenation (.bin.ndjson) formats.
39
+ line_index: 0-based index in events list. Ignored if kernel_name is provided.
40
+ out_dir: Output directory for reproducer files.
41
+ template: Template name to use for the reproducer.
42
+ kernel_name: Exact kernel name to match (case-sensitive). If provided, line_index will be recalculated.
43
+ launch_id: 0-based launch index for the kernel (default: 0, first launch).
44
+ replacer: Optional custom PlaceholderReplacer instance. If None, uses DefaultPlaceholderReplacer.
45
+ kernel_import: Kernel import mode (DEFAULT or COPY).
46
+ """
47
+ events = load_ndjson(Path(input_path))
48
+ logger.debug(f"Loaded {len(events)} events")
49
+
50
+ # If kernel_name is provided, lookup the actual line_index (overrides the parameter)
51
+ if kernel_name is not None:
52
+ logger.debug(
53
+ f"Looking up kernel '{kernel_name}' launch_id={launch_id} in {input_path}"
54
+ )
55
+ line_index = find_launch_index_by_kernel(events, kernel_name, launch_id)
56
+ logger.debug(
57
+ f"Found kernel '{kernel_name}' launch_id={launch_id} at line {line_index}"
58
+ )
59
+
60
+ logger.debug(f"Building bundle from {input_path} at line {line_index}")
61
+
62
+ # Build context bundle from the specified launch event
63
+ context_bundle = build_context_bundle(events, line_index)
64
+ logger.debug(
65
+ f"Built context bundle for kernel: {context_bundle.kernel_info.function_name}"
66
+ )
67
+ out_py_path, temp_json_path = determine_output_paths(
68
+ out_dir, context_bundle.kernel_info.function_name, template
69
+ )
70
+ save_prettified_json(context_bundle.raw_launch_event, temp_json_path)
71
+
72
+ # Save compilation event JSON if using OVERRIDE_TTIR mode
73
+ comp_json_path = None
74
+ if kernel_import == KernelImportMode.OVERRIDE_TTIR:
75
+ comp_json_path = (
76
+ temp_json_path.parent / f"{temp_json_path.stem}_compilation.json"
77
+ )
78
+ save_prettified_json(context_bundle.raw_comp_event, comp_json_path)
79
+
80
+ logger.debug("Loading reproducer template.")
81
+ template_code = load_template_code(template)
82
+
83
+ # Use PlaceholderReplacer to replace all placeholders
84
+ # If no custom replacer provided, use the default one
85
+ if replacer is None:
86
+ replacer = DefaultPlaceholderReplacer()
87
+ final_code = replacer.replace(
88
+ template_code,
89
+ context_bundle,
90
+ temp_json_path=temp_json_path,
91
+ kernel_import=kernel_import,
92
+ comp_json_filename=comp_json_path.name if comp_json_path else None,
93
+ )
94
+
95
+ # Format the generated code
96
+ final_code = format_python_code(final_code)
97
+
98
+ out_py_path.write_text(final_code, encoding="utf-8")
99
+
100
+ filepath = context_bundle.kernel_info.file_path
101
+ filepath = "/".join(filepath.split("/")[5:])
102
+ ret = {
103
+ "kernel_src_path": filepath,
104
+ "kernel": context_bundle.kernel_info.function_name,
105
+ "repro_script": str(out_py_path.resolve()),
106
+ "repro_context": str(temp_json_path.resolve()),
107
+ }
108
+ logger.info("REPRODUCER_OUTPUT\n%s", ret)
109
+
110
+ return ret
@@ -0,0 +1,335 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from abc import ABC
4
+ from typing import Any, Dict, Optional, Protocol
5
+
6
+ from tritonparse.reproducer.function_extractor import extract_utility_functions
7
+ from tritonparse.reproducer.ingestion.ndjson import ContextBundle
8
+ from tritonparse.reproducer.templates.utils import (
9
+ _disable_triton_autotune,
10
+ get_function_source,
11
+ )
12
+ from tritonparse.reproducer.types import KernelImportMode
13
+ from tritonparse.reproducer.utils import (
14
+ _generate_import_statements,
15
+ _generate_invocation_snippet,
16
+ _parse_kernel_signature,
17
+ )
18
+ from tritonparse.tp_logger import logger
19
+
20
+
21
+ class HandlerProtocol(Protocol):
22
+ def __call__(
23
+ self, code: str, context_bundle: ContextBundle, **kwargs: Any
24
+ ) -> str: ...
25
+
26
+
27
+ class PlaceholderReplacer(ABC):
28
+ """
29
+ Abstract base class for template placeholder replacement.
30
+
31
+ Subclasses should register replacement handlers in their __init__ method
32
+ by calling self.register(placeholder, handler_function).
33
+
34
+ Each handler function should have the signature:
35
+ handler(code: str, context_bundle: ContextBundle, **kwargs) -> str
36
+ """
37
+
38
+ def __init__(self):
39
+ # Dictionary mapping placeholder strings to handler functions
40
+ self.handlers: Dict[str, HandlerProtocol] = {}
41
+
42
+ def register(self, placeholder: str, handler: HandlerProtocol):
43
+ """
44
+ Register a handler function for a specific placeholder.
45
+
46
+ Args:
47
+ placeholder: The placeholder string to replace (e.g., "{{JSON_FILE_NAME_PLACEHOLDER}}")
48
+ handler: A callable that takes (code, context_bundle, **kwargs) and returns modified code
49
+ """
50
+ self.handlers[placeholder] = handler
51
+
52
+ def replace(
53
+ self, template_code: str, context_bundle: ContextBundle, **kwargs: Any
54
+ ) -> str:
55
+ """
56
+ Replace all registered placeholders in the template code.
57
+
58
+ Args:
59
+ template_code: The template code containing placeholders
60
+ context_bundle: Context information about the kernel
61
+ **kwargs: Additional keyword arguments passed to handler functions
62
+
63
+ Returns:
64
+ The code with all placeholders replaced
65
+ """
66
+ code = template_code
67
+ for handler in self.handlers.values():
68
+ code = handler(code, context_bundle, **kwargs)
69
+ return code
70
+
71
+
72
+ class DefaultPlaceholderReplacer(PlaceholderReplacer):
73
+ """
74
+ Default implementation of PlaceholderReplacer.
75
+
76
+ Handles the following placeholders:
77
+ - {{JSON_FILE_NAME_PLACEHOLDER}}: Replaced with the JSON file name
78
+ - # {{KERNEL_SYSPATH_PLACEHOLDER}}: Replaced with sys.path setup code
79
+ - # {{KERNEL_IMPORT_PLACEHOLDER}}: Replaced with kernel import statement
80
+ - # {{KERNEL_INVOCATION_PLACEHOLDER}}: Replaced with kernel invocation code
81
+ """
82
+
83
+ KERNEL_NAME_PLACEHOLDER = "{{KERNEL_NAME_PLACEHOLDER}}"
84
+ JSON_FILE_NAME_PLACEHOLDER = "{{JSON_FILE_NAME_PLACEHOLDER}}"
85
+ IR_OVERRIDE_SETUP_PLACEHOLDER = "# {{IR_OVERRIDE_SETUP_PLACEHOLDER}}"
86
+ KERNEL_SYSPATH_PLACEHOLDER = "# {{KERNEL_SYSPATH_PLACEHOLDER}}"
87
+ KERNEL_IMPORT_PLACEHOLDER = "# {{KERNEL_IMPORT_PLACEHOLDER}}"
88
+ UTILITY_FUNCTIONS_PLACEHOLDER = "# {{UTILITY_FUNCTIONS_PLACEHOLDER}}"
89
+ KERNEL_INVOCATION_PLACEHOLDER = "# {{KERNEL_INVOCATION_PLACEHOLDER}}"
90
+
91
+ def __init__(self):
92
+ super().__init__()
93
+ # Register all default handlers
94
+ self.register(self.JSON_FILE_NAME_PLACEHOLDER, self._replace_json_filename)
95
+ self.register(
96
+ self.IR_OVERRIDE_SETUP_PLACEHOLDER, self._replace_ir_override_setup
97
+ )
98
+ self.register(self.KERNEL_SYSPATH_PLACEHOLDER, self._replace_kernel_syspath)
99
+ self.register(self.KERNEL_IMPORT_PLACEHOLDER, self._replace_kernel_import)
100
+ self.register(
101
+ self.UTILITY_FUNCTIONS_PLACEHOLDER, self._replace_utility_functions
102
+ )
103
+ self.register(
104
+ self.KERNEL_INVOCATION_PLACEHOLDER, self._replace_kernel_invocation
105
+ )
106
+ self.register(self.KERNEL_NAME_PLACEHOLDER, self._replace_kernel_name)
107
+
108
+ def _replace_kernel_name(
109
+ self, code: str, context_bundle: ContextBundle, **kwargs
110
+ ) -> str:
111
+ """Replace the kernel name placeholder."""
112
+ kernel_name = context_bundle.kernel_info.function_name
113
+ if not kernel_name:
114
+ raise ValueError("Kernel function name is not available")
115
+ return code.replace(self.KERNEL_NAME_PLACEHOLDER, kernel_name)
116
+
117
+ def _replace_json_filename(
118
+ self, code: str, context_bundle: ContextBundle, **kwargs
119
+ ) -> str:
120
+ """Replace the JSON file name placeholder."""
121
+ temp_json_path = kwargs.get("temp_json_path")
122
+ if temp_json_path is None:
123
+ raise ValueError("temp_json_path is required for JSON filename replacement")
124
+ return code.replace(self.JSON_FILE_NAME_PLACEHOLDER, temp_json_path.name)
125
+
126
+ def _replace_ir_override_setup(
127
+ self, code: str, context_bundle: ContextBundle, **kwargs
128
+ ) -> str:
129
+ """Replace the IR override setup placeholder."""
130
+ kernel_import = kwargs.get("kernel_import", KernelImportMode.DEFAULT)
131
+
132
+ if kernel_import != KernelImportMode.OVERRIDE_TTIR:
133
+ return code.replace(self.IR_OVERRIDE_SETUP_PLACEHOLDER, "")
134
+
135
+ comp_json_filename = kwargs.get("comp_json_filename")
136
+ if not comp_json_filename:
137
+ raise ValueError("comp_json_filename is required for OVERRIDE_TTIR mode")
138
+
139
+ setup_code = f'''
140
+ def create_ttir_tempfile():
141
+ """Extract TTIR from compilation event and create temporary file."""
142
+ script_dir = Path(__file__).resolve().parent
143
+ comp_json_file = script_dir / "{comp_json_filename}"
144
+
145
+ with open(comp_json_file, 'r') as f:
146
+ comp_data = json.load(f)
147
+
148
+ # Extract TTIR content
149
+ kernel_name = comp_data['payload']['metadata']['name']
150
+ ttir_key = f"{{kernel_name}}.ttir"
151
+ ttir_content = comp_data['payload']['file_content'][ttir_key]
152
+
153
+ # Create temporary file
154
+ temp_file = tempfile.NamedTemporaryFile(
155
+ mode='w',
156
+ suffix='.ttir',
157
+ delete=False,
158
+ prefix=f'{{kernel_name}}_'
159
+ )
160
+ temp_file.write(ttir_content)
161
+ temp_file.close()
162
+ return temp_file.name
163
+
164
+
165
+ # Monkeypatch triton.autotune to use our TTIR
166
+ _ttir_file = create_ttir_tempfile()
167
+ _original_autotune = None
168
+
169
+ def _patched_autotune(configs, key=None, **kwargs):
170
+ """Patched autotune that uses our TTIR file."""
171
+ import triton
172
+ # Replace configs with our single config using ir_override
173
+ new_configs = [triton.Config(kwargs={{}}, ir_override=_ttir_file)]
174
+ # Call original autotune with our config
175
+ return _original_autotune(new_configs, key=[], **kwargs)
176
+
177
+ # Apply the monkeypatch before importing the kernel
178
+ import triton
179
+ _original_autotune = triton.autotune
180
+ triton.autotune = _patched_autotune
181
+ '''
182
+
183
+ return code.replace(self.IR_OVERRIDE_SETUP_PLACEHOLDER, setup_code)
184
+
185
+ def _replace_kernel_syspath(
186
+ self, code: str, context_bundle: ContextBundle, **kwargs
187
+ ) -> str:
188
+ """Replace the kernel sys.path placeholder."""
189
+ kernel_import = kwargs.get("kernel_import", KernelImportMode.DEFAULT)
190
+
191
+ if kernel_import == KernelImportMode.DEFAULT:
192
+ sys_stmt, _ = _generate_import_statements(context_bundle.kernel_info)
193
+ return code.replace(self.KERNEL_SYSPATH_PLACEHOLDER, sys_stmt)
194
+ elif kernel_import == KernelImportMode.COPY:
195
+ comment = (
196
+ "# Kernel sys.path setup skipped - kernel source code embedded below"
197
+ )
198
+ return code.replace(self.KERNEL_SYSPATH_PLACEHOLDER, comment)
199
+ elif kernel_import == KernelImportMode.OVERRIDE_TTIR:
200
+ comment = "# Kernel sys.path setup skipped - using IR override mode"
201
+ return code.replace(self.KERNEL_SYSPATH_PLACEHOLDER, comment)
202
+ else:
203
+ raise ValueError(f"Unknown kernel_import mode: {kernel_import}")
204
+
205
+ def _replace_kernel_import(
206
+ self, code: str, context_bundle: ContextBundle, **kwargs
207
+ ) -> str:
208
+ """Replace the kernel import placeholder."""
209
+ kernel_import = kwargs.get("kernel_import", KernelImportMode.DEFAULT)
210
+
211
+ if kernel_import == KernelImportMode.DEFAULT:
212
+ _, import_statement = _generate_import_statements(
213
+ context_bundle.kernel_info
214
+ )
215
+
216
+ final_stmt = "\n".join(
217
+ [import_statement, ""] + get_function_source(_disable_triton_autotune)
218
+ )
219
+ return code.replace(self.KERNEL_IMPORT_PLACEHOLDER, final_stmt)
220
+ elif kernel_import == KernelImportMode.COPY:
221
+ source_code = context_bundle.kernel_info.source_code
222
+ func_name = context_bundle.kernel_info.function_name
223
+
224
+ if not source_code or not source_code.strip():
225
+ raise ValueError("Kernel source code is empty, cannot use 'copy' mode")
226
+ if not func_name:
227
+ raise ValueError(
228
+ "Cannot determine kernel function name for 'copy' mode"
229
+ )
230
+
231
+ if kernel_import == KernelImportMode.COPY:
232
+ dependent_source_map = get_dependent_source_map(
233
+ context_bundle.kernel_info.function_name,
234
+ context_bundle.kernel_info.file_path,
235
+ )
236
+ # Only add dependent functions if extraction was successful
237
+ if dependent_source_map:
238
+ # Add separator, import statements, and dependent functions
239
+ dependent_code = (
240
+ "\n\n# Dependent functions extracted from source file\n\n"
241
+ )
242
+ dependent_code += "\n\n".join(dependent_source_map.values())
243
+ source_code += "\n\n" + dependent_code
244
+ logger.debug("Appended dependent functions to kernel source code")
245
+
246
+ # Add common imports needed for most Triton kernels
247
+ import_lines = [
248
+ "import torch",
249
+ "import numpy as np",
250
+ "import triton",
251
+ "import triton.language as tl",
252
+ "from typing import List, Tuple",
253
+ "",
254
+ ] + get_function_source(_disable_triton_autotune)
255
+
256
+ # Combine: imports + kernel source code + alias
257
+ embedded_code = "\n".join(import_lines)
258
+ embedded_code += "\n" + source_code
259
+ embedded_code += f"\n\n# Use kernel function directly\nimported_kernel_function = {func_name}"
260
+
261
+ return code.replace(self.KERNEL_IMPORT_PLACEHOLDER, embedded_code)
262
+ elif kernel_import == KernelImportMode.OVERRIDE_TTIR:
263
+ comment = "# Kernel import skipped - using IR override mode with TTIR"
264
+ return code.replace(self.KERNEL_IMPORT_PLACEHOLDER, comment)
265
+ else:
266
+ raise ValueError(f"Unknown kernel_import mode: {kernel_import}")
267
+
268
+ def _replace_utility_functions(
269
+ self, code: str, context_bundle: ContextBundle, **kwargs
270
+ ) -> str:
271
+ """Replace the utility functions placeholder with extracted functions."""
272
+ utility_code = extract_utility_functions()
273
+ return code.replace(self.UTILITY_FUNCTIONS_PLACEHOLDER, utility_code)
274
+
275
+ def _replace_kernel_invocation(
276
+ self, code: str, context_bundle: ContextBundle, **kwargs
277
+ ) -> str:
278
+ """Replace the kernel invocation placeholder."""
279
+ source_code = context_bundle.kernel_info.source_code
280
+ pos_args, kw_args = _parse_kernel_signature(source_code)
281
+ invocation_snippet = _generate_invocation_snippet(pos_args, kw_args)
282
+ return code.replace(self.KERNEL_INVOCATION_PLACEHOLDER, invocation_snippet)
283
+
284
+
285
+ def get_dependent_source_map(
286
+ function_name: str, file_path: str
287
+ ) -> Optional[dict[str, str]]:
288
+ """
289
+ Extract dependent functions and their required imports.
290
+
291
+ Returns:
292
+ A tuple of (functions_dict, import_statements_list) or None if extraction fails.
293
+ - functions_dict: Maps qualified function names to their source code
294
+ - import_statements_list: List of formatted import statements needed by dependent functions
295
+ """
296
+ from pathlib import Path
297
+
298
+ from tritonparse.tp_logger import logger
299
+
300
+ source_path = Path(file_path)
301
+ if not source_path.exists():
302
+ return None
303
+
304
+ try:
305
+ # Use MultiFileCallGraphAnalyzer for multi-file analysis
306
+ from tritonparse.reproducer.multi_file_analyzer import (
307
+ MultiFileCallGraphAnalyzer,
308
+ )
309
+
310
+ analyzer = MultiFileCallGraphAnalyzer(
311
+ entry_file=file_path,
312
+ entry_function=function_name,
313
+ )
314
+ result = analyzer.analyze()
315
+
316
+ logger.info(
317
+ f"Extracted {result.stats.total_functions_found} dependent functions "
318
+ f"from {result.stats.total_files_analyzed} files with "
319
+ f"{result.stats.total_imports} imports"
320
+ )
321
+
322
+ # Print dependent functions' short names
323
+ logger.info("\nDependent functions (short names):")
324
+ for func_name in sorted(result.function_short_names.keys()):
325
+ short_name = result.function_short_names[func_name]
326
+ logger.info(
327
+ " - %s. %s", short_name, result.functions[func_name].splitlines()[0]
328
+ )
329
+
330
+ return result.functions
331
+
332
+ except Exception as e:
333
+ # If AST analysis fails, continue without dependent functions
334
+ logger.warning(f"Failed to extract dependent functions: {e}", exc_info=True)
335
+ return None
File without changes
@@ -0,0 +1,38 @@
1
+ """
2
+ This file is automatically generated by TritonParse reproducer.
3
+ It contains a smallest testing example for a Triton kernel.
4
+ """
5
+
6
+ import logging
7
+
8
+ import torch
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # {{IR_OVERRIDE_SETUP_PLACEHOLDER}}
13
+
14
+ # {{KERNEL_SYSPATH_PLACEHOLDER}}
15
+
16
+ # {{KERNEL_IMPORT_PLACEHOLDER}}
17
+
18
+ # {{UTILITY_FUNCTIONS_PLACEHOLDER}}
19
+
20
+
21
+ def launch_kernel():
22
+ script_dir = Path(__file__).resolve().parent # noqa: F821
23
+ json_file = script_dir / "{{JSON_FILE_NAME_PLACEHOLDER}}"
24
+ grid, args_dict = create_args_from_json_file(str(json_file)) # noqa: F821
25
+
26
+ print("Generated kernel arguments dictionary:")
27
+ for name, arg in args_dict.items():
28
+ print(f" {name}: {arg}")
29
+ print(f"Grid: {grid}")
30
+
31
+ # {{KERNEL_INVOCATION_PLACEHOLDER}}
32
+
33
+ torch.cuda.synchronize()
34
+ print("Kernel execution finished.")
35
+
36
+
37
+ if __name__ == "__main__":
38
+ launch_kernel()
@@ -0,0 +1,59 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from importlib.resources import files as pkg_files
4
+ from pathlib import Path
5
+ from typing import List
6
+
7
+
8
+ BUILTIN_TEMPLATES_PACKAGE = "tritonparse.reproducer.templates"
9
+
10
+
11
+ def _is_path_like(template_arg: str) -> bool:
12
+ return "/" in template_arg or "\\" in template_arg or template_arg.endswith(".py")
13
+
14
+
15
+ def _read_file_text(path: Path) -> str:
16
+ p = path.expanduser().resolve()
17
+ if not p.exists() or not p.is_file():
18
+ raise FileNotFoundError(f"Template not found: {p}")
19
+ return p.read_text(encoding="utf-8")
20
+
21
+
22
+ def _read_builtin_template_text(name: str) -> str:
23
+ resource = pkg_files(BUILTIN_TEMPLATES_PACKAGE).joinpath(f"{name}.py")
24
+ # resource may not exist if an invalid name is provided
25
+ try:
26
+ with resource.open("r", encoding="utf-8") as f:
27
+ return f.read()
28
+ except FileNotFoundError as exc:
29
+ available = ", ".join(list_builtin_templates())
30
+ raise FileNotFoundError(
31
+ f"Builtin template '{name}' not found. Available: {available}"
32
+ ) from exc
33
+
34
+
35
+ def list_builtin_templates() -> List[str]:
36
+ """
37
+ Return the list of available builtin template names (without .py suffix).
38
+ """
39
+ names: List[str] = []
40
+ for entry in pkg_files(BUILTIN_TEMPLATES_PACKAGE).iterdir():
41
+ try:
42
+ if entry.is_file():
43
+ filename = entry.name
44
+ if filename.endswith(".py") and not filename.startswith("__"):
45
+ names.append(filename[:-3])
46
+ except (OSError, FileNotFoundError):
47
+ # Defensive: in case entry access fails in some environments
48
+ continue
49
+ names.sort()
50
+ return names
51
+
52
+
53
+ def load_template_code(template_arg: str) -> str:
54
+ """
55
+ Load template code by name (builtin, without .py) or by filesystem path.
56
+ """
57
+ if _is_path_like(template_arg):
58
+ return _read_file_text(Path(template_arg))
59
+ return _read_builtin_template_text(template_arg)
@@ -0,0 +1,106 @@
1
+ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Any, Callable, Dict, Optional, Tuple
6
+
7
+ import torch
8
+ from tritonbench.utils.triton_op import (
9
+ BenchmarkOperator,
10
+ register_benchmark,
11
+ REGISTERED_X_VALS,
12
+ )
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ imported_kernel_function: Optional[Callable[[Tuple[int], Dict[str, Any]], None]] = None
18
+
19
+ # {{IR_OVERRIDE_SETUP_PLACEHOLDER}}
20
+
21
+ # {{KERNEL_SYSPATH_PLACEHOLDER}}
22
+
23
+ # {{KERNEL_IMPORT_PLACEHOLDER}}
24
+
25
+ # {{UTILITY_FUNCTIONS_PLACEHOLDER}}
26
+
27
+ assert imported_kernel_function is not None, "imported_kernel_function is missing"
28
+
29
+ KERNEL_NAME = "{{KERNEL_NAME_PLACEHOLDER}}"
30
+ REPRO_CONTEXT_FILE_NAME = "{{JSON_FILE_NAME_PLACEHOLDER}}"
31
+
32
+
33
+ def _get_launch_kernel_args() -> Tuple[Tuple[int], Dict[str, Any]]:
34
+ script_dir = Path(__file__).resolve().parent # noqa: F821
35
+ json_file = script_dir / REPRO_CONTEXT_FILE_NAME
36
+
37
+ grid, args_dict = create_args_from_json_file(json_file) # noqa: F821, F841
38
+
39
+ print("Recorded kernel arguments dictionary:")
40
+ for name, arg in args_dict.items():
41
+ if isinstance(arg, torch.Tensor):
42
+ print(
43
+ f" {name}: Tensor: {arg.shape} {arg.dtype} stride: {arg.stride()}, is_contiguous: {arg.is_contiguous()}"
44
+ )
45
+ else:
46
+ print(f" {name}: {arg}")
47
+ print(f"Grid: {grid}")
48
+
49
+ return tuple(grid), args_dict
50
+
51
+
52
+ grid, args_dict = _get_launch_kernel_args()
53
+
54
+
55
+ def _launch_kernel(grid: tuple[int], args_dict: dict[str, Any]):
56
+ try:
57
+ assert grid is not None
58
+ assert args_dict is not None
59
+
60
+ # {{KERNEL_INVOCATION_PLACEHOLDER}}
61
+
62
+ except Exception as e:
63
+ print(f"Error: {e}")
64
+ print("Failed to launch kernel!")
65
+
66
+
67
+ # HACK: @register_x_val doesn't allow us to pass `operator_name`` as a parameter
68
+ tensor_args = {k: v for k, v in args_dict.items() if isinstance(v, torch.Tensor)}
69
+ x_vals_label = ", ".join(tensor_args.keys())
70
+ REGISTERED_X_VALS[KERNEL_NAME] = x_vals_label
71
+
72
+
73
+ class Operator(BenchmarkOperator):
74
+ @register_benchmark(operator_name=KERNEL_NAME)
75
+ def run_kernel(self, grid, args_dict):
76
+ return lambda: _launch_kernel(grid, args_dict)
77
+
78
+ def get_input_iter(self):
79
+ yield {"grid": grid, "args_dict": args_dict}
80
+
81
+ def get_x_val(self, example_inputs):
82
+ tensors_shapes = [
83
+ tuple(v.shape)
84
+ for v in example_inputs["args_dict"].values()
85
+ if isinstance(v, torch.Tensor)
86
+ ]
87
+ return tuple(tensors_shapes)
88
+
89
+
90
+ if __name__ == "__main__":
91
+ print("do_benchmark...")
92
+
93
+ args = [
94
+ "--benchmark-name",
95
+ KERNEL_NAME,
96
+ ]
97
+
98
+ from tritonbench.utils.parser import get_parser
99
+
100
+ parser = get_parser(args)
101
+ tb_args, extra_args = parser.parse_known_args(args)
102
+ bench = Operator(tb_args, extra_args)
103
+ bench.run()
104
+
105
+ print(bench.output)
106
+ print("Benchmark completed successfully!")