tritonparse 0.3.1.dev20251020071524__py3-none-any.whl → 0.3.1.dev20251021071528__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritonparse might be problematic. Click here for more details.

@@ -0,0 +1,220 @@
1
+ """
2
+ Function extractor for reproducer utility functions.
3
+
4
+ This module extracts utility functions from utils.py and load_tensor.py
5
+ using AST parsing, and generates standalone code for reproducers.
6
+ """
7
+
8
+ import ast
9
+ from pathlib import Path
10
+
11
+
12
+ def extract_utility_functions() -> str:
13
+ """
14
+ Extract all utility functions needed for the reproducer template.
15
+
16
+ Uses AST parsing to extract functions and constants from source files
17
+ without importing them (avoiding potential side effects).
18
+
19
+ Returns:
20
+ str: Complete Python code including imports and all utility functions.
21
+ """
22
+ # Prepare file paths
23
+ base_dir = Path(__file__).parent
24
+ utils_path = base_dir / "utils.py"
25
+ load_tensor_path = base_dir.parent / "tools" / "load_tensor.py"
26
+
27
+ # Parse source files
28
+ utils_tree, utils_lines = _parse_source_file(utils_path)
29
+ load_tensor_tree, load_tensor_lines = _parse_source_file(load_tensor_path)
30
+
31
+ # Define what to extract (in dependency order)
32
+ utils_function_names = [
33
+ "_get_triton_tensor_types",
34
+ "create_args_from_json_file",
35
+ "create_args_from_json",
36
+ "_apply_stride_and_offset",
37
+ "_create_base_tensor",
38
+ "_create_tensor",
39
+ "_create_arg_from_info",
40
+ ]
41
+
42
+ load_tensor_function_names = [
43
+ "load_tensor",
44
+ ]
45
+
46
+ # Extract content
47
+ extracted_parts = []
48
+
49
+ # Add required imports
50
+ extracted_parts.append(_generate_imports())
51
+
52
+ # Extract constant
53
+ constant = _extract_assignment(
54
+ utils_tree, utils_lines, "TRITON_KERNELS_CUSTOM_TYPES"
55
+ )
56
+ if constant:
57
+ extracted_parts.append(constant)
58
+
59
+ # Extract load_tensor functions
60
+ extracted_parts.extend(
61
+ _extract_functions(
62
+ load_tensor_tree, load_tensor_lines, load_tensor_function_names
63
+ )
64
+ )
65
+
66
+ # Extract utils functions
67
+ extracted_parts.extend(
68
+ _extract_functions(utils_tree, utils_lines, utils_function_names)
69
+ )
70
+
71
+ # Combine all parts
72
+ return "\n\n".join(extracted_parts)
73
+
74
+
75
+ def _parse_source_file(file_path: Path) -> tuple[ast.Module, list[str]]:
76
+ """
77
+ Parse a Python source file and return its AST and source lines.
78
+
79
+ Args:
80
+ file_path: Path to the Python source file
81
+
82
+ Returns:
83
+ tuple: (AST tree, list of source code lines)
84
+
85
+ Raises:
86
+ FileNotFoundError: If the source file doesn't exist
87
+ SyntaxError: If the source file has syntax errors
88
+ """
89
+ try:
90
+ source_code = file_path.read_text(encoding="utf-8")
91
+ tree = ast.parse(source_code, filename=str(file_path))
92
+ except FileNotFoundError as e:
93
+ raise FileNotFoundError(f"Source file not found: {file_path}") from e
94
+ except SyntaxError as e:
95
+ raise SyntaxError(f"Failed to parse {file_path}: {e}") from e
96
+
97
+ lines = source_code.splitlines()
98
+ return tree, lines
99
+
100
+
101
+ def _extract_assignment(
102
+ tree: ast.Module, lines: list[str], var_name: str
103
+ ) -> str | None:
104
+ """
105
+ Extract a module-level assignment statement by variable name.
106
+
107
+ Args:
108
+ tree: AST tree of the source file
109
+ lines: Source code lines
110
+ var_name: Name of the variable to extract
111
+
112
+ Returns:
113
+ Complete assignment statement source code, or None if not found
114
+
115
+ Example:
116
+ Extracts:
117
+ TRITON_KERNELS_CUSTOM_TYPES = (
118
+ importlib.util.find_spec("triton_kernels") is not None
119
+ and importlib.util.find_spec("triton_kernels.tensor") is not None
120
+ )
121
+ """
122
+ # Search only at module level
123
+ for node in tree.body:
124
+ if isinstance(node, ast.Assign):
125
+ for target in node.targets:
126
+ if isinstance(target, ast.Name) and target.id == var_name:
127
+ # Found it! Extract source code using line numbers
128
+ start_line = node.lineno - 1 # Convert to 0-based index
129
+ end_line = node.end_lineno # Already suitable for slicing
130
+ assignment_lines = lines[start_line:end_line]
131
+ return "\n".join(assignment_lines)
132
+ return None
133
+
134
+
135
+ def _extract_function(tree: ast.Module, lines: list[str], func_name: str) -> str | None:
136
+ """
137
+ Extract a function definition by name, including decorators.
138
+
139
+ Args:
140
+ tree: AST tree of the source file
141
+ lines: Source code lines
142
+ func_name: Name of the function to extract
143
+
144
+ Returns:
145
+ Complete function source code including decorators, or None if not found
146
+
147
+ Example:
148
+ Extracts:
149
+ @lru_cache(maxsize=1)
150
+ def _get_triton_tensor_types():
151
+ '''Docstring'''
152
+ ...
153
+ """
154
+ # Walk the entire tree (handles nested functions if needed)
155
+ for node in ast.walk(tree):
156
+ if isinstance(node, ast.FunctionDef) and node.name == func_name:
157
+ # If function has decorators, start from the first decorator
158
+ if node.decorator_list:
159
+ start_line = node.decorator_list[0].lineno - 1
160
+ else:
161
+ start_line = node.lineno - 1
162
+
163
+ end_line = node.end_lineno
164
+ func_lines = lines[start_line:end_line]
165
+ return "\n".join(func_lines)
166
+ return None
167
+
168
+
169
+ def _extract_functions(
170
+ tree: ast.Module, lines: list[str], func_names: list[str]
171
+ ) -> list[str]:
172
+ """
173
+ Extract multiple functions from a source file.
174
+
175
+ Args:
176
+ tree: AST tree of the source file
177
+ lines: Source code lines
178
+ func_names: List of function names to extract
179
+
180
+ Returns:
181
+ List of function source codes in the same order as func_names
182
+
183
+ Raises:
184
+ ValueError: If any function is not found
185
+ """
186
+ extracted = []
187
+ for func_name in func_names:
188
+ func_source = _extract_function(tree, lines, func_name)
189
+ if func_source is None:
190
+ raise ValueError(
191
+ f"Function '{func_name}' not found in source. "
192
+ f"Available functions might have been renamed or removed."
193
+ )
194
+ extracted.append(func_source)
195
+ return extracted
196
+
197
+
198
+ def _generate_imports() -> str:
199
+ """
200
+ Generate the import statements needed for the extracted functions.
201
+
202
+ Returns:
203
+ str: Import statements as a single string
204
+ """
205
+ imports = [
206
+ "import gzip",
207
+ "import hashlib",
208
+ "import importlib",
209
+ "import importlib.util",
210
+ "import io",
211
+ "import json",
212
+ "import logging",
213
+ "import sys",
214
+ "from functools import lru_cache",
215
+ "from pathlib import Path",
216
+ "from typing import Union",
217
+ "",
218
+ "import torch",
219
+ ]
220
+ return "\n".join(imports)
@@ -2,6 +2,7 @@ from abc import ABC
2
2
 
