tritonparse 0.2.1.dev20250915071616__py3-none-any.whl → 0.2.1.dev20250917071511__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritonparse might be problematic. Click here for more details.

File without changes
@@ -0,0 +1,147 @@
1
+ import importlib
2
+ import importlib.util
3
+ import json
4
+ import sys
5
+ from functools import lru_cache
6
+
7
+ import torch
8
+
9
+ TRITON_KERNELS_CUSTOM_TYPES = (
10
+ importlib.util.find_spec("triton_kernels.tensor") is not None
11
+ )
12
+
13
+
14
+ def create_args_from_json(json_path):
15
+ """
16
+ Creates a list of arguments for a kernel launch from a JSON file.
17
+
18
+ Args:
19
+ json_path (str): The path to the JSON file containing the kernel
20
+ launch information.
21
+
22
+ Returns:
23
+ tuple: A tuple containing the grid and a dictionary of arguments.
24
+ """
25
+ with open(json_path, "r") as f:
26
+ data = json.load(f)
27
+ # Handle data format validation and extraction
28
+ if isinstance(data, list):
29
+ if len(data) != 1:
30
+ print(
31
+ f"Error: Expected single element list, got list with {len(data)} elements"
32
+ )
33
+ sys.exit(1)
34
+ data = data[0]
35
+ elif not isinstance(data, dict):
36
+ print(f"Error: Expected list or dict, got {type(data)}")
37
+ sys.exit(1)
38
+
39
+ grid = data.get("grid", [])
40
+ args_dict = {}
41
+ extracted_args = data.get("extracted_args", {})
42
+
43
+ for arg_name, arg_info in extracted_args.items():
44
+ args_dict[arg_name] = _create_arg_from_info(arg_info)
45
+
46
+ return grid, args_dict
47
+
48
+
49
+ @lru_cache(maxsize=1)
50
+ def _get_triton_tensor_types():
51
+ mod = importlib.import_module("triton_kernels.tensor")
52
+ return (
53
+ getattr(mod, "Tensor"),
54
+ getattr(mod, "Storage"),
55
+ getattr(mod, "StridedLayout"),
56
+ )
57
+
58
+
59
+ def _create_arg_from_info(arg_info):
60
+ """
61
+ Recursively creates a kernel argument from its JSON info dictionary.
62
+ """
63
+ arg_type = arg_info.get("type")
64
+
65
+ if arg_type in ["int", "bool"]:
66
+ return arg_info.get("value")
67
+
68
+ elif arg_type == "tensor":
69
+ dtype_str = arg_info.get("dtype")
70
+ try:
71
+ torch_dtype = getattr(torch, dtype_str.split(".")[-1])
72
+ except AttributeError:
73
+ torch_dtype = torch.float32
74
+
75
+ shape = arg_info.get("shape", [])
76
+ device = arg_info.get("device", "cpu")
77
+
78
+ # Use a dummy tensor to check properties of the dtype
79
+ tensor_props = torch.empty(0, dtype=torch_dtype)
80
+
81
+ # Case 1: Floating point, signed integers, uint8, and bool are supported by random_()
82
+ if tensor_props.is_floating_point() or torch_dtype in [
83
+ torch.int8,
84
+ torch.int16,
85
+ torch.int32,
86
+ torch.int64,
87
+ torch.uint8,
88
+ torch.bool,
89
+ ]:
90
+ return torch.empty(shape, dtype=torch_dtype, device=device).random_()
91
+
92
+ # Case 2: Complex numbers need special handling
93
+ elif tensor_props.is_complex():
94
+ float_dtype = (
95
+ torch.float32 if torch_dtype == torch.complex64 else torch.float64
96
+ )
97
+ real_part = torch.rand(shape, dtype=float_dtype, device=device)
98
+ imag_part = torch.rand(shape, dtype=float_dtype, device=device)
99
+ return torch.complex(real_part, imag_part)
100
+
101
+ # Case 3: Handle other unsigned integers (like uint32) which fail with random_()
102
+ elif "uint" in str(torch_dtype):
103
+ return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
104
+
105
+ # Case 4: If we don't know how to handle the type, raise an error
106
+ else:
107
+ raise NotImplementedError(
108
+ f"Random data generation not implemented for dtype: {torch_dtype}"
109
+ )
110
+
111
+ elif arg_type == "triton_kernels.tensor.Tensor":
112
+ if not TRITON_KERNELS_CUSTOM_TYPES:
113
+ raise RuntimeError(
114
+ "Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Tensor."
115
+ )
116
+ Tensor, Storage, StridedLayout = _get_triton_tensor_types()
117
+ storage = _create_arg_from_info(arg_info.get("storage"))
118
+ dtype_str = arg_info.get("dtype")
119
+ torch_dtype = getattr(torch, dtype_str.split(".")[-1])
120
+ return Tensor(
121
+ storage=storage,
122
+ shape=arg_info.get("shape"),
123
+ shape_max=arg_info.get("shape_max"),
124
+ dtype=torch_dtype,
125
+ )
126
+
127
+ elif arg_type == "triton_kernels.tensor.Storage":
128
+ if not TRITON_KERNELS_CUSTOM_TYPES:
129
+ raise RuntimeError(
130
+ "Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Storage."
131
+ )
132
+ Tensor, Storage, StridedLayout = _get_triton_tensor_types()
133
+ data = _create_arg_from_info(arg_info.get("data"))
134
+ layout = _create_arg_from_info(arg_info.get("layout"))
135
+ return Storage(data=data, layout=layout)
136
+
137
+ elif arg_type == "StridedLayout":
138
+ if not TRITON_KERNELS_CUSTOM_TYPES:
139
+ raise RuntimeError(
140
+ "Optional dependency 'triton_kernels.tensor' is not installed; cannot construct StridedLayout."
141
+ )
142
+ Tensor, Storage, StridedLayout = _get_triton_tensor_types()
143
+ return StridedLayout(shape=arg_info.get("initial_shape"))
144
+
145
+ else:
146
+ print(f"Warning: Unhandled argument type '{arg_type}'. Returning None.")
147
+ return None
@@ -183,13 +183,80 @@ def convert(obj):
183
183
  asdict(obj)
