tritonparse 0.3.1.dev20251020071524__py3-none-any.whl → 0.3.1.dev20251021071528__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/reproducer/function_extractor.py +220 -0
- tritonparse/reproducer/placeholder_replacer.py +11 -0
- tritonparse/reproducer/templates/example.py +3 -363
- tritonparse/reproducer/utils.py +207 -113
- {tritonparse-0.3.1.dev20251020071524.dist-info → tritonparse-0.3.1.dev20251021071528.dist-info}/METADATA +1 -1
- {tritonparse-0.3.1.dev20251020071524.dist-info → tritonparse-0.3.1.dev20251021071528.dist-info}/RECORD +10 -9
- {tritonparse-0.3.1.dev20251020071524.dist-info → tritonparse-0.3.1.dev20251021071528.dist-info}/WHEEL +0 -0
- {tritonparse-0.3.1.dev20251020071524.dist-info → tritonparse-0.3.1.dev20251021071528.dist-info}/entry_points.txt +0 -0
- {tritonparse-0.3.1.dev20251020071524.dist-info → tritonparse-0.3.1.dev20251021071528.dist-info}/licenses/LICENSE +0 -0
- {tritonparse-0.3.1.dev20251020071524.dist-info → tritonparse-0.3.1.dev20251021071528.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Function extractor for reproducer utility functions.
|
|
3
|
+
|
|
4
|
+
This module extracts utility functions from utils.py and load_tensor.py
|
|
5
|
+
using AST parsing, and generates standalone code for reproducers.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import ast
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def extract_utility_functions() -> str:
|
|
13
|
+
"""
|
|
14
|
+
Extract all utility functions needed for the reproducer template.
|
|
15
|
+
|
|
16
|
+
Uses AST parsing to extract functions and constants from source files
|
|
17
|
+
without importing them (avoiding potential side effects).
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
str: Complete Python code including imports and all utility functions.
|
|
21
|
+
"""
|
|
22
|
+
# Prepare file paths
|
|
23
|
+
base_dir = Path(__file__).parent
|
|
24
|
+
utils_path = base_dir / "utils.py"
|
|
25
|
+
load_tensor_path = base_dir.parent / "tools" / "load_tensor.py"
|
|
26
|
+
|
|
27
|
+
# Parse source files
|
|
28
|
+
utils_tree, utils_lines = _parse_source_file(utils_path)
|
|
29
|
+
load_tensor_tree, load_tensor_lines = _parse_source_file(load_tensor_path)
|
|
30
|
+
|
|
31
|
+
# Define what to extract (in dependency order)
|
|
32
|
+
utils_function_names = [
|
|
33
|
+
"_get_triton_tensor_types",
|
|
34
|
+
"create_args_from_json_file",
|
|
35
|
+
"create_args_from_json",
|
|
36
|
+
"_apply_stride_and_offset",
|
|
37
|
+
"_create_base_tensor",
|
|
38
|
+
"_create_tensor",
|
|
39
|
+
"_create_arg_from_info",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
load_tensor_function_names = [
|
|
43
|
+
"load_tensor",
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
# Extract content
|
|
47
|
+
extracted_parts = []
|
|
48
|
+
|
|
49
|
+
# Add required imports
|
|
50
|
+
extracted_parts.append(_generate_imports())
|
|
51
|
+
|
|
52
|
+
# Extract constant
|
|
53
|
+
constant = _extract_assignment(
|
|
54
|
+
utils_tree, utils_lines, "TRITON_KERNELS_CUSTOM_TYPES"
|
|
55
|
+
)
|
|
56
|
+
if constant:
|
|
57
|
+
extracted_parts.append(constant)
|
|
58
|
+
|
|
59
|
+
# Extract load_tensor functions
|
|
60
|
+
extracted_parts.extend(
|
|
61
|
+
_extract_functions(
|
|
62
|
+
load_tensor_tree, load_tensor_lines, load_tensor_function_names
|
|
63
|
+
)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Extract utils functions
|
|
67
|
+
extracted_parts.extend(
|
|
68
|
+
_extract_functions(utils_tree, utils_lines, utils_function_names)
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Combine all parts
|
|
72
|
+
return "\n\n".join(extracted_parts)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _parse_source_file(file_path: Path) -> tuple[ast.Module, list[str]]:
|
|
76
|
+
"""
|
|
77
|
+
Parse a Python source file and return its AST and source lines.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
file_path: Path to the Python source file
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
tuple: (AST tree, list of source code lines)
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
FileNotFoundError: If the source file doesn't exist
|
|
87
|
+
SyntaxError: If the source file has syntax errors
|
|
88
|
+
"""
|
|
89
|
+
try:
|
|
90
|
+
source_code = file_path.read_text(encoding="utf-8")
|
|
91
|
+
tree = ast.parse(source_code, filename=str(file_path))
|
|
92
|
+
except FileNotFoundError as e:
|
|
93
|
+
raise FileNotFoundError(f"Source file not found: {file_path}") from e
|
|
94
|
+
except SyntaxError as e:
|
|
95
|
+
raise SyntaxError(f"Failed to parse {file_path}: {e}") from e
|
|
96
|
+
|
|
97
|
+
lines = source_code.splitlines()
|
|
98
|
+
return tree, lines
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _extract_assignment(
|
|
102
|
+
tree: ast.Module, lines: list[str], var_name: str
|
|
103
|
+
) -> str | None:
|
|
104
|
+
"""
|
|
105
|
+
Extract a module-level assignment statement by variable name.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
tree: AST tree of the source file
|
|
109
|
+
lines: Source code lines
|
|
110
|
+
var_name: Name of the variable to extract
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Complete assignment statement source code, or None if not found
|
|
114
|
+
|
|
115
|
+
Example:
|
|
116
|
+
Extracts:
|
|
117
|
+
TRITON_KERNELS_CUSTOM_TYPES = (
|
|
118
|
+
importlib.util.find_spec("triton_kernels") is not None
|
|
119
|
+
and importlib.util.find_spec("triton_kernels.tensor") is not None
|
|
120
|
+
)
|
|
121
|
+
"""
|
|
122
|
+
# Search only at module level
|
|
123
|
+
for node in tree.body:
|
|
124
|
+
if isinstance(node, ast.Assign):
|
|
125
|
+
for target in node.targets:
|
|
126
|
+
if isinstance(target, ast.Name) and target.id == var_name:
|
|
127
|
+
# Found it! Extract source code using line numbers
|
|
128
|
+
start_line = node.lineno - 1 # Convert to 0-based index
|
|
129
|
+
end_line = node.end_lineno # Already suitable for slicing
|
|
130
|
+
assignment_lines = lines[start_line:end_line]
|
|
131
|
+
return "\n".join(assignment_lines)
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _extract_function(tree: ast.Module, lines: list[str], func_name: str) -> str | None:
|
|
136
|
+
"""
|
|
137
|
+
Extract a function definition by name, including decorators.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
tree: AST tree of the source file
|
|
141
|
+
lines: Source code lines
|
|
142
|
+
func_name: Name of the function to extract
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Complete function source code including decorators, or None if not found
|
|
146
|
+
|
|
147
|
+
Example:
|
|
148
|
+
Extracts:
|
|
149
|
+
@lru_cache(maxsize=1)
|
|
150
|
+
def _get_triton_tensor_types():
|
|
151
|
+
'''Docstring'''
|
|
152
|
+
...
|
|
153
|
+
"""
|
|
154
|
+
# Walk the entire tree (handles nested functions if needed)
|
|
155
|
+
for node in ast.walk(tree):
|
|
156
|
+
if isinstance(node, ast.FunctionDef) and node.name == func_name:
|
|
157
|
+
# If function has decorators, start from the first decorator
|
|
158
|
+
if node.decorator_list:
|
|
159
|
+
start_line = node.decorator_list[0].lineno - 1
|
|
160
|
+
else:
|
|
161
|
+
start_line = node.lineno - 1
|
|
162
|
+
|
|
163
|
+
end_line = node.end_lineno
|
|
164
|
+
func_lines = lines[start_line:end_line]
|
|
165
|
+
return "\n".join(func_lines)
|
|
166
|
+
return None
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _extract_functions(
|
|
170
|
+
tree: ast.Module, lines: list[str], func_names: list[str]
|
|
171
|
+
) -> list[str]:
|
|
172
|
+
"""
|
|
173
|
+
Extract multiple functions from a source file.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
tree: AST tree of the source file
|
|
177
|
+
lines: Source code lines
|
|
178
|
+
func_names: List of function names to extract
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
List of function source codes in the same order as func_names
|
|
182
|
+
|
|
183
|
+
Raises:
|
|
184
|
+
ValueError: If any function is not found
|
|
185
|
+
"""
|
|
186
|
+
extracted = []
|
|
187
|
+
for func_name in func_names:
|
|
188
|
+
func_source = _extract_function(tree, lines, func_name)
|
|
189
|
+
if func_source is None:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"Function '{func_name}' not found in source. "
|
|
192
|
+
f"Available functions might have been renamed or removed."
|
|
193
|
+
)
|
|
194
|
+
extracted.append(func_source)
|
|
195
|
+
return extracted
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _generate_imports() -> str:
|
|
199
|
+
"""
|
|
200
|
+
Generate the import statements needed for the extracted functions.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
str: Import statements as a single string
|
|
204
|
+
"""
|
|
205
|
+
imports = [
|
|
206
|
+
"import gzip",
|
|
207
|
+
"import hashlib",
|
|
208
|
+
"import importlib",
|
|
209
|
+
"import importlib.util",
|
|
210
|
+
"import io",
|
|
211
|
+
"import json",
|
|
212
|
+
"import logging",
|
|
213
|
+
"import sys",
|
|
214
|
+
"from functools import lru_cache",
|
|
215
|
+
"from pathlib import Path",
|
|
216
|
+
"from typing import Union",
|
|
217
|
+
"",
|
|
218
|
+
"import torch",
|
|
219
|
+
]
|
|
220
|
+
return "\n".join(imports)
|
|
@@ -2,6 +2,7 @@ from abc import ABC
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Dict, Protocol
|
|
4
4
|
|
|
5
|
+
from tritonparse.reproducer.function_extractor import extract_utility_functions
|
|
5
6
|
from tritonparse.reproducer.ingestion.ndjson import ContextBundle
|
|
6
7
|
from tritonparse.reproducer.types import KernelImportMode
|
|
7
8
|
from tritonparse.reproducer.utils import (
|
|
@@ -82,6 +83,9 @@ class DefaultPlaceholderReplacer(PlaceholderReplacer):
|
|
|
82
83
|
)
|
|
83
84
|
self.register("# {{KERNEL_SYSPATH_PLACEHOLDER}}", self._replace_kernel_syspath)
|
|
84
85
|
self.register("# {{KERNEL_IMPORT_PLACEHOLDER}}", self._replace_kernel_import)
|
|
86
|
+
self.register(
|
|
87
|
+
"# {{UTILITY_FUNCTIONS_PLACEHOLDER}}", self._replace_utility_functions
|
|
88
|
+
)
|
|
85
89
|
self.register(
|
|
86
90
|
"# {{KERNEL_INVOCATION_PLACEHOLDER}}", self._replace_kernel_invocation
|
|
87
91
|
)
|
|
@@ -217,6 +221,13 @@ triton.autotune = _patched_autotune
|
|
|
217
221
|
else:
|
|
218
222
|
raise ValueError(f"Unknown kernel_import mode: {kernel_import}")
|
|
219
223
|
|
|
224
|
+
def _replace_utility_functions(
|
|
225
|
+
self, code: str, context_bundle: ContextBundle, **kwargs
|
|
226
|
+
) -> str:
|
|
227
|
+
"""Replace the utility functions placeholder with extracted functions."""
|
|
228
|
+
utility_code = extract_utility_functions()
|
|
229
|
+
return code.replace("# {{UTILITY_FUNCTIONS_PLACEHOLDER}}", utility_code)
|
|
230
|
+
|
|
220
231
|
def _replace_kernel_invocation(
|
|
221
232
|
self, code: str, context_bundle: ContextBundle, **kwargs
|
|
222
233
|
) -> str:
|
|
@@ -3,18 +3,6 @@ This file is automatically generated by TritonParse reproducer.
|
|
|
3
3
|
It contains a smallest testing example for a Triton kernel.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
import gzip
|
|
7
|
-
import hashlib
|
|
8
|
-
import importlib
|
|
9
|
-
import importlib.util
|
|
10
|
-
import io
|
|
11
|
-
import json
|
|
12
|
-
import logging
|
|
13
|
-
import sys
|
|
14
|
-
from functools import lru_cache
|
|
15
|
-
from pathlib import Path
|
|
16
|
-
from typing import Union
|
|
17
|
-
|
|
18
6
|
import torch
|
|
19
7
|
|
|
20
8
|
# {{IR_OVERRIDE_SETUP_PLACEHOLDER}}
|
|
@@ -23,361 +11,13 @@ import torch
|
|
|
23
11
|
|
|
24
12
|
# {{KERNEL_IMPORT_PLACEHOLDER}}
|
|
25
13
|
|
|
26
|
-
|
|
27
|
-
importlib.util.find_spec("triton_kernels") is not None
|
|
28
|
-
and importlib.util.find_spec("triton_kernels.tensor") is not None
|
|
29
|
-
)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
@lru_cache(maxsize=1)
|
|
33
|
-
def _get_triton_tensor_types():
|
|
34
|
-
"""
|
|
35
|
-
Import and cache Triton custom tensor types.
|
|
36
|
-
|
|
37
|
-
Returns:
|
|
38
|
-
tuple: (Tensor, Storage, StridedLayout) classes from triton_kernels.tensor.
|
|
39
|
-
|
|
40
|
-
Raises:
|
|
41
|
-
ImportError: If the optional module 'triton_kernels.tensor' is not available.
|
|
42
|
-
"""
|
|
43
|
-
mod = importlib.import_module("triton_kernels.tensor")
|
|
44
|
-
return (
|
|
45
|
-
mod.Tensor,
|
|
46
|
-
mod.Storage,
|
|
47
|
-
mod.StridedLayout,
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
def load_tensor(tensor_file_path: Union[str, Path], device: str = None) -> torch.Tensor:
|
|
52
|
-
"""
|
|
53
|
-
Load a tensor from its file path and verify its integrity using the hash in the filename.
|
|
54
|
-
|
|
55
|
-
Args:
|
|
56
|
-
tensor_file_path (str | Path): Direct path to the tensor file. Supports both:
|
|
57
|
-
- .bin.gz: gzip-compressed tensor (hash is of uncompressed data)
|
|
58
|
-
- .bin: uncompressed tensor (for backward compatibility)
|
|
59
|
-
device (str, optional): Device to load the tensor to (e.g., 'cuda:0', 'cpu').
|
|
60
|
-
If None, keeps the tensor on its original device.
|
|
61
|
-
|
|
62
|
-
Returns:
|
|
63
|
-
torch.Tensor: The loaded tensor (moved to the specified device if provided)
|
|
64
|
-
|
|
65
|
-
Raises:
|
|
66
|
-
FileNotFoundError: If the tensor file doesn't exist
|
|
67
|
-
RuntimeError: If the tensor cannot be loaded
|
|
68
|
-
ValueError: If the computed hash doesn't match the filename hash
|
|
69
|
-
"""
|
|
70
|
-
blob_path = Path(tensor_file_path)
|
|
71
|
-
|
|
72
|
-
if not blob_path.exists():
|
|
73
|
-
raise FileNotFoundError(f"Tensor blob not found: {blob_path}")
|
|
74
|
-
|
|
75
|
-
# Detect compression by file extension
|
|
76
|
-
is_compressed = blob_path.name.endswith(".bin.gz")
|
|
77
|
-
|
|
78
|
-
# Read file contents (decompress if needed)
|
|
79
|
-
try:
|
|
80
|
-
with open(blob_path, "rb") as f:
|
|
81
|
-
file_obj = gzip.GzipFile(fileobj=f, mode="rb") if is_compressed else f
|
|
82
|
-
file_contents = file_obj.read()
|
|
83
|
-
except (OSError, gzip.BadGzipFile) as e:
|
|
84
|
-
if is_compressed:
|
|
85
|
-
raise RuntimeError(f"Failed to decompress gzip file {blob_path}: {str(e)}")
|
|
86
|
-
else:
|
|
87
|
-
raise RuntimeError(f"Failed to read file {blob_path}: {str(e)}")
|
|
88
|
-
|
|
89
|
-
# Extract expected hash from filename
|
|
90
|
-
# abc123.bin.gz -> abc123 or abc123.bin -> abc123
|
|
91
|
-
expected_hash = blob_path.name.removesuffix(".bin.gz" if is_compressed else ".bin")
|
|
92
|
-
|
|
93
|
-
# Compute hash of uncompressed data
|
|
94
|
-
computed_hash = hashlib.blake2b(file_contents).hexdigest()
|
|
95
|
-
|
|
96
|
-
# Verify hash matches filename
|
|
97
|
-
if computed_hash != expected_hash:
|
|
98
|
-
raise ValueError(
|
|
99
|
-
f"Hash verification failed: expected '{expected_hash}' but computed '{computed_hash}'"
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
try:
|
|
103
|
-
# Load the tensor from memory buffer
|
|
104
|
-
tensor = torch.load(io.BytesIO(file_contents), map_location=device)
|
|
105
|
-
return tensor
|
|
106
|
-
except Exception as e:
|
|
107
|
-
raise RuntimeError(f"Failed to load tensor from {blob_path}: {str(e)}")
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
def create_args_from_json_file(json_path):
|
|
111
|
-
with open(json_path, "r") as f:
|
|
112
|
-
data = json.load(f)
|
|
113
|
-
return create_args_from_json(data)
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
def create_args_from_json(data):
|
|
117
|
-
"""
|
|
118
|
-
Parse a reproducer JSON and build kernel grid and argument dictionary.
|
|
119
|
-
|
|
120
|
-
Args:
|
|
121
|
-
json_path (str): Path to the JSON file describing the kernel launch.
|
|
122
|
-
|
|
123
|
-
Returns:
|
|
124
|
-
tuple[list, dict]: Grid specification list and map of argument name to value.
|
|
125
|
-
"""
|
|
126
|
-
# Handle data format validation and extraction
|
|
127
|
-
if isinstance(data, list):
|
|
128
|
-
if len(data) != 1:
|
|
129
|
-
print(
|
|
130
|
-
f"Error: Expected single element list, got list with {len(data)} elements"
|
|
131
|
-
)
|
|
132
|
-
sys.exit(1)
|
|
133
|
-
data = data[0]
|
|
134
|
-
elif not isinstance(data, dict):
|
|
135
|
-
print(f"Error: Expected list or dict, got {type(data)}")
|
|
136
|
-
sys.exit(1)
|
|
137
|
-
|
|
138
|
-
grid = data.get("grid", [])
|
|
139
|
-
args_dict = {}
|
|
140
|
-
extracted_args = data.get("extracted_args", {})
|
|
141
|
-
|
|
142
|
-
for arg_name, arg_info in extracted_args.items():
|
|
143
|
-
args_dict[arg_name] = _create_arg_from_info(arg_info)
|
|
144
|
-
|
|
145
|
-
return grid, args_dict
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
def _apply_stride_and_offset(tensor, shape, stride, storage_offset):
|
|
149
|
-
"""
|
|
150
|
-
Apply custom stride and storage offset to a tensor if needed.
|
|
151
|
-
|
|
152
|
-
Args:
|
|
153
|
-
tensor: The base contiguous tensor
|
|
154
|
-
shape: The desired shape
|
|
155
|
-
stride: The desired stride (or None for contiguous)
|
|
156
|
-
storage_offset: The desired storage offset
|
|
157
|
-
|
|
158
|
-
Returns:
|
|
159
|
-
torch.Tensor: The strided tensor view or original tensor if contiguous
|
|
160
|
-
"""
|
|
161
|
-
if stride is None:
|
|
162
|
-
return tensor
|
|
163
|
-
|
|
164
|
-
# Calculate expected contiguous stride
|
|
165
|
-
expected_contiguous_stride = []
|
|
166
|
-
s = 1
|
|
167
|
-
for dim_size in reversed(shape):
|
|
168
|
-
expected_contiguous_stride.insert(0, s)
|
|
169
|
-
s *= dim_size
|
|
170
|
-
|
|
171
|
-
# If stride matches contiguous stride and no storage offset, return as-is
|
|
172
|
-
if tuple(stride) == tuple(expected_contiguous_stride) and storage_offset == 0:
|
|
173
|
-
return tensor
|
|
174
|
-
|
|
175
|
-
# Calculate required storage size
|
|
176
|
-
if len(shape) > 0 and len(stride) > 0:
|
|
177
|
-
max_offset = storage_offset
|
|
178
|
-
for dim_stride, dim_size in zip(stride, shape):
|
|
179
|
-
if dim_size > 0:
|
|
180
|
-
max_offset += dim_stride * (dim_size - 1)
|
|
181
|
-
storage_size = max_offset + 1
|
|
182
|
-
else:
|
|
183
|
-
storage_size = storage_offset + 1
|
|
184
|
-
|
|
185
|
-
# Create larger storage tensor and create strided view
|
|
186
|
-
storage_tensor = torch.empty(storage_size, dtype=tensor.dtype, device=tensor.device)
|
|
187
|
-
|
|
188
|
-
# Create strided view
|
|
189
|
-
strided_view = storage_tensor.as_strided(
|
|
190
|
-
size=shape, stride=stride, storage_offset=storage_offset
|
|
191
|
-
)
|
|
192
|
-
|
|
193
|
-
# Copy data from the base tensor into the strided layout
|
|
194
|
-
strided_view.copy_(tensor.flatten()[: strided_view.numel()].view(shape))
|
|
195
|
-
|
|
196
|
-
return strided_view
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
def _create_base_tensor(arg_info) -> torch.Tensor:
|
|
200
|
-
if arg_info.get("blob_path"):
|
|
201
|
-
return load_tensor(arg_info.get("blob_path"), arg_info.get("device"))
|
|
202
|
-
|
|
203
|
-
# Extract basic tensor properties
|
|
204
|
-
dtype_str = arg_info.get("dtype")
|
|
205
|
-
try:
|
|
206
|
-
torch_dtype = getattr(torch, dtype_str.split(".")[-1])
|
|
207
|
-
except AttributeError:
|
|
208
|
-
logging.error(f"Unsupported dtype: {dtype_str}. Defaulting to float32.")
|
|
209
|
-
torch_dtype = torch.float32
|
|
210
|
-
|
|
211
|
-
shape = arg_info.get("shape", [])
|
|
212
|
-
device = arg_info.get("device", "cpu")
|
|
213
|
-
|
|
214
|
-
# Extract statistical information if available
|
|
215
|
-
mean = arg_info.get("mean")
|
|
216
|
-
std = arg_info.get("std")
|
|
217
|
-
min_val = arg_info.get("min")
|
|
218
|
-
max_val = arg_info.get("max")
|
|
219
|
-
has_stats = (
|
|
220
|
-
mean is not None
|
|
221
|
-
and std is not None
|
|
222
|
-
and min_val is not None
|
|
223
|
-
and max_val is not None
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
if arg_info.get("tensor_capture_error", False):
|
|
227
|
-
logging.error(
|
|
228
|
-
f"Error: Tensor '{arg_info.get('name', '')}' had capture error. Generating random tensor instead."
|
|
229
|
-
)
|
|
230
|
-
|
|
231
|
-
# Use a dummy tensor to check properties of the dtype
|
|
232
|
-
tensor_props = torch.empty(0, dtype=torch_dtype)
|
|
233
|
-
|
|
234
|
-
# Case 1: Floating point types
|
|
235
|
-
if tensor_props.is_floating_point():
|
|
236
|
-
if has_stats:
|
|
237
|
-
# Generate tensor with statistical properties matching original data
|
|
238
|
-
if std == 0 or min_val == max_val:
|
|
239
|
-
# Constant tensor
|
|
240
|
-
return torch.full(shape, mean, dtype=torch_dtype, device=device)
|
|
241
|
-
# Generate normal distribution with mean and std, then clamp to [min, max]
|
|
242
|
-
tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
243
|
-
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
244
|
-
return tensor.to(torch_dtype)
|
|
245
|
-
else:
|
|
246
|
-
# Fallback to original random generation
|
|
247
|
-
if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
248
|
-
tmp = torch.rand(shape, dtype=torch.float32, device=device)
|
|
249
|
-
return tmp.to(torch_dtype)
|
|
250
|
-
else:
|
|
251
|
-
return torch.empty(shape, dtype=torch_dtype, device=device).random_()
|
|
252
|
-
|
|
253
|
-
# Case 2: Integer types
|
|
254
|
-
elif torch_dtype in [
|
|
255
|
-
torch.int8,
|
|
256
|
-
torch.int16,
|
|
257
|
-
torch.int32,
|
|
258
|
-
torch.int64,
|
|
259
|
-
torch.uint8,
|
|
260
|
-
torch.bool,
|
|
261
|
-
]:
|
|
262
|
-
if has_stats and torch_dtype != torch.bool:
|
|
263
|
-
# Generate tensor with statistical properties, then round for integers
|
|
264
|
-
if std == 0 or min_val == max_val:
|
|
265
|
-
# Constant tensor
|
|
266
|
-
return torch.full(shape, int(mean), dtype=torch_dtype, device=device)
|
|
267
|
-
tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
268
|
-
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
269
|
-
return torch.round(tensor).to(torch_dtype)
|
|
270
|
-
else:
|
|
271
|
-
# Fallback to original random generation
|
|
272
|
-
return torch.empty(shape, dtype=torch_dtype, device=device).random_()
|
|
273
|
-
|
|
274
|
-
# Case 3: Complex numbers need special handling
|
|
275
|
-
elif tensor_props.is_complex():
|
|
276
|
-
# Complex types: fallback to original logic for now
|
|
277
|
-
# TODO: Could be improved to use statistical info if available
|
|
278
|
-
float_dtype = torch.float32 if torch_dtype == torch.complex64 else torch.float64
|
|
279
|
-
real_part = torch.rand(shape, dtype=float_dtype, device=device)
|
|
280
|
-
imag_part = torch.rand(shape, dtype=float_dtype, device=device)
|
|
281
|
-
return torch.complex(real_part, imag_part)
|
|
282
|
-
|
|
283
|
-
# Case 4: Handle other unsigned integers (like uint32) which fail with random_()
|
|
284
|
-
elif "uint" in str(torch_dtype):
|
|
285
|
-
if has_stats:
|
|
286
|
-
# Generate tensor with statistical properties for unsigned integers
|
|
287
|
-
if std == 0 or min_val == max_val:
|
|
288
|
-
return torch.full(shape, int(mean), dtype=torch_dtype, device=device)
|
|
289
|
-
tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
290
|
-
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
291
|
-
return torch.round(tensor).to(torch_dtype)
|
|
292
|
-
else:
|
|
293
|
-
# Fallback to original random generation
|
|
294
|
-
return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
|
|
295
|
-
|
|
296
|
-
# Case 5: If we don't know how to handle the type, raise an error
|
|
297
|
-
else:
|
|
298
|
-
raise NotImplementedError(
|
|
299
|
-
f"Random data generation not implemented for dtype: {torch_dtype}"
|
|
300
|
-
)
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
def _create_tensor(arg_info) -> torch.Tensor:
|
|
304
|
-
tensor = _create_base_tensor(arg_info)
|
|
305
|
-
|
|
306
|
-
# Apply stride and storage offset if needed
|
|
307
|
-
shape = arg_info.get("shape", [])
|
|
308
|
-
stride = arg_info.get("stride")
|
|
309
|
-
storage_offset = arg_info.get("storage_offset", 0)
|
|
310
|
-
return _apply_stride_and_offset(tensor, shape, stride, storage_offset)
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
def _create_arg_from_info(arg_info):
|
|
314
|
-
"""
|
|
315
|
-
Recursively construct a kernel argument from its JSON schema.
|
|
316
|
-
|
|
317
|
-
Args:
|
|
318
|
-
arg_info (dict): JSON object describing a single argument, including
|
|
319
|
-
fields like 'type', 'value', 'dtype', 'shape', 'device', etc.
|
|
320
|
-
|
|
321
|
-
Returns:
|
|
322
|
-
Any: The constructed Python object suitable for kernel invocation.
|
|
323
|
-
|
|
324
|
-
Raises:
|
|
325
|
-
RuntimeError: When required optional dependencies are missing.
|
|
326
|
-
NotImplementedError: When a dtype or type is not supported yet.
|
|
327
|
-
"""
|
|
328
|
-
arg_type = arg_info.get("type")
|
|
329
|
-
|
|
330
|
-
if arg_type == "NoneType":
|
|
331
|
-
return None
|
|
332
|
-
|
|
333
|
-
if arg_type in ["int", "bool", "str", "float"]:
|
|
334
|
-
return arg_info.get("value")
|
|
335
|
-
|
|
336
|
-
elif arg_type == "tensor":
|
|
337
|
-
return _create_tensor(arg_info)
|
|
338
|
-
|
|
339
|
-
elif arg_type == "triton_kernels.tensor.Tensor":
|
|
340
|
-
if not TRITON_KERNELS_CUSTOM_TYPES:
|
|
341
|
-
raise RuntimeError(
|
|
342
|
-
"Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Tensor."
|
|
343
|
-
)
|
|
344
|
-
Tensor, Storage, StridedLayout = _get_triton_tensor_types()
|
|
345
|
-
storage = _create_arg_from_info(arg_info.get("storage"))
|
|
346
|
-
dtype_str = arg_info.get("dtype")
|
|
347
|
-
torch_dtype = getattr(torch, dtype_str.split(".")[-1])
|
|
348
|
-
return Tensor(
|
|
349
|
-
storage=storage,
|
|
350
|
-
shape=arg_info.get("shape"),
|
|
351
|
-
shape_max=arg_info.get("shape_max"),
|
|
352
|
-
dtype=torch_dtype,
|
|
353
|
-
)
|
|
354
|
-
|
|
355
|
-
elif arg_type == "triton_kernels.tensor.Storage":
|
|
356
|
-
if not TRITON_KERNELS_CUSTOM_TYPES:
|
|
357
|
-
raise RuntimeError(
|
|
358
|
-
"Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Storage."
|
|
359
|
-
)
|
|
360
|
-
Tensor, Storage, StridedLayout = _get_triton_tensor_types()
|
|
361
|
-
data = _create_arg_from_info(arg_info.get("data"))
|
|
362
|
-
layout = _create_arg_from_info(arg_info.get("layout"))
|
|
363
|
-
return Storage(data=data, layout=layout)
|
|
364
|
-
|
|
365
|
-
elif arg_type == "StridedLayout":
|
|
366
|
-
if not TRITON_KERNELS_CUSTOM_TYPES:
|
|
367
|
-
raise RuntimeError(
|
|
368
|
-
"Optional dependency 'triton_kernels.tensor' is not installed; cannot construct StridedLayout."
|
|
369
|
-
)
|
|
370
|
-
Tensor, Storage, StridedLayout = _get_triton_tensor_types()
|
|
371
|
-
return StridedLayout(shape=arg_info.get("initial_shape"))
|
|
372
|
-
else:
|
|
373
|
-
print(f"Warning: Unhandled argument type '{arg_type}'. Returning None.")
|
|
374
|
-
return None
|
|
14
|
+
# {{UTILITY_FUNCTIONS_PLACEHOLDER}}
|
|
375
15
|
|
|
376
16
|
|
|
377
17
|
if __name__ == "__main__":
|
|
378
|
-
script_dir = Path(__file__).resolve().parent
|
|
18
|
+
script_dir = Path(__file__).resolve().parent # noqa: F821
|
|
379
19
|
json_file = script_dir / "{{JSON_FILE_NAME_PLACEHOLDER}}"
|
|
380
|
-
grid, args_dict = create_args_from_json_file(str(json_file))
|
|
20
|
+
grid, args_dict = create_args_from_json_file(str(json_file)) # noqa: F821
|
|
381
21
|
|
|
382
22
|
print("Generated kernel arguments dictionary:")
|
|
383
23
|
for name, arg in args_dict.items():
|
tritonparse/reproducer/utils.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import importlib
|
|
2
2
|
import importlib.util
|
|
3
3
|
import json
|
|
4
|
+
import logging
|
|
4
5
|
import sys
|
|
5
6
|
from datetime import datetime
|
|
6
7
|
from functools import lru_cache
|
|
@@ -27,9 +28,9 @@ def _get_triton_tensor_types():
|
|
|
27
28
|
)
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
def
|
|
31
|
+
def create_args_from_json_file(json_path):
|
|
31
32
|
"""
|
|
32
|
-
|
|
33
|
+
Load and parse a reproducer JSON file.
|
|
33
34
|
|
|
34
35
|
Args:
|
|
35
36
|
json_path (str): Path to the JSON file describing the kernel launch.
|
|
@@ -39,6 +40,19 @@ def create_args_from_json(json_path):
|
|
|
39
40
|
"""
|
|
40
41
|
with open(json_path, "r") as f:
|
|
41
42
|
data = json.load(f)
|
|
43
|
+
return create_args_from_json(data)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def create_args_from_json(data):
|
|
47
|
+
"""
|
|
48
|
+
Parse a reproducer JSON and build kernel grid and argument dictionary.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
data (dict | list): JSON data describing the kernel launch.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
tuple[list, dict]: Grid specification list and map of argument name to value.
|
|
55
|
+
"""
|
|
42
56
|
# Handle data format validation and extraction
|
|
43
57
|
if isinstance(data, list):
|
|
44
58
|
if len(data) != 1:
|
|
@@ -61,6 +75,192 @@ def create_args_from_json(json_path):
|
|
|
61
75
|
return grid, args_dict
|
|
62
76
|
|
|
63
77
|
|
|
78
|
+
def _apply_stride_and_offset(tensor, shape, stride, storage_offset):
|
|
79
|
+
"""
|
|
80
|
+
Apply custom stride and storage offset to a tensor if needed.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
tensor: The base contiguous tensor
|
|
84
|
+
shape: The desired shape
|
|
85
|
+
stride: The desired stride (or None for contiguous)
|
|
86
|
+
storage_offset: The desired storage offset
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
torch.Tensor: The strided tensor view or original tensor if contiguous
|
|
90
|
+
"""
|
|
91
|
+
if stride is None:
|
|
92
|
+
return tensor
|
|
93
|
+
|
|
94
|
+
# Calculate expected contiguous stride
|
|
95
|
+
expected_contiguous_stride = []
|
|
96
|
+
s = 1
|
|
97
|
+
for dim_size in reversed(shape):
|
|
98
|
+
expected_contiguous_stride.insert(0, s)
|
|
99
|
+
s *= dim_size
|
|
100
|
+
|
|
101
|
+
# If stride matches contiguous stride and no storage offset, return as-is
|
|
102
|
+
if tuple(stride) == tuple(expected_contiguous_stride) and storage_offset == 0:
|
|
103
|
+
return tensor
|
|
104
|
+
|
|
105
|
+
# Calculate required storage size
|
|
106
|
+
if len(shape) > 0 and len(stride) > 0:
|
|
107
|
+
max_offset = storage_offset
|
|
108
|
+
for dim_stride, dim_size in zip(stride, shape):
|
|
109
|
+
if dim_size > 0:
|
|
110
|
+
max_offset += dim_stride * (dim_size - 1)
|
|
111
|
+
storage_size = max_offset + 1
|
|
112
|
+
else:
|
|
113
|
+
storage_size = storage_offset + 1
|
|
114
|
+
|
|
115
|
+
# Create larger storage tensor and create strided view
|
|
116
|
+
storage_tensor = torch.empty(storage_size, dtype=tensor.dtype, device=tensor.device)
|
|
117
|
+
|
|
118
|
+
# Create strided view
|
|
119
|
+
strided_view = storage_tensor.as_strided(
|
|
120
|
+
size=shape, stride=stride, storage_offset=storage_offset
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Copy data from the base tensor into the strided layout
|
|
124
|
+
strided_view.copy_(tensor.flatten()[: strided_view.numel()].view(shape))
|
|
125
|
+
|
|
126
|
+
return strided_view
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _create_base_tensor(arg_info) -> torch.Tensor:
|
|
130
|
+
"""
|
|
131
|
+
Create a base tensor without stride/offset modifications.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
arg_info (dict): Argument information including dtype, shape, device, etc.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
torch.Tensor: The created base tensor
|
|
138
|
+
"""
|
|
139
|
+
if arg_info.get("blob_path"):
|
|
140
|
+
return load_tensor(arg_info.get("blob_path"), arg_info.get("device"))
|
|
141
|
+
|
|
142
|
+
# Extract basic tensor properties
|
|
143
|
+
dtype_str = arg_info.get("dtype")
|
|
144
|
+
try:
|
|
145
|
+
torch_dtype = getattr(torch, dtype_str.split(".")[-1])
|
|
146
|
+
except AttributeError:
|
|
147
|
+
logging.error(f"Unsupported dtype: {dtype_str}. Defaulting to float32.")
|
|
148
|
+
torch_dtype = torch.float32
|
|
149
|
+
|
|
150
|
+
shape = arg_info.get("shape", [])
|
|
151
|
+
device = arg_info.get("device", "cpu")
|
|
152
|
+
# Normalize cuda device to cuda:0
|
|
153
|
+
if isinstance(device, str) and device.startswith("cuda"):
|
|
154
|
+
device = "cuda:0"
|
|
155
|
+
|
|
156
|
+
# Extract statistical information if available
|
|
157
|
+
mean = arg_info.get("mean")
|
|
158
|
+
std = arg_info.get("std")
|
|
159
|
+
min_val = arg_info.get("min")
|
|
160
|
+
max_val = arg_info.get("max")
|
|
161
|
+
has_stats = (
|
|
162
|
+
mean is not None
|
|
163
|
+
and std is not None
|
|
164
|
+
and min_val is not None
|
|
165
|
+
and max_val is not None
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if arg_info.get("tensor_capture_error", False):
|
|
169
|
+
logging.error(
|
|
170
|
+
f"Error: Tensor '{arg_info.get('name', '')}' had capture error. Generating random tensor instead."
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Use a dummy tensor to check properties of the dtype
|
|
174
|
+
tensor_props = torch.empty(0, dtype=torch_dtype)
|
|
175
|
+
|
|
176
|
+
# Case 1: Floating point types
|
|
177
|
+
if tensor_props.is_floating_point():
|
|
178
|
+
if has_stats:
|
|
179
|
+
# Generate tensor with statistical properties matching original data
|
|
180
|
+
if std == 0 or min_val == max_val:
|
|
181
|
+
# Constant tensor
|
|
182
|
+
return torch.full(shape, mean, dtype=torch_dtype, device=device)
|
|
183
|
+
# Generate normal distribution with mean and std, then clamp to [min, max]
|
|
184
|
+
tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
185
|
+
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
186
|
+
return tensor.to(torch_dtype)
|
|
187
|
+
else:
|
|
188
|
+
# Fallback to original random generation
|
|
189
|
+
if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
190
|
+
tmp = torch.rand(shape, dtype=torch.float32, device=device)
|
|
191
|
+
return tmp.to(torch_dtype)
|
|
192
|
+
else:
|
|
193
|
+
return torch.empty(shape, dtype=torch_dtype, device=device).random_()
|
|
194
|
+
|
|
195
|
+
# Case 2: Integer types
|
|
196
|
+
elif torch_dtype in [
|
|
197
|
+
torch.int8,
|
|
198
|
+
torch.int16,
|
|
199
|
+
torch.int32,
|
|
200
|
+
torch.int64,
|
|
201
|
+
torch.uint8,
|
|
202
|
+
torch.bool,
|
|
203
|
+
]:
|
|
204
|
+
if has_stats and torch_dtype != torch.bool:
|
|
205
|
+
# Generate tensor with statistical properties, then round for integers
|
|
206
|
+
if std == 0 or min_val == max_val:
|
|
207
|
+
# Constant tensor
|
|
208
|
+
return torch.full(shape, int(mean), dtype=torch_dtype, device=device)
|
|
209
|
+
tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
210
|
+
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
211
|
+
return torch.round(tensor).to(torch_dtype)
|
|
212
|
+
else:
|
|
213
|
+
# Fallback to original random generation
|
|
214
|
+
return torch.empty(shape, dtype=torch_dtype, device=device).random_()
|
|
215
|
+
|
|
216
|
+
# Case 3: Complex numbers need special handling
|
|
217
|
+
elif tensor_props.is_complex():
|
|
218
|
+
# Complex types: fallback to original logic for now
|
|
219
|
+
# TODO: Could be improved to use statistical info if available
|
|
220
|
+
float_dtype = torch.float32 if torch_dtype == torch.complex64 else torch.float64
|
|
221
|
+
real_part = torch.rand(shape, dtype=float_dtype, device=device)
|
|
222
|
+
imag_part = torch.rand(shape, dtype=float_dtype, device=device)
|
|
223
|
+
return torch.complex(real_part, imag_part)
|
|
224
|
+
|
|
225
|
+
# Case 4: Handle other unsigned integers (like uint32) which fail with random_()
|
|
226
|
+
elif "uint" in str(torch_dtype):
|
|
227
|
+
if has_stats:
|
|
228
|
+
# Generate tensor with statistical properties for unsigned integers
|
|
229
|
+
if std == 0 or min_val == max_val:
|
|
230
|
+
return torch.full(shape, int(mean), dtype=torch_dtype, device=device)
|
|
231
|
+
tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
232
|
+
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
233
|
+
return torch.round(tensor).to(torch_dtype)
|
|
234
|
+
else:
|
|
235
|
+
# Fallback to original random generation
|
|
236
|
+
return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
|
|
237
|
+
|
|
238
|
+
# Case 5: If we don't know how to handle the type, raise an error
|
|
239
|
+
else:
|
|
240
|
+
raise NotImplementedError(
|
|
241
|
+
f"Random data generation not implemented for dtype: {torch_dtype}"
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def _create_tensor(arg_info) -> torch.Tensor:
|
|
246
|
+
"""
|
|
247
|
+
Create a tensor with stride and storage offset if needed.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
arg_info (dict): Argument information including dtype, shape, stride, etc.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
torch.Tensor: The created tensor with applied stride/offset
|
|
254
|
+
"""
|
|
255
|
+
tensor = _create_base_tensor(arg_info)
|
|
256
|
+
|
|
257
|
+
# Apply stride and storage offset if needed
|
|
258
|
+
shape = arg_info.get("shape", [])
|
|
259
|
+
stride = arg_info.get("stride")
|
|
260
|
+
storage_offset = arg_info.get("storage_offset", 0)
|
|
261
|
+
return _apply_stride_and_offset(tensor, shape, stride, storage_offset)
|
|
262
|
+
|
|
263
|
+
|
|
64
264
|
def _create_arg_from_info(arg_info):
|
|
65
265
|
"""
|
|
66
266
|
Recursively construct a kernel argument from its JSON schema.
|
|
@@ -78,120 +278,14 @@ def _create_arg_from_info(arg_info):
|
|
|
78
278
|
"""
|
|
79
279
|
arg_type = arg_info.get("type")
|
|
80
280
|
|
|
81
|
-
if arg_type
|
|
281
|
+
if arg_type == "NoneType":
|
|
282
|
+
return None
|
|
283
|
+
|
|
284
|
+
if arg_type in ["int", "bool", "str", "float"]:
|
|
82
285
|
return arg_info.get("value")
|
|
83
286
|
|
|
84
287
|
elif arg_type == "tensor":
|
|
85
|
-
|
|
86
|
-
return load_tensor(arg_info.get("blob_path"), arg_info.get("device"))
|
|
87
|
-
|
|
88
|
-
# Extract basic tensor properties
|
|
89
|
-
dtype_str = arg_info.get("dtype")
|
|
90
|
-
try:
|
|
91
|
-
torch_dtype = getattr(torch, dtype_str.split(".")[-1])
|
|
92
|
-
except AttributeError:
|
|
93
|
-
logger.error(f"Unsupported dtype: {dtype_str}. Defaulting to float32.")
|
|
94
|
-
torch_dtype = torch.float32
|
|
95
|
-
|
|
96
|
-
shape = arg_info.get("shape", [])
|
|
97
|
-
device = arg_info.get("device", "cpu")
|
|
98
|
-
|
|
99
|
-
# Extract statistical information if available
|
|
100
|
-
mean = arg_info.get("mean")
|
|
101
|
-
std = arg_info.get("std")
|
|
102
|
-
min_val = arg_info.get("min")
|
|
103
|
-
max_val = arg_info.get("max")
|
|
104
|
-
has_stats = (
|
|
105
|
-
mean is not None
|
|
106
|
-
and std is not None
|
|
107
|
-
and min_val is not None
|
|
108
|
-
and max_val is not None
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
# Use a dummy tensor to check properties of the dtype
|
|
112
|
-
tensor_props = torch.empty(0, dtype=torch_dtype)
|
|
113
|
-
|
|
114
|
-
# Case 1: Floating point types
|
|
115
|
-
if tensor_props.is_floating_point():
|
|
116
|
-
if has_stats:
|
|
117
|
-
# Generate tensor with statistical properties matching original data
|
|
118
|
-
if std == 0 or min_val == max_val:
|
|
119
|
-
# Constant tensor
|
|
120
|
-
return torch.full(shape, mean, dtype=torch_dtype, device=device)
|
|
121
|
-
# Generate normal distribution with mean and std, then clamp to [min, max]
|
|
122
|
-
tensor = (
|
|
123
|
-
torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
124
|
-
)
|
|
125
|
-
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
126
|
-
return tensor.to(torch_dtype)
|
|
127
|
-
else:
|
|
128
|
-
# Fallback to original random generation
|
|
129
|
-
if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
130
|
-
tmp = torch.rand(shape, dtype=torch.float32, device=device)
|
|
131
|
-
return tmp.to(torch_dtype)
|
|
132
|
-
else:
|
|
133
|
-
return torch.empty(
|
|
134
|
-
shape, dtype=torch_dtype, device=device
|
|
135
|
-
).random_()
|
|
136
|
-
|
|
137
|
-
# Case 2: Integer types
|
|
138
|
-
elif torch_dtype in [
|
|
139
|
-
torch.int8,
|
|
140
|
-
torch.int16,
|
|
141
|
-
torch.int32,
|
|
142
|
-
torch.int64,
|
|
143
|
-
torch.uint8,
|
|
144
|
-
torch.bool,
|
|
145
|
-
]:
|
|
146
|
-
if has_stats and torch_dtype != torch.bool:
|
|
147
|
-
# Generate tensor with statistical properties, then round for integers
|
|
148
|
-
if std == 0 or min_val == max_val:
|
|
149
|
-
# Constant tensor
|
|
150
|
-
return torch.full(
|
|
151
|
-
shape, int(mean), dtype=torch_dtype, device=device
|
|
152
|
-
)
|
|
153
|
-
tensor = (
|
|
154
|
-
torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
155
|
-
)
|
|
156
|
-
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
157
|
-
return torch.round(tensor).to(torch_dtype)
|
|
158
|
-
else:
|
|
159
|
-
# Fallback to original random generation
|
|
160
|
-
return torch.empty(shape, dtype=torch_dtype, device=device).random_()
|
|
161
|
-
|
|
162
|
-
# Case 3: Complex numbers need special handling
|
|
163
|
-
elif tensor_props.is_complex():
|
|
164
|
-
# Complex types: fallback to original logic for now
|
|
165
|
-
# TODO: Could be improved to use statistical info if available
|
|
166
|
-
float_dtype = (
|
|
167
|
-
torch.float32 if torch_dtype == torch.complex64 else torch.float64
|
|
168
|
-
)
|
|
169
|
-
real_part = torch.rand(shape, dtype=float_dtype, device=device)
|
|
170
|
-
imag_part = torch.rand(shape, dtype=float_dtype, device=device)
|
|
171
|
-
return torch.complex(real_part, imag_part)
|
|
172
|
-
|
|
173
|
-
# Case 4: Handle other unsigned integers (like uint32) which fail with random_()
|
|
174
|
-
elif "uint" in str(torch_dtype):
|
|
175
|
-
if has_stats:
|
|
176
|
-
# Generate tensor with statistical properties for unsigned integers
|
|
177
|
-
if std == 0 or min_val == max_val:
|
|
178
|
-
return torch.full(
|
|
179
|
-
shape, int(mean), dtype=torch_dtype, device=device
|
|
180
|
-
)
|
|
181
|
-
tensor = (
|
|
182
|
-
torch.randn(shape, dtype=torch.float32, device=device) * std + mean
|
|
183
|
-
)
|
|
184
|
-
tensor = torch.clamp(tensor, min=min_val, max=max_val)
|
|
185
|
-
return torch.round(tensor).to(torch_dtype)
|
|
186
|
-
else:
|
|
187
|
-
# Fallback to original random generation
|
|
188
|
-
return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
|
|
189
|
-
|
|
190
|
-
# Case 5: If we don't know how to handle the type, raise an error
|
|
191
|
-
else:
|
|
192
|
-
raise NotImplementedError(
|
|
193
|
-
f"Random data generation not implemented for dtype: {torch_dtype}"
|
|
194
|
-
)
|
|
288
|
+
return _create_tensor(arg_info)
|
|
195
289
|
|
|
196
290
|
elif arg_type == "triton_kernels.tensor.Tensor":
|
|
197
291
|
if not TRITON_KERNELS_CUSTOM_TYPES:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tritonparse
|
|
3
|
-
Version: 0.3.1.
|
|
3
|
+
Version: 0.3.1.dev20251021071528
|
|
4
4
|
Summary: TritonParse: A Compiler Tracer, Visualizer, and mini-Reproducer Generator for Triton Kernels
|
|
5
5
|
Author-email: Yueming Hao <yhao@meta.com>
|
|
6
6
|
License-Expression: BSD-3-Clause
|
|
@@ -16,13 +16,14 @@ tritonparse/trace_processor.py,sha256=brQBt26jdB6-quJXP5-warp2j31JSjOOFJa5ayiUZ5
|
|
|
16
16
|
tritonparse/utils.py,sha256=Jnlptcd79llSDev-_1XyyOnv2izUqv0PEL74A8GF2tc,4565
|
|
17
17
|
tritonparse/reproducer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
18
|
tritonparse/reproducer/cli.py,sha256=MqYuuAP-uAIWWLQxixwyyBHJGSaQdG3xXGaVjERTLX8,1522
|
|
19
|
+
tritonparse/reproducer/function_extractor.py,sha256=pB5b52Xlk9-fe-Gs-z_Rn2HkZd1bC9qHTw1JZc_epBM,6588
|
|
19
20
|
tritonparse/reproducer/orchestrator.py,sha256=ZGdsiOZg2Bg6o0cqZSCGx_v6mdzR4yZX7S3SVQjZ9-c,3182
|
|
20
|
-
tritonparse/reproducer/placeholder_replacer.py,sha256=
|
|
21
|
+
tritonparse/reproducer/placeholder_replacer.py,sha256=TvqQIrOubmBHJ_pl0ZvkpP4dIDiMYxUll1q9MIbzZco,9552
|
|
21
22
|
tritonparse/reproducer/types.py,sha256=AfVl83zoJZQ58JJoplCcMC51gK-M-OKcafatYEIGgW0,509
|
|
22
|
-
tritonparse/reproducer/utils.py,sha256=
|
|
23
|
+
tritonparse/reproducer/utils.py,sha256=LHmkM9EEAn9BwNyhHuCEp1tm7omeILbhkAcOmN-iLrk,16544
|
|
23
24
|
tritonparse/reproducer/ingestion/ndjson.py,sha256=pEujTl5xXW2E2DEW8ngxXQ8qP9oawb90wBVTWHDs1jk,7372
|
|
24
25
|
tritonparse/reproducer/templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
|
-
tritonparse/reproducer/templates/example.py,sha256=
|
|
26
|
+
tritonparse/reproducer/templates/example.py,sha256=jR3c8_d7fAFJYaj1DuUuthnI4Xd-_606bWDRdUPMNyo,785
|
|
26
27
|
tritonparse/reproducer/templates/loader.py,sha256=HqjfThdDVg7q2bYWry78sIaVRkUpkcA8KQDt83YrlVE,1920
|
|
27
28
|
tritonparse/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
28
29
|
tritonparse/tools/decompress_bin_ndjson.py,sha256=kpt7DM_sSA334F1X45xdkP2OR9LuB27Pc50EkGr6CPM,4144
|
|
@@ -31,9 +32,9 @@ tritonparse/tools/format_fix.py,sha256=Ol0Sjui8D7OzHwbamAfGnq8V5Y63uwNaFTKSORN5H
|
|
|
31
32
|
tritonparse/tools/load_tensor.py,sha256=94-TiSYlpXJx4MPmGK1ovmZlTt56Q_B3KQeCPaA6Cnw,2734
|
|
32
33
|
tritonparse/tools/prettify_ndjson.py,sha256=r2YlHwFDTHgML7KljRmMsHaDg29q8gOQAgyDKWJhxRM,11062
|
|
33
34
|
tritonparse/tools/readme.md,sha256=w6PWYfYnRgoPArLjxG9rVrpcLUkoVMGuRlbpF-o0IQM,110
|
|
34
|
-
tritonparse-0.3.1.
|
|
35
|
-
tritonparse-0.3.1.
|
|
36
|
-
tritonparse-0.3.1.
|
|
37
|
-
tritonparse-0.3.1.
|
|
38
|
-
tritonparse-0.3.1.
|
|
39
|
-
tritonparse-0.3.1.
|
|
35
|
+
tritonparse-0.3.1.dev20251021071528.dist-info/licenses/LICENSE,sha256=4ZciugpyN7wcM4L-9pyDh_etvMUeIfBhDTyH1zeZlQM,1515
|
|
36
|
+
tritonparse-0.3.1.dev20251021071528.dist-info/METADATA,sha256=zeBrpftm2P8cY-iyCCyxc80qriG_rC7apZmRKy8CTKE,8278
|
|
37
|
+
tritonparse-0.3.1.dev20251021071528.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
38
|
+
tritonparse-0.3.1.dev20251021071528.dist-info/entry_points.txt,sha256=wEXdaieDoRRCCdhEv2p_C68iytnaXU_2pwt5CqjfbWY,56
|
|
39
|
+
tritonparse-0.3.1.dev20251021071528.dist-info/top_level.txt,sha256=ITcTKgp3vf_bXV9vixuQU9IrZa3L1EfDSZwvRzRaoJU,12
|
|
40
|
+
tritonparse-0.3.1.dev20251021071528.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|