tritonparse 0.3.2.dev20251210071601__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritonparse might be problematic. Click here for more details.

Files changed (62) hide show
  1. tritonparse/__init__.py +0 -0
  2. tritonparse/__main__.py +7 -0
  3. tritonparse/cli.py +110 -0
  4. tritonparse/common.py +409 -0
  5. tritonparse/context_manager.py +64 -0
  6. tritonparse/event_diff.py +122 -0
  7. tritonparse/extract_source_mappings.py +49 -0
  8. tritonparse/info/__init__.py +30 -0
  9. tritonparse/info/cli.py +121 -0
  10. tritonparse/info/kernel_query.py +209 -0
  11. tritonparse/info/parse_helper.py +70 -0
  12. tritonparse/ir_analysis.py +427 -0
  13. tritonparse/ir_parser.py +365 -0
  14. tritonparse/mapper.py +102 -0
  15. tritonparse/reproducer/__init__.py +0 -0
  16. tritonparse/reproducer/ast_analyzer.py +636 -0
  17. tritonparse/reproducer/cli.py +72 -0
  18. tritonparse/reproducer/consolidated_result.py +52 -0
  19. tritonparse/reproducer/function_extractor.py +228 -0
  20. tritonparse/reproducer/import_info.py +25 -0
  21. tritonparse/reproducer/import_parser.py +178 -0
  22. tritonparse/reproducer/import_resolver.py +151 -0
  23. tritonparse/reproducer/ingestion/ndjson.py +237 -0
  24. tritonparse/reproducer/multi_file_analyzer.py +824 -0
  25. tritonparse/reproducer/orchestrator.py +110 -0
  26. tritonparse/reproducer/placeholder_replacer.py +335 -0
  27. tritonparse/reproducer/templates/__init__.py +0 -0
  28. tritonparse/reproducer/templates/example.py +38 -0
  29. tritonparse/reproducer/templates/loader.py +59 -0
  30. tritonparse/reproducer/templates/tritonbench.py +106 -0
  31. tritonparse/reproducer/templates/utils.py +48 -0
  32. tritonparse/reproducer/tests/__init__.py +0 -0
  33. tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
  34. tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
  35. tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
  36. tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
  37. tritonparse/reproducer/tests/test_import_parser.py +164 -0
  38. tritonparse/reproducer/tests/test_import_resolver.py +88 -0
  39. tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
  40. tritonparse/reproducer/types.py +20 -0
  41. tritonparse/reproducer/utils.py +580 -0
  42. tritonparse/shared_vars.py +12 -0
  43. tritonparse/source_type.py +56 -0
  44. tritonparse/sourcemap_utils.py +96 -0
  45. tritonparse/structured_logging.py +1634 -0
  46. tritonparse/tools/__init__.py +0 -0
  47. tritonparse/tools/decompress_bin_ndjson.py +120 -0
  48. tritonparse/tools/disasm.py +81 -0
  49. tritonparse/tools/extract_irs.py +244 -0
  50. tritonparse/tools/format_fix.py +151 -0
  51. tritonparse/tools/load_tensor.py +76 -0
  52. tritonparse/tools/prettify_ndjson.py +334 -0
  53. tritonparse/tools/readme.md +37 -0
  54. tritonparse/tp_logger.py +9 -0
  55. tritonparse/trace_processor.py +367 -0
  56. tritonparse/utils.py +155 -0
  57. tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
  58. tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
  59. tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
  60. tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
  61. tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
  62. tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
