tritonparse 0.3.2.dev20251210071601__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritonparse might be problematic. Click here for more details.

Files changed (62) hide show
  1. tritonparse/__init__.py +0 -0
  2. tritonparse/__main__.py +7 -0
  3. tritonparse/cli.py +110 -0
  4. tritonparse/common.py +409 -0
  5. tritonparse/context_manager.py +64 -0
  6. tritonparse/event_diff.py +122 -0
  7. tritonparse/extract_source_mappings.py +49 -0
  8. tritonparse/info/__init__.py +30 -0
  9. tritonparse/info/cli.py +121 -0
  10. tritonparse/info/kernel_query.py +209 -0
  11. tritonparse/info/parse_helper.py +70 -0
  12. tritonparse/ir_analysis.py +427 -0
  13. tritonparse/ir_parser.py +365 -0
  14. tritonparse/mapper.py +102 -0
  15. tritonparse/reproducer/__init__.py +0 -0
  16. tritonparse/reproducer/ast_analyzer.py +636 -0
  17. tritonparse/reproducer/cli.py +72 -0
  18. tritonparse/reproducer/consolidated_result.py +52 -0
  19. tritonparse/reproducer/function_extractor.py +228 -0
  20. tritonparse/reproducer/import_info.py +25 -0
  21. tritonparse/reproducer/import_parser.py +178 -0
  22. tritonparse/reproducer/import_resolver.py +151 -0
  23. tritonparse/reproducer/ingestion/ndjson.py +237 -0
  24. tritonparse/reproducer/multi_file_analyzer.py +824 -0
  25. tritonparse/reproducer/orchestrator.py +110 -0
  26. tritonparse/reproducer/placeholder_replacer.py +335 -0
  27. tritonparse/reproducer/templates/__init__.py +0 -0
  28. tritonparse/reproducer/templates/example.py +38 -0
  29. tritonparse/reproducer/templates/loader.py +59 -0
  30. tritonparse/reproducer/templates/tritonbench.py +106 -0
  31. tritonparse/reproducer/templates/utils.py +48 -0
  32. tritonparse/reproducer/tests/__init__.py +0 -0
  33. tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
  34. tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
  35. tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
  36. tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
  37. tritonparse/reproducer/tests/test_import_parser.py +164 -0
  38. tritonparse/reproducer/tests/test_import_resolver.py +88 -0
  39. tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
  40. tritonparse/reproducer/types.py +20 -0
  41. tritonparse/reproducer/utils.py +580 -0
  42. tritonparse/shared_vars.py +12 -0
  43. tritonparse/source_type.py +56 -0
  44. tritonparse/sourcemap_utils.py +96 -0
  45. tritonparse/structured_logging.py +1634 -0
  46. tritonparse/tools/__init__.py +0 -0
  47. tritonparse/tools/decompress_bin_ndjson.py +120 -0
  48. tritonparse/tools/disasm.py +81 -0
  49. tritonparse/tools/extract_irs.py +244 -0
  50. tritonparse/tools/format_fix.py +151 -0
  51. tritonparse/tools/load_tensor.py +76 -0
  52. tritonparse/tools/prettify_ndjson.py +334 -0
  53. tritonparse/tools/readme.md +37 -0
  54. tritonparse/tp_logger.py +9 -0
  55. tritonparse/trace_processor.py +367 -0
  56. tritonparse/utils.py +155 -0
  57. tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
  58. tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
  59. tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
  60. tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
  61. tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
  62. tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