3
3
  from typing import Any, Dict, Protocol
4
4
 
5
+ from tritonparse.reproducer.function_extractor import extract_utility_functions
5
6
  from tritonparse.reproducer.ingestion.ndjson import ContextBundle
6
7
  from tritonparse.reproducer.types import KernelImportMode
7
8
  from tritonparse.reproducer.utils import (
@@ -82,6 +83,9 @@ class DefaultPlaceholderReplacer(PlaceholderReplacer):
82
83
  )
83
84
  self.register("# {{KERNEL_SYSPATH_PLACEHOLDER}}", self._replace_kernel_syspath)
84
85
  self.register("# {{KERNEL_IMPORT_PLACEHOLDER}}", self._replace_kernel_import)
86
+ self.register(
87
+ "# {{UTILITY_FUNCTIONS_PLACEHOLDER}}", self._replace_utility_functions
88
+ )
85
89
  self.register(
86
90
  "# {{KERNEL_INVOCATION_PLACEHOLDER}}", self._replace_kernel_invocation
87
91
  )
@@ -217,6 +221,13 @@ triton.autotune = _patched_autotune
217
221
  else:
218
222
  raise ValueError(f"Unknown kernel_import mode: {kernel_import}")
219
223
 
224
+ def _replace_utility_functions(
225
+ self, code: str, context_bundle: ContextBundle, **kwargs
226
+ ) -> str:
227
+ """Replace the utility functions placeholder with extracted functions."""
228
+ utility_code = extract_utility_functions()
229
+ return code.replace("# {{UTILITY_FUNCTIONS_PLACEHOLDER}}", utility_code)
230
+
220
231
  def _replace_kernel_invocation(
221
232
  self, code: str, context_bundle: ContextBundle, **kwargs
222
233
  ) -> str:
@@ -3,18 +3,6 @@ This file is automatically generated by TritonParse reproducer.
3
3
  It contains a smallest testing example for a Triton kernel.
