tritonparse 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/__init__.py +0 -0
- tritonparse/common.py +409 -0
- tritonparse/event_diff.py +120 -0
- tritonparse/extract_source_mappings.py +49 -0
- tritonparse/ir_parser.py +220 -0
- tritonparse/mapper.py +100 -0
- tritonparse/reproducer/__init__.py +21 -0
- tritonparse/reproducer/__main__.py +81 -0
- tritonparse/reproducer/cli.py +37 -0
- tritonparse/reproducer/config.py +15 -0
- tritonparse/reproducer/factory.py +16 -0
- tritonparse/reproducer/ingestion/__init__.py +6 -0
- tritonparse/reproducer/ingestion/ndjson.py +165 -0
- tritonparse/reproducer/orchestrator.py +65 -0
- tritonparse/reproducer/param_generator.py +142 -0
- tritonparse/reproducer/prompts/__init__.py +1 -0
- tritonparse/reproducer/prompts/loader.py +18 -0
- tritonparse/reproducer/providers/__init__.py +1 -0
- tritonparse/reproducer/providers/base.py +14 -0
- tritonparse/reproducer/providers/gemini.py +47 -0
- tritonparse/reproducer/runtime/__init__.py +1 -0
- tritonparse/reproducer/runtime/executor.py +13 -0
- tritonparse/reproducer/utils/io.py +6 -0
- tritonparse/shared_vars.py +9 -0
- tritonparse/source_type.py +56 -0
- tritonparse/sourcemap_utils.py +72 -0
- tritonparse/structured_logging.py +1046 -0
- tritonparse/tools/__init__.py +0 -0
- tritonparse/tools/decompress_bin_ndjson.py +118 -0
- tritonparse/tools/format_fix.py +149 -0
- tritonparse/tools/load_tensor.py +58 -0
- tritonparse/tools/prettify_ndjson.py +315 -0
- tritonparse/tp_logger.py +9 -0
- tritonparse/trace_processor.py +331 -0
- tritonparse/utils.py +156 -0
- tritonparse-0.1.1.dist-info/METADATA +10 -0
- tritonparse-0.1.1.dist-info/RECORD +40 -0
- tritonparse-0.1.1.dist-info/WHEEL +5 -0
- tritonparse-0.1.1.dist-info/licenses/LICENSE +29 -0
- tritonparse-0.1.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1046 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
import atexit
|
|
4
|
+
import fnmatch
|
|
5
|
+
import gzip
|
|
6
|
+
import importlib
|
|
7
|
+
import inspect
|
|
8
|
+
import io
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import math
|
|
12
|
+
import os
|
|
13
|
+
from collections import defaultdict
|
|
14
|
+
from collections.abc import Mapping
|
|
15
|
+
from dataclasses import asdict, is_dataclass
|
|
16
|
+
from datetime import date, datetime
|
|
17
|
+
from enum import Enum
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
|
20
|
+
|
|
21
|
+
from triton.knobs import JITHook, LaunchHook
|
|
22
|
+
|
|
23
|
+
from .shared_vars import DEFAULT_TRACE_FILE_PREFIX
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
log = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
TEXT_FILE_EXTENSIONS = [".ttir", ".ttgir", ".llir", ".ptx", ".amdgcn", ".json"]
|
|
29
|
+
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB limit for file content extraction
|
|
30
|
+
# Enable ndjson output. json format is only for debugging purpose.
|
|
31
|
+
TRITONPARSE_NDJSON = os.getenv("TRITONPARSE_NDJSON", "1") in ["1", "true", "True"]
|
|
32
|
+
# Enable gzip compression for each line in trace files
|
|
33
|
+
TRITON_TRACE_GZIP = os.getenv("TRITON_TRACE_GZIP", "0") in ["1", "true", "True"]
|
|
34
|
+
triton_trace_log = logging.getLogger("tritonparse_trace")
|
|
35
|
+
# The folder to store the triton trace log.
|
|
36
|
+
triton_trace_folder = os.environ.get("TRITON_TRACE", None)
|
|
37
|
+
# Enable debug logging for tritonparse itself
|
|
38
|
+
TRITONPARSE_DEBUG = os.getenv("TRITONPARSE_DEBUG", None) in ["1", "true", "True"]
|
|
39
|
+
# Kernel allowlist for filtering traced kernels. Use comma separated list of fnmatch patterns.
|
|
40
|
+
TRITONPARSE_KERNEL_ALLOWLIST = os.environ.get("TRITONPARSE_KERNEL_ALLOWLIST", None)
|
|
41
|
+
# Parsed kernel allowlist patterns (set during init)
|
|
42
|
+
_KERNEL_ALLOWLIST_PATTERNS: Optional[List[str]] = None
|
|
43
|
+
# Enable launch trace. WARNNING: it will overwrite launch_metadata function for each triton kernel.
|
|
44
|
+
TRITON_TRACE_LAUNCH = os.getenv("TRITON_TRACE_LAUNCH", None) in ["1", "true", "True"]
|
|
45
|
+
# The flag to mark if launch is traced. It is used to avoid initilizing the launch hook twice.
|
|
46
|
+
_trace_launch_enabled = False
|
|
47
|
+
|
|
48
|
+
TRITON_TRACE_HANDLER = None
|
|
49
|
+
if importlib.util.find_spec("torch") is not None:
|
|
50
|
+
TORCH_INSTALLED = True
|
|
51
|
+
import torch
|
|
52
|
+
from torch.utils._traceback import CapturedTraceback
|
|
53
|
+
else:
|
|
54
|
+
TORCH_INSTALLED = False
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class TritonLogRecord(logging.LogRecord):
|
|
58
|
+
"""
|
|
59
|
+
Custom LogRecord class for structured logging of Triton operations.
|
|
60
|
+
|
|
61
|
+
Extends the standard LogRecord with additional attributes for storing
|
|
62
|
+
structured metadata and payload information.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
name,
|
|
68
|
+
level,
|
|
69
|
+
pathname,
|
|
70
|
+
lineno,
|
|
71
|
+
msg,
|
|
72
|
+
args,
|
|
73
|
+
exc_info,
|
|
74
|
+
metadata=None,
|
|
75
|
+
payload=None,
|
|
76
|
+
**kwargs,
|
|
77
|
+
):
|
|
78
|
+
super().__init__(name, level, pathname, lineno, msg, args, exc_info, **kwargs)
|
|
79
|
+
self.metadata: Dict[str, Any] = metadata or {}
|
|
80
|
+
self.payload: Optional[Union[str, Dict[str, Any], list]] = payload
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def create_triton_log_record(
|
|
84
|
+
name=None,
|
|
85
|
+
level=logging.DEBUG,
|
|
86
|
+
pathname=None,
|
|
87
|
+
lineno=None,
|
|
88
|
+
msg="",
|
|
89
|
+
args=(),
|
|
90
|
+
exc_info=None,
|
|
91
|
+
metadata=None,
|
|
92
|
+
payload=None,
|
|
93
|
+
**kwargs,
|
|
94
|
+
):
|
|
95
|
+
"""
|
|
96
|
+
Factory method to create TritonLogRecord instances with sensible defaults.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
name (str, optional): Logger name. Defaults to triton_trace_log.name.
|
|
100
|
+
level (int, optional): Log level. Defaults to DEBUG.
|
|
101
|
+
pathname (str, optional): Path to the file where the log call was made. Defaults to current file.
|
|
102
|
+
lineno (int, optional): Line number where the log call was made. Defaults to current line.
|
|
103
|
+
msg (str, optional): Log message. Defaults to empty string.
|
|
104
|
+
args (tuple, optional): Arguments to interpolate into the message. Defaults to empty tuple.
|
|
105
|
+
exc_info (optional): Exception information. Defaults to None.
|
|
106
|
+
metadata (Dict[str, Any], optional): Structured metadata for the log record. Defaults to empty dict.
|
|
107
|
+
payload (optional): Payload data. Defaults to None.
|
|
108
|
+
**kwargs: Additional keyword arguments for LogRecord
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
TritonLogRecord: A custom log record with structured data
|
|
112
|
+
"""
|
|
113
|
+
if pathname is None:
|
|
114
|
+
pathname = __file__
|
|
115
|
+
if lineno is None:
|
|
116
|
+
lineno = inspect.currentframe().f_back.f_lineno
|
|
117
|
+
if name is None:
|
|
118
|
+
name = triton_trace_log.name
|
|
119
|
+
|
|
120
|
+
record = TritonLogRecord(
|
|
121
|
+
name,
|
|
122
|
+
level,
|
|
123
|
+
pathname,
|
|
124
|
+
lineno,
|
|
125
|
+
msg,
|
|
126
|
+
args,
|
|
127
|
+
exc_info,
|
|
128
|
+
metadata=metadata,
|
|
129
|
+
payload=payload,
|
|
130
|
+
**kwargs,
|
|
131
|
+
)
|
|
132
|
+
return record
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def convert(obj):
|
|
136
|
+
"""
|
|
137
|
+
Recursively converts dataclasses, dictionaries, and lists to their serializable forms.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
obj: The object to convert, which can be a dataclass instance, dictionary, list, or any other type
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
A serializable version of the input object where dataclasses are converted to dictionaries
|
|
144
|
+
"""
|
|
145
|
+
from triton.language.core import dtype
|
|
146
|
+
|
|
147
|
+
# 1. primitives that JSON already supports -------------------------------
|
|
148
|
+
if obj is None or isinstance(obj, (bool, int, str)):
|
|
149
|
+
return obj
|
|
150
|
+
|
|
151
|
+
if isinstance(obj, float):
|
|
152
|
+
# JSON spec forbids NaN/Infinity – keep precision but stay valid
|
|
153
|
+
if math.isfinite(obj):
|
|
154
|
+
return obj
|
|
155
|
+
return str(obj) # "NaN", "inf", "-inf"
|
|
156
|
+
|
|
157
|
+
# 2. simple containers ----------------------------------------------------
|
|
158
|
+
if isinstance(obj, (list, tuple)):
|
|
159
|
+
# Handle namedtuple specially to preserve field names
|
|
160
|
+
if hasattr(obj, "_asdict"):
|
|
161
|
+
return convert(obj._asdict())
|
|
162
|
+
return [convert(x) for x in obj]
|
|
163
|
+
|
|
164
|
+
if isinstance(obj, (set, frozenset)):
|
|
165
|
+
return [convert(x) for x in sorted(obj, key=str)]
|
|
166
|
+
|
|
167
|
+
if isinstance(obj, Mapping):
|
|
168
|
+
return {str(k): convert(v) for k, v in obj.items()}
|
|
169
|
+
|
|
170
|
+
# 3. time, enum, path, bytes ---------------------------------------------
|
|
171
|
+
if isinstance(obj, (datetime, date)):
|
|
172
|
+
return obj.isoformat()
|
|
173
|
+
|
|
174
|
+
if isinstance(obj, Enum):
|
|
175
|
+
return convert(obj.value)
|
|
176
|
+
|
|
177
|
+
if isinstance(obj, Path):
|
|
178
|
+
return str(obj)
|
|
179
|
+
|
|
180
|
+
if is_dataclass(obj):
|
|
181
|
+
return convert(
|
|
182
|
+
asdict(obj)
|
|
183
|
+
) # Convert dataclass to dict and then process that dict
|
|
184
|
+
|
|
185
|
+
# 4. Common Triton constexpr objects
|
|
186
|
+
if isinstance(obj, dtype):
|
|
187
|
+
return f"triton.language.core.dtype('{str(obj)}')"
|
|
188
|
+
log.warning(f"Unknown type: {type(obj)}")
|
|
189
|
+
return str(obj) # Return primitive types as-is
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def maybe_enable_debug_logging():
|
|
193
|
+
"""
|
|
194
|
+
This logging is for logging module itself, not for logging the triton compilation.
|
|
195
|
+
"""
|
|
196
|
+
if TRITONPARSE_DEBUG:
|
|
197
|
+
# Always set debug level if TRITONPARSE_DEBUG is set
|
|
198
|
+
log.setLevel(logging.DEBUG)
|
|
199
|
+
|
|
200
|
+
# Prevent propagation to root logger to avoid duplicate messages
|
|
201
|
+
log.propagate = False
|
|
202
|
+
|
|
203
|
+
# Check if we already have a debug handler
|
|
204
|
+
has_debug_handler = any(
|
|
205
|
+
isinstance(handler, logging.StreamHandler)
|
|
206
|
+
and handler.level <= logging.DEBUG
|
|
207
|
+
for handler in log.handlers
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
if not has_debug_handler:
|
|
211
|
+
log_handler = logging.StreamHandler()
|
|
212
|
+
log_handler.setLevel(logging.DEBUG)
|
|
213
|
+
formatter = logging.Formatter("%(asctime)s[%(levelname)s] %(message)s")
|
|
214
|
+
formatter.default_time_format = "%Y%m%d %H:%M:%S"
|
|
215
|
+
formatter.default_msec_format = None
|
|
216
|
+
log_handler.setFormatter(formatter)
|
|
217
|
+
log.addHandler(log_handler)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def get_stack_trace(skip=1):
|
|
221
|
+
"""
|
|
222
|
+
Get call stack trace for the current execution context.
|
|
223
|
+
|
|
224
|
+
Extracts stack trace information using torch's CapturedTraceback utility,
|
|
225
|
+
providing detailed information about each frame in the call stack.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
skip (int): Number of frames to skip from the start of the stack
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
List[Dict]: List of frame information dictionaries containing line numbers,
|
|
232
|
+
function names, filenames, and code snippets
|
|
233
|
+
"""
|
|
234
|
+
if not TORCH_INSTALLED:
|
|
235
|
+
return []
|
|
236
|
+
frames = []
|
|
237
|
+
for frame in CapturedTraceback.extract(skip=skip).summary():
|
|
238
|
+
frames.append(
|
|
239
|
+
{
|
|
240
|
+
"line": frame.lineno,
|
|
241
|
+
"name": frame.name,
|
|
242
|
+
"filename": frame.filename,
|
|
243
|
+
"loc": frame.line,
|
|
244
|
+
}
|
|
245
|
+
)
|
|
246
|
+
return frames
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def parse_kernel_allowlist() -> Optional[List[str]]:
|
|
250
|
+
"""
|
|
251
|
+
Parse the kernel allowlist from environment variable.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
List[str] or None: List of kernel name patterns to trace, or None if all kernels should be traced
|
|
255
|
+
"""
|
|
256
|
+
if not TRITONPARSE_KERNEL_ALLOWLIST:
|
|
257
|
+
return None
|
|
258
|
+
|
|
259
|
+
# Split by comma and strip whitespace
|
|
260
|
+
patterns = [pattern.strip() for pattern in TRITONPARSE_KERNEL_ALLOWLIST.split(",")]
|
|
261
|
+
# Filter out empty patterns
|
|
262
|
+
patterns = [pattern for pattern in patterns if pattern]
|
|
263
|
+
|
|
264
|
+
if not patterns:
|
|
265
|
+
return None
|
|
266
|
+
|
|
267
|
+
log.debug(f"Kernel allowlist patterns: {patterns}")
|
|
268
|
+
return patterns
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def extract_kernel_name(src) -> Optional[str]:
|
|
272
|
+
"""
|
|
273
|
+
Extract kernel name from the source object.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
src (Union[ASTSource, IRSource]): Source object containing kernel information
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
str or None: Kernel name if extractable, None otherwise
|
|
280
|
+
"""
|
|
281
|
+
from triton.compiler import IRSource
|
|
282
|
+
|
|
283
|
+
try:
|
|
284
|
+
if isinstance(src, IRSource):
|
|
285
|
+
return src.getattr("name", None)
|
|
286
|
+
else:
|
|
287
|
+
# For ASTSource, get the function name
|
|
288
|
+
if (
|
|
289
|
+
hasattr(src, "fn")
|
|
290
|
+
and hasattr(src.fn, "fn")
|
|
291
|
+
and hasattr(src.fn.fn, "__name__")
|
|
292
|
+
):
|
|
293
|
+
return src.fn.fn.__name__
|
|
294
|
+
return None
|
|
295
|
+
except Exception as e:
|
|
296
|
+
log.warn(f"Error extracting kernel name: {e}")
|
|
297
|
+
return None
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def should_trace_kernel(
|
|
301
|
+
kernel_name: Optional[str], allowlist_patterns: Optional[List[str]]
|
|
302
|
+
) -> bool:
|
|
303
|
+
"""
|
|
304
|
+
Check if a kernel should be traced based on the allowlist.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
kernel_name (str or None): Name of the kernel
|
|
308
|
+
allowlist_patterns (List[str] or None): List of patterns to match against
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
bool: True if the kernel should be traced, False otherwise
|
|
312
|
+
"""
|
|
313
|
+
# If no allowlist is set, trace all kernels
|
|
314
|
+
if allowlist_patterns is None:
|
|
315
|
+
return True
|
|
316
|
+
|
|
317
|
+
# If we can't extract kernel name, don't trace (conservative approach)
|
|
318
|
+
if kernel_name is None:
|
|
319
|
+
log.debug("Cannot extract kernel name, skipping trace")
|
|
320
|
+
return False
|
|
321
|
+
|
|
322
|
+
# Check if kernel name matches any pattern in the allowlist
|
|
323
|
+
for pattern in allowlist_patterns:
|
|
324
|
+
if fnmatch.fnmatch(kernel_name, pattern):
|
|
325
|
+
log.debug(f"Kernel '{kernel_name}' matches pattern '{pattern}', will trace")
|
|
326
|
+
return True
|
|
327
|
+
|
|
328
|
+
log.debug(
|
|
329
|
+
f"Kernel '{kernel_name}' does not match any allowlist pattern, skipping trace"
|
|
330
|
+
)
|
|
331
|
+
return False
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def extract_python_source_info(trace_data: Dict[str, Any], source):
|
|
335
|
+
"""
|
|
336
|
+
Extract Python source code information from the source object and add it to trace_data.
|
|
337
|
+
|
|
338
|
+
This function uses Python's inspect module to extract source code information
|
|
339
|
+
from the provided source object (typically an ASTSource or IRSource instance).
|
|
340
|
+
It adds file path, line numbers, and the actual source code to the trace_data.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
trace_data (Dict[str, Any]): Dictionary to store extracted information
|
|
344
|
+
source (Union[ASTSource, IRSource]): Source object containing kernel function information
|
|
345
|
+
"""
|
|
346
|
+
# @TODO: add support for IRSource
|
|
347
|
+
from triton.compiler import IRSource
|
|
348
|
+
from triton.runtime.jit import JITFunction
|
|
349
|
+
|
|
350
|
+
if isinstance(source, IRSource):
|
|
351
|
+
return
|
|
352
|
+
|
|
353
|
+
# Get the original Python source code for the kernel
|
|
354
|
+
if (
|
|
355
|
+
isinstance(fn := source.fn, JITFunction)
|
|
356
|
+
and hasattr(fn, "starting_line_number")
|
|
357
|
+
and hasattr(fn, "raw_src")
|
|
358
|
+
):
|
|
359
|
+
start_line_number = fn.starting_line_number
|
|
360
|
+
source_lines = fn.raw_src
|
|
361
|
+
else:
|
|
362
|
+
source_lines, start_line_number = inspect.getsourcelines(fn.fn)
|
|
363
|
+
|
|
364
|
+
python_source_file = inspect.getfile(fn.fn)
|
|
365
|
+
end_line_number = start_line_number + len(source_lines)
|
|
366
|
+
trace_data["python_source"] = {
|
|
367
|
+
"file_path": python_source_file,
|
|
368
|
+
"start_line": start_line_number,
|
|
369
|
+
"end_line": end_line_number,
|
|
370
|
+
"code": "".join(source_lines),
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def extract_file_content(trace_data: Dict[str, Any], metadata_group: Dict[str, str]):
|
|
375
|
+
"""
|
|
376
|
+
Extract file content from metadata_group and add it to trace_data.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
trace_data (Dict): Dictionary to store extracted information
|
|
380
|
+
metadata_group (Dict): Dictionary mapping filenames to file paths
|
|
381
|
+
"""
|
|
382
|
+
for ir_filename, file_path in metadata_group.items():
|
|
383
|
+
# Add file path to trace data
|
|
384
|
+
trace_data["file_path"][ir_filename] = file_path
|
|
385
|
+
|
|
386
|
+
# Check if this is a text file we can read
|
|
387
|
+
if any(ir_filename.endswith(ext) for ext in TEXT_FILE_EXTENSIONS):
|
|
388
|
+
try:
|
|
389
|
+
# Check file size before reading to avoid memory issues
|
|
390
|
+
file_size = os.path.getsize(file_path)
|
|
391
|
+
if file_size > MAX_FILE_SIZE:
|
|
392
|
+
message = f"<file too large: {file_size} bytes>"
|
|
393
|
+
trace_data["file_content"][ir_filename] = message
|
|
394
|
+
continue
|
|
395
|
+
|
|
396
|
+
with open(file_path, "r") as f:
|
|
397
|
+
trace_data["file_content"][ir_filename] = f.read()
|
|
398
|
+
except (UnicodeDecodeError, OSError) as e:
|
|
399
|
+
# add more specific error type
|
|
400
|
+
message = f"<error reading file: {str(e)}>"
|
|
401
|
+
trace_data["file_content"][ir_filename] = message
|
|
402
|
+
log.debug(f"Error reading file {file_path}: {e}")
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def extract_metadata_from_src(trace_data, src):
|
|
406
|
+
from triton._C.libtriton import get_cache_invalidating_env_vars
|
|
407
|
+
|
|
408
|
+
env_vars = get_cache_invalidating_env_vars()
|
|
409
|
+
# extra_options = src.parse_options()
|
|
410
|
+
# options = backend.parse_options(dict(options or dict(), **extra_options))
|
|
411
|
+
|
|
412
|
+
# trace_data["extra_options"] = extra_options
|
|
413
|
+
trace_data["metadata"].update(
|
|
414
|
+
{
|
|
415
|
+
"env": env_vars,
|
|
416
|
+
"src_attrs": src.attrs if hasattr(src, "attrs") else {},
|
|
417
|
+
"src_cache_key": src.fn.cache_key if hasattr(src, "fn") else "",
|
|
418
|
+
"src_constants": src.constants if hasattr(src, "constants") else {},
|
|
419
|
+
}
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
class TritonJsonFormatter(logging.Formatter):
|
|
424
|
+
"""
|
|
425
|
+
Format log records as JSON for Triton compilation tracing.
|
|
426
|
+
|
|
427
|
+
This formatter converts log records with metadata and payload into NDJSON format,
|
|
428
|
+
suitable for structured logging and later analysis. It handles special attributes
|
|
429
|
+
added by the tritonparse, such as metadata dictionaries and payload data.
|
|
430
|
+
"""
|
|
431
|
+
|
|
432
|
+
def format(self, record: logging.LogRecord):
|
|
433
|
+
log_entry = record.metadata
|
|
434
|
+
payload = record.payload
|
|
435
|
+
|
|
436
|
+
log_entry["timestamp"] = self.formatTime(record, "%Y-%m-%dT%H:%M:%S.%fZ")
|
|
437
|
+
if payload is not None:
|
|
438
|
+
log_entry["payload"] = json.loads(payload)
|
|
439
|
+
clean_log_entry = convert(log_entry)
|
|
440
|
+
if not TRITONPARSE_NDJSON:
|
|
441
|
+
return json.dumps(clean_log_entry, indent=2)
|
|
442
|
+
else:
|
|
443
|
+
# NDJSON format requires a newline at the end of each line
|
|
444
|
+
json_str = json.dumps(clean_log_entry, separators=(",", ":"))
|
|
445
|
+
return json_str + "\n"
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
class TritonTraceHandler(logging.StreamHandler):
|
|
449
|
+
"""
|
|
450
|
+
A handler for Triton compilation tracing that outputs NDJSON files.
|
|
451
|
+
|
|
452
|
+
This handler creates and manages log files for Triton kernel compilation traces.
|
|
453
|
+
It supports creating new files for different compilation events and handles
|
|
454
|
+
proper cleanup of file resources. When running in a distributed environment,
|
|
455
|
+
it automatically adds rank information to filenames.
|
|
456
|
+
"""
|
|
457
|
+
|
|
458
|
+
def __init__(
|
|
459
|
+
self, root_dir: Optional[str] = None, prefix=DEFAULT_TRACE_FILE_PREFIX
|
|
460
|
+
):
|
|
461
|
+
logging.Handler.__init__(self)
|
|
462
|
+
self.root_dir = root_dir
|
|
463
|
+
self.prefix = prefix
|
|
464
|
+
self.stream = None
|
|
465
|
+
self.first_record = True
|
|
466
|
+
# If the program is unexpected terminated, atexit can ensure file resources are properly closed and released.
|
|
467
|
+
# it is because we use `self.stream` to keep the opened file stream, if the program is interrupted by some errors, the stream may not be closed.
|
|
468
|
+
atexit.register(self._cleanup)
|
|
469
|
+
|
|
470
|
+
def get_root_dir(self):
|
|
471
|
+
# For meta internal runs, we use the /logs directory by default
|
|
472
|
+
# reference implementation
|
|
473
|
+
# https://github.com/pytorch/pytorch/blob/5fe58ab5bd9e14cce3107150a9956a2ed40d2f79/torch/_logging/_internal.py#L1071
|
|
474
|
+
if self.root_dir:
|
|
475
|
+
return self.root_dir
|
|
476
|
+
TRACE_LOG_DIR = "/logs"
|
|
477
|
+
should_set_root_dir = True
|
|
478
|
+
if TORCH_INSTALLED:
|
|
479
|
+
import torch.version as torch_version
|
|
480
|
+
|
|
481
|
+
if (
|
|
482
|
+
hasattr(torch_version, "git_version")
|
|
483
|
+
and os.getenv("MAST_HPC_JOB_NAME") is None
|
|
484
|
+
):
|
|
485
|
+
log.info(
|
|
486
|
+
"TritonTraceHandler: disabled because not fbcode or conda on mast"
|
|
487
|
+
)
|
|
488
|
+
should_set_root_dir = False
|
|
489
|
+
# TODO: change to tritonparse knob
|
|
490
|
+
elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"):
|
|
491
|
+
log.info(
|
|
492
|
+
"TritonTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False"
|
|
493
|
+
)
|
|
494
|
+
should_set_root_dir = False
|
|
495
|
+
if should_set_root_dir:
|
|
496
|
+
if not os.path.exists(TRACE_LOG_DIR):
|
|
497
|
+
log.info(
|
|
498
|
+
"TritonTraceHandler: disabled because %s does not exist",
|
|
499
|
+
TRACE_LOG_DIR,
|
|
500
|
+
)
|
|
501
|
+
elif not os.access(TRACE_LOG_DIR, os.W_OK):
|
|
502
|
+
log.info(
|
|
503
|
+
"TritonTraceHandler: disabled because %s is not writeable",
|
|
504
|
+
TRACE_LOG_DIR,
|
|
505
|
+
)
|
|
506
|
+
else:
|
|
507
|
+
self.root_dir = TRACE_LOG_DIR
|
|
508
|
+
return self.root_dir
|
|
509
|
+
|
|
510
|
+
def emit(self, record):
|
|
511
|
+
# reference implementation
|
|
512
|
+
# https://github.com/pytorch/pytorch/blob/5fe58ab5bd9e14cce3107150a9956a2ed40d2f79/torch/_logging/_internal.py#L1071
|
|
513
|
+
try:
|
|
514
|
+
if self.stream is None:
|
|
515
|
+
root_dir = self.get_root_dir()
|
|
516
|
+
if root_dir is not None:
|
|
517
|
+
os.makedirs(root_dir, exist_ok=True)
|
|
518
|
+
ranksuffix = ""
|
|
519
|
+
if TORCH_INSTALLED:
|
|
520
|
+
import torch.distributed as dist
|
|
521
|
+
|
|
522
|
+
if dist.is_available() and dist.is_initialized():
|
|
523
|
+
ranksuffix = f"rank_{dist.get_rank()}_"
|
|
524
|
+
filename = f"{self.prefix}{ranksuffix}"
|
|
525
|
+
self._ensure_stream_closed()
|
|
526
|
+
# Choose file extension and mode based on compression setting
|
|
527
|
+
if TRITON_TRACE_GZIP:
|
|
528
|
+
file_extension = ".bin.ndjson"
|
|
529
|
+
file_mode = "ab+" # Binary mode for gzip member concatenation
|
|
530
|
+
else:
|
|
531
|
+
file_extension = ".ndjson"
|
|
532
|
+
file_mode = "a+"
|
|
533
|
+
log_file_name = os.path.abspath(
|
|
534
|
+
os.path.join(root_dir, f"{filename}{file_extension}")
|
|
535
|
+
)
|
|
536
|
+
self.stream = open(
|
|
537
|
+
log_file_name,
|
|
538
|
+
mode=file_mode,
|
|
539
|
+
)
|
|
540
|
+
log.debug("TritonTraceHandler: logging to %s", log_file_name)
|
|
541
|
+
else:
|
|
542
|
+
triton_trace_log.removeHandler(self)
|
|
543
|
+
return
|
|
544
|
+
|
|
545
|
+
if self.stream:
|
|
546
|
+
formatted = self.format(record)
|
|
547
|
+
if TRITON_TRACE_GZIP:
|
|
548
|
+
# Create a separate gzip member for each record
|
|
549
|
+
# This allows standard gzip readers to handle member concatenation automatically
|
|
550
|
+
buffer = io.BytesIO()
|
|
551
|
+
with gzip.GzipFile(fileobj=buffer, mode="wb") as gz:
|
|
552
|
+
gz.write(formatted.encode("utf-8"))
|
|
553
|
+
# Write the complete gzip member to the file
|
|
554
|
+
compressed_data = buffer.getvalue()
|
|
555
|
+
self.stream.write(compressed_data)
|
|
556
|
+
else:
|
|
557
|
+
self.stream.write(formatted)
|
|
558
|
+
self.flush()
|
|
559
|
+
except Exception as e:
|
|
560
|
+
# record exception and ensure resources are released
|
|
561
|
+
log.error(f"Error in TritonTraceHandler.emit: {e}")
|
|
562
|
+
self._ensure_stream_closed()
|
|
563
|
+
self.handleError(record) # call Handler's standard error handling
|
|
564
|
+
|
|
565
|
+
def close(self):
|
|
566
|
+
"""Close the current file."""
|
|
567
|
+
self.acquire()
|
|
568
|
+
try:
|
|
569
|
+
try:
|
|
570
|
+
if self.stream:
|
|
571
|
+
try:
|
|
572
|
+
self.flush()
|
|
573
|
+
finally:
|
|
574
|
+
self.stream.close()
|
|
575
|
+
self.stream = None
|
|
576
|
+
finally:
|
|
577
|
+
# Solution adopted from PyTorch PR #120289
|
|
578
|
+
logging.StreamHandler.close(self)
|
|
579
|
+
finally:
|
|
580
|
+
self.release()
|
|
581
|
+
|
|
582
|
+
def _cleanup(self):
|
|
583
|
+
"""Ensure proper cleanup on program exit"""
|
|
584
|
+
if self.stream is not None:
|
|
585
|
+
self.close()
|
|
586
|
+
|
|
587
|
+
def _ensure_stream_closed(self):
|
|
588
|
+
"""ensure stream is closed"""
|
|
589
|
+
if self.stream is not None:
|
|
590
|
+
try:
|
|
591
|
+
self.flush()
|
|
592
|
+
finally:
|
|
593
|
+
self.stream.close()
|
|
594
|
+
self.stream = None
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def init_logs():
|
|
598
|
+
"""
|
|
599
|
+
Initialise tritonparse's logging system.
|
|
600
|
+
|
|
601
|
+
Requirements handled:
|
|
602
|
+
1. First call may or may not pass `trace_folder`.
|
|
603
|
+
2. A later call *can* pass `trace_folder` and must activate an
|
|
604
|
+
existing handler whose `root_dir` was None.
|
|
605
|
+
3. When tracing is disabled (no writable directory), prevent the
|
|
606
|
+
empty →
|
|
607
|
+
DEBUG:tritonparse_trace:
|
|
608
|
+
lines by blocking propagation to the root logger.
|
|
609
|
+
"""
|
|
610
|
+
global TRITON_TRACE_HANDLER, triton_trace_folder
|
|
611
|
+
|
|
612
|
+
# Basic logger settings (safe to run on every call)
|
|
613
|
+
triton_trace_log.setLevel(logging.DEBUG)
|
|
614
|
+
triton_trace_log.propagate = False # stops bubbling to root logger. see 3)
|
|
615
|
+
# 1) Create the handler on first use (root_dir may be None)
|
|
616
|
+
if TRITON_TRACE_HANDLER is None:
|
|
617
|
+
TRITON_TRACE_HANDLER = TritonTraceHandler(triton_trace_folder)
|
|
618
|
+
# 2) If the handler has no root_dir but we now have
|
|
619
|
+
# `triton_trace_folder`, fill it in.
|
|
620
|
+
if TRITON_TRACE_HANDLER.root_dir is None and triton_trace_folder is not None:
|
|
621
|
+
TRITON_TRACE_HANDLER.root_dir = triton_trace_folder
|
|
622
|
+
# 3) Re-evaluate whether we have a writable directory
|
|
623
|
+
# (`get_root_dir()` also checks /logs logic, permissions, etc.)
|
|
624
|
+
root_dir = TRITON_TRACE_HANDLER.get_root_dir()
|
|
625
|
+
if root_dir is None:
|
|
626
|
+
# Tracing still disabled: ensure the handler is NOT attached
|
|
627
|
+
if TRITON_TRACE_HANDLER in triton_trace_log.handlers:
|
|
628
|
+
triton_trace_log.removeHandler(TRITON_TRACE_HANDLER)
|
|
629
|
+
return # quiet exit, no blank lines
|
|
630
|
+
# 4) Tracing is enabled: attach the handler (if not already
|
|
631
|
+
# attached) and set the JSON formatter.
|
|
632
|
+
if TRITON_TRACE_HANDLER not in triton_trace_log.handlers:
|
|
633
|
+
TRITON_TRACE_HANDLER.setFormatter(TritonJsonFormatter())
|
|
634
|
+
triton_trace_log.addHandler(TRITON_TRACE_HANDLER)
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
def trace_structured_triton(
|
|
638
|
+
name: str,
|
|
639
|
+
metadata_fn: Optional[Callable[[], Dict[str, Any]]] = None,
|
|
640
|
+
*,
|
|
641
|
+
payload_fn: Optional[Callable[[], Optional[Union[str, object]]]] = None,
|
|
642
|
+
):
|
|
643
|
+
"""
|
|
644
|
+
Record structured trace information for Triton kernel compilation.
|
|
645
|
+
|
|
646
|
+
This function is the main entry point for logging structured trace events
|
|
647
|
+
in the Triton system. It handles initialization of the logging system if needed,
|
|
648
|
+
creates new log files, and formats the trace data with metadata
|
|
649
|
+
and payload information.
|
|
650
|
+
|
|
651
|
+
Args:
|
|
652
|
+
name (str): Name of the trace event (e.g., "compilation", "execution")
|
|
653
|
+
metadata_fn (Callable): Function that returns a dictionary of metadata to include
|
|
654
|
+
in the trace record
|
|
655
|
+
payload_fn (Callable): Function that returns the payload data (can be a string,
|
|
656
|
+
dictionary, or other serializable object)
|
|
657
|
+
"""
|
|
658
|
+
|
|
659
|
+
if metadata_fn is None:
|
|
660
|
+
|
|
661
|
+
def metadata_fn():
|
|
662
|
+
return {}
|
|
663
|
+
|
|
664
|
+
if payload_fn is None:
|
|
665
|
+
|
|
666
|
+
def payload_fn():
|
|
667
|
+
return None
|
|
668
|
+
|
|
669
|
+
metadata_dict: Dict[str, Any] = {"event_type": name}
|
|
670
|
+
metadata_dict["pid"] = os.getpid()
|
|
671
|
+
custom_metadata = metadata_fn()
|
|
672
|
+
if custom_metadata:
|
|
673
|
+
metadata_dict.update(custom_metadata)
|
|
674
|
+
|
|
675
|
+
metadata_dict["stack"] = get_stack_trace()
|
|
676
|
+
|
|
677
|
+
# Log the record using our custom LogRecord
|
|
678
|
+
payload = payload_fn()
|
|
679
|
+
# Use a custom factory to create the record with simplified parameters
|
|
680
|
+
record = create_triton_log_record(metadata=metadata_dict, payload=payload)
|
|
681
|
+
# Log the custom record
|
|
682
|
+
triton_trace_log.handle(record)
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
def maybe_trace_triton(
|
|
686
|
+
src,
|
|
687
|
+
metadata: Dict[str, Any],
|
|
688
|
+
metadata_group: Dict[str, str],
|
|
689
|
+
times: Any,
|
|
690
|
+
event_type: str = "compilation",
|
|
691
|
+
cache_hit: bool = False,
|
|
692
|
+
):
|
|
693
|
+
"""
|
|
694
|
+
Collect and trace Triton kernel compilation information for debugging and profiling.
|
|
695
|
+
|
|
696
|
+
This function gathers metadata, IR files, and source code information about a Triton
|
|
697
|
+
kernel compilation, then logs it through the tracing system if tracing is enabled.
|
|
698
|
+
It collects information from multiple sources:
|
|
699
|
+
1. JSON metadata file (if provided)
|
|
700
|
+
2. PyTorch compilation context (if available)
|
|
701
|
+
3. IR and other compilation artifact files
|
|
702
|
+
4. Python source code of the kernel function
|
|
703
|
+
|
|
704
|
+
This function is designed to be used as a CompilationListener in triton.knobs.compilation.listener,
|
|
705
|
+
which now accepts a list of listeners.
|
|
706
|
+
|
|
707
|
+
Args:
|
|
708
|
+
src (Union[ASTSource, IRSource]): Source object containing kernel information
|
|
709
|
+
metadata (Dict[str, Any]): Dictionary containing metadata for the compilation
|
|
710
|
+
metadata_group (Dict[str, Any]): Dictionary mapping filenames to file paths for all compilation artifacts
|
|
711
|
+
times (CompileTimes): Object containing timing information for the compilation
|
|
712
|
+
event_type (str): Type of event being traced (default: "compilation")
|
|
713
|
+
cache_hit (bool): Whether the compilation was a cache hit (default: False)
|
|
714
|
+
|
|
715
|
+
Returns:
|
|
716
|
+
Dict[str, Any]: Dictionary containing all collected trace data, even if tracing is disabled
|
|
717
|
+
"""
|
|
718
|
+
# Check kernel allowlist early to avoid unnecessary work
|
|
719
|
+
if _KERNEL_ALLOWLIST_PATTERNS is not None:
|
|
720
|
+
kernel_name = extract_kernel_name(src)
|
|
721
|
+
if not should_trace_kernel(kernel_name, _KERNEL_ALLOWLIST_PATTERNS):
|
|
722
|
+
# Return empty dict to indicate no tracing was done
|
|
723
|
+
return {}
|
|
724
|
+
|
|
725
|
+
# Initialize a dictionary with defaultdict to avoid key errors
|
|
726
|
+
trace_data = defaultdict(dict)
|
|
727
|
+
# Add cache_hit to metadata
|
|
728
|
+
trace_data["metadata"]["cache_hit"] = cache_hit
|
|
729
|
+
if not metadata:
|
|
730
|
+
metadata_path = next(
|
|
731
|
+
(Path(p) for c, p in metadata_group.items() if c.endswith(".json"))
|
|
732
|
+
)
|
|
733
|
+
with open(metadata_path, "r") as f:
|
|
734
|
+
metadata = json.load(f)
|
|
735
|
+
trace_data["metadata"].update(metadata)
|
|
736
|
+
else:
|
|
737
|
+
trace_data["metadata"].update(metadata)
|
|
738
|
+
# Handle torch._guards which might not be recognized by type checker
|
|
739
|
+
if TORCH_INSTALLED:
|
|
740
|
+
trace_id = torch._guards.CompileContext.current_trace_id() # type: ignore
|
|
741
|
+
else:
|
|
742
|
+
trace_id = None
|
|
743
|
+
cid = trace_id.compile_id if trace_id else None
|
|
744
|
+
if cid is not None:
|
|
745
|
+
for attr_name in ["compiled_autograd_id", "frame_id", "frame_compile_id"]:
|
|
746
|
+
attr_value = getattr(cid, attr_name, None)
|
|
747
|
+
if attr_value is not None:
|
|
748
|
+
trace_data["pt_info"][attr_name] = attr_value
|
|
749
|
+
if trace_id:
|
|
750
|
+
trace_data["pt_info"]["attempt"] = trace_id.attempt
|
|
751
|
+
# Extract content from all IR and other files in the metadata group
|
|
752
|
+
extract_file_content(trace_data, metadata_group)
|
|
753
|
+
# Extract Python source code information if available
|
|
754
|
+
extract_python_source_info(trace_data, src)
|
|
755
|
+
extract_metadata_from_src(trace_data, src)
|
|
756
|
+
|
|
757
|
+
# Add timing information if available
|
|
758
|
+
if times:
|
|
759
|
+
trace_data["times"] = times
|
|
760
|
+
# Log the collected information through the tracing system
|
|
761
|
+
trace_structured_triton(
|
|
762
|
+
event_type,
|
|
763
|
+
payload_fn=lambda: json.dumps(convert(trace_data)),
|
|
764
|
+
)
|
|
765
|
+
|
|
766
|
+
return trace_data
|
|
767
|
+
|
|
768
|
+
|
|
769
|
+
def extract_arg_info(arg_dict):
|
|
770
|
+
"""
|
|
771
|
+
Extract detailed information from kernel arguments, especially for PyTorch tensors.
|
|
772
|
+
|
|
773
|
+
Args:
|
|
774
|
+
arg_dict: Dictionary of kernel arguments
|
|
775
|
+
|
|
776
|
+
Returns:
|
|
777
|
+
Dictionary with extracted argument information including tensor properties
|
|
778
|
+
"""
|
|
779
|
+
extracted_args = {}
|
|
780
|
+
|
|
781
|
+
for arg_name, arg_value in arg_dict.items():
|
|
782
|
+
arg_info = {}
|
|
783
|
+
|
|
784
|
+
# Check if it's a PyTorch tensor
|
|
785
|
+
if TORCH_INSTALLED and isinstance(arg_value, torch.Tensor):
|
|
786
|
+
arg_info["type"] = "tensor"
|
|
787
|
+
arg_info["shape"] = list(arg_value.shape)
|
|
788
|
+
arg_info["dtype"] = str(arg_value.dtype)
|
|
789
|
+
arg_info["device"] = str(arg_value.device)
|
|
790
|
+
arg_info["stride"] = list(arg_value.stride())
|
|
791
|
+
arg_info["numel"] = arg_value.numel()
|
|
792
|
+
arg_info["is_contiguous"] = arg_value.is_contiguous()
|
|
793
|
+
arg_info["element_size"] = arg_value.element_size()
|
|
794
|
+
arg_info["storage_offset"] = arg_value.storage_offset()
|
|
795
|
+
# Memory usage in bytes
|
|
796
|
+
arg_info["memory_usage"] = arg_value.numel() * arg_value.element_size()
|
|
797
|
+
# Add data_ptr for memory tracking (optional)
|
|
798
|
+
if hasattr(arg_value, "data_ptr"):
|
|
799
|
+
arg_info["data_ptr"] = hex(arg_value.data_ptr())
|
|
800
|
+
# Handle scalar values
|
|
801
|
+
elif isinstance(arg_value, (int, float, bool)):
|
|
802
|
+
arg_info["type"] = type(arg_value).__name__
|
|
803
|
+
arg_info["value"] = arg_value
|
|
804
|
+
# Handle strings
|
|
805
|
+
elif isinstance(arg_value, str):
|
|
806
|
+
arg_info["type"] = "str"
|
|
807
|
+
arg_info["value"] = arg_value
|
|
808
|
+
arg_info["length"] = len(arg_value)
|
|
809
|
+
# Handle other types
|
|
810
|
+
else:
|
|
811
|
+
arg_info["type"] = type(arg_value).__name__
|
|
812
|
+
# Try to convert to string for logging
|
|
813
|
+
arg_info["repr"] = str(arg_value)
|
|
814
|
+
if len(arg_info["repr"]) > 200: # Truncate very long representations
|
|
815
|
+
arg_info["repr"] = arg_info["repr"][:200] + "..."
|
|
816
|
+
|
|
817
|
+
extracted_args[arg_name] = arg_info
|
|
818
|
+
|
|
819
|
+
return extracted_args
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
def add_launch_metadata(grid, metadata, arg_dict):
|
|
823
|
+
# Extract detailed argument information
|
|
824
|
+
extracted_args = extract_arg_info(arg_dict)
|
|
825
|
+
return {"launch_metadata_tritonparse": (grid, metadata._asdict(), extracted_args)}
|
|
826
|
+
|
|
827
|
+
|
|
828
|
+
class JITHookImpl(JITHook):
|
|
829
|
+
"""
|
|
830
|
+
JIT Hook implementation that overrides or sets the launch_metadata function for Triton kernels.
|
|
831
|
+
|
|
832
|
+
This hook is essential for capturing detailed kernel launch information beyond the basic
|
|
833
|
+
metadata (like kernel name) that Triton provides by default. Without setting a custom
|
|
834
|
+
launch_metadata function, only minimal launch information is available as shown in:
|
|
835
|
+
https://github.com/triton-lang/triton/blob/7ce287dc24b43476cdeb30529089ac361564505d/python/triton/compiler/compiler.py#L504
|
|
836
|
+
|
|
837
|
+
By intercepting the JIT compilation process and setting a custom launch_metadata function,
|
|
838
|
+
we can capture comprehensive runtime information including grid parameters, kernel metadata,
|
|
839
|
+
and argument dictionaries for detailed analysis and logging.
|
|
840
|
+
"""
|
|
841
|
+
|
|
842
|
+
def __call__(
|
|
843
|
+
self,
|
|
844
|
+
*,
|
|
845
|
+
key: str,
|
|
846
|
+
repr: str,
|
|
847
|
+
fn,
|
|
848
|
+
compile,
|
|
849
|
+
is_manual_warmup: bool,
|
|
850
|
+
already_compiled: bool,
|
|
851
|
+
) -> Optional[bool]:
|
|
852
|
+
"""
|
|
853
|
+
Override or set the launch_metadata function for the JIT-compiled kernel.
|
|
854
|
+
|
|
855
|
+
This method is called during the JIT compilation process and allows us to
|
|
856
|
+
inject our custom launch_metadata function that will be used to collect
|
|
857
|
+
detailed kernel launch information.
|
|
858
|
+
|
|
859
|
+
Args:
|
|
860
|
+
key: Unique identifier for the kernel
|
|
861
|
+
repr: String representation of the kernel
|
|
862
|
+
fn: The JIT function object
|
|
863
|
+
compile: Compilation function
|
|
864
|
+
is_manual_warmup: Whether this is a manual warmup call
|
|
865
|
+
already_compiled: Whether the kernel is already compiled
|
|
866
|
+
|
|
867
|
+
Returns:
|
|
868
|
+
True to continue with compilation, None/False to skip
|
|
869
|
+
"""
|
|
870
|
+
# Check kernel allowlist early to avoid unnecessary work
|
|
871
|
+
if _KERNEL_ALLOWLIST_PATTERNS is not None:
|
|
872
|
+
kernel_name = fn.name
|
|
873
|
+
if not should_trace_kernel(kernel_name, _KERNEL_ALLOWLIST_PATTERNS):
|
|
874
|
+
# Skip overriding launch_metadata if kernel is not in allowlist
|
|
875
|
+
return True
|
|
876
|
+
|
|
877
|
+
# Get the current launch_metadata function if it exists
|
|
878
|
+
current_launch_metadata = getattr(fn.jit_function, "launch_metadata", None)
|
|
879
|
+
if current_launch_metadata is not None:
|
|
880
|
+
log.warning(
|
|
881
|
+
f"fn {fn} launch_metadata is not None: {current_launch_metadata}. It will be overridden by tritonparse."
|
|
882
|
+
)
|
|
883
|
+
fn.jit_function.launch_metadata = add_launch_metadata
|
|
884
|
+
return True
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
class LaunchHookImpl(LaunchHook):
|
|
888
|
+
"""
|
|
889
|
+
Launch Hook implementation for capturing and logging kernel launch metadata.
|
|
890
|
+
|
|
891
|
+
This hook is responsible for intercepting kernel launches and extracting the detailed
|
|
892
|
+
metadata that was set up by the JITHookImpl. It provides entry point for
|
|
893
|
+
kernel execution, allowing comprehensive logging and analysis of kernel launches
|
|
894
|
+
including timing, parameters, and execution context.
|
|
895
|
+
|
|
896
|
+
The metadata captured includes:
|
|
897
|
+
- Kernel name and function details
|
|
898
|
+
- Grid dimensions and launch parameters
|
|
899
|
+
- Kernel arguments and their values
|
|
900
|
+
- Stream information
|
|
901
|
+
- Custom metadata added by the launch_metadata function
|
|
902
|
+
"""
|
|
903
|
+
|
|
904
|
+
def __call__(self, metadata):
|
|
905
|
+
"""
|
|
906
|
+
Handle kernel launch entry point.
|
|
907
|
+
|
|
908
|
+
This method is called when a kernel is about to be launched, providing
|
|
909
|
+
access to all the launch metadata for logging, profiling, or analysis.
|
|
910
|
+
metadata format:
|
|
911
|
+
|
|
912
|
+
Args:
|
|
913
|
+
metadata: LazyDict containing comprehensive launch information including
|
|
914
|
+
kernel name, function, stream, grid parameters, and custom data
|
|
915
|
+
format: {'name': 'add_kernel', 'function': None, 'stream': 0,
|
|
916
|
+
'launch_metadata_tritonparse': (grid, self.metadata, extracted_args)}
|
|
917
|
+
where extracted_args contains detailed info for each argument:
|
|
918
|
+
- For tensors: shape, dtype, device, stride, memory_usage, etc.
|
|
919
|
+
- For scalars: type and value
|
|
920
|
+
- For other types: type and string representation
|
|
921
|
+
defined here:
|
|
922
|
+
https://github.com/triton-lang/triton/blob/7ce287dc24b43476cdeb30529089ac361564505d/
|
|
923
|
+
python/triton/compiler/compiler.py#L512.
|
|
924
|
+
"""
|
|
925
|
+
metadata_dict = metadata.get()
|
|
926
|
+
# Check kernel allowlist early to avoid unnecessary work
|
|
927
|
+
if _KERNEL_ALLOWLIST_PATTERNS is not None:
|
|
928
|
+
kernel_name = metadata_dict.get("name")
|
|
929
|
+
|
|
930
|
+
if not should_trace_kernel(kernel_name, _KERNEL_ALLOWLIST_PATTERNS):
|
|
931
|
+
# Skip tracing if kernel is not in allowlist
|
|
932
|
+
return
|
|
933
|
+
|
|
934
|
+
trace_data = defaultdict(dict)
|
|
935
|
+
trace_data["name"] = metadata_dict["name"]
|
|
936
|
+
trace_data["function"] = metadata_dict["function"]
|
|
937
|
+
trace_data["stream"] = metadata_dict["stream"]
|
|
938
|
+
launch_metadata_tritonparse = metadata_dict.get(
|
|
939
|
+
"launch_metadata_tritonparse", None
|
|
940
|
+
)
|
|
941
|
+
if launch_metadata_tritonparse is not None:
|
|
942
|
+
trace_data["grid"] = launch_metadata_tritonparse[0]
|
|
943
|
+
trace_data["compilation_metadata"] = launch_metadata_tritonparse[1]
|
|
944
|
+
trace_data["extracted_args"] = launch_metadata_tritonparse[
|
|
945
|
+
2
|
|
946
|
+
] # Now contains detailed arg info
|
|
947
|
+
trace_structured_triton("launch", metadata_fn=lambda: convert(trace_data))
|
|
948
|
+
|
|
949
|
+
|
|
950
|
+
def maybe_enable_trace_launch():
|
|
951
|
+
global _trace_launch_enabled
|
|
952
|
+
if TRITON_TRACE_LAUNCH and not _trace_launch_enabled:
|
|
953
|
+
from triton import knobs
|
|
954
|
+
|
|
955
|
+
launch_hook = LaunchHookImpl()
|
|
956
|
+
jit_hook = JITHookImpl()
|
|
957
|
+
knobs.runtime.jit_post_compile_hook = jit_hook
|
|
958
|
+
knobs.runtime.launch_enter_hook = launch_hook
|
|
959
|
+
|
|
960
|
+
_trace_launch_enabled = True
|
|
961
|
+
|
|
962
|
+
|
|
963
|
+
def init_basic(trace_folder: Optional[str] = None):
|
|
964
|
+
"""
|
|
965
|
+
Initialize the basic logging system for Triton compilation.
|
|
966
|
+
|
|
967
|
+
This function sets up the basic logging system for Triton kernel compilation,
|
|
968
|
+
|
|
969
|
+
Args:
|
|
970
|
+
trace_folder (Optional[str]): The folder to store the trace files.
|
|
971
|
+
"""
|
|
972
|
+
global triton_trace_folder, _KERNEL_ALLOWLIST_PATTERNS
|
|
973
|
+
maybe_enable_debug_logging()
|
|
974
|
+
if triton_trace_folder is not None and trace_folder is not None:
|
|
975
|
+
log.info(
|
|
976
|
+
"Conflict settings: triton_trace_folder is already set to %s, we will use provided trace_folder(%s) instead.",
|
|
977
|
+
triton_trace_folder,
|
|
978
|
+
trace_folder,
|
|
979
|
+
)
|
|
980
|
+
if trace_folder is not None:
|
|
981
|
+
triton_trace_folder = trace_folder
|
|
982
|
+
|
|
983
|
+
# Parse and store kernel allowlist configuration
|
|
984
|
+
_KERNEL_ALLOWLIST_PATTERNS = parse_kernel_allowlist()
|
|
985
|
+
if _KERNEL_ALLOWLIST_PATTERNS:
|
|
986
|
+
log.debug(
|
|
987
|
+
f"Kernel allowlist enabled with patterns: {_KERNEL_ALLOWLIST_PATTERNS}"
|
|
988
|
+
)
|
|
989
|
+
else:
|
|
990
|
+
log.debug("Kernel allowlist not set, tracing all kernels")
|
|
991
|
+
|
|
992
|
+
init_logs()
|
|
993
|
+
maybe_enable_trace_launch()
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
def init(trace_folder: Optional[str] = None, enable_trace_launch: bool = False):
|
|
997
|
+
"""
|
|
998
|
+
This function is a wrapper around init_basic() that also sets up the compilation listener.
|
|
999
|
+
|
|
1000
|
+
Args:
|
|
1001
|
+
trace_folder (Optional[str]): The folder to store the trace files.
|
|
1002
|
+
enable_trace_launch (bool): Whether to enable the trace launch hook.
|
|
1003
|
+
"""
|
|
1004
|
+
global TRITON_TRACE_LAUNCH
|
|
1005
|
+
if enable_trace_launch:
|
|
1006
|
+
TRITON_TRACE_LAUNCH = True
|
|
1007
|
+
|
|
1008
|
+
init_basic(trace_folder)
|
|
1009
|
+
from triton import knobs
|
|
1010
|
+
|
|
1011
|
+
knobs.compilation.listener = maybe_trace_triton
|
|
1012
|
+
|
|
1013
|
+
|
|
1014
|
+
def clear_logging_config():
|
|
1015
|
+
"""
|
|
1016
|
+
Clear all configurations made by init() and init_basic().
|
|
1017
|
+
|
|
1018
|
+
This function resets the logging handlers, global state variables,
|
|
1019
|
+
and Triton knobs to their default states, effectively disabling
|
|
1020
|
+
the custom tracing.
|
|
1021
|
+
|
|
1022
|
+
WARNING: This function is not supposed to be called unless you are sure
|
|
1023
|
+
you want to clear the logging config.
|
|
1024
|
+
"""
|
|
1025
|
+
global TRITON_TRACE_HANDLER, triton_trace_folder, _KERNEL_ALLOWLIST_PATTERNS
|
|
1026
|
+
global _trace_launch_enabled
|
|
1027
|
+
|
|
1028
|
+
# 1. Clean up the log handler
|
|
1029
|
+
if TRITON_TRACE_HANDLER is not None:
|
|
1030
|
+
if TRITON_TRACE_HANDLER in triton_trace_log.handlers:
|
|
1031
|
+
triton_trace_log.removeHandler(TRITON_TRACE_HANDLER)
|
|
1032
|
+
TRITON_TRACE_HANDLER.close()
|
|
1033
|
+
TRITON_TRACE_HANDLER = None
|
|
1034
|
+
|
|
1035
|
+
# 2. Reset global state variables
|
|
1036
|
+
triton_trace_folder = None
|
|
1037
|
+
_KERNEL_ALLOWLIST_PATTERNS = None
|
|
1038
|
+
_trace_launch_enabled = False
|
|
1039
|
+
|
|
1040
|
+
# 3. Reset Triton knobs
|
|
1041
|
+
# Check if triton was actually imported and used
|
|
1042
|
+
from triton import knobs
|
|
1043
|
+
|
|
1044
|
+
knobs.compilation.listener = None
|
|
1045
|
+
knobs.runtime.jit_post_compile_hook = None
|
|
1046
|
+
knobs.runtime.launch_enter_hook = None
|