tritonparse 0.3.2.dev20251210071601__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/__init__.py +0 -0
- tritonparse/__main__.py +7 -0
- tritonparse/cli.py +110 -0
- tritonparse/common.py +409 -0
- tritonparse/context_manager.py +64 -0
- tritonparse/event_diff.py +122 -0
- tritonparse/extract_source_mappings.py +49 -0
- tritonparse/info/__init__.py +30 -0
- tritonparse/info/cli.py +121 -0
- tritonparse/info/kernel_query.py +209 -0
- tritonparse/info/parse_helper.py +70 -0
- tritonparse/ir_analysis.py +427 -0
- tritonparse/ir_parser.py +365 -0
- tritonparse/mapper.py +102 -0
- tritonparse/reproducer/__init__.py +0 -0
- tritonparse/reproducer/ast_analyzer.py +636 -0
- tritonparse/reproducer/cli.py +72 -0
- tritonparse/reproducer/consolidated_result.py +52 -0
- tritonparse/reproducer/function_extractor.py +228 -0
- tritonparse/reproducer/import_info.py +25 -0
- tritonparse/reproducer/import_parser.py +178 -0
- tritonparse/reproducer/import_resolver.py +151 -0
- tritonparse/reproducer/ingestion/ndjson.py +237 -0
- tritonparse/reproducer/multi_file_analyzer.py +824 -0
- tritonparse/reproducer/orchestrator.py +110 -0
- tritonparse/reproducer/placeholder_replacer.py +335 -0
- tritonparse/reproducer/templates/__init__.py +0 -0
- tritonparse/reproducer/templates/example.py +38 -0
- tritonparse/reproducer/templates/loader.py +59 -0
- tritonparse/reproducer/templates/tritonbench.py +106 -0
- tritonparse/reproducer/templates/utils.py +48 -0
- tritonparse/reproducer/tests/__init__.py +0 -0
- tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
- tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
- tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
- tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
- tritonparse/reproducer/tests/test_import_parser.py +164 -0
- tritonparse/reproducer/tests/test_import_resolver.py +88 -0
- tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
- tritonparse/reproducer/types.py +20 -0
- tritonparse/reproducer/utils.py +580 -0
- tritonparse/shared_vars.py +12 -0
- tritonparse/source_type.py +56 -0
- tritonparse/sourcemap_utils.py +96 -0
- tritonparse/structured_logging.py +1634 -0
- tritonparse/tools/__init__.py +0 -0
- tritonparse/tools/decompress_bin_ndjson.py +120 -0
- tritonparse/tools/disasm.py +81 -0
- tritonparse/tools/extract_irs.py +244 -0
- tritonparse/tools/format_fix.py +151 -0
- tritonparse/tools/load_tensor.py +76 -0
- tritonparse/tools/prettify_ndjson.py +334 -0
- tritonparse/tools/readme.md +37 -0
- tritonparse/tp_logger.py +9 -0
- tritonparse/trace_processor.py +367 -0
- tritonparse/utils.py +155 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
ImportResolver: Resolves Python import statements to file paths.
|
|
5
|
+
|
|
6
|
+
This module provides functionality to resolve import statements to their
|
|
7
|
+
corresponding file paths using Python's importlib system, without actually
|
|
8
|
+
importing the modules (no side effects).
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import importlib.util
|
|
12
|
+
import logging
|
|
13
|
+
import sys
|
|
14
|
+
from importlib.machinery import ModuleSpec
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ImportResolver:
|
|
20
|
+
"""
|
|
21
|
+
Resolves import statements to absolute file paths using importlib.
|
|
22
|
+
|
|
23
|
+
Uses Python's import resolution system (importlib.util.find_spec) to
|
|
24
|
+
locate module files without actually importing them, avoiding side effects.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, project_root: str) -> None:
|
|
28
|
+
"""
|
|
29
|
+
Initialize the ImportResolver.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
project_root: Absolute path to project root directory
|
|
33
|
+
(e.g., /path/to/project/root)
|
|
34
|
+
"""
|
|
35
|
+
self.project_root = project_root
|
|
36
|
+
|
|
37
|
+
# Ensure project_root is in sys.path for importlib to find modules
|
|
38
|
+
if project_root not in sys.path:
|
|
39
|
+
sys.path.insert(0, project_root)
|
|
40
|
+
|
|
41
|
+
# Common third-party modules to exclude from analysis
|
|
42
|
+
self.external_modules: set[str] = {
|
|
43
|
+
"torch",
|
|
44
|
+
"triton",
|
|
45
|
+
"numpy",
|
|
46
|
+
"pandas",
|
|
47
|
+
"transformers",
|
|
48
|
+
"tqdm",
|
|
49
|
+
"typing_extensions",
|
|
50
|
+
"pydantic",
|
|
51
|
+
"requests",
|
|
52
|
+
"pytest",
|
|
53
|
+
"unittest",
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
def resolve_import(
|
|
57
|
+
self,
|
|
58
|
+
module_name: str,
|
|
59
|
+
package: str | None = None,
|
|
60
|
+
) -> tuple[str | None, bool]:
|
|
61
|
+
"""
|
|
62
|
+
Resolve import to absolute file path using importlib.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
module_name: Module to import (e.g., "torch" or "pytorch.tritonparse")
|
|
66
|
+
package: Package context for relative imports (e.g., "pytorch.tritonparse")
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Tuple of (file_path, is_external):
|
|
70
|
+
- file_path: Absolute path to .py file, or None if not resolvable
|
|
71
|
+
- is_external: True if third-party/built-in module
|
|
72
|
+
|
|
73
|
+
Examples:
|
|
74
|
+
>>> resolver = ImportResolver("/data/.../fbcode")
|
|
75
|
+
>>> path, is_ext = resolver.resolve_import("pytorch.tritonparse.module")
|
|
76
|
+
>>> # path: "/data/.../fbcode/pytorch/tritonparse/module.py"
|
|
77
|
+
>>> # is_ext: False
|
|
78
|
+
|
|
79
|
+
>>> path, is_ext = resolver.resolve_import("torch")
|
|
80
|
+
>>> # path: None
|
|
81
|
+
>>> # is_ext: True
|
|
82
|
+
"""
|
|
83
|
+
# Check if it's a known external module
|
|
84
|
+
base_module = module_name.split(".")[0]
|
|
85
|
+
if base_module in self.external_modules:
|
|
86
|
+
logger.debug(
|
|
87
|
+
"Import '%s' marked as external (known third-party module)",
|
|
88
|
+
module_name,
|
|
89
|
+
)
|
|
90
|
+
return None, True
|
|
91
|
+
|
|
92
|
+
try:
|
|
93
|
+
# Use importlib to find the module (without importing it)
|
|
94
|
+
spec: ModuleSpec | None = importlib.util.find_spec(module_name, package)
|
|
95
|
+
|
|
96
|
+
if spec is None or spec.origin is None:
|
|
97
|
+
# Module not found or built-in (has no file)
|
|
98
|
+
logger.debug(
|
|
99
|
+
"Import '%s' not resolvable (spec=%s, origin=%s)",
|
|
100
|
+
module_name,
|
|
101
|
+
spec,
|
|
102
|
+
spec.origin if spec else None,
|
|
103
|
+
)
|
|
104
|
+
return None, True
|
|
105
|
+
|
|
106
|
+
origin = spec.origin
|
|
107
|
+
|
|
108
|
+
# Skip special cases like frozen/built-in modules
|
|
109
|
+
if origin == "frozen" or origin == "built-in":
|
|
110
|
+
logger.debug("Import '%s' is a built-in module", module_name)
|
|
111
|
+
return None, True
|
|
112
|
+
|
|
113
|
+
# Check if the module is within project root
|
|
114
|
+
is_internal = origin.startswith(self.project_root)
|
|
115
|
+
|
|
116
|
+
if is_internal:
|
|
117
|
+
logger.debug(
|
|
118
|
+
"Import '%s' resolved to INTERNAL file: %s",
|
|
119
|
+
module_name,
|
|
120
|
+
origin,
|
|
121
|
+
)
|
|
122
|
+
return origin, False
|
|
123
|
+
else:
|
|
124
|
+
# External module (outside project)
|
|
125
|
+
logger.warning(
|
|
126
|
+
"Import '%s' resolved to file '%s' which is OUTSIDE project_root '%s'. "
|
|
127
|
+
"This import will be skipped. If this is unexpected, verify your --code-root parameter matches "
|
|
128
|
+
"the directory containing your source files.",
|
|
129
|
+
module_name,
|
|
130
|
+
origin,
|
|
131
|
+
self.project_root,
|
|
132
|
+
)
|
|
133
|
+
return None, True
|
|
134
|
+
|
|
135
|
+
except (ImportError, ValueError, AttributeError) as e:
|
|
136
|
+
# Module doesn't exist or can't be resolved
|
|
137
|
+
logger.debug("Import '%s' failed to resolve: %s", module_name, str(e))
|
|
138
|
+
return None, True
|
|
139
|
+
|
|
140
|
+
def is_external_module(self, module_name: str) -> bool:
|
|
141
|
+
"""
|
|
142
|
+
Check if a module is external (third-party or built-in).
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
module_name: Name of the module to check
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
True if module is external, False if it's internal (project)
|
|
149
|
+
"""
|
|
150
|
+
base_module = module_name.split(".")[0]
|
|
151
|
+
return base_module in self.external_modules
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
from tritonparse.tp_logger import logger
|
|
7
|
+
|
|
8
|
+
# Sentinel object to mark arguments that should be skipped during processing
|
|
9
|
+
_SKIP = object()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class KernelInfo:
|
|
14
|
+
"""Information about a Triton kernel extracted from compilation events."""
|
|
15
|
+
|
|
16
|
+
file_path: str
|
|
17
|
+
function_name: str
|
|
18
|
+
source_code: str
|
|
19
|
+
call_stack: List[Dict[str, Any]]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class ContextBundle:
|
|
24
|
+
"""Bundle of all context information needed to reproduce a kernel launch."""
|
|
25
|
+
|
|
26
|
+
kernel_info: KernelInfo
|
|
27
|
+
compile: Dict[str, Any]
|
|
28
|
+
launch: Dict[str, Any]
|
|
29
|
+
args: Dict[str, Any]
|
|
30
|
+
tensor_args: Dict[str, Any]
|
|
31
|
+
raw_launch_event: Dict[str, Any]
|
|
32
|
+
raw_comp_event: Dict[str, Any]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def get_launch_and_compilation_events(
|
|
36
|
+
events: List[Dict[str, Any]], line_index: Optional[int] = None
|
|
37
|
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
38
|
+
"""
|
|
39
|
+
Extract launch and compilation events from the event list.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
events: List of parsed event dictionaries.
|
|
43
|
+
line_index: 0-based index of the launch event to process.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Tuple of (launch_event, compilation_event).
|
|
47
|
+
|
|
48
|
+
Raises:
|
|
49
|
+
ValueError: If the event at line_index is not a launch event.
|
|
50
|
+
RuntimeError: If compilation event cannot be found or is ambiguous.
|
|
51
|
+
"""
|
|
52
|
+
if line_index is None or line_index >= len(events):
|
|
53
|
+
raise ValueError(f"Invalid line_index: {line_index}")
|
|
54
|
+
|
|
55
|
+
launch_event = events[line_index]
|
|
56
|
+
if launch_event["event_type"] != "launch":
|
|
57
|
+
raise ValueError(f"Event at index {line_index} is not a launch event")
|
|
58
|
+
|
|
59
|
+
comp_meta = launch_event.get("compilation_metadata", {})
|
|
60
|
+
comp_hash = comp_meta.get("hash")
|
|
61
|
+
if not comp_hash:
|
|
62
|
+
raise RuntimeError("Could not find compilation hash in launch event.")
|
|
63
|
+
|
|
64
|
+
comp_event = None
|
|
65
|
+
for event in events:
|
|
66
|
+
if (
|
|
67
|
+
event["event_type"] == "compilation"
|
|
68
|
+
and event.get("payload", {}).get("metadata", {}).get("hash") == comp_hash
|
|
69
|
+
):
|
|
70
|
+
comp_event = event
|
|
71
|
+
break
|
|
72
|
+
if not comp_event:
|
|
73
|
+
raise RuntimeError(f"Could not find compilation event for hash {comp_hash}.")
|
|
74
|
+
return launch_event, comp_event
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def get_kernel_info(comp_event: Dict[str, Any]) -> KernelInfo:
|
|
78
|
+
"""
|
|
79
|
+
Extract kernel information from a compilation event.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
comp_event: Compilation event dictionary containing kernel metadata.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
KernelInfo object with extracted kernel details.
|
|
86
|
+
|
|
87
|
+
Raises:
|
|
88
|
+
RuntimeError: If file path or function name cannot be resolved.
|
|
89
|
+
"""
|
|
90
|
+
payload = comp_event.get("payload") or {}
|
|
91
|
+
py_source = payload.get("python_source") or {}
|
|
92
|
+
code = py_source.get("code", "")
|
|
93
|
+
|
|
94
|
+
# Extract file path and function name
|
|
95
|
+
file_path = py_source.get("file_path")
|
|
96
|
+
# The function name is in the compilation metadata payload
|
|
97
|
+
func_name = (comp_event.get("payload", {}).get("metadata") or {}).get("name")
|
|
98
|
+
|
|
99
|
+
# Find '@triton.jit' decorator and slice the string from there
|
|
100
|
+
jit_marker = "@triton.jit"
|
|
101
|
+
jit_pos = code.find(jit_marker)
|
|
102
|
+
if jit_pos != -1:
|
|
103
|
+
code = code[jit_pos:]
|
|
104
|
+
logger.debug("Extracted kernel source starting from '@triton.jit'.")
|
|
105
|
+
|
|
106
|
+
if not file_path or not func_name:
|
|
107
|
+
raise RuntimeError(
|
|
108
|
+
"Could not resolve kernel file path or function name from compilation event."
|
|
109
|
+
" The import-based strategy cannot proceed."
|
|
110
|
+
)
|
|
111
|
+
return KernelInfo(file_path, func_name, code, comp_event.get("stack", []))
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _decode_arg(raw: Any) -> Any:
|
|
115
|
+
"""
|
|
116
|
+
Decode a raw argument value from event data.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
raw: Raw argument value from event data.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Decoded argument value, or _SKIP sentinel for tensors.
|
|
123
|
+
"""
|
|
124
|
+
if not isinstance(raw, dict):
|
|
125
|
+
return raw
|
|
126
|
+
t = raw.get("type")
|
|
127
|
+
if t == "tensor":
|
|
128
|
+
return _SKIP
|
|
129
|
+
if t == "NoneType":
|
|
130
|
+
return None
|
|
131
|
+
return raw.get("value", raw.get("repr"))
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _pack_args(args: Dict[str, Any]) -> Dict[str, Any]:
|
|
135
|
+
"""
|
|
136
|
+
Pack argument values into a standardized format.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
args: Dictionary of argument names to values.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
Dictionary with packed argument information including type and metadata.
|
|
143
|
+
"""
|
|
144
|
+
packed = {}
|
|
145
|
+
for k, v in args.items():
|
|
146
|
+
t = v.get("type") if isinstance(v, dict) else None
|
|
147
|
+
if t == "tensor":
|
|
148
|
+
packed[k] = {
|
|
149
|
+
"type": "tensor",
|
|
150
|
+
"shape": v.get("shape") if isinstance(v, dict) else None,
|
|
151
|
+
"dtype": v.get("dtype") if isinstance(v, dict) else None,
|
|
152
|
+
"device": v.get("device") if isinstance(v, dict) else None,
|
|
153
|
+
"stride": v.get("stride") if isinstance(v, dict) else None,
|
|
154
|
+
"is_contiguous": (
|
|
155
|
+
v.get("is_contiguous") if isinstance(v, dict) else None
|
|
156
|
+
),
|
|
157
|
+
"numel": v.get("numel") if isinstance(v, dict) else None,
|
|
158
|
+
}
|
|
159
|
+
else:
|
|
160
|
+
# scalar / NoneType etc
|
|
161
|
+
if isinstance(v, dict):
|
|
162
|
+
packed[k] = {
|
|
163
|
+
"type": v.get("type"),
|
|
164
|
+
"value": v.get("value", v.get("repr")),
|
|
165
|
+
}
|
|
166
|
+
else:
|
|
167
|
+
packed[k] = {
|
|
168
|
+
"type": None,
|
|
169
|
+
"value": v,
|
|
170
|
+
}
|
|
171
|
+
return packed
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def build_context_bundle(
|
|
175
|
+
events: List[Dict[str, Any]], line_index: Optional[int] = None
|
|
176
|
+
):
|
|
177
|
+
"""
|
|
178
|
+
Build a complete context bundle from events and line index.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
events: List of parsed event dictionaries.
|
|
182
|
+
line_index: 0-based index of the launch event to process.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
ContextBundle containing all information needed to reproduce the kernel launch.
|
|
186
|
+
|
|
187
|
+
Raises:
|
|
188
|
+
ValueError: If line_index is invalid or event is not a launch event.
|
|
189
|
+
RuntimeError: If compilation event cannot be found.
|
|
190
|
+
"""
|
|
191
|
+
launch_event, comp_event = get_launch_and_compilation_events(events, line_index)
|
|
192
|
+
kernel_info = get_kernel_info(comp_event)
|
|
193
|
+
grid = launch_event.get("grid")
|
|
194
|
+
extracted_args = launch_event.get("extracted_args", {})
|
|
195
|
+
comp_meta = launch_event.get("compilation_metadata", {})
|
|
196
|
+
|
|
197
|
+
# Compile metadata subset we care about
|
|
198
|
+
compile_block = {
|
|
199
|
+
"num_warps": comp_meta.get("num_warps"),
|
|
200
|
+
"num_stages": comp_meta.get("num_stages"),
|
|
201
|
+
"arch": comp_meta.get("arch"),
|
|
202
|
+
"backend": comp_meta.get("backend_name") or comp_meta.get("backend"),
|
|
203
|
+
"triton_version": comp_meta.get("triton_version"),
|
|
204
|
+
"hash": comp_meta.get("hash"),
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
# kwargs: include constexpr + explicit scalars used for launch (skip tensor args)
|
|
208
|
+
kwargs = {}
|
|
209
|
+
for k, v in extracted_args.items():
|
|
210
|
+
val = _decode_arg(v)
|
|
211
|
+
if val is _SKIP:
|
|
212
|
+
continue
|
|
213
|
+
kwargs[k] = val
|
|
214
|
+
|
|
215
|
+
# tensor args: only tensors
|
|
216
|
+
raw_tensor_args = {
|
|
217
|
+
k: v
|
|
218
|
+
for k, v in extracted_args.items()
|
|
219
|
+
if isinstance(v, dict) and v.get("type") == "tensor"
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
primitive_args = _pack_args(extracted_args)
|
|
223
|
+
tensor_args = _pack_args(raw_tensor_args)
|
|
224
|
+
launch_block = {
|
|
225
|
+
"grid": grid,
|
|
226
|
+
"kwargs": kwargs,
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
return ContextBundle(
|
|
230
|
+
kernel_info,
|
|
231
|
+
compile_block,
|
|
232
|
+
launch_block,
|
|
233
|
+
primitive_args,
|
|
234
|
+
tensor_args,
|
|
235
|
+
launch_event,
|
|
236
|
+
comp_event,
|
|
237
|
+
)
|