tritonparse 0.3.2.dev20251210071601__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/__init__.py +0 -0
- tritonparse/__main__.py +7 -0
- tritonparse/cli.py +110 -0
- tritonparse/common.py +409 -0
- tritonparse/context_manager.py +64 -0
- tritonparse/event_diff.py +122 -0
- tritonparse/extract_source_mappings.py +49 -0
- tritonparse/info/__init__.py +30 -0
- tritonparse/info/cli.py +121 -0
- tritonparse/info/kernel_query.py +209 -0
- tritonparse/info/parse_helper.py +70 -0
- tritonparse/ir_analysis.py +427 -0
- tritonparse/ir_parser.py +365 -0
- tritonparse/mapper.py +102 -0
- tritonparse/reproducer/__init__.py +0 -0
- tritonparse/reproducer/ast_analyzer.py +636 -0
- tritonparse/reproducer/cli.py +72 -0
- tritonparse/reproducer/consolidated_result.py +52 -0
- tritonparse/reproducer/function_extractor.py +228 -0
- tritonparse/reproducer/import_info.py +25 -0
- tritonparse/reproducer/import_parser.py +178 -0
- tritonparse/reproducer/import_resolver.py +151 -0
- tritonparse/reproducer/ingestion/ndjson.py +237 -0
- tritonparse/reproducer/multi_file_analyzer.py +824 -0
- tritonparse/reproducer/orchestrator.py +110 -0
- tritonparse/reproducer/placeholder_replacer.py +335 -0
- tritonparse/reproducer/templates/__init__.py +0 -0
- tritonparse/reproducer/templates/example.py +38 -0
- tritonparse/reproducer/templates/loader.py +59 -0
- tritonparse/reproducer/templates/tritonbench.py +106 -0
- tritonparse/reproducer/templates/utils.py +48 -0
- tritonparse/reproducer/tests/__init__.py +0 -0
- tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
- tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
- tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
- tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
- tritonparse/reproducer/tests/test_import_parser.py +164 -0
- tritonparse/reproducer/tests/test_import_resolver.py +88 -0
- tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
- tritonparse/reproducer/types.py +20 -0
- tritonparse/reproducer/utils.py +580 -0
- tritonparse/shared_vars.py +12 -0
- tritonparse/source_type.py +56 -0
- tritonparse/sourcemap_utils.py +96 -0
- tritonparse/structured_logging.py +1634 -0
- tritonparse/tools/__init__.py +0 -0
- tritonparse/tools/decompress_bin_ndjson.py +120 -0
- tritonparse/tools/disasm.py +81 -0
- tritonparse/tools/extract_irs.py +244 -0
- tritonparse/tools/format_fix.py +151 -0
- tritonparse/tools/load_tensor.py +76 -0
- tritonparse/tools/prettify_ndjson.py +334 -0
- tritonparse/tools/readme.md +37 -0
- tritonparse/tp_logger.py +9 -0
- tritonparse/trace_processor.py +367 -0
- tritonparse/utils.py +155 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,580 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
import importlib
|
|
5
|
+
import importlib.util
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import sys
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from functools import lru_cache
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import triton.language as tl
|
|
15
|
+
from tritonparse.tools.load_tensor import load_tensor
|
|
16
|
+
from tritonparse.tp_logger import logger
|
|
17
|
+
|
|
18
|
+
TRITON_KERNELS_CUSTOM_TYPES = (
|
|
19
|
+
importlib.util.find_spec("triton_kernels") is not None
|
|
20
|
+
and importlib.util.find_spec("triton_kernels.tensor") is not None
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
# Mapping from dtype string representation to Triton dtype objects
|
|
24
|
+
TRITON_DTYPE_MAP = {
|
|
25
|
+
# Signed integers
|
|
26
|
+
"int8": tl.int8,
|
|
27
|
+
"int16": tl.int16,
|
|
28
|
+
"int32": tl.int32,
|
|
29
|
+
"int64": tl.int64,
|
|
30
|
+
# Unsigned integers
|
|
31
|
+
"int1": tl.int1,
|
|
32
|
+
"uint8": tl.uint8,
|
|
33
|
+
"uint16": tl.uint16,
|
|
34
|
+
"uint32": tl.uint32,
|
|
35
|
+
"uint64": tl.uint64,
|
|
36
|
+
# Standard floating point types
|
|
37
|
+
"fp16": tl.float16,
|
|
38
|
+
"bf16": tl.bfloat16,
|
|
39
|
+
"fp32": tl.float32,
|
|
40
|
+
"fp64": tl.float64,
|
|
41
|
+
# FP8 variants
|
|
42
|
+
"fp8e4b15": tl.float8e4b15,
|
|
43
|
+
"fp8e4nv": tl.float8e4nv,
|
|
44
|
+
"fp8e4b8": tl.float8e4b8,
|
|
45
|
+
"fp8e5": tl.float8e5,
|
|
46
|
+
"fp8e5b16": tl.float8e5b16,
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@lru_cache(maxsize=1)
|
|
51
|
+
def _get_triton_tensor_types():
|
|
52
|
+
mod = importlib.import_module("triton_kernels.tensor")
|
|
53
|
+
return (
|
|
54
|
+
mod.Tensor,
|
|
55
|
+
mod.Storage,
|
|
56
|
+
mod.StridedLayout,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def create_args_from_json_file(json_path):
|
|
61
|
+
"""
|
|
62
|
+
Load and parse a reproducer JSON file.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
json_path (str): Path to the JSON file describing the kernel launch.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
tuple[list, dict]: Grid specification list and map of argument name to value.
|
|
69
|
+
"""
|
|
70
|
+
with open(json_path, "r") as f:
|
|
71
|
+
data = json.load(f)
|
|
72
|
+
return create_args_from_json(data)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def create_args_from_json(data):
|
|
76
|
+
"""
|
|
77
|
+
Parse a reproducer JSON and build kernel grid and argument dictionary.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
data (dict | list): JSON data describing the kernel launch.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
tuple[list, dict]: Grid specification list and map of argument name to value.
|
|
84
|
+
"""
|
|
85
|
+
# Handle data format validation and extraction
|
|
86
|
+
if isinstance(data, list):
|
|
87
|
+
if len(data) != 1:
|
|
88
|
+
print(
|
|
89
|
+
f"Error: Expected single element list, got list with {len(data)} elements"
|
|
90
|
+
)
|
|
91
|
+
sys.exit(1)
|
|
92
|
+
data = data[0]
|
|
93
|
+
elif not isinstance(data, dict):
|
|
94
|
+
print(f"Error: Expected list or dict, got {type(data)}")
|
|
95
|
+
sys.exit(1)
|
|
96
|
+
|
|
97
|
+
grid = data.get("grid", [])
|
|
98
|
+
args_dict = {}
|
|
99
|
+
extracted_args = data.get("extracted_args", {})
|
|
100
|
+
|
|
101
|
+
for arg_name, arg_info in extracted_args.items():
|
|
102
|
+
args_dict[arg_name] = _create_arg_from_info(arg_info)
|
|
103
|
+
|
|
104
|
+
return grid, args_dict
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _apply_stride_and_offset(tensor, shape, stride, storage_offset):
|
|
108
|
+
"""
|
|
109
|
+
Apply custom stride and storage offset to a tensor if needed.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
tensor: The base contiguous tensor
|
|
113
|
+
shape: The desired shape
|
|
114
|
+
stride: The desired stride (or None for contiguous)
|
|
115
|
+
storage_offset: The desired storage offset
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
torch.Tensor: The strided tensor view or original tensor if contiguous
|
|
119
|
+
"""
|
|
120
|
+
if stride is None:
|
|
121
|
+
return tensor
|
|
122
|
+
|
|
123
|
+
# Calculate expected contiguous stride
|
|
124
|
+
expected_contiguous_stride = []
|
|
125
|
+
s = 1
|
|
126
|
+
for dim_size in reversed(shape):
|
|
127
|
+
expected_contiguous_stride.insert(0, s)
|
|
128
|
+
s *= dim_size
|
|
129
|
+
|
|
130
|
+
# If stride matches contiguous stride and no storage offset, return as-is
|
|
131
|
+
if tuple(stride) == tuple(expected_contiguous_stride) and storage_offset == 0:
|
|
132
|
+
return tensor
|
|
133
|
+
|
|
134
|
+
# Calculate required storage size
|
|
135
|
+
if len(shape) > 0 and len(stride) > 0:
|
|
136
|
+
max_offset = storage_offset
|
|
137
|
+
for dim_stride, dim_size in zip(stride, shape):
|
|
138
|
+
if dim_size > 0:
|
|
139
|
+
max_offset += dim_stride * (dim_size - 1)
|
|
140
|
+
storage_size = max_offset + 1
|
|
141
|
+
else:
|
|
142
|
+
storage_size = storage_offset + 1
|
|
143
|
+
|
|
144
|
+
# Create larger storage tensor and create strided view
|
|
145
|
+
storage_tensor = torch.empty(storage_size, dtype=tensor.dtype, device=tensor.device)
|
|
146
|
+
|
|
147
|
+
# Create strided view
|
|
148
|
+
strided_view = storage_tensor.as_strided(
|
|
149
|
+
size=shape, stride=stride, storage_offset=storage_offset
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Copy data from the base tensor into the strided layout
|
|
153
|
+
strided_view.copy_(tensor.flatten()[: strided_view.numel()].view(shape))
|
|
154
|
+
|
|
155
|
+
return strided_view
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _create_base_tensor(arg_info) -> torch.Tensor:
|
|
159
|
+
"""
|
|
160
|
+
Create a base tensor without stride/offset modifications.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
arg_info (dict): Argument information including dtype, shape, device, etc.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
torch.Tensor: The created base tensor
|
|
167
|
+
"""
|
|
168
|
+
if arg_info.get("blob_path"):
|
|
169
|
+
return load_tensor(arg_info.get("blob_path"), arg_info.get("device"))
|
|
170
|
+
|
|
171
|
+
# Extract basic tensor properties
|
|
172
|
+
dtype_str = arg_info.get("dtype")
|
|
173
|
+
try:
|
|
174
|
+
torch_dtype = getattr(torch, dtype_str.split(".")[-1])
|
|
175
|
+
except AttributeError:
|
|
176
|
+
logging.error(f"Unsupported dtype: {dtype_str}. Defaulting to float32.")
|
|
177
|
+
torch_dtype = torch.float32
|
|
178
|
+
|
|
179
|
+
shape = arg_info.get("shape", [])
|
|
180
|
+
device = arg_info.get("device", "cpu")
|
|
181
|
+
# Normalize cuda device to cuda:0
|
|
182
|
+
if isinstance(device, str) and device.startswith("cuda"):
|
|
183
|
+
device = "cuda:0"
|
|
184
|
+
|
|
185
|
+
# Extract statistical information if available
|
|
186
|
+
mean = arg_info.get("mean")
|
|
187
|
+
std = arg_info.get("std")
|
|
188
|
+
min_val = arg_info.get("min")
|
|
189
|
+
max_val = arg_info.get("max")
|
|
190
|
+
has_stats = (
|
|
191
|
+
mean is not None
|
|
192
|
+
and std is not None
|
|
193
|
+
and min_val is not None
|
|
194
|
+
and max_val is not None
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
if arg_info.get("tensor_capture_error", False):
|
|
198
|
+
logging.error(
|
|
199
|
+
f"Error: Tensor '{arg_info.get('name', '')}' had capture error. Generating random tensor instead."
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Use a dummy tensor to check properties of the dtype
|
|
203
|
+
tensor_props = torch.empty(0, dtype=torch_dtype)
|
|
204
|
+
|
|
205
|
+
# Case 1: Floating point types
|
|
206
|
+
if tensor_props.is_floating_point():
|
|
207
|
+
if has_stats:
|
|
208
|
+
# Generate tensor with statistical properties matching original data
|
|
209
|
+
if std == 0 or min_val == max_val:
|
|
210
|
+
# Constant tensor
|
|
211
|
+
return torch.full(shape, mean, dtype=torch_dtype, device=device)
|
|
212
|
+
# Generate normal distribution with mean and std, then clamp to [min, max]
|
|
213
|
+
tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
214
|
+
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
215
|
+
return tensor.to(torch_dtype)
|
|
216
|
+
else:
|
|
217
|
+
# Fallback to original random generation
|
|
218
|
+
if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
219
|
+
tmp = torch.rand(shape, dtype=torch.float32, device=device)
|
|
220
|
+
return tmp.to(torch_dtype)
|
|
221
|
+
else:
|
|
222
|
+
return torch.empty(shape, dtype=torch_dtype, device=device).random_()
|
|
223
|
+
|
|
224
|
+
# Case 2: Integer types
|
|
225
|
+
elif torch_dtype in [
|
|
226
|
+
torch.int8,
|
|
227
|
+
torch.int16,
|
|
228
|
+
torch.int32,
|
|
229
|
+
torch.int64,
|
|
230
|
+
torch.uint8,
|
|
231
|
+
torch.bool,
|
|
232
|
+
]:
|
|
233
|
+
if has_stats and torch_dtype != torch.bool:
|
|
234
|
+
# Generate tensor with statistical properties, then round for integers
|
|
235
|
+
if std == 0 or min_val == max_val:
|
|
236
|
+
# Constant tensor
|
|
237
|
+
return torch.full(shape, int(mean), dtype=torch_dtype, device=device)
|
|
238
|
+
tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
239
|
+
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
240
|
+
return torch.round(tensor).to(torch_dtype)
|
|
241
|
+
else:
|
|
242
|
+
# Fallback to original random generation
|
|
243
|
+
return torch.empty(shape, dtype=torch_dtype, device=device).random_()
|
|
244
|
+
|
|
245
|
+
# Case 3: Complex numbers need special handling
|
|
246
|
+
elif tensor_props.is_complex():
|
|
247
|
+
# Complex types: fallback to original logic for now
|
|
248
|
+
# TODO: Could be improved to use statistical info if available
|
|
249
|
+
float_dtype = torch.float32 if torch_dtype == torch.complex64 else torch.float64
|
|
250
|
+
real_part = torch.rand(shape, dtype=float_dtype, device=device)
|
|
251
|
+
imag_part = torch.rand(shape, dtype=float_dtype, device=device)
|
|
252
|
+
return torch.complex(real_part, imag_part)
|
|
253
|
+
|
|
254
|
+
# Case 4: Handle other unsigned integers (like uint32) which fail with random_()
|
|
255
|
+
elif "uint" in str(torch_dtype):
|
|
256
|
+
if has_stats:
|
|
257
|
+
# Generate tensor with statistical properties for unsigned integers
|
|
258
|
+
if std == 0 or min_val == max_val:
|
|
259
|
+
return torch.full(shape, int(mean), dtype=torch_dtype, device=device)
|
|
260
|
+
tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
261
|
+
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
262
|
+
return torch.round(tensor).to(torch_dtype)
|
|
263
|
+
else:
|
|
264
|
+
# Fallback to original random generation
|
|
265
|
+
return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
|
|
266
|
+
|
|
267
|
+
# Case 5: If we don't know how to handle the type, raise an error
|
|
268
|
+
else:
|
|
269
|
+
raise NotImplementedError(
|
|
270
|
+
f"Random data generation not implemented for dtype: {torch_dtype}"
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _create_tensor(arg_info) -> torch.Tensor:
|
|
275
|
+
"""
|
|
276
|
+
Create a tensor with stride and storage offset if needed.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
arg_info (dict): Argument information including dtype, shape, stride, etc.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
torch.Tensor: The created tensor with applied stride/offset
|
|
283
|
+
"""
|
|
284
|
+
tensor = _create_base_tensor(arg_info)
|
|
285
|
+
|
|
286
|
+
# Apply stride and storage offset if needed
|
|
287
|
+
shape = arg_info.get("shape", [])
|
|
288
|
+
stride = arg_info.get("stride")
|
|
289
|
+
storage_offset = arg_info.get("storage_offset", 0)
|
|
290
|
+
return _apply_stride_and_offset(tensor, shape, stride, storage_offset)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def _create_arg_from_info(arg_info):
|
|
294
|
+
"""
|
|
295
|
+
Recursively construct a kernel argument from its JSON schema.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
arg_info (dict): JSON object describing a single argument, including
|
|
299
|
+
fields like 'type', 'value', 'dtype', 'shape', 'device', etc.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
Any: The constructed Python object suitable for kernel invocation.
|
|
303
|
+
|
|
304
|
+
Raises:
|
|
305
|
+
RuntimeError: When required optional dependencies are missing.
|
|
306
|
+
NotImplementedError: When a dtype or type is not supported yet.
|
|
307
|
+
"""
|
|
308
|
+
arg_type = arg_info.get("type")
|
|
309
|
+
|
|
310
|
+
if arg_type == "NoneType":
|
|
311
|
+
return None
|
|
312
|
+
|
|
313
|
+
if arg_type in ["int", "bool", "str", "float"]:
|
|
314
|
+
return arg_info.get("value")
|
|
315
|
+
|
|
316
|
+
elif arg_type == "tensor":
|
|
317
|
+
return _create_tensor(arg_info)
|
|
318
|
+
|
|
319
|
+
elif arg_type == "triton_kernels.tensor.Tensor":
|
|
320
|
+
if not TRITON_KERNELS_CUSTOM_TYPES:
|
|
321
|
+
raise RuntimeError(
|
|
322
|
+
"Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Tensor."
|
|
323
|
+
)
|
|
324
|
+
Tensor, Storage, StridedLayout = _get_triton_tensor_types()
|
|
325
|
+
storage = _create_arg_from_info(arg_info.get("storage"))
|
|
326
|
+
dtype_str = arg_info.get("dtype")
|
|
327
|
+
torch_dtype = getattr(torch, dtype_str.split(".")[-1])
|
|
328
|
+
return Tensor(
|
|
329
|
+
storage=storage,
|
|
330
|
+
shape=arg_info.get("shape"),
|
|
331
|
+
shape_max=arg_info.get("shape_max"),
|
|
332
|
+
dtype=torch_dtype,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
elif arg_type == "triton_kernels.tensor.Storage":
|
|
336
|
+
if not TRITON_KERNELS_CUSTOM_TYPES:
|
|
337
|
+
raise RuntimeError(
|
|
338
|
+
"Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Storage."
|
|
339
|
+
)
|
|
340
|
+
Tensor, Storage, StridedLayout = _get_triton_tensor_types()
|
|
341
|
+
data = _create_arg_from_info(arg_info.get("data"))
|
|
342
|
+
layout = _create_arg_from_info(arg_info.get("layout"))
|
|
343
|
+
return Storage(data=data, layout=layout)
|
|
344
|
+
|
|
345
|
+
elif arg_type == "StridedLayout":
|
|
346
|
+
if not TRITON_KERNELS_CUSTOM_TYPES:
|
|
347
|
+
raise RuntimeError(
|
|
348
|
+
"Optional dependency 'triton_kernels.tensor' is not installed; cannot construct StridedLayout."
|
|
349
|
+
)
|
|
350
|
+
Tensor, Storage, StridedLayout = _get_triton_tensor_types()
|
|
351
|
+
return StridedLayout(shape=arg_info.get("initial_shape"))
|
|
352
|
+
|
|
353
|
+
elif arg_type == "dtype":
|
|
354
|
+
dtype_repr = arg_info.get("repr")
|
|
355
|
+
if dtype_repr in TRITON_DTYPE_MAP:
|
|
356
|
+
return TRITON_DTYPE_MAP[dtype_repr]
|
|
357
|
+
else:
|
|
358
|
+
raise NotImplementedError(f"Unsupported Triton dtype: {dtype_repr}")
|
|
359
|
+
|
|
360
|
+
else:
|
|
361
|
+
print(f"Warning: Unhandled argument type '{arg_type}'. Returning None.")
|
|
362
|
+
return None
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def determine_output_paths(out_dir: str, kernel_name: str, template: str):
|
|
366
|
+
"""
|
|
367
|
+
Determine output file paths for reproducer script and context data.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
out_dir: Output directory path. If empty, uses default location.
|
|
371
|
+
kernel_name: Name of the kernel for default directory naming.
|
|
372
|
+
template: Template name or path. If a path, extracts the filename.
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
Tuple of (python_script_path, json_context_path) as Path objects.
|
|
376
|
+
"""
|
|
377
|
+
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
378
|
+
output_directory = Path(out_dir) / kernel_name
|
|
379
|
+
output_directory.mkdir(parents=True, exist_ok=True)
|
|
380
|
+
|
|
381
|
+
# Extract template name from path if needed
|
|
382
|
+
template_name = (
|
|
383
|
+
Path(template).stem if "/" in template or "\\" in template else template
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
filename_parts = ["repro"]
|
|
387
|
+
if template != "example":
|
|
388
|
+
filename_parts.append(template_name)
|
|
389
|
+
filename_parts.append(timestamp)
|
|
390
|
+
filename = "_".join(filename_parts) + ".py"
|
|
391
|
+
out_py_path = output_directory / filename
|
|
392
|
+
temp_json_path = output_directory / f"repro_context_{timestamp}.json"
|
|
393
|
+
|
|
394
|
+
return out_py_path, temp_json_path
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def _generate_import_statements(kernel_info) -> tuple[str, str]:
|
|
398
|
+
"""
|
|
399
|
+
Generate (sys.path insertion statement, import statement) for the kernel.
|
|
400
|
+
|
|
401
|
+
Strategy:
|
|
402
|
+
- Always add the kernel file's parent directory to sys.path.
|
|
403
|
+
- If the filename (without .py) is a valid identifier, import using that
|
|
404
|
+
module name: `from <stem> import <func> as imported_kernel_function`.
|
|
405
|
+
- Otherwise, fall back to dynamic import via importlib.util and bind
|
|
406
|
+
`imported_kernel_function` from the loaded module.
|
|
407
|
+
"""
|
|
408
|
+
file_path = Path(kernel_info.file_path)
|
|
409
|
+
function_name = kernel_info.function_name
|
|
410
|
+
|
|
411
|
+
if not file_path or not function_name:
|
|
412
|
+
raise ValueError("Kernel file path or function name missing from context.")
|
|
413
|
+
|
|
414
|
+
# Always add the file's parent directory to sys.path
|
|
415
|
+
sys_stmt = (
|
|
416
|
+
"import sys; p = r'" + str(file_path.parent) + "';\n"
|
|
417
|
+
"if p not in sys.path: sys.path.insert(0, p)"
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
module_name = file_path.with_suffix("").name
|
|
421
|
+
if module_name.isidentifier():
|
|
422
|
+
import_stmt = (
|
|
423
|
+
f"from {module_name} import {function_name} as imported_kernel_function"
|
|
424
|
+
)
|
|
425
|
+
logger.debug("Generated direct import statement: %s", import_stmt)
|
|
426
|
+
return sys_stmt, import_stmt
|
|
427
|
+
|
|
428
|
+
# Fallback: dynamic import when filename is not a valid identifier
|
|
429
|
+
import_stmt = (
|
|
430
|
+
"import importlib.util\n"
|
|
431
|
+
f"_spec = importlib.util.spec_from_file_location('kernel_mod', r'{str(file_path)}')\n"
|
|
432
|
+
"_mod = importlib.util.module_from_spec(_spec)\n"
|
|
433
|
+
"_spec.loader.exec_module(_mod)\n"
|
|
434
|
+
f"imported_kernel_function = getattr(_mod, '{function_name}')"
|
|
435
|
+
)
|
|
436
|
+
logger.debug("Generated dynamic import for file: %s", file_path)
|
|
437
|
+
return sys_stmt, import_stmt
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def _parse_kernel_signature(kernel_source_code: str) -> tuple[list[str], list[str]]:
|
|
441
|
+
"""
|
|
442
|
+
Parses a Triton kernel's source code using AST to distinguish positional args
|
|
443
|
+
from keyword args (those with default values).
|
|
444
|
+
|
|
445
|
+
This implementation uses Python's ast module for robust parsing that handles:
|
|
446
|
+
- Return type annotations (e.g., -> None)
|
|
447
|
+
- Complex type annotations (e.g., Callable[[dict[str, int]], list[Tensor]])
|
|
448
|
+
- Decorators (e.g., @triton.jit)
|
|
449
|
+
- Keyword-only arguments (after *)
|
|
450
|
+
- All Python syntax variations
|
|
451
|
+
|
|
452
|
+
Args:
|
|
453
|
+
kernel_source_code: Python source code containing the kernel function
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
tuple[list[str], list[str]]: (positional_args, keyword_args)
|
|
457
|
+
|
|
458
|
+
Raises:
|
|
459
|
+
ValueError: If parsing fails or no function definition is found
|
|
460
|
+
"""
|
|
461
|
+
try:
|
|
462
|
+
# Parse source code into AST
|
|
463
|
+
tree = ast.parse(kernel_source_code)
|
|
464
|
+
|
|
465
|
+
# Find the first function definition
|
|
466
|
+
func_def = None
|
|
467
|
+
for node in ast.walk(tree):
|
|
468
|
+
if isinstance(node, ast.FunctionDef):
|
|
469
|
+
func_def = node
|
|
470
|
+
break
|
|
471
|
+
|
|
472
|
+
if not func_def:
|
|
473
|
+
raise ValueError("No function definition found in source code")
|
|
474
|
+
|
|
475
|
+
positional_args = []
|
|
476
|
+
keyword_args = []
|
|
477
|
+
|
|
478
|
+
# Extract function arguments
|
|
479
|
+
args = func_def.args
|
|
480
|
+
|
|
481
|
+
# Calculate number of positional arguments
|
|
482
|
+
# defaults are right-aligned with args, so:
|
|
483
|
+
# num_positional = total_args - num_defaults
|
|
484
|
+
num_defaults = len(args.defaults)
|
|
485
|
+
num_args = len(args.args)
|
|
486
|
+
num_positional = num_args - num_defaults
|
|
487
|
+
|
|
488
|
+
# Classify regular arguments
|
|
489
|
+
for i, arg in enumerate(args.args):
|
|
490
|
+
arg_name = arg.arg
|
|
491
|
+
if i < num_positional:
|
|
492
|
+
positional_args.append(arg_name)
|
|
493
|
+
else:
|
|
494
|
+
keyword_args.append(arg_name)
|
|
495
|
+
|
|
496
|
+
# Handle keyword-only arguments (after *)
|
|
497
|
+
for arg in args.kwonlyargs:
|
|
498
|
+
keyword_args.append(arg.arg)
|
|
499
|
+
|
|
500
|
+
logger.debug("Parsed positional args: %s", positional_args)
|
|
501
|
+
logger.debug("Parsed keyword args: %s", keyword_args)
|
|
502
|
+
return positional_args, keyword_args
|
|
503
|
+
|
|
504
|
+
except SyntaxError as e:
|
|
505
|
+
raise ValueError(
|
|
506
|
+
f"Invalid Python syntax in kernel source at line {e.lineno}: {e.msg}"
|
|
507
|
+
) from e
|
|
508
|
+
except Exception as e:
|
|
509
|
+
raise ValueError(f"Failed to parse kernel signature: {e}") from e
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def _generate_invocation_snippet(
|
|
513
|
+
positional_args: list[str], keyword_args: list[str]
|
|
514
|
+
) -> str:
|
|
515
|
+
"""Generates a single-line Python code snippet for kernel invocation."""
|
|
516
|
+
# Prepare positional args for direct injection into the call
|
|
517
|
+
pos_args_str = ", ".join([f'args_dict["{arg}"]' for arg in positional_args])
|
|
518
|
+
|
|
519
|
+
# Prepare keyword args for direct injection
|
|
520
|
+
kw_args_str = ", ".join([f'{arg}=args_dict["{arg}"]' for arg in keyword_args])
|
|
521
|
+
|
|
522
|
+
# Combine them, ensuring proper comma separation
|
|
523
|
+
all_args = []
|
|
524
|
+
if pos_args_str:
|
|
525
|
+
all_args.append(pos_args_str)
|
|
526
|
+
if kw_args_str:
|
|
527
|
+
all_args.append(kw_args_str)
|
|
528
|
+
|
|
529
|
+
# Create the single-line call
|
|
530
|
+
return f"imported_kernel_function[tuple(grid)]({', '.join(all_args)})"
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
def format_python_code(code: str) -> str:
|
|
534
|
+
"""
|
|
535
|
+
Format Python code using black and organize imports using isort if available.
|
|
536
|
+
|
|
537
|
+
Args:
|
|
538
|
+
code: Python code string to format.
|
|
539
|
+
|
|
540
|
+
Returns:
|
|
541
|
+
Formatted code with organized imports if tools are available,
|
|
542
|
+
otherwise returns original code.
|
|
543
|
+
"""
|
|
544
|
+
# Step 1: Organize imports using isort if available
|
|
545
|
+
try:
|
|
546
|
+
import isort
|
|
547
|
+
|
|
548
|
+
# Configure isort to move imports to the top more aggressively
|
|
549
|
+
code = isort.code(
|
|
550
|
+
code,
|
|
551
|
+
float_to_top=True, # Move imports to the top
|
|
552
|
+
remove_redundant_aliases=True, # Remove redundant aliases like 'import torch as torch'
|
|
553
|
+
force_single_line=False,
|
|
554
|
+
line_length=88,
|
|
555
|
+
profile="black", # Compatible with black formatter
|
|
556
|
+
treat_comments_as_code=[], # Don't treat comments as code barriers
|
|
557
|
+
treat_all_comments_as_code=False, # Allow imports to move past comments
|
|
558
|
+
)
|
|
559
|
+
logger.debug("Successfully organized imports")
|
|
560
|
+
except ImportError:
|
|
561
|
+
logger.debug("isort library not available, import organization will be skipped")
|
|
562
|
+
except Exception as e:
|
|
563
|
+
logger.warning(f"Failed to organize imports: {e}")
|
|
564
|
+
|
|
565
|
+
# Step 2: Format code using black if available
|
|
566
|
+
try:
|
|
567
|
+
import black
|
|
568
|
+
|
|
569
|
+
formatted_code = black.format_str(code, mode=black.Mode())
|
|
570
|
+
logger.debug("Successfully formatted generated code")
|
|
571
|
+
return formatted_code
|
|
572
|
+
except ImportError:
|
|
573
|
+
logger.debug("black library not available, code formatting will be skipped")
|
|
574
|
+
return code
|
|
575
|
+
except black.InvalidInput as e:
|
|
576
|
+
logger.warning(f"Failed to format generated code: {e}")
|
|
577
|
+
return code
|
|
578
|
+
except Exception as e:
|
|
579
|
+
logger.warning(f"Unexpected error while formatting code: {e}")
|
|
580
|
+
return code
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# We'd like to sperate structured logging module and tritonparse module as much as possible. So, put the shared variables here.
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
# The compilation information will be stored to /logs/DEFAULT_TRACE_FILE_PREFIX by default
|
|
6
|
+
# unless other flags disable or set another store. Add USER to avoid permission issues in shared servers.
|
|
7
|
+
DEFAULT_TRACE_FILE_PREFIX = (
|
|
8
|
+
f"dedicated_log_triton_trace_{os.getenv('USER', 'unknown')}_"
|
|
9
|
+
)
|
|
10
|
+
DEFAULT_TRACE_FILE_PREFIX_WITHOUT_USER = "dedicated_log_triton_trace_"
|
|
11
|
+
# Return True if test outputs (e.g., temp dirs) should be preserved.
|
|
12
|
+
TEST_KEEP_OUTPUT = os.getenv("TEST_KEEP_OUTPUT", "0") in ["1", "true", "True"]
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Tuple
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SourceType(str, Enum):
|
|
9
|
+
"""Enumeration of supported source types for OSS only."""
|
|
10
|
+
|
|
11
|
+
LOCAL = "local"
|
|
12
|
+
LOCAL_FILE = "local_file"
|
|
13
|
+
|
|
14
|
+
@classmethod
|
|
15
|
+
def _missing_(cls, value: object) -> "SourceType":
|
|
16
|
+
"""
|
|
17
|
+
Handle unknown source types by raising a ValueError.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
value: The unknown value that was attempted to be used as a SourceType
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Never returns, always raises ValueError
|
|
24
|
+
"""
|
|
25
|
+
valid_types = [e.value for e in cls]
|
|
26
|
+
raise ValueError(
|
|
27
|
+
f"Invalid source type '{value}'. Valid types are: {', '.join(valid_types)}"
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Source:
|
|
32
|
+
"""Represents a source of logs to parse."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, source_str: str, verbose: bool = False):
|
|
35
|
+
"""
|
|
36
|
+
Initialize a Source object by parsing the source string.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
source_str: String representing the source
|
|
40
|
+
verbose: Whether to print verbose information
|
|
41
|
+
"""
|
|
42
|
+
self.source_str = source_str
|
|
43
|
+
self.verbose = verbose
|
|
44
|
+
self.type, self.value = self._parse_source()
|
|
45
|
+
|
|
46
|
+
def _parse_source(self) -> Tuple[SourceType, str]:
|
|
47
|
+
# Check if it's a local path
|
|
48
|
+
path = Path(self.source_str)
|
|
49
|
+
if path.is_dir():
|
|
50
|
+
return SourceType.LOCAL, str(path.absolute())
|
|
51
|
+
elif path.is_file():
|
|
52
|
+
return SourceType.LOCAL_FILE, str(path.absolute())
|
|
53
|
+
else:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f"Source '{self.source_str}' is not a valid directory or file"
|
|
56
|
+
)
|