4
4
  """
5
5
 
6
- import gzip
7
- import hashlib
8
- import importlib
9
- import importlib.util
10
- import io
11
- import json
12
- import logging
13
- import sys
14
- from functools import lru_cache
15
- from pathlib import Path
16
- from typing import Union
17
-
18
6
  import torch
19
7
 
20
8
  # {{IR_OVERRIDE_SETUP_PLACEHOLDER}}
@@ -23,361 +11,13 @@ import torch
23
11
 
24
12
  # {{KERNEL_IMPORT_PLACEHOLDER}}
25
13
 
26
- TRITON_KERNELS_CUSTOM_TYPES = (
27
- importlib.util.find_spec("triton_kernels") is not None
28
- and importlib.util.find_spec("triton_kernels.tensor") is not None
29
- )
30
-
31
-
32
- @lru_cache(maxsize=1)
33
- def _get_triton_tensor_types():
34
- """
35
- Import and cache Triton custom tensor types.
36
-
37
- Returns:
38
- tuple: (Tensor, Storage, StridedLayout) classes from triton_kernels.tensor.
39
-
40
- Raises:
41
- ImportError: If the optional module 'triton_kernels.tensor' is not available.
42
- """
43
- mod = importlib.import_module("triton_kernels.tensor")
44
- return (
45
- mod.Tensor,
46
- mod.Storage,
47
- mod.StridedLayout,
48
- )
49
-
50
-
51
- def load_tensor(tensor_file_path: Union[str, Path], device: str = None) -> torch.Tensor:
52
- """
53
- Load a tensor from its file path and verify its integrity using the hash in the filename.
54
-
55
- Args:
56
- tensor_file_path (str | Path): Direct path to the tensor file. Supports both:
57
- - .bin.gz: gzip-compressed tensor (hash is of uncompressed data)
58
- - .bin: uncompressed tensor (for backward compatibility)
59
- device (str, optional): Device to load the tensor to (e.g., 'cuda:0', 'cpu').
60
- If None, keeps the tensor on its original device.
61
-
62
- Returns:
63
- torch.Tensor: The loaded tensor (moved to the specified device if provided)
64
-
65
- Raises:
66
- FileNotFoundError: If the tensor file doesn't exist
67
- RuntimeError: If the tensor cannot be loaded
68
- ValueError: If the computed hash doesn't match the filename hash
69
- """
70
- blob_path = Path(tensor_file_path)
71
-
72
- if not blob_path.exists():
73
- raise FileNotFoundError(f"Tensor blob not found: {blob_path}")
74
-
75
- # Detect compression by file extension
76
- is_compressed = blob_path.name.endswith(".bin.gz")
77
-
78
- # Read file contents (decompress if needed)
79
- try:
80
- with open(blob_path, "rb") as f:
81
- file_obj = gzip.GzipFile(fileobj=f, mode="rb") if is_compressed else f
82
- file_contents = file_obj.read()
83
- except (OSError, gzip.BadGzipFile) as e:
84
- if is_compressed:
85
- raise RuntimeError(f"Failed to decompress gzip file {blob_path}: {str(e)}")
86
- else:
87
- raise RuntimeError(f"Failed to read file {blob_path}: {str(e)}")
88
-
89
- # Extract expected hash from filename
90
- # abc123.bin.gz -> abc123 or abc123.bin -> abc123
91
- expected_hash = blob_path.name.removesuffix(".bin.gz" if is_compressed else ".bin")
92
-
93
- # Compute hash of uncompressed data
94
- computed_hash = hashlib.blake2b(file_contents).hexdigest()
95
-
96
- # Verify hash matches filename
97
- if computed_hash != expected_hash:
98
- raise ValueError(
99
- f"Hash verification failed: expected '{expected_hash}' but computed '{computed_hash}'"
100
- )
101
-
102
- try:
103
- # Load the tensor from memory buffer
104
- tensor = torch.load(io.BytesIO(file_contents), map_location=device)
105
- return tensor
106
- except Exception as e:
107
- raise RuntimeError(f"Failed to load tensor from {blob_path}: {str(e)}")
108
-
109
-
110
- def create_args_from_json_file(json_path):
111
- with open(json_path, "r") as f:
112
- data = json.load(f)
113
- return create_args_from_json(data)
114
-
115
-
116
- def create_args_from_json(data):
117
- """
118
- Parse a reproducer JSON and build kernel grid and argument dictionary.
119
-
120
- Args:
121
- json_path (str): Path to the JSON file describing the kernel launch.
122
-
123
- Returns:
124
- tuple[list, dict]: Grid specification list and map of argument name to value.
125
- """
126
- # Handle data format validation and extraction
127
- if isinstance(data, list):
128
- if len(data) != 1:
129
- print(
130
- f"Error: Expected single element list, got list with {len(data)} elements"
131
- )
132
- sys.exit(1)
133
- data = data[0]
134
- elif not isinstance(data, dict):
135
- print(f"Error: Expected list or dict, got {type(data)}")
136
- sys.exit(1)
137
-
138
- grid = data.get("grid", [])
139
- args_dict = {}
140
- extracted_args = data.get("extracted_args", {})
141
-
142
- for arg_name, arg_info in extracted_args.items():
143
- args_dict[arg_name] = _create_arg_from_info(arg_info)
144
-
145
- return grid, args_dict
146
-
147
-
148
- def _apply_stride_and_offset(tensor, shape, stride, storage_offset):
149
- """
150
- Apply custom stride and storage offset to a tensor if needed.
151
-
152
- Args:
153
- tensor: The base contiguous tensor
154
- shape: The desired shape
155
- stride: The desired stride (or None for contiguous)
156
- storage_offset: The desired storage offset
157
-
158
- Returns:
159
- torch.Tensor: The strided tensor view or original tensor if contiguous
160
- """
161
- if stride is None:
162
- return tensor
163
-
164
- # Calculate expected contiguous stride
165
- expected_contiguous_stride = []
166
- s = 1
167
- for dim_size in reversed(shape):
168
- expected_contiguous_stride.insert(0, s)
169
- s *= dim_size
170
-
171
- # If stride matches contiguous stride and no storage offset, return as-is
172
- if tuple(stride) == tuple(expected_contiguous_stride) and storage_offset == 0:
173
- return tensor
174
-
175
- # Calculate required storage size
176
- if len(shape) > 0 and len(stride) > 0:
177
- max_offset = storage_offset
178
- for dim_stride, dim_size in zip(stride, shape):
179
- if dim_size > 0:
180
- max_offset += dim_stride * (dim_size - 1)
181
- storage_size = max_offset + 1
182
- else:
183
- storage_size = storage_offset + 1
184
-
185
- # Create larger storage tensor and create strided view
186
- storage_tensor = torch.empty(storage_size, dtype=tensor.dtype, device=tensor.device)
187
-
188
- # Create strided view
189
- strided_view = storage_tensor.as_strided(
190
- size=shape, stride=stride, storage_offset=storage_offset
191
- )
192
-
193
- # Copy data from the base tensor into the strided layout
194
- strided_view.copy_(tensor.flatten()[: strided_view.numel()].view(shape))
195
-
196
- return strided_view
197
-
198
-
199
- def _create_base_tensor(arg_info) -> torch.Tensor:
200
- if arg_info.get("blob_path"):
201
- return load_tensor(arg_info.get("blob_path"), arg_info.get("device"))
202
-
203
- # Extract basic tensor properties
204
- dtype_str = arg_info.get("dtype")
205
- try:
206
- torch_dtype = getattr(torch, dtype_str.split(".")[-1])
207
- except AttributeError:
208
- logging.error(f"Unsupported dtype: {dtype_str}. Defaulting to float32.")
209
- torch_dtype = torch.float32
210
-
211
- shape = arg_info.get("shape", [])
212
- device = arg_info.get("device", "cpu")
213
-
214
- # Extract statistical information if available
215
- mean = arg_info.get("mean")
216
- std = arg_info.get("std")
217
- min_val = arg_info.get("min")
218
- max_val = arg_info.get("max")
219
- has_stats = (
220
- mean is not None
221
- and std is not None
222
- and min_val is not None
223
- and max_val is not None
224
- )
225
-
226
- if arg_info.get("tensor_capture_error", False):
227
- logging.error(
228
- f"Error: Tensor '{arg_info.get('name', '')}' had capture error. Generating random tensor instead."
229
- )
230
-
231
- # Use a dummy tensor to check properties of the dtype
232
- tensor_props = torch.empty(0, dtype=torch_dtype)
233
-
234
- # Case 1: Floating point types
235
- if tensor_props.is_floating_point():
236
- if has_stats:
237
- # Generate tensor with statistical properties matching original data
238
- if std == 0 or min_val == max_val:
239
- # Constant tensor
240
- return torch.full(shape, mean, dtype=torch_dtype, device=device)
241
- # Generate normal distribution with mean and std, then clamp to [min, max]
242
- tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
243
- tensor = torch.clamp(tensor, min=min_val, max=max_val)
244
- return tensor.to(torch_dtype)
245
- else:
246
- # Fallback to original random generation
247
- if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
248
- tmp = torch.rand(shape, dtype=torch.float32, device=device)
249
- return tmp.to(torch_dtype)
250
- else:
251
- return torch.empty(shape, dtype=torch_dtype, device=device).random_()
252
-
253
- # Case 2: Integer types
254
- elif torch_dtype in [
255
- torch.int8,
256
- torch.int16,
257
- torch.int32,
258
- torch.int64,
259
- torch.uint8,
260
- torch.bool,
261
- ]:
262
- if has_stats and torch_dtype != torch.bool:
263
- # Generate tensor with statistical properties, then round for integers
264
- if std == 0 or min_val == max_val:
265
- # Constant tensor
266
- return torch.full(shape, int(mean), dtype=torch_dtype, device=device)
267
- tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
268
- tensor = torch.clamp(tensor, min=min_val, max=max_val)
269
- return torch.round(tensor).to(torch_dtype)
270
- else:
271
- # Fallback to original random generation
272
- return torch.empty(shape, dtype=torch_dtype, device=device).random_()
273
-
274
- # Case 3: Complex numbers need special handling
275
- elif tensor_props.is_complex():
276
- # Complex types: fallback to original logic for now
277
- # TODO: Could be improved to use statistical info if available
278
- float_dtype = torch.float32 if torch_dtype == torch.complex64 else torch.float64
279
- real_part = torch.rand(shape, dtype=float_dtype, device=device)
280
- imag_part = torch.rand(shape, dtype=float_dtype, device=device)
281
- return torch.complex(real_part, imag_part)
282
-
283
- # Case 4: Handle other unsigned integers (like uint32) which fail with random_()
284
- elif "uint" in str(torch_dtype):
285
- if has_stats:
286
- # Generate tensor with statistical properties for unsigned integers
287
- if std == 0 or min_val == max_val:
288
- return torch.full(shape, int(mean), dtype=torch_dtype, device=device)
289
- tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
290
- tensor = torch.clamp(tensor, min=min_val, max=max_val)
291
- return torch.round(tensor).to(torch_dtype)
292
- else:
293
- # Fallback to original random generation
294
- return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
295
-
296
- # Case 5: If we don't know how to handle the type, raise an error
297
- else:
298
- raise NotImplementedError(
299
- f"Random data generation not implemented for dtype: {torch_dtype}"
300
- )
301
-
302
-
303
- def _create_tensor(arg_info) -> torch.Tensor:
304
- tensor = _create_base_tensor(arg_info)
305
-
306
- # Apply stride and storage offset if needed
307
- shape = arg_info.get("shape", [])
308
- stride = arg_info.get("stride")
309
- storage_offset = arg_info.get("storage_offset", 0)
310
- return _apply_stride_and_offset(tensor, shape, stride, storage_offset)
311
-
312
-
313
- def _create_arg_from_info(arg_info):
314
- """
315
- Recursively construct a kernel argument from its JSON schema.
316
-
317
- Args:
318
- arg_info (dict): JSON object describing a single argument, including
319
- fields like 'type', 'value', 'dtype', 'shape', 'device', etc.
320
-
321
- Returns:
322
- Any: The constructed Python object suitable for kernel invocation.
323
-
324
- Raises:
325
- RuntimeError: When required optional dependencies are missing.
326
- NotImplementedError: When a dtype or type is not supported yet.
327
- """
328
- arg_type = arg_info.get("type")
329
-
330
- if arg_type == "NoneType":
331
- return None
332
-
333
- if arg_type in ["int", "bool", "str", "float"]:
334
- return arg_info.get("value")
335
-
336
- elif arg_type == "tensor":
337
- return _create_tensor(arg_info)
338
-
339
- elif arg_type == "triton_kernels.tensor.Tensor":
340
- if not TRITON_KERNELS_CUSTOM_TYPES:
341
- raise RuntimeError(
342
- "Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Tensor."
343
- )
344
- Tensor, Storage, StridedLayout = _get_triton_tensor_types()
345
- storage = _create_arg_from_info(arg_info.get("storage"))
346
- dtype_str = arg_info.get("dtype")
347
- torch_dtype = getattr(torch, dtype_str.split(".")[-1])
348
- return Tensor(
349
- storage=storage,
350
- shape=arg_info.get("shape"),
351
- shape_max=arg_info.get("shape_max"),
352
- dtype=torch_dtype,
353
- )
354
-
355
- elif arg_type == "triton_kernels.tensor.Storage":
356
- if not TRITON_KERNELS_CUSTOM_TYPES:
357
- raise RuntimeError(
358
- "Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Storage."
359
- )
360
- Tensor, Storage, StridedLayout = _get_triton_tensor_types()
361
- data = _create_arg_from_info(arg_info.get("data"))
362
- layout = _create_arg_from_info(arg_info.get("layout"))
363
- return Storage(data=data, layout=layout)
364
-
365
- elif arg_type == "StridedLayout":
366
- if not TRITON_KERNELS_CUSTOM_TYPES:
367
- raise RuntimeError(
368
- "Optional dependency 'triton_kernels.tensor' is not installed; cannot construct StridedLayout."
369
- )
370
- Tensor, Storage, StridedLayout = _get_triton_tensor_types()
371
- return StridedLayout(shape=arg_info.get("initial_shape"))
372
- else:
373
- print(f"Warning: Unhandled argument type '{arg_type}'. Returning None.")
374
- return None
14
+ # {{UTILITY_FUNCTIONS_PLACEHOLDER}}
375
15
 
376
16
 
377
17
  if __name__ == "__main__":
378
- script_dir = Path(__file__).resolve().parent
18
+ script_dir = Path(__file__).resolve().parent # noqa: F821
379
19
  json_file = script_dir / "{{JSON_FILE_NAME_PLACEHOLDER}}"
380
- grid, args_dict = create_args_from_json_file(str(json_file))
20
+ grid, args_dict = create_args_from_json_file(str(json_file)) # noqa: F821
381
21
 
382
22
  print("Generated kernel arguments dictionary:")
383
23
  for name, arg in args_dict.items():
@@ -1,6 +1,7 @@
1
1
  import importlib
2
2
  import importlib.util
3
3
  import json
4
+ import logging
4
5
  import sys
5
6
  from datetime import datetime
6
7
  from functools import lru_cache
@@ -27,9 +28,9 @@ def _get_triton_tensor_types():
27
28
  )
28
29
 
29
30
 
30
- def create_args_from_json(json_path):
31
+ def create_args_from_json_file(json_path):
31
32
  """