184
184
  ) # Convert dataclass to dict and then process that dict
185
185
 
186
+ if _is_triton_kernels_layout(obj):
187
+ layout_info = {"type": type(obj).__name__}
188
+ if hasattr(obj, "initial_shape"):
189
+ layout_info["initial_shape"] = convert(obj.initial_shape)
190
+ if hasattr(obj, "name"):
191
+ layout_info["name"] = convert(obj.name)
192
+ return layout_info
193
+
186
194
  # 4. Common Triton constexpr objects
187
195
  if isinstance(obj, dtype):
188
196
  return f"triton.language.core.dtype('{str(obj)}')"
197
+
198
+ if TORCH_INSTALLED and isinstance(obj, torch.dtype):
199
+ return str(obj)
200
+
189
201
  log.warning(f"Unknown type: {type(obj)}")
190
202
  return str(obj) # Return primitive types as-is
191
203
 
192
204
 
205
+ def _is_triton_kernels_layout(obj):
206
+ """
207
+ Check if an object is an instance of a Layout class from a
208
+ triton_kernels module by checking its MRO.
209
+ """
210
+ t = type(obj)
211
+ for base_class in t.__mro__:
212
+ module_name = getattr(base_class, "__module__", "")
213
+ type_name = getattr(base_class, "__name__", "")
214
+ if type_name == "Layout" and module_name.startswith("triton_kernels"):
215
+ return True
216
+ return False
217
+
218
+
219
+ def _is_from_triton_kernels_module(obj):
220
+ """
221
+ Check if an object is an instance of Tensor or Storage from a
222
+ triton_kernels module.
223
+ """
224
+ t = type(obj)
225
+ module_name = getattr(t, "__module__", "")
226
+ type_name = getattr(t, "__name__", "")
227
+ return type_name in ("Tensor", "Storage") and module_name.startswith(
228
+ "triton_kernels"
229
+ )
230
+
231
+
232
+ def _log_torch_tensor_info(tensor_value):
233
+ """
234
+ Extracts metadata from a torch.Tensor object.
235
+
236
+ Args:
237
+ tensor_value (torch.Tensor): The tensor to extract information from.
238
+
239
+ Returns:
240
+ dict: A dictionary containing tensor metadata.
241
+ """
242
+ arg_info = {}
243
+ arg_info["type"] = "tensor"
244
+ arg_info["shape"] = list(tensor_value.shape)
245
+ arg_info["dtype"] = str(tensor_value.dtype)
246
+ arg_info["device"] = str(tensor_value.device)
247
+ arg_info["stride"] = list(tensor_value.stride())
248
+ arg_info["numel"] = tensor_value.numel()
249
+ arg_info["is_contiguous"] = tensor_value.is_contiguous()
250
+ arg_info["element_size"] = tensor_value.element_size()
251
+ arg_info["storage_offset"] = tensor_value.storage_offset()
252
+ # Memory usage in bytes
253
+ arg_info["memory_usage"] = tensor_value.numel() * tensor_value.element_size()
254
+ # Add data_ptr for memory tracking (optional)
255
+ if hasattr(tensor_value, "data_ptr"):
256
+ arg_info["data_ptr"] = hex(tensor_value.data_ptr())
257
+ return arg_info
258
+
259
+
193
260
  def maybe_enable_debug_logging():
