tritonparse 0.3.2.dev20251210071601__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/__init__.py +0 -0
- tritonparse/__main__.py +7 -0
- tritonparse/cli.py +110 -0
- tritonparse/common.py +409 -0
- tritonparse/context_manager.py +64 -0
- tritonparse/event_diff.py +122 -0
- tritonparse/extract_source_mappings.py +49 -0
- tritonparse/info/__init__.py +30 -0
- tritonparse/info/cli.py +121 -0
- tritonparse/info/kernel_query.py +209 -0
- tritonparse/info/parse_helper.py +70 -0
- tritonparse/ir_analysis.py +427 -0
- tritonparse/ir_parser.py +365 -0
- tritonparse/mapper.py +102 -0
- tritonparse/reproducer/__init__.py +0 -0
- tritonparse/reproducer/ast_analyzer.py +636 -0
- tritonparse/reproducer/cli.py +72 -0
- tritonparse/reproducer/consolidated_result.py +52 -0
- tritonparse/reproducer/function_extractor.py +228 -0
- tritonparse/reproducer/import_info.py +25 -0
- tritonparse/reproducer/import_parser.py +178 -0
- tritonparse/reproducer/import_resolver.py +151 -0
- tritonparse/reproducer/ingestion/ndjson.py +237 -0
- tritonparse/reproducer/multi_file_analyzer.py +824 -0
- tritonparse/reproducer/orchestrator.py +110 -0
- tritonparse/reproducer/placeholder_replacer.py +335 -0
- tritonparse/reproducer/templates/__init__.py +0 -0
- tritonparse/reproducer/templates/example.py +38 -0
- tritonparse/reproducer/templates/loader.py +59 -0
- tritonparse/reproducer/templates/tritonbench.py +106 -0
- tritonparse/reproducer/templates/utils.py +48 -0
- tritonparse/reproducer/tests/__init__.py +0 -0
- tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
- tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
- tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
- tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
- tritonparse/reproducer/tests/test_import_parser.py +164 -0
- tritonparse/reproducer/tests/test_import_resolver.py +88 -0
- tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
- tritonparse/reproducer/types.py +20 -0
- tritonparse/reproducer/utils.py +580 -0
- tritonparse/shared_vars.py +12 -0
- tritonparse/source_type.py +56 -0
- tritonparse/sourcemap_utils.py +96 -0
- tritonparse/structured_logging.py +1634 -0
- tritonparse/tools/__init__.py +0 -0
- tritonparse/tools/decompress_bin_ndjson.py +120 -0
- tritonparse/tools/disasm.py +81 -0
- tritonparse/tools/extract_irs.py +244 -0
- tritonparse/tools/format_fix.py +151 -0
- tritonparse/tools/load_tensor.py +76 -0
- tritonparse/tools/prettify_ndjson.py +334 -0
- tritonparse/tools/readme.md +37 -0
- tritonparse/tp_logger.py +9 -0
- tritonparse/trace_processor.py +367 -0
- tritonparse/utils.py +155 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
|
File without changes
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Script to decompress .bin.ndjson files back to regular .ndjson format.
|
|
6
|
+
|
|
7
|
+
The .bin.ndjson format stores each JSON record as a separate gzip member,
|
|
8
|
+
concatenated in sequence within a single binary file. This script uses
|
|
9
|
+
gzip.open() which automatically handles member concatenation to read
|
|
10
|
+
the compressed file and write out the original NDJSON format.
|
|
11
|
+
|
|
12
|
+
Usage:
|
|
13
|
+
python decompress_bin_ndjson.py trace.bin.ndjson
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
17
|
+
import gzip
|
|
18
|
+
import sys
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def decompress_bin_ndjson(input_file: str, output_file: str = None) -> None:
|
|
23
|
+
"""
|
|
24
|
+
Decompress a .bin.ndjson file to regular .ndjson format.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
input_file: Path to the .bin.ndjson file
|
|
28
|
+
output_file: Path for the output .ndjson file (optional)
|
|
29
|
+
"""
|
|
30
|
+
input_path = Path(input_file)
|
|
31
|
+
|
|
32
|
+
# Validate input file
|
|
33
|
+
if not input_path.exists():
|
|
34
|
+
print(f"Error: Input file '{input_file}' does not exist", file=sys.stderr)
|
|
35
|
+
return
|
|
36
|
+
|
|
37
|
+
if not input_path.suffix.endswith(".bin.ndjson"):
|
|
38
|
+
print(f"Warning: Input file '{input_file}' doesn't have .bin.ndjson extension")
|
|
39
|
+
|
|
40
|
+
# Determine output file path
|
|
41
|
+
if output_file is None:
|
|
42
|
+
if input_path.name.endswith(".bin.ndjson"):
|
|
43
|
+
# Replace .bin.ndjson with .ndjson
|
|
44
|
+
output_file = str(input_path.with_suffix("").with_suffix(".ndjson"))
|
|
45
|
+
else:
|
|
46
|
+
# Add .decompressed.ndjson suffix
|
|
47
|
+
output_file = str(input_path.with_suffix(".decompressed.ndjson"))
|
|
48
|
+
|
|
49
|
+
output_path = Path(output_file)
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
line_count = 0
|
|
53
|
+
# Because we use NDJSON format, each line is a complete JSON record.
|
|
54
|
+
# It is guruanteed here https://github.com/meta-pytorch/tritonparse/blob/
|
|
55
|
+
# c8dcc2a94ac10ede4342dba7456f6ebd8409b95d/tritonparse/structured_logging.py#L320
|
|
56
|
+
with gzip.open(input_path, "rt", encoding="utf-8") as compressed_file:
|
|
57
|
+
with open(output_path, "w", encoding="utf-8") as output:
|
|
58
|
+
for line in compressed_file:
|
|
59
|
+
# gzip.open automatically handles member concatenation
|
|
60
|
+
# Each line is already a complete JSON record with newline
|
|
61
|
+
output.write(line)
|
|
62
|
+
line_count += 1
|
|
63
|
+
|
|
64
|
+
# Get file sizes for comparison
|
|
65
|
+
input_size = input_path.stat().st_size
|
|
66
|
+
output_size = output_path.stat().st_size
|
|
67
|
+
compression_ratio = (
|
|
68
|
+
(1 - input_size / output_size) * 100 if output_size > 0 else 0
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
print(f"Successfully decompressed '{input_file}' to '{output_file}'")
|
|
72
|
+
print(f" Input size: {input_size:,} bytes")
|
|
73
|
+
print(f" Output size: {output_size:,} bytes")
|
|
74
|
+
print(f" Compression ratio: {compression_ratio:.1f}%")
|
|
75
|
+
print(f" Records processed: {line_count:,}")
|
|
76
|
+
|
|
77
|
+
except gzip.BadGzipFile as e:
|
|
78
|
+
print(f"Error: Invalid gzip format in '{input_file}': {e}", file=sys.stderr)
|
|
79
|
+
except UnicodeDecodeError as e:
|
|
80
|
+
print(f"Error: Unicode decode error in '{input_file}': {e}", file=sys.stderr)
|
|
81
|
+
except Exception as e:
|
|
82
|
+
print(f"Error: Failed to decompress '{input_file}': {e}", file=sys.stderr)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def main():
|
|
86
|
+
parser = argparse.ArgumentParser(
|
|
87
|
+
description="Decompress .bin.ndjson files to regular .ndjson format",
|
|
88
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
89
|
+
epilog="""
|
|
90
|
+
Examples:
|
|
91
|
+
%(prog)s trace.bin.ndjson
|
|
92
|
+
%(prog)s trace.bin.ndjson -o output.ndjson
|
|
93
|
+
%(prog)s /logs/dedicated_log_triton_trace_user_.bin.ndjson
|
|
94
|
+
""",
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
parser.add_argument("input_file", help="Input .bin.ndjson file to decompress")
|
|
98
|
+
|
|
99
|
+
parser.add_argument(
|
|
100
|
+
"-o",
|
|
101
|
+
"--output",
|
|
102
|
+
help="Output .ndjson file path (default: replace .bin.ndjson with .ndjson)",
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
parser.add_argument(
|
|
106
|
+
"-v", "--verbose", action="store_true", help="Enable verbose output"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
args = parser.parse_args()
|
|
110
|
+
|
|
111
|
+
if args.verbose:
|
|
112
|
+
print(f"Decompressing: {args.input_file}")
|
|
113
|
+
if args.output:
|
|
114
|
+
print(f"Output file: {args.output}")
|
|
115
|
+
|
|
116
|
+
decompress_bin_ndjson(args.input_file, args.output)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
if __name__ == "__main__":
|
|
120
|
+
main()
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
import re
|
|
3
|
+
import subprocess
|
|
4
|
+
|
|
5
|
+
# Regex patterns for nvdisasm output
|
|
6
|
+
NVDISASM_FNAME_RE = re.compile(r"^\s*\.global\s+(\w+)")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def path_to_nvdisasm():
|
|
10
|
+
from triton import knobs
|
|
11
|
+
|
|
12
|
+
return knobs.nvidia.nvdisasm.path
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def is_nvdisasm_available():
|
|
16
|
+
try:
|
|
17
|
+
if path_to_nvdisasm():
|
|
18
|
+
return True
|
|
19
|
+
else:
|
|
20
|
+
return False
|
|
21
|
+
except RuntimeError:
|
|
22
|
+
return False
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def extract(file_path):
|
|
26
|
+
"""Extract SASS from CUBIN using nvdisasm.
|
|
27
|
+
|
|
28
|
+
nvdisasm output is much cleaner than cuobjdump:
|
|
29
|
+
- Single line per instruction (no encoding lines)
|
|
30
|
+
- Labels are already symbolized (.L_x_0 instead of addresses)
|
|
31
|
+
- Source line information is included
|
|
32
|
+
- No need for complex address remapping
|
|
33
|
+
|
|
34
|
+
nvdisasm Documentation:
|
|
35
|
+
https://docs.nvidia.com/cuda/cuda-binary-utilities/index.html
|
|
36
|
+
"""
|
|
37
|
+
nvdisasm = path_to_nvdisasm()
|
|
38
|
+
args = [nvdisasm, "-c", "-gp", "-g", "-gi", file_path]
|
|
39
|
+
sass_str = subprocess.check_output(args)
|
|
40
|
+
sass_lines = sass_str.splitlines()
|
|
41
|
+
line_idx = 0
|
|
42
|
+
|
|
43
|
+
while line_idx < len(sass_lines):
|
|
44
|
+
line = sass_lines[line_idx].decode()
|
|
45
|
+
|
|
46
|
+
# Find function definition (.global function_name)
|
|
47
|
+
while NVDISASM_FNAME_RE.match(line) is None:
|
|
48
|
+
line_idx += 1
|
|
49
|
+
if line_idx >= len(sass_lines):
|
|
50
|
+
return None
|
|
51
|
+
line = sass_lines[line_idx].decode()
|
|
52
|
+
|
|
53
|
+
# Extract function name
|
|
54
|
+
match = NVDISASM_FNAME_RE.match(line)
|
|
55
|
+
if match is None:
|
|
56
|
+
return None
|
|
57
|
+
fname = match.group(1)
|
|
58
|
+
ret = f"Function:{fname}\n"
|
|
59
|
+
|
|
60
|
+
# Find the actual start of function content (.text.kernel_name:)
|
|
61
|
+
text_section_pattern = f".text.{fname}:"
|
|
62
|
+
line_idx += 1
|
|
63
|
+
while line_idx < len(sass_lines):
|
|
64
|
+
line = sass_lines[line_idx].decode().strip()
|
|
65
|
+
if line == text_section_pattern:
|
|
66
|
+
line_idx += 1 # Move past the .text.kernel_name: line
|
|
67
|
+
break
|
|
68
|
+
line_idx += 1
|
|
69
|
+
|
|
70
|
+
# Process all lines until next .headerflags or end of file
|
|
71
|
+
while line_idx < len(sass_lines):
|
|
72
|
+
line = sass_lines[line_idx].decode().rstrip()
|
|
73
|
+
|
|
74
|
+
# Stop if we encounter next function's headerflags
|
|
75
|
+
if line.strip().startswith(".headerflags"):
|
|
76
|
+
break
|
|
77
|
+
ret += line + "\n"
|
|
78
|
+
line_idx += 1
|
|
79
|
+
|
|
80
|
+
ret += "\n"
|
|
81
|
+
return ret
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Extract IR files from NDJSON trace logs.
|
|
6
|
+
|
|
7
|
+
This script extracts intermediate representation (IR) files from a Triton trace NDJSON file.
|
|
8
|
+
For compilation events, it extracts the IR files (ttir, ttgir, llir, ptx, etc.) contained in
|
|
9
|
+
the file_content field and saves them as individual files.
|
|
10
|
+
|
|
11
|
+
Example:
|
|
12
|
+
Extract IRs from line 0 (first line) of the NDJSON file:
|
|
13
|
+
python extract_irs.py -i logs.ndjson --line 0 -o output_folder
|
|
14
|
+
|
|
15
|
+
Extract from line 5:
|
|
16
|
+
python extract_irs.py -i logs.ndjson --line 5 -o ./irs
|
|
17
|
+
|
|
18
|
+
Usage:
|
|
19
|
+
python extract_irs.py -i <input.ndjson> --line <line_number> -o <output_folder>
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import argparse
|
|
23
|
+
import json
|
|
24
|
+
import sys
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from typing import Any, Dict, Optional
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def read_ndjson_line(file_path: Path, line_number: int) -> Optional[Dict[str, Any]]:
|
|
30
|
+
"""
|
|
31
|
+
Read a specific line from an NDJSON file (0-based indexing).
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
file_path: Path to the NDJSON file
|
|
35
|
+
line_number: Line number to read (0-based, where 0 = first line)
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Parsed JSON object from the specified line, or None if line doesn't exist
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
FileNotFoundError: If the input file doesn't exist
|
|
42
|
+
json.JSONDecodeError: If the line contains invalid JSON
|
|
43
|
+
"""
|
|
44
|
+
if not file_path.exists():
|
|
45
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
49
|
+
for current_line_num, line in enumerate(f):
|
|
50
|
+
if current_line_num == line_number:
|
|
51
|
+
line = line.strip()
|
|
52
|
+
if not line:
|
|
53
|
+
print(f"Warning: Line {line_number} is empty", file=sys.stderr)
|
|
54
|
+
return None
|
|
55
|
+
return json.loads(line)
|
|
56
|
+
|
|
57
|
+
print(
|
|
58
|
+
f"Error: Line {line_number} not found in file (file has fewer lines)",
|
|
59
|
+
file=sys.stderr,
|
|
60
|
+
)
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
except json.JSONDecodeError as e:
|
|
64
|
+
print(f"Error: Invalid JSON on line {line_number}: {e}", file=sys.stderr)
|
|
65
|
+
raise
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def extract_irs(
|
|
69
|
+
json_obj: Dict[str, Any], output_dir: Path, kernel_name: Optional[str] = None
|
|
70
|
+
) -> int:
|
|
71
|
+
"""
|
|
72
|
+
Extract IR files from a JSON object and save them to the output directory.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
json_obj: Parsed JSON object from the NDJSON file
|
|
76
|
+
output_dir: Directory to save the extracted IR files
|
|
77
|
+
kernel_name: Optional kernel name to use for file naming (overrides metadata.name)
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Number of files extracted
|
|
81
|
+
|
|
82
|
+
Raises:
|
|
83
|
+
ValueError: If the JSON object is not a compilation event or missing required fields
|
|
84
|
+
"""
|
|
85
|
+
# Validate that this is a compilation event
|
|
86
|
+
event_type = json_obj.get("event_type")
|
|
87
|
+
if event_type != "compilation":
|
|
88
|
+
raise ValueError(f"Not a compilation event (event_type: {event_type})")
|
|
89
|
+
|
|
90
|
+
payload = json_obj.get("payload")
|
|
91
|
+
if not payload:
|
|
92
|
+
raise ValueError("Missing 'payload' field in JSON object")
|
|
93
|
+
|
|
94
|
+
# Get file_content
|
|
95
|
+
file_content = payload.get("file_content")
|
|
96
|
+
if not file_content:
|
|
97
|
+
raise ValueError("Missing 'file_content' field in payload")
|
|
98
|
+
|
|
99
|
+
# Determine kernel name
|
|
100
|
+
if kernel_name is None:
|
|
101
|
+
metadata = payload.get("metadata", {})
|
|
102
|
+
kernel_name = metadata.get("name", "kernel")
|
|
103
|
+
|
|
104
|
+
# Create output directory if it doesn't exist
|
|
105
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
106
|
+
|
|
107
|
+
# Extract each IR file
|
|
108
|
+
files_extracted = 0
|
|
109
|
+
for file_key, content in file_content.items():
|
|
110
|
+
# Determine file extension from the key
|
|
111
|
+
# file_key is typically like "embedding_forward_kernel.ttir"
|
|
112
|
+
# We want to extract just the extension
|
|
113
|
+
if "." in file_key:
|
|
114
|
+
extension = file_key.split(".")[-1]
|
|
115
|
+
else:
|
|
116
|
+
extension = "txt"
|
|
117
|
+
|
|
118
|
+
# Create output filename
|
|
119
|
+
output_filename = f"{kernel_name}.{extension}"
|
|
120
|
+
output_path = output_dir / output_filename
|
|
121
|
+
|
|
122
|
+
# Write content to file
|
|
123
|
+
try:
|
|
124
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
125
|
+
f.write(content)
|
|
126
|
+
print(f"Extracted: {output_path}")
|
|
127
|
+
files_extracted += 1
|
|
128
|
+
except OSError as e:
|
|
129
|
+
print(f"Error writing file {output_path}: {e}", file=sys.stderr)
|
|
130
|
+
|
|
131
|
+
# Optionally extract Python source code
|
|
132
|
+
python_source = payload.get("python_source")
|
|
133
|
+
if python_source and isinstance(python_source, dict):
|
|
134
|
+
source_code = python_source.get("code")
|
|
135
|
+
if source_code:
|
|
136
|
+
output_path = output_dir / f"{kernel_name}_source.py"
|
|
137
|
+
try:
|
|
138
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
139
|
+
# Add header comment with file path and line range
|
|
140
|
+
file_path_info = python_source.get("file_path", "unknown")
|
|
141
|
+
start_line = python_source.get("start_line", "?")
|
|
142
|
+
end_line = python_source.get("end_line", "?")
|
|
143
|
+
f.write(f"# Source: {file_path_info}\n")
|
|
144
|
+
f.write(f"# Lines: {start_line}-{end_line}\n\n")
|
|
145
|
+
f.write(source_code)
|
|
146
|
+
print(f"Extracted Python source: {output_path}")
|
|
147
|
+
files_extracted += 1
|
|
148
|
+
except OSError as e:
|
|
149
|
+
print(
|
|
150
|
+
f"Error writing Python source file {output_path}: {e}",
|
|
151
|
+
file=sys.stderr,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
return files_extracted
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def main():
|
|
158
|
+
"""Main function to handle command line arguments and orchestrate IR extraction."""
|
|
159
|
+
parser = argparse.ArgumentParser(
|
|
160
|
+
description="Extract IR files from Triton trace NDJSON logs",
|
|
161
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
162
|
+
epilog="""
|
|
163
|
+
Examples:
|
|
164
|
+
Extract IRs from line 0 (first line):
|
|
165
|
+
python extract_irs.py -i logs.ndjson --line 0 -o output_folder
|
|
166
|
+
|
|
167
|
+
Extract from line 5:
|
|
168
|
+
python extract_irs.py -i logs.ndjson --line 5 -o ./irs
|
|
169
|
+
|
|
170
|
+
Specify custom kernel name:
|
|
171
|
+
python extract_irs.py -i logs.ndjson --line 0 -o ./irs --kernel-name my_kernel
|
|
172
|
+
""",
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
parser.add_argument(
|
|
176
|
+
"-i", "--input", type=str, required=True, help="Path to the input NDJSON file"
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
parser.add_argument(
|
|
180
|
+
"--line",
|
|
181
|
+
type=int,
|
|
182
|
+
required=True,
|
|
183
|
+
help="Line number to extract (0-based indexing, where 0 = first line)",
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
parser.add_argument(
|
|
187
|
+
"-o",
|
|
188
|
+
"--output",
|
|
189
|
+
type=str,
|
|
190
|
+
required=True,
|
|
191
|
+
help="Output directory to save extracted IR files",
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
parser.add_argument(
|
|
195
|
+
"--kernel-name",
|
|
196
|
+
type=str,
|
|
197
|
+
help="Custom kernel name for output files (default: use metadata.name from JSON)",
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
args = parser.parse_args()
|
|
201
|
+
|
|
202
|
+
# Validate line number
|
|
203
|
+
if args.line < 0:
|
|
204
|
+
print(
|
|
205
|
+
f"Error: Line number must be non-negative (got {args.line})",
|
|
206
|
+
file=sys.stderr,
|
|
207
|
+
)
|
|
208
|
+
sys.exit(1)
|
|
209
|
+
|
|
210
|
+
# Convert to Path objects
|
|
211
|
+
input_path = Path(args.input)
|
|
212
|
+
output_dir = Path(args.output)
|
|
213
|
+
|
|
214
|
+
try:
|
|
215
|
+
# Read the specified line
|
|
216
|
+
print(f"Reading line {args.line} from {input_path}...")
|
|
217
|
+
json_obj = read_ndjson_line(input_path, args.line)
|
|
218
|
+
|
|
219
|
+
if json_obj is None:
|
|
220
|
+
print("Error: Failed to read JSON from specified line", file=sys.stderr)
|
|
221
|
+
sys.exit(1)
|
|
222
|
+
|
|
223
|
+
# Extract IRs
|
|
224
|
+
print(f"Extracting IRs to {output_dir}...")
|
|
225
|
+
num_files = extract_irs(json_obj, output_dir, args.kernel_name)
|
|
226
|
+
|
|
227
|
+
print(f"\nSuccess! Extracted {num_files} file(s) to {output_dir}")
|
|
228
|
+
|
|
229
|
+
except FileNotFoundError as e:
|
|
230
|
+
print(f"Error: {e}", file=sys.stderr)
|
|
231
|
+
sys.exit(1)
|
|
232
|
+
except ValueError as e:
|
|
233
|
+
print(f"Error: {e}", file=sys.stderr)
|
|
234
|
+
sys.exit(1)
|
|
235
|
+
except json.JSONDecodeError as e:
|
|
236
|
+
print(f"Error: Failed to parse JSON - {e}", file=sys.stderr)
|
|
237
|
+
sys.exit(1)
|
|
238
|
+
except Exception as e:
|
|
239
|
+
print(f"Unexpected error: {e}", file=sys.stderr)
|
|
240
|
+
sys.exit(1)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
if __name__ == "__main__":
|
|
244
|
+
main()
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Format fix script for tritonparse project.
|
|
6
|
+
|
|
7
|
+
This script runs all linter tools to format and fix code issues:
|
|
8
|
+
- usort: Import sorting
|
|
9
|
+
- ruff: Linting only
|
|
10
|
+
- black: Code formatting
|
|
11
|
+
|
|
12
|
+
Usage:
|
|
13
|
+
python -m tritonparse.tools.format_fix [options]
|
|
14
|
+
|
|
15
|
+
Options:
|
|
16
|
+
--check-only Only check for issues, don't fix them
|
|
17
|
+
--verbose Verbose output
|
|
18
|
+
--help Show this help message
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
import argparse
|
|
22
|
+
import subprocess
|
|
23
|
+
import sys
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def run_command(cmd: list[str], verbose: bool = False) -> bool:
|
|
27
|
+
"""Run a command and return success status."""
|
|
28
|
+
if verbose:
|
|
29
|
+
print(f"Running: {' '.join(cmd)}")
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
|
33
|
+
|
|
34
|
+
if result.returncode != 0:
|
|
35
|
+
if verbose:
|
|
36
|
+
print(f"Command failed with return code {result.returncode}")
|
|
37
|
+
if result.stdout:
|
|
38
|
+
print("STDOUT:", result.stdout)
|
|
39
|
+
if result.stderr:
|
|
40
|
+
print("STDERR:", result.stderr)
|
|
41
|
+
return False
|
|
42
|
+
|
|
43
|
+
if verbose and result.stdout:
|
|
44
|
+
print(result.stdout)
|
|
45
|
+
|
|
46
|
+
return True
|
|
47
|
+
except Exception as e:
|
|
48
|
+
if verbose:
|
|
49
|
+
print(f"Error running command: {e}")
|
|
50
|
+
return False
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def run_usort(check_only: bool = False, verbose: bool = False) -> bool:
|
|
54
|
+
"""Run usort for import sorting."""
|
|
55
|
+
cmd = ["usort"]
|
|
56
|
+
|
|
57
|
+
if check_only:
|
|
58
|
+
cmd.extend(["check", "."])
|
|
59
|
+
else:
|
|
60
|
+
cmd.extend(["format", "."])
|
|
61
|
+
|
|
62
|
+
return run_command(cmd, verbose)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def run_ruff_check(check_only: bool = False, verbose: bool = False) -> bool:
|
|
66
|
+
"""Run ruff for linting only."""
|
|
67
|
+
cmd = ["ruff", "check", "."]
|
|
68
|
+
|
|
69
|
+
if check_only:
|
|
70
|
+
cmd.append("--diff")
|
|
71
|
+
else:
|
|
72
|
+
cmd.append("--fix")
|
|
73
|
+
|
|
74
|
+
return run_command(cmd, verbose)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def run_black(check_only: bool = False, verbose: bool = False) -> bool:
|
|
78
|
+
"""Run black for code formatting."""
|
|
79
|
+
cmd = ["black"]
|
|
80
|
+
|
|
81
|
+
if check_only:
|
|
82
|
+
cmd.extend(["--check", "--diff", "."])
|
|
83
|
+
else:
|
|
84
|
+
cmd.append(".")
|
|
85
|
+
|
|
86
|
+
return run_command(cmd, verbose)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def main():
|
|
90
|
+
"""Main function."""
|
|
91
|
+
parser = argparse.ArgumentParser(
|
|
92
|
+
description="Format fix script for tritonparse project",
|
|
93
|
+
epilog="""
|
|
94
|
+
Examples:
|
|
95
|
+
# Fix all formatting issues
|
|
96
|
+
python -m tritonparse.tools.format_fix
|
|
97
|
+
|
|
98
|
+
# Check for issues without fixing
|
|
99
|
+
python -m tritonparse.tools.format_fix --check-only
|
|
100
|
+
|
|
101
|
+
# Verbose output
|
|
102
|
+
python -m tritonparse.tools.format_fix --verbose
|
|
103
|
+
""",
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
parser.add_argument(
|
|
107
|
+
"--check-only",
|
|
108
|
+
action="store_true",
|
|
109
|
+
help="Only check for issues, don't fix them",
|
|
110
|
+
)
|
|
111
|
+
parser.add_argument("--verbose", action="store_true", help="Verbose output")
|
|
112
|
+
|
|
113
|
+
args = parser.parse_args()
|
|
114
|
+
|
|
115
|
+
# Run formatters on the entire project
|
|
116
|
+
success = True
|
|
117
|
+
|
|
118
|
+
# 1. Run usort for import sorting
|
|
119
|
+
print("Running usort for import sorting...")
|
|
120
|
+
if not run_usort(args.check_only, args.verbose):
|
|
121
|
+
print("❌ usort failed")
|
|
122
|
+
success = False
|
|
123
|
+
else:
|
|
124
|
+
print("✅ usort completed")
|
|
125
|
+
|
|
126
|
+
# 2. Run ruff for linting only
|
|
127
|
+
print("Running ruff for linting...")
|
|
128
|
+
if not run_ruff_check(args.check_only, args.verbose):
|
|
129
|
+
print("❌ ruff linting failed")
|
|
130
|
+
success = False
|
|
131
|
+
else:
|
|
132
|
+
print("✅ ruff linting completed")
|
|
133
|
+
|
|
134
|
+
# 3. Run black for code formatting
|
|
135
|
+
print("Running black for code formatting...")
|
|
136
|
+
if not run_black(args.check_only, args.verbose):
|
|
137
|
+
print("❌ black failed")
|
|
138
|
+
success = False
|
|
139
|
+
else:
|
|
140
|
+
print("✅ black completed")
|
|
141
|
+
|
|
142
|
+
if success:
|
|
143
|
+
print("\n🎉 All formatting tools completed successfully!")
|
|
144
|
+
return 0
|
|
145
|
+
else:
|
|
146
|
+
print("\n❌ Some formatting tools failed. Please check the output above.")
|
|
147
|
+
return 1
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
if __name__ == "__main__":
|
|
151
|
+
sys.exit(main())
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Simple tensor loading utility for tritonparse saved tensors.
|
|
6
|
+
Usage:
|
|
7
|
+
import tritonparse.tools.load_tensor as load_tensor
|
|
8
|
+
tensor = load_tensor.load_tensor(tensor_file_path, device)
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import gzip
|
|
12
|
+
import hashlib
|
|
13
|
+
import io
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def load_tensor(tensor_file_path: Union[str, Path], device: str = None) -> torch.Tensor:
|
|
21
|
+
"""
|
|
22
|
+
Load a tensor from its file path and verify its integrity using the hash in the filename.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
tensor_file_path (str | Path): Direct path to the tensor file. Supports both:
|
|
26
|
+
- .bin.gz: gzip-compressed tensor (hash is of uncompressed data)
|
|
27
|
+
- .bin: uncompressed tensor (for backward compatibility)
|
|
28
|
+
device (str, optional): Device to load the tensor to (e.g., 'cuda:0', 'cpu').
|
|
29
|
+
If None, keeps the tensor on its original device.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
torch.Tensor: The loaded tensor (moved to the specified device if provided)
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
FileNotFoundError: If the tensor file doesn't exist
|
|
36
|
+
RuntimeError: If the tensor cannot be loaded
|
|
37
|
+
ValueError: If the computed hash doesn't match the filename hash
|
|
38
|
+
"""
|
|
39
|
+
blob_path = Path(tensor_file_path)
|
|
40
|
+
|
|
41
|
+
if not blob_path.exists():
|
|
42
|
+
raise FileNotFoundError(f"Tensor blob not found: {blob_path}")
|
|
43
|
+
|
|
44
|
+
# Detect compression by file extension
|
|
45
|
+
is_compressed = blob_path.name.endswith(".bin.gz")
|
|
46
|
+
|
|
47
|
+
# Read file contents (decompress if needed)
|
|
48
|
+
try:
|
|
49
|
+
with open(blob_path, "rb") as f:
|
|
50
|
+
file_obj = gzip.GzipFile(fileobj=f, mode="rb") if is_compressed else f
|
|
51
|
+
file_contents = file_obj.read()
|
|
52
|
+
except (OSError, gzip.BadGzipFile) as e:
|
|
53
|
+
if is_compressed:
|
|
54
|
+
raise RuntimeError(f"Failed to decompress gzip file {blob_path}: {str(e)}")
|
|
55
|
+
else:
|
|
56
|
+
raise RuntimeError(f"Failed to read file {blob_path}: {str(e)}")
|
|
57
|
+
|
|
58
|
+
# Extract expected hash from filename
|
|
59
|
+
# abc123.bin.gz -> abc123 or abc123.bin -> abc123
|
|
60
|
+
expected_hash = blob_path.name.removesuffix(".bin.gz" if is_compressed else ".bin")
|
|
61
|
+
|
|
62
|
+
# Compute hash of uncompressed data
|
|
63
|
+
computed_hash = hashlib.blake2b(file_contents).hexdigest()
|
|
64
|
+
|
|
65
|
+
# Verify hash matches filename
|
|
66
|
+
if computed_hash != expected_hash:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"Hash verification failed: expected '{expected_hash}' but computed '{computed_hash}'"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
# Load the tensor from memory buffer
|
|
73
|
+
tensor = torch.load(io.BytesIO(file_contents), map_location=device)
|
|
74
|
+
return tensor
|
|
75
|
+
except Exception as e:
|
|
76
|
+
raise RuntimeError(f"Failed to load tensor from {blob_path}: {str(e)}")
|