32
- Parse a reproducer JSON and build kernel grid and argument dictionary.
33
+ Load and parse a reproducer JSON file.
33
34
 
34
35
  Args:
35
36
  json_path (str): Path to the JSON file describing the kernel launch.
@@ -39,6 +40,19 @@ def create_args_from_json(json_path):
39
40
  """
40
41
  with open(json_path, "r") as f:
41
42
  data = json.load(f)
43
+ return create_args_from_json(data)
44
+
45
+
46
+ def create_args_from_json(data):
47
+ """
48
+ Parse a reproducer JSON and build kernel grid and argument dictionary.
49
+
50
+ Args:
51
+ data (dict | list): JSON data describing the kernel launch.
52
+
53
+ Returns:
54
+ tuple[list, dict]: Grid specification list and map of argument name to value.
55
+ """
42
56
  # Handle data format validation and extraction
43
57
  if isinstance(data, list):
44
58
  if len(data) != 1:
@@ -61,6 +75,192 @@ def create_args_from_json(json_path):
61
75
  return grid, args_dict
62
76
 
63
77
 
78
+ def _apply_stride_and_offset(tensor, shape, stride, storage_offset):
79
+ """
80
+ Apply custom stride and storage offset to a tensor if needed.
81
+
82
+ Args:
83
+ tensor: The base contiguous tensor
84
+ shape: The desired shape
85
+ stride: The desired stride (or None for contiguous)
86
+ storage_offset: The desired storage offset
87
+
88
+ Returns:
89
+ torch.Tensor: The strided tensor view or original tensor if contiguous
90
+ """
91
+ if stride is None:
92
+ return tensor
93
+
94
+ # Calculate expected contiguous stride
95
+ expected_contiguous_stride = []
96
+ s = 1
97
+ for dim_size in reversed(shape):
98
+ expected_contiguous_stride.insert(0, s)
99
+ s *= dim_size
100
+
101
+ # If stride matches contiguous stride and no storage offset, return as-is
102
+ if tuple(stride) == tuple(expected_contiguous_stride) and storage_offset == 0:
103
+ return tensor
104
+
105
+ # Calculate required storage size
106
+ if len(shape) > 0 and len(stride) > 0:
107
+ max_offset = storage_offset
108
+ for dim_stride, dim_size in zip(stride, shape):
109
+ if dim_size > 0:
110
+ max_offset += dim_stride * (dim_size - 1)
111
+ storage_size = max_offset + 1
112
+ else:
113
+ storage_size = storage_offset + 1
114
+
115
+ # Create larger storage tensor and create strided view
116
+ storage_tensor = torch.empty(storage_size, dtype=tensor.dtype, device=tensor.device)
117
+
118
+ # Create strided view
119
+ strided_view = storage_tensor.as_strided(
120
+ size=shape, stride=stride, storage_offset=storage_offset
121
+ )
122
+
123
+ # Copy data from the base tensor into the strided layout
124
+ strided_view.copy_(tensor.flatten()[: strided_view.numel()].view(shape))
125
+
126
+ return strided_view
127
+
128
+
129
+ def _create_base_tensor(arg_info) -> torch.Tensor:
130
+ """
131
+ Create a base tensor without stride/offset modifications.
132
+
133
+ Args:
134
+ arg_info (dict): Argument information including dtype, shape, device, etc.
135
+
136
+ Returns:
137
+ torch.Tensor: The created base tensor
138
+ """
139
+ if arg_info.get("blob_path"):
140
+ return load_tensor(arg_info.get("blob_path"), arg_info.get("device"))
141
+
142
+ # Extract basic tensor properties
143
+ dtype_str = arg_info.get("dtype")
144
+ try:
145
+ torch_dtype = getattr(torch, dtype_str.split(".")[-1])
146
+ except AttributeError:
147
+ logging.error(f"Unsupported dtype: {dtype_str}. Defaulting to float32.")
148
+ torch_dtype = torch.float32
149
+
150
+ shape = arg_info.get("shape", [])
151
+ device = arg_info.get("device", "cpu")
152
+ # Normalize cuda device to cuda:0
153
+ if isinstance(device, str) and device.startswith("cuda"):
154
+ device = "cuda:0"
155
+
156
+ # Extract statistical information if available
157
+ mean = arg_info.get("mean")
158
+ std = arg_info.get("std")
159
+ min_val = arg_info.get("min")
160
+ max_val = arg_info.get("max")
161
+ has_stats = (
162
+ mean is not None
163
+ and std is not None
164
+ and min_val is not None
165
+ and max_val is not None
166
+ )
167
+
168
+ if arg_info.get("tensor_capture_error", False):
169
+ logging.error(
170
+ f"Error: Tensor '{arg_info.get('name', '')}' had capture error. Generating random tensor instead."
171
+ )
172
+
173
+ # Use a dummy tensor to check properties of the dtype
174
+ tensor_props = torch.empty(0, dtype=torch_dtype)
175
+
176
+ # Case 1: Floating point types
177
+ if tensor_props.is_floating_point():
178
+ if has_stats:
179
+ # Generate tensor with statistical properties matching original data
180
+ if std == 0 or min_val == max_val:
181
+ # Constant tensor
182
+ return torch.full(shape, mean, dtype=torch_dtype, device=device)
183
+ # Generate normal distribution with mean and std, then clamp to [min, max]
184
+ tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
185
+ tensor = torch.clamp(tensor, min=min_val, max=max_val)
186
+ return tensor.to(torch_dtype)
187
+ else:
188
+ # Fallback to original random generation
189
+ if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
190
+ tmp = torch.rand(shape, dtype=torch.float32, device=device)
191
+ return tmp.to(torch_dtype)
192
+ else:
193
+ return torch.empty(shape, dtype=torch_dtype, device=device).random_()
194
+
195
+ # Case 2: Integer types
196
+ elif torch_dtype in [
197
+ torch.int8,
198
+ torch.int16,
199
+ torch.int32,
200
+ torch.int64,
201
+ torch.uint8,
202
+ torch.bool,
203
+ ]:
204
+ if has_stats and torch_dtype != torch.bool:
205
+ # Generate tensor with statistical properties, then round for integers
206
+ if std == 0 or min_val == max_val:
207
+ # Constant tensor
208
+ return torch.full(shape, int(mean), dtype=torch_dtype, device=device)
209
+ tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
210
+ tensor = torch.clamp(tensor, min=min_val, max=max_val)
211
+ return torch.round(tensor).to(torch_dtype)
212
+ else:
213
+ # Fallback to original random generation
214
+ return torch.empty(shape, dtype=torch_dtype, device=device).random_()
215
+
216
+ # Case 3: Complex numbers need special handling
217
+ elif tensor_props.is_complex():
218
+ # Complex types: fallback to original logic for now
219
+ # TODO: Could be improved to use statistical info if available
220
+ float_dtype = torch.float32 if torch_dtype == torch.complex64 else torch.float64
221
+ real_part = torch.rand(shape, dtype=float_dtype, device=device)
222
+ imag_part = torch.rand(shape, dtype=float_dtype, device=device)
223
+ return torch.complex(real_part, imag_part)
224
+
225
+ # Case 4: Handle other unsigned integers (like uint32) which fail with random_()
226
+ elif "uint" in str(torch_dtype):
227
+ if has_stats:
228
+ # Generate tensor with statistical properties for unsigned integers
229
+ if std == 0 or min_val == max_val:
230
+ return torch.full(shape, int(mean), dtype=torch_dtype, device=device)
231
+ tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
232
+ tensor = torch.clamp(tensor, min=min_val, max=max_val)
233
+ return torch.round(tensor).to(torch_dtype)
234
+ else:
235
+ # Fallback to original random generation
236
+ return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
237
+
238
+ # Case 5: If we don't know how to handle the type, raise an error
239
+ else:
240
+ raise NotImplementedError(
241
+ f"Random data generation not implemented for dtype: {torch_dtype}"
242
+ )
243
+
244
+
245
+ def _create_tensor(arg_info) -> torch.Tensor:
246
+ """
247
+ Create a tensor with stride and storage offset if needed.
248
+
249
+ Args:
250
+ arg_info (dict): Argument information including dtype, shape, stride, etc.
251
+
252
+ Returns:
253
+ torch.Tensor: The created tensor with applied stride/offset
254
+ """
255
+ tensor = _create_base_tensor(arg_info)
256
+
257
+ # Apply stride and storage offset if needed
258
+ shape = arg_info.get("shape", [])
259
+ stride = arg_info.get("stride")
260
+ storage_offset = arg_info.get("storage_offset", 0)
261
+ return _apply_stride_and_offset(tensor, shape, stride, storage_offset)
262
+
263
+
64
264
  def _create_arg_from_info(arg_info):
65
265
  """