194
261
  """
195
262
  This logging is for logging module itself, not for logging the triton compilation.
@@ -769,7 +836,8 @@ def maybe_trace_triton(
769
836
 
770
837
  def extract_arg_info(arg_dict):
771
838
  """
772
- Extract detailed information from kernel arguments, especially for PyTorch tensors.
839
+ Extract detailed information from kernel arguments, especially for PyTorch
840
+ tensors.
773
841
 
774
842
  Args:
775
843
  arg_dict: Dictionary of kernel arguments
@@ -785,19 +853,40 @@ def extract_arg_info(arg_dict):
785
853
  # Check if it's a PyTorch tensor
786
854
  if TORCH_INSTALLED and isinstance(arg_value, torch.Tensor):
787
855
  arg_info["type"] = "tensor"
788
- arg_info["shape"] = list(arg_value.shape)
789
- arg_info["dtype"] = str(arg_value.dtype)
790
- arg_info["device"] = str(arg_value.device)
791
- arg_info["stride"] = list(arg_value.stride())
792
- arg_info["numel"] = arg_value.numel()
793
- arg_info["is_contiguous"] = arg_value.is_contiguous()
794
- arg_info["element_size"] = arg_value.element_size()
795
- arg_info["storage_offset"] = arg_value.storage_offset()
796
- # Memory usage in bytes
797
- arg_info["memory_usage"] = arg_value.numel() * arg_value.element_size()
798
- # Add data_ptr for memory tracking (optional)
799
- if hasattr(arg_value, "data_ptr"):
800
- arg_info["data_ptr"] = hex(arg_value.data_ptr())
856
+ arg_info.update(_log_torch_tensor_info(arg_value))
857
+ # Handle custom Tensor/Storage types from triton_kernels
858
+ elif _is_from_triton_kernels_module(arg_value):
859
+ type_name = type(arg_value).__name__
860
+ arg_info["type"] = f"triton_kernels.tensor.{type_name}"
861
+
862
+ if type_name == "Tensor":
863
+ # Dump all attributes needed to reconstruct the Tensor wrapper
864
+ if hasattr(arg_value, "shape"):
865
+ arg_info["shape"] = convert(arg_value.shape)
866
+ if hasattr(arg_value, "shape_max"):
867
+ arg_info["shape_max"] = convert(arg_value.shape_max)
868
+ if hasattr(arg_value, "dtype"):
869
+ arg_info["dtype"] = convert(arg_value.dtype)
870
+ if hasattr(arg_value, "storage"):
871
+ # Recursively process the storage, which can be another
872
+ # custom type or a torch.Tensor
873
+ storage_arg = {"storage": arg_value.storage}
874
+ arg_info["storage"] = extract_arg_info(storage_arg)["storage"]
875
+
876
+ elif type_name == "Storage":
877
+ # Dump all attributes needed to reconstruct the Storage object
878
+ if (
879
+ hasattr(arg_value, "data")
880
+ and TORCH_INSTALLED
881
+ and isinstance(arg_value.data, torch.Tensor)
882
+ ):
883
+ # The 'data' is a torch.Tensor, log its metadata fully
884
+ arg_info["data"] = _log_torch_tensor_info(arg_value.data)
885
+ if hasattr(arg_value, "layout"):
886
+ arg_info["layout"] = convert(arg_value.layout)
887
+ else:
888
+ log.warning(f"Unknown type: {type(arg_value)}")
889
+
801
890
  # Handle scalar values
