tritonparse 0.2.4.dev20250922071528__py3-none-any.whl → 0.2.4.dev20250924071525__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/reproducer/cli.py +27 -0
- tritonparse/reproducer/ingestion/ndjson.py +235 -0
- tritonparse/reproducer/orchestrator.py +63 -0
- tritonparse/reproducer/templates/example.py +239 -0
- tritonparse/reproducer/templates/loader.py +57 -0
- tritonparse/reproducer/utils.py +176 -21
- tritonparse/tools/prettify_ndjson.py +8 -6
- tritonparse/utils.py +9 -20
- {tritonparse-0.2.4.dev20250922071528.dist-info → tritonparse-0.2.4.dev20250924071525.dist-info}/METADATA +1 -1
- {tritonparse-0.2.4.dev20250922071528.dist-info → tritonparse-0.2.4.dev20250924071525.dist-info}/RECORD +13 -8
- {tritonparse-0.2.4.dev20250922071528.dist-info → tritonparse-0.2.4.dev20250924071525.dist-info}/WHEEL +0 -0
- {tritonparse-0.2.4.dev20250922071528.dist-info → tritonparse-0.2.4.dev20250924071525.dist-info}/licenses/LICENSE +0 -0
- {tritonparse-0.2.4.dev20250922071528.dist-info → tritonparse-0.2.4.dev20250924071525.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def _add_reproducer_args(parser: argparse.ArgumentParser) -> None:
|
|
5
|
+
"""Add common arguments for the reproducer to a parser."""
|
|
6
|
+
parser.add_argument("input", help="Path to the ndjson/ndjson.gz log file")
|
|
7
|
+
parser.add_argument(
|
|
8
|
+
"--line-index",
|
|
9
|
+
type=int,
|
|
10
|
+
help="The line number of the launch event in the input file to reproduce.",
|
|
11
|
+
)
|
|
12
|
+
parser.add_argument(
|
|
13
|
+
"--out-dir",
|
|
14
|
+
default="repro_output",
|
|
15
|
+
help=(
|
|
16
|
+
"Directory to save the reproducer script and context JSON. Defaults to "
|
|
17
|
+
"'repro_output/<kernel_name>/' if not provided."
|
|
18
|
+
),
|
|
19
|
+
)
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"--template",
|
|
22
|
+
default="example",
|
|
23
|
+
help=(
|
|
24
|
+
"Template name (builtin, without .py) or a filesystem path to a .py file. "
|
|
25
|
+
"Defaults to 'example'."
|
|
26
|
+
),
|
|
27
|
+
)
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
from tritonparse.tp_logger import logger
|
|
5
|
+
|
|
6
|
+
# Sentinel object to mark arguments that should be skipped during processing
|
|
7
|
+
_SKIP = object()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class KernelInfo:
|
|
12
|
+
"""Information about a Triton kernel extracted from compilation events."""
|
|
13
|
+
|
|
14
|
+
file_path: str
|
|
15
|
+
function_name: str
|
|
16
|
+
source_code: str
|
|
17
|
+
call_stack: List[Dict[str, Any]]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class ContextBundle:
|
|
22
|
+
"""Bundle of all context information needed to reproduce a kernel launch."""
|
|
23
|
+
|
|
24
|
+
kernel_info: KernelInfo
|
|
25
|
+
compile: Dict[str, Any]
|
|
26
|
+
launch: Dict[str, Any]
|
|
27
|
+
args: Dict[str, Any]
|
|
28
|
+
tensor_args: Dict[str, Any]
|
|
29
|
+
raw_launch_event: Dict[str, Any]
|
|
30
|
+
raw_comp_event: Dict[str, Any]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_launch_and_compilation_events(
|
|
34
|
+
events: List[Dict[str, Any]], line_index: Optional[int] = None
|
|
35
|
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
36
|
+
"""
|
|
37
|
+
Extract launch and compilation events from the event list.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
events: List of parsed event dictionaries.
|
|
41
|
+
line_index: Index of the launch event to process.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Tuple of (launch_event, compilation_event).
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
ValueError: If the event at line_index is not a launch event.
|
|
48
|
+
RuntimeError: If compilation event cannot be found or is ambiguous.
|
|
49
|
+
"""
|
|
50
|
+
if line_index is None or line_index >= len(events):
|
|
51
|
+
raise ValueError(f"Invalid line_index: {line_index}")
|
|
52
|
+
|
|
53
|
+
launch_event = events[line_index]
|
|
54
|
+
if launch_event["event_type"] != "launch":
|
|
55
|
+
raise ValueError(f"Event at index {line_index} is not a launch event")
|
|
56
|
+
|
|
57
|
+
comp_meta = launch_event.get("compilation_metadata", {})
|
|
58
|
+
comp_hash = comp_meta.get("hash")
|
|
59
|
+
if not comp_hash:
|
|
60
|
+
raise RuntimeError("Could not find compilation hash in launch event.")
|
|
61
|
+
|
|
62
|
+
comp_event = None
|
|
63
|
+
for event in events:
|
|
64
|
+
if (
|
|
65
|
+
event["event_type"] == "compilation"
|
|
66
|
+
and event.get("payload", {}).get("metadata", {}).get("hash") == comp_hash
|
|
67
|
+
):
|
|
68
|
+
comp_event = event
|
|
69
|
+
break
|
|
70
|
+
if not comp_event:
|
|
71
|
+
raise RuntimeError(f"Could not find compilation event for hash {comp_hash}.")
|
|
72
|
+
return launch_event, comp_event
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def get_kernel_info(comp_event: Dict[str, Any]) -> KernelInfo:
|
|
76
|
+
"""
|
|
77
|
+
Extract kernel information from a compilation event.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
comp_event: Compilation event dictionary containing kernel metadata.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
KernelInfo object with extracted kernel details.
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
RuntimeError: If file path or function name cannot be resolved.
|
|
87
|
+
"""
|
|
88
|
+
payload = comp_event.get("payload") or {}
|
|
89
|
+
py_source = payload.get("python_source") or {}
|
|
90
|
+
code = py_source.get("code", "")
|
|
91
|
+
|
|
92
|
+
# Extract file path and function name
|
|
93
|
+
file_path = py_source.get("file_path")
|
|
94
|
+
# The function name is in the compilation metadata payload
|
|
95
|
+
func_name = (comp_event.get("payload", {}).get("metadata") or {}).get("name")
|
|
96
|
+
|
|
97
|
+
# Find '@triton.jit' decorator and slice the string from there
|
|
98
|
+
jit_marker = "@triton.jit"
|
|
99
|
+
jit_pos = code.find(jit_marker)
|
|
100
|
+
if jit_pos != -1:
|
|
101
|
+
code = code[jit_pos:]
|
|
102
|
+
logger.debug("Extracted kernel source starting from '@triton.jit'.")
|
|
103
|
+
|
|
104
|
+
if not file_path or not func_name:
|
|
105
|
+
raise RuntimeError(
|
|
106
|
+
"Could not resolve kernel file path or function name from compilation event."
|
|
107
|
+
" The import-based strategy cannot proceed."
|
|
108
|
+
)
|
|
109
|
+
return KernelInfo(file_path, func_name, code, comp_event.get("stack", []))
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _decode_arg(raw: Any) -> Any:
|
|
113
|
+
"""
|
|
114
|
+
Decode a raw argument value from event data.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
raw: Raw argument value from event data.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Decoded argument value, or _SKIP sentinel for tensors.
|
|
121
|
+
"""
|
|
122
|
+
if not isinstance(raw, dict):
|
|
123
|
+
return raw
|
|
124
|
+
t = raw.get("type")
|
|
125
|
+
if t == "tensor":
|
|
126
|
+
return _SKIP
|
|
127
|
+
if t == "NoneType":
|
|
128
|
+
return None
|
|
129
|
+
return raw.get("value", raw.get("repr"))
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _pack_args(args: Dict[str, Any]) -> Dict[str, Any]:
|
|
133
|
+
"""
|
|
134
|
+
Pack argument values into a standardized format.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
args: Dictionary of argument names to values.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Dictionary with packed argument information including type and metadata.
|
|
141
|
+
"""
|
|
142
|
+
packed = {}
|
|
143
|
+
for k, v in args.items():
|
|
144
|
+
t = v.get("type") if isinstance(v, dict) else None
|
|
145
|
+
if t == "tensor":
|
|
146
|
+
packed[k] = {
|
|
147
|
+
"type": "tensor",
|
|
148
|
+
"shape": v.get("shape") if isinstance(v, dict) else None,
|
|
149
|
+
"dtype": v.get("dtype") if isinstance(v, dict) else None,
|
|
150
|
+
"device": v.get("device") if isinstance(v, dict) else None,
|
|
151
|
+
"stride": v.get("stride") if isinstance(v, dict) else None,
|
|
152
|
+
"is_contiguous": (
|
|
153
|
+
v.get("is_contiguous") if isinstance(v, dict) else None
|
|
154
|
+
),
|
|
155
|
+
"numel": v.get("numel") if isinstance(v, dict) else None,
|
|
156
|
+
}
|
|
157
|
+
else:
|
|
158
|
+
# scalar / NoneType etc
|
|
159
|
+
if isinstance(v, dict):
|
|
160
|
+
packed[k] = {
|
|
161
|
+
"type": v.get("type"),
|
|
162
|
+
"value": v.get("value", v.get("repr")),
|
|
163
|
+
}
|
|
164
|
+
else:
|
|
165
|
+
packed[k] = {
|
|
166
|
+
"type": None,
|
|
167
|
+
"value": v,
|
|
168
|
+
}
|
|
169
|
+
return packed
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def build_context_bundle(
|
|
173
|
+
events: List[Dict[str, Any]], line_index: Optional[int] = None
|
|
174
|
+
):
|
|
175
|
+
"""
|
|
176
|
+
Build a complete context bundle from events and line index.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
events: List of parsed event dictionaries.
|
|
180
|
+
line_index: Index of the launch event to process.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
ContextBundle containing all information needed to reproduce the kernel launch.
|
|
184
|
+
|
|
185
|
+
Raises:
|
|
186
|
+
ValueError: If line_index is invalid or event is not a launch event.
|
|
187
|
+
RuntimeError: If compilation event cannot be found.
|
|
188
|
+
"""
|
|
189
|
+
launch_event, comp_event = get_launch_and_compilation_events(events, line_index)
|
|
190
|
+
kernel_info = get_kernel_info(comp_event)
|
|
191
|
+
grid = launch_event.get("grid")
|
|
192
|
+
extracted_args = launch_event.get("extracted_args", {})
|
|
193
|
+
comp_meta = launch_event.get("compilation_metadata", {})
|
|
194
|
+
|
|
195
|
+
# Compile metadata subset we care about
|
|
196
|
+
compile_block = {
|
|
197
|
+
"num_warps": comp_meta.get("num_warps"),
|
|
198
|
+
"num_stages": comp_meta.get("num_stages"),
|
|
199
|
+
"arch": comp_meta.get("arch"),
|
|
200
|
+
"backend": comp_meta.get("backend_name") or comp_meta.get("backend"),
|
|
201
|
+
"triton_version": comp_meta.get("triton_version"),
|
|
202
|
+
"hash": comp_meta.get("hash"),
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
# kwargs: include constexpr + explicit scalars used for launch (skip tensor args)
|
|
206
|
+
kwargs = {}
|
|
207
|
+
for k, v in extracted_args.items():
|
|
208
|
+
val = _decode_arg(v)
|
|
209
|
+
if val is _SKIP:
|
|
210
|
+
continue
|
|
211
|
+
kwargs[k] = val
|
|
212
|
+
|
|
213
|
+
# tensor args: only tensors
|
|
214
|
+
raw_tensor_args = {
|
|
215
|
+
k: v
|
|
216
|
+
for k, v in extracted_args.items()
|
|
217
|
+
if isinstance(v, dict) and v.get("type") == "tensor"
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
primitive_args = _pack_args(extracted_args)
|
|
221
|
+
tensor_args = _pack_args(raw_tensor_args)
|
|
222
|
+
launch_block = {
|
|
223
|
+
"grid": grid,
|
|
224
|
+
"kwargs": kwargs,
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
return ContextBundle(
|
|
228
|
+
kernel_info,
|
|
229
|
+
compile_block,
|
|
230
|
+
launch_block,
|
|
231
|
+
primitive_args,
|
|
232
|
+
tensor_args,
|
|
233
|
+
launch_event,
|
|
234
|
+
comp_event,
|
|
235
|
+
)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from tritonparse.reproducer.ingestion.ndjson import build_context_bundle
|
|
4
|
+
from tritonparse.reproducer.templates.loader import load_template_code
|
|
5
|
+
from tritonparse.reproducer.utils import (
|
|
6
|
+
_generate_import_statements,
|
|
7
|
+
_generate_invocation_snippet,
|
|
8
|
+
_parse_kernel_signature,
|
|
9
|
+
determine_output_paths,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from tritonparse.tools.prettify_ndjson import load_ndjson, save_prettified_json
|
|
13
|
+
from tritonparse.tp_logger import logger
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def reproduce(
|
|
17
|
+
input_path: str,
|
|
18
|
+
line_index: int,
|
|
19
|
+
out_dir: str,
|
|
20
|
+
template: str,
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Generate a reproducer script from NDJSON trace file.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
input_path: Path to the NDJSON trace file.
|
|
27
|
+
line_index: Line index of the launch event to reproduce.
|
|
28
|
+
out_dir: Output directory for reproducer files.
|
|
29
|
+
"""
|
|
30
|
+
logger.debug(f"Building bundle from {input_path} at line {line_index}")
|
|
31
|
+
events = load_ndjson(Path(input_path), save_irs=True)
|
|
32
|
+
logger.debug(f"Loaded {len(events)} events")
|
|
33
|
+
|
|
34
|
+
# Build context bundle from the specified launch event
|
|
35
|
+
context_bundle = build_context_bundle(events, line_index)
|
|
36
|
+
logger.debug(
|
|
37
|
+
f"Built context bundle for kernel: {context_bundle.kernel_info.function_name}"
|
|
38
|
+
)
|
|
39
|
+
out_py_path, temp_json_path = determine_output_paths(
|
|
40
|
+
out_dir, context_bundle.kernel_info.function_name
|
|
41
|
+
)
|
|
42
|
+
save_prettified_json(context_bundle.raw_launch_event, temp_json_path)
|
|
43
|
+
logger.debug("Loading reproducer template.")
|
|
44
|
+
template_code = load_template_code(template)
|
|
45
|
+
final_code = template_code.replace(
|
|
46
|
+
"{{JSON_FILE_NAME_PLACEHOLDER}}", temp_json_path.name
|
|
47
|
+
)
|
|
48
|
+
sys_stmt, import_statement = _generate_import_statements(context_bundle.kernel_info)
|
|
49
|
+
final_code = final_code.replace("# {{KERNEL_SYSPATH_PLACEHOLDER}}", sys_stmt)
|
|
50
|
+
final_code = final_code.replace("# {{KERNEL_IMPORT_PLACEHOLDER}}", import_statement)
|
|
51
|
+
source_code = context_bundle.kernel_info.source_code
|
|
52
|
+
pos_args, kw_args = _parse_kernel_signature(source_code)
|
|
53
|
+
invocation_snippet = _generate_invocation_snippet(pos_args, kw_args)
|
|
54
|
+
final_code = final_code.replace(
|
|
55
|
+
"# {{KERNEL_INVOCATION_PLACEHOLDER}}", invocation_snippet
|
|
56
|
+
)
|
|
57
|
+
out_py_path.write_text(final_code, encoding="utf-8")
|
|
58
|
+
logger.info(
|
|
59
|
+
"REPRODUCER_OUTPUT script=%s json=%s kernel=%s",
|
|
60
|
+
str(out_py_path.resolve()),
|
|
61
|
+
str(temp_json_path.resolve()),
|
|
62
|
+
context_bundle.kernel_info.function_name,
|
|
63
|
+
)
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import importlib
|
|
3
|
+
import json
|
|
4
|
+
import sys
|
|
5
|
+
from functools import lru_cache
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
# {{KERNEL_SYSPATH_PLACEHOLDER}}
|
|
11
|
+
|
|
12
|
+
# {{KERNEL_IMPORT_PLACEHOLDER}}
|
|
13
|
+
|
|
14
|
+
TRITON_KERNELS_CUSTOM_TYPES = (
|
|
15
|
+
importlib.util.find_spec("triton_kernels") is not None
|
|
16
|
+
and importlib.util.find_spec("triton_kernels.tensor") is not None
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@lru_cache(maxsize=1)
|
|
21
|
+
def _get_triton_tensor_types():
|
|
22
|
+
"""
|
|
23
|
+
Import and cache Triton custom tensor types.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
tuple: (Tensor, Storage, StridedLayout) classes from triton_kernels.tensor.
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
ImportError: If the optional module 'triton_kernels.tensor' is not available.
|
|
30
|
+
"""
|
|
31
|
+
mod = importlib.import_module("triton_kernels.tensor")
|
|
32
|
+
return (
|
|
33
|
+
mod.Tensor,
|
|
34
|
+
mod.Storage,
|
|
35
|
+
mod.StridedLayout,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def load_tensor(tensor_file_path: str, device: str = None) -> torch.Tensor:
|
|
40
|
+
"""
|
|
41
|
+
Load a tensor from its file path and verify its integrity using the hash in the filename.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
tensor_file_path (str): Direct path to the tensor .bin file. The filename should be
|
|
45
|
+
the hash of the file contents followed by .bin extension.
|
|
46
|
+
device (str, optional): Device to load the tensor to (e.g., 'cuda:0', 'cpu').
|
|
47
|
+
If None, keeps the tensor on its original device.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
torch.Tensor: The loaded tensor (moved to the specified device if provided)
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
FileNotFoundError: If the tensor file doesn't exist
|
|
54
|
+
RuntimeError: If the tensor cannot be loaded
|
|
55
|
+
ValueError: If the computed hash doesn't match the filename hash
|
|
56
|
+
"""
|
|
57
|
+
blob_path = Path(tensor_file_path)
|
|
58
|
+
|
|
59
|
+
if not blob_path.exists():
|
|
60
|
+
raise FileNotFoundError(f"Tensor blob not found: {blob_path}")
|
|
61
|
+
|
|
62
|
+
# Extract expected hash from filename (remove .bin extension)
|
|
63
|
+
expected_hash = blob_path.stem
|
|
64
|
+
|
|
65
|
+
# Compute actual hash of file contents
|
|
66
|
+
with open(blob_path, "rb") as f:
|
|
67
|
+
file_contents = f.read()
|
|
68
|
+
computed_hash = hashlib.blake2b(file_contents).hexdigest()
|
|
69
|
+
|
|
70
|
+
# Verify hash matches filename
|
|
71
|
+
if computed_hash != expected_hash:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"Hash verification failed: expected '{expected_hash}' but computed '{computed_hash}'"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
# Load the tensor using torch.load (tensors are saved with torch.save)
|
|
78
|
+
# If device is None, keep tensor on its original device, otherwise move to specified device
|
|
79
|
+
tensor = torch.load(blob_path, map_location=device)
|
|
80
|
+
return tensor
|
|
81
|
+
except Exception as e:
|
|
82
|
+
raise RuntimeError(f"Failed to load tensor from {blob_path}: {str(e)}") from e
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def create_args_from_json(json_path):
|
|
86
|
+
"""
|
|
87
|
+
Parse a reproducer JSON and build kernel grid and argument dictionary.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
json_path (str): Path to the JSON file describing the kernel launch.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
tuple[list, dict]: Grid specification list and map of argument name to value.
|
|
94
|
+
"""
|
|
95
|
+
with open(json_path, "r") as f:
|
|
96
|
+
data = json.load(f)
|
|
97
|
+
# Handle data format validation and extraction
|
|
98
|
+
if isinstance(data, list):
|
|
99
|
+
if len(data) != 1:
|
|
100
|
+
print(
|
|
101
|
+
f"Error: Expected single element list, got list with {len(data)} elements"
|
|
102
|
+
)
|
|
103
|
+
sys.exit(1)
|
|
104
|
+
data = data[0]
|
|
105
|
+
elif not isinstance(data, dict):
|
|
106
|
+
print(f"Error: Expected list or dict, got {type(data)}")
|
|
107
|
+
sys.exit(1)
|
|
108
|
+
|
|
109
|
+
grid = data.get("grid", [])
|
|
110
|
+
args_dict = {}
|
|
111
|
+
extracted_args = data.get("extracted_args", {})
|
|
112
|
+
|
|
113
|
+
for arg_name, arg_info in extracted_args.items():
|
|
114
|
+
args_dict[arg_name] = _create_arg_from_info(arg_info)
|
|
115
|
+
|
|
116
|
+
return grid, args_dict
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _create_arg_from_info(arg_info):
|
|
120
|
+
"""
|
|
121
|
+
Recursively construct a kernel argument from its JSON schema.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
arg_info (dict): JSON object describing a single argument, including
|
|
125
|
+
fields like 'type', 'value', 'dtype', 'shape', 'device', etc.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Any: The constructed Python object suitable for kernel invocation.
|
|
129
|
+
|
|
130
|
+
Raises:
|
|
131
|
+
RuntimeError: When required optional dependencies are missing.
|
|
132
|
+
NotImplementedError: When a dtype or type is not supported yet.
|
|
133
|
+
"""
|
|
134
|
+
arg_type = arg_info.get("type")
|
|
135
|
+
|
|
136
|
+
if arg_type in ["int", "bool"]:
|
|
137
|
+
return arg_info.get("value")
|
|
138
|
+
|
|
139
|
+
elif arg_type == "tensor":
|
|
140
|
+
if arg_info.get("blob_path"):
|
|
141
|
+
return load_tensor(arg_info.get("blob_path"), arg_info.get("device"))
|
|
142
|
+
dtype_str = arg_info.get("dtype")
|
|
143
|
+
try:
|
|
144
|
+
torch_dtype = getattr(torch, dtype_str.split(".")[-1])
|
|
145
|
+
except AttributeError:
|
|
146
|
+
torch_dtype = torch.float32
|
|
147
|
+
|
|
148
|
+
shape = arg_info.get("shape", [])
|
|
149
|
+
device = arg_info.get("device", "cpu")
|
|
150
|
+
|
|
151
|
+
# Use a dummy tensor to check properties of the dtype
|
|
152
|
+
tensor_props = torch.empty(0, dtype=torch_dtype)
|
|
153
|
+
|
|
154
|
+
# Case 1: Floating point, signed integers, uint8, and bool are supported by random_()
|
|
155
|
+
if tensor_props.is_floating_point():
|
|
156
|
+
if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
157
|
+
tmp = torch.rand(shape, dtype=torch.float32, device=device)
|
|
158
|
+
return tmp.to(torch_dtype)
|
|
159
|
+
else:
|
|
160
|
+
return torch.empty(shape, dtype=torch_dtype, device=device).random_()
|
|
161
|
+
elif torch_dtype in [
|
|
162
|
+
torch.int8,
|
|
163
|
+
torch.int16,
|
|
164
|
+
torch.int32,
|
|
165
|
+
torch.int64,
|
|
166
|
+
torch.uint8,
|
|
167
|
+
torch.bool,
|
|
168
|
+
]:
|
|
169
|
+
return torch.empty(shape, dtype=torch_dtype, device=device).random_()
|
|
170
|
+
# Case 2: Complex numbers need special handling
|
|
171
|
+
elif tensor_props.is_complex():
|
|
172
|
+
float_dtype = (
|
|
173
|
+
torch.float32 if torch_dtype == torch.complex64 else torch.float64
|
|
174
|
+
)
|
|
175
|
+
real_part = torch.rand(shape, dtype=float_dtype, device=device)
|
|
176
|
+
imag_part = torch.rand(shape, dtype=float_dtype, device=device)
|
|
177
|
+
return torch.complex(real_part, imag_part)
|
|
178
|
+
|
|
179
|
+
# Case 3: Handle other unsigned integers (like uint32) which fail with random_()
|
|
180
|
+
elif "uint" in str(torch_dtype):
|
|
181
|
+
return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
|
|
182
|
+
# Case 4: If we don't know how to handle the type, raise an error
|
|
183
|
+
else:
|
|
184
|
+
raise NotImplementedError(
|
|
185
|
+
f"Random data generation not implemented for dtype: {torch_dtype}"
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
elif arg_type == "triton_kernels.tensor.Tensor":
|
|
189
|
+
if not TRITON_KERNELS_CUSTOM_TYPES:
|
|
190
|
+
raise RuntimeError(
|
|
191
|
+
"Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Tensor."
|
|
192
|
+
)
|
|
193
|
+
Tensor, Storage, StridedLayout = _get_triton_tensor_types()
|
|
194
|
+
storage = _create_arg_from_info(arg_info.get("storage"))
|
|
195
|
+
dtype_str = arg_info.get("dtype")
|
|
196
|
+
torch_dtype = getattr(torch, dtype_str.split(".")[-1])
|
|
197
|
+
return Tensor(
|
|
198
|
+
storage=storage,
|
|
199
|
+
shape=arg_info.get("shape"),
|
|
200
|
+
shape_max=arg_info.get("shape_max"),
|
|
201
|
+
dtype=torch_dtype,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
elif arg_type == "triton_kernels.tensor.Storage":
|
|
205
|
+
if not TRITON_KERNELS_CUSTOM_TYPES:
|
|
206
|
+
raise RuntimeError(
|
|
207
|
+
"Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Storage."
|
|
208
|
+
)
|
|
209
|
+
Tensor, Storage, StridedLayout = _get_triton_tensor_types()
|
|
210
|
+
data = _create_arg_from_info(arg_info.get("data"))
|
|
211
|
+
layout = _create_arg_from_info(arg_info.get("layout"))
|
|
212
|
+
return Storage(data=data, layout=layout)
|
|
213
|
+
|
|
214
|
+
elif arg_type == "StridedLayout":
|
|
215
|
+
if not TRITON_KERNELS_CUSTOM_TYPES:
|
|
216
|
+
raise RuntimeError(
|
|
217
|
+
"Optional dependency 'triton_kernels.tensor' is not installed; cannot construct StridedLayout."
|
|
218
|
+
)
|
|
219
|
+
Tensor, Storage, StridedLayout = _get_triton_tensor_types()
|
|
220
|
+
return StridedLayout(shape=arg_info.get("initial_shape"))
|
|
221
|
+
else:
|
|
222
|
+
print(f"Warning: Unhandled argument type '{arg_type}'. Returning None.")
|
|
223
|
+
return None
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
if __name__ == "__main__":
|
|
227
|
+
script_dir = Path(__file__).resolve().parent
|
|
228
|
+
json_file = script_dir / "{{JSON_FILE_NAME_PLACEHOLDER}}"
|
|
229
|
+
grid, args_dict = create_args_from_json(str(json_file))
|
|
230
|
+
|
|
231
|
+
print("Generated kernel arguments dictionary:")
|
|
232
|
+
for name, arg in args_dict.items():
|
|
233
|
+
print(f" {name}: {arg}")
|
|
234
|
+
print(f"Grid: {grid}")
|
|
235
|
+
|
|
236
|
+
# {{KERNEL_INVOCATION_PLACEHOLDER}}
|
|
237
|
+
|
|
238
|
+
torch.cuda.synchronize()
|
|
239
|
+
print("Kernel execution finished.")
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from importlib.resources import files as pkg_files
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
BUILTIN_TEMPLATES_PACKAGE = "tritonparse.reproducer.templates"
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _is_path_like(template_arg: str) -> bool:
|
|
10
|
+
return "/" in template_arg or "\\" in template_arg or template_arg.endswith(".py")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _read_file_text(path: Path) -> str:
|
|
14
|
+
p = path.expanduser().resolve()
|
|
15
|
+
if not p.exists() or not p.is_file():
|
|
16
|
+
raise FileNotFoundError(f"Template not found: {p}")
|
|
17
|
+
return p.read_text(encoding="utf-8")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _read_builtin_template_text(name: str) -> str:
|
|
21
|
+
resource = pkg_files(BUILTIN_TEMPLATES_PACKAGE).joinpath(f"{name}.py")
|
|
22
|
+
# resource may not exist if an invalid name is provided
|
|
23
|
+
try:
|
|
24
|
+
with resource.open("r", encoding="utf-8") as f:
|
|
25
|
+
return f.read()
|
|
26
|
+
except FileNotFoundError as exc:
|
|
27
|
+
available = ", ".join(list_builtin_templates())
|
|
28
|
+
raise FileNotFoundError(
|
|
29
|
+
f"Builtin template '{name}' not found. Available: {available}"
|
|
30
|
+
) from exc
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def list_builtin_templates() -> List[str]:
|
|
34
|
+
"""
|
|
35
|
+
Return the list of available builtin template names (without .py suffix).
|
|
36
|
+
"""
|
|
37
|
+
names: List[str] = []
|
|
38
|
+
for entry in pkg_files(BUILTIN_TEMPLATES_PACKAGE).iterdir():
|
|
39
|
+
try:
|
|
40
|
+
if entry.is_file():
|
|
41
|
+
filename = entry.name
|
|
42
|
+
if filename.endswith(".py") and not filename.startswith("__"):
|
|
43
|
+
names.append(filename[:-3])
|
|
44
|
+
except (OSError, FileNotFoundError):
|
|
45
|
+
# Defensive: in case entry access fails in some environments
|
|
46
|
+
continue
|
|
47
|
+
names.sort()
|
|
48
|
+
return names
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def load_template_code(template_arg: str) -> str:
|
|
52
|
+
"""
|
|
53
|
+
Load template code by name (builtin, without .py) or by filesystem path.
|
|
54
|
+
"""
|
|
55
|
+
if _is_path_like(template_arg):
|
|
56
|
+
return _read_file_text(Path(template_arg))
|
|
57
|
+
return _read_builtin_template_text(template_arg)
|
tritonparse/reproducer/utils.py
CHANGED
|
@@ -2,25 +2,40 @@ import importlib
|
|
|
2
2
|
import importlib.util
|
|
3
3
|
import json
|
|
4
4
|
import sys
|
|
5
|
+
from datetime import datetime
|
|
5
6
|
from functools import lru_cache
|
|
7
|
+
from pathlib import Path
|
|
6
8
|
|
|
7
9
|
import torch
|
|
8
10
|
|
|
11
|
+
from tritonparse.tools.load_tensor import load_tensor
|
|
12
|
+
from tritonparse.tp_logger import logger
|
|
13
|
+
|
|
9
14
|
TRITON_KERNELS_CUSTOM_TYPES = (
|
|
10
|
-
importlib.util.find_spec("triton_kernels
|
|
15
|
+
importlib.util.find_spec("triton_kernels") is not None
|
|
16
|
+
and importlib.util.find_spec("triton_kernels.tensor") is not None
|
|
11
17
|
)
|
|
12
18
|
|
|
13
19
|
|
|
20
|
+
@lru_cache(maxsize=1)
|
|
21
|
+
def _get_triton_tensor_types():
|
|
22
|
+
mod = importlib.import_module("triton_kernels.tensor")
|
|
23
|
+
return (
|
|
24
|
+
mod.Tensor,
|
|
25
|
+
mod.Storage,
|
|
26
|
+
mod.StridedLayout,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
14
30
|
def create_args_from_json(json_path):
|
|
15
31
|
"""
|
|
16
|
-
|
|
32
|
+
Parse a reproducer JSON and build kernel grid and argument dictionary.
|
|
17
33
|
|
|
18
34
|
Args:
|
|
19
|
-
json_path (str):
|
|
20
|
-
launch information.
|
|
35
|
+
json_path (str): Path to the JSON file describing the kernel launch.
|
|
21
36
|
|
|
22
37
|
Returns:
|
|
23
|
-
tuple:
|
|
38
|
+
tuple[list, dict]: Grid specification list and map of argument name to value.
|
|
24
39
|
"""
|
|
25
40
|
with open(json_path, "r") as f:
|
|
26
41
|
data = json.load(f)
|
|
@@ -46,19 +61,20 @@ def create_args_from_json(json_path):
|
|
|
46
61
|
return grid, args_dict
|
|
47
62
|
|
|
48
63
|
|
|
49
|
-
@lru_cache(maxsize=1)
|
|
50
|
-
def _get_triton_tensor_types():
|
|
51
|
-
mod = importlib.import_module("triton_kernels.tensor")
|
|
52
|
-
return (
|
|
53
|
-
getattr(mod, "Tensor"),
|
|
54
|
-
getattr(mod, "Storage"),
|
|
55
|
-
getattr(mod, "StridedLayout"),
|
|
56
|
-
)
|
|
57
|
-
|
|
58
|
-
|
|
59
64
|
def _create_arg_from_info(arg_info):
|
|
60
65
|
"""
|
|
61
|
-
Recursively
|
|
66
|
+
Recursively construct a kernel argument from its JSON schema.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
arg_info (dict): JSON object describing a single argument, including
|
|
70
|
+
fields like 'type', 'value', 'dtype', 'shape', 'device', etc.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Any: The constructed Python object suitable for kernel invocation.
|
|
74
|
+
|
|
75
|
+
Raises:
|
|
76
|
+
RuntimeError: When required optional dependencies are missing.
|
|
77
|
+
NotImplementedError: When a dtype or type is not supported yet.
|
|
62
78
|
"""
|
|
63
79
|
arg_type = arg_info.get("type")
|
|
64
80
|
|
|
@@ -66,6 +82,8 @@ def _create_arg_from_info(arg_info):
|
|
|
66
82
|
return arg_info.get("value")
|
|
67
83
|
|
|
68
84
|
elif arg_type == "tensor":
|
|
85
|
+
if arg_info.get("blob_path"):
|
|
86
|
+
return load_tensor(arg_info.get("blob_path"), arg_info.get("device"))
|
|
69
87
|
dtype_str = arg_info.get("dtype")
|
|
70
88
|
try:
|
|
71
89
|
torch_dtype = getattr(torch, dtype_str.split(".")[-1])
|
|
@@ -79,7 +97,13 @@ def _create_arg_from_info(arg_info):
|
|
|
79
97
|
tensor_props = torch.empty(0, dtype=torch_dtype)
|
|
80
98
|
|
|
81
99
|
# Case 1: Floating point, signed integers, uint8, and bool are supported by random_()
|
|
82
|
-
if tensor_props.is_floating_point()
|
|
100
|
+
if tensor_props.is_floating_point():
|
|
101
|
+
if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
102
|
+
tmp = torch.rand(shape, dtype=torch.float32, device=device)
|
|
103
|
+
return tmp.to(torch_dtype)
|
|
104
|
+
else:
|
|
105
|
+
return torch.empty(shape, dtype=torch_dtype, device=device).random_()
|
|
106
|
+
elif torch_dtype in [
|
|
83
107
|
torch.int8,
|
|
84
108
|
torch.int16,
|
|
85
109
|
torch.int32,
|
|
@@ -88,7 +112,6 @@ def _create_arg_from_info(arg_info):
|
|
|
88
112
|
torch.bool,
|
|
89
113
|
]:
|
|
90
114
|
return torch.empty(shape, dtype=torch_dtype, device=device).random_()
|
|
91
|
-
|
|
92
115
|
# Case 2: Complex numbers need special handling
|
|
93
116
|
elif tensor_props.is_complex():
|
|
94
117
|
float_dtype = (
|
|
@@ -101,13 +124,11 @@ def _create_arg_from_info(arg_info):
|
|
|
101
124
|
# Case 3: Handle other unsigned integers (like uint32) which fail with random_()
|
|
102
125
|
elif "uint" in str(torch_dtype):
|
|
103
126
|
return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
|
|
104
|
-
|
|
105
|
-
# Case 4: If we don't know how to handle the type, raise an error
|
|
127
|
+
# Case 4: If we don't know how to handle the type, raise an error
|
|
106
128
|
else:
|
|
107
129
|
raise NotImplementedError(
|
|
108
130
|
f"Random data generation not implemented for dtype: {torch_dtype}"
|
|
109
131
|
)
|
|
110
|
-
|
|
111
132
|
elif arg_type == "triton_kernels.tensor.Tensor":
|
|
112
133
|
if not TRITON_KERNELS_CUSTOM_TYPES:
|
|
113
134
|
raise RuntimeError(
|
|
@@ -145,3 +166,137 @@ def _create_arg_from_info(arg_info):
|
|
|
145
166
|
else:
|
|
146
167
|
print(f"Warning: Unhandled argument type '{arg_type}'. Returning None.")
|
|
147
168
|
return None
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def determine_output_paths(out_dir: str, kernel_name: str):
|
|
172
|
+
"""
|
|
173
|
+
Determine output file paths for reproducer script and context data.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
out_dir: Output directory path. If empty, uses default location.
|
|
177
|
+
kernel_name: Name of the kernel for default directory naming.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Tuple of (python_script_path, json_context_path) as Path objects.
|
|
181
|
+
"""
|
|
182
|
+
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
183
|
+
output_directory = Path(out_dir) / kernel_name
|
|
184
|
+
output_directory.mkdir(parents=True, exist_ok=True)
|
|
185
|
+
|
|
186
|
+
out_py_path = output_directory / f"repro_{timestamp}.py"
|
|
187
|
+
temp_json_path = output_directory / f"repro_context_{timestamp}.json"
|
|
188
|
+
|
|
189
|
+
return out_py_path, temp_json_path
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _generate_import_statements(kernel_info) -> tuple[str, str]:
|
|
193
|
+
"""
|
|
194
|
+
Generate (sys.path insertion statement, import statement) for the kernel.
|
|
195
|
+
|
|
196
|
+
Strategy:
|
|
197
|
+
- Always add the kernel file's parent directory to sys.path.
|
|
198
|
+
- If the filename (without .py) is a valid identifier, import using that
|
|
199
|
+
module name: `from <stem> import <func> as imported_kernel_function`.
|
|
200
|
+
- Otherwise, fall back to dynamic import via importlib.util and bind
|
|
201
|
+
`imported_kernel_function` from the loaded module.
|
|
202
|
+
"""
|
|
203
|
+
file_path = Path(kernel_info.file_path)
|
|
204
|
+
function_name = kernel_info.function_name
|
|
205
|
+
|
|
206
|
+
if not file_path or not function_name:
|
|
207
|
+
raise ValueError("Kernel file path or function name missing from context.")
|
|
208
|
+
|
|
209
|
+
# Always add the file's parent directory to sys.path
|
|
210
|
+
sys_stmt = (
|
|
211
|
+
"import sys; p = r'" + str(file_path.parent) + "';\n"
|
|
212
|
+
"if p not in sys.path: sys.path.insert(0, p)"
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
module_name = file_path.with_suffix("").name
|
|
216
|
+
if module_name.isidentifier():
|
|
217
|
+
import_stmt = (
|
|
218
|
+
f"from {module_name} import {function_name} as imported_kernel_function"
|
|
219
|
+
)
|
|
220
|
+
logger.debug("Generated direct import statement: %s", import_stmt)
|
|
221
|
+
return sys_stmt, import_stmt
|
|
222
|
+
|
|
223
|
+
# Fallback: dynamic import when filename is not a valid identifier
|
|
224
|
+
import_stmt = (
|
|
225
|
+
"import importlib.util\n"
|
|
226
|
+
f"_spec = importlib.util.spec_from_file_location('kernel_mod', r'{str(file_path)}')\n"
|
|
227
|
+
"_mod = importlib.util.module_from_spec(_spec)\n"
|
|
228
|
+
"_spec.loader.exec_module(_mod)\n"
|
|
229
|
+
f"imported_kernel_function = getattr(_mod, '{function_name}')"
|
|
230
|
+
)
|
|
231
|
+
logger.debug("Generated dynamic import for file: %s", file_path)
|
|
232
|
+
return sys_stmt, import_stmt
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _parse_kernel_signature(kernel_source_code: str) -> tuple[list[str], list[str]]:
|
|
236
|
+
"""
|
|
237
|
+
Parses a Triton kernel's source code to distinguish positional args
|
|
238
|
+
from keyword args (those with default values).
|
|
239
|
+
"""
|
|
240
|
+
signature_lines = []
|
|
241
|
+
in_signature = False
|
|
242
|
+
for line in kernel_source_code.splitlines():
|
|
243
|
+
# Mark beginning of signature when function definition is found
|
|
244
|
+
if line.strip().startswith("def "):
|
|
245
|
+
in_signature = True
|
|
246
|
+
if in_signature:
|
|
247
|
+
# Strip comments and leading/trailing whitespace
|
|
248
|
+
clean_line = line.split("#")[0].strip()
|
|
249
|
+
signature_lines.append(clean_line)
|
|
250
|
+
# Stop capturing after the signature ends
|
|
251
|
+
if "):" in line:
|
|
252
|
+
break
|
|
253
|
+
|
|
254
|
+
full_signature = "".join(signature_lines)
|
|
255
|
+
# Extract content between the first '(' and the last '):'
|
|
256
|
+
try:
|
|
257
|
+
params_str = full_signature[
|
|
258
|
+
full_signature.find("(") + 1 : full_signature.rfind("):")
|
|
259
|
+
]
|
|
260
|
+
except IndexError as exc:
|
|
261
|
+
raise ValueError("Could not parse kernel signature.") from exc
|
|
262
|
+
|
|
263
|
+
# Clean up and split the parameters string
|
|
264
|
+
params = [p.strip() for p in params_str.replace("\n", "").split(",") if p.strip()]
|
|
265
|
+
|
|
266
|
+
positional_args = []
|
|
267
|
+
keyword_args = []
|
|
268
|
+
|
|
269
|
+
for param in params:
|
|
270
|
+
if "=" in param:
|
|
271
|
+
# Keyword arguments have a default value
|
|
272
|
+
arg_name = param.split("=")[0].strip()
|
|
273
|
+
keyword_args.append(arg_name)
|
|
274
|
+
else:
|
|
275
|
+
# Positional arguments do not have a default value
|
|
276
|
+
arg_name = param.split(":")[0].strip()
|
|
277
|
+
positional_args.append(arg_name)
|
|
278
|
+
|
|
279
|
+
logger.debug("Parsed positional args: %s", positional_args)
|
|
280
|
+
logger.debug("Parsed keyword args: %s", keyword_args)
|
|
281
|
+
return positional_args, keyword_args
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _generate_invocation_snippet(
|
|
285
|
+
positional_args: list[str], keyword_args: list[str]
|
|
286
|
+
) -> str:
|
|
287
|
+
"""Generates a single-line Python code snippet for kernel invocation."""
|
|
288
|
+
# Prepare positional args for direct injection into the call
|
|
289
|
+
pos_args_str = ", ".join([f'args_dict["{arg}"]' for arg in positional_args])
|
|
290
|
+
|
|
291
|
+
# Prepare keyword args for direct injection
|
|
292
|
+
kw_args_str = ", ".join([f'{arg}=args_dict["{arg}"]' for arg in keyword_args])
|
|
293
|
+
|
|
294
|
+
# Combine them, ensuring proper comma separation
|
|
295
|
+
all_args = []
|
|
296
|
+
if pos_args_str:
|
|
297
|
+
all_args.append(pos_args_str)
|
|
298
|
+
if kw_args_str:
|
|
299
|
+
all_args.append(kw_args_str)
|
|
300
|
+
|
|
301
|
+
# Create the single-line call
|
|
302
|
+
return f"imported_kernel_function[tuple(grid)]({', '.join(all_args)})"
|
|
@@ -40,7 +40,7 @@ import argparse
|
|
|
40
40
|
import json
|
|
41
41
|
import sys
|
|
42
42
|
from pathlib import Path
|
|
43
|
-
from typing import Any, List
|
|
43
|
+
from typing import Any, List, Union
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
def parse_line_ranges(lines_arg: str) -> set[int]:
|
|
@@ -174,7 +174,7 @@ def load_ndjson(
|
|
|
174
174
|
except FileNotFoundError:
|
|
175
175
|
print(f"Error: File '{file_path}' not found.", file=sys.stderr)
|
|
176
176
|
raise
|
|
177
|
-
except
|
|
177
|
+
except (OSError, UnicodeDecodeError) as e:
|
|
178
178
|
print(f"Error reading file '{file_path}': {e}", file=sys.stderr)
|
|
179
179
|
raise
|
|
180
180
|
|
|
@@ -201,19 +201,21 @@ def load_ndjson(
|
|
|
201
201
|
return json_objects
|
|
202
202
|
|
|
203
203
|
|
|
204
|
-
def save_prettified_json(
|
|
204
|
+
def save_prettified_json(
|
|
205
|
+
json_objects: Union[List[Any], Any], output_path: Path
|
|
206
|
+
) -> None:
|
|
205
207
|
"""
|
|
206
|
-
Save
|
|
208
|
+
Save JSON data to a prettified JSON file.
|
|
207
209
|
|
|
208
210
|
Args:
|
|
209
|
-
json_objects:
|
|
211
|
+
json_objects: Either a list of JSON objects or a single JSON-serializable object
|
|
210
212
|
output_path: Path where to save the prettified JSON file
|
|
211
213
|
"""
|
|
212
214
|
try:
|
|
213
215
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
214
216
|
json.dump(json_objects, f, indent=2, ensure_ascii=False, sort_keys=True)
|
|
215
217
|
print(f"Successfully converted to prettified JSON: {output_path}")
|
|
216
|
-
except
|
|
218
|
+
except OSError as e:
|
|
217
219
|
print(f"Error writing to file '{output_path}': {e}", file=sys.stderr)
|
|
218
220
|
raise
|
|
219
221
|
|
tritonparse/utils.py
CHANGED
|
@@ -16,21 +16,15 @@ from .common import (
|
|
|
16
16
|
)
|
|
17
17
|
from .source_type import Source, SourceType
|
|
18
18
|
|
|
19
|
-
# argument parser for OSS
|
|
20
|
-
parser = None
|
|
21
19
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
global parser
|
|
25
|
-
|
|
26
|
-
parser = argparse.ArgumentParser(
|
|
27
|
-
description="analyze triton structured logs", conflict_handler="resolve"
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
# Add arguments for the parse command
|
|
20
|
+
def _add_parse_args(parser: argparse.ArgumentParser) -> None:
|
|
21
|
+
"""Add common 'parse' subcommand arguments to a parser."""
|
|
31
22
|
parser.add_argument(
|
|
32
23
|
"source",
|
|
33
|
-
help=
|
|
24
|
+
help=(
|
|
25
|
+
"Source of torch logs to be analyzed. It is expected to path to a local "
|
|
26
|
+
"directory or log"
|
|
27
|
+
),
|
|
34
28
|
)
|
|
35
29
|
parser.add_argument(
|
|
36
30
|
"-o",
|
|
@@ -40,7 +34,9 @@ def init_parser():
|
|
|
40
34
|
)
|
|
41
35
|
parser.add_argument(
|
|
42
36
|
"--overwrite",
|
|
43
|
-
help=
|
|
37
|
+
help=(
|
|
38
|
+
"Delete out directory if it already exists. Only does something if --out is set"
|
|
39
|
+
),
|
|
44
40
|
action="store_true",
|
|
45
41
|
)
|
|
46
42
|
parser.add_argument("-r", "--rank", help="Rank of logs to be analyzed", type=int)
|
|
@@ -54,7 +50,6 @@ def init_parser():
|
|
|
54
50
|
from tritonparse.fb.utils import append_parser
|
|
55
51
|
|
|
56
52
|
append_parser(parser)
|
|
57
|
-
return parser
|
|
58
53
|
|
|
59
54
|
|
|
60
55
|
def oss_run(
|
|
@@ -113,12 +108,6 @@ def oss_run(
|
|
|
113
108
|
print_parsed_files_summary(out_dir)
|
|
114
109
|
|
|
115
110
|
|
|
116
|
-
def unified_parse_from_cli():
|
|
117
|
-
parser = init_parser()
|
|
118
|
-
args = parser.parse_args()
|
|
119
|
-
return unified_parse(**vars(args))
|
|
120
|
-
|
|
121
|
-
|
|
122
111
|
def unified_parse(
|
|
123
112
|
source: str,
|
|
124
113
|
out: Optional[str] = None,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tritonparse
|
|
3
|
-
Version: 0.2.4.
|
|
3
|
+
Version: 0.2.4.dev20250924071525
|
|
4
4
|
Summary: TritonParse: A Compiler Tracer, Visualizer, and mini-Reproducer Generator for Triton Kernels
|
|
5
5
|
Author-email: Yueming Hao <yhao@meta.com>
|
|
6
6
|
License-Expression: BSD-3-Clause
|
|
@@ -10,17 +10,22 @@ tritonparse/sourcemap_utils.py,sha256=qsQmTDuEe9yuUVyxSHRbjTR38gi0hvJEijnPkrJVAV
|
|
|
10
10
|
tritonparse/structured_logging.py,sha256=7r9pv6miUdb8-CCZfj8SkD3XItzwPeONmszEL7TZak4,43949
|
|
11
11
|
tritonparse/tp_logger.py,sha256=vXzY7hMDmVnRBGBhIjFZe3nHZzG5NKKPONGUszJhGgU,242
|
|
12
12
|
tritonparse/trace_processor.py,sha256=QzUOKwnOkBbwTTKBsa5ZMUABPLMJIBFtTcG2SkhO0I8,12771
|
|
13
|
-
tritonparse/utils.py,sha256=
|
|
13
|
+
tritonparse/utils.py,sha256=ujx9iUrpOthJ5vWzaNs6RXtqX0dp_GeozOaQLqlUDxg,4269
|
|
14
14
|
tritonparse/reproducer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
-
tritonparse/reproducer/
|
|
15
|
+
tritonparse/reproducer/cli.py,sha256=JvnAi1FKSpNa6eHEapRn9jdXxsj1vvyrrEEnfxTJYa8,871
|
|
16
|
+
tritonparse/reproducer/orchestrator.py,sha256=9SQ_rATY-s4r3BZQZdKLw7WYGz8IQJ1StPMgRbKAs5s,2456
|
|
17
|
+
tritonparse/reproducer/utils.py,sha256=qi4XTKk0pWV4hgYg_GPBISEfVXlrI6tZR0A5ZZbwVyo,11132
|
|
18
|
+
tritonparse/reproducer/ingestion/ndjson.py,sha256=pEujTl5xXW2E2DEW8ngxXQ8qP9oawb90wBVTWHDs1jk,7372
|
|
19
|
+
tritonparse/reproducer/templates/example.py,sha256=XWfXD4tDOiE213YlWWK1l1ZgXbK3BX61NnvuVTkO-S0,8595
|
|
20
|
+
tritonparse/reproducer/templates/loader.py,sha256=HqjfThdDVg7q2bYWry78sIaVRkUpkcA8KQDt83YrlVE,1920
|
|
16
21
|
tritonparse/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
22
|
tritonparse/tools/decompress_bin_ndjson.py,sha256=kpt7DM_sSA334F1X45xdkP2OR9LuB27Pc50EkGr6CPM,4144
|
|
18
23
|
tritonparse/tools/format_fix.py,sha256=Ol0Sjui8D7OzHwbamAfGnq8V5Y63uwNaFTKSORN5HkQ,3867
|
|
19
24
|
tritonparse/tools/load_tensor.py,sha256=tfdmNVd9gsZqO6msQBhbXIhOvUzgc83yF64k2GDWPNk,2122
|
|
20
|
-
tritonparse/tools/prettify_ndjson.py,sha256=
|
|
25
|
+
tritonparse/tools/prettify_ndjson.py,sha256=YpJ7SFXTkZPZEXQeN1w5wkOf9pFrGqaqhhfHV7eobWA,10998
|
|
21
26
|
tritonparse/tools/readme.md,sha256=w6PWYfYnRgoPArLjxG9rVrpcLUkoVMGuRlbpF-o0IQM,110
|
|
22
|
-
tritonparse-0.2.4.
|
|
23
|
-
tritonparse-0.2.4.
|
|
24
|
-
tritonparse-0.2.4.
|
|
25
|
-
tritonparse-0.2.4.
|
|
26
|
-
tritonparse-0.2.4.
|
|
27
|
+
tritonparse-0.2.4.dev20250924071525.dist-info/licenses/LICENSE,sha256=4ZciugpyN7wcM4L-9pyDh_etvMUeIfBhDTyH1zeZlQM,1515
|
|
28
|
+
tritonparse-0.2.4.dev20250924071525.dist-info/METADATA,sha256=cxNEHWh9EoRq332ybQO3nsai4gC5eASCWvyPloW0gko,6580
|
|
29
|
+
tritonparse-0.2.4.dev20250924071525.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
30
|
+
tritonparse-0.2.4.dev20250924071525.dist-info/top_level.txt,sha256=ITcTKgp3vf_bXV9vixuQU9IrZa3L1EfDSZwvRzRaoJU,12
|
|
31
|
+
tritonparse-0.2.4.dev20250924071525.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|