tritonparse 0.3.2.dev20251210071601__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritonparse might be problematic. Click here for more details.

Files changed (62) hide show
  1. tritonparse/__init__.py +0 -0
  2. tritonparse/__main__.py +7 -0
  3. tritonparse/cli.py +110 -0
  4. tritonparse/common.py +409 -0
  5. tritonparse/context_manager.py +64 -0
  6. tritonparse/event_diff.py +122 -0
  7. tritonparse/extract_source_mappings.py +49 -0
  8. tritonparse/info/__init__.py +30 -0
  9. tritonparse/info/cli.py +121 -0
  10. tritonparse/info/kernel_query.py +209 -0
  11. tritonparse/info/parse_helper.py +70 -0
  12. tritonparse/ir_analysis.py +427 -0
  13. tritonparse/ir_parser.py +365 -0
  14. tritonparse/mapper.py +102 -0
  15. tritonparse/reproducer/__init__.py +0 -0
  16. tritonparse/reproducer/ast_analyzer.py +636 -0
  17. tritonparse/reproducer/cli.py +72 -0
  18. tritonparse/reproducer/consolidated_result.py +52 -0
  19. tritonparse/reproducer/function_extractor.py +228 -0
  20. tritonparse/reproducer/import_info.py +25 -0
  21. tritonparse/reproducer/import_parser.py +178 -0
  22. tritonparse/reproducer/import_resolver.py +151 -0
  23. tritonparse/reproducer/ingestion/ndjson.py +237 -0
  24. tritonparse/reproducer/multi_file_analyzer.py +824 -0
  25. tritonparse/reproducer/orchestrator.py +110 -0
  26. tritonparse/reproducer/placeholder_replacer.py +335 -0
  27. tritonparse/reproducer/templates/__init__.py +0 -0
  28. tritonparse/reproducer/templates/example.py +38 -0
  29. tritonparse/reproducer/templates/loader.py +59 -0
  30. tritonparse/reproducer/templates/tritonbench.py +106 -0
  31. tritonparse/reproducer/templates/utils.py +48 -0
  32. tritonparse/reproducer/tests/__init__.py +0 -0
  33. tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
  34. tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
  35. tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
  36. tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
  37. tritonparse/reproducer/tests/test_import_parser.py +164 -0
  38. tritonparse/reproducer/tests/test_import_resolver.py +88 -0
  39. tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
  40. tritonparse/reproducer/types.py +20 -0
  41. tritonparse/reproducer/utils.py +580 -0
  42. tritonparse/shared_vars.py +12 -0
  43. tritonparse/source_type.py +56 -0
  44. tritonparse/sourcemap_utils.py +96 -0
  45. tritonparse/structured_logging.py +1634 -0
  46. tritonparse/tools/__init__.py +0 -0
  47. tritonparse/tools/decompress_bin_ndjson.py +120 -0
  48. tritonparse/tools/disasm.py +81 -0
  49. tritonparse/tools/extract_irs.py +244 -0
  50. tritonparse/tools/format_fix.py +151 -0
  51. tritonparse/tools/load_tensor.py +76 -0
  52. tritonparse/tools/prettify_ndjson.py +334 -0
  53. tritonparse/tools/readme.md +37 -0
  54. tritonparse/tp_logger.py +9 -0
  55. tritonparse/trace_processor.py +367 -0
  56. tritonparse/utils.py +155 -0
  57. tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
  58. tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
  59. tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
  60. tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
  61. tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
  62. tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1634 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import atexit