66
266
  Recursively construct a kernel argument from its JSON schema.
@@ -78,120 +278,14 @@ def _create_arg_from_info(arg_info):
78
278
  """
79
279
  arg_type = arg_info.get("type")
80
280
 
81
- if arg_type in ["int", "bool"]:
281
+ if arg_type == "NoneType":
282
+ return None
283
+
284
+ if arg_type in ["int", "bool", "str", "float"]:
82
285
  return arg_info.get("value")
83
286
 
84
287
  elif arg_type == "tensor":
85
- if arg_info.get("blob_path"):
86
- return load_tensor(arg_info.get("blob_path"), arg_info.get("device"))
87
-
88
- # Extract basic tensor properties
89
- dtype_str = arg_info.get("dtype")
90
- try:
91
- torch_dtype = getattr(torch, dtype_str.split(".")[-1])
92
- except AttributeError:
93
- logger.error(f"Unsupported dtype: {dtype_str}. Defaulting to float32.")
94
- torch_dtype = torch.float32
95
-
96
- shape = arg_info.get("shape", [])
97
- device = arg_info.get("device", "cpu")
98
-
99
- # Extract statistical information if available
100
- mean = arg_info.get("mean")
101
- std = arg_info.get("std")
102
- min_val = arg_info.get("min")
103
- max_val = arg_info.get("max")
104
- has_stats = (
105
- mean is not None
106
- and std is not None
107
- and min_val is not None
108
- and max_val is not None
109
- )
110
-
111
- # Use a dummy tensor to check properties of the dtype
112
- tensor_props = torch.empty(0, dtype=torch_dtype)
113
-
114
- # Case 1: Floating point types
115
- if tensor_props.is_floating_point():
116
- if has_stats:
117
- # Generate tensor with statistical properties matching original data
118
- if std == 0 or min_val == max_val:
119
- # Constant tensor
120
- return torch.full(shape, mean, dtype=torch_dtype, device=device)
121
- # Generate normal distribution with mean and std, then clamp to [min, max]
122
- tensor = (
123
- torch.randn(shape, dtype=torch.float32, device=device) * std + mean
124
- )
125
- tensor = torch.clamp(tensor, min=min_val, max=max_val)
126
- return tensor.to(torch_dtype)
127
- else:
128
- # Fallback to original random generation
129
- if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
130
- tmp = torch.rand(shape, dtype=torch.float32, device=device)
131
- return tmp.to(torch_dtype)
132
- else:
133
- return torch.empty(
134
- shape, dtype=torch_dtype, device=device
135
- ).random_()
136
-
137
- # Case 2: Integer types
138
- elif torch_dtype in [
139
- torch.int8,
140
- torch.int16,
141
- torch.int32,
142
- torch.int64,
143
- torch.uint8,
144
- torch.bool,
145
- ]:
146
- if has_stats and torch_dtype != torch.bool:
147
- # Generate tensor with statistical properties, then round for integers
148
- if std == 0 or min_val == max_val:
149
- # Constant tensor
150
- return torch.full(
151
- shape, int(mean), dtype=torch_dtype, device=device
152
- )
153
- tensor = (
154
- torch.randn(shape, dtype=torch.float32, device=device) * std + mean
155
- )
156
- tensor = torch.clamp(tensor, min=min_val, max=max_val)
157
- return torch.round(tensor).to(torch_dtype)
158
- else:
159
- # Fallback to original random generation
160
- return torch.empty(shape, dtype=torch_dtype, device=device).random_()
161
-
162
- # Case 3: Complex numbers need special handling
163
- elif tensor_props.is_complex():
164
- # Complex types: fallback to original logic for now
165
- # TODO: Could be improved to use statistical info if available
166
- float_dtype = (
167
- torch.float32 if torch_dtype == torch.complex64 else torch.float64
168
- )
169
- real_part = torch.rand(shape, dtype=float_dtype, device=device)
170
- imag_part = torch.rand(shape, dtype=float_dtype, device=device)
171
- return torch.complex(real_part, imag_part)
172
-
173
- # Case 4: Handle other unsigned integers (like uint32) which fail with random_()
174
- elif "uint" in str(torch_dtype):
175
- if has_stats:
176
- # Generate tensor with statistical properties for unsigned integers
177
- if std == 0 or min_val == max_val:
178
- return torch.full(
179
- shape, int(mean), dtype=torch_dtype, device=device
180
- )
181
- tensor = (
182
- torch.randn(shape, dtype=torch.float32, device=device) * std + mean
183
- )
184
- tensor = torch.clamp(tensor, min=min_val, max=max_val)
185
- return torch.round(tensor).to(torch_dtype)
186
- else:
187
- # Fallback to original random generation
188
- return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
189
-
190
- # Case 5: If we don't know how to handle the type, raise an error
191
- else:
192
- raise NotImplementedError(
193
- f"Random data generation not implemented for dtype: {torch_dtype}"
194
- )
288
+ return _create_tensor(arg_info)
195
289
 
196
290
  elif arg_type == "triton_kernels.tensor.Tensor":
197
291
  if not TRITON_KERNELS_CUSTOM_TYPES:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tritonparse
3
- Version: 0.3.1.dev20251020071524
3
+ Version: 0.3.1.dev20251021071528
4
4
  Summary: TritonParse: A Compiler Tracer, Visualizer, and mini-Reproducer Generator for Triton Kernels
5
5
  Author-email: Yueming Hao <yhao@meta.com>
6
6
  License-Expression: BSD-3-Clause
@@ -16,13 +16,14 @@ tritonparse/trace_processor.py,sha256=brQBt26jdB6-quJXP5-warp2j31JSjOOFJa5ayiUZ5
16
16
  tritonparse/utils.py,sha256=Jnlptcd79llSDev-_1XyyOnv2izUqv0PEL74A8GF2tc,4565
17
17
  tritonparse/reproducer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
18
  tritonparse/reproducer/cli.py,sha256=MqYuuAP-uAIWWLQxixwyyBHJGSaQdG3xXGaVjERTLX8,1522
19
+ tritonparse/reproducer/function_extractor.py,sha256=pB5b52Xlk9-fe-Gs-z_Rn2HkZd1bC9qHTw1JZc_epBM,6588
19
20
  tritonparse/reproducer/orchestrator.py,sha256=ZGdsiOZg2Bg6o0cqZSCGx_v6mdzR4yZX7S3SVQjZ9-c,3182
20
- tritonparse/reproducer/placeholder_replacer.py,sha256=YPspknFkoZ1WLHQBrJefSlp4RerbdX9LpBe-QSCmMRs,9026
21
+ tritonparse/reproducer/placeholder_replacer.py,sha256=TvqQIrOubmBHJ_pl0ZvkpP4dIDiMYxUll1q9MIbzZco,9552
21
22
  tritonparse/reproducer/types.py,sha256=AfVl83zoJZQ58JJoplCcMC51gK-M-OKcafatYEIGgW0,509
22
- tritonparse/reproducer/utils.py,sha256=UTclw48vH49g6Z2ljJL5DOZ6Rl4UDudyr0PeUySa3p8,13857
23
+ tritonparse/reproducer/utils.py,sha256=LHmkM9EEAn9BwNyhHuCEp1tm7omeILbhkAcOmN-iLrk,16544
23
24
  tritonparse/reproducer/ingestion/ndjson.py,sha256=pEujTl5xXW2E2DEW8ngxXQ8qP9oawb90wBVTWHDs1jk,7372
24
25
  tritonparse/reproducer/templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
- tritonparse/reproducer/templates/example.py,sha256=QAyykFvmd1nI4ZvCjkJfrpAs2XBsMnFOK6WzhXxQ1uI,13963
26
+ tritonparse/reproducer/templates/example.py,sha256=jR3c8_d7fAFJYaj1DuUuthnI4Xd-_606bWDRdUPMNyo,785
26
27
  tritonparse/reproducer/templates/loader.py,sha256=HqjfThdDVg7q2bYWry78sIaVRkUpkcA8KQDt83YrlVE,1920
27
28
  tritonparse/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
29
  tritonparse/tools/decompress_bin_ndjson.py,sha256=kpt7DM_sSA334F1X45xdkP2OR9LuB27Pc50EkGr6CPM,4144
@@ -31,9 +32,9 @@ tritonparse/tools/format_fix.py,sha256=Ol0Sjui8D7OzHwbamAfGnq8V5Y63uwNaFTKSORN5H
31
32
  tritonparse/tools/load_tensor.py,sha256=94-TiSYlpXJx4MPmGK1ovmZlTt56Q_B3KQeCPaA6Cnw,2734
32
33
  tritonparse/tools/prettify_ndjson.py,sha256=r2YlHwFDTHgML7KljRmMsHaDg29q8gOQAgyDKWJhxRM,11062
33
34
  tritonparse/tools/readme.md,sha256=w6PWYfYnRgoPArLjxG9rVrpcLUkoVMGuRlbpF-o0IQM,110
34
- tritonparse-0.3.1.dev20251020071524.dist-info/licenses/LICENSE,sha256=4ZciugpyN7wcM4L-9pyDh_etvMUeIfBhDTyH1zeZlQM,1515
35
- tritonparse-0.3.1.dev20251020071524.dist-info/METADATA,sha256=gXkcfOLHKFIqbKhd4AQjj7v_UthKrA0SXjKKxnJv728,8278
36
- tritonparse-0.3.1.dev20251020071524.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
37
- tritonparse-0.3.1.dev20251020071524.dist-info/entry_points.txt,sha256=wEXdaieDoRRCCdhEv2p_C68iytnaXU_2pwt5CqjfbWY,56
38
- tritonparse-0.3.1.dev20251020071524.dist-info/top_level.txt,sha256=ITcTKgp3vf_bXV9vixuQU9IrZa3L1EfDSZwvRzRaoJU,12
39
- tritonparse-0.3.1.dev20251020071524.dist-info/RECORD,,
35
+ tritonparse-0.3.1.dev20251021071528.dist-info/licenses/LICENSE,sha256=4ZciugpyN7wcM4L-9pyDh_etvMUeIfBhDTyH1zeZlQM,1515
36
+ tritonparse-0.3.1.dev20251021071528.dist-info/METADATA,sha256=zeBrpftm2P8cY-iyCCyxc80qriG_rC7apZmRKy8CTKE,8278
37
+ tritonparse-0.3.1.dev20251021071528.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
38
+ tritonparse-0.3.1.dev20251021071528.dist-info/entry_points.txt,sha256=wEXdaieDoRRCCdhEv2p_C68iytnaXU_2pwt5CqjfbWY,56
39
+ tritonparse-0.3.1.dev20251021071528.dist-info/top_level.txt,sha256=ITcTKgp3vf_bXV9vixuQU9IrZa3L1EfDSZwvRzRaoJU,12
40
+ tritonparse-0.3.1.dev20251021071528.dist-info/RECORD,,