tritonparse 0.3.2.dev20251210071601__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritonparse might be problematic. Click here for more details.

Files changed (62) hide show
  1. tritonparse/__init__.py +0 -0
  2. tritonparse/__main__.py +7 -0
  3. tritonparse/cli.py +110 -0
  4. tritonparse/common.py +409 -0
  5. tritonparse/context_manager.py +64 -0
  6. tritonparse/event_diff.py +122 -0
  7. tritonparse/extract_source_mappings.py +49 -0
  8. tritonparse/info/__init__.py +30 -0
  9. tritonparse/info/cli.py +121 -0
  10. tritonparse/info/kernel_query.py +209 -0
  11. tritonparse/info/parse_helper.py +70 -0
  12. tritonparse/ir_analysis.py +427 -0
  13. tritonparse/ir_parser.py +365 -0
  14. tritonparse/mapper.py +102 -0
  15. tritonparse/reproducer/__init__.py +0 -0
  16. tritonparse/reproducer/ast_analyzer.py +636 -0
  17. tritonparse/reproducer/cli.py +72 -0
  18. tritonparse/reproducer/consolidated_result.py +52 -0
  19. tritonparse/reproducer/function_extractor.py +228 -0
  20. tritonparse/reproducer/import_info.py +25 -0
  21. tritonparse/reproducer/import_parser.py +178 -0
  22. tritonparse/reproducer/import_resolver.py +151 -0
  23. tritonparse/reproducer/ingestion/ndjson.py +237 -0
  24. tritonparse/reproducer/multi_file_analyzer.py +824 -0
  25. tritonparse/reproducer/orchestrator.py +110 -0
  26. tritonparse/reproducer/placeholder_replacer.py +335 -0
  27. tritonparse/reproducer/templates/__init__.py +0 -0
  28. tritonparse/reproducer/templates/example.py +38 -0
  29. tritonparse/reproducer/templates/loader.py +59 -0
  30. tritonparse/reproducer/templates/tritonbench.py +106 -0
  31. tritonparse/reproducer/templates/utils.py +48 -0
  32. tritonparse/reproducer/tests/__init__.py +0 -0
  33. tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
  34. tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
  35. tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
  36. tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
  37. tritonparse/reproducer/tests/test_import_parser.py +164 -0
  38. tritonparse/reproducer/tests/test_import_resolver.py +88 -0
  39. tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
  40. tritonparse/reproducer/types.py +20 -0
  41. tritonparse/reproducer/utils.py +580 -0
  42. tritonparse/shared_vars.py +12 -0
  43. tritonparse/source_type.py +56 -0
  44. tritonparse/sourcemap_utils.py +96 -0
  45. tritonparse/structured_logging.py +1634 -0
  46. tritonparse/tools/__init__.py +0 -0
  47. tritonparse/tools/decompress_bin_ndjson.py +120 -0
  48. tritonparse/tools/disasm.py +81 -0
  49. tritonparse/tools/extract_irs.py +244 -0
  50. tritonparse/tools/format_fix.py +151 -0
  51. tritonparse/tools/load_tensor.py +76 -0
  52. tritonparse/tools/prettify_ndjson.py +334 -0
  53. tritonparse/tools/readme.md +37 -0
  54. tritonparse/tp_logger.py +9 -0
  55. tritonparse/trace_processor.py +367 -0
  56. tritonparse/utils.py +155 -0
  57. tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
  58. tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
  59. tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
  60. tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
  61. tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
  62. tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
