tritonparse 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/__init__.py +0 -0
- tritonparse/common.py +409 -0
- tritonparse/event_diff.py +120 -0
- tritonparse/extract_source_mappings.py +49 -0
- tritonparse/ir_parser.py +220 -0
- tritonparse/mapper.py +100 -0
- tritonparse/reproducer/__init__.py +21 -0
- tritonparse/reproducer/__main__.py +81 -0
- tritonparse/reproducer/cli.py +37 -0
- tritonparse/reproducer/config.py +15 -0
- tritonparse/reproducer/factory.py +16 -0
- tritonparse/reproducer/ingestion/__init__.py +6 -0
- tritonparse/reproducer/ingestion/ndjson.py +165 -0
- tritonparse/reproducer/orchestrator.py +65 -0
- tritonparse/reproducer/param_generator.py +142 -0
- tritonparse/reproducer/prompts/__init__.py +1 -0
- tritonparse/reproducer/prompts/loader.py +18 -0
- tritonparse/reproducer/providers/__init__.py +1 -0
- tritonparse/reproducer/providers/base.py +14 -0
- tritonparse/reproducer/providers/gemini.py +47 -0
- tritonparse/reproducer/runtime/__init__.py +1 -0
- tritonparse/reproducer/runtime/executor.py +13 -0
- tritonparse/reproducer/utils/io.py +6 -0
- tritonparse/shared_vars.py +9 -0
- tritonparse/source_type.py +56 -0
- tritonparse/sourcemap_utils.py +72 -0
- tritonparse/structured_logging.py +1046 -0
- tritonparse/tools/__init__.py +0 -0
- tritonparse/tools/decompress_bin_ndjson.py +118 -0
- tritonparse/tools/format_fix.py +149 -0
- tritonparse/tools/load_tensor.py +58 -0
- tritonparse/tools/prettify_ndjson.py +315 -0
- tritonparse/tp_logger.py +9 -0
- tritonparse/trace_processor.py +331 -0
- tritonparse/utils.py +156 -0
- tritonparse-0.1.1.dist-info/METADATA +10 -0
- tritonparse-0.1.1.dist-info/RECORD +40 -0
- tritonparse-0.1.1.dist-info/WHEEL +5 -0
- tritonparse-0.1.1.dist-info/licenses/LICENSE +29 -0
- tritonparse-0.1.1.dist-info/top_level.txt +1 -0
tritonparse/__init__.py
ADDED
|
File without changes
|
tritonparse/common.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
import gzip
|
|
4
|
+
|
|
5
|
+
import importlib
|
|
6
|
+
import importlib.util
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
import re
|
|
10
|
+
import shutil
|
|
11
|
+
import tempfile
|
|
12
|
+
from collections import defaultdict
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Optional, Tuple
|
|
15
|
+
|
|
16
|
+
from .extract_source_mappings import parse_single_file
|
|
17
|
+
from .shared_vars import DEFAULT_TRACE_FILE_PREFIX_WITHOUT_USER as LOG_PREFIX
|
|
18
|
+
from .tp_logger import logger
|
|
19
|
+
|
|
20
|
+
LOG_RANK_REGEX = re.compile(r"rank_(\d+)")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def is_fbcode():
|
|
24
|
+
return importlib.util.find_spec("tritonparse.fb") is not None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
if is_fbcode():
|
|
28
|
+
from tritonparse.fb.source_type import SourceType
|
|
29
|
+
else:
|
|
30
|
+
from tritonparse.source_type import SourceType
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Rank:
|
|
34
|
+
"""Class representing a rank in distributed training."""
|
|
35
|
+
|
|
36
|
+
def __init__(self, rank_value: Optional[int] = None):
|
|
37
|
+
"""
|
|
38
|
+
Initialize a Rank object.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
rank_value: Specific rank value, or None for default rank
|
|
42
|
+
"""
|
|
43
|
+
if rank_value is not None:
|
|
44
|
+
self.value = rank_value
|
|
45
|
+
self.is_default = False
|
|
46
|
+
else:
|
|
47
|
+
self.value = 0
|
|
48
|
+
self.is_default = True
|
|
49
|
+
|
|
50
|
+
def to_string(self, prefix: str = "", suffix: str = "") -> str:
|
|
51
|
+
"""
|
|
52
|
+
Convert rank to string representation with optional prefix.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
prefix: Prefix to add before rank string
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
String representation of the rank
|
|
59
|
+
"""
|
|
60
|
+
if self.is_default:
|
|
61
|
+
return ""
|
|
62
|
+
return f"{prefix}rank_{self.value}{suffix}"
|
|
63
|
+
|
|
64
|
+
def to_int(self) -> int:
|
|
65
|
+
"""
|
|
66
|
+
Convert rank to integer value.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Integer value of the rank
|
|
70
|
+
"""
|
|
71
|
+
return self.value
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class RankConfig:
|
|
75
|
+
"""Configuration for handling ranks in log processing."""
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
rank: Optional[Rank] = None,
|
|
80
|
+
all_ranks: bool = False,
|
|
81
|
+
is_local: bool = False,
|
|
82
|
+
):
|
|
83
|
+
"""
|
|
84
|
+
Initialize a RankConfig object.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
rank: Specific rank to process
|
|
88
|
+
all_ranks: Whether to process all ranks
|
|
89
|
+
is_local: Whether processing local logs
|
|
90
|
+
"""
|
|
91
|
+
self.rank = rank
|
|
92
|
+
self.all_ranks = all_ranks
|
|
93
|
+
self.is_local = is_local
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def from_cli_args(
|
|
97
|
+
cls, rank: Optional[int], all_ranks: bool, source_type: SourceType
|
|
98
|
+
) -> "RankConfig":
|
|
99
|
+
"""
|
|
100
|
+
Create a RankConfig from command line arguments.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
rank: Specific rank value from CLI
|
|
104
|
+
all_ranks: Whether --all-ranks flag was specified
|
|
105
|
+
source_type: Type of source
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Configured RankConfig object
|
|
109
|
+
"""
|
|
110
|
+
if all_ranks:
|
|
111
|
+
if rank is not None:
|
|
112
|
+
raise ValueError("Can't specify both a rank and --all-ranks")
|
|
113
|
+
return cls(all_ranks=True)
|
|
114
|
+
|
|
115
|
+
if rank is not None:
|
|
116
|
+
return cls(rank=Rank(rank))
|
|
117
|
+
if source_type in [SourceType.LOCAL, SourceType.LOCAL_FILE]:
|
|
118
|
+
return cls(is_local=True)
|
|
119
|
+
elif is_fbcode():
|
|
120
|
+
from tritonparse.fb.utils import rank_config_from_cli_args
|
|
121
|
+
|
|
122
|
+
return rank_config_from_cli_args(cls, source_type)
|
|
123
|
+
else:
|
|
124
|
+
return cls(all_ranks=True)
|
|
125
|
+
|
|
126
|
+
def to_rank(self) -> Rank:
|
|
127
|
+
"""
|
|
128
|
+
Get the rank object from this config.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Rank object
|
|
132
|
+
"""
|
|
133
|
+
if self.rank:
|
|
134
|
+
return self.rank
|
|
135
|
+
return Rank()
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def print_parsed_files_summary(parsed_log_dir: str) -> None:
|
|
139
|
+
"""
|
|
140
|
+
Print a beautiful summary of all parsed files.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
parsed_log_dir: Directory containing parsed files
|
|
144
|
+
"""
|
|
145
|
+
# Collect all parsed files
|
|
146
|
+
all_parsed_files = []
|
|
147
|
+
for root, _, files in os.walk(parsed_log_dir):
|
|
148
|
+
for file in files:
|
|
149
|
+
file_path = os.path.join(root, file)
|
|
150
|
+
all_parsed_files.append(file_path)
|
|
151
|
+
|
|
152
|
+
# Sort files for consistent output
|
|
153
|
+
all_parsed_files.sort()
|
|
154
|
+
|
|
155
|
+
# Print beautiful summary
|
|
156
|
+
print("\n" + "=" * 80)
|
|
157
|
+
print("📁 TRITONPARSE PARSING RESULTS")
|
|
158
|
+
print("=" * 80)
|
|
159
|
+
|
|
160
|
+
# Print log file list (required for integration)
|
|
161
|
+
print(f"📂 Parsed files directory: {parsed_log_dir}")
|
|
162
|
+
print(f"📊 Total files generated: {len(all_parsed_files)}")
|
|
163
|
+
|
|
164
|
+
if all_parsed_files:
|
|
165
|
+
print("\n📄 Generated files:")
|
|
166
|
+
print("-" * 50)
|
|
167
|
+
for i, file_path in enumerate(all_parsed_files, 1):
|
|
168
|
+
# Get relative path for cleaner display
|
|
169
|
+
rel_path = os.path.relpath(file_path, parsed_log_dir)
|
|
170
|
+
file_size = "N/A"
|
|
171
|
+
try:
|
|
172
|
+
size_bytes = os.path.getsize(file_path)
|
|
173
|
+
if size_bytes < 1024:
|
|
174
|
+
file_size = f"{size_bytes}B"
|
|
175
|
+
elif size_bytes < 1024 * 1024:
|
|
176
|
+
file_size = f"{size_bytes / 1024:.1f}KB"
|
|
177
|
+
else:
|
|
178
|
+
file_size = f"{size_bytes / (1024 * 1024):.1f}MB"
|
|
179
|
+
except OSError:
|
|
180
|
+
pass
|
|
181
|
+
|
|
182
|
+
print(f" {i:2d}. 📝 {rel_path} ({file_size})")
|
|
183
|
+
|
|
184
|
+
print("=" * 80)
|
|
185
|
+
print("✅ Parsing completed successfully!")
|
|
186
|
+
print("=" * 80 + "\n")
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def gzip_single_file(file_path: str, verbose: bool = False) -> str:
|
|
190
|
+
"""
|
|
191
|
+
Gzip a single file and delete the original file.
|
|
192
|
+
Args:
|
|
193
|
+
file_path: Path to the file to gzip
|
|
194
|
+
verbose: Whether to print verbose information
|
|
195
|
+
Returns:
|
|
196
|
+
Path to the gzipped file
|
|
197
|
+
"""
|
|
198
|
+
if file_path.endswith(".gz"):
|
|
199
|
+
return file_path
|
|
200
|
+
|
|
201
|
+
gz_file_path = file_path + ".gz"
|
|
202
|
+
if verbose:
|
|
203
|
+
logger.info(f"Gzipping {file_path}")
|
|
204
|
+
|
|
205
|
+
with open(file_path, "rb") as f_in:
|
|
206
|
+
with gzip.open(gz_file_path, "wb") as f_out:
|
|
207
|
+
shutil.copyfileobj(f_in, f_out)
|
|
208
|
+
|
|
209
|
+
# Delete the original file after successful compression
|
|
210
|
+
os.remove(file_path)
|
|
211
|
+
if verbose:
|
|
212
|
+
logger.info(f"Deleted original file {file_path}")
|
|
213
|
+
|
|
214
|
+
return gz_file_path
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def copy_local_to_tmpdir(local_path: str, verbose: bool = False) -> str:
|
|
218
|
+
"""
|
|
219
|
+
Copy local log files to a temporary directory.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
local_path: Path to local directory or single file containing logs
|
|
223
|
+
verbose: Whether to print verbose information
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Path to temporary directory containing copied logs
|
|
227
|
+
|
|
228
|
+
Raises:
|
|
229
|
+
RuntimeError: If the local_path does not exist
|
|
230
|
+
"""
|
|
231
|
+
if not os.path.exists(local_path):
|
|
232
|
+
raise RuntimeError(f"Path does not exist: {local_path}")
|
|
233
|
+
|
|
234
|
+
temp_dir = tempfile.mkdtemp()
|
|
235
|
+
|
|
236
|
+
# Handle single file case
|
|
237
|
+
if os.path.isfile(local_path):
|
|
238
|
+
if os.path.basename(local_path).startswith(LOG_PREFIX):
|
|
239
|
+
if verbose:
|
|
240
|
+
logger.info(f"Copying single file {local_path} to {temp_dir}")
|
|
241
|
+
shutil.copy2(local_path, temp_dir)
|
|
242
|
+
return temp_dir
|
|
243
|
+
|
|
244
|
+
# Handle directory case
|
|
245
|
+
if not os.path.isdir(local_path):
|
|
246
|
+
raise RuntimeError(f"Path is neither a file nor a directory: {local_path}")
|
|
247
|
+
|
|
248
|
+
for item in os.listdir(local_path):
|
|
249
|
+
item_path = os.path.join(local_path, item)
|
|
250
|
+
if os.path.isfile(item_path) and os.path.basename(item_path).startswith(
|
|
251
|
+
LOG_PREFIX
|
|
252
|
+
):
|
|
253
|
+
if verbose:
|
|
254
|
+
logger.info(f"Copying {item_path} to {temp_dir}")
|
|
255
|
+
shutil.copy2(item_path, temp_dir)
|
|
256
|
+
|
|
257
|
+
return temp_dir
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def parse_logs(
|
|
261
|
+
logs_to_parse: str,
|
|
262
|
+
rank_config: RankConfig,
|
|
263
|
+
verbose: bool = False,
|
|
264
|
+
tritonparse_url_prefix: str = "",
|
|
265
|
+
) -> Tuple[str, dict]:
|
|
266
|
+
"""
|
|
267
|
+
Parse logs.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
logs_to_parse: Path to directory containing logs to parse
|
|
271
|
+
rank_config: Rank configuration
|
|
272
|
+
verbose: Whether to print verbose information
|
|
273
|
+
tritonparse_url_prefix: URL prefix for the generated file mapping
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
Tuple of (parsed log directory, file mapping)
|
|
277
|
+
"""
|
|
278
|
+
|
|
279
|
+
raw_log_dir = logs_to_parse
|
|
280
|
+
parsed_log_dir = tempfile.mkdtemp()
|
|
281
|
+
# Dictionary to store ranks and their log files
|
|
282
|
+
ranks = defaultdict(list) # Dict[Rank, List[str]]
|
|
283
|
+
# Find all eligible logs in the raw log directory
|
|
284
|
+
for item in os.listdir(raw_log_dir):
|
|
285
|
+
path = os.path.join(raw_log_dir, item)
|
|
286
|
+
if not os.path.isfile(path):
|
|
287
|
+
continue
|
|
288
|
+
log_name = f"{LOG_PREFIX}.*{rank_config.to_rank().to_string('')}"
|
|
289
|
+
pattern = re.compile(log_name)
|
|
290
|
+
if pattern.search(item):
|
|
291
|
+
# Check if the log has a rank in its name
|
|
292
|
+
rank_match = LOG_RANK_REGEX.search(item)
|
|
293
|
+
if rank_match:
|
|
294
|
+
# If we have a rank, add it to the list of ranks
|
|
295
|
+
rank_value = int(rank_match.group(1))
|
|
296
|
+
rank = Rank(rank_value)
|
|
297
|
+
ranks[rank].append(path)
|
|
298
|
+
elif rank_config.is_local:
|
|
299
|
+
# Local logs don't always have a rank associated with them, we can push as default
|
|
300
|
+
rank = Rank()
|
|
301
|
+
if rank in ranks:
|
|
302
|
+
ranks[rank].append(path)
|
|
303
|
+
else:
|
|
304
|
+
ranks[rank] = [path]
|
|
305
|
+
if not ranks:
|
|
306
|
+
raise RuntimeError(f"No eligible structured trace logs found in {raw_log_dir}")
|
|
307
|
+
file_mapping = {"tritonparse_url_prefix": tritonparse_url_prefix}
|
|
308
|
+
# Parse each eligible log
|
|
309
|
+
for rank, files in ranks.items():
|
|
310
|
+
use_filenames = False
|
|
311
|
+
if len(files) > 1:
|
|
312
|
+
logger.warning(
|
|
313
|
+
"Warning: multiple logs found for the same rank. Using filenames."
|
|
314
|
+
)
|
|
315
|
+
use_filenames = True
|
|
316
|
+
# Determine rank key for file mapping
|
|
317
|
+
rank_key = "rank_default" if rank.is_default else f"rank_{rank.value}"
|
|
318
|
+
for file_path in files:
|
|
319
|
+
filename = os.path.basename(file_path)
|
|
320
|
+
input_file = os.path.join(raw_log_dir, filename)
|
|
321
|
+
|
|
322
|
+
relative_path = ""
|
|
323
|
+
if use_filenames:
|
|
324
|
+
rank_prefix = "" if rank.is_default else f"{rank.to_string('')}/"
|
|
325
|
+
relative_path = f"{rank_prefix}{filename}"
|
|
326
|
+
else:
|
|
327
|
+
relative_path = rank.to_string("")
|
|
328
|
+
output_dir = os.path.join(parsed_log_dir, relative_path)
|
|
329
|
+
# Parse the file
|
|
330
|
+
parse_single_file(input_file, output_dir)
|
|
331
|
+
# Collect generated files after parsing and gzip them immediately
|
|
332
|
+
if os.path.exists(output_dir):
|
|
333
|
+
generated_files = []
|
|
334
|
+
mapped_file = None
|
|
335
|
+
|
|
336
|
+
for generated_item in os.listdir(output_dir):
|
|
337
|
+
generated_path = os.path.join(output_dir, generated_item)
|
|
338
|
+
if os.path.isfile(generated_path):
|
|
339
|
+
# Gzip the file immediately after parsing
|
|
340
|
+
gz_file_path = gzip_single_file(generated_path, verbose)
|
|
341
|
+
gz_filename = os.path.basename(gz_file_path)
|
|
342
|
+
# Check if it's a mapped file (assuming files with 'mapped' in name)
|
|
343
|
+
if "mapped" in generated_item.lower():
|
|
344
|
+
mapped_file = gz_filename
|
|
345
|
+
else:
|
|
346
|
+
generated_files.append(gz_filename)
|
|
347
|
+
# Initialize rank entry if not exists
|
|
348
|
+
if rank_key not in file_mapping:
|
|
349
|
+
file_mapping[rank_key] = {"regular_files": [], "mapped_file": None}
|
|
350
|
+
# Add files to the mapping (now with .gz extensions)
|
|
351
|
+
file_mapping[rank_key]["regular_files"].extend(generated_files)
|
|
352
|
+
# this is used to generate the tritonparse url
|
|
353
|
+
file_mapping[rank_key]["rank_suffix"] = rank_config.to_rank().to_string(
|
|
354
|
+
suffix="/"
|
|
355
|
+
)
|
|
356
|
+
if mapped_file:
|
|
357
|
+
file_mapping[rank_key]["mapped_file"] = mapped_file
|
|
358
|
+
|
|
359
|
+
# Clean up the file mapping - remove None mapped_files and ensure no duplicates
|
|
360
|
+
for rank_key in file_mapping:
|
|
361
|
+
if rank_key != "tritonparse_url_prefix":
|
|
362
|
+
# Remove duplicates from regular_files
|
|
363
|
+
file_mapping[rank_key]["regular_files"] = list(
|
|
364
|
+
set(file_mapping[rank_key]["regular_files"])
|
|
365
|
+
)
|
|
366
|
+
# Remove mapped_file if None
|
|
367
|
+
if file_mapping[rank_key]["mapped_file"] is None:
|
|
368
|
+
del file_mapping[rank_key]["mapped_file"]
|
|
369
|
+
# Save file mapping to parsed_log_dir
|
|
370
|
+
log_file_list_path = os.path.join(parsed_log_dir, "log_file_list.json")
|
|
371
|
+
with open(log_file_list_path, "w") as f:
|
|
372
|
+
json.dump(file_mapping, f, indent=2)
|
|
373
|
+
|
|
374
|
+
# NOTICE: this print is required for tlparser-tritonparse integration
|
|
375
|
+
# DON'T REMOVE THIS PRINT
|
|
376
|
+
print(f"tritonparse log file list: {log_file_list_path}")
|
|
377
|
+
return parsed_log_dir, file_mapping
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def save_logs(out_dir: Path, parsed_logs: str, overwrite: bool, verbose: bool) -> None:
|
|
381
|
+
"""
|
|
382
|
+
Save logs to a local directory.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
out_dir: Path to output directory
|
|
386
|
+
parsed_logs: Path to directory containing parsed logs
|
|
387
|
+
overwrite: Whether to overwrite existing logs
|
|
388
|
+
verbose: Whether to print verbose information
|
|
389
|
+
"""
|
|
390
|
+
if not out_dir.is_absolute():
|
|
391
|
+
out_dir = out_dir.resolve()
|
|
392
|
+
|
|
393
|
+
os.makedirs(out_dir, exist_ok=True)
|
|
394
|
+
|
|
395
|
+
logger.info(f"Copying parsed logs from {parsed_logs} to {out_dir}")
|
|
396
|
+
|
|
397
|
+
# Copy each item in the parsed_logs directory to the output directory
|
|
398
|
+
for item in os.listdir(parsed_logs):
|
|
399
|
+
src_path = os.path.join(parsed_logs, item)
|
|
400
|
+
dst_path = os.path.join(out_dir, item)
|
|
401
|
+
|
|
402
|
+
if os.path.isdir(src_path):
|
|
403
|
+
if verbose:
|
|
404
|
+
logger.info(f"Copying directory {src_path}/ to {dst_path}/")
|
|
405
|
+
shutil.copytree(src_path, dst_path)
|
|
406
|
+
else:
|
|
407
|
+
if verbose:
|
|
408
|
+
logger.info(f"Copying file from {src_path} to {dst_path}")
|
|
409
|
+
shutil.copy2(src_path, dst_path)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import Any, Dict, List, Tuple
|
|
4
|
+
|
|
5
|
+
from .sourcemap_utils import _flatten_dict, _to_ranges, _unflatten_dict
|
|
6
|
+
|
|
7
|
+
# Fields that are expected to vary but are not useful to list out in the diff.
|
|
8
|
+
SUMMARY_FIELDS = ["pid", "timestamp", "stream", "function", "data_ptr"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _generate_launch_diff(
|
|
12
|
+
launches: List[Tuple[Dict[str, Any], int]],
|
|
13
|
+
) -> Tuple[Dict[str, Any], Dict[str, Any], List[Dict[str, int]]]:
|
|
14
|
+
"""
|
|
15
|
+
Compares a list of launch events and returns sames, diffs, and an index map.
|
|
16
|
+
"""
|
|
17
|
+
if not launches:
|
|
18
|
+
return {}, {}, []
|
|
19
|
+
|
|
20
|
+
launch_events = [launch[0] for launch in launches]
|
|
21
|
+
launch_index_map = [launch[1] for launch in launches]
|
|
22
|
+
|
|
23
|
+
if len(launch_events) == 1:
|
|
24
|
+
return (
|
|
25
|
+
_unflatten_dict(_flatten_dict(launch_events[0])),
|
|
26
|
+
{},
|
|
27
|
+
_to_ranges(launch_index_map),
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# Group values by key
|
|
31
|
+
data_by_key = defaultdict(lambda: defaultdict(list))
|
|
32
|
+
for i, launch in enumerate(launch_events):
|
|
33
|
+
launch_flat = _flatten_dict(launch)
|
|
34
|
+
for key, value in launch_flat.items():
|
|
35
|
+
# JSON doesn't support all Python types as values directly, str is safer
|
|
36
|
+
value_str = json.dumps(value, sort_keys=True)
|
|
37
|
+
data_by_key[key][value_str].append(i)
|
|
38
|
+
|
|
39
|
+
sames_flat = {}
|
|
40
|
+
diffs_flat = {}
|
|
41
|
+
|
|
42
|
+
for key, value_groups in data_by_key.items():
|
|
43
|
+
if len(value_groups) == 1:
|
|
44
|
+
# This key has the same value across all launches
|
|
45
|
+
value_str = list(value_groups.keys())[0]
|
|
46
|
+
sames_flat[key] = json.loads(value_str)
|
|
47
|
+
else:
|
|
48
|
+
# This key has different values
|
|
49
|
+
is_summary = any(summary_key in key for summary_key in SUMMARY_FIELDS)
|
|
50
|
+
if is_summary:
|
|
51
|
+
diffs_flat[key] = {
|
|
52
|
+
"diff_type": "summary",
|
|
53
|
+
"summary_text": f"Varies across {len(value_groups)} unique values",
|
|
54
|
+
}
|
|
55
|
+
else:
|
|
56
|
+
values_dist = []
|
|
57
|
+
for value_str, indices in value_groups.items():
|
|
58
|
+
values_dist.append(
|
|
59
|
+
{
|
|
60
|
+
"value": json.loads(value_str),
|
|
61
|
+
"count": len(indices),
|
|
62
|
+
"launches": _to_ranges(indices),
|
|
63
|
+
}
|
|
64
|
+
)
|
|
65
|
+
# Sort by first occurrence
|
|
66
|
+
values_dist.sort(key=lambda x: x["launches"][0]["start"])
|
|
67
|
+
diffs_flat[key] = {
|
|
68
|
+
"diff_type": "distribution",
|
|
69
|
+
"values": values_dist,
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
# Unflatten the results
|
|
73
|
+
sames_unflattened = _unflatten_dict(sames_flat)
|
|
74
|
+
diffs_unflattened = _unflatten_dict(diffs_flat)
|
|
75
|
+
|
|
76
|
+
# Special handling for extracted_args to create argument_diff structures
|
|
77
|
+
if "extracted_args" in sames_unflattened or "extracted_args" in diffs_unflattened:
|
|
78
|
+
sames_args = sames_unflattened.pop("extracted_args", {})
|
|
79
|
+
diffs_args_flat = diffs_unflattened.pop("extracted_args", {})
|
|
80
|
+
|
|
81
|
+
all_arg_names = set(sames_args.keys()) | set(diffs_args_flat.keys())
|
|
82
|
+
|
|
83
|
+
final_arg_diffs = {}
|
|
84
|
+
|
|
85
|
+
for arg_name in all_arg_names:
|
|
86
|
+
if arg_name in diffs_args_flat:
|
|
87
|
+
# This argument has at least one differing sub-field.
|
|
88
|
+
arg_sames = {}
|
|
89
|
+
arg_diffs_internal = {}
|
|
90
|
+
|
|
91
|
+
# Collect all sub-fields for this argument from the original data
|
|
92
|
+
all_sub_fields = set()
|
|
93
|
+
for launch in launch_events:
|
|
94
|
+
arg_data = launch.get("extracted_args", {}).get(arg_name, {})
|
|
95
|
+
all_sub_fields.update(arg_data.keys())
|
|
96
|
+
|
|
97
|
+
for sub_field in all_sub_fields:
|
|
98
|
+
flat_key = f"extracted_args.{arg_name}.{sub_field}"
|
|
99
|
+
if flat_key in diffs_flat:
|
|
100
|
+
arg_diffs_internal[sub_field] = diffs_flat[flat_key]
|
|
101
|
+
elif flat_key in sames_flat:
|
|
102
|
+
arg_sames[sub_field] = sames_flat[flat_key]
|
|
103
|
+
|
|
104
|
+
if arg_sames or arg_diffs_internal:
|
|
105
|
+
final_arg_diffs[arg_name] = {
|
|
106
|
+
"diff_type": "argument_diff",
|
|
107
|
+
"sames": arg_sames,
|
|
108
|
+
"diffs": arg_diffs_internal,
|
|
109
|
+
}
|
|
110
|
+
elif arg_name in sames_args:
|
|
111
|
+
# This argument is entirely the same across all launches.
|
|
112
|
+
# We move it back to the main sames dict for consistency.
|
|
113
|
+
if "extracted_args" not in sames_unflattened:
|
|
114
|
+
sames_unflattened["extracted_args"] = {}
|
|
115
|
+
sames_unflattened["extracted_args"][arg_name] = sames_args[arg_name]
|
|
116
|
+
|
|
117
|
+
if final_arg_diffs:
|
|
118
|
+
diffs_unflattened["extracted_args"] = final_arg_diffs
|
|
119
|
+
|
|
120
|
+
return sames_unflattened, diffs_unflattened, _to_ranges(launch_index_map)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Extract source code mappings from Triton trace files and update the original JSON.
|
|
6
|
+
This script reads a JSON trace file containing Triton IR (TTIR, TTGIR) and PTX(AMDGCN),
|
|
7
|
+
and extracts bidirectional mappings between:
|
|
8
|
+
- Python ↔ TTIR
|
|
9
|
+
- Python ↔ TTGIR
|
|
10
|
+
- Python ↔ PTX(AMDGCN)
|
|
11
|
+
- TTIR ↔ TTGIR
|
|
12
|
+
- TTIR ↔ PTX(AMDGCN)
|
|
13
|
+
- TTGIR ↔ PTX(AMDGCN)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from .trace_processor import parse_single_file
|
|
20
|
+
|
|
21
|
+
logging.basicConfig(level=logging.INFO)
|
|
22
|
+
logger = logging.getLogger("SourceMapping")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def parse_args():
|
|
26
|
+
parser = argparse.ArgumentParser(
|
|
27
|
+
description="Extract source code mappings from Triton trace files."
|
|
28
|
+
)
|
|
29
|
+
parser.add_argument("-i", "--input", help="Path to the Triton trace NDJSON file")
|
|
30
|
+
parser.add_argument(
|
|
31
|
+
"--output-dir",
|
|
32
|
+
default=None,
|
|
33
|
+
help="Directory to save the output files. If not specified, the input file's directory will be used.",
|
|
34
|
+
)
|
|
35
|
+
parser.add_argument(
|
|
36
|
+
"-o",
|
|
37
|
+
"--output",
|
|
38
|
+
default=None,
|
|
39
|
+
help="Output NDJSON path. If it is None, the default output file name will be set to {input}_mapped.ndjson in the parse function.",
|
|
40
|
+
)
|
|
41
|
+
return parser.parse_args()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
if __name__ == "__main__":
|
|
45
|
+
args = parse_args()
|
|
46
|
+
if args.input:
|
|
47
|
+
parse_single_file(args.input, args.output_dir)
|
|
48
|
+
else:
|
|
49
|
+
logger.error("No input file specified.")
|