tritonparse 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritonparse might be problematic. Click here for more details.

Files changed (40) hide show
  1. tritonparse/__init__.py +0 -0
  2. tritonparse/common.py +409 -0
  3. tritonparse/event_diff.py +120 -0
  4. tritonparse/extract_source_mappings.py +49 -0
  5. tritonparse/ir_parser.py +220 -0
  6. tritonparse/mapper.py +100 -0
  7. tritonparse/reproducer/__init__.py +21 -0
  8. tritonparse/reproducer/__main__.py +81 -0
  9. tritonparse/reproducer/cli.py +37 -0
  10. tritonparse/reproducer/config.py +15 -0
  11. tritonparse/reproducer/factory.py +16 -0
  12. tritonparse/reproducer/ingestion/__init__.py +6 -0
  13. tritonparse/reproducer/ingestion/ndjson.py +165 -0
  14. tritonparse/reproducer/orchestrator.py +65 -0
  15. tritonparse/reproducer/param_generator.py +142 -0
  16. tritonparse/reproducer/prompts/__init__.py +1 -0
  17. tritonparse/reproducer/prompts/loader.py +18 -0
  18. tritonparse/reproducer/providers/__init__.py +1 -0
  19. tritonparse/reproducer/providers/base.py +14 -0
  20. tritonparse/reproducer/providers/gemini.py +47 -0
  21. tritonparse/reproducer/runtime/__init__.py +1 -0
  22. tritonparse/reproducer/runtime/executor.py +13 -0
  23. tritonparse/reproducer/utils/io.py +6 -0
  24. tritonparse/shared_vars.py +9 -0
  25. tritonparse/source_type.py +56 -0
  26. tritonparse/sourcemap_utils.py +72 -0
  27. tritonparse/structured_logging.py +1046 -0
  28. tritonparse/tools/__init__.py +0 -0
  29. tritonparse/tools/decompress_bin_ndjson.py +118 -0
  30. tritonparse/tools/format_fix.py +149 -0
  31. tritonparse/tools/load_tensor.py +58 -0
  32. tritonparse/tools/prettify_ndjson.py +315 -0
  33. tritonparse/tp_logger.py +9 -0
  34. tritonparse/trace_processor.py +331 -0
  35. tritonparse/utils.py +156 -0
  36. tritonparse-0.1.1.dist-info/METADATA +10 -0
  37. tritonparse-0.1.1.dist-info/RECORD +40 -0
  38. tritonparse-0.1.1.dist-info/WHEEL +5 -0
  39. tritonparse-0.1.1.dist-info/licenses/LICENSE +29 -0
  40. tritonparse-0.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1046 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import atexit