4
+ import fnmatch
5
+ import gzip
6
+ import hashlib
7
+ import importlib
8
+ import inspect
9
+ import io
10
+ import json
11
+ import logging
12
+ import math
13
+ import os
14
+ import subprocess
15
+ import tempfile
16
+ from collections import defaultdict
17
+ from collections.abc import Mapping
18
+ from dataclasses import asdict, is_dataclass
19
+ from datetime import date, datetime
20
+ from enum import Enum
21
+ from functools import partial
22
+ from pathlib import Path
23
+ from typing import Any, Callable, Dict, List, Optional, Union
24
+
25
+ from triton.knobs import JITHook, LaunchHook
26
+
27
+ from .shared_vars import DEFAULT_TRACE_FILE_PREFIX
28
+
29
+
30
+ log = logging.getLogger(__name__)
31
+
32
+ TEXT_FILE_EXTENSIONS = [".ttir", ".ttgir", ".llir", ".ptx", ".amdgcn", ".json"]
33
+ MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB limit for file content extraction
34
+ # Enable gzip compression for each line in trace files
35
+ TRITON_TRACE_GZIP = os.getenv("TRITON_TRACE_GZIP", "0") in ["1", "true", "True"]
36
+ triton_trace_log = logging.getLogger("tritonparse_trace")
37
+ # The folder to store the triton trace log.
38
+ triton_trace_folder = os.environ.get("TRITON_TRACE", None)
39
+ # Enable debug logging for tritonparse itself
40
+ TRITONPARSE_DEBUG = os.getenv("TRITONPARSE_DEBUG", None) in ["1", "true", "True"]
41
+ # Kernel allowlist for filtering traced kernels. Use comma separated list of fnmatch patterns.
42
+ TRITONPARSE_KERNEL_ALLOWLIST = os.environ.get("TRITONPARSE_KERNEL_ALLOWLIST", None)
43
+ # Parsed kernel allowlist patterns (set during init)
44
+ _KERNEL_ALLOWLIST_PATTERNS: Optional[List[str]] = None
45
+ # Enable launch trace. WARNNING: it will overwrite launch_metadata function for each triton kernel.
46
+ TRITON_TRACE_LAUNCH = os.getenv("TRITON_TRACE_LAUNCH", None) in ["1", "true", "True"]
47
+ # Enable more tensor information collection in trace logs.
48
+ TRITONPARSE_MORE_TENSOR_INFORMATION = os.getenv(
49
+ "TRITONPARSE_MORE_TENSOR_INFORMATION", None
50
+ ) in ["1", "true", "True"]
51
+ # Enable full Python source file extraction instead of just the function definition
52
+ TRITON_FULL_PYTHON_SOURCE = os.getenv("TRITON_FULL_PYTHON_SOURCE", "0") in [
53
+ "1",
54
+ "true",
55
+ "True",
56
+ ]
57
+ # Maximum file size for full source extraction (default 10MB)
58
+ TRITON_MAX_SOURCE_SIZE = int(os.getenv("TRITON_MAX_SOURCE_SIZE", str(10 * 1024 * 1024)))
59
+ # Inductor compiled kernel's launch tracing needs this flag to be set.
60
+ # If TRITON_TRACE_LAUNCH is enabled, also enable TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK
61
+ TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK = (
62
+ os.getenv("TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK", None) in ["1", "true", "True"]
63
+ or TRITON_TRACE_LAUNCH
64
+ )
65
+ # Enable NVIDIA SASS dump. It requires the CUBIN file to be localable.
66
+ # WARNNING: it will slow down the compilation significantly.
67
+ TRITONPARSE_DUMP_SASS = os.getenv("TRITONPARSE_DUMP_SASS", None) in [
68
+ "1",
69
+ "true",
70
+ "True",
71
+ ]
72
+
73
+ # The flag to mark if launch is traced. It is used to avoid initilizing the launch hook twice.
74
+ _trace_launch_enabled = False
75
+ # Enable tensor blob storage
76
+ TRITONPARSE_SAVE_TENSOR_BLOBS = os.getenv("TRITONPARSE_SAVE_TENSOR_BLOBS", "0") in [
77
+ "1",
78
+ "true",
79
+ "True",
80
+ ]
81
+ # Tensor size limit in bytes (default 10GB)
82
+ TRITONPARSE_TENSOR_SIZE_LIMIT = int(
83
+ os.getenv("TRITONPARSE_TENSOR_SIZE_LIMIT", str(10 * 1024 * 1024 * 1024))
84
+ )
85
+ # Tensor storage quota in bytes (default 100GB) - tracks compressed size for current run
86
+ TRITONPARSE_TENSOR_STORAGE_QUOTA = int(
87
+ os.getenv("TRITONPARSE_TENSOR_STORAGE_QUOTA", str(100 * 1024 * 1024 * 1024))
88
+ )
89
+ # Compression threshold in bytes (default 1MB) - only compress blobs >= this size
90
+ TRITONPARSE_COMPRESSION_THRESHOLD = 1 * 1024 * 1024
91
+ # Compression level for gzip (0-9, higher = better compression but slower)
92
+ TRITONPARSE_COMPRESSION_LEVEL = 4
93
+ # Log statistics every N saved blobs
94
+ TRITONPARSE_STATS_LOG_FREQUENCY = 100
95
+
96
+ TRITON_TRACE_HANDLER = None
97
+ # Global tensor blob manager instance
98
+ TENSOR_BLOB_MANAGER = None
99
+
100
+ if importlib.util.find_spec("torch") is not None:
101
+ TORCH_INSTALLED = True
102
+ import torch
103
+ from torch.utils._traceback import CapturedTraceback
104
+ else:
105
+ TORCH_INSTALLED = False
106
+
107
+
108
+ class TensorBlobManager:
109
+ """
110
+ Manager for storing tensor data as content-addressed blobs.
111
+
112
+ Uses BLAKE2b hashing for content addressing and stores blobs in a two-level
113
+ directory structure to avoid filesystem limitations with large numbers of files.
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ root_dir: Optional[str] = None,
119
+ storage_quota: Optional[int] = None,
120
+ ):
121
+ self.root_dir = None
122
+ self.hash_to_path_cache = {} # In-memory cache for hash -> path mapping
123
+ self.compression_threshold = TRITONPARSE_COMPRESSION_THRESHOLD
124
+ self.storage_quota = (
125
+ storage_quota
126
+ if storage_quota is not None
127
+ else TRITONPARSE_TENSOR_STORAGE_QUOTA
128
+ )
129
+
130
+ # Resource statistics (tracks current run only)
131
+ self.total_compressed_bytes = 0 # Total compressed size written in this run
132
+ self.total_uncompressed_bytes = (
133
+ 0 # Total uncompressed size (for compression ratio)
134
+ )
135
+ self.blob_count = 0 # Total blob references (including dedup hits)
136
+ self.blob_saved_count = 0 # Actual blobs saved (excluding dedup hits)
137
+ self.storage_disabled = False # Whether storage has been disabled due to quota
138
+ self.storage_disabled_reason = None # Reason for disabling storage
139
+
140
+ if root_dir:
141
+ self.set_root_dir(root_dir)
142
+
143
+ def set_root_dir(self, root_dir: str):
144
+ """Set the root directory for blob storage."""
145
+ self.root_dir = Path(root_dir) / "saved_tensors"
146
+ self.root_dir.mkdir(parents=True, exist_ok=True)
147
+ log.debug(f"TensorBlobManager: using root directory {self.root_dir}")
148
+
149
+ def _compute_hash(self, data: bytes) -> str:
150
+ """Compute BLAKE2b hash of the data."""
151
+ return hashlib.blake2b(data).hexdigest()
152
+
153
+ def _get_blob_path(self, hash_hex: str, extension: str = ".bin.gz") -> Path:
154
+ """Get the file path for a given hash using two-level directory structure."""
155
+ if not self.root_dir:
156
+ raise ValueError("Root directory not set")
157
+
158
+ # Two-level directory: first 2 chars / full_hash{extension}
159
+ subdir = hash_hex[:2]
160
+ filename = f"{hash_hex}{extension}"
161
+ return (self.root_dir / subdir / filename).resolve()
162
+
163
+ def _get_tensor_size_bytes(self, tensor) -> int:
164
+ """Get tensor size in bytes before serialization."""
165
+ if hasattr(tensor, "numel") and hasattr(tensor, "element_size"):
166
+ return tensor.numel() * tensor.element_size()
167
+ return 0
168
+
169
+ def _log_statistics(self, final: bool = False):
170
+ """Print statistics about tensor blob storage.
171
+
172
+ Args:
173
+ final: If True, this is the final statistics message (e.g., when storage is disabled)
174
+ """
175
+ prefix = "📊 Final" if final else "📊"
176
+ compression_ratio = (
177
+ self.total_uncompressed_bytes / max(1, self.total_compressed_bytes)
178
+ if self.total_compressed_bytes > 0
179
+ else 0.0
180
+ )
181
+ dedup_count = self.blob_count - self.blob_saved_count
182
+
183
+ log.info(
184
+ f"{prefix} Tensor blob stats: "
185
+ f"{self.blob_saved_count} saved ({self.blob_count} total, {dedup_count} dedup), "
186
+ f"{self.total_compressed_bytes / 1024**3:.2f}GB compressed "
187
+ f"({self.total_uncompressed_bytes / 1024**3:.2f}GB uncompressed), "
188
+ f"compression ratio: {compression_ratio:.2f}x"
189
+ )
190
+
191
+ def _disable_storage(self, reason: str):
192
+ """Disable blob storage and log warning with statistics.
193
+
194
+ Args:
195
+ reason: The reason why storage is being disabled
196
+ """
197
+ if not self.storage_disabled: # Only disable once
198
+ self.storage_disabled = True
199
+ self.storage_disabled_reason = reason
200
+ log.warning(f"⚠️ TENSOR BLOB STORAGE DISABLED: {reason}")
201
+ self._log_statistics(final=True)
202
+
203
+ def save_tensor_blob(self, tensor) -> Dict[str, Any]:
204
+ """
205
+ Save tensor as a blob and return metadata.
206
+
207
+ Args:
208
+ tensor: PyTorch tensor to save
209
+
210
+ Returns:
211
+ Dictionary with blob metadata or error information:
212
+ - Success: {'tensor_hash': str, 'blob_path': str, 'blob_size': int,
213
+ 'blob_size_uncompressed': int, 'compression': str,
214
+ 'compression_ratio': float, 'serialization_method': str}
215
+ - Dedup hit: Same as success but from cache (not counted in quota)
216
+ - Error: {'error': str, 'tensor_hash': None}
217
+ """
218
+ # Early exit: Check if storage is disabled
219
+ if self.storage_disabled:
220
+ return {"error": self.storage_disabled_reason, "tensor_hash": None}
221
+
222
+ # Early exit: Check if root directory is set
223
+ if not self.root_dir:
224
+ return {"error": "Blob storage not initialized", "tensor_hash": None}
225
+
226
+ try:
227
+ # Check tensor size before serialization
228
+ tensor_size = self._get_tensor_size_bytes(tensor)
229
+ if tensor_size > TRITONPARSE_TENSOR_SIZE_LIMIT:
230
+ log.warning(
231
+ f"Tensor size {tensor_size} bytes exceeds limit {TRITONPARSE_TENSOR_SIZE_LIMIT} bytes, skipping blob storage"
232
+ )
233
+ return {
234
+ "error": f"Tensor size {tensor_size} bytes exceeds limit {TRITONPARSE_TENSOR_SIZE_LIMIT} bytes",
235
+ "tensor_hash": None,
236
+ }
237
+
238
+ # Serialize tensor using torch.save
239
+ import io
240
+
241
+ buffer = io.BytesIO()
242
+ if TORCH_INSTALLED:
243
+ torch.save(tensor.cpu(), buffer)
244
+ else:
245
+ return {
246
+ "error": "PyTorch not available for tensor serialization",
247
+ "tensor_hash": None,
248
+ }
249
+
250
+ blob_data = buffer.getvalue()
251
+ uncompressed_size = len(blob_data)
252
+
253
+ # Compute hash on uncompressed data for content addressing
254
+ hash_hex = self._compute_hash(blob_data)
255
+
256
+ # Check for deduplication (before compression to save work)
257
+ if hash_hex in self.hash_to_path_cache:
258
+ blob_path = self.hash_to_path_cache[hash_hex]
259
+ try:
260
+ # Try to access the file - handles race condition where file might be deleted
261
+ disk_size = blob_path.stat().st_size
262
+ compression = (
263
+ "gzip" if str(blob_path).endswith(".bin.gz") else "none"
264
+ )
265
+ compression_ratio = uncompressed_size / max(1, disk_size)
266
+
267
+ # Deduplication hit - increment count but don't add to quota
268
+ self.blob_count += 1
269
+
270
+ return {
271
+ "tensor_hash": hash_hex,
272
+ "blob_path": str(blob_path),
273
+ "blob_size": disk_size,
274
+ "blob_size_uncompressed": uncompressed_size,
275
+ "compression": compression,
276
+ "compression_ratio": compression_ratio,
277
+ "serialization_method": "torch_save",
278
+ "deduplicated": True,
279
+ }
280
+ except (FileNotFoundError, OSError):
281
+ # File was deleted or inaccessible - remove from cache and continue to save
282
+ log.debug(
283
+ f"Cached blob file no longer exists: {blob_path}, will re-save"
284
+ )
285
+ self.hash_to_path_cache.pop(hash_hex, None)
286
+
287
+ # Decide whether to compress based on size threshold
288
+ if uncompressed_size >= self.compression_threshold:
289
+ # Compress the data
290
+ data_to_write = gzip.compress(
291
+ blob_data, compresslevel=TRITONPARSE_COMPRESSION_LEVEL
292
+ )
293
+ file_extension = ".bin.gz"
294
+ compression = "gzip"
295
+ else:
296
+ # Don't compress small files (overhead not worth it)
297
+ data_to_write = blob_data
298
+ file_extension = ".bin"
299
+ compression = "none"
300
+
301
+ disk_size = len(data_to_write)
302
+
303
+ # Check quota BEFORE writing
304
+ if self.total_compressed_bytes + disk_size > self.storage_quota:
305
+ self._disable_storage(
306
+ f"Storage quota would be exceeded: "
307
+ f"{(self.total_compressed_bytes + disk_size) / 1024**3:.2f}GB > "
308
+ f"{self.storage_quota / 1024**3:.2f}GB limit"
309
+ )
310
+ return {"error": self.storage_disabled_reason, "tensor_hash": None}
311
+
312
+ # Create blob file path with appropriate extension
313
+ blob_path = self._get_blob_path(hash_hex, extension=file_extension)
314
+ blob_path.parent.mkdir(parents=True, exist_ok=True)
315
+
316
+ # Atomic write using temporary file + rename
317
+ with tempfile.NamedTemporaryFile(
318
+ mode="wb",
319
+ dir=blob_path.parent,
320
+ prefix=f".tmp_{hash_hex}_",
321
+ delete=False,
322
+ ) as tmp_file:
323
+ tmp_file.write(data_to_write)
324
+ tmp_path = Path(tmp_file.name)
325
+
326
+ # Atomic rename
327
+ tmp_path.rename(blob_path)
328
+
329
+ # Update cache and statistics
330
+ self.hash_to_path_cache[hash_hex] = blob_path
331
+ self.total_compressed_bytes += disk_size
332
+ self.total_uncompressed_bytes += uncompressed_size
333
+ self.blob_count += 1
334
+ self.blob_saved_count += 1
335
+
336
+ # Log progress periodically
337
+ if self.blob_saved_count % TRITONPARSE_STATS_LOG_FREQUENCY == 0:
338
+ self._log_statistics()
339
+
340
+ log.debug(
341
+ f"Saved tensor blob: {hash_hex} -> {blob_path} ({disk_size} bytes, compression={compression})"
342
+ )
343
+
344
+ compression_ratio = uncompressed_size / max(1, disk_size)
345
+
346
+ return {
347
+ "tensor_hash": hash_hex,
348
+ "blob_path": str(blob_path),
349
+ "blob_size": disk_size,
350
+ "blob_size_uncompressed": uncompressed_size,
351
+ "compression": compression,
352
+ "compression_ratio": compression_ratio,
353
+ "serialization_method": "torch_save",
354
+ }
355
+
356
+ except OSError as e:
357
+ # Disk full, permission errors, etc. - disable storage to avoid repeated failures
358
+ error_msg = f"Failed to save tensor blob (I/O error): {str(e)}"
359
+ log.error(error_msg)
360
+ self._disable_storage(error_msg)
361
+ return {"error": error_msg, "tensor_hash": None}
362
+ except Exception as e:
363
+ # Other unexpected errors - log but don't disable storage
364
+ error_msg = f"Failed to save tensor blob: {str(e)}"
365
+ log.error(error_msg)
366
+ return {"error": error_msg, "tensor_hash": None}
367
+
368
+
369
+ class TritonLogRecord(logging.LogRecord):
370
+ """
371
+ Custom LogRecord class for structured logging of Triton operations.
372
+
373
+ Extends the standard LogRecord with additional attributes for storing
374
+ structured metadata and payload information.
375
+ """
376
+
377
+ def __init__(
378
+ self,
379
+ name,
380
+ level,
381
+ pathname,
382
+ lineno,
383
+ msg,
384
+ args,
385
+ exc_info,
386
+ metadata=None,
387
+ payload=None,
388
+ **kwargs,
389
+ ):
390
+ super().__init__(name, level, pathname, lineno, msg, args, exc_info, **kwargs)
391
+ self.metadata: Dict[str, Any] = metadata or {}
392
+ self.payload: Optional[Union[str, Dict[str, Any], list]] = payload
393
+
394
+
395
+ def create_triton_log_record(
396
+ name=None,
397
+ level=logging.DEBUG,
398
+ pathname=None,
399
+ lineno=None,
400
+ msg="",
401
+ args=(),
402
+ exc_info=None,
403
+ metadata=None,
404
+ payload=None,
405
+ **kwargs,
406
+ ):
407
+ """
408
+ Factory method to create TritonLogRecord instances with sensible defaults.
409
+
410
+ Args:
411
+ name (str, optional): Logger name. Defaults to triton_trace_log.name.
412
+ level (int, optional): Log level. Defaults to DEBUG.
413
+ pathname (str, optional): Path to the file where the log call was made. Defaults to current file.
414
+ lineno (int, optional): Line number where the log call was made. Defaults to current line.
415
+ msg (str, optional): Log message. Defaults to empty string.
416
+ args (tuple, optional): Arguments to interpolate into the message. Defaults to empty tuple.
417
+ exc_info (optional): Exception information. Defaults to None.
418
+ metadata (Dict[str, Any], optional): Structured metadata for the log record. Defaults to empty dict.
419
+ payload (optional): Payload data. Defaults to None.
420
+ **kwargs: Additional keyword arguments for LogRecord
421
+
422
+ Returns:
423
+ TritonLogRecord: A custom log record with structured data
424
+ """
425
+ if pathname is None:
426
+ pathname = __file__
427
+ if lineno is None:
428
+ lineno = inspect.currentframe().f_back.f_lineno
429
+ if name is None:
430
+ name = triton_trace_log.name
431
+
432
+ record = TritonLogRecord(
433
+ name,
434
+ level,
435
+ pathname,
436
+ lineno,
437
+ msg,
438
+ args,
439
+ exc_info,
440
+ metadata=metadata,
441
+ payload=payload,
442
+ **kwargs,
443
+ )
444
+ return record
445
+
446
+
447
+ def convert(obj):
448
+ """
449
+ Recursively converts dataclasses, dictionaries, and lists to their serializable forms.
450
+
451
+ Args:
452
+ obj: The object to convert, which can be a dataclass instance, dictionary, list, or any other type
453
+
454
+ Returns:
455
+ A serializable version of the input object where dataclasses are converted to dictionaries
456
+ """
457
+ from triton.language.core import dtype
458
+
459
+ # 1. primitives that JSON already supports -------------------------------
460
+ if obj is None or isinstance(obj, (bool, int, str)):
461
+ return obj
462
+
463
+ if isinstance(obj, float):
464
+ # JSON spec forbids NaN/Infinity – keep precision but stay valid
465
+ if math.isfinite(obj):
466
+ return obj
467
+ return str(obj) # "NaN", "inf", "-inf"
468
+
469
+ # 2. simple containers ----------------------------------------------------
470
+ if isinstance(obj, (list, tuple)):
471
+ # Handle namedtuple specially to preserve field names
472
+ if hasattr(obj, "_asdict"):
473
+ return convert(obj._asdict())
474
+ return [convert(x) for x in obj]
475
+
476
+ if isinstance(obj, (set, frozenset)):
477
+ return [convert(x) for x in sorted(obj, key=str)]
478
+
479
+ if isinstance(obj, Mapping):
480
+ return {str(k): convert(v) for k, v in obj.items()}
481
+
482
+ # 3. time, enum, path, bytes ---------------------------------------------
483
+ if isinstance(obj, (datetime, date)):
484
+ return obj.isoformat()
485
+
486
+ if isinstance(obj, Enum):
487
+ return convert(obj.value)
488
+
489
+ if isinstance(obj, Path):
490
+ return str(obj)
491
+
492
+ if is_dataclass(obj):
493
+ return convert(
494
+ asdict(obj)
495
+ ) # Convert dataclass to dict and then process that dict
496
+
497
+ if _is_triton_kernels_layout(obj):
498
+ layout_info = {"type": type(obj).__name__}
499
+ if hasattr(obj, "initial_shape"):
500
+ layout_info["initial_shape"] = convert(obj.initial_shape)
501
+ if hasattr(obj, "name"):
502
+ layout_info["name"] = convert(obj.name)
503
+ return layout_info
504
+
505
+ # 4. Common Triton constexpr objects
506
+ if isinstance(obj, dtype):
507
+ return f"triton.language.core.dtype('{str(obj)}')"
508
+
509
+ if TORCH_INSTALLED and isinstance(obj, torch.dtype):
510
+ return str(obj)
511
+
512
+ log.warning(f"Unknown type: {type(obj)}")
513
+ return str(obj) # Return primitive types as-is
514
+
515
+
516
+ def _is_triton_kernels_layout(obj):
517
+ """
518
+ Check if an object is an instance of a Layout class from a
519
+ triton_kernels module by checking its MRO.
520
+ """
521
+ t = type(obj)
522
+ for base_class in t.__mro__:
523
+ module_name = getattr(base_class, "__module__", "")
524
+ type_name = getattr(base_class, "__name__", "")
525
+ if type_name == "Layout" and module_name.startswith("triton_kernels"):
526
+ return True
527
+ return False
528
+
529
+
530
+ def _is_from_triton_kernels_module(obj):
531
+ """
532
+ Check if an object is an instance of Tensor or Storage from a
533
+ triton_kernels module.
534
+ """
535
+ t = type(obj)
536
+ module_name = getattr(t, "__module__", "")
537
+ type_name = getattr(t, "__name__", "")
538
+ return type_name in ("Tensor", "Storage") and module_name.startswith(
539
+ "triton_kernels"
540
+ )
541
+
542
+
543
+ def _log_torch_tensor_info(tensor_value):
544
+ """
545
+ Extracts metadata from a torch.Tensor object.
546
+
547
+ Args:
548
+ tensor_value (torch.Tensor): The tensor to extract information from.
549
+
550
+ Returns:
551
+ dict: A dictionary containing tensor metadata.
552
+ """
553
+ arg_info = {}
554
+ arg_info["type"] = "tensor"
555
+ arg_info["shape"] = list(tensor_value.shape)
556
+ arg_info["dtype"] = str(tensor_value.dtype)
557
+ arg_info["device"] = str(tensor_value.device)
558
+ arg_info["stride"] = list(tensor_value.stride())
559
+ arg_info["numel"] = tensor_value.numel()
560
+ arg_info["is_contiguous"] = tensor_value.is_contiguous()
561
+ arg_info["element_size"] = tensor_value.element_size()
562
+ arg_info["storage_offset"] = tensor_value.storage_offset()
563
+ # Memory usage in bytes
564
+ arg_info["memory_usage"] = tensor_value.numel() * tensor_value.element_size()
565
+ # Add data_ptr for memory tracking (optional)
566
+ if hasattr(tensor_value, "data_ptr"):
567
+ arg_info["data_ptr"] = hex(tensor_value.data_ptr())
568
+ if TRITONPARSE_MORE_TENSOR_INFORMATION:
569
+ try:
570
+ # Convert to float for reliable statistics computation across all dtypes
571
+ # This creates a new tensor without modifying the original
572
+ float_tensor = tensor_value.float()
573
+ arg_info["min"] = float_tensor.min().item()
574
+ arg_info["max"] = float_tensor.max().item()
575
+ arg_info["mean"] = float_tensor.mean().item()
576
+ arg_info["std"] = float_tensor.std().item()
577
+ except (RuntimeError, ValueError, TypeError) as e:
578
+ log.error(f"Unable to compute tensor statistics: {e}")
579
+ arg_info["tensor_capture_error"] = str(e)
580
+
581
+ # Add tensor blob storage if enabled
582
+ if TRITONPARSE_SAVE_TENSOR_BLOBS and TENSOR_BLOB_MANAGER is not None:
583
+ blob_info = TENSOR_BLOB_MANAGER.save_tensor_blob(tensor_value)
584
+ arg_info.update(blob_info)
585
+ return arg_info
586
+
587
+
588
+ def maybe_enable_debug_logging():
589
+ """
590
+ This logging is for logging module itself, not for logging the triton compilation.
591
+ """
592
+ if TRITONPARSE_DEBUG:
593
+ # Always set debug level if TRITONPARSE_DEBUG is set
594
+ log.setLevel(logging.DEBUG)
595
+
596
+ # Prevent propagation to root logger to avoid duplicate messages
597
+ log.propagate = False
598
+
599
+ # Check if we already have a debug handler
600
+ has_debug_handler = any(
601
+ isinstance(handler, logging.StreamHandler)
602
+ and handler.level <= logging.DEBUG
603
+ for handler in log.handlers
604
+ )
605
+
606
+ if not has_debug_handler:
607
+ log_handler = logging.StreamHandler()
608
+ log_handler.setLevel(logging.DEBUG)
609
+ formatter = logging.Formatter("%(asctime)s[%(levelname)s] %(message)s")
610
+ formatter.default_time_format = "%Y%m%d %H:%M:%S"
611
+ formatter.default_msec_format = None
612
+ log_handler.setFormatter(formatter)
613
+ log.addHandler(log_handler)
614
+
615
+
616
+ def get_stack_trace(skip=1):
617
+ """
618
+ Get call stack trace for the current execution context.
619
+
620
+ Extracts stack trace information using torch's CapturedTraceback utility,
621
+ providing detailed information about each frame in the call stack.
622
+
623
+ Args:
624
+ skip (int): Number of frames to skip from the start of the stack
625
+
626
+ Returns:
627
+ List[Dict]: List of frame information dictionaries containing line numbers,
628
+ function names, filenames, and code snippets
629
+ """
630
+ if not TORCH_INSTALLED:
631
+ return []
632
+ frames = []
633
+ for frame in CapturedTraceback.extract(skip=skip).summary():
634
+ frames.append(
635
+ {
636
+ "line": frame.lineno,
637
+ "name": frame.name,
638
+ "filename": frame.filename,
639
+ "loc": frame.line,
640
+ }
641
+ )
642
+ return frames
643
+
644
+
645
+ def parse_kernel_allowlist() -> Optional[List[str]]:
646
+ """
647
+ Parse the kernel allowlist from environment variable.
648
+
649
+ Returns:
650
+ List[str] or None: List of kernel name patterns to trace, or None if all kernels should be traced
651
+ """
652
+ if not TRITONPARSE_KERNEL_ALLOWLIST:
653
+ return None
654
+
655
+ # Split by comma and strip whitespace
656
+ patterns = [pattern.strip() for pattern in TRITONPARSE_KERNEL_ALLOWLIST.split(",")]
657
+ # Filter out empty patterns
658
+ patterns = [pattern for pattern in patterns if pattern]
659
+
660
+ if not patterns:
661
+ return None
662
+
663
+ log.debug(f"Kernel allowlist patterns: {patterns}")
664
+ return patterns
665
+
666
+
667
+ def extract_kernel_name(src) -> Optional[str]:
668
+ """
669
+ Extract kernel name from the source object.
670
+
671
+ Args:
672
+ src (Union[ASTSource, IRSource]): Source object containing kernel information
673
+
674
+ Returns:
675
+ str or None: Kernel name if extractable, None otherwise
676
+ """
677
+ from triton.compiler import IRSource
678
+
679
+ try:
680
+ if isinstance(src, IRSource):
681
+ return src.getattr("name", None)
682
+ else:
683
+ # For ASTSource, get the function name
684
+ if (
685
+ hasattr(src, "fn")
686
+ and hasattr(src.fn, "fn")
687
+ and hasattr(src.fn.fn, "__name__")
688
+ ):
689
+ return src.fn.fn.__name__
690
+ return None
691
+ except Exception as e:
692
+ log.warn(f"Error extracting kernel name: {e}")
693
+ return None
694
+
695
+
696
+ def should_trace_kernel(
697
+ kernel_name: Optional[str], allowlist_patterns: Optional[List[str]]
698
+ ) -> bool:
699
+ """
700
+ Check if a kernel should be traced based on the allowlist.
701
+
702
+ Args:
703
+ kernel_name (str or None): Name of the kernel
704
+ allowlist_patterns (List[str] or None): List of patterns to match against
705
+
706
+ Returns:
707
+ bool: True if the kernel should be traced, False otherwise
708
+ """
709
+ # If no allowlist is set, trace all kernels
710
+ if allowlist_patterns is None:
711
+ return True
712
+
713
+ # If we can't extract kernel name, don't trace (conservative approach)
714
+ if kernel_name is None:
715
+ log.debug("Cannot extract kernel name, skipping trace")
716
+ return False
717
+
718
+ # Check if kernel name matches any pattern in the allowlist
719
+ for pattern in allowlist_patterns:
720
+ if fnmatch.fnmatch(kernel_name, pattern):
721
+ log.debug(f"Kernel '{kernel_name}' matches pattern '{pattern}', will trace")
722
+ return True
723
+
724
+ log.debug(
725
+ f"Kernel '{kernel_name}' does not match any allowlist pattern, skipping trace"
726
+ )
727
+ return False
728
+
729
+
730
+ def extract_python_source_info(trace_data: Dict[str, Any], source):
731
+ """
732
+ Extract Python source code information from the source object and add it to trace_data.
733
+
734
+ This function uses Python's inspect module to extract source code information
735
+ from the provided source object (typically an ASTSource or IRSource instance).
736
+ It adds file path, line numbers, and the actual source code to the trace_data.
737
+
738
+ By default, only the function definition is extracted. Set TRITON_FULL_PYTHON_SOURCE=1
739
+ to extract the entire Python source file.
740
+ @TODO: we should enable it by default in next diff and track the compilation time regression
741
+
742
+ Environment Variables:
743
+ TRITON_FULL_PYTHON_SOURCE: If set to "1", extract the full Python file
744
+ instead of just the function definition.
745
+ TRITON_MAX_SOURCE_SIZE: Maximum file size in bytes for full source extraction
746
+ (default: 10MB). Files larger than this will fall back
747
+ to function-only mode.
748
+
749
+ Args:
750
+ trace_data (Dict[str, Any]): Dictionary to store extracted information
751
+ source (Union[ASTSource, IRSource]): Source object containing kernel function information
752
+ """
753
+ # @TODO: add support for IRSource
754
+ from triton.compiler import IRSource
755
+ from triton.runtime.jit import JITFunction
756
+
757
+ if isinstance(source, IRSource):
758
+ return
759
+
760
+ # Get the function reference
761
+ if isinstance(fn := source.fn, JITFunction):
762
+ fn_ref = fn.fn
763
+ else:
764
+ fn_ref = source.fn
765
+
766
+ python_source_file = inspect.getfile(fn_ref)
767
+
768
+ # Get function range information
769
+ if (
770
+ isinstance(fn := source.fn, JITFunction)
771
+ and hasattr(fn, "starting_line_number")
772
+ and hasattr(fn, "raw_src")
773
+ ):
774
+ function_start_line = fn.starting_line_number
775
+ source_lines = fn.raw_src
776
+ else:
777
+ source_lines, function_start_line = inspect.getsourcelines(fn_ref)
778
+
779
+ function_end_line = function_start_line + len(source_lines) - 1
780
+
781
+ if TRITON_FULL_PYTHON_SOURCE:
782
+ # Full file mode: read the entire Python file
783
+ try:
784
+ # Check file size before reading
785
+ file_size = os.path.getsize(python_source_file)
786
+ except OSError as e:
787
+ log.warning(
788
+ f"Failed to check file size for {python_source_file}: {e}. "
789
+ f"Falling back to function-only mode."
790
+ )
791
+ use_full_source = False
792
+ else:
793
+ if file_size > TRITON_MAX_SOURCE_SIZE:
794
+ log.warning(
795
+ f"Source file {python_source_file} is too large ({file_size} bytes, "
796
+ f"limit: {TRITON_MAX_SOURCE_SIZE} bytes). Falling back to function-only mode."
797
+ )
798
+ use_full_source = False
799
+ else:
800
+ use_full_source = True
801
+
802
+ if use_full_source:
803
+ try:
804
+ with open(python_source_file, "r", encoding="utf-8") as f:
805
+ file_content = f.read()
806
+
807
+ # Calculate total lines
808
+ total_lines = len(file_content.split("\n"))
809
+
810
+ trace_data["python_source"] = {
811
+ "file_path": python_source_file,
812
+ "start_line": 1,
813
+ "end_line": total_lines,
814
+ "code": file_content,
815
+ # Add function range for frontend highlighting and scrolling
816
+ "function_start_line": function_start_line,
817
+ "function_end_line": function_end_line,
818
+ }
819
+ return
820
+ except (OSError, UnicodeDecodeError) as e:
821
+ log.warning(
822
+ f"Failed to read full source file {python_source_file}: {e}. "
823
+ f"Falling back to function-only mode."
824
+ )
825
+
826
+ # Default behavior: only extract function definition
827
+ trace_data["python_source"] = {
828
+ "file_path": python_source_file,
829
+ "start_line": function_start_line,
830
+ "end_line": function_end_line,
831
+ "code": "".join(source_lines),
832
+ }
833
+
834
+
835
+ def extract_file_content(trace_data: Dict[str, Any], metadata_group: Dict[str, str]):
836
+ """
837
+ Extract file content from metadata_group and add it to trace_data.
838
+
839
+ Args:
840
+ trace_data (Dict): Dictionary to store extracted information
841
+ metadata_group (Dict): Dictionary mapping filenames to file paths
842
+ """
843
+ for ir_filename, file_path in metadata_group.items():
844
+ # Add file path to trace data
845
+ trace_data["file_path"][ir_filename] = file_path
846
+
847
+ # Check if this is a text file we can read
848
+ if any(ir_filename.endswith(ext) for ext in TEXT_FILE_EXTENSIONS):
849
+ try:
850
+ # Check file size before reading to avoid memory issues
851
+ file_size = os.path.getsize(file_path)
852
+ if file_size > MAX_FILE_SIZE:
853
+ message = f"<file too large: {file_size} bytes>"
854
+ trace_data["file_content"][ir_filename] = message
855
+ continue
856
+
857
+ with open(file_path, "r") as f:
858
+ trace_data["file_content"][ir_filename] = f.read()
859
+ except (UnicodeDecodeError, OSError) as e:
860
+ # add more specific error type
861
+ message = f"<error reading file: {str(e)}>"
862
+ trace_data["file_content"][ir_filename] = message
863
+ log.debug(f"Error reading file {file_path}: {e}")
864
+ cubin_keys = [key for key in metadata_group.keys() if key.endswith(".cubin")]
865
+ cubin_path = metadata_group[cubin_keys[0]] if cubin_keys else None
866
+
867
+ if TRITONPARSE_DUMP_SASS and cubin_path:
868
+ filename_no_ext = os.path.splitext(os.path.basename(cubin_path))[0]
869
+ sass_filename = f"{filename_no_ext}.sass"
870
+ try:
871
+ import tritonparse.tools.disasm
872
+
873
+ sass_content = tritonparse.tools.disasm.extract(cubin_path)
874
+ trace_data["file_content"][sass_filename] = sass_content
875
+ except subprocess.CalledProcessError as e:
876
+ message = f"<nvdisasm failed: {str(e)}>"
877
+ trace_data["file_content"][sass_filename] = message
878
+ except OSError as e:
879
+ message = f"<error reading cubin file: {str(e)}>"
880
+ trace_data["file_content"][sass_filename] = message
881
+ except Exception as e:
882
+ message = f"<error dumping SASS: {str(e)}>"
883
+ trace_data["file_content"][sass_filename] = message
884
+
885
+
886
+ def extract_metadata_from_src(trace_data, src):
887
+ from triton._C.libtriton import get_cache_invalidating_env_vars
888
+
889
+ env_vars = get_cache_invalidating_env_vars()
890
+ # extra_options = src.parse_options()
891
+ # options = backend.parse_options(dict(options or dict(), **extra_options))
892
+
893
+ # trace_data["extra_options"] = extra_options
894
+ trace_data["metadata"].update(
895
+ {
896
+ "env": env_vars,
897
+ "src_attrs": src.attrs if hasattr(src, "attrs") else {},
898
+ "src_cache_key": src.fn.cache_key if hasattr(src, "fn") else "",
899
+ "src_constants": src.constants if hasattr(src, "constants") else {},
900
+ }
901
+ )
902
+
903
+
904
+ class TritonJsonFormatter(logging.Formatter):
905
+ """
906
+ Format log records as JSON for Triton compilation tracing.
907
+
908
+ This formatter converts log records with metadata and payload into NDJSON format,
909
+ suitable for structured logging and later analysis. It handles special attributes
910
+ added by the tritonparse, such as metadata dictionaries and payload data.
911
+ """
912
+
913
+ def format(self, record: logging.LogRecord):
914
+ log_entry = record.metadata
915
+ payload = record.payload
916
+
917
+ log_entry["timestamp"] = self.formatTime(record, "%Y-%m-%dT%H:%M:%S.%fZ")
918
+ if payload is not None:
919
+ log_entry["payload"] = json.loads(payload)
920
+ clean_log_entry = convert(log_entry)
921
+ # NDJSON format requires a newline at the end of each line
922
+ json_str = json.dumps(clean_log_entry, separators=(",", ":"))
923
+ return json_str + "\n"
924
+
925
+
926
+ class TritonTraceHandler(logging.StreamHandler):
927
+ """
928
+ A handler for Triton compilation tracing that outputs NDJSON files.
929
+
930
+ This handler creates and manages log files for Triton kernel compilation traces.
931
+ It supports creating new files for different compilation events and handles
932
+ proper cleanup of file resources. When running in a distributed environment,
933
+ it automatically adds rank information to filenames.
934
+ """
935
+
936
+ def __init__(
937
+ self, root_dir: Optional[str] = None, prefix=DEFAULT_TRACE_FILE_PREFIX
938
+ ):
939
+ logging.Handler.__init__(self)
940
+ self.root_dir = root_dir
941
+ self.prefix = prefix
942
+ self.stream = None
943
+ self.first_record = True
944
+ # If the program is unexpected terminated, atexit can ensure file resources are properly closed and released.
945
+ # it is because we use `self.stream` to keep the opened file stream, if the program is interrupted by some errors, the stream may not be closed.
946
+ atexit.register(self._cleanup)
947
+
948
+ def get_root_dir(self):
949
+ # For meta internal runs, we use the /logs directory by default
950
+ # reference implementation
951
+ # https://github.com/pytorch/pytorch/blob/5fe58ab5bd9e14cce3107150a9956a2ed40d2f79/torch/_logging/_internal.py#L1071
952
+ if self.root_dir:
953
+ return self.root_dir
954
+ TRACE_LOG_DIR = "/logs"
955
+ should_set_root_dir = True
956
+ if TORCH_INSTALLED:
957
+ import torch.version as torch_version
958
+
959
+ if (
960
+ hasattr(torch_version, "git_version")
961
+ and os.getenv("MAST_HPC_JOB_NAME") is None
962
+ ):
963
+ log.info(
964
+ "TritonTraceHandler: disabled because not fbcode or conda on mast"
965
+ )
966
+ should_set_root_dir = False
967
+ # TODO: change to tritonparse knob
968
+ # The following check is necessary because the possible version mismatch between torch and tritonparse
969
+ elif (
970
+ hasattr(torch, "_utils_internal")
971
+ and hasattr(torch._utils_internal, "justknobs_check")
972
+ and not torch._utils_internal.justknobs_check("pytorch/trace:enable")
973
+ ):
974
+ log.info(
975
+ "TritonTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False"
976
+ )
977
+ should_set_root_dir = False
978
+ if should_set_root_dir:
979
+ if not os.path.exists(TRACE_LOG_DIR):
980
+ log.info(
981
+ "TritonTraceHandler: disabled because %s does not exist",
982
+ TRACE_LOG_DIR,
983
+ )
984
+ elif not os.access(TRACE_LOG_DIR, os.W_OK):
985
+ log.info(
986
+ "TritonTraceHandler: disabled because %s is not writable",
987
+ TRACE_LOG_DIR,
988
+ )
989
+ else:
990
+ self.root_dir = TRACE_LOG_DIR
991
+ return self.root_dir
992
+
993
+ def emit(self, record):
994
+ # reference implementation
995
+ # https://github.com/pytorch/pytorch/blob/5fe58ab5bd9e14cce3107150a9956a2ed40d2f79/torch/_logging/_internal.py#L1071
996
+ try:
997
+ if self.stream is None:
998
+ root_dir = self.get_root_dir()
999
+ if root_dir is not None:
1000
+ os.makedirs(root_dir, exist_ok=True)
1001
+ ranksuffix = ""
1002
+ if TORCH_INSTALLED:
1003
+ import torch.distributed as dist
1004
+
1005
+ if dist.is_available() and dist.is_initialized():
1006
+ ranksuffix = f"rank_{dist.get_rank()}_"
1007
+ filename = f"{self.prefix}{ranksuffix}"
1008
+ self._ensure_stream_closed()
1009
+ # Choose file extension and mode based on compression setting
1010
+ if TRITON_TRACE_GZIP:
1011
+ file_extension = ".bin.ndjson"
1012
+ file_mode = "ab+" # Binary mode for gzip member concatenation
1013
+ else:
1014
+ file_extension = ".ndjson"
1015
+ file_mode = "a+"
1016
+ log_file_name = os.path.abspath(
1017
+ os.path.join(root_dir, f"{filename}{file_extension}")
1018
+ )
1019
+ self.stream = open(
1020
+ log_file_name,
1021
+ mode=file_mode,
1022
+ )
1023
+ log.debug("TritonTraceHandler: logging to %s", log_file_name)
1024
+ else:
1025
+ triton_trace_log.removeHandler(self)
1026
+ return
1027
+
1028
+ if self.stream:
1029
+ formatted = self.format(record)
1030
+ if TRITON_TRACE_GZIP:
1031
+ # Create a separate gzip member for each record
1032
+ # This allows standard gzip readers to handle member concatenation automatically
1033
+ buffer = io.BytesIO()
1034
+ with gzip.GzipFile(fileobj=buffer, mode="wb") as gz:
1035
+ gz.write(formatted.encode("utf-8"))
1036
+ # Write the complete gzip member to the file
1037
+ compressed_data = buffer.getvalue()
1038
+ self.stream.write(compressed_data)
1039
+ else:
1040
+ self.stream.write(formatted)
1041
+ self.flush()
1042
+ except Exception as e:
1043
+ # record exception and ensure resources are released
1044
+ log.error(f"Error in TritonTraceHandler.emit: {e}")
1045
+ self._ensure_stream_closed()
1046
+ self.handleError(record) # call Handler's standard error handling
1047
+
1048
+ def close(self):
1049
+ """Close the current file."""
1050
+ self.acquire()
1051
+ try:
1052
+ try:
1053
+ if self.stream:
1054
+ try:
1055
+ self.flush()
1056
+ finally:
1057
+ self.stream.close()
1058
+ self.stream = None
1059
+ finally:
1060
+ # Solution adopted from PyTorch PR #120289
1061
+ logging.StreamHandler.close(self)
1062
+ finally:
1063
+ self.release()
1064
+
1065
+ def _cleanup(self):
1066
+ """Ensure proper cleanup on program exit"""
1067
+ if self.stream is not None:
1068
+ self.close()
1069
+
1070
+ def _ensure_stream_closed(self):
1071
+ """ensure stream is closed"""
1072
+ if self.stream is not None:
1073
+ try:
1074
+ self.flush()
1075
+ finally:
1076
+ self.stream.close()
1077
+ self.stream = None
1078
+
1079
+
1080
+ def init_logs():
1081
+ """
1082
+ Initialise tritonparse's logging system.
1083
+
1084
+ Requirements handled:
1085
+ 1. First call may or may not pass `trace_folder`.
1086
+ 2. A later call *can* pass `trace_folder` and must activate an
1087
+ existing handler whose `root_dir` was None.
1088
+ 3. When tracing is disabled (no writable directory), prevent the
1089
+ empty →
1090
+ DEBUG:tritonparse_trace:
1091
+ lines by blocking propagation to the root logger.
1092
+ """
1093
+ global TRITON_TRACE_HANDLER, triton_trace_folder, TENSOR_BLOB_MANAGER
1094
+
1095
+ # Basic logger settings (safe to run on every call)
1096
+ triton_trace_log.setLevel(logging.DEBUG)
1097
+ triton_trace_log.propagate = False # stops bubbling to root logger. see 3)
1098
+ # 1) Create the handler on first use (root_dir may be None)
1099
+ if TRITON_TRACE_HANDLER is None:
1100
+ TRITON_TRACE_HANDLER = TritonTraceHandler(triton_trace_folder)
1101
+ # 2) If the handler has no root_dir but we now have
1102
+ # `triton_trace_folder`, fill it in.
1103
+ if TRITON_TRACE_HANDLER.root_dir is None and triton_trace_folder is not None:
1104
+ TRITON_TRACE_HANDLER.root_dir = triton_trace_folder
1105
+ # 3) Re-evaluate whether we have a writable directory
1106
+ # (`get_root_dir()` also checks /logs logic, permissions, etc.)
1107
+ root_dir = TRITON_TRACE_HANDLER.get_root_dir()
1108
+ if root_dir is None:
1109
+ # Tracing still disabled: ensure the handler is NOT attached
1110
+ if TRITON_TRACE_HANDLER in triton_trace_log.handlers:
1111
+ triton_trace_log.removeHandler(TRITON_TRACE_HANDLER)
1112
+ return # quiet exit, no blank lines
1113
+ # 4) Tracing is enabled: attach the handler (if not already
1114
+ # attached) and set the JSON formatter.
1115
+ if TRITON_TRACE_HANDLER not in triton_trace_log.handlers:
1116
+ TRITON_TRACE_HANDLER.setFormatter(TritonJsonFormatter())
1117
+ triton_trace_log.addHandler(TRITON_TRACE_HANDLER)
1118
+
1119
+ # Initialize tensor blob manager if enabled
1120
+ if TRITONPARSE_SAVE_TENSOR_BLOBS and root_dir:
1121
+ if TENSOR_BLOB_MANAGER is None:
1122
+ TENSOR_BLOB_MANAGER = TensorBlobManager(
1123
+ root_dir=root_dir, storage_quota=TRITONPARSE_TENSOR_STORAGE_QUOTA
1124
+ )
1125
+ elif TENSOR_BLOB_MANAGER.root_dir is None:
1126
+ # Update root_dir if it wasn't set during initialization
1127
+ TENSOR_BLOB_MANAGER.set_root_dir(root_dir)
1128
+
1129
+
1130
+ def trace_structured_triton(
1131
+ name: str,
1132
+ metadata_fn: Optional[Callable[[], Dict[str, Any]]] = None,
1133
+ *,
1134
+ payload_fn: Optional[Callable[[], Optional[Union[str, object]]]] = None,
1135
+ ):
1136
+ """
1137
+ Record structured trace information for Triton kernel compilation.
1138
+
1139
+ This function is the main entry point for logging structured trace events
1140
+ in the Triton system. It handles initialization of the logging system if needed,
1141
+ creates new log files, and formats the trace data with metadata
1142
+ and payload information.
1143
+
1144
+ Args:
1145
+ name (str): Name of the trace event (e.g., "compilation", "execution")
1146
+ metadata_fn (Callable): Function that returns a dictionary of metadata to include
1147
+ in the trace record
1148
+ payload_fn (Callable): Function that returns the payload data (can be a string,
1149
+ dictionary, or other serializable object)
1150
+ """
1151
+
1152
+ if metadata_fn is None:
1153
+
1154
+ def metadata_fn():
1155
+ return {}
1156
+
1157
+ if payload_fn is None:
1158
+
1159
+ def payload_fn():
1160
+ return None
1161
+
1162
+ metadata_dict: Dict[str, Any] = {"event_type": name}
1163
+ metadata_dict["pid"] = os.getpid()
1164
+ custom_metadata = metadata_fn()
1165
+ if custom_metadata:
1166
+ metadata_dict.update(custom_metadata)
1167
+
1168
+ metadata_dict["stack"] = get_stack_trace()
1169
+
1170
+ # Log the record using our custom LogRecord
1171
+ payload = payload_fn()
1172
+ # Use a custom factory to create the record with simplified parameters
1173
+ record = create_triton_log_record(metadata=metadata_dict, payload=payload)
1174
+ # Log the custom record
1175
+ triton_trace_log.handle(record)
1176
+
1177
+
1178
+ def maybe_trace_triton(
1179
+ src,
1180
+ metadata: Dict[str, Any],
1181
+ metadata_group: Dict[str, str],
1182
+ times: Any,
1183
+ event_type: str = "compilation",
1184
+ cache_hit: bool = False,
1185
+ ):
1186
+ """
1187
+ Collect and trace Triton kernel compilation information for debugging and profiling.
1188
+
1189
+ This function gathers metadata, IR files, and source code information about a Triton
1190
+ kernel compilation, then logs it through the tracing system if tracing is enabled.
1191
+ It collects information from multiple sources:
1192
+ 1. JSON metadata file (if provided)
1193
+ 2. PyTorch compilation context (if available)
1194
+ 3. IR and other compilation artifact files
1195
+ 4. Python source code of the kernel function
1196
+
1197
+ This function is designed to be used as a CompilationListener in triton.knobs.compilation.listener,
1198
+ which now accepts a list of listeners.
1199
+
1200
+ Args:
1201
+ src (Union[ASTSource, IRSource]): Source object containing kernel information
1202
+ metadata (Dict[str, Any]): Dictionary containing metadata for the compilation
1203
+ metadata_group (Dict[str, Any]): Dictionary mapping filenames to file paths for all compilation artifacts
1204
+ times (CompileTimes): Object containing timing information for the compilation
1205
+ event_type (str): Type of event being traced (default: "compilation")
1206
+ cache_hit (bool): Whether the compilation was a cache hit (default: False)
1207
+
1208
+ Returns:
1209
+ Dict[str, Any]: Dictionary containing all collected trace data, even if tracing is disabled
1210
+ """
1211
+ # Check kernel allowlist early to avoid unnecessary work
1212
+ if _KERNEL_ALLOWLIST_PATTERNS is not None:
1213
+ kernel_name = extract_kernel_name(src)
1214
+ if not should_trace_kernel(kernel_name, _KERNEL_ALLOWLIST_PATTERNS):
1215
+ # Return empty dict to indicate no tracing was done
1216
+ return {}
1217
+
1218
+ # Initialize a dictionary with defaultdict to avoid key errors
1219
+ trace_data = defaultdict(dict)
1220
+ # Add cache_hit to metadata
1221
+ trace_data["metadata"]["cache_hit"] = cache_hit
1222
+ if not metadata:
1223
+ metadata_path = next(
1224
+ (Path(p) for c, p in metadata_group.items() if c.endswith(".json"))
1225
+ )
1226
+ with open(metadata_path, "r") as f:
1227
+ metadata = json.load(f)
1228
+ trace_data["metadata"].update(metadata)
1229
+ else:
1230
+ trace_data["metadata"].update(metadata)
1231
+ # Handle torch._guards which might not be recognized by type checker
1232
+ if TORCH_INSTALLED:
1233
+ trace_id = torch._guards.CompileContext.current_trace_id() # type: ignore
1234
+ else:
1235
+ trace_id = None
1236
+ cid = trace_id.compile_id if trace_id else None
1237
+ if cid is not None:
1238
+ for attr_name in ["compiled_autograd_id", "frame_id", "frame_compile_id"]:
1239
+ attr_value = getattr(cid, attr_name, None)
1240
+ if attr_value is not None:
1241
+ trace_data["pt_info"][attr_name] = attr_value
1242
+ if trace_id:
1243
+ trace_data["pt_info"]["attempt"] = trace_id.attempt
1244
+ # Extract content from all IR and other files in the metadata group
1245
+ extract_file_content(trace_data, metadata_group)
1246
+ # Extract Python source code information if available
1247
+ extract_python_source_info(trace_data, src)
1248
+ extract_metadata_from_src(trace_data, src)
1249
+
1250
+ # Add timing information if available
1251
+ if times:
1252
+ trace_data["metadata"]["times"] = times
1253
+ # Log the collected information through the tracing system
1254
+ trace_structured_triton(
1255
+ event_type,
1256
+ payload_fn=lambda: json.dumps(convert(trace_data)),
1257
+ )
1258
+
1259
+ return trace_data
1260
+
1261
+
1262
+ def extract_arg_info(arg_dict):
1263
+ """
1264
+ Extract detailed information from kernel arguments, especially for PyTorch
1265
+ tensors.
1266
+
1267
+ Args:
1268
+ arg_dict: Dictionary of kernel arguments
1269
+
1270
+ Returns:
1271
+ Dictionary with extracted argument information including tensor properties
1272
+ """
1273
+ extracted_args = {}
1274
+
1275
+ for arg_name, arg_value in arg_dict.items():
1276
+ arg_info = {}
1277
+
1278
+ # Check if it's a PyTorch tensor
1279
+ if TORCH_INSTALLED and isinstance(arg_value, torch.Tensor):
1280
+ arg_info["type"] = "tensor"
1281
+ arg_info.update(_log_torch_tensor_info(arg_value))
1282
+ # Handle custom Tensor/Storage types from triton_kernels
1283
+ elif _is_from_triton_kernels_module(arg_value):
1284
+ type_name = type(arg_value).__name__
1285
+ arg_info["type"] = f"triton_kernels.tensor.{type_name}"
1286
+
1287
+ if type_name == "Tensor":
1288
+ # Dump all attributes needed to reconstruct the Tensor wrapper
1289
+ if hasattr(arg_value, "shape"):
1290
+ arg_info["shape"] = convert(arg_value.shape)
1291
+ if hasattr(arg_value, "shape_max"):
1292
+ arg_info["shape_max"] = convert(arg_value.shape_max)
1293
+ if hasattr(arg_value, "dtype"):
1294
+ arg_info["dtype"] = convert(arg_value.dtype)
1295
+ if hasattr(arg_value, "storage"):
1296
+ # Recursively process the storage, which can be another
1297
+ # custom type or a torch.Tensor
1298
+ storage_arg = {"storage": arg_value.storage}
1299
+ arg_info["storage"] = extract_arg_info(storage_arg)["storage"]
1300
+
1301
+ elif type_name == "Storage":
1302
+ # Dump all attributes needed to reconstruct the Storage object
1303
+ if (
1304
+ hasattr(arg_value, "data")
1305
+ and TORCH_INSTALLED
1306
+ and isinstance(arg_value.data, torch.Tensor)
1307
+ ):
1308
+ # The 'data' is a torch.Tensor, log its metadata fully
1309
+ arg_info["data"] = _log_torch_tensor_info(arg_value.data)
1310
+ if hasattr(arg_value, "layout"):
1311
+ arg_info["layout"] = convert(arg_value.layout)
1312
+ else:
1313
+ log.warning(f"Unknown type: {type(arg_value)}")
1314
+
1315
+ # Handle scalar values
1316
+ elif isinstance(arg_value, (int, float, bool)):
1317
+ arg_info["type"] = type(arg_value).__name__
1318
+ arg_info["value"] = arg_value
1319
+ # Handle strings
1320
+ elif isinstance(arg_value, str):
1321
+ arg_info["type"] = "str"
1322
+ arg_info["value"] = arg_value
1323
+ arg_info["length"] = len(arg_value)
1324
+ # Handle other types
1325
+ else:
1326
+ arg_info["type"] = type(arg_value).__name__
1327
+ # Try to convert to string for logging
1328
+ arg_info["repr"] = str(arg_value)
1329
+ if len(arg_info["repr"]) > 200: # Truncate very long representations
1330
+ arg_info["repr"] = arg_info["repr"][:200] + "..."
1331
+
1332
+ extracted_args[arg_name] = arg_info
1333
+
1334
+ return extracted_args
1335
+
1336
+
1337
+ def add_launch_metadata(grid, metadata, arg_dict, inductor_args=None):
1338
+ # Check if we're in CUDA graph capture mode - if so, skip detailed argument extraction
1339
+ # to avoid CUDA errors (cudaErrorStreamCaptureUnsupported)
1340
+ is_capturing = False
1341
+ if TORCH_INSTALLED:
1342
+ try:
1343
+ is_capturing = torch.cuda.is_current_stream_capturing()
1344
+ except (AttributeError, RuntimeError):
1345
+ pass
1346
+
1347
+ if is_capturing:
1348
+ # During CUDA graph capture, return minimal metadata without argument extraction
1349
+ return {
1350
+ "launch_metadata_tritonparse": (
1351
+ grid,
1352
+ metadata._asdict(),
1353
+ {"_note": "argument extraction skipped during CUDA graph capture"},
1354
+ {},
1355
+ )
1356
+ }
1357
+
1358
+ # Extract detailed argument information (only when NOT capturing)
1359
+ extracted_args = extract_arg_info(arg_dict)
1360
+ extracted_inductor_args = extract_arg_info(inductor_args) if inductor_args else {}
1361
+ return {
1362
+ "launch_metadata_tritonparse": (
1363
+ grid,
1364
+ metadata._asdict(),
1365
+ extracted_args,
1366
+ extracted_inductor_args,
1367
+ )
1368
+ }
1369
+
1370
+
1371
+ class JITHookImpl(JITHook):
1372
+ """
1373
+ JIT Hook implementation that overrides or sets the launch_metadata function for Triton kernels.
1374
+
1375
+ This hook is essential for capturing detailed kernel launch information beyond the basic
1376
+ metadata (like kernel name) that Triton provides by default. Without setting a custom
1377
+ launch_metadata function, only minimal launch information is available as shown in:
1378
+ https://github.com/triton-lang/triton/blob/7ce287dc24b43476cdeb30529089ac361564505d/python/triton/compiler/compiler.py#L504
1379
+
1380
+ By intercepting the JIT compilation process and setting a custom launch_metadata function,
1381
+ we can capture comprehensive runtime information including grid parameters, kernel metadata,
1382
+ and argument dictionaries for detailed analysis and logging.
1383
+ """
1384
+
1385
+ def __call__(
1386
+ self,
1387
+ *,
1388
+ key: str,
1389
+ repr: str,
1390
+ fn,
1391
+ compile,
1392
+ is_manual_warmup: bool,
1393
+ already_compiled: bool,
1394
+ inductor_args: Optional[Dict[str, Any]] = None,
1395
+ ) -> Optional[bool]:
1396
+ """
1397
+ Override or set the launch_metadata function for the JIT-compiled kernel.
1398
+
1399
+ This method is called during the JIT compilation process and allows us to
1400
+ inject our custom launch_metadata function that will be used to collect
1401
+ detailed kernel launch information.
1402
+
1403
+ Args:
1404
+ key: Unique identifier for the kernel
1405
+ repr: String representation of the kernel
1406
+ fn: The JIT function object
1407
+ compile: Compilation function
1408
+ is_manual_warmup: Whether this is a manual warmup call
1409
+ already_compiled: Whether the kernel is already compiled
1410
+
1411
+ Returns:
1412
+ True to continue with compilation, None/False to skip
1413
+ """
1414
+ # Check kernel allowlist early to avoid unnecessary work
1415
+ if _KERNEL_ALLOWLIST_PATTERNS is not None:
1416
+ kernel_name = fn.name
1417
+ if not should_trace_kernel(kernel_name, _KERNEL_ALLOWLIST_PATTERNS):
1418
+ # Skip overriding launch_metadata if kernel is not in allowlist
1419
+ return True
1420
+
1421
+ # Get the current launch_metadata function if it exists
1422
+ function = getattr(fn, "jit_function", fn)
1423
+
1424
+ current_launch_metadata = getattr(function, "launch_metadata", None)
1425
+ if current_launch_metadata is not None:
1426
+ log.warning(
1427
+ f"fn {fn} launch_metadata is not None: {current_launch_metadata}. It will be overridden by tritonparse."
1428
+ )
1429
+ function.launch_metadata = partial(
1430
+ add_launch_metadata, inductor_args=inductor_args
1431
+ )
1432
+ return True
1433
+
1434
+
1435
+ class LaunchHookImpl(LaunchHook):
1436
+ """
1437
+ Launch Hook implementation for capturing and logging kernel launch metadata.
1438
+
1439
+ This hook is responsible for intercepting kernel launches and extracting the detailed
1440
+ metadata that was set up by the JITHookImpl. It provides entry point for
1441
+ kernel execution, allowing comprehensive logging and analysis of kernel launches
1442
+ including timing, parameters, and execution context.
1443
+
1444
+ The metadata captured includes:
1445
+ - Kernel name and function details
1446
+ - Grid dimensions and launch parameters
1447
+ - Kernel arguments and their values
1448
+ - Stream information
1449
+ - Custom metadata added by the launch_metadata function
1450
+ """
1451
+
1452
+ def __call__(self, metadata):
1453
+ """
1454
+ Handle kernel launch entry point.
1455
+
1456
+ This method is called when a kernel is about to be launched, providing
1457
+ access to all the launch metadata for logging, profiling, or analysis.
1458
+ metadata format:
1459
+
1460
+ Args:
1461
+ metadata: LazyDict containing comprehensive launch information including
1462
+ kernel name, function, stream, grid parameters, and custom data
1463
+ format: {'name': 'add_kernel', 'function': None, 'stream': 0,
1464
+ 'launch_metadata_tritonparse': (grid, self.metadata, extracted_args)}
1465
+ where extracted_args contains detailed info for each argument:
1466
+ - For tensors: shape, dtype, device, stride, memory_usage, etc.
1467
+ - For scalars: type and value
1468
+ - For other types: type and string representation
1469
+ defined here:
1470
+ https://github.com/triton-lang/triton/blob/7ce287dc24b43476cdeb30529089ac361564505d/
1471
+ python/triton/compiler/compiler.py#L512.
1472
+ """
1473
+ metadata_dict = metadata.get()
1474
+ # Check kernel allowlist early to avoid unnecessary work
1475
+ if _KERNEL_ALLOWLIST_PATTERNS is not None:
1476
+ kernel_name = metadata_dict.get("name")
1477
+
1478
+ if not should_trace_kernel(kernel_name, _KERNEL_ALLOWLIST_PATTERNS):
1479
+ # Skip tracing if kernel is not in allowlist
1480
+ return
1481
+
1482
+ trace_data = defaultdict(dict)
1483
+ trace_data["name"] = metadata_dict["name"]
1484
+ trace_data["function"] = metadata_dict["function"]
1485
+ trace_data["stream"] = metadata_dict["stream"]
1486
+ launch_metadata_tritonparse = metadata_dict.get(
1487
+ "launch_metadata_tritonparse", None
1488
+ )
1489
+ if launch_metadata_tritonparse is not None:
1490
+ trace_data["grid"] = launch_metadata_tritonparse[0]
1491
+ trace_data["compilation_metadata"] = launch_metadata_tritonparse[1]
1492
+ trace_data["extracted_args"] = launch_metadata_tritonparse[
1493
+ 2
1494
+ ] # Now contains detailed arg info
1495
+ trace_data["extracted_inductor_args"] = launch_metadata_tritonparse[3]
1496
+ trace_structured_triton("launch", metadata_fn=lambda: convert(trace_data))
1497
+
1498
+
1499
+ def maybe_enable_trace_launch():
1500
+ global _trace_launch_enabled
1501
+ if TRITON_TRACE_LAUNCH and not _trace_launch_enabled:
1502
+ from triton import knobs
1503
+
1504
+ launch_hook = LaunchHookImpl()
1505
+ jit_hook = JITHookImpl()
1506
+ knobs.runtime.jit_post_compile_hook = jit_hook
1507
+ knobs.runtime.launch_enter_hook = launch_hook
1508
+
1509
+ _trace_launch_enabled = True
1510
+
1511
+
1512
+ def init_basic(trace_folder: Optional[str] = None):
1513
+ """
1514
+ Initialize the basic logging system for Triton compilation.
1515
+
1516
+ This function sets up the basic logging system for Triton kernel compilation.
1517
+
1518
+ Args:
1519
+ trace_folder (Optional[str]): The folder to store the trace files.
1520
+ """
1521
+ global triton_trace_folder, _KERNEL_ALLOWLIST_PATTERNS
1522
+ maybe_enable_debug_logging()
1523
+ if triton_trace_folder is not None and trace_folder is not None:
1524
+ log.info(
1525
+ "Conflict settings: triton_trace_folder is already set to %s, we will use provided trace_folder(%s) instead.",
1526
+ triton_trace_folder,
1527
+ trace_folder,
1528
+ )
1529
+ if trace_folder is not None:
1530
+ triton_trace_folder = trace_folder
1531
+
1532
+ # Parse and store kernel allowlist configuration
1533
+ _KERNEL_ALLOWLIST_PATTERNS = parse_kernel_allowlist()
1534
+ if _KERNEL_ALLOWLIST_PATTERNS:
1535
+ log.debug(
1536
+ f"Kernel allowlist enabled with patterns: {_KERNEL_ALLOWLIST_PATTERNS}"
1537
+ )
1538
+ else:
1539
+ log.debug("Kernel allowlist not set, tracing all kernels")
1540
+
1541
+ init_logs()
1542
+ maybe_enable_trace_launch()
1543
+
1544
+
1545
+ def init(
1546
+ trace_folder: Optional[str] = None,
1547
+ enable_trace_launch: bool = False,
1548
+ enable_more_tensor_information: bool = False,
1549
+ enable_sass_dump: Optional[bool] = False,
1550
+ enable_tensor_blob_storage: bool = False,
1551
+ tensor_storage_quota: Optional[int] = None,
1552
+ ):
1553
+ """
1554
+ This function is a wrapper around init_basic() that also sets up the compilation listener. Its arguments have higher priority than the environment variables for same settings.
1555
+
1556
+ Args:
1557
+ trace_folder (Optional[str]): The folder to store the trace files.
1558
+ enable_trace_launch (bool): Whether to enable the trace launch hook.
1559
+ enable_more_tensor_information (bool): Whether to enable more tensor information logging.
1560
+ It only works when enable_trace_launch/TRITON_TRACE_LAUNCH is True.
1561
+ enable_sass_dump (Optional[bool]): Whether to enable SASS dumping.
1562
+ enable_tensor_blob_storage (bool): Whether to enable tensor blob storage.
1563
+ tensor_storage_quota (Optional[int]): Storage quota in bytes for tensor blobs (default: 100GB).
1564
+ """
1565
+ global TRITON_TRACE_LAUNCH, TRITONPARSE_MORE_TENSOR_INFORMATION
1566
+ global TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK, TRITONPARSE_DUMP_SASS
1567
+ global TRITONPARSE_SAVE_TENSOR_BLOBS, TRITONPARSE_TENSOR_STORAGE_QUOTA
1568
+
1569
+ # Set global flags BEFORE calling init_basic, so init_logs() can see them
1570
+ if enable_trace_launch:
1571
+ TRITON_TRACE_LAUNCH = True
1572
+ TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK = True
1573
+ if enable_more_tensor_information:
1574
+ TRITONPARSE_MORE_TENSOR_INFORMATION = True
1575
+ if enable_sass_dump:
1576
+ TRITONPARSE_DUMP_SASS = True
1577
+ if enable_tensor_blob_storage:
1578
+ TRITONPARSE_SAVE_TENSOR_BLOBS = True
1579
+
1580
+ # Set the quota in global var for TensorBlobManager creation in init_logs()
1581
+ if tensor_storage_quota is not None:
1582
+ TRITONPARSE_TENSOR_STORAGE_QUOTA = tensor_storage_quota
1583
+
1584
+ init_basic(trace_folder)
1585
+ from triton import knobs
1586
+
1587
+ knobs.compilation.listener = maybe_trace_triton
1588
+
1589
+
1590
+ def init_with_env():
1591
+ """
1592
+ This function is used to initialize TritonParse with the environment variable TRITON_TRACE_FOLDER and TRITON_TRACE_LAUNCH specifically.
1593
+ It is only supposed to be used in OSS triton's source code.
1594
+ """
1595
+ if triton_trace_folder:
1596
+ init(triton_trace_folder, enable_trace_launch=TRITON_TRACE_LAUNCH)
1597
+
1598
+
1599
+ def clear_logging_config():
1600
+ """
1601
+ Clear all configurations made by init() and init_basic().
1602
+
1603
+ This function resets the logging handlers, global state variables,
1604
+ and Triton knobs to their default states, effectively disabling
1605
+ the custom tracing.
1606
+
1607
+ WARNING: This function is not supposed to be called unless you are sure
1608
+ you want to clear the logging config.
1609
+ """
1610
+ global TRITON_TRACE_HANDLER, triton_trace_folder, _KERNEL_ALLOWLIST_PATTERNS
1611
+ global _trace_launch_enabled
1612
+ global TENSOR_BLOB_MANAGER
1613
+ # 1. Clean up the log handler
1614
+ if TRITON_TRACE_HANDLER is not None:
1615
+ if TRITON_TRACE_HANDLER in triton_trace_log.handlers:
1616
+ triton_trace_log.removeHandler(TRITON_TRACE_HANDLER)
1617
+ TRITON_TRACE_HANDLER.close()
1618
+ TRITON_TRACE_HANDLER = None
1619
+
1620
+ # 2. Reset global state variables
1621
+ triton_trace_folder = None
1622
+ _KERNEL_ALLOWLIST_PATTERNS = None
1623
+ _trace_launch_enabled = False
1624
+
1625
+ # 3. Reset tensor blob manager and related flags
1626
+ TENSOR_BLOB_MANAGER = None
1627
+
1628
+ # 4. Reset Triton knobs
1629
+ # Check if triton was actually imported and used
1630
+ from triton import knobs
1631
+
1632
+ knobs.compilation.listener = None
1633
+ knobs.runtime.jit_post_compile_hook = None
1634
+ knobs.runtime.launch_enter_hook = None