@@ -0,0 +1,151 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ """
4
+ ImportResolver: Resolves Python import statements to file paths.
5
+
6
+ This module provides functionality to resolve import statements to their
7
+ corresponding file paths using Python's importlib system, without actually
8
+ importing the modules (no side effects).
9
+ """
10
+
11
+ import importlib.util
12
+ import logging
13
+ import sys
14
+ from importlib.machinery import ModuleSpec
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class ImportResolver:
20
+ """
21
+ Resolves import statements to absolute file paths using importlib.
22
+
23
+ Uses Python's import resolution system (importlib.util.find_spec) to
24
+ locate module files without actually importing them, avoiding side effects.
25
+ """
26
+
27
+ def __init__(self, project_root: str) -> None:
28
+ """
29
+ Initialize the ImportResolver.
30
+
31
+ Args:
32
+ project_root: Absolute path to project root directory
33
+ (e.g., /path/to/project/root)
34
+ """
35
+ self.project_root = project_root
36
+
37
+ # Ensure project_root is in sys.path for importlib to find modules
38
+ if project_root not in sys.path:
39
+ sys.path.insert(0, project_root)
40
+
41
+ # Common third-party modules to exclude from analysis
42
+ self.external_modules: set[str] = {
43
+ "torch",
44
+ "triton",
45
+ "numpy",
46
+ "pandas",
47
+ "transformers",
48
+ "tqdm",
49
+ "typing_extensions",
50
+ "pydantic",
51
+ "requests",
52
+ "pytest",
53
+ "unittest",
54
+ }
55
+
56
+ def resolve_import(
57
+ self,
58
+ module_name: str,
59
+ package: str | None = None,
60
+ ) -> tuple[str | None, bool]:
61
+ """
62
+ Resolve import to absolute file path using importlib.
63
+
64
+ Args:
65
+ module_name: Module to import (e.g., "torch" or "pytorch.tritonparse")
66
+ package: Package context for relative imports (e.g., "pytorch.tritonparse")
67
+
68
+ Returns:
69
+ Tuple of (file_path, is_external):
70
+ - file_path: Absolute path to .py file, or None if not resolvable
71
+ - is_external: True if third-party/built-in module
72
+
73
+ Examples:
74
+ >>> resolver = ImportResolver("/data/.../fbcode")
75
+ >>> path, is_ext = resolver.resolve_import("pytorch.tritonparse.module")
76
+ >>> # path: "/data/.../fbcode/pytorch/tritonparse/module.py"
77
+ >>> # is_ext: False
78
+
79
+ >>> path, is_ext = resolver.resolve_import("torch")
80
+ >>> # path: None
81
+ >>> # is_ext: True
82
+ """
83
+ # Check if it's a known external module
84
+ base_module = module_name.split(".")[0]
85
+ if base_module in self.external_modules:
86
+ logger.debug(
87
+ "Import '%s' marked as external (known third-party module)",
88
+ module_name,
89
+ )
90
+ return None, True
91
+
92
+ try:
93
+ # Use importlib to find the module (without importing it)
94
+ spec: ModuleSpec | None = importlib.util.find_spec(module_name, package)
95
+
96
+ if spec is None or spec.origin is None:
97
+ # Module not found or built-in (has no file)
98
+ logger.debug(
99
+ "Import '%s' not resolvable (spec=%s, origin=%s)",
100
+ module_name,
101
+ spec,
102
+ spec.origin if spec else None,
103
+ )
104
+ return None, True
105
+
106
+ origin = spec.origin
107
+
108
+ # Skip special cases like frozen/built-in modules
109
+ if origin == "frozen" or origin == "built-in":
110
+ logger.debug("Import '%s' is a built-in module", module_name)
111
+ return None, True
112
+
113
+ # Check if the module is within project root
114
+ is_internal = origin.startswith(self.project_root)
115
+
116
+ if is_internal:
117
+ logger.debug(
118
+ "Import '%s' resolved to INTERNAL file: %s",
119
+ module_name,
120
+ origin,
121
+ )
122
+ return origin, False
123
+ else:
124
+ # External module (outside project)
125
+ logger.warning(
126
+ "Import '%s' resolved to file '%s' which is OUTSIDE project_root '%s'. "
127
+ "This import will be skipped. If this is unexpected, verify your --code-root parameter matches "
128
+ "the directory containing your source files.",
129
+ module_name,
130
+ origin,
131
+ self.project_root,
132
+ )
133
+ return None, True
134
+
135
+ except (ImportError, ValueError, AttributeError) as e:
136
+ # Module doesn't exist or can't be resolved
137
+ logger.debug("Import '%s' failed to resolve: %s", module_name, str(e))
138
+ return None, True
139
+
140
+ def is_external_module(self, module_name: str) -> bool:
141
+ """
142
+ Check if a module is external (third-party or built-in).
143
+
144
+ Args:
145
+ module_name: Name of the module to check
146
+
147
+ Returns:
148
+ True if module is external, False if it's internal (project)
149
+ """
150
+ base_module = module_name.split(".")[0]
151
+ return base_module in self.external_modules
@@ -0,0 +1,237 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ from tritonparse.tp_logger import logger
7
+
8
+ # Sentinel object to mark arguments that should be skipped during processing
9
+ _SKIP = object()
10
+
11
+
12
+ @dataclass
13
+ class KernelInfo:
14
+ """Information about a Triton kernel extracted from compilation events."""
15
+
16
+ file_path: str
17
+ function_name: str
18
+ source_code: str
19
+ call_stack: List[Dict[str, Any]]
20
+
21
+
22
+ @dataclass
23
+ class ContextBundle:
24
+ """Bundle of all context information needed to reproduce a kernel launch."""
25
+
26
+ kernel_info: KernelInfo
27
+ compile: Dict[str, Any]
28
+ launch: Dict[str, Any]
29
+ args: Dict[str, Any]
30
+ tensor_args: Dict[str, Any]
31
+ raw_launch_event: Dict[str, Any]
32
+ raw_comp_event: Dict[str, Any]
33
+
34
+
35
+ def get_launch_and_compilation_events(
36
+ events: List[Dict[str, Any]], line_index: Optional[int] = None
37
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
38
+ """
39
+ Extract launch and compilation events from the event list.
40
+
41
+ Args:
42
+ events: List of parsed event dictionaries.
43
+ line_index: 0-based index of the launch event to process.
44
+
45
+ Returns:
46
+ Tuple of (launch_event, compilation_event).
47
+
48
+ Raises:
49
+ ValueError: If the event at line_index is not a launch event.
50
+ RuntimeError: If compilation event cannot be found or is ambiguous.
51
+ """
52
+ if line_index is None or line_index >= len(events):
53
+ raise ValueError(f"Invalid line_index: {line_index}")
54
+
55
+ launch_event = events[line_index]
56
+ if launch_event["event_type"] != "launch":
57
+ raise ValueError(f"Event at index {line_index} is not a launch event")
58
+
59
+ comp_meta = launch_event.get("compilation_metadata", {})
60
+ comp_hash = comp_meta.get("hash")
61
+ if not comp_hash:
62
+ raise RuntimeError("Could not find compilation hash in launch event.")
63
+
64
+ comp_event = None
65
+ for event in events:
66
+ if (
67
+ event["event_type"] == "compilation"
68
+ and event.get("payload", {}).get("metadata", {}).get("hash") == comp_hash
69
+ ):
70
+ comp_event = event
71
+ break
72
+ if not comp_event:
73
+ raise RuntimeError(f"Could not find compilation event for hash {comp_hash}.")
74
+ return launch_event, comp_event
75
+
76
+
77
+ def get_kernel_info(comp_event: Dict[str, Any]) -> KernelInfo:
78
+ """
79
+ Extract kernel information from a compilation event.
80
+
81
+ Args:
82
+ comp_event: Compilation event dictionary containing kernel metadata.
83
+
84
+ Returns:
85
+ KernelInfo object with extracted kernel details.
86
+
87
+ Raises:
88
+ RuntimeError: If file path or function name cannot be resolved.
89
+ """
90
+ payload = comp_event.get("payload") or {}
91
+ py_source = payload.get("python_source") or {}
92
+ code = py_source.get("code", "")
93
+
94
+ # Extract file path and function name
95
+ file_path = py_source.get("file_path")
96
+ # The function name is in the compilation metadata payload
97
+ func_name = (comp_event.get("payload", {}).get("metadata") or {}).get("name")
98
+
99
+ # Find '@triton.jit' decorator and slice the string from there
100
+ jit_marker = "@triton.jit"
101
+ jit_pos = code.find(jit_marker)
102
+ if jit_pos != -1:
103
+ code = code[jit_pos:]
104
+ logger.debug("Extracted kernel source starting from '@triton.jit'.")
105
+
106
+ if not file_path or not func_name:
107
+ raise RuntimeError(
108
+ "Could not resolve kernel file path or function name from compilation event."
109
+ " The import-based strategy cannot proceed."
110
+ )
111
+ return KernelInfo(file_path, func_name, code, comp_event.get("stack", []))
112
+
113
+
114
+ def _decode_arg(raw: Any) -> Any:
115
+ """
116
+ Decode a raw argument value from event data.
117
+
118
+ Args:
119
+ raw: Raw argument value from event data.
120
+
121
+ Returns:
122
+ Decoded argument value, or _SKIP sentinel for tensors.
123
+ """
124
+ if not isinstance(raw, dict):
125
+ return raw
126
+ t = raw.get("type")
127
+ if t == "tensor":
128
+ return _SKIP
129
+ if t == "NoneType":
130
+ return None
131
+ return raw.get("value", raw.get("repr"))
132
+
133
+
134
+ def _pack_args(args: Dict[str, Any]) -> Dict[str, Any]:
135
+ """
136
+ Pack argument values into a standardized format.
137
+
138
+ Args:
139
+ args: Dictionary of argument names to values.
140
+
141
+ Returns:
142
+ Dictionary with packed argument information including type and metadata.
143
+ """
144
+ packed = {}
145
+ for k, v in args.items():
146
+ t = v.get("type") if isinstance(v, dict) else None
147
+ if t == "tensor":
148
+ packed[k] = {
149
+ "type": "tensor",
150
+ "shape": v.get("shape") if isinstance(v, dict) else None,
151
+ "dtype": v.get("dtype") if isinstance(v, dict) else None,
152
+ "device": v.get("device") if isinstance(v, dict) else None,
153
+ "stride": v.get("stride") if isinstance(v, dict) else None,
154
+ "is_contiguous": (
155
+ v.get("is_contiguous") if isinstance(v, dict) else None
156
+ ),
157
+ "numel": v.get("numel") if isinstance(v, dict) else None,
158
+ }
159
+ else:
160
+ # scalar / NoneType etc
161
+ if isinstance(v, dict):
162
+ packed[k] = {
163
+ "type": v.get("type"),
164
+ "value": v.get("value", v.get("repr")),
165
+ }
166
+ else:
167
+ packed[k] = {
168
+ "type": None,
169
+ "value": v,
170
+ }
171
+ return packed
172
+
173
+
174
+ def build_context_bundle(
175
+ events: List[Dict[str, Any]], line_index: Optional[int] = None
176
+ ):
177
+ """
178
+ Build a complete context bundle from events and line index.
179
+
180
+ Args:
181
+ events: List of parsed event dictionaries.
182
+ line_index: 0-based index of the launch event to process.
183
+
184
+ Returns:
185
+ ContextBundle containing all information needed to reproduce the kernel launch.
186
+
187
+ Raises:
188
+ ValueError: If line_index is invalid or event is not a launch event.
189
+ RuntimeError: If compilation event cannot be found.
190
+ """
191
+ launch_event, comp_event = get_launch_and_compilation_events(events, line_index)
192
+ kernel_info = get_kernel_info(comp_event)
193
+ grid = launch_event.get("grid")
194
+ extracted_args = launch_event.get("extracted_args", {})
195
+ comp_meta = launch_event.get("compilation_metadata", {})
196
+
197
+ # Compile metadata subset we care about
198
+ compile_block = {
199
+ "num_warps": comp_meta.get("num_warps"),
200
+ "num_stages": comp_meta.get("num_stages"),
201
+ "arch": comp_meta.get("arch"),
202
+ "backend": comp_meta.get("backend_name") or comp_meta.get("backend"),
203
+ "triton_version": comp_meta.get("triton_version"),
204
+ "hash": comp_meta.get("hash"),
205
+ }
206
+
207
+ # kwargs: include constexpr + explicit scalars used for launch (skip tensor args)
208
+ kwargs = {}
209
+ for k, v in extracted_args.items():
210
+ val = _decode_arg(v)
211
+ if val is _SKIP:
212
+ continue
213
+ kwargs[k] = val
214
+
215
+ # tensor args: only tensors
216
+ raw_tensor_args = {
217
+ k: v
218
+ for k, v in extracted_args.items()
219
+ if isinstance(v, dict) and v.get("type") == "tensor"
220
+ }
221
+
222
+ primitive_args = _pack_args(extracted_args)
223
+ tensor_args = _pack_args(raw_tensor_args)
224
+ launch_block = {
225
+ "grid": grid,
226
+ "kwargs": kwargs,
227
+ }
228
+
229
+ return ContextBundle(
230
+ kernel_info,
231
+ compile_block,
232
+ launch_block,
233
+ primitive_args,
234
+ tensor_args,
235
+ launch_event,
236
+ comp_event,
237
+ )