File without changes
@@ -0,0 +1,120 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ """
5
+ Script to decompress .bin.ndjson files back to regular .ndjson format.
6
+
7
+ The .bin.ndjson format stores each JSON record as a separate gzip member,
8
+ concatenated in sequence within a single binary file. This script uses
9
+ gzip.open() which automatically handles member concatenation to read
10
+ the compressed file and write out the original NDJSON format.
11
+
12
+ Usage:
13
+ python decompress_bin_ndjson.py trace.bin.ndjson
14
+ """
15
+
16
+ import argparse
17
+ import gzip
18
+ import sys
19
+ from pathlib import Path
20
+
21
+
22
+ def decompress_bin_ndjson(input_file: str, output_file: str = None) -> None:
23
+ """
24
+ Decompress a .bin.ndjson file to regular .ndjson format.
25
+
26
+ Args:
27
+ input_file: Path to the .bin.ndjson file
28
+ output_file: Path for the output .ndjson file (optional)
29
+ """
30
+ input_path = Path(input_file)
31
+
32
+ # Validate input file
33
+ if not input_path.exists():
34
+ print(f"Error: Input file '{input_file}' does not exist", file=sys.stderr)
35
+ return
36
+
37
+ if not input_path.suffix.endswith(".bin.ndjson"):
38
+ print(f"Warning: Input file '{input_file}' doesn't have .bin.ndjson extension")
39
+
40
+ # Determine output file path
41
+ if output_file is None:
42
+ if input_path.name.endswith(".bin.ndjson"):
43
+ # Replace .bin.ndjson with .ndjson
44
+ output_file = str(input_path.with_suffix("").with_suffix(".ndjson"))
45
+ else:
46
+ # Add .decompressed.ndjson suffix
47
+ output_file = str(input_path.with_suffix(".decompressed.ndjson"))
48
+
49
+ output_path = Path(output_file)
50
+
51
+ try:
52
+ line_count = 0
53
+ # Because we use NDJSON format, each line is a complete JSON record.
54
+ # It is guruanteed here https://github.com/meta-pytorch/tritonparse/blob/
55
+ # c8dcc2a94ac10ede4342dba7456f6ebd8409b95d/tritonparse/structured_logging.py#L320
56
+ with gzip.open(input_path, "rt", encoding="utf-8") as compressed_file:
57
+ with open(output_path, "w", encoding="utf-8") as output:
58
+ for line in compressed_file:
59
+ # gzip.open automatically handles member concatenation
60
+ # Each line is already a complete JSON record with newline
61
+ output.write(line)
62
+ line_count += 1
63
+
64
+ # Get file sizes for comparison
65
+ input_size = input_path.stat().st_size
66
+ output_size = output_path.stat().st_size
67
+ compression_ratio = (
68
+ (1 - input_size / output_size) * 100 if output_size > 0 else 0
69
+ )
70
+
71
+ print(f"Successfully decompressed '{input_file}' to '{output_file}'")
72
+ print(f" Input size: {input_size:,} bytes")
73
+ print(f" Output size: {output_size:,} bytes")
74
+ print(f" Compression ratio: {compression_ratio:.1f}%")
75
+ print(f" Records processed: {line_count:,}")
76
+
77
+ except gzip.BadGzipFile as e:
78
+ print(f"Error: Invalid gzip format in '{input_file}': {e}", file=sys.stderr)
79
+ except UnicodeDecodeError as e:
80
+ print(f"Error: Unicode decode error in '{input_file}': {e}", file=sys.stderr)
81
+ except Exception as e:
82
+ print(f"Error: Failed to decompress '{input_file}': {e}", file=sys.stderr)
83
+
84
+
85
+ def main():
86
+ parser = argparse.ArgumentParser(
87
+ description="Decompress .bin.ndjson files to regular .ndjson format",
88
+ formatter_class=argparse.RawDescriptionHelpFormatter,
89
+ epilog="""
90
+ Examples:
91
+ %(prog)s trace.bin.ndjson
92
+ %(prog)s trace.bin.ndjson -o output.ndjson
93
+ %(prog)s /logs/dedicated_log_triton_trace_user_.bin.ndjson
94
+ """,
95
+ )
96
+
97
+ parser.add_argument("input_file", help="Input .bin.ndjson file to decompress")
98
+
99
+ parser.add_argument(
100
+ "-o",
101
+ "--output",
102
+ help="Output .ndjson file path (default: replace .bin.ndjson with .ndjson)",
103
+ )
104
+
105
+ parser.add_argument(
106
+ "-v", "--verbose", action="store_true", help="Enable verbose output"
107
+ )
108
+
109
+ args = parser.parse_args()
110
+
111
+ if args.verbose:
112
+ print(f"Decompressing: {args.input_file}")
113
+ if args.output:
114
+ print(f"Output file: {args.output}")
115
+
116
+ decompress_bin_ndjson(args.input_file, args.output)
117
+
118
+
119
+ if __name__ == "__main__":
120
+ main()
@@ -0,0 +1,81 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import re
3
+ import subprocess
4
+
5
+ # Regex patterns for nvdisasm output
6
+ NVDISASM_FNAME_RE = re.compile(r"^\s*\.global\s+(\w+)")
7
+
8
+
9
+ def path_to_nvdisasm():
10
+ from triton import knobs
11
+
12
+ return knobs.nvidia.nvdisasm.path
13
+
14
+
15
+ def is_nvdisasm_available():
16
+ try:
17
+ if path_to_nvdisasm():
18
+ return True
19
+ else:
20
+ return False
21
+ except RuntimeError:
22
+ return False
23
+
24
+
25
+ def extract(file_path):
26
+ """Extract SASS from CUBIN using nvdisasm.
27
+
28
+ nvdisasm output is much cleaner than cuobjdump:
29
+ - Single line per instruction (no encoding lines)
30
+ - Labels are already symbolized (.L_x_0 instead of addresses)
31
+ - Source line information is included
32
+ - No need for complex address remapping
33
+
34
+ nvdisasm Documentation:
35
+ https://docs.nvidia.com/cuda/cuda-binary-utilities/index.html
36
+ """
37
+ nvdisasm = path_to_nvdisasm()
38
+ args = [nvdisasm, "-c", "-gp", "-g", "-gi", file_path]
39
+ sass_str = subprocess.check_output(args)
40
+ sass_lines = sass_str.splitlines()
41
+ line_idx = 0
42
+
43
+ while line_idx < len(sass_lines):
44
+ line = sass_lines[line_idx].decode()
45
+
46
+ # Find function definition (.global function_name)
47
+ while NVDISASM_FNAME_RE.match(line) is None:
48
+ line_idx += 1
49
+ if line_idx >= len(sass_lines):
50
+ return None
51
+ line = sass_lines[line_idx].decode()
52
+
53
+ # Extract function name
54
+ match = NVDISASM_FNAME_RE.match(line)
55
+ if match is None:
56
+ return None
57
+ fname = match.group(1)
58
+ ret = f"Function:{fname}\n"
59
+
60
+ # Find the actual start of function content (.text.kernel_name:)
61
+ text_section_pattern = f".text.{fname}:"
62
+ line_idx += 1
63
+ while line_idx < len(sass_lines):
64
+ line = sass_lines[line_idx].decode().strip()
65
+ if line == text_section_pattern:
66
+ line_idx += 1 # Move past the .text.kernel_name: line
67
+ break
68
+ line_idx += 1
69
+
70
+ # Process all lines until next .headerflags or end of file
71
+ while line_idx < len(sass_lines):
72
+ line = sass_lines[line_idx].decode().rstrip()
73
+
74
+ # Stop if we encounter next function's headerflags
75
+ if line.strip().startswith(".headerflags"):
76
+ break
77
+ ret += line + "\n"
78
+ line_idx += 1
79
+
80
+ ret += "\n"
81
+ return ret
@@ -0,0 +1,244 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ """
5
+ Extract IR files from NDJSON trace logs.
6
+
7
+ This script extracts intermediate representation (IR) files from a Triton trace NDJSON file.
8
+ For compilation events, it extracts the IR files (ttir, ttgir, llir, ptx, etc.) contained in
9
+ the file_content field and saves them as individual files.
10
+
11
+ Example:
12
+ Extract IRs from line 0 (first line) of the NDJSON file:
13
+ python extract_irs.py -i logs.ndjson --line 0 -o output_folder
14
+
15
+ Extract from line 5:
16
+ python extract_irs.py -i logs.ndjson --line 5 -o ./irs
17
+
18
+ Usage:
19
+ python extract_irs.py -i <input.ndjson> --line <line_number> -o <output_folder>
20
+ """
21
+
22
+ import argparse
23
+ import json
24
+ import sys
25
+ from pathlib import Path
26
+ from typing import Any, Dict, Optional
27
+
28
+
29
+ def read_ndjson_line(file_path: Path, line_number: int) -> Optional[Dict[str, Any]]:
30
+ """
31
+ Read a specific line from an NDJSON file (0-based indexing).
32
+
33
+ Args:
34
+ file_path: Path to the NDJSON file
35
+ line_number: Line number to read (0-based, where 0 = first line)
36
+
37
+ Returns:
38
+ Parsed JSON object from the specified line, or None if line doesn't exist
39
+
40
+ Raises:
41
+ FileNotFoundError: If the input file doesn't exist
42
+ json.JSONDecodeError: If the line contains invalid JSON
43
+ """
44
+ if not file_path.exists():
45
+ raise FileNotFoundError(f"File not found: {file_path}")
46
+
47
+ try:
48
+ with open(file_path, "r", encoding="utf-8") as f:
49
+ for current_line_num, line in enumerate(f):
50
+ if current_line_num == line_number:
51
+ line = line.strip()
52
+ if not line:
53
+ print(f"Warning: Line {line_number} is empty", file=sys.stderr)
54
+ return None
55
+ return json.loads(line)
56
+
57
+ print(
58
+ f"Error: Line {line_number} not found in file (file has fewer lines)",
59
+ file=sys.stderr,
60
+ )
61
+ return None
62
+
63
+ except json.JSONDecodeError as e:
64
+ print(f"Error: Invalid JSON on line {line_number}: {e}", file=sys.stderr)
65
+ raise
66
+
67
+
68
+ def extract_irs(
69
+ json_obj: Dict[str, Any], output_dir: Path, kernel_name: Optional[str] = None
70
+ ) -> int:
71
+ """
72
+ Extract IR files from a JSON object and save them to the output directory.
73
+
74
+ Args:
75
+ json_obj: Parsed JSON object from the NDJSON file
76
+ output_dir: Directory to save the extracted IR files
77
+ kernel_name: Optional kernel name to use for file naming (overrides metadata.name)
78
+
79
+ Returns:
80
+ Number of files extracted
81
+
82
+ Raises:
83
+ ValueError: If the JSON object is not a compilation event or missing required fields
84
+ """
85
+ # Validate that this is a compilation event
86
+ event_type = json_obj.get("event_type")
87
+ if event_type != "compilation":
88
+ raise ValueError(f"Not a compilation event (event_type: {event_type})")
89
+
90
+ payload = json_obj.get("payload")
91
+ if not payload:
92
+ raise ValueError("Missing 'payload' field in JSON object")
93
+
94
+ # Get file_content
95
+ file_content = payload.get("file_content")
96
+ if not file_content:
97
+ raise ValueError("Missing 'file_content' field in payload")
98
+
99
+ # Determine kernel name
100
+ if kernel_name is None:
101
+ metadata = payload.get("metadata", {})
102
+ kernel_name = metadata.get("name", "kernel")
103
+
104
+ # Create output directory if it doesn't exist
105
+ output_dir.mkdir(parents=True, exist_ok=True)
106
+
107
+ # Extract each IR file
108
+ files_extracted = 0
109
+ for file_key, content in file_content.items():
110
+ # Determine file extension from the key
111
+ # file_key is typically like "embedding_forward_kernel.ttir"
112
+ # We want to extract just the extension
113
+ if "." in file_key:
114
+ extension = file_key.split(".")[-1]
115
+ else:
116
+ extension = "txt"
117
+
118
+ # Create output filename
119
+ output_filename = f"{kernel_name}.{extension}"
120
+ output_path = output_dir / output_filename
121
+
122
+ # Write content to file
123
+ try:
124
+ with open(output_path, "w", encoding="utf-8") as f:
125
+ f.write(content)
126
+ print(f"Extracted: {output_path}")
127
+ files_extracted += 1
128
+ except OSError as e:
129
+ print(f"Error writing file {output_path}: {e}", file=sys.stderr)
130
+
131
+ # Optionally extract Python source code
132
+ python_source = payload.get("python_source")
133
+ if python_source and isinstance(python_source, dict):
134
+ source_code = python_source.get("code")
135
+ if source_code:
136
+ output_path = output_dir / f"{kernel_name}_source.py"
137
+ try:
138
+ with open(output_path, "w", encoding="utf-8") as f:
139
+ # Add header comment with file path and line range
140
+ file_path_info = python_source.get("file_path", "unknown")
141
+ start_line = python_source.get("start_line", "?")
142
+ end_line = python_source.get("end_line", "?")
143
+ f.write(f"# Source: {file_path_info}\n")
144
+ f.write(f"# Lines: {start_line}-{end_line}\n\n")
145
+ f.write(source_code)
146
+ print(f"Extracted Python source: {output_path}")
147
+ files_extracted += 1
148
+ except OSError as e:
149
+ print(
150
+ f"Error writing Python source file {output_path}: {e}",
151
+ file=sys.stderr,
152
+ )
153
+
154
+ return files_extracted
155
+
156
+
157
+ def main():
158
+ """Main function to handle command line arguments and orchestrate IR extraction."""
159
+ parser = argparse.ArgumentParser(
160
+ description="Extract IR files from Triton trace NDJSON logs",
161
+ formatter_class=argparse.RawDescriptionHelpFormatter,
162
+ epilog="""
163
+ Examples:
164
+ Extract IRs from line 0 (first line):
165
+ python extract_irs.py -i logs.ndjson --line 0 -o output_folder
166
+
167
+ Extract from line 5:
168
+ python extract_irs.py -i logs.ndjson --line 5 -o ./irs
169
+
170
+ Specify custom kernel name:
171
+ python extract_irs.py -i logs.ndjson --line 0 -o ./irs --kernel-name my_kernel
172
+ """,
173
+ )
174
+
175
+ parser.add_argument(
176
+ "-i", "--input", type=str, required=True, help="Path to the input NDJSON file"
177
+ )
178
+
179
+ parser.add_argument(
180
+ "--line",
181
+ type=int,
182
+ required=True,
183
+ help="Line number to extract (0-based indexing, where 0 = first line)",
184
+ )
185
+
186
+ parser.add_argument(
187
+ "-o",
188
+ "--output",
189
+ type=str,
190
+ required=True,
191
+ help="Output directory to save extracted IR files",
192
+ )
193
+
194
+ parser.add_argument(
195
+ "--kernel-name",
196
+ type=str,
197
+ help="Custom kernel name for output files (default: use metadata.name from JSON)",
198
+ )
199
+
200
+ args = parser.parse_args()
201
+
202
+ # Validate line number
203
+ if args.line < 0:
204
+ print(
205
+ f"Error: Line number must be non-negative (got {args.line})",
206
+ file=sys.stderr,
207
+ )
208
+ sys.exit(1)
209
+
210
+ # Convert to Path objects
211
+ input_path = Path(args.input)
212
+ output_dir = Path(args.output)
213
+
214
+ try:
215
+ # Read the specified line
216
+ print(f"Reading line {args.line} from {input_path}...")
217
+ json_obj = read_ndjson_line(input_path, args.line)
218
+
219
+ if json_obj is None:
220
+ print("Error: Failed to read JSON from specified line", file=sys.stderr)
221
+ sys.exit(1)
222
+
223
+ # Extract IRs
224
+ print(f"Extracting IRs to {output_dir}...")
225
+ num_files = extract_irs(json_obj, output_dir, args.kernel_name)
226
+
227
+ print(f"\nSuccess! Extracted {num_files} file(s) to {output_dir}")
228
+
229
+ except FileNotFoundError as e:
230
+ print(f"Error: {e}", file=sys.stderr)
231
+ sys.exit(1)
232
+ except ValueError as e:
233
+ print(f"Error: {e}", file=sys.stderr)
234
+ sys.exit(1)
235
+ except json.JSONDecodeError as e:
236
+ print(f"Error: Failed to parse JSON - {e}", file=sys.stderr)
237
+ sys.exit(1)
238
+ except Exception as e:
239
+ print(f"Unexpected error: {e}", file=sys.stderr)
240
+ sys.exit(1)
241
+
242
+
243
+ if __name__ == "__main__":
244
+ main()
@@ -0,0 +1,151 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ """
5
+ Format fix script for tritonparse project.
6
+
7
+ This script runs all linter tools to format and fix code issues:
8
+ - usort: Import sorting
9
+ - ruff: Linting only
10
+ - black: Code formatting
11
+
12
+ Usage:
13
+ python -m tritonparse.tools.format_fix [options]
14
+
15
+ Options:
16
+ --check-only Only check for issues, don't fix them
17
+ --verbose Verbose output
18
+ --help Show this help message
19
+ """
20
+
21
+ import argparse
22
+ import subprocess
23
+ import sys
24
+
25
+
26
+ def run_command(cmd: list[str], verbose: bool = False) -> bool:
27
+ """Run a command and return success status."""
28
+ if verbose:
29
+ print(f"Running: {' '.join(cmd)}")
30
+
31
+ try:
32
+ result = subprocess.run(cmd, capture_output=True, text=True, check=False)
33
+
34
+ if result.returncode != 0:
35
+ if verbose:
36
+ print(f"Command failed with return code {result.returncode}")
37
+ if result.stdout:
38
+ print("STDOUT:", result.stdout)
39
+ if result.stderr:
40
+ print("STDERR:", result.stderr)
41
+ return False
42
+
43
+ if verbose and result.stdout:
44
+ print(result.stdout)
45
+
46
+ return True
47
+ except Exception as e:
48
+ if verbose:
49
+ print(f"Error running command: {e}")
50
+ return False
51
+
52
+
53
+ def run_usort(check_only: bool = False, verbose: bool = False) -> bool:
54
+ """Run usort for import sorting."""
55
+ cmd = ["usort"]
56
+
57
+ if check_only:
58
+ cmd.extend(["check", "."])
59
+ else:
60
+ cmd.extend(["format", "."])
61
+
62
+ return run_command(cmd, verbose)
63
+
64
+
65
+ def run_ruff_check(check_only: bool = False, verbose: bool = False) -> bool:
66
+ """Run ruff for linting only."""
67
+ cmd = ["ruff", "check", "."]
68
+
69
+ if check_only:
70
+ cmd.append("--diff")
71
+ else:
72
+ cmd.append("--fix")
73
+
74
+ return run_command(cmd, verbose)
75
+
76
+
77
+ def run_black(check_only: bool = False, verbose: bool = False) -> bool:
78
+ """Run black for code formatting."""
79
+ cmd = ["black"]
80
+
81
+ if check_only:
82
+ cmd.extend(["--check", "--diff", "."])
83
+ else:
84
+ cmd.append(".")
85
+
86
+ return run_command(cmd, verbose)
87
+
88
+
89
+ def main():
90
+ """Main function."""
91
+ parser = argparse.ArgumentParser(
92
+ description="Format fix script for tritonparse project",
93
+ epilog="""
94
+ Examples:
95
+ # Fix all formatting issues
96
+ python -m tritonparse.tools.format_fix
97
+
98
+ # Check for issues without fixing
99
+ python -m tritonparse.tools.format_fix --check-only
100
+
101
+ # Verbose output
102
+ python -m tritonparse.tools.format_fix --verbose
103
+ """,
104
+ )
105
+
106
+ parser.add_argument(
107
+ "--check-only",
108
+ action="store_true",
109
+ help="Only check for issues, don't fix them",
110
+ )
111
+ parser.add_argument("--verbose", action="store_true", help="Verbose output")
112
+
113
+ args = parser.parse_args()
114
+
115
+ # Run formatters on the entire project
116
+ success = True
117
+
118
+ # 1. Run usort for import sorting
119
+ print("Running usort for import sorting...")
120
+ if not run_usort(args.check_only, args.verbose):
121
+ print("❌ usort failed")
122
+ success = False
123
+ else:
124
+ print("✅ usort completed")
125
+
126
+ # 2. Run ruff for linting only
127
+ print("Running ruff for linting...")
128
+ if not run_ruff_check(args.check_only, args.verbose):
129
+ print("❌ ruff linting failed")
130
+ success = False
131
+ else:
132
+ print("✅ ruff linting completed")
133
+
134
+ # 3. Run black for code formatting
135
+ print("Running black for code formatting...")
136
+ if not run_black(args.check_only, args.verbose):
137
+ print("❌ black failed")
138
+ success = False
139
+ else:
140
+ print("✅ black completed")
141
+
142
+ if success:
143
+ print("\n🎉 All formatting tools completed successfully!")
144
+ return 0
145
+ else:
146
+ print("\n❌ Some formatting tools failed. Please check the output above.")
147
+ return 1
148
+
149
+
150
+ if __name__ == "__main__":
151
+ sys.exit(main())
@@ -0,0 +1,76 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+
4
+ """
5
+ Simple tensor loading utility for tritonparse saved tensors.
6
+ Usage:
7
+ import tritonparse.tools.load_tensor as load_tensor
8
+ tensor = load_tensor.load_tensor(tensor_file_path, device)
9
+ """
10
+
11
+ import gzip
12
+ import hashlib
13
+ import io
14
+ from pathlib import Path
15
+ from typing import Union
16
+
17
+ import torch
18
+
19
+
20
+ def load_tensor(tensor_file_path: Union[str, Path], device: str = None) -> torch.Tensor:
21
+ """
22
+ Load a tensor from its file path and verify its integrity using the hash in the filename.
23
+
24
+ Args:
25
+ tensor_file_path (str | Path): Direct path to the tensor file. Supports both:
26
+ - .bin.gz: gzip-compressed tensor (hash is of uncompressed data)
27
+ - .bin: uncompressed tensor (for backward compatibility)
28
+ device (str, optional): Device to load the tensor to (e.g., 'cuda:0', 'cpu').
29
+ If None, keeps the tensor on its original device.
30
+
31
+ Returns:
32
+ torch.Tensor: The loaded tensor (moved to the specified device if provided)
33
+
34
+ Raises:
35
+ FileNotFoundError: If the tensor file doesn't exist
36
+ RuntimeError: If the tensor cannot be loaded
37
+ ValueError: If the computed hash doesn't match the filename hash
38
+ """
39
+ blob_path = Path(tensor_file_path)
40
+
41
+ if not blob_path.exists():
42
+ raise FileNotFoundError(f"Tensor blob not found: {blob_path}")
43
+
44
+ # Detect compression by file extension
45
+ is_compressed = blob_path.name.endswith(".bin.gz")
46
+
47
+ # Read file contents (decompress if needed)
48
+ try:
49
+ with open(blob_path, "rb") as f:
50
+ file_obj = gzip.GzipFile(fileobj=f, mode="rb") if is_compressed else f
51
+ file_contents = file_obj.read()
52
+ except (OSError, gzip.BadGzipFile) as e:
53
+ if is_compressed:
54
+ raise RuntimeError(f"Failed to decompress gzip file {blob_path}: {str(e)}")
55
+ else:
56
+ raise RuntimeError(f"Failed to read file {blob_path}: {str(e)}")
57
+
58
+ # Extract expected hash from filename
59
+ # abc123.bin.gz -> abc123 or abc123.bin -> abc123
60
+ expected_hash = blob_path.name.removesuffix(".bin.gz" if is_compressed else ".bin")
61
+
62
+ # Compute hash of uncompressed data
63
+ computed_hash = hashlib.blake2b(file_contents).hexdigest()
64
+
65
+ # Verify hash matches filename
66
+ if computed_hash != expected_hash:
67
+ raise ValueError(
68
+ f"Hash verification failed: expected '{expected_hash}' but computed '{computed_hash}'"
69
+ )
70
+
71
+ try:
72
+ # Load the tensor from memory buffer
73
+ tensor = torch.load(io.BytesIO(file_contents), map_location=device)
74
+ return tensor
75
+ except Exception as e:
76
+ raise RuntimeError(f"Failed to load tensor from {blob_path}: {str(e)}")