802
891
  elif isinstance(arg_value, (int, float, bool)):
803
892
  arg_info["type"] = type(arg_value).__name__
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tritonparse
3
- Version: 0.2.1.dev20250915071616
3
+ Version: 0.2.1.dev20250917071511
4
4
  Summary: TritonParse: A Compiler Tracer, Visualizer, and mini-Reproducer Generator for Triton Kernels
5
5
  Author-email: Yueming Hao <yhao@meta.com>
6
6
  License-Expression: BSD-3-Clause
@@ -21,7 +21,7 @@ Dynamic: license-file
21
21
  [![License: BSD-3](https://img.shields.io/badge/License-BSD--3-blue.svg)](https://opensource.org/licenses/BSD-3-Clause)
22
22
  [![GitHub Pages](https://img.shields.io/badge/GitHub%20Pages-Deploy-brightgreen)](https://meta-pytorch.org/tritonparse/)
23
23
 
24
- **A comprehensive visualization and analysis tool for Triton IR files** — helping developers analyze, debug, and understand Triton kernel compilation processes.
24
+ **A comprehensive visualization and analysis tool for Triton kernel compilation and launch** — helping developers analyze, debug, and understand Triton kernel compilation processes.
25
25
 
26
26
  🌐 **[Try it online →](https://meta-pytorch.org/tritonparse/?json_url=https://meta-pytorch.org/tritonparse/dedicated_log_triton_trace_findhao__mapped.ndjson.gz)**
27
27
 
@@ -81,10 +81,18 @@ INFO:tritonparse:Copying parsed logs from /tmp/tmp1gan7zky to /scratch/findhao/t
81
81
  ## 🛠️ Installation
82
82
 
83
83
  **For basic usage (trace generation):**
84
+ Four options:
84
85
  ```bash
86
+ # install nightly version
87
+ pip install -U --pre tritonparse
88
+ # install stable version
89
+ pip install tritonparse
90
+ # install from source
85
91
  git clone https://github.com/meta-pytorch/tritonparse.git
86
92
  cd tritonparse
87
93
  pip install -e .
94
+ # pip install the latest version from github
95
+ pip install git+https://github.com/meta-pytorch/tritonparse.git
88
96
  ```
89
97
 
90
98
  **Prerequisites:** Python ≥ 3.10, Triton ≥ 3.4.0, GPU required (NVIDIA/AMD)
@@ -7,18 +7,20 @@ tritonparse/mapper.py,sha256=prrczfi13P7Aa042OrEBsmRF1HW3jDhwxicANgPkWIM,4150
7
7
  tritonparse/shared_vars.py,sha256=-c9CvXJSDm9spYhDOJPEQProeT_xl3PaNmqTEYi_u4s,505
8
8
  tritonparse/source_type.py,sha256=nmYEQS8rfkIN9BhNhQbkmEvKnvS-3zAxRGLY4TaZdi8,1676
9
9
  tritonparse/sourcemap_utils.py,sha256=qsQmTDuEe9yuUVyxSHRbjTR38gi0hvJEijnPkrJVAV4,2037
10
- tritonparse/structured_logging.py,sha256=mrApsVigIHZln6De2ElwNTUSZfO61OHXgd08o2X26sM,40430
10
+ tritonparse/structured_logging.py,sha256=7wSKc9HnV40eLoOgFNosc7dXWvSxsJJo5mnRVoDZqIk,43678
11
11
  tritonparse/tp_logger.py,sha256=vXzY7hMDmVnRBGBhIjFZe3nHZzG5NKKPONGUszJhGgU,242
12
12
  tritonparse/trace_processor.py,sha256=QzUOKwnOkBbwTTKBsa5ZMUABPLMJIBFtTcG2SkhO0I8,12771
13
13
  tritonparse/utils.py,sha256=wt61tpbkqjGqHh0c7Nr2WlOv7PbQssmjULd6uA6aAko,4475
14
+ tritonparse/reproducer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ tritonparse/reproducer/utils.py,sha256=VfMBwnTEZO8Ug9_ZRlZUVTMaMczDkviAykXpnK5dacU,5093
14
16
  tritonparse/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
17
  tritonparse/tools/decompress_bin_ndjson.py,sha256=kpt7DM_sSA334F1X45xdkP2OR9LuB27Pc50EkGr6CPM,4144
16
18
  tritonparse/tools/format_fix.py,sha256=Ol0Sjui8D7OzHwbamAfGnq8V5Y63uwNaFTKSORN5HkQ,3867
17
19
  tritonparse/tools/load_tensor.py,sha256=tfdmNVd9gsZqO6msQBhbXIhOvUzgc83yF64k2GDWPNk,2122
18
20
  tritonparse/tools/prettify_ndjson.py,sha256=VOzVWoXpCbaAXYA4i_wBcQIHfh-JhAx7xR4cF_L8yDs,10928
19
21
  tritonparse/tools/readme.md,sha256=w6PWYfYnRgoPArLjxG9rVrpcLUkoVMGuRlbpF-o0IQM,110
20
- tritonparse-0.2.1.dev20250915071616.dist-info/licenses/LICENSE,sha256=4ZciugpyN7wcM4L-9pyDh_etvMUeIfBhDTyH1zeZlQM,1515
21
- tritonparse-0.2.1.dev20250915071616.dist-info/METADATA,sha256=CigDDEKt_dqH4EA7Le5lt9a-HQkeidokxdsYJBadVA0,6306
22
- tritonparse-0.2.1.dev20250915071616.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- tritonparse-0.2.1.dev20250915071616.dist-info/top_level.txt,sha256=ITcTKgp3vf_bXV9vixuQU9IrZa3L1EfDSZwvRzRaoJU,12
24
- tritonparse-0.2.1.dev20250915071616.dist-info/RECORD,,
22
+ tritonparse-0.2.1.dev20250917071511.dist-info/licenses/LICENSE,sha256=4ZciugpyN7wcM4L-9pyDh_etvMUeIfBhDTyH1zeZlQM,1515
23
+ tritonparse-0.2.1.dev20250917071511.dist-info/METADATA,sha256=nesWLDcKVWKWtq3Wph_hHxTH5abc2SREFy_tYDeACdk,6580
24
+ tritonparse-0.2.1.dev20250917071511.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
25
+ tritonparse-0.2.1.dev20250917071511.dist-info/top_level.txt,sha256=ITcTKgp3vf_bXV9vixuQU9IrZa3L1EfDSZwvRzRaoJU,12
26
+ tritonparse-0.2.1.dev20250917071511.dist-info/RECORD,,