4
+ import fnmatch
5
+ import gzip
6
+ import importlib
7
+ import inspect
8
+ import io
9
+ import json
10
+ import logging
11
+ import math
12
+ import os
13
+ from collections import defaultdict
14
+ from collections.abc import Mapping
15
+ from dataclasses import asdict, is_dataclass
16
+ from datetime import date, datetime
17
+ from enum import Enum
18
+ from pathlib import Path
19
+ from typing import Any, Callable, Dict, List, Optional, Union
20
+
21
+ from triton.knobs import JITHook, LaunchHook
22
+
23
+ from .shared_vars import DEFAULT_TRACE_FILE_PREFIX
24
+
25
+
26
+ log = logging.getLogger(__name__)
27
+
28
+ TEXT_FILE_EXTENSIONS = [".ttir", ".ttgir", ".llir", ".ptx", ".amdgcn", ".json"]
29
+ MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB limit for file content extraction
30
+ # Enable ndjson output. json format is only for debugging purpose.
31
+ TRITONPARSE_NDJSON = os.getenv("TRITONPARSE_NDJSON", "1") in ["1", "true", "True"]
32
+ # Enable gzip compression for each line in trace files
33
+ TRITON_TRACE_GZIP = os.getenv("TRITON_TRACE_GZIP", "0") in ["1", "true", "True"]
34
+ triton_trace_log = logging.getLogger("tritonparse_trace")
35
+ # The folder to store the triton trace log.
36
+ triton_trace_folder = os.environ.get("TRITON_TRACE", None)
37
+ # Enable debug logging for tritonparse itself
38
+ TRITONPARSE_DEBUG = os.getenv("TRITONPARSE_DEBUG", None) in ["1", "true", "True"]
39
+ # Kernel allowlist for filtering traced kernels. Use comma separated list of fnmatch patterns.
40
+ TRITONPARSE_KERNEL_ALLOWLIST = os.environ.get("TRITONPARSE_KERNEL_ALLOWLIST", None)
41
+ # Parsed kernel allowlist patterns (set during init)
42
+ _KERNEL_ALLOWLIST_PATTERNS: Optional[List[str]] = None
43
+ # Enable launch trace. WARNNING: it will overwrite launch_metadata function for each triton kernel.
44
+ TRITON_TRACE_LAUNCH = os.getenv("TRITON_TRACE_LAUNCH", None) in ["1", "true", "True"]
45
+ # The flag to mark if launch is traced. It is used to avoid initilizing the launch hook twice.
46
+ _trace_launch_enabled = False
47
+
48
+ TRITON_TRACE_HANDLER = None
49
+ if importlib.util.find_spec("torch") is not None:
50
+ TORCH_INSTALLED = True
51
+ import torch
52
+ from torch.utils._traceback import CapturedTraceback
53
+ else:
54
+ TORCH_INSTALLED = False
55
+
56
+
57
+ class TritonLogRecord(logging.LogRecord):
58
+ """
59
+ Custom LogRecord class for structured logging of Triton operations.
60
+
61
+ Extends the standard LogRecord with additional attributes for storing
62
+ structured metadata and payload information.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ name,
68
+ level,
69
+ pathname,
70
+ lineno,
71
+ msg,
72
+ args,
73
+ exc_info,
74
+ metadata=None,
75
+ payload=None,
76
+ **kwargs,
77
+ ):
78
+ super().__init__(name, level, pathname, lineno, msg, args, exc_info, **kwargs)
79
+ self.metadata: Dict[str, Any] = metadata or {}
80
+ self.payload: Optional[Union[str, Dict[str, Any], list]] = payload
81
+
82
+
83
+ def create_triton_log_record(
84
+ name=None,
85
+ level=logging.DEBUG,
86
+ pathname=None,
87
+ lineno=None,
88
+ msg="",
89
+ args=(),
90
+ exc_info=None,
91
+ metadata=None,
92
+ payload=None,
93
+ **kwargs,
94
+ ):
95
+ """
96
+ Factory method to create TritonLogRecord instances with sensible defaults.
97
+
98
+ Args:
99
+ name (str, optional): Logger name. Defaults to triton_trace_log.name.
100
+ level (int, optional): Log level. Defaults to DEBUG.
101
+ pathname (str, optional): Path to the file where the log call was made. Defaults to current file.
102
+ lineno (int, optional): Line number where the log call was made. Defaults to current line.
103
+ msg (str, optional): Log message. Defaults to empty string.
104
+ args (tuple, optional): Arguments to interpolate into the message. Defaults to empty tuple.
105
+ exc_info (optional): Exception information. Defaults to None.
106
+ metadata (Dict[str, Any], optional): Structured metadata for the log record. Defaults to empty dict.
107
+ payload (optional): Payload data. Defaults to None.
108
+ **kwargs: Additional keyword arguments for LogRecord
109
+
110
+ Returns:
111
+ TritonLogRecord: A custom log record with structured data
112
+ """
113
+ if pathname is None:
114
+ pathname = __file__
115
+ if lineno is None:
116
+ lineno = inspect.currentframe().f_back.f_lineno
117
+ if name is None:
118
+ name = triton_trace_log.name
119
+
120
+ record = TritonLogRecord(
121
+ name,
122
+ level,
123
+ pathname,
124
+ lineno,
125
+ msg,
126
+ args,
127
+ exc_info,
128
+ metadata=metadata,
129
+ payload=payload,
130
+ **kwargs,
131
+ )
132
+ return record
133
+
134
+
135
+ def convert(obj):
136
+ """
137
+ Recursively converts dataclasses, dictionaries, and lists to their serializable forms.
138
+
139
+ Args:
140
+ obj: The object to convert, which can be a dataclass instance, dictionary, list, or any other type
141
+
142
+ Returns:
143
+ A serializable version of the input object where dataclasses are converted to dictionaries
144
+ """
145
+ from triton.language.core import dtype
146
+
147
+ # 1. primitives that JSON already supports -------------------------------
148
+ if obj is None or isinstance(obj, (bool, int, str)):
149
+ return obj
150
+
151
+ if isinstance(obj, float):
152
+ # JSON spec forbids NaN/Infinity – keep precision but stay valid
153
+ if math.isfinite(obj):
154
+ return obj
155
+ return str(obj) # "NaN", "inf", "-inf"
156
+
157
+ # 2. simple containers ----------------------------------------------------
158
+ if isinstance(obj, (list, tuple)):
159
+ # Handle namedtuple specially to preserve field names
160
+ if hasattr(obj, "_asdict"):
161
+ return convert(obj._asdict())
162
+ return [convert(x) for x in obj]
163
+
164
+ if isinstance(obj, (set, frozenset)):
165
+ return [convert(x) for x in sorted(obj, key=str)]
166
+
167
+ if isinstance(obj, Mapping):
168
+ return {str(k): convert(v) for k, v in obj.items()}
169
+
170
+ # 3. time, enum, path, bytes ---------------------------------------------
171
+ if isinstance(obj, (datetime, date)):
172
+ return obj.isoformat()
173
+
174
+ if isinstance(obj, Enum):
175
+ return convert(obj.value)
176
+
177
+ if isinstance(obj, Path):
178
+ return str(obj)
179
+
180
+ if is_dataclass(obj):
181
+ return convert(
182
+ asdict(obj)
183
+ ) # Convert dataclass to dict and then process that dict
184
+
185
+ # 4. Common Triton constexpr objects
186
+ if isinstance(obj, dtype):
187
+ return f"triton.language.core.dtype('{str(obj)}')"
188
+ log.warning(f"Unknown type: {type(obj)}")
189
+ return str(obj) # Return primitive types as-is
190
+
191
+
192
+ def maybe_enable_debug_logging():
193
+ """
194
+ This logging is for logging module itself, not for logging the triton compilation.
195
+ """
196
+ if TRITONPARSE_DEBUG:
197
+ # Always set debug level if TRITONPARSE_DEBUG is set
198
+ log.setLevel(logging.DEBUG)
199
+
200
+ # Prevent propagation to root logger to avoid duplicate messages
201
+ log.propagate = False
202
+
203
+ # Check if we already have a debug handler
204
+ has_debug_handler = any(
205
+ isinstance(handler, logging.StreamHandler)
206
+ and handler.level <= logging.DEBUG
207
+ for handler in log.handlers
208
+ )
209
+
210
+ if not has_debug_handler:
211
+ log_handler = logging.StreamHandler()
212
+ log_handler.setLevel(logging.DEBUG)
213
+ formatter = logging.Formatter("%(asctime)s[%(levelname)s] %(message)s")
214
+ formatter.default_time_format = "%Y%m%d %H:%M:%S"
215
+ formatter.default_msec_format = None
216
+ log_handler.setFormatter(formatter)
217
+ log.addHandler(log_handler)
218
+
219
+
220
+ def get_stack_trace(skip=1):
221
+ """
222
+ Get call stack trace for the current execution context.
223
+
224
+ Extracts stack trace information using torch's CapturedTraceback utility,
225
+ providing detailed information about each frame in the call stack.
226
+
227
+ Args:
228
+ skip (int): Number of frames to skip from the start of the stack
229
+
230
+ Returns:
231
+ List[Dict]: List of frame information dictionaries containing line numbers,
232
+ function names, filenames, and code snippets
233
+ """
234
+ if not TORCH_INSTALLED:
235
+ return []
236
+ frames = []
237
+ for frame in CapturedTraceback.extract(skip=skip).summary():
238
+ frames.append(
239
+ {
240
+ "line": frame.lineno,
241
+ "name": frame.name,
242
+ "filename": frame.filename,
243
+ "loc": frame.line,
244
+ }
245
+ )
246
+ return frames
247
+
248
+
249
+ def parse_kernel_allowlist() -> Optional[List[str]]:
250
+ """
251
+ Parse the kernel allowlist from environment variable.
252
+
253
+ Returns:
254
+ List[str] or None: List of kernel name patterns to trace, or None if all kernels should be traced
255
+ """
256
+ if not TRITONPARSE_KERNEL_ALLOWLIST:
257
+ return None
258
+
259
+ # Split by comma and strip whitespace
260
+ patterns = [pattern.strip() for pattern in TRITONPARSE_KERNEL_ALLOWLIST.split(",")]
261
+ # Filter out empty patterns
262
+ patterns = [pattern for pattern in patterns if pattern]
263
+
264
+ if not patterns:
265
+ return None
266
+
267
+ log.debug(f"Kernel allowlist patterns: {patterns}")
268
+ return patterns
269
+
270
+
271
+ def extract_kernel_name(src) -> Optional[str]:
272
+ """
273
+ Extract kernel name from the source object.
274
+
275
+ Args:
276
+ src (Union[ASTSource, IRSource]): Source object containing kernel information
277
+
278
+ Returns:
279
+ str or None: Kernel name if extractable, None otherwise
280
+ """
281
+ from triton.compiler import IRSource
282
+
283
+ try:
284
+ if isinstance(src, IRSource):
285
+ return src.getattr("name", None)
286
+ else:
287
+ # For ASTSource, get the function name
288
+ if (
289
+ hasattr(src, "fn")
290
+ and hasattr(src.fn, "fn")
291
+ and hasattr(src.fn.fn, "__name__")
292
+ ):
293
+ return src.fn.fn.__name__
294
+ return None
295
+ except Exception as e:
296
+ log.warn(f"Error extracting kernel name: {e}")
297
+ return None
298
+
299
+
300
+ def should_trace_kernel(
301
+ kernel_name: Optional[str], allowlist_patterns: Optional[List[str]]
302
+ ) -> bool:
303
+ """
304
+ Check if a kernel should be traced based on the allowlist.
305
+
306
+ Args:
307
+ kernel_name (str or None): Name of the kernel
308
+ allowlist_patterns (List[str] or None): List of patterns to match against
309
+
310
+ Returns:
311
+ bool: True if the kernel should be traced, False otherwise
312
+ """
313
+ # If no allowlist is set, trace all kernels
314
+ if allowlist_patterns is None:
315
+ return True
316
+
317
+ # If we can't extract kernel name, don't trace (conservative approach)
318
+ if kernel_name is None:
319
+ log.debug("Cannot extract kernel name, skipping trace")
320
+ return False
321
+
322
+ # Check if kernel name matches any pattern in the allowlist
323
+ for pattern in allowlist_patterns:
324
+ if fnmatch.fnmatch(kernel_name, pattern):
325
+ log.debug(f"Kernel '{kernel_name}' matches pattern '{pattern}', will trace")
326
+ return True
327
+
328
+ log.debug(
329
+ f"Kernel '{kernel_name}' does not match any allowlist pattern, skipping trace"
330
+ )
331
+ return False
332
+
333
+
334
+ def extract_python_source_info(trace_data: Dict[str, Any], source):
335
+ """
336
+ Extract Python source code information from the source object and add it to trace_data.
337
+
338
+ This function uses Python's inspect module to extract source code information
339
+ from the provided source object (typically an ASTSource or IRSource instance).
340
+ It adds file path, line numbers, and the actual source code to the trace_data.
341
+
342
+ Args:
343
+ trace_data (Dict[str, Any]): Dictionary to store extracted information
344
+ source (Union[ASTSource, IRSource]): Source object containing kernel function information
345
+ """
346
+ # @TODO: add support for IRSource
347
+ from triton.compiler import IRSource
348
+ from triton.runtime.jit import JITFunction
349
+
350
+ if isinstance(source, IRSource):
351
+ return
352
+
353
+ # Get the original Python source code for the kernel
354
+ if (
355
+ isinstance(fn := source.fn, JITFunction)
356
+ and hasattr(fn, "starting_line_number")
357
+ and hasattr(fn, "raw_src")
358
+ ):
359
+ start_line_number = fn.starting_line_number
360
+ source_lines = fn.raw_src
361
+ else:
362
+ source_lines, start_line_number = inspect.getsourcelines(fn.fn)
363
+
364
+ python_source_file = inspect.getfile(fn.fn)
365
+ end_line_number = start_line_number + len(source_lines)
366
+ trace_data["python_source"] = {
367
+ "file_path": python_source_file,
368
+ "start_line": start_line_number,
369
+ "end_line": end_line_number,
370
+ "code": "".join(source_lines),
371
+ }
372
+
373
+
374
+ def extract_file_content(trace_data: Dict[str, Any], metadata_group: Dict[str, str]):
375
+ """
376
+ Extract file content from metadata_group and add it to trace_data.
377
+
378
+ Args:
379
+ trace_data (Dict): Dictionary to store extracted information
380
+ metadata_group (Dict): Dictionary mapping filenames to file paths
381
+ """
382
+ for ir_filename, file_path in metadata_group.items():
383
+ # Add file path to trace data
384
+ trace_data["file_path"][ir_filename] = file_path
385
+
386
+ # Check if this is a text file we can read
387
+ if any(ir_filename.endswith(ext) for ext in TEXT_FILE_EXTENSIONS):
388
+ try:
389
+ # Check file size before reading to avoid memory issues
390
+ file_size = os.path.getsize(file_path)
391
+ if file_size > MAX_FILE_SIZE:
392
+ message = f"<file too large: {file_size} bytes>"
393
+ trace_data["file_content"][ir_filename] = message
394
+ continue
395
+
396
+ with open(file_path, "r") as f:
397
+ trace_data["file_content"][ir_filename] = f.read()
398
+ except (UnicodeDecodeError, OSError) as e:
399
+ # add more specific error type
400
+ message = f"<error reading file: {str(e)}>"
401
+ trace_data["file_content"][ir_filename] = message
402
+ log.debug(f"Error reading file {file_path}: {e}")
403
+
404
+
405
+ def extract_metadata_from_src(trace_data, src):
406
+ from triton._C.libtriton import get_cache_invalidating_env_vars
407
+
408
+ env_vars = get_cache_invalidating_env_vars()
409
+ # extra_options = src.parse_options()
410
+ # options = backend.parse_options(dict(options or dict(), **extra_options))
411
+
412
+ # trace_data["extra_options"] = extra_options
413
+ trace_data["metadata"].update(
414
+ {
415
+ "env": env_vars,
416
+ "src_attrs": src.attrs if hasattr(src, "attrs") else {},
417
+ "src_cache_key": src.fn.cache_key if hasattr(src, "fn") else "",
418
+ "src_constants": src.constants if hasattr(src, "constants") else {},
419
+ }
420
+ )
421
+
422
+
423
+ class TritonJsonFormatter(logging.Formatter):
424
+ """
425
+ Format log records as JSON for Triton compilation tracing.
426
+
427
+ This formatter converts log records with metadata and payload into NDJSON format,
428
+ suitable for structured logging and later analysis. It handles special attributes
429
+ added by the tritonparse, such as metadata dictionaries and payload data.
430
+ """
431
+
432
+ def format(self, record: logging.LogRecord):
433
+ log_entry = record.metadata
434
+ payload = record.payload
435
+
436
+ log_entry["timestamp"] = self.formatTime(record, "%Y-%m-%dT%H:%M:%S.%fZ")
437
+ if payload is not None:
438
+ log_entry["payload"] = json.loads(payload)
439
+ clean_log_entry = convert(log_entry)
440
+ if not TRITONPARSE_NDJSON:
441
+ return json.dumps(clean_log_entry, indent=2)
442
+ else:
443
+ # NDJSON format requires a newline at the end of each line
444
+ json_str = json.dumps(clean_log_entry, separators=(",", ":"))
445
+ return json_str + "\n"
446
+
447
+
448
+ class TritonTraceHandler(logging.StreamHandler):
449
+ """
450
+ A handler for Triton compilation tracing that outputs NDJSON files.
451
+
452
+ This handler creates and manages log files for Triton kernel compilation traces.
453
+ It supports creating new files for different compilation events and handles
454
+ proper cleanup of file resources. When running in a distributed environment,
455
+ it automatically adds rank information to filenames.
456
+ """
457
+
458
+ def __init__(
459
+ self, root_dir: Optional[str] = None, prefix=DEFAULT_TRACE_FILE_PREFIX
460
+ ):
461
+ logging.Handler.__init__(self)
462
+ self.root_dir = root_dir
463
+ self.prefix = prefix
464
+ self.stream = None
465
+ self.first_record = True
466
+ # If the program is unexpected terminated, atexit can ensure file resources are properly closed and released.
467
+ # it is because we use `self.stream` to keep the opened file stream, if the program is interrupted by some errors, the stream may not be closed.
468
+ atexit.register(self._cleanup)
469
+
470
+ def get_root_dir(self):
471
+ # For meta internal runs, we use the /logs directory by default
472
+ # reference implementation
473
+ # https://github.com/pytorch/pytorch/blob/5fe58ab5bd9e14cce3107150a9956a2ed40d2f79/torch/_logging/_internal.py#L1071
474
+ if self.root_dir:
475
+ return self.root_dir
476
+ TRACE_LOG_DIR = "/logs"
477
+ should_set_root_dir = True
478
+ if TORCH_INSTALLED:
479
+ import torch.version as torch_version
480
+
481
+ if (
482
+ hasattr(torch_version, "git_version")
483
+ and os.getenv("MAST_HPC_JOB_NAME") is None
484
+ ):
485
+ log.info(
486
+ "TritonTraceHandler: disabled because not fbcode or conda on mast"
487
+ )
488
+ should_set_root_dir = False
489
+ # TODO: change to tritonparse knob
490
+ elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"):
491
+ log.info(
492
+ "TritonTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False"
493
+ )
494
+ should_set_root_dir = False
495
+ if should_set_root_dir:
496
+ if not os.path.exists(TRACE_LOG_DIR):
497
+ log.info(
498
+ "TritonTraceHandler: disabled because %s does not exist",
499
+ TRACE_LOG_DIR,
500
+ )
501
+ elif not os.access(TRACE_LOG_DIR, os.W_OK):
502
+ log.info(
503
+ "TritonTraceHandler: disabled because %s is not writeable",
504
+ TRACE_LOG_DIR,
505
+ )
506
+ else:
507
+ self.root_dir = TRACE_LOG_DIR
508
+ return self.root_dir
509
+
510
+ def emit(self, record):
511
+ # reference implementation
512
+ # https://github.com/pytorch/pytorch/blob/5fe58ab5bd9e14cce3107150a9956a2ed40d2f79/torch/_logging/_internal.py#L1071
513
+ try:
514
+ if self.stream is None:
515
+ root_dir = self.get_root_dir()
516
+ if root_dir is not None:
517
+ os.makedirs(root_dir, exist_ok=True)
518
+ ranksuffix = ""
519
+ if TORCH_INSTALLED:
520
+ import torch.distributed as dist
521
+
522
+ if dist.is_available() and dist.is_initialized():
523
+ ranksuffix = f"rank_{dist.get_rank()}_"
524
+ filename = f"{self.prefix}{ranksuffix}"
525
+ self._ensure_stream_closed()
526
+ # Choose file extension and mode based on compression setting
527
+ if TRITON_TRACE_GZIP:
528
+ file_extension = ".bin.ndjson"
529
+ file_mode = "ab+" # Binary mode for gzip member concatenation
530
+ else:
531
+ file_extension = ".ndjson"
532
+ file_mode = "a+"
533
+ log_file_name = os.path.abspath(
534
+ os.path.join(root_dir, f"{filename}{file_extension}")
535
+ )
536
+ self.stream = open(
537
+ log_file_name,
538
+ mode=file_mode,
539
+ )
540
+ log.debug("TritonTraceHandler: logging to %s", log_file_name)
541
+ else:
542
+ triton_trace_log.removeHandler(self)
543
+ return
544
+
545
+ if self.stream:
546
+ formatted = self.format(record)
547
+ if TRITON_TRACE_GZIP:
548
+ # Create a separate gzip member for each record
549
+ # This allows standard gzip readers to handle member concatenation automatically
550
+ buffer = io.BytesIO()
551
+ with gzip.GzipFile(fileobj=buffer, mode="wb") as gz:
552
+ gz.write(formatted.encode("utf-8"))
553
+ # Write the complete gzip member to the file
554
+ compressed_data = buffer.getvalue()
555
+ self.stream.write(compressed_data)
556
+ else:
557
+ self.stream.write(formatted)
558
+ self.flush()
559
+ except Exception as e:
560
+ # record exception and ensure resources are released
561
+ log.error(f"Error in TritonTraceHandler.emit: {e}")
562
+ self._ensure_stream_closed()
563
+ self.handleError(record) # call Handler's standard error handling
564
+
565
+ def close(self):
566
+ """Close the current file."""
567
+ self.acquire()
568
+ try:
569
+ try:
570
+ if self.stream:
571
+ try:
572
+ self.flush()
573
+ finally:
574
+ self.stream.close()
575
+ self.stream = None
576
+ finally:
577
+ # Solution adopted from PyTorch PR #120289
578
+ logging.StreamHandler.close(self)
579
+ finally:
580
+ self.release()
581
+
582
+ def _cleanup(self):
583
+ """Ensure proper cleanup on program exit"""
584
+ if self.stream is not None:
585
+ self.close()
586
+
587
+ def _ensure_stream_closed(self):
588
+ """ensure stream is closed"""
589
+ if self.stream is not None:
590
+ try:
591
+ self.flush()
592
+ finally:
593
+ self.stream.close()
594
+ self.stream = None
595
+
596
+
597
+ def init_logs():
598
+ """
599
+ Initialise tritonparse's logging system.
600
+
601
+ Requirements handled:
602
+ 1. First call may or may not pass `trace_folder`.
603
+ 2. A later call *can* pass `trace_folder` and must activate an
604
+ existing handler whose `root_dir` was None.
605
+ 3. When tracing is disabled (no writable directory), prevent the
606
+ empty →
607
+ DEBUG:tritonparse_trace:
608
+ lines by blocking propagation to the root logger.
609
+ """
610
+ global TRITON_TRACE_HANDLER, triton_trace_folder
611
+
612
+ # Basic logger settings (safe to run on every call)
613
+ triton_trace_log.setLevel(logging.DEBUG)
614
+ triton_trace_log.propagate = False # stops bubbling to root logger. see 3)
615
+ # 1) Create the handler on first use (root_dir may be None)
616
+ if TRITON_TRACE_HANDLER is None:
617
+ TRITON_TRACE_HANDLER = TritonTraceHandler(triton_trace_folder)
618
+ # 2) If the handler has no root_dir but we now have
619
+ # `triton_trace_folder`, fill it in.
620
+ if TRITON_TRACE_HANDLER.root_dir is None and triton_trace_folder is not None:
621
+ TRITON_TRACE_HANDLER.root_dir = triton_trace_folder
622
+ # 3) Re-evaluate whether we have a writable directory
623
+ # (`get_root_dir()` also checks /logs logic, permissions, etc.)
624
+ root_dir = TRITON_TRACE_HANDLER.get_root_dir()
625
+ if root_dir is None:
626
+ # Tracing still disabled: ensure the handler is NOT attached
627
+ if TRITON_TRACE_HANDLER in triton_trace_log.handlers:
628
+ triton_trace_log.removeHandler(TRITON_TRACE_HANDLER)
629
+ return # quiet exit, no blank lines
630
+ # 4) Tracing is enabled: attach the handler (if not already
631
+ # attached) and set the JSON formatter.
632
+ if TRITON_TRACE_HANDLER not in triton_trace_log.handlers:
633
+ TRITON_TRACE_HANDLER.setFormatter(TritonJsonFormatter())
634
+ triton_trace_log.addHandler(TRITON_TRACE_HANDLER)
635
+
636
+
637
+ def trace_structured_triton(
638
+ name: str,
639
+ metadata_fn: Optional[Callable[[], Dict[str, Any]]] = None,
640
+ *,
641
+ payload_fn: Optional[Callable[[], Optional[Union[str, object]]]] = None,
642
+ ):
643
+ """
644
+ Record structured trace information for Triton kernel compilation.
645
+
646
+ This function is the main entry point for logging structured trace events
647
+ in the Triton system. It handles initialization of the logging system if needed,
648
+ creates new log files, and formats the trace data with metadata
649
+ and payload information.
650
+
651
+ Args:
652
+ name (str): Name of the trace event (e.g., "compilation", "execution")
653
+ metadata_fn (Callable): Function that returns a dictionary of metadata to include
654
+ in the trace record
655
+ payload_fn (Callable): Function that returns the payload data (can be a string,
656
+ dictionary, or other serializable object)
657
+ """
658
+
659
+ if metadata_fn is None:
660
+
661
+ def metadata_fn():
662
+ return {}
663
+
664
+ if payload_fn is None:
665
+
666
+ def payload_fn():
667
+ return None
668
+
669
+ metadata_dict: Dict[str, Any] = {"event_type": name}
670
+ metadata_dict["pid"] = os.getpid()
671
+ custom_metadata = metadata_fn()
672
+ if custom_metadata:
673
+ metadata_dict.update(custom_metadata)
674
+
675
+ metadata_dict["stack"] = get_stack_trace()
676
+
677
+ # Log the record using our custom LogRecord
678
+ payload = payload_fn()
679
+ # Use a custom factory to create the record with simplified parameters
680
+ record = create_triton_log_record(metadata=metadata_dict, payload=payload)
681
+ # Log the custom record
682
+ triton_trace_log.handle(record)
683
+
684
+
685
+ def maybe_trace_triton(
686
+ src,
687
+ metadata: Dict[str, Any],
688
+ metadata_group: Dict[str, str],
689
+ times: Any,
690
+ event_type: str = "compilation",
691
+ cache_hit: bool = False,
692
+ ):
693
+ """
694
+ Collect and trace Triton kernel compilation information for debugging and profiling.
695
+
696
+ This function gathers metadata, IR files, and source code information about a Triton
697
+ kernel compilation, then logs it through the tracing system if tracing is enabled.
698
+ It collects information from multiple sources:
699
+ 1. JSON metadata file (if provided)
700
+ 2. PyTorch compilation context (if available)
701
+ 3. IR and other compilation artifact files
702
+ 4. Python source code of the kernel function
703
+
704
+ This function is designed to be used as a CompilationListener in triton.knobs.compilation.listener,
705
+ which now accepts a list of listeners.
706
+
707
+ Args:
708
+ src (Union[ASTSource, IRSource]): Source object containing kernel information
709
+ metadata (Dict[str, Any]): Dictionary containing metadata for the compilation
710
+ metadata_group (Dict[str, Any]): Dictionary mapping filenames to file paths for all compilation artifacts
711
+ times (CompileTimes): Object containing timing information for the compilation
712
+ event_type (str): Type of event being traced (default: "compilation")
713
+ cache_hit (bool): Whether the compilation was a cache hit (default: False)
714
+
715
+ Returns:
716
+ Dict[str, Any]: Dictionary containing all collected trace data, even if tracing is disabled
717
+ """
718
+ # Check kernel allowlist early to avoid unnecessary work
719
+ if _KERNEL_ALLOWLIST_PATTERNS is not None:
720
+ kernel_name = extract_kernel_name(src)
721
+ if not should_trace_kernel(kernel_name, _KERNEL_ALLOWLIST_PATTERNS):
722
+ # Return empty dict to indicate no tracing was done
723
+ return {}
724
+
725
+ # Initialize a dictionary with defaultdict to avoid key errors
726
+ trace_data = defaultdict(dict)
727
+ # Add cache_hit to metadata
728
+ trace_data["metadata"]["cache_hit"] = cache_hit
729
+ if not metadata:
730
+ metadata_path = next(
731
+ (Path(p) for c, p in metadata_group.items() if c.endswith(".json"))
732
+ )
733
+ with open(metadata_path, "r") as f:
734
+ metadata = json.load(f)
735
+ trace_data["metadata"].update(metadata)
736
+ else:
737
+ trace_data["metadata"].update(metadata)
738
+ # Handle torch._guards which might not be recognized by type checker
739
+ if TORCH_INSTALLED:
740
+ trace_id = torch._guards.CompileContext.current_trace_id() # type: ignore
741
+ else:
742
+ trace_id = None
743
+ cid = trace_id.compile_id if trace_id else None
744
+ if cid is not None:
745
+ for attr_name in ["compiled_autograd_id", "frame_id", "frame_compile_id"]:
746
+ attr_value = getattr(cid, attr_name, None)
747
+ if attr_value is not None:
748
+ trace_data["pt_info"][attr_name] = attr_value
749
+ if trace_id:
750
+ trace_data["pt_info"]["attempt"] = trace_id.attempt
751
+ # Extract content from all IR and other files in the metadata group
752
+ extract_file_content(trace_data, metadata_group)
753
+ # Extract Python source code information if available
754
+ extract_python_source_info(trace_data, src)
755
+ extract_metadata_from_src(trace_data, src)
756
+
757
+ # Add timing information if available
758
+ if times:
759
+ trace_data["times"] = times
760
+ # Log the collected information through the tracing system
761
+ trace_structured_triton(
762
+ event_type,
763
+ payload_fn=lambda: json.dumps(convert(trace_data)),
764
+ )
765
+
766
+ return trace_data
767
+
768
+
769
+ def extract_arg_info(arg_dict):
770
+ """
771
+ Extract detailed information from kernel arguments, especially for PyTorch tensors.
772
+
773
+ Args:
774
+ arg_dict: Dictionary of kernel arguments
775
+
776
+ Returns:
777
+ Dictionary with extracted argument information including tensor properties
778
+ """
779
+ extracted_args = {}
780
+
781
+ for arg_name, arg_value in arg_dict.items():
782
+ arg_info = {}
783
+
784
+ # Check if it's a PyTorch tensor
785
+ if TORCH_INSTALLED and isinstance(arg_value, torch.Tensor):
786
+ arg_info["type"] = "tensor"
787
+ arg_info["shape"] = list(arg_value.shape)
788
+ arg_info["dtype"] = str(arg_value.dtype)
789
+ arg_info["device"] = str(arg_value.device)
790
+ arg_info["stride"] = list(arg_value.stride())
791
+ arg_info["numel"] = arg_value.numel()
792
+ arg_info["is_contiguous"] = arg_value.is_contiguous()
793
+ arg_info["element_size"] = arg_value.element_size()
794
+ arg_info["storage_offset"] = arg_value.storage_offset()
795
+ # Memory usage in bytes
796
+ arg_info["memory_usage"] = arg_value.numel() * arg_value.element_size()
797
+ # Add data_ptr for memory tracking (optional)
798
+ if hasattr(arg_value, "data_ptr"):
799
+ arg_info["data_ptr"] = hex(arg_value.data_ptr())
800
+ # Handle scalar values
801
+ elif isinstance(arg_value, (int, float, bool)):
802
+ arg_info["type"] = type(arg_value).__name__
803
+ arg_info["value"] = arg_value
804
+ # Handle strings
805
+ elif isinstance(arg_value, str):
806
+ arg_info["type"] = "str"
807
+ arg_info["value"] = arg_value
808
+ arg_info["length"] = len(arg_value)
809
+ # Handle other types
810
+ else:
811
+ arg_info["type"] = type(arg_value).__name__
812
+ # Try to convert to string for logging
813
+ arg_info["repr"] = str(arg_value)
814
+ if len(arg_info["repr"]) > 200: # Truncate very long representations
815
+ arg_info["repr"] = arg_info["repr"][:200] + "..."
816
+
817
+ extracted_args[arg_name] = arg_info
818
+
819
+ return extracted_args
820
+
821
+
822
+ def add_launch_metadata(grid, metadata, arg_dict):
823
+ # Extract detailed argument information
824
+ extracted_args = extract_arg_info(arg_dict)
825
+ return {"launch_metadata_tritonparse": (grid, metadata._asdict(), extracted_args)}
826
+
827
+
828
+ class JITHookImpl(JITHook):
829
+ """
830
+ JIT Hook implementation that overrides or sets the launch_metadata function for Triton kernels.
831
+
832
+ This hook is essential for capturing detailed kernel launch information beyond the basic
833
+ metadata (like kernel name) that Triton provides by default. Without setting a custom
834
+ launch_metadata function, only minimal launch information is available as shown in:
835
+ https://github.com/triton-lang/triton/blob/7ce287dc24b43476cdeb30529089ac361564505d/python/triton/compiler/compiler.py#L504
836
+
837
+ By intercepting the JIT compilation process and setting a custom launch_metadata function,
838
+ we can capture comprehensive runtime information including grid parameters, kernel metadata,
839
+ and argument dictionaries for detailed analysis and logging.
840
+ """
841
+
842
+ def __call__(
843
+ self,
844
+ *,
845
+ key: str,
846
+ repr: str,
847
+ fn,
848
+ compile,
849
+ is_manual_warmup: bool,
850
+ already_compiled: bool,
851
+ ) -> Optional[bool]:
852
+ """
853
+ Override or set the launch_metadata function for the JIT-compiled kernel.
854
+
855
+ This method is called during the JIT compilation process and allows us to
856
+ inject our custom launch_metadata function that will be used to collect
857
+ detailed kernel launch information.
858
+
859
+ Args:
860
+ key: Unique identifier for the kernel
861
+ repr: String representation of the kernel
862
+ fn: The JIT function object
863
+ compile: Compilation function
864
+ is_manual_warmup: Whether this is a manual warmup call
865
+ already_compiled: Whether the kernel is already compiled
866
+
867
+ Returns:
868
+ True to continue with compilation, None/False to skip
869
+ """
870
+ # Check kernel allowlist early to avoid unnecessary work
871
+ if _KERNEL_ALLOWLIST_PATTERNS is not None:
872
+ kernel_name = fn.name
873
+ if not should_trace_kernel(kernel_name, _KERNEL_ALLOWLIST_PATTERNS):
874
+ # Skip overriding launch_metadata if kernel is not in allowlist
875
+ return True
876
+
877
+ # Get the current launch_metadata function if it exists
878
+ current_launch_metadata = getattr(fn.jit_function, "launch_metadata", None)
879
+ if current_launch_metadata is not None:
880
+ log.warning(
881
+ f"fn {fn} launch_metadata is not None: {current_launch_metadata}. It will be overridden by tritonparse."
882
+ )
883
+ fn.jit_function.launch_metadata = add_launch_metadata
884
+ return True
885
+
886
+
887
+ class LaunchHookImpl(LaunchHook):
888
+ """
889
+ Launch Hook implementation for capturing and logging kernel launch metadata.
890
+
891
+ This hook is responsible for intercepting kernel launches and extracting the detailed
892
+ metadata that was set up by the JITHookImpl. It provides entry point for
893
+ kernel execution, allowing comprehensive logging and analysis of kernel launches
894
+ including timing, parameters, and execution context.
895
+
896
+ The metadata captured includes:
897
+ - Kernel name and function details
898
+ - Grid dimensions and launch parameters
899
+ - Kernel arguments and their values
900
+ - Stream information
901
+ - Custom metadata added by the launch_metadata function
902
+ """
903
+
904
+ def __call__(self, metadata):
905
+ """
906
+ Handle kernel launch entry point.
907
+
908
+ This method is called when a kernel is about to be launched, providing
909
+ access to all the launch metadata for logging, profiling, or analysis.
910
+ metadata format:
911
+
912
+ Args:
913
+ metadata: LazyDict containing comprehensive launch information including
914
+ kernel name, function, stream, grid parameters, and custom data
915
+ format: {'name': 'add_kernel', 'function': None, 'stream': 0,
916
+ 'launch_metadata_tritonparse': (grid, self.metadata, extracted_args)}
917
+ where extracted_args contains detailed info for each argument:
918
+ - For tensors: shape, dtype, device, stride, memory_usage, etc.
919
+ - For scalars: type and value
920
+ - For other types: type and string representation
921
+ defined here:
922
+ https://github.com/triton-lang/triton/blob/7ce287dc24b43476cdeb30529089ac361564505d/
923
+ python/triton/compiler/compiler.py#L512.
924
+ """
925
+ metadata_dict = metadata.get()
926
+ # Check kernel allowlist early to avoid unnecessary work
927
+ if _KERNEL_ALLOWLIST_PATTERNS is not None:
928
+ kernel_name = metadata_dict.get("name")
929
+
930
+ if not should_trace_kernel(kernel_name, _KERNEL_ALLOWLIST_PATTERNS):
931
+ # Skip tracing if kernel is not in allowlist
932
+ return
933
+
934
+ trace_data = defaultdict(dict)
935
+ trace_data["name"] = metadata_dict["name"]
936
+ trace_data["function"] = metadata_dict["function"]
937
+ trace_data["stream"] = metadata_dict["stream"]
938
+ launch_metadata_tritonparse = metadata_dict.get(
939
+ "launch_metadata_tritonparse", None
940
+ )
941
+ if launch_metadata_tritonparse is not None:
942
+ trace_data["grid"] = launch_metadata_tritonparse[0]
943
+ trace_data["compilation_metadata"] = launch_metadata_tritonparse[1]
944
+ trace_data["extracted_args"] = launch_metadata_tritonparse[
945
+ 2
946
+ ] # Now contains detailed arg info
947
+ trace_structured_triton("launch", metadata_fn=lambda: convert(trace_data))
948
+
949
+
950
+ def maybe_enable_trace_launch():
951
+ global _trace_launch_enabled
952
+ if TRITON_TRACE_LAUNCH and not _trace_launch_enabled:
953
+ from triton import knobs
954
+
955
+ launch_hook = LaunchHookImpl()
956
+ jit_hook = JITHookImpl()
957
+ knobs.runtime.jit_post_compile_hook = jit_hook
958
+ knobs.runtime.launch_enter_hook = launch_hook
959
+
960
+ _trace_launch_enabled = True
961
+
962
+
963
+ def init_basic(trace_folder: Optional[str] = None):
964
+ """
965
+ Initialize the basic logging system for Triton compilation.
966
+
967
+ This function sets up the basic logging system for Triton kernel compilation,
968
+
969
+ Args:
970
+ trace_folder (Optional[str]): The folder to store the trace files.
971
+ """
972
+ global triton_trace_folder, _KERNEL_ALLOWLIST_PATTERNS
973
+ maybe_enable_debug_logging()
974
+ if triton_trace_folder is not None and trace_folder is not None:
975
+ log.info(
976
+ "Conflict settings: triton_trace_folder is already set to %s, we will use provided trace_folder(%s) instead.",
977
+ triton_trace_folder,
978
+ trace_folder,
979
+ )
980
+ if trace_folder is not None:
981
+ triton_trace_folder = trace_folder
982
+
983
+ # Parse and store kernel allowlist configuration
984
+ _KERNEL_ALLOWLIST_PATTERNS = parse_kernel_allowlist()
985
+ if _KERNEL_ALLOWLIST_PATTERNS:
986
+ log.debug(
987
+ f"Kernel allowlist enabled with patterns: {_KERNEL_ALLOWLIST_PATTERNS}"
988
+ )
989
+ else:
990
+ log.debug("Kernel allowlist not set, tracing all kernels")
991
+
992
+ init_logs()
993
+ maybe_enable_trace_launch()
994
+
995
+
996
+ def init(trace_folder: Optional[str] = None, enable_trace_launch: bool = False):
997
+ """
998
+ This function is a wrapper around init_basic() that also sets up the compilation listener.
999
+
1000
+ Args:
1001
+ trace_folder (Optional[str]): The folder to store the trace files.
1002
+ enable_trace_launch (bool): Whether to enable the trace launch hook.
1003
+ """
1004
+ global TRITON_TRACE_LAUNCH
1005
+ if enable_trace_launch:
1006
+ TRITON_TRACE_LAUNCH = True
1007
+
1008
+ init_basic(trace_folder)
1009
+ from triton import knobs
1010
+
1011
+ knobs.compilation.listener = maybe_trace_triton
1012
+
1013
+
1014
+ def clear_logging_config():
1015
+ """
1016
+ Clear all configurations made by init() and init_basic().
1017
+
1018
+ This function resets the logging handlers, global state variables,
1019
+ and Triton knobs to their default states, effectively disabling
1020
+ the custom tracing.
1021
+
1022
+ WARNING: This function is not supposed to be called unless you are sure
1023
+ you want to clear the logging config.
1024
+ """
1025
+ global TRITON_TRACE_HANDLER, triton_trace_folder, _KERNEL_ALLOWLIST_PATTERNS
1026
+ global _trace_launch_enabled
1027
+
1028
+ # 1. Clean up the log handler
1029
+ if TRITON_TRACE_HANDLER is not None:
1030
+ if TRITON_TRACE_HANDLER in triton_trace_log.handlers:
1031
+ triton_trace_log.removeHandler(TRITON_TRACE_HANDLER)
1032
+ TRITON_TRACE_HANDLER.close()
1033
+ TRITON_TRACE_HANDLER = None
1034
+
1035
+ # 2. Reset global state variables
1036
+ triton_trace_folder = None
1037
+ _KERNEL_ALLOWLIST_PATTERNS = None
1038
+ _trace_launch_enabled = False
1039
+
1040
+ # 3. Reset Triton knobs
1041
+ # Check if triton was actually imported and used
1042
+ from triton import knobs
1043
+
1044
+ knobs.compilation.listener = None
1045
+ knobs.runtime.jit_post_compile_hook = None
1046
+ knobs.runtime.launch_enter_hook = None