tritonparse 0.3.2.dev20251210071601__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/__init__.py +0 -0
- tritonparse/__main__.py +7 -0
- tritonparse/cli.py +110 -0
- tritonparse/common.py +409 -0
- tritonparse/context_manager.py +64 -0
- tritonparse/event_diff.py +122 -0
- tritonparse/extract_source_mappings.py +49 -0
- tritonparse/info/__init__.py +30 -0
- tritonparse/info/cli.py +121 -0
- tritonparse/info/kernel_query.py +209 -0
- tritonparse/info/parse_helper.py +70 -0
- tritonparse/ir_analysis.py +427 -0
- tritonparse/ir_parser.py +365 -0
- tritonparse/mapper.py +102 -0
- tritonparse/reproducer/__init__.py +0 -0
- tritonparse/reproducer/ast_analyzer.py +636 -0
- tritonparse/reproducer/cli.py +72 -0
- tritonparse/reproducer/consolidated_result.py +52 -0
- tritonparse/reproducer/function_extractor.py +228 -0
- tritonparse/reproducer/import_info.py +25 -0
- tritonparse/reproducer/import_parser.py +178 -0
- tritonparse/reproducer/import_resolver.py +151 -0
- tritonparse/reproducer/ingestion/ndjson.py +237 -0
- tritonparse/reproducer/multi_file_analyzer.py +824 -0
- tritonparse/reproducer/orchestrator.py +110 -0
- tritonparse/reproducer/placeholder_replacer.py +335 -0
- tritonparse/reproducer/templates/__init__.py +0 -0
- tritonparse/reproducer/templates/example.py +38 -0
- tritonparse/reproducer/templates/loader.py +59 -0
- tritonparse/reproducer/templates/tritonbench.py +106 -0
- tritonparse/reproducer/templates/utils.py +48 -0
- tritonparse/reproducer/tests/__init__.py +0 -0
- tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
- tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
- tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
- tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
- tritonparse/reproducer/tests/test_import_parser.py +164 -0
- tritonparse/reproducer/tests/test_import_resolver.py +88 -0
- tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
- tritonparse/reproducer/types.py +20 -0
- tritonparse/reproducer/utils.py +580 -0
- tritonparse/shared_vars.py +12 -0
- tritonparse/source_type.py +56 -0
- tritonparse/sourcemap_utils.py +96 -0
- tritonparse/structured_logging.py +1634 -0
- tritonparse/tools/__init__.py +0 -0
- tritonparse/tools/decompress_bin_ndjson.py +120 -0
- tritonparse/tools/disasm.py +81 -0
- tritonparse/tools/extract_irs.py +244 -0
- tritonparse/tools/format_fix.py +151 -0
- tritonparse/tools/load_tensor.py +76 -0
- tritonparse/tools/prettify_ndjson.py +334 -0
- tritonparse/tools/readme.md +37 -0
- tritonparse/tp_logger.py +9 -0
- tritonparse/trace_processor.py +367 -0
- tritonparse/utils.py +155 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1634 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
import atexit
|
|
4
|
+
import fnmatch
|
|
5
|
+
import gzip
|
|
6
|
+
import hashlib
|
|
7
|
+
import importlib
|
|
8
|
+
import inspect
|
|
9
|
+
import io
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import math
|
|
13
|
+
import os
|
|
14
|
+
import subprocess
|
|
15
|
+
import tempfile
|
|
16
|
+
from collections import defaultdict
|
|
17
|
+
from collections.abc import Mapping
|
|
18
|
+
from dataclasses import asdict, is_dataclass
|
|
19
|
+
from datetime import date, datetime
|
|
20
|
+
from enum import Enum
|
|
21
|
+
from functools import partial
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
|
24
|
+
|
|
25
|
+
from triton.knobs import JITHook, LaunchHook
|
|
26
|
+
|
|
27
|
+
from .shared_vars import DEFAULT_TRACE_FILE_PREFIX
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
log = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
TEXT_FILE_EXTENSIONS = [".ttir", ".ttgir", ".llir", ".ptx", ".amdgcn", ".json"]
|
|
33
|
+
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB limit for file content extraction
|
|
34
|
+
# Enable gzip compression for each line in trace files
|
|
35
|
+
TRITON_TRACE_GZIP = os.getenv("TRITON_TRACE_GZIP", "0") in ["1", "true", "True"]
|
|
36
|
+
triton_trace_log = logging.getLogger("tritonparse_trace")
|
|
37
|
+
# The folder to store the triton trace log.
|
|
38
|
+
triton_trace_folder = os.environ.get("TRITON_TRACE", None)
|
|
39
|
+
# Enable debug logging for tritonparse itself
|
|
40
|
+
TRITONPARSE_DEBUG = os.getenv("TRITONPARSE_DEBUG", None) in ["1", "true", "True"]
|
|
41
|
+
# Kernel allowlist for filtering traced kernels. Use comma separated list of fnmatch patterns.
|
|
42
|
+
TRITONPARSE_KERNEL_ALLOWLIST = os.environ.get("TRITONPARSE_KERNEL_ALLOWLIST", None)
|
|
43
|
+
# Parsed kernel allowlist patterns (set during init)
|
|
44
|
+
_KERNEL_ALLOWLIST_PATTERNS: Optional[List[str]] = None
|
|
45
|
+
# Enable launch trace. WARNNING: it will overwrite launch_metadata function for each triton kernel.
|
|
46
|
+
TRITON_TRACE_LAUNCH = os.getenv("TRITON_TRACE_LAUNCH", None) in ["1", "true", "True"]
|
|
47
|
+
# Enable more tensor information collection in trace logs.
|
|
48
|
+
TRITONPARSE_MORE_TENSOR_INFORMATION = os.getenv(
|
|
49
|
+
"TRITONPARSE_MORE_TENSOR_INFORMATION", None
|
|
50
|
+
) in ["1", "true", "True"]
|
|
51
|
+
# Enable full Python source file extraction instead of just the function definition
|
|
52
|
+
TRITON_FULL_PYTHON_SOURCE = os.getenv("TRITON_FULL_PYTHON_SOURCE", "0") in [
|
|
53
|
+
"1",
|
|
54
|
+
"true",
|
|
55
|
+
"True",
|
|
56
|
+
]
|
|
57
|
+
# Maximum file size for full source extraction (default 10MB)
|
|
58
|
+
TRITON_MAX_SOURCE_SIZE = int(os.getenv("TRITON_MAX_SOURCE_SIZE", str(10 * 1024 * 1024)))
|
|
59
|
+
# Inductor compiled kernel's launch tracing needs this flag to be set.
|
|
60
|
+
# If TRITON_TRACE_LAUNCH is enabled, also enable TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK
|
|
61
|
+
TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK = (
|
|
62
|
+
os.getenv("TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK", None) in ["1", "true", "True"]
|
|
63
|
+
or TRITON_TRACE_LAUNCH
|
|
64
|
+
)
|
|
65
|
+
# Enable NVIDIA SASS dump. It requires the CUBIN file to be localable.
|
|
66
|
+
# WARNNING: it will slow down the compilation significantly.
|
|
67
|
+
TRITONPARSE_DUMP_SASS = os.getenv("TRITONPARSE_DUMP_SASS", None) in [
|
|
68
|
+
"1",
|
|
69
|
+
"true",
|
|
70
|
+
"True",
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
# The flag to mark if launch is traced. It is used to avoid initilizing the launch hook twice.
|
|
74
|
+
_trace_launch_enabled = False
|
|
75
|
+
# Enable tensor blob storage
|
|
76
|
+
TRITONPARSE_SAVE_TENSOR_BLOBS = os.getenv("TRITONPARSE_SAVE_TENSOR_BLOBS", "0") in [
|
|
77
|
+
"1",
|
|
78
|
+
"true",
|
|
79
|
+
"True",
|
|
80
|
+
]
|
|
81
|
+
# Tensor size limit in bytes (default 10GB)
|
|
82
|
+
TRITONPARSE_TENSOR_SIZE_LIMIT = int(
|
|
83
|
+
os.getenv("TRITONPARSE_TENSOR_SIZE_LIMIT", str(10 * 1024 * 1024 * 1024))
|
|
84
|
+
)
|
|
85
|
+
# Tensor storage quota in bytes (default 100GB) - tracks compressed size for current run
|
|
86
|
+
TRITONPARSE_TENSOR_STORAGE_QUOTA = int(
|
|
87
|
+
os.getenv("TRITONPARSE_TENSOR_STORAGE_QUOTA", str(100 * 1024 * 1024 * 1024))
|
|
88
|
+
)
|
|
89
|
+
# Compression threshold in bytes (default 1MB) - only compress blobs >= this size
|
|
90
|
+
TRITONPARSE_COMPRESSION_THRESHOLD = 1 * 1024 * 1024
|
|
91
|
+
# Compression level for gzip (0-9, higher = better compression but slower)
|
|
92
|
+
TRITONPARSE_COMPRESSION_LEVEL = 4
|
|
93
|
+
# Log statistics every N saved blobs
|
|
94
|
+
TRITONPARSE_STATS_LOG_FREQUENCY = 100
|
|
95
|
+
|
|
96
|
+
TRITON_TRACE_HANDLER = None
|
|
97
|
+
# Global tensor blob manager instance
|
|
98
|
+
TENSOR_BLOB_MANAGER = None
|
|
99
|
+
|
|
100
|
+
if importlib.util.find_spec("torch") is not None:
|
|
101
|
+
TORCH_INSTALLED = True
|
|
102
|
+
import torch
|
|
103
|
+
from torch.utils._traceback import CapturedTraceback
|
|
104
|
+
else:
|
|
105
|
+
TORCH_INSTALLED = False
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class TensorBlobManager:
|
|
109
|
+
"""
|
|
110
|
+
Manager for storing tensor data as content-addressed blobs.
|
|
111
|
+
|
|
112
|
+
Uses BLAKE2b hashing for content addressing and stores blobs in a two-level
|
|
113
|
+
directory structure to avoid filesystem limitations with large numbers of files.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
root_dir: Optional[str] = None,
|
|
119
|
+
storage_quota: Optional[int] = None,
|
|
120
|
+
):
|
|
121
|
+
self.root_dir = None
|
|
122
|
+
self.hash_to_path_cache = {} # In-memory cache for hash -> path mapping
|
|
123
|
+
self.compression_threshold = TRITONPARSE_COMPRESSION_THRESHOLD
|
|
124
|
+
self.storage_quota = (
|
|
125
|
+
storage_quota
|
|
126
|
+
if storage_quota is not None
|
|
127
|
+
else TRITONPARSE_TENSOR_STORAGE_QUOTA
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Resource statistics (tracks current run only)
|
|
131
|
+
self.total_compressed_bytes = 0 # Total compressed size written in this run
|
|
132
|
+
self.total_uncompressed_bytes = (
|
|
133
|
+
0 # Total uncompressed size (for compression ratio)
|
|
134
|
+
)
|
|
135
|
+
self.blob_count = 0 # Total blob references (including dedup hits)
|
|
136
|
+
self.blob_saved_count = 0 # Actual blobs saved (excluding dedup hits)
|
|
137
|
+
self.storage_disabled = False # Whether storage has been disabled due to quota
|
|
138
|
+
self.storage_disabled_reason = None # Reason for disabling storage
|
|
139
|
+
|
|
140
|
+
if root_dir:
|
|
141
|
+
self.set_root_dir(root_dir)
|
|
142
|
+
|
|
143
|
+
def set_root_dir(self, root_dir: str):
|
|
144
|
+
"""Set the root directory for blob storage."""
|
|
145
|
+
self.root_dir = Path(root_dir) / "saved_tensors"
|
|
146
|
+
self.root_dir.mkdir(parents=True, exist_ok=True)
|
|
147
|
+
log.debug(f"TensorBlobManager: using root directory {self.root_dir}")
|
|
148
|
+
|
|
149
|
+
def _compute_hash(self, data: bytes) -> str:
|
|
150
|
+
"""Compute BLAKE2b hash of the data."""
|
|
151
|
+
return hashlib.blake2b(data).hexdigest()
|
|
152
|
+
|
|
153
|
+
def _get_blob_path(self, hash_hex: str, extension: str = ".bin.gz") -> Path:
|
|
154
|
+
"""Get the file path for a given hash using two-level directory structure."""
|
|
155
|
+
if not self.root_dir:
|
|
156
|
+
raise ValueError("Root directory not set")
|
|
157
|
+
|
|
158
|
+
# Two-level directory: first 2 chars / full_hash{extension}
|
|
159
|
+
subdir = hash_hex[:2]
|
|
160
|
+
filename = f"{hash_hex}{extension}"
|
|
161
|
+
return (self.root_dir / subdir / filename).resolve()
|
|
162
|
+
|
|
163
|
+
def _get_tensor_size_bytes(self, tensor) -> int:
|
|
164
|
+
"""Get tensor size in bytes before serialization."""
|
|
165
|
+
if hasattr(tensor, "numel") and hasattr(tensor, "element_size"):
|
|
166
|
+
return tensor.numel() * tensor.element_size()
|
|
167
|
+
return 0
|
|
168
|
+
|
|
169
|
+
def _log_statistics(self, final: bool = False):
|
|
170
|
+
"""Print statistics about tensor blob storage.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
final: If True, this is the final statistics message (e.g., when storage is disabled)
|
|
174
|
+
"""
|
|
175
|
+
prefix = "📊 Final" if final else "📊"
|
|
176
|
+
compression_ratio = (
|
|
177
|
+
self.total_uncompressed_bytes / max(1, self.total_compressed_bytes)
|
|
178
|
+
if self.total_compressed_bytes > 0
|
|
179
|
+
else 0.0
|
|
180
|
+
)
|
|
181
|
+
dedup_count = self.blob_count - self.blob_saved_count
|
|
182
|
+
|
|
183
|
+
log.info(
|
|
184
|
+
f"{prefix} Tensor blob stats: "
|
|
185
|
+
f"{self.blob_saved_count} saved ({self.blob_count} total, {dedup_count} dedup), "
|
|
186
|
+
f"{self.total_compressed_bytes / 1024**3:.2f}GB compressed "
|
|
187
|
+
f"({self.total_uncompressed_bytes / 1024**3:.2f}GB uncompressed), "
|
|
188
|
+
f"compression ratio: {compression_ratio:.2f}x"
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
def _disable_storage(self, reason: str):
|
|
192
|
+
"""Disable blob storage and log warning with statistics.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
reason: The reason why storage is being disabled
|
|
196
|
+
"""
|
|
197
|
+
if not self.storage_disabled: # Only disable once
|
|
198
|
+
self.storage_disabled = True
|
|
199
|
+
self.storage_disabled_reason = reason
|
|
200
|
+
log.warning(f"⚠️ TENSOR BLOB STORAGE DISABLED: {reason}")
|
|
201
|
+
self._log_statistics(final=True)
|
|
202
|
+
|
|
203
|
+
def save_tensor_blob(self, tensor) -> Dict[str, Any]:
|
|
204
|
+
"""
|
|
205
|
+
Save tensor as a blob and return metadata.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
tensor: PyTorch tensor to save
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Dictionary with blob metadata or error information:
|
|
212
|
+
- Success: {'tensor_hash': str, 'blob_path': str, 'blob_size': int,
|
|
213
|
+
'blob_size_uncompressed': int, 'compression': str,
|
|
214
|
+
'compression_ratio': float, 'serialization_method': str}
|
|
215
|
+
- Dedup hit: Same as success but from cache (not counted in quota)
|
|
216
|
+
- Error: {'error': str, 'tensor_hash': None}
|
|
217
|
+
"""
|
|
218
|
+
# Early exit: Check if storage is disabled
|
|
219
|
+
if self.storage_disabled:
|
|
220
|
+
return {"error": self.storage_disabled_reason, "tensor_hash": None}
|
|
221
|
+
|
|
222
|
+
# Early exit: Check if root directory is set
|
|
223
|
+
if not self.root_dir:
|
|
224
|
+
return {"error": "Blob storage not initialized", "tensor_hash": None}
|
|
225
|
+
|
|
226
|
+
try:
|
|
227
|
+
# Check tensor size before serialization
|
|
228
|
+
tensor_size = self._get_tensor_size_bytes(tensor)
|
|
229
|
+
if tensor_size > TRITONPARSE_TENSOR_SIZE_LIMIT:
|
|
230
|
+
log.warning(
|
|
231
|
+
f"Tensor size {tensor_size} bytes exceeds limit {TRITONPARSE_TENSOR_SIZE_LIMIT} bytes, skipping blob storage"
|
|
232
|
+
)
|
|
233
|
+
return {
|
|
234
|
+
"error": f"Tensor size {tensor_size} bytes exceeds limit {TRITONPARSE_TENSOR_SIZE_LIMIT} bytes",
|
|
235
|
+
"tensor_hash": None,
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
# Serialize tensor using torch.save
|
|
239
|
+
import io
|
|
240
|
+
|
|
241
|
+
buffer = io.BytesIO()
|
|
242
|
+
if TORCH_INSTALLED:
|
|
243
|
+
torch.save(tensor.cpu(), buffer)
|
|
244
|
+
else:
|
|
245
|
+
return {
|
|
246
|
+
"error": "PyTorch not available for tensor serialization",
|
|
247
|
+
"tensor_hash": None,
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
blob_data = buffer.getvalue()
|
|
251
|
+
uncompressed_size = len(blob_data)
|
|
252
|
+
|
|
253
|
+
# Compute hash on uncompressed data for content addressing
|
|
254
|
+
hash_hex = self._compute_hash(blob_data)
|
|
255
|
+
|
|
256
|
+
# Check for deduplication (before compression to save work)
|
|
257
|
+
if hash_hex in self.hash_to_path_cache:
|
|
258
|
+
blob_path = self.hash_to_path_cache[hash_hex]
|
|
259
|
+
try:
|
|
260
|
+
# Try to access the file - handles race condition where file might be deleted
|
|
261
|
+
disk_size = blob_path.stat().st_size
|
|
262
|
+
compression = (
|
|
263
|
+
"gzip" if str(blob_path).endswith(".bin.gz") else "none"
|
|
264
|
+
)
|
|
265
|
+
compression_ratio = uncompressed_size / max(1, disk_size)
|
|
266
|
+
|
|
267
|
+
# Deduplication hit - increment count but don't add to quota
|
|
268
|
+
self.blob_count += 1
|
|
269
|
+
|
|
270
|
+
return {
|
|
271
|
+
"tensor_hash": hash_hex,
|
|
272
|
+
"blob_path": str(blob_path),
|
|
273
|
+
"blob_size": disk_size,
|
|
274
|
+
"blob_size_uncompressed": uncompressed_size,
|
|
275
|
+
"compression": compression,
|
|
276
|
+
"compression_ratio": compression_ratio,
|
|
277
|
+
"serialization_method": "torch_save",
|
|
278
|
+
"deduplicated": True,
|
|
279
|
+
}
|
|
280
|
+
except (FileNotFoundError, OSError):
|
|
281
|
+
# File was deleted or inaccessible - remove from cache and continue to save
|
|
282
|
+
log.debug(
|
|
283
|
+
f"Cached blob file no longer exists: {blob_path}, will re-save"
|
|
284
|
+
)
|
|
285
|
+
self.hash_to_path_cache.pop(hash_hex, None)
|
|
286
|
+
|
|
287
|
+
# Decide whether to compress based on size threshold
|
|
288
|
+
if uncompressed_size >= self.compression_threshold:
|
|
289
|
+
# Compress the data
|
|
290
|
+
data_to_write = gzip.compress(
|
|
291
|
+
blob_data, compresslevel=TRITONPARSE_COMPRESSION_LEVEL
|
|
292
|
+
)
|
|
293
|
+
file_extension = ".bin.gz"
|
|
294
|
+
compression = "gzip"
|
|
295
|
+
else:
|
|
296
|
+
# Don't compress small files (overhead not worth it)
|
|
297
|
+
data_to_write = blob_data
|
|
298
|
+
file_extension = ".bin"
|
|
299
|
+
compression = "none"
|
|
300
|
+
|
|
301
|
+
disk_size = len(data_to_write)
|
|
302
|
+
|
|
303
|
+
# Check quota BEFORE writing
|
|
304
|
+
if self.total_compressed_bytes + disk_size > self.storage_quota:
|
|
305
|
+
self._disable_storage(
|
|
306
|
+
f"Storage quota would be exceeded: "
|
|
307
|
+
f"{(self.total_compressed_bytes + disk_size) / 1024**3:.2f}GB > "
|
|
308
|
+
f"{self.storage_quota / 1024**3:.2f}GB limit"
|
|
309
|
+
)
|
|
310
|
+
return {"error": self.storage_disabled_reason, "tensor_hash": None}
|
|
311
|
+
|
|
312
|
+
# Create blob file path with appropriate extension
|
|
313
|
+
blob_path = self._get_blob_path(hash_hex, extension=file_extension)
|
|
314
|
+
blob_path.parent.mkdir(parents=True, exist_ok=True)
|
|
315
|
+
|
|
316
|
+
# Atomic write using temporary file + rename
|
|
317
|
+
with tempfile.NamedTemporaryFile(
|
|
318
|
+
mode="wb",
|
|
319
|
+
dir=blob_path.parent,
|
|
320
|
+
prefix=f".tmp_{hash_hex}_",
|
|
321
|
+
delete=False,
|
|
322
|
+
) as tmp_file:
|
|
323
|
+
tmp_file.write(data_to_write)
|
|
324
|
+
tmp_path = Path(tmp_file.name)
|
|
325
|
+
|
|
326
|
+
# Atomic rename
|
|
327
|
+
tmp_path.rename(blob_path)
|
|
328
|
+
|
|
329
|
+
# Update cache and statistics
|
|
330
|
+
self.hash_to_path_cache[hash_hex] = blob_path
|
|
331
|
+
self.total_compressed_bytes += disk_size
|
|
332
|
+
self.total_uncompressed_bytes += uncompressed_size
|
|
333
|
+
self.blob_count += 1
|
|
334
|
+
self.blob_saved_count += 1
|
|
335
|
+
|
|
336
|
+
# Log progress periodically
|
|
337
|
+
if self.blob_saved_count % TRITONPARSE_STATS_LOG_FREQUENCY == 0:
|
|
338
|
+
self._log_statistics()
|
|
339
|
+
|
|
340
|
+
log.debug(
|
|
341
|
+
f"Saved tensor blob: {hash_hex} -> {blob_path} ({disk_size} bytes, compression={compression})"
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
compression_ratio = uncompressed_size / max(1, disk_size)
|
|
345
|
+
|
|
346
|
+
return {
|
|
347
|
+
"tensor_hash": hash_hex,
|
|
348
|
+
"blob_path": str(blob_path),
|
|
349
|
+
"blob_size": disk_size,
|
|
350
|
+
"blob_size_uncompressed": uncompressed_size,
|
|
351
|
+
"compression": compression,
|
|
352
|
+
"compression_ratio": compression_ratio,
|
|
353
|
+
"serialization_method": "torch_save",
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
except OSError as e:
|
|
357
|
+
# Disk full, permission errors, etc. - disable storage to avoid repeated failures
|
|
358
|
+
error_msg = f"Failed to save tensor blob (I/O error): {str(e)}"
|
|
359
|
+
log.error(error_msg)
|
|
360
|
+
self._disable_storage(error_msg)
|
|
361
|
+
return {"error": error_msg, "tensor_hash": None}
|
|
362
|
+
except Exception as e:
|
|
363
|
+
# Other unexpected errors - log but don't disable storage
|
|
364
|
+
error_msg = f"Failed to save tensor blob: {str(e)}"
|
|
365
|
+
log.error(error_msg)
|
|
366
|
+
return {"error": error_msg, "tensor_hash": None}
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
class TritonLogRecord(logging.LogRecord):
|
|
370
|
+
"""
|
|
371
|
+
Custom LogRecord class for structured logging of Triton operations.
|
|
372
|
+
|
|
373
|
+
Extends the standard LogRecord with additional attributes for storing
|
|
374
|
+
structured metadata and payload information.
|
|
375
|
+
"""
|
|
376
|
+
|
|
377
|
+
def __init__(
|
|
378
|
+
self,
|
|
379
|
+
name,
|
|
380
|
+
level,
|
|
381
|
+
pathname,
|
|
382
|
+
lineno,
|
|
383
|
+
msg,
|
|
384
|
+
args,
|
|
385
|
+
exc_info,
|
|
386
|
+
metadata=None,
|
|
387
|
+
payload=None,
|
|
388
|
+
**kwargs,
|
|
389
|
+
):
|
|
390
|
+
super().__init__(name, level, pathname, lineno, msg, args, exc_info, **kwargs)
|
|
391
|
+
self.metadata: Dict[str, Any] = metadata or {}
|
|
392
|
+
self.payload: Optional[Union[str, Dict[str, Any], list]] = payload
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def create_triton_log_record(
|
|
396
|
+
name=None,
|
|
397
|
+
level=logging.DEBUG,
|
|
398
|
+
pathname=None,
|
|
399
|
+
lineno=None,
|
|
400
|
+
msg="",
|
|
401
|
+
args=(),
|
|
402
|
+
exc_info=None,
|
|
403
|
+
metadata=None,
|
|
404
|
+
payload=None,
|
|
405
|
+
**kwargs,
|
|
406
|
+
):
|
|
407
|
+
"""
|
|
408
|
+
Factory method to create TritonLogRecord instances with sensible defaults.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
name (str, optional): Logger name. Defaults to triton_trace_log.name.
|
|
412
|
+
level (int, optional): Log level. Defaults to DEBUG.
|
|
413
|
+
pathname (str, optional): Path to the file where the log call was made. Defaults to current file.
|
|
414
|
+
lineno (int, optional): Line number where the log call was made. Defaults to current line.
|
|
415
|
+
msg (str, optional): Log message. Defaults to empty string.
|
|
416
|
+
args (tuple, optional): Arguments to interpolate into the message. Defaults to empty tuple.
|
|
417
|
+
exc_info (optional): Exception information. Defaults to None.
|
|
418
|
+
metadata (Dict[str, Any], optional): Structured metadata for the log record. Defaults to empty dict.
|
|
419
|
+
payload (optional): Payload data. Defaults to None.
|
|
420
|
+
**kwargs: Additional keyword arguments for LogRecord
|
|
421
|
+
|
|
422
|
+
Returns:
|
|
423
|
+
TritonLogRecord: A custom log record with structured data
|
|
424
|
+
"""
|
|
425
|
+
if pathname is None:
|
|
426
|
+
pathname = __file__
|
|
427
|
+
if lineno is None:
|
|
428
|
+
lineno = inspect.currentframe().f_back.f_lineno
|
|
429
|
+
if name is None:
|
|
430
|
+
name = triton_trace_log.name
|
|
431
|
+
|
|
432
|
+
record = TritonLogRecord(
|
|
433
|
+
name,
|
|
434
|
+
level,
|
|
435
|
+
pathname,
|
|
436
|
+
lineno,
|
|
437
|
+
msg,
|
|
438
|
+
args,
|
|
439
|
+
exc_info,
|
|
440
|
+
metadata=metadata,
|
|
441
|
+
payload=payload,
|
|
442
|
+
**kwargs,
|
|
443
|
+
)
|
|
444
|
+
return record
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def convert(obj):
|
|
448
|
+
"""
|
|
449
|
+
Recursively converts dataclasses, dictionaries, and lists to their serializable forms.
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
obj: The object to convert, which can be a dataclass instance, dictionary, list, or any other type
|
|
453
|
+
|
|
454
|
+
Returns:
|
|
455
|
+
A serializable version of the input object where dataclasses are converted to dictionaries
|
|
456
|
+
"""
|
|
457
|
+
from triton.language.core import dtype
|
|
458
|
+
|
|
459
|
+
# 1. primitives that JSON already supports -------------------------------
|
|
460
|
+
if obj is None or isinstance(obj, (bool, int, str)):
|
|
461
|
+
return obj
|
|
462
|
+
|
|
463
|
+
if isinstance(obj, float):
|
|
464
|
+
# JSON spec forbids NaN/Infinity – keep precision but stay valid
|
|
465
|
+
if math.isfinite(obj):
|
|
466
|
+
return obj
|
|
467
|
+
return str(obj) # "NaN", "inf", "-inf"
|
|
468
|
+
|
|
469
|
+
# 2. simple containers ----------------------------------------------------
|
|
470
|
+
if isinstance(obj, (list, tuple)):
|
|
471
|
+
# Handle namedtuple specially to preserve field names
|
|
472
|
+
if hasattr(obj, "_asdict"):
|
|
473
|
+
return convert(obj._asdict())
|
|
474
|
+
return [convert(x) for x in obj]
|
|
475
|
+
|
|
476
|
+
if isinstance(obj, (set, frozenset)):
|
|
477
|
+
return [convert(x) for x in sorted(obj, key=str)]
|
|
478
|
+
|
|
479
|
+
if isinstance(obj, Mapping):
|
|
480
|
+
return {str(k): convert(v) for k, v in obj.items()}
|
|
481
|
+
|
|
482
|
+
# 3. time, enum, path, bytes ---------------------------------------------
|
|
483
|
+
if isinstance(obj, (datetime, date)):
|
|
484
|
+
return obj.isoformat()
|
|
485
|
+
|
|
486
|
+
if isinstance(obj, Enum):
|
|
487
|
+
return convert(obj.value)
|
|
488
|
+
|
|
489
|
+
if isinstance(obj, Path):
|
|
490
|
+
return str(obj)
|
|
491
|
+
|
|
492
|
+
if is_dataclass(obj):
|
|
493
|
+
return convert(
|
|
494
|
+
asdict(obj)
|
|
495
|
+
) # Convert dataclass to dict and then process that dict
|
|
496
|
+
|
|
497
|
+
if _is_triton_kernels_layout(obj):
|
|
498
|
+
layout_info = {"type": type(obj).__name__}
|
|
499
|
+
if hasattr(obj, "initial_shape"):
|
|
500
|
+
layout_info["initial_shape"] = convert(obj.initial_shape)
|
|
501
|
+
if hasattr(obj, "name"):
|
|
502
|
+
layout_info["name"] = convert(obj.name)
|
|
503
|
+
return layout_info
|
|
504
|
+
|
|
505
|
+
# 4. Common Triton constexpr objects
|
|
506
|
+
if isinstance(obj, dtype):
|
|
507
|
+
return f"triton.language.core.dtype('{str(obj)}')"
|
|
508
|
+
|
|
509
|
+
if TORCH_INSTALLED and isinstance(obj, torch.dtype):
|
|
510
|
+
return str(obj)
|
|
511
|
+
|
|
512
|
+
log.warning(f"Unknown type: {type(obj)}")
|
|
513
|
+
return str(obj) # Return primitive types as-is
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def _is_triton_kernels_layout(obj):
|
|
517
|
+
"""
|
|
518
|
+
Check if an object is an instance of a Layout class from a
|
|
519
|
+
triton_kernels module by checking its MRO.
|
|
520
|
+
"""
|
|
521
|
+
t = type(obj)
|
|
522
|
+
for base_class in t.__mro__:
|
|
523
|
+
module_name = getattr(base_class, "__module__", "")
|
|
524
|
+
type_name = getattr(base_class, "__name__", "")
|
|
525
|
+
if type_name == "Layout" and module_name.startswith("triton_kernels"):
|
|
526
|
+
return True
|
|
527
|
+
return False
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def _is_from_triton_kernels_module(obj):
|
|
531
|
+
"""
|
|
532
|
+
Check if an object is an instance of Tensor or Storage from a
|
|
533
|
+
triton_kernels module.
|
|
534
|
+
"""
|
|
535
|
+
t = type(obj)
|
|
536
|
+
module_name = getattr(t, "__module__", "")
|
|
537
|
+
type_name = getattr(t, "__name__", "")
|
|
538
|
+
return type_name in ("Tensor", "Storage") and module_name.startswith(
|
|
539
|
+
"triton_kernels"
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def _log_torch_tensor_info(tensor_value):
|
|
544
|
+
"""
|
|
545
|
+
Extracts metadata from a torch.Tensor object.
|
|
546
|
+
|
|
547
|
+
Args:
|
|
548
|
+
tensor_value (torch.Tensor): The tensor to extract information from.
|
|
549
|
+
|
|
550
|
+
Returns:
|
|
551
|
+
dict: A dictionary containing tensor metadata.
|
|
552
|
+
"""
|
|
553
|
+
arg_info = {}
|
|
554
|
+
arg_info["type"] = "tensor"
|
|
555
|
+
arg_info["shape"] = list(tensor_value.shape)
|
|
556
|
+
arg_info["dtype"] = str(tensor_value.dtype)
|
|
557
|
+
arg_info["device"] = str(tensor_value.device)
|
|
558
|
+
arg_info["stride"] = list(tensor_value.stride())
|
|
559
|
+
arg_info["numel"] = tensor_value.numel()
|
|
560
|
+
arg_info["is_contiguous"] = tensor_value.is_contiguous()
|
|
561
|
+
arg_info["element_size"] = tensor_value.element_size()
|
|
562
|
+
arg_info["storage_offset"] = tensor_value.storage_offset()
|
|
563
|
+
# Memory usage in bytes
|
|
564
|
+
arg_info["memory_usage"] = tensor_value.numel() * tensor_value.element_size()
|
|
565
|
+
# Add data_ptr for memory tracking (optional)
|
|
566
|
+
if hasattr(tensor_value, "data_ptr"):
|
|
567
|
+
arg_info["data_ptr"] = hex(tensor_value.data_ptr())
|
|
568
|
+
if TRITONPARSE_MORE_TENSOR_INFORMATION:
|
|
569
|
+
try:
|
|
570
|
+
# Convert to float for reliable statistics computation across all dtypes
|
|
571
|
+
# This creates a new tensor without modifying the original
|
|
572
|
+
float_tensor = tensor_value.float()
|
|
573
|
+
arg_info["min"] = float_tensor.min().item()
|
|
574
|
+
arg_info["max"] = float_tensor.max().item()
|
|
575
|
+
arg_info["mean"] = float_tensor.mean().item()
|
|
576
|
+
arg_info["std"] = float_tensor.std().item()
|
|
577
|
+
except (RuntimeError, ValueError, TypeError) as e:
|
|
578
|
+
log.error(f"Unable to compute tensor statistics: {e}")
|
|
579
|
+
arg_info["tensor_capture_error"] = str(e)
|
|
580
|
+
|
|
581
|
+
# Add tensor blob storage if enabled
|
|
582
|
+
if TRITONPARSE_SAVE_TENSOR_BLOBS and TENSOR_BLOB_MANAGER is not None:
|
|
583
|
+
blob_info = TENSOR_BLOB_MANAGER.save_tensor_blob(tensor_value)
|
|
584
|
+
arg_info.update(blob_info)
|
|
585
|
+
return arg_info
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
def maybe_enable_debug_logging():
|
|
589
|
+
"""
|
|
590
|
+
This logging is for logging module itself, not for logging the triton compilation.
|
|
591
|
+
"""
|
|
592
|
+
if TRITONPARSE_DEBUG:
|
|
593
|
+
# Always set debug level if TRITONPARSE_DEBUG is set
|
|
594
|
+
log.setLevel(logging.DEBUG)
|
|
595
|
+
|
|
596
|
+
# Prevent propagation to root logger to avoid duplicate messages
|
|
597
|
+
log.propagate = False
|
|
598
|
+
|
|
599
|
+
# Check if we already have a debug handler
|
|
600
|
+
has_debug_handler = any(
|
|
601
|
+
isinstance(handler, logging.StreamHandler)
|
|
602
|
+
and handler.level <= logging.DEBUG
|
|
603
|
+
for handler in log.handlers
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
if not has_debug_handler:
|
|
607
|
+
log_handler = logging.StreamHandler()
|
|
608
|
+
log_handler.setLevel(logging.DEBUG)
|
|
609
|
+
formatter = logging.Formatter("%(asctime)s[%(levelname)s] %(message)s")
|
|
610
|
+
formatter.default_time_format = "%Y%m%d %H:%M:%S"
|
|
611
|
+
formatter.default_msec_format = None
|
|
612
|
+
log_handler.setFormatter(formatter)
|
|
613
|
+
log.addHandler(log_handler)
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
def get_stack_trace(skip=1):
|
|
617
|
+
"""
|
|
618
|
+
Get call stack trace for the current execution context.
|
|
619
|
+
|
|
620
|
+
Extracts stack trace information using torch's CapturedTraceback utility,
|
|
621
|
+
providing detailed information about each frame in the call stack.
|
|
622
|
+
|
|
623
|
+
Args:
|
|
624
|
+
skip (int): Number of frames to skip from the start of the stack
|
|
625
|
+
|
|
626
|
+
Returns:
|
|
627
|
+
List[Dict]: List of frame information dictionaries containing line numbers,
|
|
628
|
+
function names, filenames, and code snippets
|
|
629
|
+
"""
|
|
630
|
+
if not TORCH_INSTALLED:
|
|
631
|
+
return []
|
|
632
|
+
frames = []
|
|
633
|
+
for frame in CapturedTraceback.extract(skip=skip).summary():
|
|
634
|
+
frames.append(
|
|
635
|
+
{
|
|
636
|
+
"line": frame.lineno,
|
|
637
|
+
"name": frame.name,
|
|
638
|
+
"filename": frame.filename,
|
|
639
|
+
"loc": frame.line,
|
|
640
|
+
}
|
|
641
|
+
)
|
|
642
|
+
return frames
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
def parse_kernel_allowlist() -> Optional[List[str]]:
|
|
646
|
+
"""
|
|
647
|
+
Parse the kernel allowlist from environment variable.
|
|
648
|
+
|
|
649
|
+
Returns:
|
|
650
|
+
List[str] or None: List of kernel name patterns to trace, or None if all kernels should be traced
|
|
651
|
+
"""
|
|
652
|
+
if not TRITONPARSE_KERNEL_ALLOWLIST:
|
|
653
|
+
return None
|
|
654
|
+
|
|
655
|
+
# Split by comma and strip whitespace
|
|
656
|
+
patterns = [pattern.strip() for pattern in TRITONPARSE_KERNEL_ALLOWLIST.split(",")]
|
|
657
|
+
# Filter out empty patterns
|
|
658
|
+
patterns = [pattern for pattern in patterns if pattern]
|
|
659
|
+
|
|
660
|
+
if not patterns:
|
|
661
|
+
return None
|
|
662
|
+
|
|
663
|
+
log.debug(f"Kernel allowlist patterns: {patterns}")
|
|
664
|
+
return patterns
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def extract_kernel_name(src) -> Optional[str]:
|
|
668
|
+
"""
|
|
669
|
+
Extract kernel name from the source object.
|
|
670
|
+
|
|
671
|
+
Args:
|
|
672
|
+
src (Union[ASTSource, IRSource]): Source object containing kernel information
|
|
673
|
+
|
|
674
|
+
Returns:
|
|
675
|
+
str or None: Kernel name if extractable, None otherwise
|
|
676
|
+
"""
|
|
677
|
+
from triton.compiler import IRSource
|
|
678
|
+
|
|
679
|
+
try:
|
|
680
|
+
if isinstance(src, IRSource):
|
|
681
|
+
return src.getattr("name", None)
|
|
682
|
+
else:
|
|
683
|
+
# For ASTSource, get the function name
|
|
684
|
+
if (
|
|
685
|
+
hasattr(src, "fn")
|
|
686
|
+
and hasattr(src.fn, "fn")
|
|
687
|
+
and hasattr(src.fn.fn, "__name__")
|
|
688
|
+
):
|
|
689
|
+
return src.fn.fn.__name__
|
|
690
|
+
return None
|
|
691
|
+
except Exception as e:
|
|
692
|
+
log.warn(f"Error extracting kernel name: {e}")
|
|
693
|
+
return None
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
def should_trace_kernel(
|
|
697
|
+
kernel_name: Optional[str], allowlist_patterns: Optional[List[str]]
|
|
698
|
+
) -> bool:
|
|
699
|
+
"""
|
|
700
|
+
Check if a kernel should be traced based on the allowlist.
|
|
701
|
+
|
|
702
|
+
Args:
|
|
703
|
+
kernel_name (str or None): Name of the kernel
|
|
704
|
+
allowlist_patterns (List[str] or None): List of patterns to match against
|
|
705
|
+
|
|
706
|
+
Returns:
|
|
707
|
+
bool: True if the kernel should be traced, False otherwise
|
|
708
|
+
"""
|
|
709
|
+
# If no allowlist is set, trace all kernels
|
|
710
|
+
if allowlist_patterns is None:
|
|
711
|
+
return True
|
|
712
|
+
|
|
713
|
+
# If we can't extract kernel name, don't trace (conservative approach)
|
|
714
|
+
if kernel_name is None:
|
|
715
|
+
log.debug("Cannot extract kernel name, skipping trace")
|
|
716
|
+
return False
|
|
717
|
+
|
|
718
|
+
# Check if kernel name matches any pattern in the allowlist
|
|
719
|
+
for pattern in allowlist_patterns:
|
|
720
|
+
if fnmatch.fnmatch(kernel_name, pattern):
|
|
721
|
+
log.debug(f"Kernel '{kernel_name}' matches pattern '{pattern}', will trace")
|
|
722
|
+
return True
|
|
723
|
+
|
|
724
|
+
log.debug(
|
|
725
|
+
f"Kernel '{kernel_name}' does not match any allowlist pattern, skipping trace"
|
|
726
|
+
)
|
|
727
|
+
return False
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
def extract_python_source_info(trace_data: Dict[str, Any], source):
|
|
731
|
+
"""
|
|
732
|
+
Extract Python source code information from the source object and add it to trace_data.
|
|
733
|
+
|
|
734
|
+
This function uses Python's inspect module to extract source code information
|
|
735
|
+
from the provided source object (typically an ASTSource or IRSource instance).
|
|
736
|
+
It adds file path, line numbers, and the actual source code to the trace_data.
|
|
737
|
+
|
|
738
|
+
By default, only the function definition is extracted. Set TRITON_FULL_PYTHON_SOURCE=1
|
|
739
|
+
to extract the entire Python source file.
|
|
740
|
+
@TODO: we should enable it by default in next diff and track the compilation time regression
|
|
741
|
+
|
|
742
|
+
Environment Variables:
|
|
743
|
+
TRITON_FULL_PYTHON_SOURCE: If set to "1", extract the full Python file
|
|
744
|
+
instead of just the function definition.
|
|
745
|
+
TRITON_MAX_SOURCE_SIZE: Maximum file size in bytes for full source extraction
|
|
746
|
+
(default: 10MB). Files larger than this will fall back
|
|
747
|
+
to function-only mode.
|
|
748
|
+
|
|
749
|
+
Args:
|
|
750
|
+
trace_data (Dict[str, Any]): Dictionary to store extracted information
|
|
751
|
+
source (Union[ASTSource, IRSource]): Source object containing kernel function information
|
|
752
|
+
"""
|
|
753
|
+
# @TODO: add support for IRSource
|
|
754
|
+
from triton.compiler import IRSource
|
|
755
|
+
from triton.runtime.jit import JITFunction
|
|
756
|
+
|
|
757
|
+
if isinstance(source, IRSource):
|
|
758
|
+
return
|
|
759
|
+
|
|
760
|
+
# Get the function reference
|
|
761
|
+
if isinstance(fn := source.fn, JITFunction):
|
|
762
|
+
fn_ref = fn.fn
|
|
763
|
+
else:
|
|
764
|
+
fn_ref = source.fn
|
|
765
|
+
|
|
766
|
+
python_source_file = inspect.getfile(fn_ref)
|
|
767
|
+
|
|
768
|
+
# Get function range information
|
|
769
|
+
if (
|
|
770
|
+
isinstance(fn := source.fn, JITFunction)
|
|
771
|
+
and hasattr(fn, "starting_line_number")
|
|
772
|
+
and hasattr(fn, "raw_src")
|
|
773
|
+
):
|
|
774
|
+
function_start_line = fn.starting_line_number
|
|
775
|
+
source_lines = fn.raw_src
|
|
776
|
+
else:
|
|
777
|
+
source_lines, function_start_line = inspect.getsourcelines(fn_ref)
|
|
778
|
+
|
|
779
|
+
function_end_line = function_start_line + len(source_lines) - 1
|
|
780
|
+
|
|
781
|
+
if TRITON_FULL_PYTHON_SOURCE:
|
|
782
|
+
# Full file mode: read the entire Python file
|
|
783
|
+
try:
|
|
784
|
+
# Check file size before reading
|
|
785
|
+
file_size = os.path.getsize(python_source_file)
|
|
786
|
+
except OSError as e:
|
|
787
|
+
log.warning(
|
|
788
|
+
f"Failed to check file size for {python_source_file}: {e}. "
|
|
789
|
+
f"Falling back to function-only mode."
|
|
790
|
+
)
|
|
791
|
+
use_full_source = False
|
|
792
|
+
else:
|
|
793
|
+
if file_size > TRITON_MAX_SOURCE_SIZE:
|
|
794
|
+
log.warning(
|
|
795
|
+
f"Source file {python_source_file} is too large ({file_size} bytes, "
|
|
796
|
+
f"limit: {TRITON_MAX_SOURCE_SIZE} bytes). Falling back to function-only mode."
|
|
797
|
+
)
|
|
798
|
+
use_full_source = False
|
|
799
|
+
else:
|
|
800
|
+
use_full_source = True
|
|
801
|
+
|
|
802
|
+
if use_full_source:
|
|
803
|
+
try:
|
|
804
|
+
with open(python_source_file, "r", encoding="utf-8") as f:
|
|
805
|
+
file_content = f.read()
|
|
806
|
+
|
|
807
|
+
# Calculate total lines
|
|
808
|
+
total_lines = len(file_content.split("\n"))
|
|
809
|
+
|
|
810
|
+
trace_data["python_source"] = {
|
|
811
|
+
"file_path": python_source_file,
|
|
812
|
+
"start_line": 1,
|
|
813
|
+
"end_line": total_lines,
|
|
814
|
+
"code": file_content,
|
|
815
|
+
# Add function range for frontend highlighting and scrolling
|
|
816
|
+
"function_start_line": function_start_line,
|
|
817
|
+
"function_end_line": function_end_line,
|
|
818
|
+
}
|
|
819
|
+
return
|
|
820
|
+
except (OSError, UnicodeDecodeError) as e:
|
|
821
|
+
log.warning(
|
|
822
|
+
f"Failed to read full source file {python_source_file}: {e}. "
|
|
823
|
+
f"Falling back to function-only mode."
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
# Default behavior: only extract function definition
|
|
827
|
+
trace_data["python_source"] = {
|
|
828
|
+
"file_path": python_source_file,
|
|
829
|
+
"start_line": function_start_line,
|
|
830
|
+
"end_line": function_end_line,
|
|
831
|
+
"code": "".join(source_lines),
|
|
832
|
+
}
|
|
833
|
+
|
|
834
|
+
|
|
835
|
+
def extract_file_content(trace_data: Dict[str, Any], metadata_group: Dict[str, str]):
|
|
836
|
+
"""
|
|
837
|
+
Extract file content from metadata_group and add it to trace_data.
|
|
838
|
+
|
|
839
|
+
Args:
|
|
840
|
+
trace_data (Dict): Dictionary to store extracted information
|
|
841
|
+
metadata_group (Dict): Dictionary mapping filenames to file paths
|
|
842
|
+
"""
|
|
843
|
+
for ir_filename, file_path in metadata_group.items():
|
|
844
|
+
# Add file path to trace data
|
|
845
|
+
trace_data["file_path"][ir_filename] = file_path
|
|
846
|
+
|
|
847
|
+
# Check if this is a text file we can read
|
|
848
|
+
if any(ir_filename.endswith(ext) for ext in TEXT_FILE_EXTENSIONS):
|
|
849
|
+
try:
|
|
850
|
+
# Check file size before reading to avoid memory issues
|
|
851
|
+
file_size = os.path.getsize(file_path)
|
|
852
|
+
if file_size > MAX_FILE_SIZE:
|
|
853
|
+
message = f"<file too large: {file_size} bytes>"
|
|
854
|
+
trace_data["file_content"][ir_filename] = message
|
|
855
|
+
continue
|
|
856
|
+
|
|
857
|
+
with open(file_path, "r") as f:
|
|
858
|
+
trace_data["file_content"][ir_filename] = f.read()
|
|
859
|
+
except (UnicodeDecodeError, OSError) as e:
|
|
860
|
+
# add more specific error type
|
|
861
|
+
message = f"<error reading file: {str(e)}>"
|
|
862
|
+
trace_data["file_content"][ir_filename] = message
|
|
863
|
+
log.debug(f"Error reading file {file_path}: {e}")
|
|
864
|
+
cubin_keys = [key for key in metadata_group.keys() if key.endswith(".cubin")]
|
|
865
|
+
cubin_path = metadata_group[cubin_keys[0]] if cubin_keys else None
|
|
866
|
+
|
|
867
|
+
if TRITONPARSE_DUMP_SASS and cubin_path:
|
|
868
|
+
filename_no_ext = os.path.splitext(os.path.basename(cubin_path))[0]
|
|
869
|
+
sass_filename = f"{filename_no_ext}.sass"
|
|
870
|
+
try:
|
|
871
|
+
import tritonparse.tools.disasm
|
|
872
|
+
|
|
873
|
+
sass_content = tritonparse.tools.disasm.extract(cubin_path)
|
|
874
|
+
trace_data["file_content"][sass_filename] = sass_content
|
|
875
|
+
except subprocess.CalledProcessError as e:
|
|
876
|
+
message = f"<nvdisasm failed: {str(e)}>"
|
|
877
|
+
trace_data["file_content"][sass_filename] = message
|
|
878
|
+
except OSError as e:
|
|
879
|
+
message = f"<error reading cubin file: {str(e)}>"
|
|
880
|
+
trace_data["file_content"][sass_filename] = message
|
|
881
|
+
except Exception as e:
|
|
882
|
+
message = f"<error dumping SASS: {str(e)}>"
|
|
883
|
+
trace_data["file_content"][sass_filename] = message
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
def extract_metadata_from_src(trace_data, src):
|
|
887
|
+
from triton._C.libtriton import get_cache_invalidating_env_vars
|
|
888
|
+
|
|
889
|
+
env_vars = get_cache_invalidating_env_vars()
|
|
890
|
+
# extra_options = src.parse_options()
|
|
891
|
+
# options = backend.parse_options(dict(options or dict(), **extra_options))
|
|
892
|
+
|
|
893
|
+
# trace_data["extra_options"] = extra_options
|
|
894
|
+
trace_data["metadata"].update(
|
|
895
|
+
{
|
|
896
|
+
"env": env_vars,
|
|
897
|
+
"src_attrs": src.attrs if hasattr(src, "attrs") else {},
|
|
898
|
+
"src_cache_key": src.fn.cache_key if hasattr(src, "fn") else "",
|
|
899
|
+
"src_constants": src.constants if hasattr(src, "constants") else {},
|
|
900
|
+
}
|
|
901
|
+
)
|
|
902
|
+
|
|
903
|
+
|
|
904
|
+
class TritonJsonFormatter(logging.Formatter):
|
|
905
|
+
"""
|
|
906
|
+
Format log records as JSON for Triton compilation tracing.
|
|
907
|
+
|
|
908
|
+
This formatter converts log records with metadata and payload into NDJSON format,
|
|
909
|
+
suitable for structured logging and later analysis. It handles special attributes
|
|
910
|
+
added by the tritonparse, such as metadata dictionaries and payload data.
|
|
911
|
+
"""
|
|
912
|
+
|
|
913
|
+
def format(self, record: logging.LogRecord):
|
|
914
|
+
log_entry = record.metadata
|
|
915
|
+
payload = record.payload
|
|
916
|
+
|
|
917
|
+
log_entry["timestamp"] = self.formatTime(record, "%Y-%m-%dT%H:%M:%S.%fZ")
|
|
918
|
+
if payload is not None:
|
|
919
|
+
log_entry["payload"] = json.loads(payload)
|
|
920
|
+
clean_log_entry = convert(log_entry)
|
|
921
|
+
# NDJSON format requires a newline at the end of each line
|
|
922
|
+
json_str = json.dumps(clean_log_entry, separators=(",", ":"))
|
|
923
|
+
return json_str + "\n"
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
class TritonTraceHandler(logging.StreamHandler):
|
|
927
|
+
"""
|
|
928
|
+
A handler for Triton compilation tracing that outputs NDJSON files.
|
|
929
|
+
|
|
930
|
+
This handler creates and manages log files for Triton kernel compilation traces.
|
|
931
|
+
It supports creating new files for different compilation events and handles
|
|
932
|
+
proper cleanup of file resources. When running in a distributed environment,
|
|
933
|
+
it automatically adds rank information to filenames.
|
|
934
|
+
"""
|
|
935
|
+
|
|
936
|
+
def __init__(
|
|
937
|
+
self, root_dir: Optional[str] = None, prefix=DEFAULT_TRACE_FILE_PREFIX
|
|
938
|
+
):
|
|
939
|
+
logging.Handler.__init__(self)
|
|
940
|
+
self.root_dir = root_dir
|
|
941
|
+
self.prefix = prefix
|
|
942
|
+
self.stream = None
|
|
943
|
+
self.first_record = True
|
|
944
|
+
# If the program is unexpected terminated, atexit can ensure file resources are properly closed and released.
|
|
945
|
+
# it is because we use `self.stream` to keep the opened file stream, if the program is interrupted by some errors, the stream may not be closed.
|
|
946
|
+
atexit.register(self._cleanup)
|
|
947
|
+
|
|
948
|
+
def get_root_dir(self):
|
|
949
|
+
# For meta internal runs, we use the /logs directory by default
|
|
950
|
+
# reference implementation
|
|
951
|
+
# https://github.com/pytorch/pytorch/blob/5fe58ab5bd9e14cce3107150a9956a2ed40d2f79/torch/_logging/_internal.py#L1071
|
|
952
|
+
if self.root_dir:
|
|
953
|
+
return self.root_dir
|
|
954
|
+
TRACE_LOG_DIR = "/logs"
|
|
955
|
+
should_set_root_dir = True
|
|
956
|
+
if TORCH_INSTALLED:
|
|
957
|
+
import torch.version as torch_version
|
|
958
|
+
|
|
959
|
+
if (
|
|
960
|
+
hasattr(torch_version, "git_version")
|
|
961
|
+
and os.getenv("MAST_HPC_JOB_NAME") is None
|
|
962
|
+
):
|
|
963
|
+
log.info(
|
|
964
|
+
"TritonTraceHandler: disabled because not fbcode or conda on mast"
|
|
965
|
+
)
|
|
966
|
+
should_set_root_dir = False
|
|
967
|
+
# TODO: change to tritonparse knob
|
|
968
|
+
# The following check is necessary because the possible version mismatch between torch and tritonparse
|
|
969
|
+
elif (
|
|
970
|
+
hasattr(torch, "_utils_internal")
|
|
971
|
+
and hasattr(torch._utils_internal, "justknobs_check")
|
|
972
|
+
and not torch._utils_internal.justknobs_check("pytorch/trace:enable")
|
|
973
|
+
):
|
|
974
|
+
log.info(
|
|
975
|
+
"TritonTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False"
|
|
976
|
+
)
|
|
977
|
+
should_set_root_dir = False
|
|
978
|
+
if should_set_root_dir:
|
|
979
|
+
if not os.path.exists(TRACE_LOG_DIR):
|
|
980
|
+
log.info(
|
|
981
|
+
"TritonTraceHandler: disabled because %s does not exist",
|
|
982
|
+
TRACE_LOG_DIR,
|
|
983
|
+
)
|
|
984
|
+
elif not os.access(TRACE_LOG_DIR, os.W_OK):
|
|
985
|
+
log.info(
|
|
986
|
+
"TritonTraceHandler: disabled because %s is not writable",
|
|
987
|
+
TRACE_LOG_DIR,
|
|
988
|
+
)
|
|
989
|
+
else:
|
|
990
|
+
self.root_dir = TRACE_LOG_DIR
|
|
991
|
+
return self.root_dir
|
|
992
|
+
|
|
993
|
+
def emit(self, record):
|
|
994
|
+
# reference implementation
|
|
995
|
+
# https://github.com/pytorch/pytorch/blob/5fe58ab5bd9e14cce3107150a9956a2ed40d2f79/torch/_logging/_internal.py#L1071
|
|
996
|
+
try:
|
|
997
|
+
if self.stream is None:
|
|
998
|
+
root_dir = self.get_root_dir()
|
|
999
|
+
if root_dir is not None:
|
|
1000
|
+
os.makedirs(root_dir, exist_ok=True)
|
|
1001
|
+
ranksuffix = ""
|
|
1002
|
+
if TORCH_INSTALLED:
|
|
1003
|
+
import torch.distributed as dist
|
|
1004
|
+
|
|
1005
|
+
if dist.is_available() and dist.is_initialized():
|
|
1006
|
+
ranksuffix = f"rank_{dist.get_rank()}_"
|
|
1007
|
+
filename = f"{self.prefix}{ranksuffix}"
|
|
1008
|
+
self._ensure_stream_closed()
|
|
1009
|
+
# Choose file extension and mode based on compression setting
|
|
1010
|
+
if TRITON_TRACE_GZIP:
|
|
1011
|
+
file_extension = ".bin.ndjson"
|
|
1012
|
+
file_mode = "ab+" # Binary mode for gzip member concatenation
|
|
1013
|
+
else:
|
|
1014
|
+
file_extension = ".ndjson"
|
|
1015
|
+
file_mode = "a+"
|
|
1016
|
+
log_file_name = os.path.abspath(
|
|
1017
|
+
os.path.join(root_dir, f"{filename}{file_extension}")
|
|
1018
|
+
)
|
|
1019
|
+
self.stream = open(
|
|
1020
|
+
log_file_name,
|
|
1021
|
+
mode=file_mode,
|
|
1022
|
+
)
|
|
1023
|
+
log.debug("TritonTraceHandler: logging to %s", log_file_name)
|
|
1024
|
+
else:
|
|
1025
|
+
triton_trace_log.removeHandler(self)
|
|
1026
|
+
return
|
|
1027
|
+
|
|
1028
|
+
if self.stream:
|
|
1029
|
+
formatted = self.format(record)
|
|
1030
|
+
if TRITON_TRACE_GZIP:
|
|
1031
|
+
# Create a separate gzip member for each record
|
|
1032
|
+
# This allows standard gzip readers to handle member concatenation automatically
|
|
1033
|
+
buffer = io.BytesIO()
|
|
1034
|
+
with gzip.GzipFile(fileobj=buffer, mode="wb") as gz:
|
|
1035
|
+
gz.write(formatted.encode("utf-8"))
|
|
1036
|
+
# Write the complete gzip member to the file
|
|
1037
|
+
compressed_data = buffer.getvalue()
|
|
1038
|
+
self.stream.write(compressed_data)
|
|
1039
|
+
else:
|
|
1040
|
+
self.stream.write(formatted)
|
|
1041
|
+
self.flush()
|
|
1042
|
+
except Exception as e:
|
|
1043
|
+
# record exception and ensure resources are released
|
|
1044
|
+
log.error(f"Error in TritonTraceHandler.emit: {e}")
|
|
1045
|
+
self._ensure_stream_closed()
|
|
1046
|
+
self.handleError(record) # call Handler's standard error handling
|
|
1047
|
+
|
|
1048
|
+
def close(self):
|
|
1049
|
+
"""Close the current file."""
|
|
1050
|
+
self.acquire()
|
|
1051
|
+
try:
|
|
1052
|
+
try:
|
|
1053
|
+
if self.stream:
|
|
1054
|
+
try:
|
|
1055
|
+
self.flush()
|
|
1056
|
+
finally:
|
|
1057
|
+
self.stream.close()
|
|
1058
|
+
self.stream = None
|
|
1059
|
+
finally:
|
|
1060
|
+
# Solution adopted from PyTorch PR #120289
|
|
1061
|
+
logging.StreamHandler.close(self)
|
|
1062
|
+
finally:
|
|
1063
|
+
self.release()
|
|
1064
|
+
|
|
1065
|
+
def _cleanup(self):
|
|
1066
|
+
"""Ensure proper cleanup on program exit"""
|
|
1067
|
+
if self.stream is not None:
|
|
1068
|
+
self.close()
|
|
1069
|
+
|
|
1070
|
+
def _ensure_stream_closed(self):
|
|
1071
|
+
"""ensure stream is closed"""
|
|
1072
|
+
if self.stream is not None:
|
|
1073
|
+
try:
|
|
1074
|
+
self.flush()
|
|
1075
|
+
finally:
|
|
1076
|
+
self.stream.close()
|
|
1077
|
+
self.stream = None
|
|
1078
|
+
|
|
1079
|
+
|
|
1080
|
+
def init_logs():
|
|
1081
|
+
"""
|
|
1082
|
+
Initialise tritonparse's logging system.
|
|
1083
|
+
|
|
1084
|
+
Requirements handled:
|
|
1085
|
+
1. First call may or may not pass `trace_folder`.
|
|
1086
|
+
2. A later call *can* pass `trace_folder` and must activate an
|
|
1087
|
+
existing handler whose `root_dir` was None.
|
|
1088
|
+
3. When tracing is disabled (no writable directory), prevent the
|
|
1089
|
+
empty →
|
|
1090
|
+
DEBUG:tritonparse_trace:
|
|
1091
|
+
lines by blocking propagation to the root logger.
|
|
1092
|
+
"""
|
|
1093
|
+
global TRITON_TRACE_HANDLER, triton_trace_folder, TENSOR_BLOB_MANAGER
|
|
1094
|
+
|
|
1095
|
+
# Basic logger settings (safe to run on every call)
|
|
1096
|
+
triton_trace_log.setLevel(logging.DEBUG)
|
|
1097
|
+
triton_trace_log.propagate = False # stops bubbling to root logger. see 3)
|
|
1098
|
+
# 1) Create the handler on first use (root_dir may be None)
|
|
1099
|
+
if TRITON_TRACE_HANDLER is None:
|
|
1100
|
+
TRITON_TRACE_HANDLER = TritonTraceHandler(triton_trace_folder)
|
|
1101
|
+
# 2) If the handler has no root_dir but we now have
|
|
1102
|
+
# `triton_trace_folder`, fill it in.
|
|
1103
|
+
if TRITON_TRACE_HANDLER.root_dir is None and triton_trace_folder is not None:
|
|
1104
|
+
TRITON_TRACE_HANDLER.root_dir = triton_trace_folder
|
|
1105
|
+
# 3) Re-evaluate whether we have a writable directory
|
|
1106
|
+
# (`get_root_dir()` also checks /logs logic, permissions, etc.)
|
|
1107
|
+
root_dir = TRITON_TRACE_HANDLER.get_root_dir()
|
|
1108
|
+
if root_dir is None:
|
|
1109
|
+
# Tracing still disabled: ensure the handler is NOT attached
|
|
1110
|
+
if TRITON_TRACE_HANDLER in triton_trace_log.handlers:
|
|
1111
|
+
triton_trace_log.removeHandler(TRITON_TRACE_HANDLER)
|
|
1112
|
+
return # quiet exit, no blank lines
|
|
1113
|
+
# 4) Tracing is enabled: attach the handler (if not already
|
|
1114
|
+
# attached) and set the JSON formatter.
|
|
1115
|
+
if TRITON_TRACE_HANDLER not in triton_trace_log.handlers:
|
|
1116
|
+
TRITON_TRACE_HANDLER.setFormatter(TritonJsonFormatter())
|
|
1117
|
+
triton_trace_log.addHandler(TRITON_TRACE_HANDLER)
|
|
1118
|
+
|
|
1119
|
+
# Initialize tensor blob manager if enabled
|
|
1120
|
+
if TRITONPARSE_SAVE_TENSOR_BLOBS and root_dir:
|
|
1121
|
+
if TENSOR_BLOB_MANAGER is None:
|
|
1122
|
+
TENSOR_BLOB_MANAGER = TensorBlobManager(
|
|
1123
|
+
root_dir=root_dir, storage_quota=TRITONPARSE_TENSOR_STORAGE_QUOTA
|
|
1124
|
+
)
|
|
1125
|
+
elif TENSOR_BLOB_MANAGER.root_dir is None:
|
|
1126
|
+
# Update root_dir if it wasn't set during initialization
|
|
1127
|
+
TENSOR_BLOB_MANAGER.set_root_dir(root_dir)
|
|
1128
|
+
|
|
1129
|
+
|
|
1130
|
+
def trace_structured_triton(
|
|
1131
|
+
name: str,
|
|
1132
|
+
metadata_fn: Optional[Callable[[], Dict[str, Any]]] = None,
|
|
1133
|
+
*,
|
|
1134
|
+
payload_fn: Optional[Callable[[], Optional[Union[str, object]]]] = None,
|
|
1135
|
+
):
|
|
1136
|
+
"""
|
|
1137
|
+
Record structured trace information for Triton kernel compilation.
|
|
1138
|
+
|
|
1139
|
+
This function is the main entry point for logging structured trace events
|
|
1140
|
+
in the Triton system. It handles initialization of the logging system if needed,
|
|
1141
|
+
creates new log files, and formats the trace data with metadata
|
|
1142
|
+
and payload information.
|
|
1143
|
+
|
|
1144
|
+
Args:
|
|
1145
|
+
name (str): Name of the trace event (e.g., "compilation", "execution")
|
|
1146
|
+
metadata_fn (Callable): Function that returns a dictionary of metadata to include
|
|
1147
|
+
in the trace record
|
|
1148
|
+
payload_fn (Callable): Function that returns the payload data (can be a string,
|
|
1149
|
+
dictionary, or other serializable object)
|
|
1150
|
+
"""
|
|
1151
|
+
|
|
1152
|
+
if metadata_fn is None:
|
|
1153
|
+
|
|
1154
|
+
def metadata_fn():
|
|
1155
|
+
return {}
|
|
1156
|
+
|
|
1157
|
+
if payload_fn is None:
|
|
1158
|
+
|
|
1159
|
+
def payload_fn():
|
|
1160
|
+
return None
|
|
1161
|
+
|
|
1162
|
+
metadata_dict: Dict[str, Any] = {"event_type": name}
|
|
1163
|
+
metadata_dict["pid"] = os.getpid()
|
|
1164
|
+
custom_metadata = metadata_fn()
|
|
1165
|
+
if custom_metadata:
|
|
1166
|
+
metadata_dict.update(custom_metadata)
|
|
1167
|
+
|
|
1168
|
+
metadata_dict["stack"] = get_stack_trace()
|
|
1169
|
+
|
|
1170
|
+
# Log the record using our custom LogRecord
|
|
1171
|
+
payload = payload_fn()
|
|
1172
|
+
# Use a custom factory to create the record with simplified parameters
|
|
1173
|
+
record = create_triton_log_record(metadata=metadata_dict, payload=payload)
|
|
1174
|
+
# Log the custom record
|
|
1175
|
+
triton_trace_log.handle(record)
|
|
1176
|
+
|
|
1177
|
+
|
|
1178
|
+
def maybe_trace_triton(
|
|
1179
|
+
src,
|
|
1180
|
+
metadata: Dict[str, Any],
|
|
1181
|
+
metadata_group: Dict[str, str],
|
|
1182
|
+
times: Any,
|
|
1183
|
+
event_type: str = "compilation",
|
|
1184
|
+
cache_hit: bool = False,
|
|
1185
|
+
):
|
|
1186
|
+
"""
|
|
1187
|
+
Collect and trace Triton kernel compilation information for debugging and profiling.
|
|
1188
|
+
|
|
1189
|
+
This function gathers metadata, IR files, and source code information about a Triton
|
|
1190
|
+
kernel compilation, then logs it through the tracing system if tracing is enabled.
|
|
1191
|
+
It collects information from multiple sources:
|
|
1192
|
+
1. JSON metadata file (if provided)
|
|
1193
|
+
2. PyTorch compilation context (if available)
|
|
1194
|
+
3. IR and other compilation artifact files
|
|
1195
|
+
4. Python source code of the kernel function
|
|
1196
|
+
|
|
1197
|
+
This function is designed to be used as a CompilationListener in triton.knobs.compilation.listener,
|
|
1198
|
+
which now accepts a list of listeners.
|
|
1199
|
+
|
|
1200
|
+
Args:
|
|
1201
|
+
src (Union[ASTSource, IRSource]): Source object containing kernel information
|
|
1202
|
+
metadata (Dict[str, Any]): Dictionary containing metadata for the compilation
|
|
1203
|
+
metadata_group (Dict[str, Any]): Dictionary mapping filenames to file paths for all compilation artifacts
|
|
1204
|
+
times (CompileTimes): Object containing timing information for the compilation
|
|
1205
|
+
event_type (str): Type of event being traced (default: "compilation")
|
|
1206
|
+
cache_hit (bool): Whether the compilation was a cache hit (default: False)
|
|
1207
|
+
|
|
1208
|
+
Returns:
|
|
1209
|
+
Dict[str, Any]: Dictionary containing all collected trace data, even if tracing is disabled
|
|
1210
|
+
"""
|
|
1211
|
+
# Check kernel allowlist early to avoid unnecessary work
|
|
1212
|
+
if _KERNEL_ALLOWLIST_PATTERNS is not None:
|
|
1213
|
+
kernel_name = extract_kernel_name(src)
|
|
1214
|
+
if not should_trace_kernel(kernel_name, _KERNEL_ALLOWLIST_PATTERNS):
|
|
1215
|
+
# Return empty dict to indicate no tracing was done
|
|
1216
|
+
return {}
|
|
1217
|
+
|
|
1218
|
+
# Initialize a dictionary with defaultdict to avoid key errors
|
|
1219
|
+
trace_data = defaultdict(dict)
|
|
1220
|
+
# Add cache_hit to metadata
|
|
1221
|
+
trace_data["metadata"]["cache_hit"] = cache_hit
|
|
1222
|
+
if not metadata:
|
|
1223
|
+
metadata_path = next(
|
|
1224
|
+
(Path(p) for c, p in metadata_group.items() if c.endswith(".json"))
|
|
1225
|
+
)
|
|
1226
|
+
with open(metadata_path, "r") as f:
|
|
1227
|
+
metadata = json.load(f)
|
|
1228
|
+
trace_data["metadata"].update(metadata)
|
|
1229
|
+
else:
|
|
1230
|
+
trace_data["metadata"].update(metadata)
|
|
1231
|
+
# Handle torch._guards which might not be recognized by type checker
|
|
1232
|
+
if TORCH_INSTALLED:
|
|
1233
|
+
trace_id = torch._guards.CompileContext.current_trace_id() # type: ignore
|
|
1234
|
+
else:
|
|
1235
|
+
trace_id = None
|
|
1236
|
+
cid = trace_id.compile_id if trace_id else None
|
|
1237
|
+
if cid is not None:
|
|
1238
|
+
for attr_name in ["compiled_autograd_id", "frame_id", "frame_compile_id"]:
|
|
1239
|
+
attr_value = getattr(cid, attr_name, None)
|
|
1240
|
+
if attr_value is not None:
|
|
1241
|
+
trace_data["pt_info"][attr_name] = attr_value
|
|
1242
|
+
if trace_id:
|
|
1243
|
+
trace_data["pt_info"]["attempt"] = trace_id.attempt
|
|
1244
|
+
# Extract content from all IR and other files in the metadata group
|
|
1245
|
+
extract_file_content(trace_data, metadata_group)
|
|
1246
|
+
# Extract Python source code information if available
|
|
1247
|
+
extract_python_source_info(trace_data, src)
|
|
1248
|
+
extract_metadata_from_src(trace_data, src)
|
|
1249
|
+
|
|
1250
|
+
# Add timing information if available
|
|
1251
|
+
if times:
|
|
1252
|
+
trace_data["metadata"]["times"] = times
|
|
1253
|
+
# Log the collected information through the tracing system
|
|
1254
|
+
trace_structured_triton(
|
|
1255
|
+
event_type,
|
|
1256
|
+
payload_fn=lambda: json.dumps(convert(trace_data)),
|
|
1257
|
+
)
|
|
1258
|
+
|
|
1259
|
+
return trace_data
|
|
1260
|
+
|
|
1261
|
+
|
|
1262
|
+
def extract_arg_info(arg_dict):
|
|
1263
|
+
"""
|
|
1264
|
+
Extract detailed information from kernel arguments, especially for PyTorch
|
|
1265
|
+
tensors.
|
|
1266
|
+
|
|
1267
|
+
Args:
|
|
1268
|
+
arg_dict: Dictionary of kernel arguments
|
|
1269
|
+
|
|
1270
|
+
Returns:
|
|
1271
|
+
Dictionary with extracted argument information including tensor properties
|
|
1272
|
+
"""
|
|
1273
|
+
extracted_args = {}
|
|
1274
|
+
|
|
1275
|
+
for arg_name, arg_value in arg_dict.items():
|
|
1276
|
+
arg_info = {}
|
|
1277
|
+
|
|
1278
|
+
# Check if it's a PyTorch tensor
|
|
1279
|
+
if TORCH_INSTALLED and isinstance(arg_value, torch.Tensor):
|
|
1280
|
+
arg_info["type"] = "tensor"
|
|
1281
|
+
arg_info.update(_log_torch_tensor_info(arg_value))
|
|
1282
|
+
# Handle custom Tensor/Storage types from triton_kernels
|
|
1283
|
+
elif _is_from_triton_kernels_module(arg_value):
|
|
1284
|
+
type_name = type(arg_value).__name__
|
|
1285
|
+
arg_info["type"] = f"triton_kernels.tensor.{type_name}"
|
|
1286
|
+
|
|
1287
|
+
if type_name == "Tensor":
|
|
1288
|
+
# Dump all attributes needed to reconstruct the Tensor wrapper
|
|
1289
|
+
if hasattr(arg_value, "shape"):
|
|
1290
|
+
arg_info["shape"] = convert(arg_value.shape)
|
|
1291
|
+
if hasattr(arg_value, "shape_max"):
|
|
1292
|
+
arg_info["shape_max"] = convert(arg_value.shape_max)
|
|
1293
|
+
if hasattr(arg_value, "dtype"):
|
|
1294
|
+
arg_info["dtype"] = convert(arg_value.dtype)
|
|
1295
|
+
if hasattr(arg_value, "storage"):
|
|
1296
|
+
# Recursively process the storage, which can be another
|
|
1297
|
+
# custom type or a torch.Tensor
|
|
1298
|
+
storage_arg = {"storage": arg_value.storage}
|
|
1299
|
+
arg_info["storage"] = extract_arg_info(storage_arg)["storage"]
|
|
1300
|
+
|
|
1301
|
+
elif type_name == "Storage":
|
|
1302
|
+
# Dump all attributes needed to reconstruct the Storage object
|
|
1303
|
+
if (
|
|
1304
|
+
hasattr(arg_value, "data")
|
|
1305
|
+
and TORCH_INSTALLED
|
|
1306
|
+
and isinstance(arg_value.data, torch.Tensor)
|
|
1307
|
+
):
|
|
1308
|
+
# The 'data' is a torch.Tensor, log its metadata fully
|
|
1309
|
+
arg_info["data"] = _log_torch_tensor_info(arg_value.data)
|
|
1310
|
+
if hasattr(arg_value, "layout"):
|
|
1311
|
+
arg_info["layout"] = convert(arg_value.layout)
|
|
1312
|
+
else:
|
|
1313
|
+
log.warning(f"Unknown type: {type(arg_value)}")
|
|
1314
|
+
|
|
1315
|
+
# Handle scalar values
|
|
1316
|
+
elif isinstance(arg_value, (int, float, bool)):
|
|
1317
|
+
arg_info["type"] = type(arg_value).__name__
|
|
1318
|
+
arg_info["value"] = arg_value
|
|
1319
|
+
# Handle strings
|
|
1320
|
+
elif isinstance(arg_value, str):
|
|
1321
|
+
arg_info["type"] = "str"
|
|
1322
|
+
arg_info["value"] = arg_value
|
|
1323
|
+
arg_info["length"] = len(arg_value)
|
|
1324
|
+
# Handle other types
|
|
1325
|
+
else:
|
|
1326
|
+
arg_info["type"] = type(arg_value).__name__
|
|
1327
|
+
# Try to convert to string for logging
|
|
1328
|
+
arg_info["repr"] = str(arg_value)
|
|
1329
|
+
if len(arg_info["repr"]) > 200: # Truncate very long representations
|
|
1330
|
+
arg_info["repr"] = arg_info["repr"][:200] + "..."
|
|
1331
|
+
|
|
1332
|
+
extracted_args[arg_name] = arg_info
|
|
1333
|
+
|
|
1334
|
+
return extracted_args
|
|
1335
|
+
|
|
1336
|
+
|
|
1337
|
+
def add_launch_metadata(grid, metadata, arg_dict, inductor_args=None):
|
|
1338
|
+
# Check if we're in CUDA graph capture mode - if so, skip detailed argument extraction
|
|
1339
|
+
# to avoid CUDA errors (cudaErrorStreamCaptureUnsupported)
|
|
1340
|
+
is_capturing = False
|
|
1341
|
+
if TORCH_INSTALLED:
|
|
1342
|
+
try:
|
|
1343
|
+
is_capturing = torch.cuda.is_current_stream_capturing()
|
|
1344
|
+
except (AttributeError, RuntimeError):
|
|
1345
|
+
pass
|
|
1346
|
+
|
|
1347
|
+
if is_capturing:
|
|
1348
|
+
# During CUDA graph capture, return minimal metadata without argument extraction
|
|
1349
|
+
return {
|
|
1350
|
+
"launch_metadata_tritonparse": (
|
|
1351
|
+
grid,
|
|
1352
|
+
metadata._asdict(),
|
|
1353
|
+
{"_note": "argument extraction skipped during CUDA graph capture"},
|
|
1354
|
+
{},
|
|
1355
|
+
)
|
|
1356
|
+
}
|
|
1357
|
+
|
|
1358
|
+
# Extract detailed argument information (only when NOT capturing)
|
|
1359
|
+
extracted_args = extract_arg_info(arg_dict)
|
|
1360
|
+
extracted_inductor_args = extract_arg_info(inductor_args) if inductor_args else {}
|
|
1361
|
+
return {
|
|
1362
|
+
"launch_metadata_tritonparse": (
|
|
1363
|
+
grid,
|
|
1364
|
+
metadata._asdict(),
|
|
1365
|
+
extracted_args,
|
|
1366
|
+
extracted_inductor_args,
|
|
1367
|
+
)
|
|
1368
|
+
}
|
|
1369
|
+
|
|
1370
|
+
|
|
1371
|
+
class JITHookImpl(JITHook):
|
|
1372
|
+
"""
|
|
1373
|
+
JIT Hook implementation that overrides or sets the launch_metadata function for Triton kernels.
|
|
1374
|
+
|
|
1375
|
+
This hook is essential for capturing detailed kernel launch information beyond the basic
|
|
1376
|
+
metadata (like kernel name) that Triton provides by default. Without setting a custom
|
|
1377
|
+
launch_metadata function, only minimal launch information is available as shown in:
|
|
1378
|
+
https://github.com/triton-lang/triton/blob/7ce287dc24b43476cdeb30529089ac361564505d/python/triton/compiler/compiler.py#L504
|
|
1379
|
+
|
|
1380
|
+
By intercepting the JIT compilation process and setting a custom launch_metadata function,
|
|
1381
|
+
we can capture comprehensive runtime information including grid parameters, kernel metadata,
|
|
1382
|
+
and argument dictionaries for detailed analysis and logging.
|
|
1383
|
+
"""
|
|
1384
|
+
|
|
1385
|
+
def __call__(
|
|
1386
|
+
self,
|
|
1387
|
+
*,
|
|
1388
|
+
key: str,
|
|
1389
|
+
repr: str,
|
|
1390
|
+
fn,
|
|
1391
|
+
compile,
|
|
1392
|
+
is_manual_warmup: bool,
|
|
1393
|
+
already_compiled: bool,
|
|
1394
|
+
inductor_args: Optional[Dict[str, Any]] = None,
|
|
1395
|
+
) -> Optional[bool]:
|
|
1396
|
+
"""
|
|
1397
|
+
Override or set the launch_metadata function for the JIT-compiled kernel.
|
|
1398
|
+
|
|
1399
|
+
This method is called during the JIT compilation process and allows us to
|
|
1400
|
+
inject our custom launch_metadata function that will be used to collect
|
|
1401
|
+
detailed kernel launch information.
|
|
1402
|
+
|
|
1403
|
+
Args:
|
|
1404
|
+
key: Unique identifier for the kernel
|
|
1405
|
+
repr: String representation of the kernel
|
|
1406
|
+
fn: The JIT function object
|
|
1407
|
+
compile: Compilation function
|
|
1408
|
+
is_manual_warmup: Whether this is a manual warmup call
|
|
1409
|
+
already_compiled: Whether the kernel is already compiled
|
|
1410
|
+
|
|
1411
|
+
Returns:
|
|
1412
|
+
True to continue with compilation, None/False to skip
|
|
1413
|
+
"""
|
|
1414
|
+
# Check kernel allowlist early to avoid unnecessary work
|
|
1415
|
+
if _KERNEL_ALLOWLIST_PATTERNS is not None:
|
|
1416
|
+
kernel_name = fn.name
|
|
1417
|
+
if not should_trace_kernel(kernel_name, _KERNEL_ALLOWLIST_PATTERNS):
|
|
1418
|
+
# Skip overriding launch_metadata if kernel is not in allowlist
|
|
1419
|
+
return True
|
|
1420
|
+
|
|
1421
|
+
# Get the current launch_metadata function if it exists
|
|
1422
|
+
function = getattr(fn, "jit_function", fn)
|
|
1423
|
+
|
|
1424
|
+
current_launch_metadata = getattr(function, "launch_metadata", None)
|
|
1425
|
+
if current_launch_metadata is not None:
|
|
1426
|
+
log.warning(
|
|
1427
|
+
f"fn {fn} launch_metadata is not None: {current_launch_metadata}. It will be overridden by tritonparse."
|
|
1428
|
+
)
|
|
1429
|
+
function.launch_metadata = partial(
|
|
1430
|
+
add_launch_metadata, inductor_args=inductor_args
|
|
1431
|
+
)
|
|
1432
|
+
return True
|
|
1433
|
+
|
|
1434
|
+
|
|
1435
|
+
class LaunchHookImpl(LaunchHook):
|
|
1436
|
+
"""
|
|
1437
|
+
Launch Hook implementation for capturing and logging kernel launch metadata.
|
|
1438
|
+
|
|
1439
|
+
This hook is responsible for intercepting kernel launches and extracting the detailed
|
|
1440
|
+
metadata that was set up by the JITHookImpl. It provides entry point for
|
|
1441
|
+
kernel execution, allowing comprehensive logging and analysis of kernel launches
|
|
1442
|
+
including timing, parameters, and execution context.
|
|
1443
|
+
|
|
1444
|
+
The metadata captured includes:
|
|
1445
|
+
- Kernel name and function details
|
|
1446
|
+
- Grid dimensions and launch parameters
|
|
1447
|
+
- Kernel arguments and their values
|
|
1448
|
+
- Stream information
|
|
1449
|
+
- Custom metadata added by the launch_metadata function
|
|
1450
|
+
"""
|
|
1451
|
+
|
|
1452
|
+
def __call__(self, metadata):
|
|
1453
|
+
"""
|
|
1454
|
+
Handle kernel launch entry point.
|
|
1455
|
+
|
|
1456
|
+
This method is called when a kernel is about to be launched, providing
|
|
1457
|
+
access to all the launch metadata for logging, profiling, or analysis.
|
|
1458
|
+
metadata format:
|
|
1459
|
+
|
|
1460
|
+
Args:
|
|
1461
|
+
metadata: LazyDict containing comprehensive launch information including
|
|
1462
|
+
kernel name, function, stream, grid parameters, and custom data
|
|
1463
|
+
format: {'name': 'add_kernel', 'function': None, 'stream': 0,
|
|
1464
|
+
'launch_metadata_tritonparse': (grid, self.metadata, extracted_args)}
|
|
1465
|
+
where extracted_args contains detailed info for each argument:
|
|
1466
|
+
- For tensors: shape, dtype, device, stride, memory_usage, etc.
|
|
1467
|
+
- For scalars: type and value
|
|
1468
|
+
- For other types: type and string representation
|
|
1469
|
+
defined here:
|
|
1470
|
+
https://github.com/triton-lang/triton/blob/7ce287dc24b43476cdeb30529089ac361564505d/
|
|
1471
|
+
python/triton/compiler/compiler.py#L512.
|
|
1472
|
+
"""
|
|
1473
|
+
metadata_dict = metadata.get()
|
|
1474
|
+
# Check kernel allowlist early to avoid unnecessary work
|
|
1475
|
+
if _KERNEL_ALLOWLIST_PATTERNS is not None:
|
|
1476
|
+
kernel_name = metadata_dict.get("name")
|
|
1477
|
+
|
|
1478
|
+
if not should_trace_kernel(kernel_name, _KERNEL_ALLOWLIST_PATTERNS):
|
|
1479
|
+
# Skip tracing if kernel is not in allowlist
|
|
1480
|
+
return
|
|
1481
|
+
|
|
1482
|
+
trace_data = defaultdict(dict)
|
|
1483
|
+
trace_data["name"] = metadata_dict["name"]
|
|
1484
|
+
trace_data["function"] = metadata_dict["function"]
|
|
1485
|
+
trace_data["stream"] = metadata_dict["stream"]
|
|
1486
|
+
launch_metadata_tritonparse = metadata_dict.get(
|
|
1487
|
+
"launch_metadata_tritonparse", None
|
|
1488
|
+
)
|
|
1489
|
+
if launch_metadata_tritonparse is not None:
|
|
1490
|
+
trace_data["grid"] = launch_metadata_tritonparse[0]
|
|
1491
|
+
trace_data["compilation_metadata"] = launch_metadata_tritonparse[1]
|
|
1492
|
+
trace_data["extracted_args"] = launch_metadata_tritonparse[
|
|
1493
|
+
2
|
|
1494
|
+
] # Now contains detailed arg info
|
|
1495
|
+
trace_data["extracted_inductor_args"] = launch_metadata_tritonparse[3]
|
|
1496
|
+
trace_structured_triton("launch", metadata_fn=lambda: convert(trace_data))
|
|
1497
|
+
|
|
1498
|
+
|
|
1499
|
+
def maybe_enable_trace_launch():
|
|
1500
|
+
global _trace_launch_enabled
|
|
1501
|
+
if TRITON_TRACE_LAUNCH and not _trace_launch_enabled:
|
|
1502
|
+
from triton import knobs
|
|
1503
|
+
|
|
1504
|
+
launch_hook = LaunchHookImpl()
|
|
1505
|
+
jit_hook = JITHookImpl()
|
|
1506
|
+
knobs.runtime.jit_post_compile_hook = jit_hook
|
|
1507
|
+
knobs.runtime.launch_enter_hook = launch_hook
|
|
1508
|
+
|
|
1509
|
+
_trace_launch_enabled = True
|
|
1510
|
+
|
|
1511
|
+
|
|
1512
|
+
def init_basic(trace_folder: Optional[str] = None):
|
|
1513
|
+
"""
|
|
1514
|
+
Initialize the basic logging system for Triton compilation.
|
|
1515
|
+
|
|
1516
|
+
This function sets up the basic logging system for Triton kernel compilation.
|
|
1517
|
+
|
|
1518
|
+
Args:
|
|
1519
|
+
trace_folder (Optional[str]): The folder to store the trace files.
|
|
1520
|
+
"""
|
|
1521
|
+
global triton_trace_folder, _KERNEL_ALLOWLIST_PATTERNS
|
|
1522
|
+
maybe_enable_debug_logging()
|
|
1523
|
+
if triton_trace_folder is not None and trace_folder is not None:
|
|
1524
|
+
log.info(
|
|
1525
|
+
"Conflict settings: triton_trace_folder is already set to %s, we will use provided trace_folder(%s) instead.",
|
|
1526
|
+
triton_trace_folder,
|
|
1527
|
+
trace_folder,
|
|
1528
|
+
)
|
|
1529
|
+
if trace_folder is not None:
|
|
1530
|
+
triton_trace_folder = trace_folder
|
|
1531
|
+
|
|
1532
|
+
# Parse and store kernel allowlist configuration
|
|
1533
|
+
_KERNEL_ALLOWLIST_PATTERNS = parse_kernel_allowlist()
|
|
1534
|
+
if _KERNEL_ALLOWLIST_PATTERNS:
|
|
1535
|
+
log.debug(
|
|
1536
|
+
f"Kernel allowlist enabled with patterns: {_KERNEL_ALLOWLIST_PATTERNS}"
|
|
1537
|
+
)
|
|
1538
|
+
else:
|
|
1539
|
+
log.debug("Kernel allowlist not set, tracing all kernels")
|
|
1540
|
+
|
|
1541
|
+
init_logs()
|
|
1542
|
+
maybe_enable_trace_launch()
|
|
1543
|
+
|
|
1544
|
+
|
|
1545
|
+
def init(
|
|
1546
|
+
trace_folder: Optional[str] = None,
|
|
1547
|
+
enable_trace_launch: bool = False,
|
|
1548
|
+
enable_more_tensor_information: bool = False,
|
|
1549
|
+
enable_sass_dump: Optional[bool] = False,
|
|
1550
|
+
enable_tensor_blob_storage: bool = False,
|
|
1551
|
+
tensor_storage_quota: Optional[int] = None,
|
|
1552
|
+
):
|
|
1553
|
+
"""
|
|
1554
|
+
This function is a wrapper around init_basic() that also sets up the compilation listener. Its arguments have higher priority than the environment variables for same settings.
|
|
1555
|
+
|
|
1556
|
+
Args:
|
|
1557
|
+
trace_folder (Optional[str]): The folder to store the trace files.
|
|
1558
|
+
enable_trace_launch (bool): Whether to enable the trace launch hook.
|
|
1559
|
+
enable_more_tensor_information (bool): Whether to enable more tensor information logging.
|
|
1560
|
+
It only works when enable_trace_launch/TRITON_TRACE_LAUNCH is True.
|
|
1561
|
+
enable_sass_dump (Optional[bool]): Whether to enable SASS dumping.
|
|
1562
|
+
enable_tensor_blob_storage (bool): Whether to enable tensor blob storage.
|
|
1563
|
+
tensor_storage_quota (Optional[int]): Storage quota in bytes for tensor blobs (default: 100GB).
|
|
1564
|
+
"""
|
|
1565
|
+
global TRITON_TRACE_LAUNCH, TRITONPARSE_MORE_TENSOR_INFORMATION
|
|
1566
|
+
global TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK, TRITONPARSE_DUMP_SASS
|
|
1567
|
+
global TRITONPARSE_SAVE_TENSOR_BLOBS, TRITONPARSE_TENSOR_STORAGE_QUOTA
|
|
1568
|
+
|
|
1569
|
+
# Set global flags BEFORE calling init_basic, so init_logs() can see them
|
|
1570
|
+
if enable_trace_launch:
|
|
1571
|
+
TRITON_TRACE_LAUNCH = True
|
|
1572
|
+
TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK = True
|
|
1573
|
+
if enable_more_tensor_information:
|
|
1574
|
+
TRITONPARSE_MORE_TENSOR_INFORMATION = True
|
|
1575
|
+
if enable_sass_dump:
|
|
1576
|
+
TRITONPARSE_DUMP_SASS = True
|
|
1577
|
+
if enable_tensor_blob_storage:
|
|
1578
|
+
TRITONPARSE_SAVE_TENSOR_BLOBS = True
|
|
1579
|
+
|
|
1580
|
+
# Set the quota in global var for TensorBlobManager creation in init_logs()
|
|
1581
|
+
if tensor_storage_quota is not None:
|
|
1582
|
+
TRITONPARSE_TENSOR_STORAGE_QUOTA = tensor_storage_quota
|
|
1583
|
+
|
|
1584
|
+
init_basic(trace_folder)
|
|
1585
|
+
from triton import knobs
|
|
1586
|
+
|
|
1587
|
+
knobs.compilation.listener = maybe_trace_triton
|
|
1588
|
+
|
|
1589
|
+
|
|
1590
|
+
def init_with_env():
|
|
1591
|
+
"""
|
|
1592
|
+
This function is used to initialize TritonParse with the environment variable TRITON_TRACE_FOLDER and TRITON_TRACE_LAUNCH specifically.
|
|
1593
|
+
It is only supposed to be used in OSS triton's source code.
|
|
1594
|
+
"""
|
|
1595
|
+
if triton_trace_folder:
|
|
1596
|
+
init(triton_trace_folder, enable_trace_launch=TRITON_TRACE_LAUNCH)
|
|
1597
|
+
|
|
1598
|
+
|
|
1599
|
+
def clear_logging_config():
|
|
1600
|
+
"""
|
|
1601
|
+
Clear all configurations made by init() and init_basic().
|
|
1602
|
+
|
|
1603
|
+
This function resets the logging handlers, global state variables,
|
|
1604
|
+
and Triton knobs to their default states, effectively disabling
|
|
1605
|
+
the custom tracing.
|
|
1606
|
+
|
|
1607
|
+
WARNING: This function is not supposed to be called unless you are sure
|
|
1608
|
+
you want to clear the logging config.
|
|
1609
|
+
"""
|
|
1610
|
+
global TRITON_TRACE_HANDLER, triton_trace_folder, _KERNEL_ALLOWLIST_PATTERNS
|
|
1611
|
+
global _trace_launch_enabled
|
|
1612
|
+
global TENSOR_BLOB_MANAGER
|
|
1613
|
+
# 1. Clean up the log handler
|
|
1614
|
+
if TRITON_TRACE_HANDLER is not None:
|
|
1615
|
+
if TRITON_TRACE_HANDLER in triton_trace_log.handlers:
|
|
1616
|
+
triton_trace_log.removeHandler(TRITON_TRACE_HANDLER)
|
|
1617
|
+
TRITON_TRACE_HANDLER.close()
|
|
1618
|
+
TRITON_TRACE_HANDLER = None
|
|
1619
|
+
|
|
1620
|
+
# 2. Reset global state variables
|
|
1621
|
+
triton_trace_folder = None
|
|
1622
|
+
_KERNEL_ALLOWLIST_PATTERNS = None
|
|
1623
|
+
_trace_launch_enabled = False
|
|
1624
|
+
|
|
1625
|
+
# 3. Reset tensor blob manager and related flags
|
|
1626
|
+
TENSOR_BLOB_MANAGER = None
|
|
1627
|
+
|
|
1628
|
+
# 4. Reset Triton knobs
|
|
1629
|
+
# Check if triton was actually imported and used
|
|
1630
|
+
from triton import knobs
|
|
1631
|
+
|
|
1632
|
+
knobs.compilation.listener = None
|
|
1633
|
+
knobs.runtime.jit_post_compile_hook = None
|
|
1634
|
+
knobs.runtime.launch_enter_hook = None
|