tritonparse 0.2.4.dev20250922071528__py3-none-any.whl → 0.2.4.dev20250924071525__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritonparse might be problematic. Click here for more details.

@@ -0,0 +1,27 @@
1
+ import argparse
2
+
3
+
4
+ def _add_reproducer_args(parser: argparse.ArgumentParser) -> None:
5
+ """Add common arguments for the reproducer to a parser."""
6
+ parser.add_argument("input", help="Path to the ndjson/ndjson.gz log file")
7
+ parser.add_argument(
8
+ "--line-index",
9
+ type=int,
10
+ help="The line number of the launch event in the input file to reproduce.",
11
+ )
12
+ parser.add_argument(
13
+ "--out-dir",
14
+ default="repro_output",
15
+ help=(
16
+ "Directory to save the reproducer script and context JSON. Defaults to "
17
+ "'repro_output/<kernel_name>/' if not provided."
18
+ ),
19
+ )
20
+ parser.add_argument(
21
+ "--template",
22
+ default="example",
23
+ help=(
24
+ "Template name (builtin, without .py) or a filesystem path to a .py file. "
25
+ "Defaults to 'example'."
26
+ ),
27
+ )
@@ -0,0 +1,235 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ from tritonparse.tp_logger import logger
5
+
6
+ # Sentinel object to mark arguments that should be skipped during processing
7
+ _SKIP = object()
8
+
9
+
10
+ @dataclass
11
+ class KernelInfo:
12
+ """Information about a Triton kernel extracted from compilation events."""
13
+
14
+ file_path: str
15
+ function_name: str
16
+ source_code: str
17
+ call_stack: List[Dict[str, Any]]
18
+
19
+
20
+ @dataclass
21
+ class ContextBundle:
22
+ """Bundle of all context information needed to reproduce a kernel launch."""
23
+
24
+ kernel_info: KernelInfo
25
+ compile: Dict[str, Any]
26
+ launch: Dict[str, Any]
27
+ args: Dict[str, Any]
28
+ tensor_args: Dict[str, Any]
29
+ raw_launch_event: Dict[str, Any]
30
+ raw_comp_event: Dict[str, Any]
31
+
32
+
33
+ def get_launch_and_compilation_events(
34
+ events: List[Dict[str, Any]], line_index: Optional[int] = None
35
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
36
+ """
37
+ Extract launch and compilation events from the event list.
38
+
39
+ Args:
40
+ events: List of parsed event dictionaries.
41
+ line_index: Index of the launch event to process.
42
+
43
+ Returns:
44
+ Tuple of (launch_event, compilation_event).
45
+
46
+ Raises:
47
+ ValueError: If the event at line_index is not a launch event.
48
+ RuntimeError: If compilation event cannot be found or is ambiguous.
49
+ """
50
+ if line_index is None or line_index >= len(events):
51
+ raise ValueError(f"Invalid line_index: {line_index}")
52
+
53
+ launch_event = events[line_index]
54
+ if launch_event["event_type"] != "launch":
55
+ raise ValueError(f"Event at index {line_index} is not a launch event")
56
+
57
+ comp_meta = launch_event.get("compilation_metadata", {})
58
+ comp_hash = comp_meta.get("hash")
59
+ if not comp_hash:
60
+ raise RuntimeError("Could not find compilation hash in launch event.")
61
+
62
+ comp_event = None
63
+ for event in events:
64
+ if (
65
+ event["event_type"] == "compilation"
66
+ and event.get("payload", {}).get("metadata", {}).get("hash") == comp_hash
67
+ ):
68
+ comp_event = event
69
+ break
70
+ if not comp_event:
71
+ raise RuntimeError(f"Could not find compilation event for hash {comp_hash}.")
72
+ return launch_event, comp_event
73
+
74
+
75
+ def get_kernel_info(comp_event: Dict[str, Any]) -> KernelInfo:
76
+ """
77
+ Extract kernel information from a compilation event.
78
+
79
+ Args:
80
+ comp_event: Compilation event dictionary containing kernel metadata.
81
+
82
+ Returns:
83
+ KernelInfo object with extracted kernel details.
84
+
85
+ Raises:
86
+ RuntimeError: If file path or function name cannot be resolved.
87
+ """
88
+ payload = comp_event.get("payload") or {}
89
+ py_source = payload.get("python_source") or {}
90
+ code = py_source.get("code", "")
91
+
92
+ # Extract file path and function name
93
+ file_path = py_source.get("file_path")
94
+ # The function name is in the compilation metadata payload
95
+ func_name = (comp_event.get("payload", {}).get("metadata") or {}).get("name")
96
+
97
+ # Find '@triton.jit' decorator and slice the string from there
98
+ jit_marker = "@triton.jit"
99
+ jit_pos = code.find(jit_marker)
100
+ if jit_pos != -1:
101
+ code = code[jit_pos:]
102
+ logger.debug("Extracted kernel source starting from '@triton.jit'.")
103
+
104
+ if not file_path or not func_name:
105
+ raise RuntimeError(
106
+ "Could not resolve kernel file path or function name from compilation event."
107
+ " The import-based strategy cannot proceed."
108
+ )
109
+ return KernelInfo(file_path, func_name, code, comp_event.get("stack", []))
110
+
111
+
112
+ def _decode_arg(raw: Any) -> Any:
113
+ """
114
+ Decode a raw argument value from event data.
115
+
116
+ Args:
117
+ raw: Raw argument value from event data.
118
+
119
+ Returns:
120
+ Decoded argument value, or _SKIP sentinel for tensors.
121
+ """
122
+ if not isinstance(raw, dict):
123
+ return raw
124
+ t = raw.get("type")
125
+ if t == "tensor":
126
+ return _SKIP
127
+ if t == "NoneType":
128
+ return None
129
+ return raw.get("value", raw.get("repr"))
130
+
131
+
132
+ def _pack_args(args: Dict[str, Any]) -> Dict[str, Any]:
133
+ """
134
+ Pack argument values into a standardized format.
135
+
136
+ Args:
137
+ args: Dictionary of argument names to values.
138
+
139
+ Returns:
140
+ Dictionary with packed argument information including type and metadata.
141
+ """
142
+ packed = {}
143
+ for k, v in args.items():
144
+ t = v.get("type") if isinstance(v, dict) else None
145
+ if t == "tensor":
146
+ packed[k] = {
147
+ "type": "tensor",
148
+ "shape": v.get("shape") if isinstance(v, dict) else None,
149
+ "dtype": v.get("dtype") if isinstance(v, dict) else None,
150
+ "device": v.get("device") if isinstance(v, dict) else None,
151
+ "stride": v.get("stride") if isinstance(v, dict) else None,
152
+ "is_contiguous": (
153
+ v.get("is_contiguous") if isinstance(v, dict) else None
154
+ ),
155
+ "numel": v.get("numel") if isinstance(v, dict) else None,
156
+ }
157
+ else:
158
+ # scalar / NoneType etc
159
+ if isinstance(v, dict):
160
+ packed[k] = {
161
+ "type": v.get("type"),
162
+ "value": v.get("value", v.get("repr")),
163
+ }
164
+ else:
165
+ packed[k] = {
166
+ "type": None,
167
+ "value": v,
168
+ }
169
+ return packed
170
+
171
+
172
+ def build_context_bundle(
173
+ events: List[Dict[str, Any]], line_index: Optional[int] = None
174
+ ):
175
+ """
176
+ Build a complete context bundle from events and line index.
177
+
178
+ Args:
179
+ events: List of parsed event dictionaries.
180
+ line_index: Index of the launch event to process.
181
+
182
+ Returns:
183
+ ContextBundle containing all information needed to reproduce the kernel launch.
184
+
185
+ Raises:
186
+ ValueError: If line_index is invalid or event is not a launch event.
187
+ RuntimeError: If compilation event cannot be found.
188
+ """
189
+ launch_event, comp_event = get_launch_and_compilation_events(events, line_index)
190
+ kernel_info = get_kernel_info(comp_event)
191
+ grid = launch_event.get("grid")
192
+ extracted_args = launch_event.get("extracted_args", {})
193
+ comp_meta = launch_event.get("compilation_metadata", {})
194
+
195
+ # Compile metadata subset we care about
196
+ compile_block = {
197
+ "num_warps": comp_meta.get("num_warps"),
198
+ "num_stages": comp_meta.get("num_stages"),
199
+ "arch": comp_meta.get("arch"),
200
+ "backend": comp_meta.get("backend_name") or comp_meta.get("backend"),
201
+ "triton_version": comp_meta.get("triton_version"),
202
+ "hash": comp_meta.get("hash"),
203
+ }
204
+
205
+ # kwargs: include constexpr + explicit scalars used for launch (skip tensor args)
206
+ kwargs = {}
207
+ for k, v in extracted_args.items():
208
+ val = _decode_arg(v)
209
+ if val is _SKIP:
210
+ continue
211
+ kwargs[k] = val
212
+
213
+ # tensor args: only tensors
214
+ raw_tensor_args = {
215
+ k: v
216
+ for k, v in extracted_args.items()
217
+ if isinstance(v, dict) and v.get("type") == "tensor"
218
+ }
219
+
220
+ primitive_args = _pack_args(extracted_args)
221
+ tensor_args = _pack_args(raw_tensor_args)
222
+ launch_block = {
223
+ "grid": grid,
224
+ "kwargs": kwargs,
225
+ }
226
+
227
+ return ContextBundle(
228
+ kernel_info,
229
+ compile_block,
230
+ launch_block,
231
+ primitive_args,
232
+ tensor_args,
233
+ launch_event,
234
+ comp_event,
235
+ )
@@ -0,0 +1,63 @@
1
+ from pathlib import Path
2
+
3
+ from tritonparse.reproducer.ingestion.ndjson import build_context_bundle
4
+ from tritonparse.reproducer.templates.loader import load_template_code
5
+ from tritonparse.reproducer.utils import (
6
+ _generate_import_statements,
7
+ _generate_invocation_snippet,
8
+ _parse_kernel_signature,
9
+ determine_output_paths,
10
+ )
11
+
12
+ from tritonparse.tools.prettify_ndjson import load_ndjson, save_prettified_json
13
+ from tritonparse.tp_logger import logger
14
+
15
+
16
+ def reproduce(
17
+ input_path: str,
18
+ line_index: int,
19
+ out_dir: str,
20
+ template: str,
21
+ ):
22
+ """
23
+ Generate a reproducer script from NDJSON trace file.
24
+
25
+ Args:
26
+ input_path: Path to the NDJSON trace file.
27
+ line_index: Line index of the launch event to reproduce.
28
+ out_dir: Output directory for reproducer files.
29
+ """
30
+ logger.debug(f"Building bundle from {input_path} at line {line_index}")
31
+ events = load_ndjson(Path(input_path), save_irs=True)
32
+ logger.debug(f"Loaded {len(events)} events")
33
+
34
+ # Build context bundle from the specified launch event
35
+ context_bundle = build_context_bundle(events, line_index)
36
+ logger.debug(
37
+ f"Built context bundle for kernel: {context_bundle.kernel_info.function_name}"
38
+ )
39
+ out_py_path, temp_json_path = determine_output_paths(
40
+ out_dir, context_bundle.kernel_info.function_name
41
+ )
42
+ save_prettified_json(context_bundle.raw_launch_event, temp_json_path)
43
+ logger.debug("Loading reproducer template.")
44
+ template_code = load_template_code(template)
45
+ final_code = template_code.replace(
46
+ "{{JSON_FILE_NAME_PLACEHOLDER}}", temp_json_path.name
47
+ )
48
+ sys_stmt, import_statement = _generate_import_statements(context_bundle.kernel_info)
49
+ final_code = final_code.replace("# {{KERNEL_SYSPATH_PLACEHOLDER}}", sys_stmt)
50
+ final_code = final_code.replace("# {{KERNEL_IMPORT_PLACEHOLDER}}", import_statement)
51
+ source_code = context_bundle.kernel_info.source_code
52
+ pos_args, kw_args = _parse_kernel_signature(source_code)
53
+ invocation_snippet = _generate_invocation_snippet(pos_args, kw_args)
54
+ final_code = final_code.replace(
55
+ "# {{KERNEL_INVOCATION_PLACEHOLDER}}", invocation_snippet
56
+ )
57
+ out_py_path.write_text(final_code, encoding="utf-8")
58
+ logger.info(
59
+ "REPRODUCER_OUTPUT script=%s json=%s kernel=%s",
60
+ str(out_py_path.resolve()),
61
+ str(temp_json_path.resolve()),
62
+ context_bundle.kernel_info.function_name,
63
+ )
@@ -0,0 +1,239 @@
1
+ import hashlib
2
+ import importlib
3
+ import json
4
+ import sys
5
+ from functools import lru_cache
6
+ from pathlib import Path
7
+
8
+ import torch
9
+
10
+ # {{KERNEL_SYSPATH_PLACEHOLDER}}
11
+
12
+ # {{KERNEL_IMPORT_PLACEHOLDER}}
13
+
14
+ TRITON_KERNELS_CUSTOM_TYPES = (
15
+ importlib.util.find_spec("triton_kernels") is not None
16
+ and importlib.util.find_spec("triton_kernels.tensor") is not None
17
+ )
18
+
19
+
20
+ @lru_cache(maxsize=1)
21
+ def _get_triton_tensor_types():
22
+ """
23
+ Import and cache Triton custom tensor types.
24
+
25
+ Returns:
26
+ tuple: (Tensor, Storage, StridedLayout) classes from triton_kernels.tensor.
27
+
28
+ Raises:
29
+ ImportError: If the optional module 'triton_kernels.tensor' is not available.
30
+ """
31
+ mod = importlib.import_module("triton_kernels.tensor")
32
+ return (
33
+ mod.Tensor,
34
+ mod.Storage,
35
+ mod.StridedLayout,
36
+ )
37
+
38
+
39
+ def load_tensor(tensor_file_path: str, device: str = None) -> torch.Tensor:
40
+ """
41
+ Load a tensor from its file path and verify its integrity using the hash in the filename.
42
+
43
+ Args:
44
+ tensor_file_path (str): Direct path to the tensor .bin file. The filename should be
45
+ the hash of the file contents followed by .bin extension.
46
+ device (str, optional): Device to load the tensor to (e.g., 'cuda:0', 'cpu').
47
+ If None, keeps the tensor on its original device.
48
+
49
+ Returns:
50
+ torch.Tensor: The loaded tensor (moved to the specified device if provided)
51
+
52
+ Raises:
53
+ FileNotFoundError: If the tensor file doesn't exist
54
+ RuntimeError: If the tensor cannot be loaded
55
+ ValueError: If the computed hash doesn't match the filename hash
56
+ """
57
+ blob_path = Path(tensor_file_path)
58
+
59
+ if not blob_path.exists():
60
+ raise FileNotFoundError(f"Tensor blob not found: {blob_path}")
61
+
62
+ # Extract expected hash from filename (remove .bin extension)
63
+ expected_hash = blob_path.stem
64
+
65
+ # Compute actual hash of file contents
66
+ with open(blob_path, "rb") as f:
67
+ file_contents = f.read()
68
+ computed_hash = hashlib.blake2b(file_contents).hexdigest()
69
+
70
+ # Verify hash matches filename
71
+ if computed_hash != expected_hash:
72
+ raise ValueError(
73
+ f"Hash verification failed: expected '{expected_hash}' but computed '{computed_hash}'"
74
+ )
75
+
76
+ try:
77
+ # Load the tensor using torch.load (tensors are saved with torch.save)
78
+ # If device is None, keep tensor on its original device, otherwise move to specified device
79
+ tensor = torch.load(blob_path, map_location=device)
80
+ return tensor
81
+ except Exception as e:
82
+ raise RuntimeError(f"Failed to load tensor from {blob_path}: {str(e)}") from e
83
+
84
+
85
+ def create_args_from_json(json_path):
86
+ """
87
+ Parse a reproducer JSON and build kernel grid and argument dictionary.
88
+
89
+ Args:
90
+ json_path (str): Path to the JSON file describing the kernel launch.
91
+
92
+ Returns:
93
+ tuple[list, dict]: Grid specification list and map of argument name to value.
94
+ """
95
+ with open(json_path, "r") as f:
96
+ data = json.load(f)
97
+ # Handle data format validation and extraction
98
+ if isinstance(data, list):
99
+ if len(data) != 1:
100
+ print(
101
+ f"Error: Expected single element list, got list with {len(data)} elements"
102
+ )
103
+ sys.exit(1)
104
+ data = data[0]
105
+ elif not isinstance(data, dict):
106
+ print(f"Error: Expected list or dict, got {type(data)}")
107
+ sys.exit(1)
108
+
109
+ grid = data.get("grid", [])
110
+ args_dict = {}
111
+ extracted_args = data.get("extracted_args", {})
112
+
113
+ for arg_name, arg_info in extracted_args.items():
114
+ args_dict[arg_name] = _create_arg_from_info(arg_info)
115
+
116
+ return grid, args_dict
117
+
118
+
119
+ def _create_arg_from_info(arg_info):
120
+ """
121
+ Recursively construct a kernel argument from its JSON schema.
122
+
123
+ Args:
124
+ arg_info (dict): JSON object describing a single argument, including
125
+ fields like 'type', 'value', 'dtype', 'shape', 'device', etc.
126
+
127
+ Returns:
128
+ Any: The constructed Python object suitable for kernel invocation.
129
+
130
+ Raises:
131
+ RuntimeError: When required optional dependencies are missing.
132
+ NotImplementedError: When a dtype or type is not supported yet.
133
+ """
134
+ arg_type = arg_info.get("type")
135
+
136
+ if arg_type in ["int", "bool"]:
137
+ return arg_info.get("value")
138
+
139
+ elif arg_type == "tensor":
140
+ if arg_info.get("blob_path"):
141
+ return load_tensor(arg_info.get("blob_path"), arg_info.get("device"))
142
+ dtype_str = arg_info.get("dtype")
143
+ try:
144
+ torch_dtype = getattr(torch, dtype_str.split(".")[-1])
145
+ except AttributeError:
146
+ torch_dtype = torch.float32
147
+
148
+ shape = arg_info.get("shape", [])
149
+ device = arg_info.get("device", "cpu")
150
+
151
+ # Use a dummy tensor to check properties of the dtype
152
+ tensor_props = torch.empty(0, dtype=torch_dtype)
153
+
154
+ # Case 1: Floating point, signed integers, uint8, and bool are supported by random_()
155
+ if tensor_props.is_floating_point():
156
+ if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
157
+ tmp = torch.rand(shape, dtype=torch.float32, device=device)
158
+ return tmp.to(torch_dtype)
159
+ else:
160
+ return torch.empty(shape, dtype=torch_dtype, device=device).random_()
161
+ elif torch_dtype in [
162
+ torch.int8,
163
+ torch.int16,
164
+ torch.int32,
165
+ torch.int64,
166
+ torch.uint8,
167
+ torch.bool,
168
+ ]:
169
+ return torch.empty(shape, dtype=torch_dtype, device=device).random_()
170
+ # Case 2: Complex numbers need special handling
171
+ elif tensor_props.is_complex():
172
+ float_dtype = (
173
+ torch.float32 if torch_dtype == torch.complex64 else torch.float64
174
+ )
175
+ real_part = torch.rand(shape, dtype=float_dtype, device=device)
176
+ imag_part = torch.rand(shape, dtype=float_dtype, device=device)
177
+ return torch.complex(real_part, imag_part)
178
+
179
+ # Case 3: Handle other unsigned integers (like uint32) which fail with random_()
180
+ elif "uint" in str(torch_dtype):
181
+ return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
182
+ # Case 4: If we don't know how to handle the type, raise an error
183
+ else:
184
+ raise NotImplementedError(
185
+ f"Random data generation not implemented for dtype: {torch_dtype}"
186
+ )
187
+
188
+ elif arg_type == "triton_kernels.tensor.Tensor":
189
+ if not TRITON_KERNELS_CUSTOM_TYPES:
190
+ raise RuntimeError(
191
+ "Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Tensor."
192
+ )
193
+ Tensor, Storage, StridedLayout = _get_triton_tensor_types()
194
+ storage = _create_arg_from_info(arg_info.get("storage"))
195
+ dtype_str = arg_info.get("dtype")
196
+ torch_dtype = getattr(torch, dtype_str.split(".")[-1])
197
+ return Tensor(
198
+ storage=storage,
199
+ shape=arg_info.get("shape"),
200
+ shape_max=arg_info.get("shape_max"),
201
+ dtype=torch_dtype,
202
+ )
203
+
204
+ elif arg_type == "triton_kernels.tensor.Storage":
205
+ if not TRITON_KERNELS_CUSTOM_TYPES:
206
+ raise RuntimeError(
207
+ "Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Storage."
208
+ )
209
+ Tensor, Storage, StridedLayout = _get_triton_tensor_types()
210
+ data = _create_arg_from_info(arg_info.get("data"))
211
+ layout = _create_arg_from_info(arg_info.get("layout"))
212
+ return Storage(data=data, layout=layout)
213
+
214
+ elif arg_type == "StridedLayout":
215
+ if not TRITON_KERNELS_CUSTOM_TYPES:
216
+ raise RuntimeError(
217
+ "Optional dependency 'triton_kernels.tensor' is not installed; cannot construct StridedLayout."
218
+ )
219
+ Tensor, Storage, StridedLayout = _get_triton_tensor_types()
220
+ return StridedLayout(shape=arg_info.get("initial_shape"))
221
+ else:
222
+ print(f"Warning: Unhandled argument type '{arg_type}'. Returning None.")
223
+ return None
224
+
225
+
226
+ if __name__ == "__main__":
227
+ script_dir = Path(__file__).resolve().parent
228
+ json_file = script_dir / "{{JSON_FILE_NAME_PLACEHOLDER}}"
229
+ grid, args_dict = create_args_from_json(str(json_file))
230
+
231
+ print("Generated kernel arguments dictionary:")
232
+ for name, arg in args_dict.items():
233
+ print(f" {name}: {arg}")
234
+ print(f"Grid: {grid}")
235
+
236
+ # {{KERNEL_INVOCATION_PLACEHOLDER}}
237
+
238
+ torch.cuda.synchronize()
239
+ print("Kernel execution finished.")
@@ -0,0 +1,57 @@
1
+ from importlib.resources import files as pkg_files
2
+ from pathlib import Path
3
+ from typing import List
4
+
5
+
6
+ BUILTIN_TEMPLATES_PACKAGE = "tritonparse.reproducer.templates"
7
+
8
+
9
+ def _is_path_like(template_arg: str) -> bool:
10
+ return "/" in template_arg or "\\" in template_arg or template_arg.endswith(".py")
11
+
12
+
13
+ def _read_file_text(path: Path) -> str:
14
+ p = path.expanduser().resolve()
15
+ if not p.exists() or not p.is_file():
16
+ raise FileNotFoundError(f"Template not found: {p}")
17
+ return p.read_text(encoding="utf-8")
18
+
19
+
20
+ def _read_builtin_template_text(name: str) -> str:
21
+ resource = pkg_files(BUILTIN_TEMPLATES_PACKAGE).joinpath(f"{name}.py")
22
+ # resource may not exist if an invalid name is provided
23
+ try:
24
+ with resource.open("r", encoding="utf-8") as f:
25
+ return f.read()
26
+ except FileNotFoundError as exc:
27
+ available = ", ".join(list_builtin_templates())
28
+ raise FileNotFoundError(
29
+ f"Builtin template '{name}' not found. Available: {available}"
30
+ ) from exc
31
+
32
+
33
+ def list_builtin_templates() -> List[str]:
34
+ """
35
+ Return the list of available builtin template names (without .py suffix).
36
+ """
37
+ names: List[str] = []
38
+ for entry in pkg_files(BUILTIN_TEMPLATES_PACKAGE).iterdir():
39
+ try:
40
+ if entry.is_file():
41
+ filename = entry.name
42
+ if filename.endswith(".py") and not filename.startswith("__"):
43
+ names.append(filename[:-3])
44
+ except (OSError, FileNotFoundError):
45
+ # Defensive: in case entry access fails in some environments
46
+ continue
47
+ names.sort()
48
+ return names
49
+
50
+
51
+ def load_template_code(template_arg: str) -> str:
52
+ """
53
+ Load template code by name (builtin, without .py) or by filesystem path.
54
+ """
55
+ if _is_path_like(template_arg):
56
+ return _read_file_text(Path(template_arg))
57
+ return _read_builtin_template_text(template_arg)
@@ -2,25 +2,40 @@ import importlib
2
2
  import importlib.util
3
3
  import json
4
4
  import sys
5
+ from datetime import datetime
5
6
  from functools import lru_cache
7
+ from pathlib import Path
6
8
 
7
9
  import torch
8
10
 
11
+ from tritonparse.tools.load_tensor import load_tensor
12
+ from tritonparse.tp_logger import logger
13
+
9
14
  TRITON_KERNELS_CUSTOM_TYPES = (
10
- importlib.util.find_spec("triton_kernels.tensor") is not None
15
+ importlib.util.find_spec("triton_kernels") is not None
16
+ and importlib.util.find_spec("triton_kernels.tensor") is not None
11
17
  )
12
18
 
13
19
 
20
+ @lru_cache(maxsize=1)
21
+ def _get_triton_tensor_types():
22
+ mod = importlib.import_module("triton_kernels.tensor")
23
+ return (
24
+ mod.Tensor,
25
+ mod.Storage,
26
+ mod.StridedLayout,
27
+ )
28
+
29
+
14
30
  def create_args_from_json(json_path):
15
31
  """
16
- Creates a list of arguments for a kernel launch from a JSON file.
32
+ Parse a reproducer JSON and build kernel grid and argument dictionary.
17
33
 
18
34
  Args:
19
- json_path (str): The path to the JSON file containing the kernel
20
- launch information.
35
+ json_path (str): Path to the JSON file describing the kernel launch.
21
36
 
22
37
  Returns:
23
- tuple: A tuple containing the grid and a dictionary of arguments.
38
+ tuple[list, dict]: Grid specification list and map of argument name to value.
24
39
  """
25
40
  with open(json_path, "r") as f:
26
41
  data = json.load(f)
@@ -46,19 +61,20 @@ def create_args_from_json(json_path):
46
61
  return grid, args_dict
47
62
 
48
63
 
49
- @lru_cache(maxsize=1)
50
- def _get_triton_tensor_types():
51
- mod = importlib.import_module("triton_kernels.tensor")
52
- return (
53
- getattr(mod, "Tensor"),
54
- getattr(mod, "Storage"),
55
- getattr(mod, "StridedLayout"),
56
- )
57
-
58
-
59
64
  def _create_arg_from_info(arg_info):
60
65
  """
61
- Recursively creates a kernel argument from its JSON info dictionary.
66
+ Recursively construct a kernel argument from its JSON schema.
67
+
68
+ Args:
69
+ arg_info (dict): JSON object describing a single argument, including
70
+ fields like 'type', 'value', 'dtype', 'shape', 'device', etc.
71
+
72
+ Returns:
73
+ Any: The constructed Python object suitable for kernel invocation.
74
+
75
+ Raises:
76
+ RuntimeError: When required optional dependencies are missing.
77
+ NotImplementedError: When a dtype or type is not supported yet.
62
78
  """
63
79
  arg_type = arg_info.get("type")
64
80
 
@@ -66,6 +82,8 @@ def _create_arg_from_info(arg_info):
66
82
  return arg_info.get("value")
67
83
 
68
84
  elif arg_type == "tensor":
85
+ if arg_info.get("blob_path"):
86
+ return load_tensor(arg_info.get("blob_path"), arg_info.get("device"))
69
87
  dtype_str = arg_info.get("dtype")
70
88
  try:
71
89
  torch_dtype = getattr(torch, dtype_str.split(".")[-1])
@@ -79,7 +97,13 @@ def _create_arg_from_info(arg_info):
79
97
  tensor_props = torch.empty(0, dtype=torch_dtype)
80
98
 
81
99
  # Case 1: Floating point, signed integers, uint8, and bool are supported by random_()
82
- if tensor_props.is_floating_point() or torch_dtype in [
100
+ if tensor_props.is_floating_point():
101
+ if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
102
+ tmp = torch.rand(shape, dtype=torch.float32, device=device)
103
+ return tmp.to(torch_dtype)
104
+ else:
105
+ return torch.empty(shape, dtype=torch_dtype, device=device).random_()
106
+ elif torch_dtype in [
83
107
  torch.int8,
84
108
  torch.int16,
85
109
  torch.int32,
@@ -88,7 +112,6 @@ def _create_arg_from_info(arg_info):
88
112
  torch.bool,
89
113
  ]:
90
114
  return torch.empty(shape, dtype=torch_dtype, device=device).random_()
91
-
92
115
  # Case 2: Complex numbers need special handling
93
116
  elif tensor_props.is_complex():
94
117
  float_dtype = (
@@ -101,13 +124,11 @@ def _create_arg_from_info(arg_info):
101
124
  # Case 3: Handle other unsigned integers (like uint32) which fail with random_()
102
125
  elif "uint" in str(torch_dtype):
103
126
  return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
104
-
105
- # Case 4: If we don't know how to handle the type, raise an error
127
+ # Case 4: If we don't know how to handle the type, raise an error
106
128
  else:
107
129
  raise NotImplementedError(
108
130
  f"Random data generation not implemented for dtype: {torch_dtype}"
109
131
  )
110
-
111
132
  elif arg_type == "triton_kernels.tensor.Tensor":
112
133
  if not TRITON_KERNELS_CUSTOM_TYPES:
113
134
  raise RuntimeError(
@@ -145,3 +166,137 @@ def _create_arg_from_info(arg_info):
145
166
  else:
146
167
  print(f"Warning: Unhandled argument type '{arg_type}'. Returning None.")
147
168
  return None
169
+
170
+
171
+ def determine_output_paths(out_dir: str, kernel_name: str):
172
+ """
173
+ Determine output file paths for reproducer script and context data.
174
+
175
+ Args:
176
+ out_dir: Output directory path. If empty, uses default location.
177
+ kernel_name: Name of the kernel for default directory naming.
178
+
179
+ Returns:
180
+ Tuple of (python_script_path, json_context_path) as Path objects.
181
+ """
182
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
183
+ output_directory = Path(out_dir) / kernel_name
184
+ output_directory.mkdir(parents=True, exist_ok=True)
185
+
186
+ out_py_path = output_directory / f"repro_{timestamp}.py"
187
+ temp_json_path = output_directory / f"repro_context_{timestamp}.json"
188
+
189
+ return out_py_path, temp_json_path
190
+
191
+
192
+ def _generate_import_statements(kernel_info) -> tuple[str, str]:
193
+ """
194
+ Generate (sys.path insertion statement, import statement) for the kernel.
195
+
196
+ Strategy:
197
+ - Always add the kernel file's parent directory to sys.path.
198
+ - If the filename (without .py) is a valid identifier, import using that
199
+ module name: `from <stem> import <func> as imported_kernel_function`.
200
+ - Otherwise, fall back to dynamic import via importlib.util and bind
201
+ `imported_kernel_function` from the loaded module.
202
+ """
203
+ file_path = Path(kernel_info.file_path)
204
+ function_name = kernel_info.function_name
205
+
206
+ if not file_path or not function_name:
207
+ raise ValueError("Kernel file path or function name missing from context.")
208
+
209
+ # Always add the file's parent directory to sys.path
210
+ sys_stmt = (
211
+ "import sys; p = r'" + str(file_path.parent) + "';\n"
212
+ "if p not in sys.path: sys.path.insert(0, p)"
213
+ )
214
+
215
+ module_name = file_path.with_suffix("").name
216
+ if module_name.isidentifier():
217
+ import_stmt = (
218
+ f"from {module_name} import {function_name} as imported_kernel_function"
219
+ )
220
+ logger.debug("Generated direct import statement: %s", import_stmt)
221
+ return sys_stmt, import_stmt
222
+
223
+ # Fallback: dynamic import when filename is not a valid identifier
224
+ import_stmt = (
225
+ "import importlib.util\n"
226
+ f"_spec = importlib.util.spec_from_file_location('kernel_mod', r'{str(file_path)}')\n"
227
+ "_mod = importlib.util.module_from_spec(_spec)\n"
228
+ "_spec.loader.exec_module(_mod)\n"
229
+ f"imported_kernel_function = getattr(_mod, '{function_name}')"
230
+ )
231
+ logger.debug("Generated dynamic import for file: %s", file_path)
232
+ return sys_stmt, import_stmt
233
+
234
+
235
+ def _parse_kernel_signature(kernel_source_code: str) -> tuple[list[str], list[str]]:
236
+ """
237
+ Parses a Triton kernel's source code to distinguish positional args
238
+ from keyword args (those with default values).
239
+ """
240
+ signature_lines = []
241
+ in_signature = False
242
+ for line in kernel_source_code.splitlines():
243
+ # Mark beginning of signature when function definition is found
244
+ if line.strip().startswith("def "):
245
+ in_signature = True
246
+ if in_signature:
247
+ # Strip comments and leading/trailing whitespace
248
+ clean_line = line.split("#")[0].strip()
249
+ signature_lines.append(clean_line)
250
+ # Stop capturing after the signature ends
251
+ if "):" in line:
252
+ break
253
+
254
+ full_signature = "".join(signature_lines)
255
+ # Extract content between the first '(' and the last '):'
256
+ try:
257
+ params_str = full_signature[
258
+ full_signature.find("(") + 1 : full_signature.rfind("):")
259
+ ]
260
+ except IndexError as exc:
261
+ raise ValueError("Could not parse kernel signature.") from exc
262
+
263
+ # Clean up and split the parameters string
264
+ params = [p.strip() for p in params_str.replace("\n", "").split(",") if p.strip()]
265
+
266
+ positional_args = []
267
+ keyword_args = []
268
+
269
+ for param in params:
270
+ if "=" in param:
271
+ # Keyword arguments have a default value
272
+ arg_name = param.split("=")[0].strip()
273
+ keyword_args.append(arg_name)
274
+ else:
275
+ # Positional arguments do not have a default value
276
+ arg_name = param.split(":")[0].strip()
277
+ positional_args.append(arg_name)
278
+
279
+ logger.debug("Parsed positional args: %s", positional_args)
280
+ logger.debug("Parsed keyword args: %s", keyword_args)
281
+ return positional_args, keyword_args
282
+
283
+
284
+ def _generate_invocation_snippet(
285
+ positional_args: list[str], keyword_args: list[str]
286
+ ) -> str:
287
+ """Generates a single-line Python code snippet for kernel invocation."""
288
+ # Prepare positional args for direct injection into the call
289
+ pos_args_str = ", ".join([f'args_dict["{arg}"]' for arg in positional_args])
290
+
291
+ # Prepare keyword args for direct injection
292
+ kw_args_str = ", ".join([f'{arg}=args_dict["{arg}"]' for arg in keyword_args])
293
+
294
+ # Combine them, ensuring proper comma separation
295
+ all_args = []
296
+ if pos_args_str:
297
+ all_args.append(pos_args_str)
298
+ if kw_args_str:
299
+ all_args.append(kw_args_str)
300
+
301
+ # Create the single-line call
302
+ return f"imported_kernel_function[tuple(grid)]({', '.join(all_args)})"
@@ -40,7 +40,7 @@ import argparse
40
40
  import json
41
41
  import sys
42
42
  from pathlib import Path
43
- from typing import Any, List
43
+ from typing import Any, List, Union
44
44
 
45
45
 
46
46
  def parse_line_ranges(lines_arg: str) -> set[int]:
@@ -174,7 +174,7 @@ def load_ndjson(
174
174
  except FileNotFoundError:
175
175
  print(f"Error: File '{file_path}' not found.", file=sys.stderr)
176
176
  raise
177
- except Exception as e:
177
+ except (OSError, UnicodeDecodeError) as e:
178
178
  print(f"Error reading file '{file_path}': {e}", file=sys.stderr)
179
179
  raise
180
180
 
@@ -201,19 +201,21 @@ def load_ndjson(
201
201
  return json_objects
202
202
 
203
203
 
204
- def save_prettified_json(json_objects: List[Any], output_path: Path) -> None:
204
+ def save_prettified_json(
205
+ json_objects: Union[List[Any], Any], output_path: Path
206
+ ) -> None:
205
207
  """
206
- Save list of JSON objects to a prettified JSON file.
208
+ Save JSON data to a prettified JSON file.
207
209
 
208
210
  Args:
209
- json_objects: List of JSON objects to save
211
+ json_objects: Either a list of JSON objects or a single JSON-serializable object
210
212
  output_path: Path where to save the prettified JSON file
211
213
  """
212
214
  try:
213
215
  with open(output_path, "w", encoding="utf-8") as f:
214
216
  json.dump(json_objects, f, indent=2, ensure_ascii=False, sort_keys=True)
215
217
  print(f"Successfully converted to prettified JSON: {output_path}")
216
- except Exception as e:
218
+ except OSError as e:
217
219
  print(f"Error writing to file '{output_path}': {e}", file=sys.stderr)
218
220
  raise
219
221
 
tritonparse/utils.py CHANGED
@@ -16,21 +16,15 @@ from .common import (
16
16
  )
17
17
  from .source_type import Source, SourceType
18
18
 
19
- # argument parser for OSS
20
- parser = None
21
19
 
22
-
23
- def init_parser():
24
- global parser
25
-
26
- parser = argparse.ArgumentParser(
27
- description="analyze triton structured logs", conflict_handler="resolve"
28
- )
29
-
30
- # Add arguments for the parse command
20
+ def _add_parse_args(parser: argparse.ArgumentParser) -> None:
21
+ """Add common 'parse' subcommand arguments to a parser."""
31
22
  parser.add_argument(
32
23
  "source",
33
- help="Source of torch logs to be analyzed. It is expected to path to a local directory or log",
24
+ help=(
25
+ "Source of torch logs to be analyzed. It is expected to path to a local "
26
+ "directory or log"
27
+ ),
34
28
  )
35
29
  parser.add_argument(
36
30
  "-o",
@@ -40,7 +34,9 @@ def init_parser():
40
34
  )
41
35
  parser.add_argument(
42
36
  "--overwrite",
43
- help="Delete out directory if it already exists. Only does something if --out is set",
37
+ help=(
38
+ "Delete out directory if it already exists. Only does something if --out is set"
39
+ ),
44
40
  action="store_true",
45
41
  )
46
42
  parser.add_argument("-r", "--rank", help="Rank of logs to be analyzed", type=int)
@@ -54,7 +50,6 @@ def init_parser():
54
50
  from tritonparse.fb.utils import append_parser
55
51
 
56
52
  append_parser(parser)
57
- return parser
58
53
 
59
54
 
60
55
  def oss_run(
@@ -113,12 +108,6 @@ def oss_run(
113
108
  print_parsed_files_summary(out_dir)
114
109
 
115
110
 
116
- def unified_parse_from_cli():
117
- parser = init_parser()
118
- args = parser.parse_args()
119
- return unified_parse(**vars(args))
120
-
121
-
122
111
  def unified_parse(
123
112
  source: str,
124
113
  out: Optional[str] = None,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tritonparse
3
- Version: 0.2.4.dev20250922071528
3
+ Version: 0.2.4.dev20250924071525
4
4
  Summary: TritonParse: A Compiler Tracer, Visualizer, and mini-Reproducer Generator for Triton Kernels
5
5
  Author-email: Yueming Hao <yhao@meta.com>
6
6
  License-Expression: BSD-3-Clause
@@ -10,17 +10,22 @@ tritonparse/sourcemap_utils.py,sha256=qsQmTDuEe9yuUVyxSHRbjTR38gi0hvJEijnPkrJVAV
10
10
  tritonparse/structured_logging.py,sha256=7r9pv6miUdb8-CCZfj8SkD3XItzwPeONmszEL7TZak4,43949
11
11
  tritonparse/tp_logger.py,sha256=vXzY7hMDmVnRBGBhIjFZe3nHZzG5NKKPONGUszJhGgU,242
12
12
  tritonparse/trace_processor.py,sha256=QzUOKwnOkBbwTTKBsa5ZMUABPLMJIBFtTcG2SkhO0I8,12771
13
- tritonparse/utils.py,sha256=wt61tpbkqjGqHh0c7Nr2WlOv7PbQssmjULd6uA6aAko,4475
13
+ tritonparse/utils.py,sha256=ujx9iUrpOthJ5vWzaNs6RXtqX0dp_GeozOaQLqlUDxg,4269
14
14
  tritonparse/reproducer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- tritonparse/reproducer/utils.py,sha256=VfMBwnTEZO8Ug9_ZRlZUVTMaMczDkviAykXpnK5dacU,5093
15
+ tritonparse/reproducer/cli.py,sha256=JvnAi1FKSpNa6eHEapRn9jdXxsj1vvyrrEEnfxTJYa8,871
16
+ tritonparse/reproducer/orchestrator.py,sha256=9SQ_rATY-s4r3BZQZdKLw7WYGz8IQJ1StPMgRbKAs5s,2456
17
+ tritonparse/reproducer/utils.py,sha256=qi4XTKk0pWV4hgYg_GPBISEfVXlrI6tZR0A5ZZbwVyo,11132
18
+ tritonparse/reproducer/ingestion/ndjson.py,sha256=pEujTl5xXW2E2DEW8ngxXQ8qP9oawb90wBVTWHDs1jk,7372
19
+ tritonparse/reproducer/templates/example.py,sha256=XWfXD4tDOiE213YlWWK1l1ZgXbK3BX61NnvuVTkO-S0,8595
20
+ tritonparse/reproducer/templates/loader.py,sha256=HqjfThdDVg7q2bYWry78sIaVRkUpkcA8KQDt83YrlVE,1920
16
21
  tritonparse/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
22
  tritonparse/tools/decompress_bin_ndjson.py,sha256=kpt7DM_sSA334F1X45xdkP2OR9LuB27Pc50EkGr6CPM,4144
18
23
  tritonparse/tools/format_fix.py,sha256=Ol0Sjui8D7OzHwbamAfGnq8V5Y63uwNaFTKSORN5HkQ,3867
19
24
  tritonparse/tools/load_tensor.py,sha256=tfdmNVd9gsZqO6msQBhbXIhOvUzgc83yF64k2GDWPNk,2122
20
- tritonparse/tools/prettify_ndjson.py,sha256=VOzVWoXpCbaAXYA4i_wBcQIHfh-JhAx7xR4cF_L8yDs,10928
25
+ tritonparse/tools/prettify_ndjson.py,sha256=YpJ7SFXTkZPZEXQeN1w5wkOf9pFrGqaqhhfHV7eobWA,10998
21
26
  tritonparse/tools/readme.md,sha256=w6PWYfYnRgoPArLjxG9rVrpcLUkoVMGuRlbpF-o0IQM,110
22
- tritonparse-0.2.4.dev20250922071528.dist-info/licenses/LICENSE,sha256=4ZciugpyN7wcM4L-9pyDh_etvMUeIfBhDTyH1zeZlQM,1515
23
- tritonparse-0.2.4.dev20250922071528.dist-info/METADATA,sha256=pm3r6Z1nR3gOJ35Ztyen1MhOvxKfPqz18_06ASPNYlc,6580
24
- tritonparse-0.2.4.dev20250922071528.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
25
- tritonparse-0.2.4.dev20250922071528.dist-info/top_level.txt,sha256=ITcTKgp3vf_bXV9vixuQU9IrZa3L1EfDSZwvRzRaoJU,12
26
- tritonparse-0.2.4.dev20250922071528.dist-info/RECORD,,
27
+ tritonparse-0.2.4.dev20250924071525.dist-info/licenses/LICENSE,sha256=4ZciugpyN7wcM4L-9pyDh_etvMUeIfBhDTyH1zeZlQM,1515
28
+ tritonparse-0.2.4.dev20250924071525.dist-info/METADATA,sha256=cxNEHWh9EoRq332ybQO3nsai4gC5eASCWvyPloW0gko,6580
29
+ tritonparse-0.2.4.dev20250924071525.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
30
+ tritonparse-0.2.4.dev20250924071525.dist-info/top_level.txt,sha256=ITcTKgp3vf_bXV9vixuQU9IrZa3L1EfDSZwvRzRaoJU,12
31
+ tritonparse-0.2.4.dev20250924071525.dist-info/RECORD,,