tritonparse 0.2.4.dev20250930071536__py3-none-any.whl → 0.2.4.dev20251002071459__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/__main__.py +5 -0
- tritonparse/reproducer/cli.py +6 -2
- tritonparse/reproducer/orchestrator.py +13 -7
- tritonparse/reproducer/templates/example.py +84 -10
- tritonparse/reproducer/utils.py +74 -11
- tritonparse/structured_logging.py +69 -4
- tritonparse/tools/disasm.py +81 -0
- {tritonparse-0.2.4.dev20250930071536.dist-info → tritonparse-0.2.4.dev20251002071459.dist-info}/METADATA +1 -1
- {tritonparse-0.2.4.dev20250930071536.dist-info → tritonparse-0.2.4.dev20251002071459.dist-info}/RECORD +13 -10
- tritonparse-0.2.4.dev20251002071459.dist-info/entry_points.txt +2 -0
- {tritonparse-0.2.4.dev20250930071536.dist-info → tritonparse-0.2.4.dev20251002071459.dist-info}/WHEEL +0 -0
- {tritonparse-0.2.4.dev20250930071536.dist-info → tritonparse-0.2.4.dev20251002071459.dist-info}/licenses/LICENSE +0 -0
- {tritonparse-0.2.4.dev20250930071536.dist-info → tritonparse-0.2.4.dev20251002071459.dist-info}/top_level.txt +0 -0
tritonparse/__main__.py
ADDED
tritonparse/reproducer/cli.py
CHANGED
|
@@ -5,9 +5,13 @@ def _add_reproducer_args(parser: argparse.ArgumentParser) -> None:
|
|
|
5
5
|
"""Add common arguments for the reproducer to a parser."""
|
|
6
6
|
parser.add_argument("input", help="Path to the ndjson/ndjson.gz log file")
|
|
7
7
|
parser.add_argument(
|
|
8
|
-
"--line
|
|
8
|
+
"--line",
|
|
9
9
|
type=int,
|
|
10
|
-
|
|
10
|
+
default=1,
|
|
11
|
+
help=(
|
|
12
|
+
"The line number (1-based) of the launch event in the input file to reproduce. "
|
|
13
|
+
"Defaults to 1."
|
|
14
|
+
),
|
|
11
15
|
)
|
|
12
16
|
parser.add_argument(
|
|
13
17
|
"--out-dir",
|
|
@@ -18,7 +18,7 @@ def reproduce(
|
|
|
18
18
|
line_index: int,
|
|
19
19
|
out_dir: str,
|
|
20
20
|
template: str,
|
|
21
|
-
):
|
|
21
|
+
) -> dict[str, Path]:
|
|
22
22
|
"""
|
|
23
23
|
Generate a reproducer script from NDJSON trace file.
|
|
24
24
|
|
|
@@ -55,9 +55,15 @@ def reproduce(
|
|
|
55
55
|
"# {{KERNEL_INVOCATION_PLACEHOLDER}}", invocation_snippet
|
|
56
56
|
)
|
|
57
57
|
out_py_path.write_text(final_code, encoding="utf-8")
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
58
|
+
|
|
59
|
+
filepath = context_bundle.kernel_info.file_path
|
|
60
|
+
filepath = "/".join(filepath.split("/")[5:])
|
|
61
|
+
ret = {
|
|
62
|
+
"kernel_src_path": filepath,
|
|
63
|
+
"kernel": context_bundle.kernel_info.function_name,
|
|
64
|
+
"repo_script": str(out_py_path.resolve()),
|
|
65
|
+
"repo_context": str(temp_json_path.resolve()),
|
|
66
|
+
}
|
|
67
|
+
logger.info("REPRODUCER_OUTPUT\n%s", ret)
|
|
68
|
+
|
|
69
|
+
return ret
|
|
@@ -1,6 +1,12 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This file is automatically generated by TritonParse reproducer.
|
|
3
|
+
It contains a smallest testing example for a Triton kernel.
|
|
4
|
+
"""
|
|
5
|
+
|
|
1
6
|
import hashlib
|
|
2
7
|
import importlib
|
|
3
8
|
import json
|
|
9
|
+
import logging
|
|
4
10
|
import sys
|
|
5
11
|
from functools import lru_cache
|
|
6
12
|
from pathlib import Path
|
|
@@ -139,25 +145,62 @@ def _create_arg_from_info(arg_info):
|
|
|
139
145
|
elif arg_type == "tensor":
|
|
140
146
|
if arg_info.get("blob_path"):
|
|
141
147
|
return load_tensor(arg_info.get("blob_path"), arg_info.get("device"))
|
|
148
|
+
|
|
149
|
+
# Extract basic tensor properties
|
|
142
150
|
dtype_str = arg_info.get("dtype")
|
|
143
151
|
try:
|
|
144
152
|
torch_dtype = getattr(torch, dtype_str.split(".")[-1])
|
|
145
153
|
except AttributeError:
|
|
154
|
+
logging.error(f"Unsupported dtype: {dtype_str}. Defaulting to float32.")
|
|
146
155
|
torch_dtype = torch.float32
|
|
147
156
|
|
|
148
157
|
shape = arg_info.get("shape", [])
|
|
149
158
|
device = arg_info.get("device", "cpu")
|
|
150
159
|
|
|
160
|
+
# Extract statistical information if available
|
|
161
|
+
mean = arg_info.get("mean")
|
|
162
|
+
std = arg_info.get("std")
|
|
163
|
+
min_val = arg_info.get("min")
|
|
164
|
+
max_val = arg_info.get("max")
|
|
165
|
+
has_stats = (
|
|
166
|
+
mean is not None
|
|
167
|
+
and std is not None
|
|
168
|
+
and min_val is not None
|
|
169
|
+
and max_val is not None
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
if arg_info.get("tensor_capture_error", False):
|
|
173
|
+
logging.error(
|
|
174
|
+
f"Error: Tensor '{arg_info.get('name', '')}' had capture error. Generating random tensor instead."
|
|
175
|
+
)
|
|
176
|
+
|
|
151
177
|
# Use a dummy tensor to check properties of the dtype
|
|
152
178
|
tensor_props = torch.empty(0, dtype=torch_dtype)
|
|
153
179
|
|
|
154
|
-
# Case 1: Floating point
|
|
180
|
+
# Case 1: Floating point types
|
|
155
181
|
if tensor_props.is_floating_point():
|
|
156
|
-
if
|
|
157
|
-
|
|
158
|
-
|
|
182
|
+
if has_stats:
|
|
183
|
+
# Generate tensor with statistical properties matching original data
|
|
184
|
+
if std == 0 or min_val == max_val:
|
|
185
|
+
# Constant tensor
|
|
186
|
+
return torch.full(shape, mean, dtype=torch_dtype, device=device)
|
|
187
|
+
# Generate normal distribution with mean and std, then clamp to [min, max]
|
|
188
|
+
tensor = (
|
|
189
|
+
torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
190
|
+
)
|
|
191
|
+
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
192
|
+
return tensor.to(torch_dtype)
|
|
159
193
|
else:
|
|
160
|
-
|
|
194
|
+
# Fallback to original random generation
|
|
195
|
+
if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
196
|
+
tmp = torch.rand(shape, dtype=torch.float32, device=device)
|
|
197
|
+
return tmp.to(torch_dtype)
|
|
198
|
+
else:
|
|
199
|
+
return torch.empty(
|
|
200
|
+
shape, dtype=torch_dtype, device=device
|
|
201
|
+
).random_()
|
|
202
|
+
|
|
203
|
+
# Case 2: Integer types
|
|
161
204
|
elif torch_dtype in [
|
|
162
205
|
torch.int8,
|
|
163
206
|
torch.int16,
|
|
@@ -166,9 +209,26 @@ def _create_arg_from_info(arg_info):
|
|
|
166
209
|
torch.uint8,
|
|
167
210
|
torch.bool,
|
|
168
211
|
]:
|
|
169
|
-
|
|
170
|
-
|
|
212
|
+
if has_stats and torch_dtype != torch.bool:
|
|
213
|
+
# Generate tensor with statistical properties, then round for integers
|
|
214
|
+
if std == 0 or min_val == max_val:
|
|
215
|
+
# Constant tensor
|
|
216
|
+
return torch.full(
|
|
217
|
+
shape, int(mean), dtype=torch_dtype, device=device
|
|
218
|
+
)
|
|
219
|
+
tensor = (
|
|
220
|
+
torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
221
|
+
)
|
|
222
|
+
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
223
|
+
return torch.round(tensor).to(torch_dtype)
|
|
224
|
+
else:
|
|
225
|
+
# Fallback to original random generation
|
|
226
|
+
return torch.empty(shape, dtype=torch_dtype, device=device).random_()
|
|
227
|
+
|
|
228
|
+
# Case 3: Complex numbers need special handling
|
|
171
229
|
elif tensor_props.is_complex():
|
|
230
|
+
# Complex types: fallback to original logic for now
|
|
231
|
+
# TODO: Could be improved to use statistical info if available
|
|
172
232
|
float_dtype = (
|
|
173
233
|
torch.float32 if torch_dtype == torch.complex64 else torch.float64
|
|
174
234
|
)
|
|
@@ -176,10 +236,24 @@ def _create_arg_from_info(arg_info):
|
|
|
176
236
|
imag_part = torch.rand(shape, dtype=float_dtype, device=device)
|
|
177
237
|
return torch.complex(real_part, imag_part)
|
|
178
238
|
|
|
179
|
-
# Case
|
|
239
|
+
# Case 4: Handle other unsigned integers (like uint32) which fail with random_()
|
|
180
240
|
elif "uint" in str(torch_dtype):
|
|
181
|
-
|
|
182
|
-
|
|
241
|
+
if has_stats:
|
|
242
|
+
# Generate tensor with statistical properties for unsigned integers
|
|
243
|
+
if std == 0 or min_val == max_val:
|
|
244
|
+
return torch.full(
|
|
245
|
+
shape, int(mean), dtype=torch_dtype, device=device
|
|
246
|
+
)
|
|
247
|
+
tensor = (
|
|
248
|
+
torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
249
|
+
)
|
|
250
|
+
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
251
|
+
return torch.round(tensor).to(torch_dtype)
|
|
252
|
+
else:
|
|
253
|
+
# Fallback to original random generation
|
|
254
|
+
return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
|
|
255
|
+
|
|
256
|
+
# Case 5: If we don't know how to handle the type, raise an error
|
|
183
257
|
else:
|
|
184
258
|
raise NotImplementedError(
|
|
185
259
|
f"Random data generation not implemented for dtype: {torch_dtype}"
|
tritonparse/reproducer/utils.py
CHANGED
|
@@ -84,25 +84,57 @@ def _create_arg_from_info(arg_info):
|
|
|
84
84
|
elif arg_type == "tensor":
|
|
85
85
|
if arg_info.get("blob_path"):
|
|
86
86
|
return load_tensor(arg_info.get("blob_path"), arg_info.get("device"))
|
|
87
|
+
|
|
88
|
+
# Extract basic tensor properties
|
|
87
89
|
dtype_str = arg_info.get("dtype")
|
|
88
90
|
try:
|
|
89
91
|
torch_dtype = getattr(torch, dtype_str.split(".")[-1])
|
|
90
92
|
except AttributeError:
|
|
93
|
+
logger.error(f"Unsupported dtype: {dtype_str}. Defaulting to float32.")
|
|
91
94
|
torch_dtype = torch.float32
|
|
92
95
|
|
|
93
96
|
shape = arg_info.get("shape", [])
|
|
94
97
|
device = arg_info.get("device", "cpu")
|
|
95
98
|
|
|
99
|
+
# Extract statistical information if available
|
|
100
|
+
mean = arg_info.get("mean")
|
|
101
|
+
std = arg_info.get("std")
|
|
102
|
+
min_val = arg_info.get("min")
|
|
103
|
+
max_val = arg_info.get("max")
|
|
104
|
+
has_stats = (
|
|
105
|
+
mean is not None
|
|
106
|
+
and std is not None
|
|
107
|
+
and min_val is not None
|
|
108
|
+
and max_val is not None
|
|
109
|
+
)
|
|
110
|
+
|
|
96
111
|
# Use a dummy tensor to check properties of the dtype
|
|
97
112
|
tensor_props = torch.empty(0, dtype=torch_dtype)
|
|
98
113
|
|
|
99
|
-
# Case 1: Floating point
|
|
114
|
+
# Case 1: Floating point types
|
|
100
115
|
if tensor_props.is_floating_point():
|
|
101
|
-
if
|
|
102
|
-
|
|
103
|
-
|
|
116
|
+
if has_stats:
|
|
117
|
+
# Generate tensor with statistical properties matching original data
|
|
118
|
+
if std == 0 or min_val == max_val:
|
|
119
|
+
# Constant tensor
|
|
120
|
+
return torch.full(shape, mean, dtype=torch_dtype, device=device)
|
|
121
|
+
# Generate normal distribution with mean and std, then clamp to [min, max]
|
|
122
|
+
tensor = (
|
|
123
|
+
torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
124
|
+
)
|
|
125
|
+
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
126
|
+
return tensor.to(torch_dtype)
|
|
104
127
|
else:
|
|
105
|
-
|
|
128
|
+
# Fallback to original random generation
|
|
129
|
+
if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
130
|
+
tmp = torch.rand(shape, dtype=torch.float32, device=device)
|
|
131
|
+
return tmp.to(torch_dtype)
|
|
132
|
+
else:
|
|
133
|
+
return torch.empty(
|
|
134
|
+
shape, dtype=torch_dtype, device=device
|
|
135
|
+
).random_()
|
|
136
|
+
|
|
137
|
+
# Case 2: Integer types
|
|
106
138
|
elif torch_dtype in [
|
|
107
139
|
torch.int8,
|
|
108
140
|
torch.int16,
|
|
@@ -111,9 +143,26 @@ def _create_arg_from_info(arg_info):
|
|
|
111
143
|
torch.uint8,
|
|
112
144
|
torch.bool,
|
|
113
145
|
]:
|
|
114
|
-
|
|
115
|
-
|
|
146
|
+
if has_stats and torch_dtype != torch.bool:
|
|
147
|
+
# Generate tensor with statistical properties, then round for integers
|
|
148
|
+
if std == 0 or min_val == max_val:
|
|
149
|
+
# Constant tensor
|
|
150
|
+
return torch.full(
|
|
151
|
+
shape, int(mean), dtype=torch_dtype, device=device
|
|
152
|
+
)
|
|
153
|
+
tensor = (
|
|
154
|
+
torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
155
|
+
)
|
|
156
|
+
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
157
|
+
return torch.round(tensor).to(torch_dtype)
|
|
158
|
+
else:
|
|
159
|
+
# Fallback to original random generation
|
|
160
|
+
return torch.empty(shape, dtype=torch_dtype, device=device).random_()
|
|
161
|
+
|
|
162
|
+
# Case 3: Complex numbers need special handling
|
|
116
163
|
elif tensor_props.is_complex():
|
|
164
|
+
# Complex types: fallback to original logic for now
|
|
165
|
+
# TODO: Could be improved to use statistical info if available
|
|
117
166
|
float_dtype = (
|
|
118
167
|
torch.float32 if torch_dtype == torch.complex64 else torch.float64
|
|
119
168
|
)
|
|
@@ -121,14 +170,29 @@ def _create_arg_from_info(arg_info):
|
|
|
121
170
|
imag_part = torch.rand(shape, dtype=float_dtype, device=device)
|
|
122
171
|
return torch.complex(real_part, imag_part)
|
|
123
172
|
|
|
124
|
-
# Case
|
|
173
|
+
# Case 4: Handle other unsigned integers (like uint32) which fail with random_()
|
|
125
174
|
elif "uint" in str(torch_dtype):
|
|
126
|
-
|
|
127
|
-
|
|
175
|
+
if has_stats:
|
|
176
|
+
# Generate tensor with statistical properties for unsigned integers
|
|
177
|
+
if std == 0 or min_val == max_val:
|
|
178
|
+
return torch.full(
|
|
179
|
+
shape, int(mean), dtype=torch_dtype, device=device
|
|
180
|
+
)
|
|
181
|
+
tensor = (
|
|
182
|
+
torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
183
|
+
)
|
|
184
|
+
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
185
|
+
return torch.round(tensor).to(torch_dtype)
|
|
186
|
+
else:
|
|
187
|
+
# Fallback to original random generation
|
|
188
|
+
return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
|
|
189
|
+
|
|
190
|
+
# Case 5: If we don't know how to handle the type, raise an error
|
|
128
191
|
else:
|
|
129
192
|
raise NotImplementedError(
|
|
130
193
|
f"Random data generation not implemented for dtype: {torch_dtype}"
|
|
131
194
|
)
|
|
195
|
+
|
|
132
196
|
elif arg_type == "triton_kernels.tensor.Tensor":
|
|
133
197
|
if not TRITON_KERNELS_CUSTOM_TYPES:
|
|
134
198
|
raise RuntimeError(
|
|
@@ -162,7 +226,6 @@ def _create_arg_from_info(arg_info):
|
|
|
162
226
|
)
|
|
163
227
|
Tensor, Storage, StridedLayout = _get_triton_tensor_types()
|
|
164
228
|
return StridedLayout(shape=arg_info.get("initial_shape"))
|
|
165
|
-
|
|
166
229
|
else:
|
|
167
230
|
print(f"Warning: Unhandled argument type '{arg_type}'. Returning None.")
|
|
168
231
|
return None
|
|
@@ -10,6 +10,7 @@ import json
|
|
|
10
10
|
import logging
|
|
11
11
|
import math
|
|
12
12
|
import os
|
|
13
|
+
import subprocess
|
|
13
14
|
from collections import defaultdict
|
|
14
15
|
from collections.abc import Mapping
|
|
15
16
|
from dataclasses import asdict, is_dataclass
|
|
@@ -41,6 +42,24 @@ TRITONPARSE_KERNEL_ALLOWLIST = os.environ.get("TRITONPARSE_KERNEL_ALLOWLIST", No
|
|
|
41
42
|
_KERNEL_ALLOWLIST_PATTERNS: Optional[List[str]] = None
|
|
42
43
|
# Enable launch trace. WARNNING: it will overwrite launch_metadata function for each triton kernel.
|
|
43
44
|
TRITON_TRACE_LAUNCH = os.getenv("TRITON_TRACE_LAUNCH", None) in ["1", "true", "True"]
|
|
45
|
+
# Enable more tensor information collection in trace logs.
|
|
46
|
+
TRITONPARSE_MORE_TENSOR_INFORMATION = os.getenv(
|
|
47
|
+
"TRITONPARSE_MORE_TENSOR_INFORMATION", None
|
|
48
|
+
) in ["1", "true", "True"]
|
|
49
|
+
# Inductor compiled kernel's launch tracing needs this flag to be set.
|
|
50
|
+
# If TRITON_TRACE_LAUNCH is enabled, also enable TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK
|
|
51
|
+
TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK = (
|
|
52
|
+
os.getenv("TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK", None) in ["1", "true", "True"]
|
|
53
|
+
or TRITON_TRACE_LAUNCH
|
|
54
|
+
)
|
|
55
|
+
# Enable NVIDIA SASS dump. It requires the CUBIN file to be localable.
|
|
56
|
+
# WARNNING: it will slow down the compilation significantly.
|
|
57
|
+
TRITONPARSE_DUMP_SASS = os.getenv("TRITONPARSE_DUMP_SASS", None) in [
|
|
58
|
+
"1",
|
|
59
|
+
"true",
|
|
60
|
+
"True",
|
|
61
|
+
]
|
|
62
|
+
|
|
44
63
|
# The flag to mark if launch is traced. It is used to avoid initilizing the launch hook twice.
|
|
45
64
|
_trace_launch_enabled = False
|
|
46
65
|
|
|
@@ -252,6 +271,15 @@ def _log_torch_tensor_info(tensor_value):
|
|
|
252
271
|
# Add data_ptr for memory tracking (optional)
|
|
253
272
|
if hasattr(tensor_value, "data_ptr"):
|
|
254
273
|
arg_info["data_ptr"] = hex(tensor_value.data_ptr())
|
|
274
|
+
if TRITONPARSE_MORE_TENSOR_INFORMATION:
|
|
275
|
+
try:
|
|
276
|
+
arg_info["min"] = tensor_value.min().item()
|
|
277
|
+
arg_info["max"] = tensor_value.max().item()
|
|
278
|
+
arg_info["mean"] = tensor_value.float().mean().item()
|
|
279
|
+
arg_info["std"] = tensor_value.float().std().item()
|
|
280
|
+
except (RuntimeError, ValueError, TypeError) as e:
|
|
281
|
+
log.error(f"Error computing additional tensor statistics: {e}")
|
|
282
|
+
arg_info["tensor_capture_error"] = str(e)
|
|
255
283
|
return arg_info
|
|
256
284
|
|
|
257
285
|
|
|
@@ -466,6 +494,26 @@ def extract_file_content(trace_data: Dict[str, Any], metadata_group: Dict[str, s
|
|
|
466
494
|
message = f"<error reading file: {str(e)}>"
|
|
467
495
|
trace_data["file_content"][ir_filename] = message
|
|
468
496
|
log.debug(f"Error reading file {file_path}: {e}")
|
|
497
|
+
cubin_keys = [key for key in metadata_group.keys() if key.endswith(".cubin")]
|
|
498
|
+
cubin_path = metadata_group[cubin_keys[0]] if cubin_keys else None
|
|
499
|
+
|
|
500
|
+
if TRITONPARSE_DUMP_SASS and cubin_path:
|
|
501
|
+
filename_no_ext = os.path.splitext(os.path.basename(cubin_path))[0]
|
|
502
|
+
sass_filename = f"{filename_no_ext}.sass"
|
|
503
|
+
try:
|
|
504
|
+
import tritonparse.tools.disasm
|
|
505
|
+
|
|
506
|
+
sass_content = tritonparse.tools.disasm.extract(cubin_path)
|
|
507
|
+
trace_data["file_content"][sass_filename] = sass_content
|
|
508
|
+
except subprocess.CalledProcessError as e:
|
|
509
|
+
message = f"<nvdisasm failed: {str(e)}>"
|
|
510
|
+
trace_data["file_content"][sass_filename] = message
|
|
511
|
+
except OSError as e:
|
|
512
|
+
message = f"<error reading cubin file: {str(e)}>"
|
|
513
|
+
trace_data["file_content"][sass_filename] = message
|
|
514
|
+
except Exception as e:
|
|
515
|
+
message = f"<error dumping SASS: {str(e)}>"
|
|
516
|
+
trace_data["file_content"][sass_filename] = message
|
|
469
517
|
|
|
470
518
|
|
|
471
519
|
def extract_metadata_from_src(trace_data, src):
|
|
@@ -1068,7 +1116,7 @@ def init_basic(trace_folder: Optional[str] = None):
|
|
|
1068
1116
|
"""
|
|
1069
1117
|
Initialize the basic logging system for Triton compilation.
|
|
1070
1118
|
|
|
1071
|
-
This function sets up the basic logging system for Triton kernel compilation
|
|
1119
|
+
This function sets up the basic logging system for Triton kernel compilation.
|
|
1072
1120
|
|
|
1073
1121
|
Args:
|
|
1074
1122
|
trace_folder (Optional[str]): The folder to store the trace files.
|
|
@@ -1097,17 +1145,34 @@ def init_basic(trace_folder: Optional[str] = None):
|
|
|
1097
1145
|
maybe_enable_trace_launch()
|
|
1098
1146
|
|
|
1099
1147
|
|
|
1100
|
-
def init(
|
|
1148
|
+
def init(
|
|
1149
|
+
trace_folder: Optional[str] = None,
|
|
1150
|
+
enable_trace_launch: bool = False,
|
|
1151
|
+
enable_more_tensor_information: bool = False,
|
|
1152
|
+
enable_sass_dump: Optional[bool] = False,
|
|
1153
|
+
):
|
|
1101
1154
|
"""
|
|
1102
|
-
This function is a wrapper around init_basic() that also sets up the compilation listener.
|
|
1155
|
+
This function is a wrapper around init_basic() that also sets up the compilation listener. Its arguments have higher priority than the environment variables for same settings.
|
|
1103
1156
|
|
|
1104
1157
|
Args:
|
|
1105
1158
|
trace_folder (Optional[str]): The folder to store the trace files.
|
|
1106
1159
|
enable_trace_launch (bool): Whether to enable the trace launch hook.
|
|
1160
|
+
enable_more_tensor_information (bool): Whether to enable more tensor information logging.
|
|
1161
|
+
It only works when enable_trace_launch/TRITON_TRACE_LAUNCH is True.
|
|
1162
|
+
enable_sass_dump (Optional[bool]): Whether to enable SASS dumping.
|
|
1107
1163
|
"""
|
|
1108
|
-
global
|
|
1164
|
+
global \
|
|
1165
|
+
TRITON_TRACE_LAUNCH, \
|
|
1166
|
+
TRITONPARSE_MORE_TENSOR_INFORMATION, \
|
|
1167
|
+
TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK, \
|
|
1168
|
+
TRITONPARSE_DUMP_SASS
|
|
1109
1169
|
if enable_trace_launch:
|
|
1110
1170
|
TRITON_TRACE_LAUNCH = True
|
|
1171
|
+
TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK = True
|
|
1172
|
+
if enable_more_tensor_information:
|
|
1173
|
+
TRITONPARSE_MORE_TENSOR_INFORMATION = True
|
|
1174
|
+
if enable_sass_dump:
|
|
1175
|
+
TRITONPARSE_DUMP_SASS = True
|
|
1111
1176
|
|
|
1112
1177
|
init_basic(trace_folder)
|
|
1113
1178
|
from triton import knobs
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
import re
|
|
3
|
+
import subprocess
|
|
4
|
+
|
|
5
|
+
# Regex patterns for nvdisasm output
|
|
6
|
+
NVDISASM_FNAME_RE = re.compile(r"^\s*\.global\s+(\w+)")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def path_to_nvdisasm():
|
|
10
|
+
from triton import knobs
|
|
11
|
+
|
|
12
|
+
return knobs.nvidia.nvdisasm.path
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def is_nvdisasm_available():
|
|
16
|
+
try:
|
|
17
|
+
if path_to_nvdisasm():
|
|
18
|
+
return True
|
|
19
|
+
else:
|
|
20
|
+
return False
|
|
21
|
+
except RuntimeError:
|
|
22
|
+
return False
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def extract(file_path):
|
|
26
|
+
"""Extract SASS from CUBIN using nvdisasm.
|
|
27
|
+
|
|
28
|
+
nvdisasm output is much cleaner than cuobjdump:
|
|
29
|
+
- Single line per instruction (no encoding lines)
|
|
30
|
+
- Labels are already symbolized (.L_x_0 instead of addresses)
|
|
31
|
+
- Source line information is included
|
|
32
|
+
- No need for complex address remapping
|
|
33
|
+
|
|
34
|
+
nvdisasm Documentation:
|
|
35
|
+
https://docs.nvidia.com/cuda/cuda-binary-utilities/index.html
|
|
36
|
+
"""
|
|
37
|
+
nvdisasm = path_to_nvdisasm()
|
|
38
|
+
args = [nvdisasm, "-c", "-gp", "-g", "-gi", file_path]
|
|
39
|
+
sass_str = subprocess.check_output(args)
|
|
40
|
+
sass_lines = sass_str.splitlines()
|
|
41
|
+
line_idx = 0
|
|
42
|
+
|
|
43
|
+
while line_idx < len(sass_lines):
|
|
44
|
+
line = sass_lines[line_idx].decode()
|
|
45
|
+
|
|
46
|
+
# Find function definition (.global function_name)
|
|
47
|
+
while NVDISASM_FNAME_RE.match(line) is None:
|
|
48
|
+
line_idx += 1
|
|
49
|
+
if line_idx >= len(sass_lines):
|
|
50
|
+
return None
|
|
51
|
+
line = sass_lines[line_idx].decode()
|
|
52
|
+
|
|
53
|
+
# Extract function name
|
|
54
|
+
match = NVDISASM_FNAME_RE.match(line)
|
|
55
|
+
if match is None:
|
|
56
|
+
return None
|
|
57
|
+
fname = match.group(1)
|
|
58
|
+
ret = f"Function:{fname}\n"
|
|
59
|
+
|
|
60
|
+
# Find the actual start of function content (.text.kernel_name:)
|
|
61
|
+
text_section_pattern = f".text.{fname}:"
|
|
62
|
+
line_idx += 1
|
|
63
|
+
while line_idx < len(sass_lines):
|
|
64
|
+
line = sass_lines[line_idx].decode().strip()
|
|
65
|
+
if line == text_section_pattern:
|
|
66
|
+
line_idx += 1 # Move past the .text.kernel_name: line
|
|
67
|
+
break
|
|
68
|
+
line_idx += 1
|
|
69
|
+
|
|
70
|
+
# Process all lines until next .headerflags or end of file
|
|
71
|
+
while line_idx < len(sass_lines):
|
|
72
|
+
line = sass_lines[line_idx].decode().rstrip()
|
|
73
|
+
|
|
74
|
+
# Stop if we encounter next function's headerflags
|
|
75
|
+
if line.strip().startswith(".headerflags"):
|
|
76
|
+
break
|
|
77
|
+
ret += line + "\n"
|
|
78
|
+
line_idx += 1
|
|
79
|
+
|
|
80
|
+
ret += "\n"
|
|
81
|
+
return ret
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tritonparse
|
|
3
|
-
Version: 0.2.4.
|
|
3
|
+
Version: 0.2.4.dev20251002071459
|
|
4
4
|
Summary: TritonParse: A Compiler Tracer, Visualizer, and mini-Reproducer Generator for Triton Kernels
|
|
5
5
|
Author-email: Yueming Hao <yhao@meta.com>
|
|
6
6
|
License-Expression: BSD-3-Clause
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
tritonparse/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
tritonparse/__main__.py,sha256=YDBolsfpyAINExxCt7CDhHFno6nzEceE9Hzr3BUs6Hg,62
|
|
2
3
|
tritonparse/common.py,sha256=aT7zIPKEiuvTq_MbHgMREuVIz-gVcsRskSDlvuIOHuQ,13662
|
|
3
4
|
tritonparse/context_manager.py,sha256=M2HI4APCOZw3xQ05F1QAJIa89h-RajZByZJWpZDq7vI,596
|
|
4
5
|
tritonparse/event_diff.py,sha256=yOD6uNxLJroatfx2nEGr-erw24ObOrHU9P6V5pzr8do,4907
|
|
@@ -8,25 +9,27 @@ tritonparse/mapper.py,sha256=prrczfi13P7Aa042OrEBsmRF1HW3jDhwxicANgPkWIM,4150
|
|
|
8
9
|
tritonparse/shared_vars.py,sha256=-c9CvXJSDm9spYhDOJPEQProeT_xl3PaNmqTEYi_u4s,505
|
|
9
10
|
tritonparse/source_type.py,sha256=nmYEQS8rfkIN9BhNhQbkmEvKnvS-3zAxRGLY4TaZdi8,1676
|
|
10
11
|
tritonparse/sourcemap_utils.py,sha256=qsQmTDuEe9yuUVyxSHRbjTR38gi0hvJEijnPkrJVAV4,2037
|
|
11
|
-
tritonparse/structured_logging.py,sha256=
|
|
12
|
+
tritonparse/structured_logging.py,sha256=2M1UwC6eXUMV4ybIQiaibeUgYor2Zjh6S1CVGthOMDs,46720
|
|
12
13
|
tritonparse/tp_logger.py,sha256=vXzY7hMDmVnRBGBhIjFZe3nHZzG5NKKPONGUszJhGgU,242
|
|
13
14
|
tritonparse/trace_processor.py,sha256=QzUOKwnOkBbwTTKBsa5ZMUABPLMJIBFtTcG2SkhO0I8,12771
|
|
14
15
|
tritonparse/utils.py,sha256=cO3c82PJfToW2pDsVicP3dFh1We3UVv3c3NqC_aTb_g,4312
|
|
15
16
|
tritonparse/reproducer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
|
-
tritonparse/reproducer/cli.py,sha256=
|
|
17
|
-
tritonparse/reproducer/orchestrator.py,sha256=
|
|
18
|
-
tritonparse/reproducer/utils.py,sha256=
|
|
17
|
+
tritonparse/reproducer/cli.py,sha256=bhtjD3k8pr7l2R2wmoleL-pGer2YndhUaLGnZq4rRBQ,948
|
|
18
|
+
tritonparse/reproducer/orchestrator.py,sha256=Uy3CSntjzgd1VZrsHKARE0XTqBgpUukGo7C8b37m2JA,2640
|
|
19
|
+
tritonparse/reproducer/utils.py,sha256=UTclw48vH49g6Z2ljJL5DOZ6Rl4UDudyr0PeUySa3p8,13857
|
|
19
20
|
tritonparse/reproducer/ingestion/ndjson.py,sha256=pEujTl5xXW2E2DEW8ngxXQ8qP9oawb90wBVTWHDs1jk,7372
|
|
20
|
-
tritonparse/reproducer/templates/example.py,sha256=
|
|
21
|
+
tritonparse/reproducer/templates/example.py,sha256=RExB1HVcHafopic3RF5_T40uNcRKmCMyLc18Bg94p4A,11686
|
|
21
22
|
tritonparse/reproducer/templates/loader.py,sha256=HqjfThdDVg7q2bYWry78sIaVRkUpkcA8KQDt83YrlVE,1920
|
|
22
23
|
tritonparse/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
23
24
|
tritonparse/tools/decompress_bin_ndjson.py,sha256=kpt7DM_sSA334F1X45xdkP2OR9LuB27Pc50EkGr6CPM,4144
|
|
25
|
+
tritonparse/tools/disasm.py,sha256=c4HmNNoPPeXPQBQkPVcMaHwDHbHNZNxuqXn4UIIs1Z0,2434
|
|
24
26
|
tritonparse/tools/format_fix.py,sha256=Ol0Sjui8D7OzHwbamAfGnq8V5Y63uwNaFTKSORN5HkQ,3867
|
|
25
27
|
tritonparse/tools/load_tensor.py,sha256=tfdmNVd9gsZqO6msQBhbXIhOvUzgc83yF64k2GDWPNk,2122
|
|
26
28
|
tritonparse/tools/prettify_ndjson.py,sha256=r2YlHwFDTHgML7KljRmMsHaDg29q8gOQAgyDKWJhxRM,11062
|
|
27
29
|
tritonparse/tools/readme.md,sha256=w6PWYfYnRgoPArLjxG9rVrpcLUkoVMGuRlbpF-o0IQM,110
|
|
28
|
-
tritonparse-0.2.4.
|
|
29
|
-
tritonparse-0.2.4.
|
|
30
|
-
tritonparse-0.2.4.
|
|
31
|
-
tritonparse-0.2.4.
|
|
32
|
-
tritonparse-0.2.4.
|
|
30
|
+
tritonparse-0.2.4.dev20251002071459.dist-info/licenses/LICENSE,sha256=4ZciugpyN7wcM4L-9pyDh_etvMUeIfBhDTyH1zeZlQM,1515
|
|
31
|
+
tritonparse-0.2.4.dev20251002071459.dist-info/METADATA,sha256=L9kPxL_yQ4BXVxTI0P-q-wP5c2ssREBOuSk6dwZqZRk,6580
|
|
32
|
+
tritonparse-0.2.4.dev20251002071459.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
33
|
+
tritonparse-0.2.4.dev20251002071459.dist-info/entry_points.txt,sha256=7P8TuH_nMXcPl1r8udA96SW8ccvAznZqTpCWLWDnV2o,53
|
|
34
|
+
tritonparse-0.2.4.dev20251002071459.dist-info/top_level.txt,sha256=ITcTKgp3vf_bXV9vixuQU9IrZa3L1EfDSZwvRzRaoJU,12
|
|
35
|
+
tritonparse-0.2.4.dev20251002071459.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|