tritonparse 0.3.2.dev20251210071601__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/__init__.py +0 -0
- tritonparse/__main__.py +7 -0
- tritonparse/cli.py +110 -0
- tritonparse/common.py +409 -0
- tritonparse/context_manager.py +64 -0
- tritonparse/event_diff.py +122 -0
- tritonparse/extract_source_mappings.py +49 -0
- tritonparse/info/__init__.py +30 -0
- tritonparse/info/cli.py +121 -0
- tritonparse/info/kernel_query.py +209 -0
- tritonparse/info/parse_helper.py +70 -0
- tritonparse/ir_analysis.py +427 -0
- tritonparse/ir_parser.py +365 -0
- tritonparse/mapper.py +102 -0
- tritonparse/reproducer/__init__.py +0 -0
- tritonparse/reproducer/ast_analyzer.py +636 -0
- tritonparse/reproducer/cli.py +72 -0
- tritonparse/reproducer/consolidated_result.py +52 -0
- tritonparse/reproducer/function_extractor.py +228 -0
- tritonparse/reproducer/import_info.py +25 -0
- tritonparse/reproducer/import_parser.py +178 -0
- tritonparse/reproducer/import_resolver.py +151 -0
- tritonparse/reproducer/ingestion/ndjson.py +237 -0
- tritonparse/reproducer/multi_file_analyzer.py +824 -0
- tritonparse/reproducer/orchestrator.py +110 -0
- tritonparse/reproducer/placeholder_replacer.py +335 -0
- tritonparse/reproducer/templates/__init__.py +0 -0
- tritonparse/reproducer/templates/example.py +38 -0
- tritonparse/reproducer/templates/loader.py +59 -0
- tritonparse/reproducer/templates/tritonbench.py +106 -0
- tritonparse/reproducer/templates/utils.py +48 -0
- tritonparse/reproducer/tests/__init__.py +0 -0
- tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
- tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
- tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
- tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
- tritonparse/reproducer/tests/test_import_parser.py +164 -0
- tritonparse/reproducer/tests/test_import_resolver.py +88 -0
- tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
- tritonparse/reproducer/types.py +20 -0
- tritonparse/reproducer/utils.py +580 -0
- tritonparse/shared_vars.py +12 -0
- tritonparse/source_type.py +56 -0
- tritonparse/sourcemap_utils.py +96 -0
- tritonparse/structured_logging.py +1634 -0
- tritonparse/tools/__init__.py +0 -0
- tritonparse/tools/decompress_bin_ndjson.py +120 -0
- tritonparse/tools/disasm.py +81 -0
- tritonparse/tools/extract_irs.py +244 -0
- tritonparse/tools/format_fix.py +151 -0
- tritonparse/tools/load_tensor.py +76 -0
- tritonparse/tools/prettify_ndjson.py +334 -0
- tritonparse/tools/readme.md +37 -0
- tritonparse/tp_logger.py +9 -0
- tritonparse/trace_processor.py +367 -0
- tritonparse/utils.py +155 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
tritonparse/ir_parser.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from typing import Any, Dict, List
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger("SourceMapping")
|
|
10
|
+
|
|
11
|
+
# the definition of the #loc directive. they are in the bottom of the IR files
|
|
12
|
+
# Example:#loc2 = loc("/tmp/torchinductor_yhao/yp/abcdef.py":20:28)
|
|
13
|
+
# Note: This should only match numbered locs like #loc1, #loc2, not bare #loc
|
|
14
|
+
LOC_PATTERN = re.compile(r'#loc(\d+) = loc\("([^"]+)":(\d+):(\d+)\)')
|
|
15
|
+
|
|
16
|
+
# the reference to the #loc directive. they are in the end of lines of the IR files
|
|
17
|
+
# Example: loc(#loc2)
|
|
18
|
+
CODE_LOC_PATTERN = re.compile(r".*loc\(#loc(\d*)\)\s*$")
|
|
19
|
+
|
|
20
|
+
# this pattern is used in the first function arguments line.
|
|
21
|
+
DIRECT_FILE_PATTERN = re.compile(r'.*loc\("([^"]+)":(\d+):(\d+)\)')
|
|
22
|
+
|
|
23
|
+
# the definition of the PTX loc directive.
|
|
24
|
+
# Example: .loc 1 0 50 // abcdef.py:0:50
|
|
25
|
+
PTX_LOC_PATTERN = re.compile(
|
|
26
|
+
r"^\s*\.loc\s+\d+\s+(\d+)\s+(\d+)\s+//\s*(.+?):(\d+):(\d+)"
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
# the definition of the AMDGCN loc directive.
|
|
30
|
+
# Example: .loc 1 32 30 ; abcd.py:32:30
|
|
31
|
+
# .loc 1 32 46 is_stmt 0 ; abcd.py:32:46
|
|
32
|
+
AMDGCN_LOC_PATTERN = re.compile(
|
|
33
|
+
r".*loc\s+(\d+)\s+(\d+)\s+(\d+)(?:\s+[^;]*)?;\s*(.+?):(\d+):(\d+)"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# alias loc definitions in TTGIR/TTIR
|
|
38
|
+
# Example: #loc16 = loc("pid"(#loc2))
|
|
39
|
+
# Example: #loc13 = loc("x_ptr"(#loc)) - bare #loc without number
|
|
40
|
+
ALIAS_WITH_NAME_PATTERN = re.compile(
|
|
41
|
+
r'#loc(\d+)\s*=\s*loc\("([^"]+)"\s*\(\s*#loc(\d*)\s*\)\s*\)'
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Example: #loc20 = loc(#loc16)
|
|
45
|
+
ALIAS_SIMPLE_PATTERN = re.compile(r"#loc(\d+)\s*=\s*loc\(\s*#loc(\d*)\s*\)")
|
|
46
|
+
|
|
47
|
+
# Callsite loc definitions in TTIR/TTGIR
|
|
48
|
+
# Example: #loc220 = loc(callsite(#loc57 at #loc190))
|
|
49
|
+
# Captures: loc_id, callee_loc_id, caller_loc_id
|
|
50
|
+
# Note: Uses (\d*) to match optional numbers (for bare #loc references)
|
|
51
|
+
CALLSITE_PATTERN = re.compile(
|
|
52
|
+
r"#loc(\d+)\s*=\s*loc\(\s*callsite\(\s*#loc(\d*)\s+at\s+#loc(\d*)\s*\)\s*\)"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def extract_loc_definitions(ir_content: str) -> Dict[str, Dict[str, Any]]:
|
|
57
|
+
"""
|
|
58
|
+
Extracts location definitions from the given IR content.
|
|
59
|
+
|
|
60
|
+
This function searches for #loc directives in the provided IR content string.
|
|
61
|
+
It identifies the main #loc directive, which is a special case located at the top
|
|
62
|
+
of the IR files, and any subsequent #loc directives that define source file locations.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
ir_content (str): The content of the IR file as a string.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Dict[str, Dict[str, Any]]: A dictionary mapping location IDs to their corresponding
|
|
69
|
+
file names, line numbers, and column numbers.
|
|
70
|
+
"""
|
|
71
|
+
locations = {}
|
|
72
|
+
# The first #loc directive is a special case. It locates at the top of the IR files
|
|
73
|
+
# Store it with empty string "" as key to avoid conflict with #loc1
|
|
74
|
+
main_match = re.search(r'#loc = loc\("([^"]+)":(\d+):(\d+)\)', ir_content)
|
|
75
|
+
if main_match:
|
|
76
|
+
locations[""] = {
|
|
77
|
+
"file": main_match.group(1),
|
|
78
|
+
"line": int(main_match.group(2)),
|
|
79
|
+
"column": int(main_match.group(3)),
|
|
80
|
+
}
|
|
81
|
+
# #loc1 = loc(unknown) is another special case. We ignore it.
|
|
82
|
+
for loc_id, filename, line, col in LOC_PATTERN.findall(ir_content):
|
|
83
|
+
key = loc_id
|
|
84
|
+
locations[key] = {"file": filename, "line": int(line), "column": int(col)}
|
|
85
|
+
|
|
86
|
+
# Handle alias-style loc definitions that reference another #loc
|
|
87
|
+
# Build alias map first: alias_id -> target_id
|
|
88
|
+
alias_map: Dict[str, str] = {}
|
|
89
|
+
for m in ALIAS_WITH_NAME_PATTERN.finditer(ir_content):
|
|
90
|
+
alias_id, _name, target_id = m.groups()
|
|
91
|
+
# Empty target_id means bare #loc, map to "" (main loc key)
|
|
92
|
+
alias_map[alias_id] = target_id or ""
|
|
93
|
+
for m in ALIAS_SIMPLE_PATTERN.finditer(ir_content):
|
|
94
|
+
alias_id, target_id = m.groups()
|
|
95
|
+
# Empty target_id means bare #loc, map to "" (main loc key)
|
|
96
|
+
alias_map[alias_id] = target_id or ""
|
|
97
|
+
|
|
98
|
+
# Build definition line map and alias name map by scanning lines
|
|
99
|
+
def_line_map: Dict[str, int] = {}
|
|
100
|
+
alias_name_map: Dict[str, str] = {}
|
|
101
|
+
main_loc_line: int = 0
|
|
102
|
+
for i, line in enumerate(ir_content.split("\n"), start=1):
|
|
103
|
+
if m := ALIAS_WITH_NAME_PATTERN.search(line):
|
|
104
|
+
alias_id, name, target_id = m.groups()
|
|
105
|
+
def_line_map[alias_id] = i
|
|
106
|
+
alias_name_map[alias_id] = name
|
|
107
|
+
# ensure alias map is populated even if only found in line scan
|
|
108
|
+
# Empty target_id means bare #loc, map to "" (main loc key)
|
|
109
|
+
alias_map.setdefault(alias_id, target_id or "")
|
|
110
|
+
elif m := ALIAS_SIMPLE_PATTERN.search(line):
|
|
111
|
+
alias_id, target_id = m.groups()
|
|
112
|
+
def_line_map[alias_id] = i
|
|
113
|
+
# Empty target_id means bare #loc, map to "" (main loc key)
|
|
114
|
+
alias_map.setdefault(alias_id, target_id or "")
|
|
115
|
+
if m2 := LOC_PATTERN.search(line):
|
|
116
|
+
base_id, _fn, _ln, _col = m2.groups()
|
|
117
|
+
def_line_map[base_id] = i
|
|
118
|
+
if re.search(r'#loc\s*=\s*loc\("[^"]+":\d+:\d+\)', line):
|
|
119
|
+
# main #loc = loc("file":line:col) without id
|
|
120
|
+
main_loc_line = main_loc_line or i
|
|
121
|
+
|
|
122
|
+
# Resolve aliases to base locations (file/line/column)
|
|
123
|
+
resolving_stack = set()
|
|
124
|
+
|
|
125
|
+
def resolve_alias(current_id: str) -> Dict[str, Any]:
|
|
126
|
+
# Already a concrete location
|
|
127
|
+
if current_id in locations:
|
|
128
|
+
return locations[current_id]
|
|
129
|
+
# Detect cycles
|
|
130
|
+
if current_id in resolving_stack:
|
|
131
|
+
return {}
|
|
132
|
+
resolving_stack.add(current_id)
|
|
133
|
+
parent_id = alias_map.get(current_id)
|
|
134
|
+
result: Dict[str, Any] = {}
|
|
135
|
+
if parent_id is not None:
|
|
136
|
+
base = resolve_alias(parent_id)
|
|
137
|
+
if base:
|
|
138
|
+
# copy to avoid sharing the same dict by reference
|
|
139
|
+
result = {
|
|
140
|
+
"file": base.get("file"),
|
|
141
|
+
"line": base.get("line"),
|
|
142
|
+
"column": base.get("column"),
|
|
143
|
+
}
|
|
144
|
+
locations[current_id] = result
|
|
145
|
+
resolving_stack.remove(current_id)
|
|
146
|
+
return result
|
|
147
|
+
|
|
148
|
+
# Resolve aliases and attach alias metadata
|
|
149
|
+
for alias_id, target_id in alias_map.items():
|
|
150
|
+
if alias_id not in locations:
|
|
151
|
+
resolve_alias(alias_id)
|
|
152
|
+
|
|
153
|
+
# Collect callsite definitions
|
|
154
|
+
callsite_defs = []
|
|
155
|
+
for i, line in enumerate(ir_content.split("\n"), start=1):
|
|
156
|
+
if m := CALLSITE_PATTERN.search(line):
|
|
157
|
+
loc_id, callee_id, caller_id = m.groups()
|
|
158
|
+
# Empty strings map to main loc key ""
|
|
159
|
+
callsite_defs.append((loc_id, callee_id or "", caller_id or "", i))
|
|
160
|
+
|
|
161
|
+
# Resolve callsite definitions
|
|
162
|
+
# A callsite inherits the location from its callee (the code being called)
|
|
163
|
+
# and stores a reference to its caller (the code doing the calling)
|
|
164
|
+
for loc_id, callee_id, caller_id, def_line in callsite_defs:
|
|
165
|
+
if loc_id not in locations: # Avoid overwriting existing definitions
|
|
166
|
+
if callee_id in locations:
|
|
167
|
+
# Inherit location info from callee
|
|
168
|
+
callee_info = locations[callee_id]
|
|
169
|
+
locations[loc_id] = {
|
|
170
|
+
"file": callee_info["file"],
|
|
171
|
+
"line": callee_info["line"],
|
|
172
|
+
"column": callee_info["column"],
|
|
173
|
+
"def_line": def_line,
|
|
174
|
+
"is_callsite": True,
|
|
175
|
+
"callsite_callee": callee_id,
|
|
176
|
+
"callsite_caller": caller_id,
|
|
177
|
+
}
|
|
178
|
+
else:
|
|
179
|
+
logger.warning(
|
|
180
|
+
f"Callsite #loc{loc_id} references undefined callee #loc{callee_id}"
|
|
181
|
+
)
|
|
182
|
+
# Note: We don't add this callsite to locations since callee is missing
|
|
183
|
+
|
|
184
|
+
# Verify caller references (warning only, don't block)
|
|
185
|
+
for loc_id, _callee_id, caller_id, _def_line in callsite_defs:
|
|
186
|
+
if loc_id in locations and caller_id and caller_id not in locations:
|
|
187
|
+
logger.warning(
|
|
188
|
+
f"Callsite #loc{loc_id} references undefined caller #loc{caller_id}"
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Attach definition line and alias metadata
|
|
192
|
+
for k, v in def_line_map.items():
|
|
193
|
+
if k in locations:
|
|
194
|
+
locations[k]["def_line"] = v
|
|
195
|
+
for alias_id, target_id in alias_map.items():
|
|
196
|
+
if alias_id in locations:
|
|
197
|
+
locations[alias_id]["alias_of"] = target_id
|
|
198
|
+
if alias_id in alias_name_map:
|
|
199
|
+
locations[alias_id]["alias_name"] = alias_name_map[alias_id]
|
|
200
|
+
|
|
201
|
+
# Attach definition line metadata
|
|
202
|
+
for k, v in def_line_map.items():
|
|
203
|
+
if k in locations:
|
|
204
|
+
locations[k]["def_line"] = v
|
|
205
|
+
if main_loc_line and "" in locations:
|
|
206
|
+
locations[""]["def_line"] = main_loc_line
|
|
207
|
+
return locations
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def extract_code_locations(ir_content: str) -> Dict[int, str]:
|
|
211
|
+
"""
|
|
212
|
+
Extracts code location mappings from the given IR content.
|
|
213
|
+
|
|
214
|
+
This function scans through the provided IR content line by line and identifies
|
|
215
|
+
lines that contain location references. It uses regular expressions to match
|
|
216
|
+
both the #loc directives and direct file references. The function returns a
|
|
217
|
+
dictionary mapping line numbers to their corresponding location identifiers.
|
|
218
|
+
Limitations:
|
|
219
|
+
For the first function arguments line, it may use some #loc(file:line:col), DIRECT_FILE_PATTERN, we only use the first location reference.
|
|
220
|
+
Args:
|
|
221
|
+
ir_content (str): The content of the IR file as a string.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Dict[int, str]: A dictionary mapping line numbers to location identifiers,
|
|
225
|
+
which can be either a #loc identifier or a direct file reference.
|
|
226
|
+
"""
|
|
227
|
+
line_to_loc = {}
|
|
228
|
+
for i, line in enumerate(ir_content.split("\n"), start=1):
|
|
229
|
+
if m := CODE_LOC_PATTERN.search(line):
|
|
230
|
+
line_to_loc[i] = m.group(1) or "0"
|
|
231
|
+
elif m := DIRECT_FILE_PATTERN.search(line):
|
|
232
|
+
file_path, ln, col = m.groups()
|
|
233
|
+
line_to_loc[i] = f"direct:{file_path}:{ln}:{col}"
|
|
234
|
+
return line_to_loc
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def extract_ptx_amdgcn_mappings(
|
|
238
|
+
content: str, other_mappings: List[Any] = None, ir_type: str = "ptx"
|
|
239
|
+
) -> Dict[str, Dict[str, Any]]:
|
|
240
|
+
"""
|
|
241
|
+
Extract mappings from PTX code where `.loc` directives provide source file and line info.
|
|
242
|
+
This function only processes code between the function begin and end markers (e.g., "// -- Begin function" and "// -- End function"). The PTX source code line mapping is quite different from that of other IRs. It segments the PTX code using the .loc directive, where each .loc directive provides information for mapping to a source code line.
|
|
243
|
+
|
|
244
|
+
This function:
|
|
245
|
+
1. Identifies the function boundary in PTX code
|
|
246
|
+
2. Only processes code within the function boundary
|
|
247
|
+
3. Maps PTX lines with `.loc` directives to source files and line numbers
|
|
248
|
+
4. Associates subsequent code lines with the most recent `.loc` directive
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
ptx_content: The content of the PTX file
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
Dictionary mapping PTX line numbers to source location information
|
|
255
|
+
"""
|
|
256
|
+
mappings = {}
|
|
257
|
+
current_mapping = None
|
|
258
|
+
|
|
259
|
+
# Mark function scope
|
|
260
|
+
function_start_line = 0
|
|
261
|
+
function_end_line = 0
|
|
262
|
+
# filename: {file_path, ...}
|
|
263
|
+
referenced_files = defaultdict(set)
|
|
264
|
+
if other_mappings is None:
|
|
265
|
+
other_mappings = []
|
|
266
|
+
for other in other_mappings:
|
|
267
|
+
for _, info in other.items():
|
|
268
|
+
if "file" in info:
|
|
269
|
+
file_name = os.path.basename(info["file"])
|
|
270
|
+
referenced_files[file_name].add(info["file"])
|
|
271
|
+
|
|
272
|
+
def get_file_path(filename: str) -> str:
|
|
273
|
+
file_path = filename
|
|
274
|
+
if not os.path.isabs(filename):
|
|
275
|
+
logger.debug(
|
|
276
|
+
f"Filename '{filename}' does not contain a path. Attempting to resolve."
|
|
277
|
+
)
|
|
278
|
+
# Attempt to resolve the filename to a full path using referenced_files
|
|
279
|
+
if filename in referenced_files:
|
|
280
|
+
if len(referenced_files[filename]) > 1:
|
|
281
|
+
logger.debug(
|
|
282
|
+
f"Filename '{filename}' has multiple file paths. Using the first one."
|
|
283
|
+
)
|
|
284
|
+
file_path = list(referenced_files[filename])[0]
|
|
285
|
+
logger.debug(f"Resolved filename '{filename}' to {file_path}")
|
|
286
|
+
else:
|
|
287
|
+
logger.debug(f"Filename '{filename}' not found in referenced files.")
|
|
288
|
+
return file_path
|
|
289
|
+
|
|
290
|
+
# Regular expressions to match function start and end markers
|
|
291
|
+
# @TODO: need to double check if the PTX content only contains one function
|
|
292
|
+
begin_func_pattern = re.compile(
|
|
293
|
+
r"(?:(?://|;)\s*(?:\.globl\s+\S+\s+)?|\.globl\s+\S+\s+;\s*)--\s*Begin function"
|
|
294
|
+
)
|
|
295
|
+
end_func_pattern = re.compile(r"(?://|;)\s*--\s*End function")
|
|
296
|
+
|
|
297
|
+
# First scan: find function boundaries
|
|
298
|
+
lines = content.split("\n")
|
|
299
|
+
for i, line in enumerate(lines, 1):
|
|
300
|
+
if begin_func_pattern.search(line) and function_start_line == 0:
|
|
301
|
+
function_start_line = i
|
|
302
|
+
elif end_func_pattern.search(line) and function_start_line > 0:
|
|
303
|
+
function_end_line = i
|
|
304
|
+
break
|
|
305
|
+
|
|
306
|
+
# If no function boundaries are found, return empty mapping
|
|
307
|
+
if function_start_line == 0 or function_end_line == 0:
|
|
308
|
+
logger.warning(
|
|
309
|
+
f"Could not identify {ir_type} function boundaries. No {ir_type} mappings generated."
|
|
310
|
+
)
|
|
311
|
+
return mappings
|
|
312
|
+
|
|
313
|
+
logger.debug(
|
|
314
|
+
f"Processing {ir_type} function from line {function_start_line} to {function_end_line}"
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
is_ptx = ir_type == "ptx"
|
|
318
|
+
is_amdgcn = ir_type == "amdgcn"
|
|
319
|
+
|
|
320
|
+
tmp_loc_pattern = PTX_LOC_PATTERN if is_ptx else AMDGCN_LOC_PATTERN
|
|
321
|
+
# Second scan: process code within function body
|
|
322
|
+
# pay attention to the line number, it starts from 0 but the function_start_line starts from 1
|
|
323
|
+
for i, line in enumerate(
|
|
324
|
+
lines[function_start_line:function_end_line], start=function_start_line + 1
|
|
325
|
+
):
|
|
326
|
+
try:
|
|
327
|
+
# Check .loc directive line
|
|
328
|
+
match = tmp_loc_pattern.match(line)
|
|
329
|
+
if match:
|
|
330
|
+
if is_ptx:
|
|
331
|
+
py_line, py_col, filename, _, _ = match.groups()
|
|
332
|
+
elif is_amdgcn:
|
|
333
|
+
py_file_index, py_line, py_col, filename, _, _ = match.groups()
|
|
334
|
+
else:
|
|
335
|
+
logger.error(f"Unknown IR type: {ir_type}")
|
|
336
|
+
raise ValueError(f"Unknown IR type: {ir_type}")
|
|
337
|
+
file_path = get_file_path(filename)
|
|
338
|
+
# Create new mapping
|
|
339
|
+
current_mapping = {
|
|
340
|
+
"file": file_path,
|
|
341
|
+
"line": int(py_line),
|
|
342
|
+
"column": int(py_col),
|
|
343
|
+
f"{ir_type}_line": i,
|
|
344
|
+
}
|
|
345
|
+
# Store mapping
|
|
346
|
+
mappings[str(i)] = current_mapping
|
|
347
|
+
elif current_mapping:
|
|
348
|
+
# For lines without their own .loc after .loc directive, associate with the nearest .loc mapping
|
|
349
|
+
# Only process non-empty, non-comment meaningful code lines
|
|
350
|
+
line_content = line.strip()
|
|
351
|
+
if line_content and not (
|
|
352
|
+
(is_ptx and line_content.startswith("//"))
|
|
353
|
+
or (is_amdgcn and line_content.startswith(";"))
|
|
354
|
+
):
|
|
355
|
+
mappings[str(i)] = {
|
|
356
|
+
"file": current_mapping["file"],
|
|
357
|
+
"line": current_mapping["line"],
|
|
358
|
+
"column": current_mapping["column"],
|
|
359
|
+
f"{ir_type}_line": i,
|
|
360
|
+
}
|
|
361
|
+
except Exception as e:
|
|
362
|
+
logger.error(f"Error processing line {i}: {e}")
|
|
363
|
+
logger.error(f"Line content: {line}")
|
|
364
|
+
raise e
|
|
365
|
+
return mappings
|
tritonparse/mapper.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from typing import Any, Dict, List, Tuple
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger("SourceMapping")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def create_python_mapping(
|
|
12
|
+
ir_maps: List[Tuple[str, Dict[str, Dict[str, Any]]]],
|
|
13
|
+
) -> Dict[int, Dict[str, List[int]]]:
|
|
14
|
+
"""
|
|
15
|
+
Create a mapping from Python source code to IR mappings. We assume there is only one Python source code for each triton kernel.
|
|
16
|
+
Args:
|
|
17
|
+
ir_maps: A list of tuples containing the IR type and the IR mappings.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
A dictionary mapping Python source code line numbers to their corresponding IR mappings.
|
|
21
|
+
"""
|
|
22
|
+
py_map = defaultdict(lambda: defaultdict(list))
|
|
23
|
+
for ir_type, ir_map in ir_maps:
|
|
24
|
+
for line_number, info in ir_map.items():
|
|
25
|
+
py_line_number: int = info["line"]
|
|
26
|
+
py_map[py_line_number][f"{ir_type}_lines"].append(line_number)
|
|
27
|
+
return {k: dict(v) for k, v in py_map.items()}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def create_ir_mapping(
|
|
31
|
+
source_map: Dict[str, Dict[str, Any]], target_map: Dict[str, Dict[str, Any]]
|
|
32
|
+
) -> Dict[str, List[int]]:
|
|
33
|
+
"""
|
|
34
|
+
Create a mapping from source IR lines to target IR lines.
|
|
35
|
+
|
|
36
|
+
This function takes two mappings: one for source IR and one for target IR, and creates a new mapping
|
|
37
|
+
that associates lines in the source IR with corresponding lines in the target IR based on their file,
|
|
38
|
+
line, and column information.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
source_map (Dict[str, Dict[str, Any]]): A dictionary mapping source IR line numbers to their source file,
|
|
42
|
+
line, and column information.
|
|
43
|
+
target_map (Dict[str, Dict[str, Any]]): A dictionary mapping target IR line numbers to their source file,
|
|
44
|
+
line, and column information.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Dict[str, List[int]]: A dictionary mapping source IR line numbers to lists of corresponding target IR line numbers.
|
|
48
|
+
"""
|
|
49
|
+
source_to_target = defaultdict(list)
|
|
50
|
+
|
|
51
|
+
# Build a mapping from source file locations to target lines
|
|
52
|
+
for tgt_line, tgt_info in target_map.items():
|
|
53
|
+
if "file" in tgt_info and "line" in tgt_info:
|
|
54
|
+
key = f"{tgt_info['file']}:{tgt_info['line']}:{tgt_info.get('column', 0)}"
|
|
55
|
+
source_to_target[key].append(int(tgt_line))
|
|
56
|
+
|
|
57
|
+
# Map source lines to target lines
|
|
58
|
+
mapping = {}
|
|
59
|
+
for src_line, src_info in source_map.items():
|
|
60
|
+
if "file" in src_info and "line" in src_info:
|
|
61
|
+
key = f"{src_info['file']}:{src_info['line']}:{src_info.get('column', 0)}"
|
|
62
|
+
if key in source_to_target:
|
|
63
|
+
mapping[src_line] = sorted(source_to_target[key])
|
|
64
|
+
|
|
65
|
+
return mapping
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def create_bidirectional_mapping(
|
|
69
|
+
source_map: Dict[str, Dict[str, Any]],
|
|
70
|
+
target_map: Dict[str, Dict[str, Any]],
|
|
71
|
+
source_type: str,
|
|
72
|
+
target_type: str,
|
|
73
|
+
) -> None:
|
|
74
|
+
"""
|
|
75
|
+
Create bidirectional mappings between two IR types and update their mapping dictionaries.
|
|
76
|
+
|
|
77
|
+
This function creates mappings from source IR to target IR and vice versa, then
|
|
78
|
+
updates both mapping dictionaries with the line references.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
source_map: Dictionary mapping source IR line numbers to source locations
|
|
82
|
+
target_map: Dictionary mapping target IR line numbers to source locations
|
|
83
|
+
source_type: String identifier for the source IR type (e.g., 'ttir', 'ttgir', 'ptx')
|
|
84
|
+
target_type: String identifier for the target IR type (e.g., 'ttir', 'ttgir', 'ptx')
|
|
85
|
+
"""
|
|
86
|
+
# Create forward mapping (source to target)
|
|
87
|
+
source_to_target = create_ir_mapping(source_map, target_map)
|
|
88
|
+
|
|
89
|
+
# Add target line references to source mappings
|
|
90
|
+
for source_line, target_lines in source_to_target.items():
|
|
91
|
+
if source_line in source_map and target_lines:
|
|
92
|
+
source_map[source_line][f"{target_type}_lines"] = target_lines
|
|
93
|
+
|
|
94
|
+
# Create reverse mapping (target to source)
|
|
95
|
+
target_to_source = create_ir_mapping(target_map, source_map)
|
|
96
|
+
|
|
97
|
+
# Add source line references to target mappings
|
|
98
|
+
for target_line, source_lines in target_to_source.items():
|
|
99
|
+
if target_line in target_map:
|
|
100
|
+
target_map[target_line][f"{source_type}_lines"] = source_lines
|
|
101
|
+
|
|
102
|
+
logger.debug(f"Created {source_type} to {target_type} mappings (and reverse)")
|
|
File without changes
|