tritonparse 0.2.1.dev20250915071616__py3-none-any.whl → 0.2.1.dev20250916071516__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/reproducer/__init__.py +0 -0
- tritonparse/reproducer/utils.py +147 -0
- tritonparse/structured_logging.py +103 -14
- {tritonparse-0.2.1.dev20250915071616.dist-info → tritonparse-0.2.1.dev20250916071516.dist-info}/METADATA +1 -1
- {tritonparse-0.2.1.dev20250915071616.dist-info → tritonparse-0.2.1.dev20250916071516.dist-info}/RECORD +8 -6
- {tritonparse-0.2.1.dev20250915071616.dist-info → tritonparse-0.2.1.dev20250916071516.dist-info}/WHEEL +0 -0
- {tritonparse-0.2.1.dev20250915071616.dist-info → tritonparse-0.2.1.dev20250916071516.dist-info}/licenses/LICENSE +0 -0
- {tritonparse-0.2.1.dev20250915071616.dist-info → tritonparse-0.2.1.dev20250916071516.dist-info}/top_level.txt +0 -0
|
File without changes
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import importlib.util
|
|
3
|
+
import json
|
|
4
|
+
import sys
|
|
5
|
+
from functools import lru_cache
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
TRITON_KERNELS_CUSTOM_TYPES = (
|
|
10
|
+
importlib.util.find_spec("triton_kernels.tensor") is not None
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def create_args_from_json(json_path):
|
|
15
|
+
"""
|
|
16
|
+
Creates a list of arguments for a kernel launch from a JSON file.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
json_path (str): The path to the JSON file containing the kernel
|
|
20
|
+
launch information.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
tuple: A tuple containing the grid and a dictionary of arguments.
|
|
24
|
+
"""
|
|
25
|
+
with open(json_path, "r") as f:
|
|
26
|
+
data = json.load(f)
|
|
27
|
+
# Handle data format validation and extraction
|
|
28
|
+
if isinstance(data, list):
|
|
29
|
+
if len(data) != 1:
|
|
30
|
+
print(
|
|
31
|
+
f"Error: Expected single element list, got list with {len(data)} elements"
|
|
32
|
+
)
|
|
33
|
+
sys.exit(1)
|
|
34
|
+
data = data[0]
|
|
35
|
+
elif not isinstance(data, dict):
|
|
36
|
+
print(f"Error: Expected list or dict, got {type(data)}")
|
|
37
|
+
sys.exit(1)
|
|
38
|
+
|
|
39
|
+
grid = data.get("grid", [])
|
|
40
|
+
args_dict = {}
|
|
41
|
+
extracted_args = data.get("extracted_args", {})
|
|
42
|
+
|
|
43
|
+
for arg_name, arg_info in extracted_args.items():
|
|
44
|
+
args_dict[arg_name] = _create_arg_from_info(arg_info)
|
|
45
|
+
|
|
46
|
+
return grid, args_dict
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@lru_cache(maxsize=1)
|
|
50
|
+
def _get_triton_tensor_types():
|
|
51
|
+
mod = importlib.import_module("triton_kernels.tensor")
|
|
52
|
+
return (
|
|
53
|
+
getattr(mod, "Tensor"),
|
|
54
|
+
getattr(mod, "Storage"),
|
|
55
|
+
getattr(mod, "StridedLayout"),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _create_arg_from_info(arg_info):
|
|
60
|
+
"""
|
|
61
|
+
Recursively creates a kernel argument from its JSON info dictionary.
|
|
62
|
+
"""
|
|
63
|
+
arg_type = arg_info.get("type")
|
|
64
|
+
|
|
65
|
+
if arg_type in ["int", "bool"]:
|
|
66
|
+
return arg_info.get("value")
|
|
67
|
+
|
|
68
|
+
elif arg_type == "tensor":
|
|
69
|
+
dtype_str = arg_info.get("dtype")
|
|
70
|
+
try:
|
|
71
|
+
torch_dtype = getattr(torch, dtype_str.split(".")[-1])
|
|
72
|
+
except AttributeError:
|
|
73
|
+
torch_dtype = torch.float32
|
|
74
|
+
|
|
75
|
+
shape = arg_info.get("shape", [])
|
|
76
|
+
device = arg_info.get("device", "cpu")
|
|
77
|
+
|
|
78
|
+
# Use a dummy tensor to check properties of the dtype
|
|
79
|
+
tensor_props = torch.empty(0, dtype=torch_dtype)
|
|
80
|
+
|
|
81
|
+
# Case 1: Floating point, signed integers, uint8, and bool are supported by random_()
|
|
82
|
+
if tensor_props.is_floating_point() or torch_dtype in [
|
|
83
|
+
torch.int8,
|
|
84
|
+
torch.int16,
|
|
85
|
+
torch.int32,
|
|
86
|
+
torch.int64,
|
|
87
|
+
torch.uint8,
|
|
88
|
+
torch.bool,
|
|
89
|
+
]:
|
|
90
|
+
return torch.empty(shape, dtype=torch_dtype, device=device).random_()
|
|
91
|
+
|
|
92
|
+
# Case 2: Complex numbers need special handling
|
|
93
|
+
elif tensor_props.is_complex():
|
|
94
|
+
float_dtype = (
|
|
95
|
+
torch.float32 if torch_dtype == torch.complex64 else torch.float64
|
|
96
|
+
)
|
|
97
|
+
real_part = torch.rand(shape, dtype=float_dtype, device=device)
|
|
98
|
+
imag_part = torch.rand(shape, dtype=float_dtype, device=device)
|
|
99
|
+
return torch.complex(real_part, imag_part)
|
|
100
|
+
|
|
101
|
+
# Case 3: Handle other unsigned integers (like uint32) which fail with random_()
|
|
102
|
+
elif "uint" in str(torch_dtype):
|
|
103
|
+
return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
|
|
104
|
+
|
|
105
|
+
# Case 4: If we don't know how to handle the type, raise an error
|
|
106
|
+
else:
|
|
107
|
+
raise NotImplementedError(
|
|
108
|
+
f"Random data generation not implemented for dtype: {torch_dtype}"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
elif arg_type == "triton_kernels.tensor.Tensor":
|
|
112
|
+
if not TRITON_KERNELS_CUSTOM_TYPES:
|
|
113
|
+
raise RuntimeError(
|
|
114
|
+
"Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Tensor."
|
|
115
|
+
)
|
|
116
|
+
Tensor, Storage, StridedLayout = _get_triton_tensor_types()
|
|
117
|
+
storage = _create_arg_from_info(arg_info.get("storage"))
|
|
118
|
+
dtype_str = arg_info.get("dtype")
|
|
119
|
+
torch_dtype = getattr(torch, dtype_str.split(".")[-1])
|
|
120
|
+
return Tensor(
|
|
121
|
+
storage=storage,
|
|
122
|
+
shape=arg_info.get("shape"),
|
|
123
|
+
shape_max=arg_info.get("shape_max"),
|
|
124
|
+
dtype=torch_dtype,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
elif arg_type == "triton_kernels.tensor.Storage":
|
|
128
|
+
if not TRITON_KERNELS_CUSTOM_TYPES:
|
|
129
|
+
raise RuntimeError(
|
|
130
|
+
"Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Storage."
|
|
131
|
+
)
|
|
132
|
+
Tensor, Storage, StridedLayout = _get_triton_tensor_types()
|
|
133
|
+
data = _create_arg_from_info(arg_info.get("data"))
|
|
134
|
+
layout = _create_arg_from_info(arg_info.get("layout"))
|
|
135
|
+
return Storage(data=data, layout=layout)
|
|
136
|
+
|
|
137
|
+
elif arg_type == "StridedLayout":
|
|
138
|
+
if not TRITON_KERNELS_CUSTOM_TYPES:
|
|
139
|
+
raise RuntimeError(
|
|
140
|
+
"Optional dependency 'triton_kernels.tensor' is not installed; cannot construct StridedLayout."
|
|
141
|
+
)
|
|
142
|
+
Tensor, Storage, StridedLayout = _get_triton_tensor_types()
|
|
143
|
+
return StridedLayout(shape=arg_info.get("initial_shape"))
|
|
144
|
+
|
|
145
|
+
else:
|
|
146
|
+
print(f"Warning: Unhandled argument type '{arg_type}'. Returning None.")
|
|
147
|
+
return None
|
|
@@ -183,13 +183,80 @@ def convert(obj):
|
|
|
183
183
|
asdict(obj)
|
|
184
184
|
) # Convert dataclass to dict and then process that dict
|
|
185
185
|
|
|
186
|
+
if _is_triton_kernels_layout(obj):
|
|
187
|
+
layout_info = {"type": type(obj).__name__}
|
|
188
|
+
if hasattr(obj, "initial_shape"):
|
|
189
|
+
layout_info["initial_shape"] = convert(obj.initial_shape)
|
|
190
|
+
if hasattr(obj, "name"):
|
|
191
|
+
layout_info["name"] = convert(obj.name)
|
|
192
|
+
return layout_info
|
|
193
|
+
|
|
186
194
|
# 4. Common Triton constexpr objects
|
|
187
195
|
if isinstance(obj, dtype):
|
|
188
196
|
return f"triton.language.core.dtype('{str(obj)}')"
|
|
197
|
+
|
|
198
|
+
if TORCH_INSTALLED and isinstance(obj, torch.dtype):
|
|
199
|
+
return str(obj)
|
|
200
|
+
|
|
189
201
|
log.warning(f"Unknown type: {type(obj)}")
|
|
190
202
|
return str(obj) # Return primitive types as-is
|
|
191
203
|
|
|
192
204
|
|
|
205
|
+
def _is_triton_kernels_layout(obj):
|
|
206
|
+
"""
|
|
207
|
+
Check if an object is an instance of a Layout class from a
|
|
208
|
+
triton_kernels module by checking its MRO.
|
|
209
|
+
"""
|
|
210
|
+
t = type(obj)
|
|
211
|
+
for base_class in t.__mro__:
|
|
212
|
+
module_name = getattr(base_class, "__module__", "")
|
|
213
|
+
type_name = getattr(base_class, "__name__", "")
|
|
214
|
+
if type_name == "Layout" and module_name.startswith("triton_kernels"):
|
|
215
|
+
return True
|
|
216
|
+
return False
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _is_from_triton_kernels_module(obj):
|
|
220
|
+
"""
|
|
221
|
+
Check if an object is an instance of Tensor or Storage from a
|
|
222
|
+
triton_kernels module.
|
|
223
|
+
"""
|
|
224
|
+
t = type(obj)
|
|
225
|
+
module_name = getattr(t, "__module__", "")
|
|
226
|
+
type_name = getattr(t, "__name__", "")
|
|
227
|
+
return type_name in ("Tensor", "Storage") and module_name.startswith(
|
|
228
|
+
"triton_kernels"
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def _log_torch_tensor_info(tensor_value):
|
|
233
|
+
"""
|
|
234
|
+
Extracts metadata from a torch.Tensor object.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
tensor_value (torch.Tensor): The tensor to extract information from.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
dict: A dictionary containing tensor metadata.
|
|
241
|
+
"""
|
|
242
|
+
arg_info = {}
|
|
243
|
+
arg_info["type"] = "tensor"
|
|
244
|
+
arg_info["shape"] = list(tensor_value.shape)
|
|
245
|
+
arg_info["dtype"] = str(tensor_value.dtype)
|
|
246
|
+
arg_info["device"] = str(tensor_value.device)
|
|
247
|
+
arg_info["stride"] = list(tensor_value.stride())
|
|
248
|
+
arg_info["numel"] = tensor_value.numel()
|
|
249
|
+
arg_info["is_contiguous"] = tensor_value.is_contiguous()
|
|
250
|
+
arg_info["element_size"] = tensor_value.element_size()
|
|
251
|
+
arg_info["storage_offset"] = tensor_value.storage_offset()
|
|
252
|
+
# Memory usage in bytes
|
|
253
|
+
arg_info["memory_usage"] = tensor_value.numel() * tensor_value.element_size()
|
|
254
|
+
# Add data_ptr for memory tracking (optional)
|
|
255
|
+
if hasattr(tensor_value, "data_ptr"):
|
|
256
|
+
arg_info["data_ptr"] = hex(tensor_value.data_ptr())
|
|
257
|
+
return arg_info
|
|
258
|
+
|
|
259
|
+
|
|
193
260
|
def maybe_enable_debug_logging():
|
|
194
261
|
"""
|
|
195
262
|
This logging is for logging module itself, not for logging the triton compilation.
|
|
@@ -769,7 +836,8 @@ def maybe_trace_triton(
|
|
|
769
836
|
|
|
770
837
|
def extract_arg_info(arg_dict):
|
|
771
838
|
"""
|
|
772
|
-
Extract detailed information from kernel arguments, especially for PyTorch
|
|
839
|
+
Extract detailed information from kernel arguments, especially for PyTorch
|
|
840
|
+
tensors.
|
|
773
841
|
|
|
774
842
|
Args:
|
|
775
843
|
arg_dict: Dictionary of kernel arguments
|
|
@@ -785,19 +853,40 @@ def extract_arg_info(arg_dict):
|
|
|
785
853
|
# Check if it's a PyTorch tensor
|
|
786
854
|
if TORCH_INSTALLED and isinstance(arg_value, torch.Tensor):
|
|
787
855
|
arg_info["type"] = "tensor"
|
|
788
|
-
arg_info
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
arg_info["
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
856
|
+
arg_info.update(_log_torch_tensor_info(arg_value))
|
|
857
|
+
# Handle custom Tensor/Storage types from triton_kernels
|
|
858
|
+
elif _is_from_triton_kernels_module(arg_value):
|
|
859
|
+
type_name = type(arg_value).__name__
|
|
860
|
+
arg_info["type"] = f"triton_kernels.tensor.{type_name}"
|
|
861
|
+
|
|
862
|
+
if type_name == "Tensor":
|
|
863
|
+
# Dump all attributes needed to reconstruct the Tensor wrapper
|
|
864
|
+
if hasattr(arg_value, "shape"):
|
|
865
|
+
arg_info["shape"] = convert(arg_value.shape)
|
|
866
|
+
if hasattr(arg_value, "shape_max"):
|
|
867
|
+
arg_info["shape_max"] = convert(arg_value.shape_max)
|
|
868
|
+
if hasattr(arg_value, "dtype"):
|
|
869
|
+
arg_info["dtype"] = convert(arg_value.dtype)
|
|
870
|
+
if hasattr(arg_value, "storage"):
|
|
871
|
+
# Recursively process the storage, which can be another
|
|
872
|
+
# custom type or a torch.Tensor
|
|
873
|
+
storage_arg = {"storage": arg_value.storage}
|
|
874
|
+
arg_info["storage"] = extract_arg_info(storage_arg)["storage"]
|
|
875
|
+
|
|
876
|
+
elif type_name == "Storage":
|
|
877
|
+
# Dump all attributes needed to reconstruct the Storage object
|
|
878
|
+
if (
|
|
879
|
+
hasattr(arg_value, "data")
|
|
880
|
+
and TORCH_INSTALLED
|
|
881
|
+
and isinstance(arg_value.data, torch.Tensor)
|
|
882
|
+
):
|
|
883
|
+
# The 'data' is a torch.Tensor, log its metadata fully
|
|
884
|
+
arg_info["data"] = _log_torch_tensor_info(arg_value.data)
|
|
885
|
+
if hasattr(arg_value, "layout"):
|
|
886
|
+
arg_info["layout"] = convert(arg_value.layout)
|
|
887
|
+
else:
|
|
888
|
+
log.warning(f"Unknown type: {type(arg_value)}")
|
|
889
|
+
|
|
801
890
|
# Handle scalar values
|
|
802
891
|
elif isinstance(arg_value, (int, float, bool)):
|
|
803
892
|
arg_info["type"] = type(arg_value).__name__
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tritonparse
|
|
3
|
-
Version: 0.2.1.
|
|
3
|
+
Version: 0.2.1.dev20250916071516
|
|
4
4
|
Summary: TritonParse: A Compiler Tracer, Visualizer, and mini-Reproducer Generator for Triton Kernels
|
|
5
5
|
Author-email: Yueming Hao <yhao@meta.com>
|
|
6
6
|
License-Expression: BSD-3-Clause
|
|
@@ -7,18 +7,20 @@ tritonparse/mapper.py,sha256=prrczfi13P7Aa042OrEBsmRF1HW3jDhwxicANgPkWIM,4150
|
|
|
7
7
|
tritonparse/shared_vars.py,sha256=-c9CvXJSDm9spYhDOJPEQProeT_xl3PaNmqTEYi_u4s,505
|
|
8
8
|
tritonparse/source_type.py,sha256=nmYEQS8rfkIN9BhNhQbkmEvKnvS-3zAxRGLY4TaZdi8,1676
|
|
9
9
|
tritonparse/sourcemap_utils.py,sha256=qsQmTDuEe9yuUVyxSHRbjTR38gi0hvJEijnPkrJVAV4,2037
|
|
10
|
-
tritonparse/structured_logging.py,sha256=
|
|
10
|
+
tritonparse/structured_logging.py,sha256=7wSKc9HnV40eLoOgFNosc7dXWvSxsJJo5mnRVoDZqIk,43678
|
|
11
11
|
tritonparse/tp_logger.py,sha256=vXzY7hMDmVnRBGBhIjFZe3nHZzG5NKKPONGUszJhGgU,242
|
|
12
12
|
tritonparse/trace_processor.py,sha256=QzUOKwnOkBbwTTKBsa5ZMUABPLMJIBFtTcG2SkhO0I8,12771
|
|
13
13
|
tritonparse/utils.py,sha256=wt61tpbkqjGqHh0c7Nr2WlOv7PbQssmjULd6uA6aAko,4475
|
|
14
|
+
tritonparse/reproducer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
+
tritonparse/reproducer/utils.py,sha256=VfMBwnTEZO8Ug9_ZRlZUVTMaMczDkviAykXpnK5dacU,5093
|
|
14
16
|
tritonparse/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
17
|
tritonparse/tools/decompress_bin_ndjson.py,sha256=kpt7DM_sSA334F1X45xdkP2OR9LuB27Pc50EkGr6CPM,4144
|
|
16
18
|
tritonparse/tools/format_fix.py,sha256=Ol0Sjui8D7OzHwbamAfGnq8V5Y63uwNaFTKSORN5HkQ,3867
|
|
17
19
|
tritonparse/tools/load_tensor.py,sha256=tfdmNVd9gsZqO6msQBhbXIhOvUzgc83yF64k2GDWPNk,2122
|
|
18
20
|
tritonparse/tools/prettify_ndjson.py,sha256=VOzVWoXpCbaAXYA4i_wBcQIHfh-JhAx7xR4cF_L8yDs,10928
|
|
19
21
|
tritonparse/tools/readme.md,sha256=w6PWYfYnRgoPArLjxG9rVrpcLUkoVMGuRlbpF-o0IQM,110
|
|
20
|
-
tritonparse-0.2.1.
|
|
21
|
-
tritonparse-0.2.1.
|
|
22
|
-
tritonparse-0.2.1.
|
|
23
|
-
tritonparse-0.2.1.
|
|
24
|
-
tritonparse-0.2.1.
|
|
22
|
+
tritonparse-0.2.1.dev20250916071516.dist-info/licenses/LICENSE,sha256=4ZciugpyN7wcM4L-9pyDh_etvMUeIfBhDTyH1zeZlQM,1515
|
|
23
|
+
tritonparse-0.2.1.dev20250916071516.dist-info/METADATA,sha256=5ic4t2gOHTBKvj4oZlYVOERcUOIaO93Ufkbyu6NG5aI,6306
|
|
24
|
+
tritonparse-0.2.1.dev20250916071516.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
25
|
+
tritonparse-0.2.1.dev20250916071516.dist-info/top_level.txt,sha256=ITcTKgp3vf_bXV9vixuQU9IrZa3L1EfDSZwvRzRaoJU,12
|
|
26
|
+
tritonparse-0.2.1.dev20250916071516.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|