tritonparse 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/__init__.py +0 -0
- tritonparse/common.py +409 -0
- tritonparse/event_diff.py +120 -0
- tritonparse/extract_source_mappings.py +49 -0
- tritonparse/ir_parser.py +220 -0
- tritonparse/mapper.py +100 -0
- tritonparse/reproducer/__init__.py +21 -0
- tritonparse/reproducer/__main__.py +81 -0
- tritonparse/reproducer/cli.py +37 -0
- tritonparse/reproducer/config.py +15 -0
- tritonparse/reproducer/factory.py +16 -0
- tritonparse/reproducer/ingestion/__init__.py +6 -0
- tritonparse/reproducer/ingestion/ndjson.py +165 -0
- tritonparse/reproducer/orchestrator.py +65 -0
- tritonparse/reproducer/param_generator.py +142 -0
- tritonparse/reproducer/prompts/__init__.py +1 -0
- tritonparse/reproducer/prompts/loader.py +18 -0
- tritonparse/reproducer/providers/__init__.py +1 -0
- tritonparse/reproducer/providers/base.py +14 -0
- tritonparse/reproducer/providers/gemini.py +47 -0
- tritonparse/reproducer/runtime/__init__.py +1 -0
- tritonparse/reproducer/runtime/executor.py +13 -0
- tritonparse/reproducer/utils/io.py +6 -0
- tritonparse/shared_vars.py +9 -0
- tritonparse/source_type.py +56 -0
- tritonparse/sourcemap_utils.py +72 -0
- tritonparse/structured_logging.py +1046 -0
- tritonparse/tools/__init__.py +0 -0
- tritonparse/tools/decompress_bin_ndjson.py +118 -0
- tritonparse/tools/format_fix.py +149 -0
- tritonparse/tools/load_tensor.py +58 -0
- tritonparse/tools/prettify_ndjson.py +315 -0
- tritonparse/tp_logger.py +9 -0
- tritonparse/trace_processor.py +331 -0
- tritonparse/utils.py +156 -0
- tritonparse-0.1.1.dist-info/METADATA +10 -0
- tritonparse-0.1.1.dist-info/RECORD +40 -0
- tritonparse-0.1.1.dist-info/WHEEL +5 -0
- tritonparse-0.1.1.dist-info/licenses/LICENSE +29 -0
- tritonparse-0.1.1.dist-info/top_level.txt +1 -0
tritonparse/ir_parser.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from typing import Any, Dict, List
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger("SourceMapping")
|
|
8
|
+
|
|
9
|
+
# the definition of the #loc directive. they are in the bottom of the IR files
|
|
10
|
+
# Example:#loc2 = loc("/tmp/torchinductor_yhao/yp/abcdef.py":20:28)
|
|
11
|
+
LOC_PATTERN = re.compile(r'#loc(\d*) = loc\("([^"]+)":(\d+):(\d+)\)')
|
|
12
|
+
|
|
13
|
+
# the reference to the #loc directive. they are in the end of lines of the IR files
|
|
14
|
+
# Example: loc(#loc2)
|
|
15
|
+
CODE_LOC_PATTERN = re.compile(r".*loc\(#loc(\d*)\)\s*$")
|
|
16
|
+
|
|
17
|
+
# this pattern is used in the first function arguments line.
|
|
18
|
+
DIRECT_FILE_PATTERN = re.compile(r'.*loc\("([^"]+)":(\d+):(\d+)\)')
|
|
19
|
+
|
|
20
|
+
# the definition of the PTX loc directive.
|
|
21
|
+
# Example: .loc 1 0 50 // abcdef.py:0:50
|
|
22
|
+
PTX_LOC_PATTERN = re.compile(
|
|
23
|
+
r"^\s*\.loc\s+\d+\s+(\d+)\s+(\d+)\s+//\s*(.+?):(\d+):(\d+)"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# the definition of the AMDGCN loc directive.
|
|
27
|
+
# Example: .loc 1 32 30 ; abcd.py:32:30
|
|
28
|
+
# .loc 1 32 46 is_stmt 0 ; abcd.py:32:46
|
|
29
|
+
AMDGCN_LOC_PATTERN = re.compile(
|
|
30
|
+
r".*loc\s+(\d+)\s+(\d+)\s+(\d+)(?:\s+[^;]*)?;\s*(.+?):(\d+):(\d+)"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def extract_loc_definitions(ir_content: str) -> Dict[str, Dict[str, Any]]:
|
|
35
|
+
"""
|
|
36
|
+
Extracts location definitions from the given IR content.
|
|
37
|
+
|
|
38
|
+
This function searches for #loc directives in the provided IR content string.
|
|
39
|
+
It identifies the main #loc directive, which is a special case located at the top
|
|
40
|
+
of the IR files, and any subsequent #loc directives that define source file locations.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
ir_content (str): The content of the IR file as a string.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Dict[str, Dict[str, Any]]: A dictionary mapping location IDs to their corresponding
|
|
47
|
+
file names, line numbers, and column numbers.
|
|
48
|
+
"""
|
|
49
|
+
locations = {}
|
|
50
|
+
# The first #loc directive is a special case. It locates at the top of the IR files
|
|
51
|
+
main_match = re.search(r'#loc = loc\("([^"]+)":(\d+):(\d+)\)', ir_content)
|
|
52
|
+
if main_match:
|
|
53
|
+
locations["1"] = {
|
|
54
|
+
"file": main_match.group(1),
|
|
55
|
+
"line": int(main_match.group(2)),
|
|
56
|
+
"column": int(main_match.group(3)),
|
|
57
|
+
}
|
|
58
|
+
# #loc1 = loc(unknown) is another special case. We ignore it.
|
|
59
|
+
for loc_id, filename, line, col in LOC_PATTERN.findall(ir_content):
|
|
60
|
+
key = loc_id
|
|
61
|
+
locations[key] = {"file": filename, "line": int(line), "column": int(col)}
|
|
62
|
+
return locations
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def extract_code_locations(ir_content: str) -> Dict[int, str]:
|
|
66
|
+
"""
|
|
67
|
+
Extracts code location mappings from the given IR content.
|
|
68
|
+
|
|
69
|
+
This function scans through the provided IR content line by line and identifies
|
|
70
|
+
lines that contain location references. It uses regular expressions to match
|
|
71
|
+
both the #loc directives and direct file references. The function returns a
|
|
72
|
+
dictionary mapping line numbers to their corresponding location identifiers.
|
|
73
|
+
Limitations:
|
|
74
|
+
For the first function arguments line, it may use some #loc(file:line:col), DIRECT_FILE_PATTERN, we only use the first location reference.
|
|
75
|
+
Args:
|
|
76
|
+
ir_content (str): The content of the IR file as a string.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Dict[int, str]: A dictionary mapping line numbers to location identifiers,
|
|
80
|
+
which can be either a #loc identifier or a direct file reference.
|
|
81
|
+
"""
|
|
82
|
+
line_to_loc = {}
|
|
83
|
+
for i, line in enumerate(ir_content.split("\n"), start=1):
|
|
84
|
+
if m := CODE_LOC_PATTERN.search(line):
|
|
85
|
+
line_to_loc[i] = m.group(1) or "0"
|
|
86
|
+
elif m := DIRECT_FILE_PATTERN.search(line):
|
|
87
|
+
file_path, ln, col = m.groups()
|
|
88
|
+
line_to_loc[i] = f"direct:{file_path}:{ln}:{col}"
|
|
89
|
+
return line_to_loc
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def extract_ptx_amdgcn_mappings(
|
|
93
|
+
content: str, other_mappings: List[Any] = None, ir_type: str = "ptx"
|
|
94
|
+
) -> Dict[str, Dict[str, Any]]:
|
|
95
|
+
"""
|
|
96
|
+
Extract mappings from PTX code where `.loc` directives provide source file and line info.
|
|
97
|
+
This function only processes code between the function begin and end markers (e.g., "// -- Begin function" and "// -- End function"). The PTX source code line mapping is quite different from that of other IRs. It segments the PTX code using the .loc directive, where each .loc directive provides information for mapping to a source code line.
|
|
98
|
+
|
|
99
|
+
This function:
|
|
100
|
+
1. Identifies the function boundary in PTX code
|
|
101
|
+
2. Only processes code within the function boundary
|
|
102
|
+
3. Maps PTX lines with `.loc` directives to source files and line numbers
|
|
103
|
+
4. Associates subsequent code lines with the most recent `.loc` directive
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
ptx_content: The content of the PTX file
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Dictionary mapping PTX line numbers to source location information
|
|
110
|
+
"""
|
|
111
|
+
mappings = {}
|
|
112
|
+
current_mapping = None
|
|
113
|
+
|
|
114
|
+
# Mark function scope
|
|
115
|
+
function_start_line = 0
|
|
116
|
+
function_end_line = 0
|
|
117
|
+
# filename: {file_path, ...}
|
|
118
|
+
referenced_files = defaultdict(set)
|
|
119
|
+
if other_mappings is None:
|
|
120
|
+
other_mappings = []
|
|
121
|
+
for other in other_mappings:
|
|
122
|
+
for _, info in other.items():
|
|
123
|
+
if "file" in info:
|
|
124
|
+
file_name = os.path.basename(info["file"])
|
|
125
|
+
referenced_files[file_name].add(info["file"])
|
|
126
|
+
|
|
127
|
+
def get_file_path(filename: str) -> str:
|
|
128
|
+
file_path = filename
|
|
129
|
+
if not os.path.isabs(filename):
|
|
130
|
+
logger.debug(
|
|
131
|
+
f"Filename '{filename}' does not contain a path. Attempting to resolve."
|
|
132
|
+
)
|
|
133
|
+
# Attempt to resolve the filename to a full path using referenced_files
|
|
134
|
+
if filename in referenced_files:
|
|
135
|
+
if len(referenced_files[filename]) > 1:
|
|
136
|
+
logger.debug(
|
|
137
|
+
f"Filename '{filename}' has multiple file paths. Using the first one."
|
|
138
|
+
)
|
|
139
|
+
file_path = list(referenced_files[filename])[0]
|
|
140
|
+
logger.debug(f"Resolved filename '{filename}' to {file_path}")
|
|
141
|
+
else:
|
|
142
|
+
logger.debug(f"Filename '{filename}' not found in referenced files.")
|
|
143
|
+
return file_path
|
|
144
|
+
|
|
145
|
+
# Regular expressions to match function start and end markers
|
|
146
|
+
# @TODO: need to double check if the PTX content only contains one function
|
|
147
|
+
begin_func_pattern = re.compile(
|
|
148
|
+
r"(?:(?://|;)\s*(?:\.globl\s+\S+\s+)?|\.globl\s+\S+\s+;\s*)--\s*Begin function"
|
|
149
|
+
)
|
|
150
|
+
end_func_pattern = re.compile(r"(?://|;)\s*--\s*End function")
|
|
151
|
+
|
|
152
|
+
# First scan: find function boundaries
|
|
153
|
+
lines = content.split("\n")
|
|
154
|
+
for i, line in enumerate(lines, 1):
|
|
155
|
+
if begin_func_pattern.search(line) and function_start_line == 0:
|
|
156
|
+
function_start_line = i
|
|
157
|
+
elif end_func_pattern.search(line) and function_start_line > 0:
|
|
158
|
+
function_end_line = i
|
|
159
|
+
break
|
|
160
|
+
|
|
161
|
+
# If no function boundaries are found, return empty mapping
|
|
162
|
+
if function_start_line == 0 or function_end_line == 0:
|
|
163
|
+
logger.warning(
|
|
164
|
+
f"Could not identify {ir_type} function boundaries. No {ir_type} mappings generated."
|
|
165
|
+
)
|
|
166
|
+
return mappings
|
|
167
|
+
|
|
168
|
+
logger.debug(
|
|
169
|
+
f"Processing {ir_type} function from line {function_start_line} to {function_end_line}"
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
is_ptx = ir_type == "ptx"
|
|
173
|
+
is_amdgcn = ir_type == "amdgcn"
|
|
174
|
+
|
|
175
|
+
tmp_loc_pattern = PTX_LOC_PATTERN if is_ptx else AMDGCN_LOC_PATTERN
|
|
176
|
+
# Second scan: process code within function body
|
|
177
|
+
# pay attention to the line number, it starts from 0 but the function_start_line starts from 1
|
|
178
|
+
for i, line in enumerate(
|
|
179
|
+
lines[function_start_line:function_end_line], start=function_start_line + 1
|
|
180
|
+
):
|
|
181
|
+
try:
|
|
182
|
+
# Check .loc directive line
|
|
183
|
+
match = tmp_loc_pattern.match(line)
|
|
184
|
+
if match:
|
|
185
|
+
if is_ptx:
|
|
186
|
+
py_line, py_col, filename, _, _ = match.groups()
|
|
187
|
+
elif is_amdgcn:
|
|
188
|
+
py_file_index, py_line, py_col, filename, _, _ = match.groups()
|
|
189
|
+
else:
|
|
190
|
+
logger.error(f"Unknown IR type: {ir_type}")
|
|
191
|
+
raise ValueError(f"Unknown IR type: {ir_type}")
|
|
192
|
+
file_path = get_file_path(filename)
|
|
193
|
+
# Create new mapping
|
|
194
|
+
current_mapping = {
|
|
195
|
+
"file": file_path,
|
|
196
|
+
"line": int(py_line),
|
|
197
|
+
"column": int(py_col),
|
|
198
|
+
f"{ir_type}_line": i,
|
|
199
|
+
}
|
|
200
|
+
# Store mapping
|
|
201
|
+
mappings[str(i)] = current_mapping
|
|
202
|
+
elif current_mapping:
|
|
203
|
+
# For lines without their own .loc after .loc directive, associate with the nearest .loc mapping
|
|
204
|
+
# Only process non-empty, non-comment meaningful code lines
|
|
205
|
+
line_content = line.strip()
|
|
206
|
+
if line_content and not (
|
|
207
|
+
(is_ptx and line_content.startswith("//"))
|
|
208
|
+
or (is_amdgcn and line_content.startswith(";"))
|
|
209
|
+
):
|
|
210
|
+
mappings[str(i)] = {
|
|
211
|
+
"file": current_mapping["file"],
|
|
212
|
+
"line": current_mapping["line"],
|
|
213
|
+
"column": current_mapping["column"],
|
|
214
|
+
f"{ir_type}_line": i,
|
|
215
|
+
}
|
|
216
|
+
except Exception as e:
|
|
217
|
+
logger.error(f"Error processing line {i}: {e}")
|
|
218
|
+
logger.error(f"Line content: {line}")
|
|
219
|
+
raise e
|
|
220
|
+
return mappings
|
tritonparse/mapper.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import Any, Dict, List, Tuple
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger("SourceMapping")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def create_python_mapping(
|
|
10
|
+
ir_maps: List[Tuple[str, Dict[str, Dict[str, Any]]]],
|
|
11
|
+
) -> Dict[int, Dict[str, List[int]]]:
|
|
12
|
+
"""
|
|
13
|
+
Create a mapping from Python source code to IR mappings. We assume there is only one Python source code for each triton kernel.
|
|
14
|
+
Args:
|
|
15
|
+
ir_maps: A list of tuples containing the IR type and the IR mappings.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
A dictionary mapping Python source code line numbers to their corresponding IR mappings.
|
|
19
|
+
"""
|
|
20
|
+
py_map = defaultdict(lambda: defaultdict(list))
|
|
21
|
+
for ir_type, ir_map in ir_maps:
|
|
22
|
+
for line_number, info in ir_map.items():
|
|
23
|
+
py_line_number: int = info["line"]
|
|
24
|
+
py_map[py_line_number][f"{ir_type}_lines"].append(line_number)
|
|
25
|
+
return {k: dict(v) for k, v in py_map.items()}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def create_ir_mapping(
|
|
29
|
+
source_map: Dict[str, Dict[str, Any]], target_map: Dict[str, Dict[str, Any]]
|
|
30
|
+
) -> Dict[str, List[int]]:
|
|
31
|
+
"""
|
|
32
|
+
Create a mapping from source IR lines to target IR lines.
|
|
33
|
+
|
|
34
|
+
This function takes two mappings: one for source IR and one for target IR, and creates a new mapping
|
|
35
|
+
that associates lines in the source IR with corresponding lines in the target IR based on their file,
|
|
36
|
+
line, and column information.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
source_map (Dict[str, Dict[str, Any]]): A dictionary mapping source IR line numbers to their source file,
|
|
40
|
+
line, and column information.
|
|
41
|
+
target_map (Dict[str, Dict[str, Any]]): A dictionary mapping target IR line numbers to their source file,
|
|
42
|
+
line, and column information.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Dict[str, List[int]]: A dictionary mapping source IR line numbers to lists of corresponding target IR line numbers.
|
|
46
|
+
"""
|
|
47
|
+
source_to_target = defaultdict(list)
|
|
48
|
+
|
|
49
|
+
# Build a mapping from source file locations to target lines
|
|
50
|
+
for tgt_line, tgt_info in target_map.items():
|
|
51
|
+
if "file" in tgt_info and "line" in tgt_info:
|
|
52
|
+
key = f"{tgt_info['file']}:{tgt_info['line']}:{tgt_info.get('column', 0)}"
|
|
53
|
+
source_to_target[key].append(int(tgt_line))
|
|
54
|
+
|
|
55
|
+
# Map source lines to target lines
|
|
56
|
+
mapping = {}
|
|
57
|
+
for src_line, src_info in source_map.items():
|
|
58
|
+
if "file" in src_info and "line" in src_info:
|
|
59
|
+
key = f"{src_info['file']}:{src_info['line']}:{src_info.get('column', 0)}"
|
|
60
|
+
if key in source_to_target:
|
|
61
|
+
mapping[src_line] = sorted(source_to_target[key])
|
|
62
|
+
|
|
63
|
+
return mapping
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def create_bidirectional_mapping(
|
|
67
|
+
source_map: Dict[str, Dict[str, Any]],
|
|
68
|
+
target_map: Dict[str, Dict[str, Any]],
|
|
69
|
+
source_type: str,
|
|
70
|
+
target_type: str,
|
|
71
|
+
) -> None:
|
|
72
|
+
"""
|
|
73
|
+
Create bidirectional mappings between two IR types and update their mapping dictionaries.
|
|
74
|
+
|
|
75
|
+
This function creates mappings from source IR to target IR and vice versa, then
|
|
76
|
+
updates both mapping dictionaries with the line references.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
source_map: Dictionary mapping source IR line numbers to source locations
|
|
80
|
+
target_map: Dictionary mapping target IR line numbers to source locations
|
|
81
|
+
source_type: String identifier for the source IR type (e.g., 'ttir', 'ttgir', 'ptx')
|
|
82
|
+
target_type: String identifier for the target IR type (e.g., 'ttir', 'ttgir', 'ptx')
|
|
83
|
+
"""
|
|
84
|
+
# Create forward mapping (source to target)
|
|
85
|
+
source_to_target = create_ir_mapping(source_map, target_map)
|
|
86
|
+
|
|
87
|
+
# Add target line references to source mappings
|
|
88
|
+
for source_line, target_lines in source_to_target.items():
|
|
89
|
+
if source_line in source_map and target_lines:
|
|
90
|
+
source_map[source_line][f"{target_type}_lines"] = target_lines
|
|
91
|
+
|
|
92
|
+
# Create reverse mapping (target to source)
|
|
93
|
+
target_to_source = create_ir_mapping(target_map, source_map)
|
|
94
|
+
|
|
95
|
+
# Add source line references to target mappings
|
|
96
|
+
for target_line, source_lines in target_to_source.items():
|
|
97
|
+
if target_line in target_map:
|
|
98
|
+
target_map[target_line][f"{source_type}_lines"] = source_lines
|
|
99
|
+
|
|
100
|
+
logger.debug(f"Created {source_type} to {target_type} mappings (and reverse)")
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Reproducer subpackage: generate runnable Triton repro scripts from traces.
|
|
2
|
+
|
|
3
|
+
Contains:
|
|
4
|
+
- ingestion.ndjson: parse NDJSON and build a context bundle
|
|
5
|
+
- orchestrator: LLM-based code generation with optional execute/repair
|
|
6
|
+
- providers: LLM provider protocol and Gemini provider
|
|
7
|
+
- prompts: simple prompt loader and templates
|
|
8
|
+
- runtime.executor: helper to run generated Python scripts
|
|
9
|
+
- param_generator: synthesize tensor/scalar allocations to reduce LLM burden
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from .ingestion.ndjson import build_context_bundle
|
|
13
|
+
from .orchestrator import generate_from_ndjson
|
|
14
|
+
from .param_generator import generate_allocation_snippet, generate_kwargs_dict
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"build_context_bundle",
|
|
18
|
+
"generate_from_ndjson",
|
|
19
|
+
"generate_allocation_snippet",
|
|
20
|
+
"generate_kwargs_dict",
|
|
21
|
+
]
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def main() -> None:
|
|
6
|
+
p = argparse.ArgumentParser(
|
|
7
|
+
description=(
|
|
8
|
+
"Generate a runnable Triton repro script from a tritonparse NDJSON" " trace"
|
|
9
|
+
)
|
|
10
|
+
)
|
|
11
|
+
p.add_argument("--ndjson", required=True, help="Path to NDJSON trace file")
|
|
12
|
+
p.add_argument(
|
|
13
|
+
"--launch-index",
|
|
14
|
+
type=int,
|
|
15
|
+
default=0,
|
|
16
|
+
help="Launch index to reproduce",
|
|
17
|
+
)
|
|
18
|
+
p.add_argument("--out", default="repro.py", help="Output Python file path")
|
|
19
|
+
p.add_argument(
|
|
20
|
+
"--execute",
|
|
21
|
+
action="store_true",
|
|
22
|
+
help="Execute the generated script",
|
|
23
|
+
)
|
|
24
|
+
p.add_argument(
|
|
25
|
+
"--retries",
|
|
26
|
+
type=int,
|
|
27
|
+
default=0,
|
|
28
|
+
help="Auto-repair attempts if execution fails",
|
|
29
|
+
)
|
|
30
|
+
p.add_argument(
|
|
31
|
+
"--temperature",
|
|
32
|
+
type=float,
|
|
33
|
+
help="Override sampling temperature",
|
|
34
|
+
)
|
|
35
|
+
p.add_argument(
|
|
36
|
+
"--max-tokens",
|
|
37
|
+
type=int,
|
|
38
|
+
help="Override max tokens for generation",
|
|
39
|
+
)
|
|
40
|
+
args = p.parse_args()
|
|
41
|
+
|
|
42
|
+
# Lazy imports to allow `--help` without optional deps installed
|
|
43
|
+
from .config import load_config
|
|
44
|
+
from .orchestrator import generate_from_ndjson
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
from .factory import make_gemini_provider
|
|
48
|
+
except Exception: # pragma: no cover
|
|
49
|
+
print(
|
|
50
|
+
"Failed to import provider factory. Ensure optional deps are installed (e.g. google-genai).",
|
|
51
|
+
file=sys.stderr,
|
|
52
|
+
)
|
|
53
|
+
raise
|
|
54
|
+
|
|
55
|
+
cfg = load_config()
|
|
56
|
+
try:
|
|
57
|
+
provider = make_gemini_provider()
|
|
58
|
+
except ModuleNotFoundError: # pragma: no cover
|
|
59
|
+
print(
|
|
60
|
+
"Gemini provider requires 'google-genai'. Install via: pip install google-genai",
|
|
61
|
+
file=sys.stderr,
|
|
62
|
+
)
|
|
63
|
+
sys.exit(2)
|
|
64
|
+
temperature = args.temperature if args.temperature is not None else cfg.temperature
|
|
65
|
+
max_tokens = args.max_tokens if args.max_tokens is not None else cfg.max_tokens
|
|
66
|
+
|
|
67
|
+
res = generate_from_ndjson(
|
|
68
|
+
args.ndjson,
|
|
69
|
+
provider,
|
|
70
|
+
launch_index=args.launch_index,
|
|
71
|
+
out_py=args.out,
|
|
72
|
+
execute=args.execute,
|
|
73
|
+
retries=args.retries,
|
|
74
|
+
temperature=temperature,
|
|
75
|
+
max_tokens=max_tokens,
|
|
76
|
+
)
|
|
77
|
+
print(res)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
if __name__ == "__main__": # pragma: no cover
|
|
81
|
+
main()
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
from .config import load_config
|
|
4
|
+
from .factory import make_gemini_provider
|
|
5
|
+
from .orchestrator import generate_from_ndjson
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def add_reproducer_subparser(parser: argparse.ArgumentParser) -> None:
|
|
9
|
+
sub = parser.add_subparsers(dest="subcommand")
|
|
10
|
+
repro = sub.add_parser(
|
|
11
|
+
"repro",
|
|
12
|
+
help="Generate a runnable Triton repro script from NDJSON",
|
|
13
|
+
)
|
|
14
|
+
repro.add_argument("--ndjson", required=True)
|
|
15
|
+
repro.add_argument("--launch-index", type=int, default=0)
|
|
16
|
+
repro.add_argument("--out", default="repro.py")
|
|
17
|
+
repro.add_argument("--execute", action="store_true")
|
|
18
|
+
repro.add_argument("--retries", type=int, default=0)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def maybe_handle_reproducer(args: argparse.Namespace) -> bool:
|
|
22
|
+
if getattr(args, "subcommand", None) != "repro":
|
|
23
|
+
return False
|
|
24
|
+
cfg = load_config()
|
|
25
|
+
provider = make_gemini_provider()
|
|
26
|
+
res = generate_from_ndjson(
|
|
27
|
+
args.ndjson,
|
|
28
|
+
provider,
|
|
29
|
+
launch_index=args.launch_index,
|
|
30
|
+
out_py=args.out,
|
|
31
|
+
execute=args.execute,
|
|
32
|
+
retries=args.retries,
|
|
33
|
+
temperature=cfg.temperature,
|
|
34
|
+
max_tokens=cfg.max_tokens,
|
|
35
|
+
)
|
|
36
|
+
print(res)
|
|
37
|
+
return True
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass
|
|
6
|
+
class ReproducerConfig:
|
|
7
|
+
project: str = os.getenv("GOOGLE_CLOUD_PROJECT", "")
|
|
8
|
+
location: str = os.getenv("GOOGLE_LOCATION", "us-central1")
|
|
9
|
+
model: str = os.getenv("TP_REPRO_MODEL", "gemini-2.5-pro")
|
|
10
|
+
temperature: float = float(os.getenv("TP_REPRO_TEMPERATURE", "0.1"))
|
|
11
|
+
max_tokens: int = int(os.getenv("TP_REPRO_MAX_TOKENS", "10240"))
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def load_config() -> ReproducerConfig:
|
|
15
|
+
return ReproducerConfig()
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Provider factory for reproducer.
|
|
2
|
+
|
|
3
|
+
Currently supports Gemini only.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from .config import load_config
|
|
7
|
+
from .providers.gemini import GeminiProvider
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def make_gemini_provider() -> GeminiProvider:
|
|
11
|
+
cfg = load_config()
|
|
12
|
+
return GeminiProvider(
|
|
13
|
+
project=cfg.project,
|
|
14
|
+
location=cfg.location,
|
|
15
|
+
model=cfg.model,
|
|
16
|
+
)
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any, Dict, List
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def _iter_events(path: str):
|
|
6
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
7
|
+
for line in f:
|
|
8
|
+
line = line.strip()
|
|
9
|
+
if not line:
|
|
10
|
+
continue
|
|
11
|
+
try:
|
|
12
|
+
yield json.loads(line)
|
|
13
|
+
except json.JSONDecodeError:
|
|
14
|
+
# skip malformed lines
|
|
15
|
+
continue
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _index_compilations(events: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
|
19
|
+
idx = {}
|
|
20
|
+
for e in events:
|
|
21
|
+
if e.get("event_type") != "compilation":
|
|
22
|
+
continue
|
|
23
|
+
payload = e.get("payload") or {}
|
|
24
|
+
meta = payload.get("metadata") or {}
|
|
25
|
+
h = meta.get("hash")
|
|
26
|
+
if h:
|
|
27
|
+
idx[h] = e
|
|
28
|
+
return idx
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _get_launches(events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
32
|
+
return [e for e in events if e.get("event_type") == "launch"]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _resolve_kernel_source(
|
|
36
|
+
launch: Dict[str, Any], comp_idx: Dict[str, Dict[str, Any]]
|
|
37
|
+
) -> str:
|
|
38
|
+
# In new format, launch has top-level compilation_metadata, not payload.*
|
|
39
|
+
comp_meta = (
|
|
40
|
+
launch.get("compilation_metadata")
|
|
41
|
+
or launch.get("payload", {}).get("compilation_metadata")
|
|
42
|
+
or {}
|
|
43
|
+
)
|
|
44
|
+
h = comp_meta.get("hash")
|
|
45
|
+
if not h:
|
|
46
|
+
return ""
|
|
47
|
+
comp = comp_idx.get(h, {})
|
|
48
|
+
payload = comp.get("payload") or {}
|
|
49
|
+
py = payload.get("python_source") or {}
|
|
50
|
+
return py.get("code", "")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _pack_args(args: Dict[str, Any]) -> Dict[str, Any]:
|
|
54
|
+
packed = {}
|
|
55
|
+
for k, v in args.items():
|
|
56
|
+
t = v.get("type") if isinstance(v, dict) else None
|
|
57
|
+
if t == "tensor":
|
|
58
|
+
packed[k] = {
|
|
59
|
+
"type": "tensor",
|
|
60
|
+
"shape": v.get("shape") if isinstance(v, dict) else None,
|
|
61
|
+
"dtype": v.get("dtype") if isinstance(v, dict) else None,
|
|
62
|
+
"device": v.get("device") if isinstance(v, dict) else None,
|
|
63
|
+
"stride": v.get("stride") if isinstance(v, dict) else None,
|
|
64
|
+
"is_contiguous": (
|
|
65
|
+
v.get("is_contiguous") if isinstance(v, dict) else None
|
|
66
|
+
),
|
|
67
|
+
"numel": v.get("numel") if isinstance(v, dict) else None,
|
|
68
|
+
}
|
|
69
|
+
else:
|
|
70
|
+
# scalar / NoneType etc
|
|
71
|
+
if isinstance(v, dict):
|
|
72
|
+
packed[k] = {
|
|
73
|
+
"type": v.get("type"),
|
|
74
|
+
"value": v.get("value", v.get("repr")),
|
|
75
|
+
}
|
|
76
|
+
else:
|
|
77
|
+
packed[k] = {
|
|
78
|
+
"type": None,
|
|
79
|
+
"value": v,
|
|
80
|
+
}
|
|
81
|
+
return packed
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# Sentinel and helper to normalize extracted argument values
|
|
85
|
+
_SKIP = object()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _decode_arg(raw: Any):
|
|
89
|
+
if not isinstance(raw, dict):
|
|
90
|
+
return raw
|
|
91
|
+
t = raw.get("type")
|
|
92
|
+
if t == "tensor":
|
|
93
|
+
return _SKIP
|
|
94
|
+
if t == "NoneType":
|
|
95
|
+
return None
|
|
96
|
+
return raw.get("value", raw.get("repr"))
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def build_context_bundle(ndjson_path: str, launch_index: int = 0) -> Dict[str, Any]:
|
|
100
|
+
events = list(_iter_events(ndjson_path))
|
|
101
|
+
launches = _get_launches(events)
|
|
102
|
+
if not launches:
|
|
103
|
+
raise RuntimeError("No launch events found in NDJSON.")
|
|
104
|
+
if launch_index < 0 or launch_index >= len(launches):
|
|
105
|
+
raise IndexError(
|
|
106
|
+
f"launch_index out of range: {launch_index} (total {len(launches)})"
|
|
107
|
+
)
|
|
108
|
+
launch = launches[launch_index]
|
|
109
|
+
comp_idx = _index_compilations(events)
|
|
110
|
+
kernel_source = _resolve_kernel_source(launch, comp_idx)
|
|
111
|
+
# find '@triton.jit' and slice the string
|
|
112
|
+
jit_marker = "@triton.jit"
|
|
113
|
+
jit_pos = kernel_source.find(jit_marker)
|
|
114
|
+
if jit_pos != -1:
|
|
115
|
+
kernel_source = kernel_source[jit_pos:]
|
|
116
|
+
|
|
117
|
+
# flatten launch fields (support both formats)
|
|
118
|
+
grid = launch.get("grid") or (launch.get("payload", {})).get("grid")
|
|
119
|
+
comp_meta = (
|
|
120
|
+
launch.get("compilation_metadata")
|
|
121
|
+
or (launch.get("payload", {})).get("compilation_metadata")
|
|
122
|
+
or {}
|
|
123
|
+
)
|
|
124
|
+
extracted_args = (
|
|
125
|
+
launch.get("extracted_args")
|
|
126
|
+
or (launch.get("payload", {})).get("extracted_args")
|
|
127
|
+
or {}
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# compile metadata subset we care about
|
|
131
|
+
compile_block = {
|
|
132
|
+
"num_warps": comp_meta.get("num_warps"),
|
|
133
|
+
"num_stages": comp_meta.get("num_stages"),
|
|
134
|
+
"arch": comp_meta.get("arch"),
|
|
135
|
+
"backend": comp_meta.get("backend_name") or comp_meta.get("backend"),
|
|
136
|
+
"triton_version": comp_meta.get("triton_version"),
|
|
137
|
+
"hash": comp_meta.get("hash"),
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
# kwargs: include constexpr + explicit scalars used for launch (skip tensor args)
|
|
141
|
+
kwargs = {}
|
|
142
|
+
for k, v in extracted_args.items():
|
|
143
|
+
val = _decode_arg(v)
|
|
144
|
+
if val is _SKIP:
|
|
145
|
+
continue
|
|
146
|
+
kwargs[k] = val
|
|
147
|
+
|
|
148
|
+
# tensor args: only tensors
|
|
149
|
+
tensor_args = {
|
|
150
|
+
k: v
|
|
151
|
+
for k, v in extracted_args.items()
|
|
152
|
+
if isinstance(v, dict) and v.get("type") == "tensor"
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
bundle = {
|
|
156
|
+
"kernel_source": kernel_source,
|
|
157
|
+
"compile": compile_block,
|
|
158
|
+
"launch": {
|
|
159
|
+
"grid": grid,
|
|
160
|
+
"kwargs": kwargs,
|
|
161
|
+
},
|
|
162
|
+
"args": _pack_args(extracted_args),
|
|
163
|
+
"tensor_args": _pack_args(tensor_args),
|
|
164
|
+
}
|
|
165
|
+
return bundle
|