File without changes
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from .cli import main
4
+
5
+
6
+ if __name__ == "__main__":
7
+ main()
tritonparse/cli.py ADDED
@@ -0,0 +1,110 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ import argparse
4
+ from importlib.metadata import PackageNotFoundError, version
5
+
6
+ from .common import is_fbcode
7
+ from .info.cli import _add_info_args, info_command
8
+ from .reproducer.cli import _add_reproducer_args
9
+ from .reproducer.orchestrator import reproduce
10
+ from .utils import _add_parse_args, unified_parse
11
+
12
+
13
+ def _get_package_version() -> str:
14
+ try:
15
+ return version("tritonparse")
16
+ except PackageNotFoundError:
17
+ return "0+unknown"
18
+
19
+
20
+ def main():
21
+ pkg_version = _get_package_version()
22
+
23
+ # Use different command name for fbcode vs OSS
24
+ prog_name = "tritonparse" if is_fbcode() else "tritonparseoss"
25
+
26
+ parser = argparse.ArgumentParser(
27
+ prog=prog_name,
28
+ description=(
29
+ "TritonParse: parse structured logs and generate minimal reproducers"
30
+ ),
31
+ epilog=(
32
+ "Examples:\n"
33
+ f" {prog_name} parse /path/to/logs --out parsed_output\n"
34
+ f" {prog_name} reproduce /path/to/trace.ndjson --line 1 --out-dir repro_output\n"
35
+ f" {prog_name} info /path/to/trace.ndjson\n"
36
+ f" {prog_name} info /path/to/trace.ndjson --kernel matmul_kernel\n"
37
+ ),
38
+ formatter_class=argparse.RawDescriptionHelpFormatter,
39
+ )
40
+ parser.add_argument(
41
+ "--version",
42
+ action="version",
43
+ version=f"%(prog)s {pkg_version}",
44
+ help="Show program's version number and exit",
45
+ )
46
+
47
+ subparsers = parser.add_subparsers(dest="command", required=True)
48
+
49
+ # parse subcommand
50
+ parse_parser = subparsers.add_parser(
51
+ "parse",
52
+ help="Parse triton structured logs",
53
+ conflict_handler="resolve",
54
+ )
55
+ _add_parse_args(parse_parser)
56
+ parse_parser.set_defaults(func="parse")
57
+
58
+ # reproduce subcommand
59
+ repro_parser = subparsers.add_parser(
60
+ "reproduce",
61
+ help="Build reproducer from trace file",
62
+ )
63
+ _add_reproducer_args(repro_parser)
64
+ repro_parser.set_defaults(func="reproduce")
65
+
66
+ # info subcommand
67
+ info_parser = subparsers.add_parser(
68
+ "info",
69
+ help="Query kernel information from trace file",
70
+ )
71
+ _add_info_args(info_parser)
72
+ info_parser.set_defaults(func="info")
73
+
74
+ args = parser.parse_args()
75
+
76
+ if args.func == "parse":
77
+ parse_args = {
78
+ k: v for k, v in vars(args).items() if k not in ["command", "func"]
79
+ }
80
+ unified_parse(**parse_args)
81
+ elif args.func == "reproduce":
82
+ # Check mutual exclusivity between --line and --kernel/--launch-id
83
+ if args.kernel and args.line != 0:
84
+ repro_parser.error("--line and --kernel/--launch-id are mutually exclusive")
85
+
86
+ replacer = None
87
+ if args.use_fbcode:
88
+ from tritonparse.fb.reproducer.replacer import FBCodePlaceholderReplacer
89
+
90
+ replacer = FBCodePlaceholderReplacer()
91
+ print(f"Using FBCode placeholder replacer for template: {args.template}")
92
+
93
+ reproduce(
94
+ input_path=args.input,
95
+ line_index=args.line if not args.kernel else 0,
96
+ out_dir=args.out_dir,
97
+ template=args.template,
98
+ kernel_name=args.kernel,
99
+ launch_id=args.launch_id if args.kernel else 0,
100
+ kernel_import=args.kernel_import,
101
+ replacer=replacer,
102
+ )
103
+ elif args.func == "info":
104
+ info_command(input_path=args.input, kernel_name=args.kernel)
105
+ else:
106
+ raise RuntimeError(f"Unknown command: {args.func}")
107
+
108
+
109
+ if __name__ == "__main__":
110
+ main() # pragma: no cover
tritonparse/common.py ADDED
@@ -0,0 +1,409 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import gzip
4
+ import importlib
5
+ import importlib.util
6
+ import json
7
+ import os
8
+ import re
9
+ import shutil
10
+ import tempfile
11
+ from collections import defaultdict
12
+ from pathlib import Path
13
+ from typing import Optional, Tuple
14
+
15
+ from .extract_source_mappings import parse_single_file
16
+ from .shared_vars import DEFAULT_TRACE_FILE_PREFIX_WITHOUT_USER as LOG_PREFIX
17
+ from .tp_logger import logger
18
+
19
+ LOG_RANK_REGEX = re.compile(r"rank_(\d+)")
20
+
21
+
22
+ def is_fbcode():
23
+ return importlib.util.find_spec("tritonparse.fb") is not None
24
+
25
+
26
+ if is_fbcode():
27
+ from .fb.source_type import SourceType
28
+ else:
29
+ from .source_type import SourceType
30
+
31
+
32
+ class Rank:
33
+ """Class representing a rank in distributed training."""
34
+
35
+ def __init__(self, rank_value: Optional[int] = None):
36
+ """
37
+ Initialize a Rank object.
38
+
39
+ Args:
40
+ rank_value: Specific rank value, or None for default rank
41
+ """
42
+ if rank_value is not None:
43
+ self.value = rank_value
44
+ self.is_default = False
45
+ else:
46
+ self.value = 0
47
+ self.is_default = True
48
+
49
+ def to_string(self, prefix: str = "", suffix: str = "") -> str:
50
+ """
51
+ Convert rank to string representation with optional prefix.
52
+
53
+ Args:
54
+ prefix: Prefix to add before rank string
55
+
56
+ Returns:
57
+ String representation of the rank
58
+ """
59
+ if self.is_default:
60
+ return ""
61
+ return f"{prefix}rank_{self.value}{suffix}"
62
+
63
+ def to_int(self) -> int:
64
+ """
65
+ Convert rank to integer value.
66
+
67
+ Returns:
68
+ Integer value of the rank
69
+ """
70
+ return self.value
71
+
72
+
73
+ class RankConfig:
74
+ """Configuration for handling ranks in log processing."""
75
+
76
+ def __init__(
77
+ self,
78
+ rank: Optional[Rank] = None,
79
+ all_ranks: bool = False,
80
+ is_local: bool = False,
81
+ ):
82
+ """
83
+ Initialize a RankConfig object.
84
+
85
+ Args:
86
+ rank: Specific rank to process
87
+ all_ranks: Whether to process all ranks
88
+ is_local: Whether processing local logs
89
+ """
90
+ self.rank = rank
91
+ self.all_ranks = all_ranks
92
+ self.is_local = is_local
93
+
94
+ @classmethod
95
+ def from_cli_args(
96
+ cls, rank: Optional[int], all_ranks: bool, source_type: SourceType
97
+ ) -> "RankConfig":
98
+ """
99
+ Create a RankConfig from command line arguments.
100
+
101
+ Args:
102
+ rank: Specific rank value from CLI
103
+ all_ranks: Whether --all-ranks flag was specified
104
+ source_type: Type of source
105
+
106
+ Returns:
107
+ Configured RankConfig object
108
+ """
109
+ if all_ranks:
110
+ if rank is not None:
111
+ raise ValueError("Can't specify both a rank and --all-ranks")
112
+ return cls(all_ranks=True)
113
+
114
+ if rank is not None:
115
+ return cls(rank=Rank(rank))
116
+ if source_type in [SourceType.LOCAL, SourceType.LOCAL_FILE]:
117
+ return cls(is_local=True)
118
+ elif is_fbcode():
119
+ from tritonparse.fb.utils import rank_config_from_cli_args
120
+
121
+ return rank_config_from_cli_args(cls, source_type)
122
+ else:
123
+ return cls(all_ranks=True)
124
+
125
+ def to_rank(self) -> Rank:
126
+ """
127
+ Get the rank object from this config.
128
+
129
+ Returns:
130
+ Rank object
131
+ """
132
+ if self.rank:
133
+ return self.rank
134
+ return Rank()
135
+
136
+
137
+ def print_parsed_files_summary(parsed_log_dir: str) -> None:
138
+ """
139
+ Print a beautiful summary of all parsed files.
140
+
141
+ Args:
142
+ parsed_log_dir: Directory containing parsed files
143
+ """
144
+ # Collect all parsed files
145
+ all_parsed_files = []
146
+ for root, _, files in os.walk(parsed_log_dir):
147
+ for file in files:
148
+ file_path = os.path.join(root, file)
149
+ all_parsed_files.append(file_path)
150
+
151
+ # Sort files for consistent output
152
+ all_parsed_files.sort()
153
+
154
+ # Print beautiful summary
155
+ print("\n" + "=" * 80)
156
+ print("📁 TRITONPARSE PARSING RESULTS")
157
+ print("=" * 80)
158
+
159
+ # Print log file list (required for integration)
160
+ print(f"📂 Parsed files directory: {parsed_log_dir}")
161
+ print(f"📊 Total files generated: {len(all_parsed_files)}")
162
+
163
+ if all_parsed_files:
164
+ print("\n📄 Generated files:")
165
+ print("-" * 50)
166
+ for i, file_path in enumerate(all_parsed_files, 1):
167
+ # Get relative path for cleaner display
168
+ rel_path = os.path.relpath(file_path, parsed_log_dir)
169
+ file_size = "N/A"
170
+ try:
171
+ size_bytes = os.path.getsize(file_path)
172
+ if size_bytes < 1024:
173
+ file_size = f"{size_bytes}B"
174
+ elif size_bytes < 1024 * 1024:
175
+ file_size = f"{size_bytes / 1024:.1f}KB"
176
+ else:
177
+ file_size = f"{size_bytes / (1024 * 1024):.1f}MB"
178
+ except OSError:
179
+ pass
180
+
181
+ print(f" {i:2d}. 📝 {rel_path} ({file_size})")
182
+
183
+ print("=" * 80)
184
+ print("✅ Parsing completed successfully!")
185
+ print("=" * 80 + "\n")
186
+
187
+
188
+ def gzip_single_file(file_path: str, verbose: bool = False) -> str:
189
+ """
190
+ Gzip a single file and delete the original file.
191
+ Args:
192
+ file_path: Path to the file to gzip
193
+ verbose: Whether to print verbose information
194
+ Returns:
195
+ Path to the gzipped file
196
+ """
197
+ if file_path.endswith(".gz"):
198
+ return file_path
199
+
200
+ gz_file_path = file_path + ".gz"
201
+ if verbose:
202
+ logger.info(f"Gzipping {file_path}")
203
+
204
+ with open(file_path, "rb") as f_in:
205
+ with gzip.open(gz_file_path, "wb") as f_out:
206
+ shutil.copyfileobj(f_in, f_out)
207
+
208
+ # Delete the original file after successful compression
209
+ os.remove(file_path)
210
+ if verbose:
211
+ logger.info(f"Deleted original file {file_path}")
212
+
213
+ return gz_file_path
214
+
215
+
216
+ def copy_local_to_tmpdir(local_path: str, verbose: bool = False) -> str:
217
+ """
218
+ Copy local log files to a temporary directory.
219
+
220
+ Args:
221
+ local_path: Path to local directory or single file containing logs
222
+ verbose: Whether to print verbose information
223
+
224
+ Returns:
225
+ Path to temporary directory containing copied logs
226
+
227
+ Raises:
228
+ RuntimeError: If the local_path does not exist
229
+ """
230
+ if not os.path.exists(local_path):
231
+ raise RuntimeError(f"Path does not exist: {local_path}")
232
+
233
+ temp_dir = tempfile.mkdtemp()
234
+
235
+ # Handle single file case
236
+ if os.path.isfile(local_path):
237
+ if os.path.basename(local_path).startswith(LOG_PREFIX):
238
+ if verbose:
239
+ logger.info(f"Copying single file {local_path} to {temp_dir}")
240
+ shutil.copy2(local_path, temp_dir)
241
+ return temp_dir
242
+
243
+ # Handle directory case
244
+ if not os.path.isdir(local_path):
245
+ raise RuntimeError(f"Path is neither a file nor a directory: {local_path}")
246
+
247
+ for item in os.listdir(local_path):
248
+ item_path = os.path.join(local_path, item)
249
+ if os.path.isfile(item_path) and os.path.basename(item_path).startswith(
250
+ LOG_PREFIX
251
+ ):
252
+ if verbose:
253
+ logger.info(f"Copying {item_path} to {temp_dir}")
254
+ shutil.copy2(item_path, temp_dir)
255
+
256
+ return temp_dir
257
+
258
+
259
+ def parse_logs(
260
+ logs_to_parse: str,
261
+ rank_config: RankConfig,
262
+ verbose: bool = False,
263
+ tritonparse_url_prefix: str = "",
264
+ split_inductor_compilations: bool = True,
265
+ ) -> Tuple[str, dict]:
266
+ """
267
+ Parse logs.
268
+
269
+ Args:
270
+ logs_to_parse: Path to directory containing logs to parse
271
+ rank_config: Rank configuration
272
+ verbose: Whether to print verbose information
273
+ tritonparse_url_prefix: URL prefix for the generated file mapping
274
+ split_inductor_compilations: Whether to split
275
+ output files by frame_id, compile_id, attempt_id, and compiled_autograd_id.
276
+ Defaults to True. This rule follows tlparse's behavior.
277
+ Returns:
278
+ Tuple of (parsed log directory, file mapping)
279
+ """
280
+
281
+ raw_log_dir = logs_to_parse
282
+ parsed_log_dir = tempfile.mkdtemp()
283
+ # Dictionary to store ranks and their log files
284
+ ranks = defaultdict(list) # Dict[Rank, List[str]]
285
+ # Find all eligible logs in the raw log directory
286
+ for item in os.listdir(raw_log_dir):
287
+ path = os.path.join(raw_log_dir, item)
288
+ if not os.path.isfile(path):
289
+ continue
290
+ log_name = f"{LOG_PREFIX}.*{rank_config.to_rank().to_string('')}"
291
+ pattern = re.compile(log_name)
292
+ if pattern.search(item):
293
+ # Check if the log has a rank in its name
294
+ rank_match = LOG_RANK_REGEX.search(item)
295
+ if rank_match:
296
+ # If we have a rank, add it to the list of ranks
297
+ rank_value = int(rank_match.group(1))
298
+ rank = Rank(rank_value)
299
+ ranks[rank].append(path)
300
+ elif rank_config.is_local:
301
+ # Local logs don't always have a rank associated with them, we can push as default
302
+ rank = Rank()
303
+ if rank in ranks:
304
+ ranks[rank].append(path)
305
+ else:
306
+ ranks[rank] = [path]
307
+ if not ranks:
308
+ raise RuntimeError(f"No eligible structured trace logs found in {raw_log_dir}")
309
+ file_mapping = {"tritonparse_url_prefix": tritonparse_url_prefix}
310
+ # Parse each eligible log
311
+ for rank, files in ranks.items():
312
+ use_filenames = False
313
+ if len(files) > 1:
314
+ logger.warning(
315
+ "Warning: multiple logs found for the same rank. Using filenames."
316
+ )
317
+ use_filenames = True
318
+ # Determine rank key for file mapping
319
+ rank_key = "rank_default" if rank.is_default else f"rank_{rank.value}"
320
+ for file_path in files:
321
+ filename = os.path.basename(file_path)
322
+ input_file = os.path.join(raw_log_dir, filename)
323
+
324
+ relative_path = ""
325
+ if use_filenames:
326
+ rank_prefix = "" if rank.is_default else f"{rank.to_string('')}/"
327
+ relative_path = f"{rank_prefix}{filename}"
328
+ else:
329
+ relative_path = rank.to_string("")
330
+ output_dir = os.path.join(parsed_log_dir, relative_path)
331
+ # Parse the file
332
+ parse_single_file(input_file, output_dir, split_inductor_compilations)
333
+ # Collect generated files after parsing and gzip them immediately
334
+ if os.path.exists(output_dir):
335
+ generated_files = []
336
+ mapped_file = None
337
+
338
+ for generated_item in os.listdir(output_dir):
339
+ generated_path = os.path.join(output_dir, generated_item)
340
+ if os.path.isfile(generated_path):
341
+ # Gzip the file immediately after parsing
342
+ gz_file_path = gzip_single_file(generated_path, verbose)
343
+ gz_filename = os.path.basename(gz_file_path)
344
+ # Check if it's a mapped file (assuming files with 'mapped' in name)
345
+ if "mapped" in generated_item.lower():
346
+ mapped_file = gz_filename
347
+ else:
348
+ generated_files.append(gz_filename)
349
+ # Initialize rank entry if not exists
350
+ if rank_key not in file_mapping:
351
+ file_mapping[rank_key] = {"regular_files": [], "mapped_file": None}
352
+ # Add files to the mapping (now with .gz extensions)
353
+ file_mapping[rank_key]["regular_files"].extend(generated_files)
354
+ # this is used to generate the tritonparse url
355
+ file_mapping[rank_key]["rank_suffix"] = rank_config.to_rank().to_string(
356
+ suffix="/"
357
+ )
358
+ if mapped_file:
359
+ file_mapping[rank_key]["mapped_file"] = mapped_file
360
+
361
+ # Clean up the file mapping - remove None mapped_files and ensure no duplicates
362
+ for rank_key, rank_data in file_mapping.items():
363
+ if rank_key != "tritonparse_url_prefix":
364
+ # Remove duplicates from regular_files
365
+ rank_data["regular_files"] = list(set(rank_data["regular_files"]))
366
+ # Remove mapped_file if None
367
+ if rank_data["mapped_file"] is None:
368
+ del rank_data["mapped_file"]
369
+ # Save file mapping to parsed_log_dir
370
+ log_file_list_path = os.path.join(parsed_log_dir, "log_file_list.json")
371
+ with open(log_file_list_path, "w") as f:
372
+ json.dump(file_mapping, f, indent=2)
373
+
374
+ # NOTICE: this print is required for tlparser-tritonparse integration
375
+ # DON'T REMOVE THIS PRINT
376
+ print(f"tritonparse log file list: {log_file_list_path}")
377
+ return parsed_log_dir, file_mapping
378
+
379
+
380
+ def save_logs(out_dir: Path, parsed_logs: str, overwrite: bool, verbose: bool) -> None:
381
+ """
382
+ Save logs to a local directory.
383
+
384
+ Args:
385
+ out_dir: Path to output directory
386
+ parsed_logs: Path to directory containing parsed logs
387
+ overwrite: Whether to overwrite existing logs
388
+ verbose: Whether to print verbose information
389
+ """
390
+ if not out_dir.is_absolute():
391
+ out_dir = out_dir.resolve()
392
+
393
+ os.makedirs(out_dir, exist_ok=True)
394
+
395
+ logger.info(f"Copying parsed logs from {parsed_logs} to {out_dir}")
396
+
397
+ # Copy each item in the parsed_logs directory to the output directory
398
+ for item in os.listdir(parsed_logs):
399
+ src_path = os.path.join(parsed_logs, item)
400
+ dst_path = os.path.join(out_dir, item)
401
+
402
+ if os.path.isdir(src_path):
403
+ if verbose:
404
+ logger.info(f"Copying directory {src_path}/ to {dst_path}/")
405
+ shutil.copytree(src_path, dst_path)
406
+ else:
407
+ if verbose:
408
+ logger.info(f"Copying file from {src_path} to {dst_path}")
409
+ shutil.copy2(src_path, dst_path)
@@ -0,0 +1,64 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import os
4
+ import shutil
5
+ import tempfile
6
+
7
+ from .shared_vars import TEST_KEEP_OUTPUT
8
+ from .structured_logging import clear_logging_config, init
9
+ from .utils import unified_parse
10
+
11
+
12
+ def createUniqueTempDirectory():
13
+ return tempfile.mkdtemp()
14
+
15
+
16
+ class TritonParseManager:
17
+ def __init__(
18
+ self,
19
+ enable_trace_launch=False,
20
+ split_inductor_compilations=True,
21
+ enable_tensor_blob_storage=False,
22
+ tensor_storage_quota=None,
23
+ **parse_kwargs,
24
+ ):
25
+ """
26
+ Context manager for tritonparse workflow.
27
+
28
+ Args:
29
+ enable_trace_launch: Whether to enable trace launch
30
+ split_inductor_compilations: Whether to split inductor compilations in the output
31
+ enable_tensor_blob_storage: Whether to enable tensor blob storage
32
+ tensor_storage_quota: Storage quota in bytes for tensor blobs (default: 100GB)
33
+ **parse_kwargs: Additional keyword arguments to pass to unified_parse
34
+ """
35
+ self.enable_trace_launch = enable_trace_launch
36
+ self.split_inductor_compilations = split_inductor_compilations
37
+ self.enable_tensor_blob_storage = enable_tensor_blob_storage
38
+ self.tensor_storage_quota = tensor_storage_quota
39
+ self.parse_kwargs = parse_kwargs
40
+ self.dir_path = None
41
+ self.output_link = None
42
+
43
+ def __enter__(self):
44
+ self.dir_path = createUniqueTempDirectory()
45
+ init_kwargs = {
46
+ "enable_trace_launch": self.enable_trace_launch,
47
+ "enable_tensor_blob_storage": self.enable_tensor_blob_storage,
48
+ }
49
+ if self.tensor_storage_quota is not None:
50
+ init_kwargs["tensor_storage_quota"] = self.tensor_storage_quota
51
+
52
+ init(self.dir_path, **init_kwargs)
53
+ return self
54
+
55
+ def __exit__(self, exc_type, exc_val, exc_tb):
56
+ self.output_link = unified_parse(
57
+ source=self.dir_path,
58
+ overwrite=True,
59
+ split_inductor_compilations=self.split_inductor_compilations,
60
+ **self.parse_kwargs,
61
+ )
62
+ clear_logging_config()
63
+ if os.path.exists(self.dir_path) and not TEST_KEEP_OUTPUT:
64
+ shutil.rmtree(self.dir_path)