tritonparse 0.3.2.dev20251210071601__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/__init__.py +0 -0
- tritonparse/__main__.py +7 -0
- tritonparse/cli.py +110 -0
- tritonparse/common.py +409 -0
- tritonparse/context_manager.py +64 -0
- tritonparse/event_diff.py +122 -0
- tritonparse/extract_source_mappings.py +49 -0
- tritonparse/info/__init__.py +30 -0
- tritonparse/info/cli.py +121 -0
- tritonparse/info/kernel_query.py +209 -0
- tritonparse/info/parse_helper.py +70 -0
- tritonparse/ir_analysis.py +427 -0
- tritonparse/ir_parser.py +365 -0
- tritonparse/mapper.py +102 -0
- tritonparse/reproducer/__init__.py +0 -0
- tritonparse/reproducer/ast_analyzer.py +636 -0
- tritonparse/reproducer/cli.py +72 -0
- tritonparse/reproducer/consolidated_result.py +52 -0
- tritonparse/reproducer/function_extractor.py +228 -0
- tritonparse/reproducer/import_info.py +25 -0
- tritonparse/reproducer/import_parser.py +178 -0
- tritonparse/reproducer/import_resolver.py +151 -0
- tritonparse/reproducer/ingestion/ndjson.py +237 -0
- tritonparse/reproducer/multi_file_analyzer.py +824 -0
- tritonparse/reproducer/orchestrator.py +110 -0
- tritonparse/reproducer/placeholder_replacer.py +335 -0
- tritonparse/reproducer/templates/__init__.py +0 -0
- tritonparse/reproducer/templates/example.py +38 -0
- tritonparse/reproducer/templates/loader.py +59 -0
- tritonparse/reproducer/templates/tritonbench.py +106 -0
- tritonparse/reproducer/templates/utils.py +48 -0
- tritonparse/reproducer/tests/__init__.py +0 -0
- tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
- tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
- tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
- tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
- tritonparse/reproducer/tests/test_import_parser.py +164 -0
- tritonparse/reproducer/tests/test_import_resolver.py +88 -0
- tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
- tritonparse/reproducer/types.py +20 -0
- tritonparse/reproducer/utils.py +580 -0
- tritonparse/shared_vars.py +12 -0
- tritonparse/source_type.py +56 -0
- tritonparse/sourcemap_utils.py +96 -0
- tritonparse/structured_logging.py +1634 -0
- tritonparse/tools/__init__.py +0 -0
- tritonparse/tools/decompress_bin_ndjson.py +120 -0
- tritonparse/tools/disasm.py +81 -0
- tritonparse/tools/extract_irs.py +244 -0
- tritonparse/tools/format_fix.py +151 -0
- tritonparse/tools/load_tensor.py +76 -0
- tritonparse/tools/prettify_ndjson.py +334 -0
- tritonparse/tools/readme.md +37 -0
- tritonparse/tp_logger.py +9 -0
- tritonparse/trace_processor.py +367 -0
- tritonparse/utils.py +155 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,636 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
import builtins
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# Default built-in Python functions to filter out from call graph
|
|
10
|
+
DEFAULT_BUILTIN_FILTERS = [
|
|
11
|
+
name
|
|
12
|
+
for name in dir(builtins)
|
|
13
|
+
if callable(getattr(builtins, name)) and not name.startswith("_")
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True)
|
|
18
|
+
class Site:
|
|
19
|
+
filename: str
|
|
20
|
+
lineno: int
|
|
21
|
+
col: int
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class FuncDescriptor:
|
|
26
|
+
name: str
|
|
27
|
+
decorators: List[str]
|
|
28
|
+
site: Site
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class Edge:
|
|
33
|
+
caller: str
|
|
34
|
+
callee: str
|
|
35
|
+
site: Site
|
|
36
|
+
call_type: str
|
|
37
|
+
callee_descriptor: Optional[FuncDescriptor]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def split_by_the_last_dot(s: str) -> Tuple[Optional[str], Optional[str]]:
|
|
41
|
+
if s is None:
|
|
42
|
+
return None, None
|
|
43
|
+
if "." in s:
|
|
44
|
+
return tuple(s.rsplit(".", 1)) # pyre-ignore[7]
|
|
45
|
+
else:
|
|
46
|
+
return None, s
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class CallGraph(ast.NodeVisitor):
|
|
50
|
+
"""
|
|
51
|
+
AST visitor that builds a call graph by tracking function calls and definitions.
|
|
52
|
+
|
|
53
|
+
This class traverses an AST and records:
|
|
54
|
+
- Function definitions and their decorators
|
|
55
|
+
- Function calls and their call sites
|
|
56
|
+
- Import statements and name bindings
|
|
57
|
+
- Lambda expressions
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
filename: str = "<string>",
|
|
63
|
+
module_name: str = "<module>",
|
|
64
|
+
backends: Optional[List[str]] = None,
|
|
65
|
+
transitive_closure: bool = True,
|
|
66
|
+
callee_prefix_filters: Optional[List[str]] = None,
|
|
67
|
+
callee_name_filters: Optional[List[str]] = None,
|
|
68
|
+
):
|
|
69
|
+
self.filename = filename
|
|
70
|
+
|
|
71
|
+
self.edges: List[Edge] = []
|
|
72
|
+
self.decorator_edges: List[Edge] = []
|
|
73
|
+
assert backends is not None, "Backends must not be None"
|
|
74
|
+
self.backends: Dict[str, List[Any]] = {}
|
|
75
|
+
for backend in backends:
|
|
76
|
+
self.backends[backend] = []
|
|
77
|
+
|
|
78
|
+
self.scope_stack: List[str] = []
|
|
79
|
+
self.module_name = module_name
|
|
80
|
+
|
|
81
|
+
self.bindings_stack: List[Dict[str, str]] = [dict()]
|
|
82
|
+
self.local_functions: Set[str] = set()
|
|
83
|
+
|
|
84
|
+
# Track functions in the call chain for transitive closure
|
|
85
|
+
self.transitive_closure = transitive_closure
|
|
86
|
+
# Note: backends are provided as short names (e.g., "_attn_fwd_base_opt")
|
|
87
|
+
# but we'll need to match them against fully qualified names later
|
|
88
|
+
# We store both the short name and will add the qualified name when we see the function definition
|
|
89
|
+
self.tracked_functions: Set[str] = (
|
|
90
|
+
set(backends) if transitive_closure else set()
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Prefix filters to exclude certain callees (e.g., "triton.", "tl.")
|
|
94
|
+
self.callee_prefix_filters = callee_prefix_filters or []
|
|
95
|
+
|
|
96
|
+
# Name filters to exclude specific built-in function names
|
|
97
|
+
# Combine user-provided filters with default built-ins
|
|
98
|
+
self.callee_name_filters = set(
|
|
99
|
+
(callee_name_filters or []) + DEFAULT_BUILTIN_FILTERS
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# lambda node -> synthetic id (stable within this pass)
|
|
103
|
+
self._lambda_ids: Dict[ast.Lambda, str] = {}
|
|
104
|
+
|
|
105
|
+
# Store function AST nodes and source code for extraction
|
|
106
|
+
self.func_nodes: Dict[str, ast.FunctionDef] = {}
|
|
107
|
+
self.source_code: str = ""
|
|
108
|
+
|
|
109
|
+
# ---------- helpers ----------
|
|
110
|
+
def _cur_scope(self) -> str:
|
|
111
|
+
return ".".join([self.module_name] + self.scope_stack).strip(".")
|
|
112
|
+
|
|
113
|
+
def _push_scope(self, name: str) -> None:
|
|
114
|
+
self.scope_stack.append(name)
|
|
115
|
+
self.bindings_stack.append({})
|
|
116
|
+
|
|
117
|
+
def _pop_scope(self) -> None:
|
|
118
|
+
self.scope_stack.pop()
|
|
119
|
+
self.bindings_stack.pop()
|
|
120
|
+
|
|
121
|
+
def _bind(self, name: str, target: str) -> None:
|
|
122
|
+
self.bindings_stack[-1][name] = target
|
|
123
|
+
|
|
124
|
+
def _bind_func_descriptor(self, node, decorators: List[str]) -> None:
|
|
125
|
+
name = node.name
|
|
126
|
+
site = Site(
|
|
127
|
+
self.filename, getattr(node, "lineno", -1), getattr(node, "col_offset", -1)
|
|
128
|
+
)
|
|
129
|
+
self.bindings_stack[-1][f"__{name}_descriptor__"] = FuncDescriptor(
|
|
130
|
+
name, decorators, site
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def _resolve_name(self, id_: str) -> str:
|
|
134
|
+
for env in reversed(self.bindings_stack):
|
|
135
|
+
if id_ in env:
|
|
136
|
+
return env[id_]
|
|
137
|
+
return id_
|
|
138
|
+
|
|
139
|
+
def _resolve_func_descriptor(self, id_: str) -> Optional[FuncDescriptor]:
|
|
140
|
+
for env in reversed(self.bindings_stack):
|
|
141
|
+
decorator_constant = f"__{id_}_descriptor__"
|
|
142
|
+
if decorator_constant in env:
|
|
143
|
+
return env[decorator_constant]
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
def _resolve_attr(self, node: ast.AST) -> str:
|
|
147
|
+
parts: List[str] = []
|
|
148
|
+
cur = node
|
|
149
|
+
while isinstance(cur, ast.Attribute):
|
|
150
|
+
parts.append(cur.attr)
|
|
151
|
+
cur = cur.value
|
|
152
|
+
if isinstance(cur, ast.Name):
|
|
153
|
+
head = self._resolve_name(cur.id)
|
|
154
|
+
else:
|
|
155
|
+
return "<dynamic_attr>"
|
|
156
|
+
return ".".join([head] + list(reversed(parts)))
|
|
157
|
+
|
|
158
|
+
def _lambda_id(self, node: ast.Lambda) -> str:
|
|
159
|
+
lid = self._lambda_ids.get(node)
|
|
160
|
+
if lid is None:
|
|
161
|
+
scope = self._cur_scope() or "<module>"
|
|
162
|
+
lid = f"{scope}.<lambda>@{getattr(node,'lineno',-1)}:{getattr(node,'col_offset',-1)}"
|
|
163
|
+
self._lambda_ids[node] = lid
|
|
164
|
+
return lid
|
|
165
|
+
|
|
166
|
+
def _record_call(
|
|
167
|
+
self, callee: str, node: ast.AST, maybe_triton: bool = False, caller=None
|
|
168
|
+
) -> None:
|
|
169
|
+
if caller is None:
|
|
170
|
+
caller = self._cur_scope() or "<module>"
|
|
171
|
+
# replace callee with caller class name if it is "self." call
|
|
172
|
+
if "." in caller and callee.startswith("self."):
|
|
173
|
+
caller_prefix, _ = split_by_the_last_dot(caller)
|
|
174
|
+
# remove the "self." prefix
|
|
175
|
+
callee_name = callee[5:]
|
|
176
|
+
callee = caller_prefix + "." + callee_name
|
|
177
|
+
site = Site(
|
|
178
|
+
self.filename, getattr(node, "lineno", -1), getattr(node, "col_offset", -1)
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Check if callee should be filtered out based on prefix filters
|
|
182
|
+
is_filtered = any(
|
|
183
|
+
callee.startswith(prefix) for prefix in self.callee_prefix_filters
|
|
184
|
+
)
|
|
185
|
+
if is_filtered:
|
|
186
|
+
return
|
|
187
|
+
|
|
188
|
+
# Check if callee should be filtered out based on exact name match
|
|
189
|
+
# Extract the function name from qualified name (e.g., "module.func" -> "func")
|
|
190
|
+
callee_name = callee.split(".")[-1] if "." in callee else callee
|
|
191
|
+
if callee_name in self.callee_name_filters:
|
|
192
|
+
return
|
|
193
|
+
|
|
194
|
+
# Determine if the caller should be tracked based on the transitive_closure flag
|
|
195
|
+
if self.transitive_closure:
|
|
196
|
+
# Transitive mode: track calls from functions in tracked_functions or matching backends
|
|
197
|
+
is_tracked = caller in self.tracked_functions or any(
|
|
198
|
+
backend in caller for backend in self.backends
|
|
199
|
+
)
|
|
200
|
+
else:
|
|
201
|
+
# Backend-only mode: only track calls from functions matching backend patterns
|
|
202
|
+
is_tracked = any(backend in caller for backend in self.backends)
|
|
203
|
+
|
|
204
|
+
# In transitive closure mode, record all edges during AST traversal
|
|
205
|
+
# We'll filter them afterwards based on reachability from backends
|
|
206
|
+
if self.transitive_closure:
|
|
207
|
+
callee_descriptor = self._resolve_func_descriptor(callee)
|
|
208
|
+
self.edges.append(
|
|
209
|
+
Edge(
|
|
210
|
+
caller,
|
|
211
|
+
callee,
|
|
212
|
+
callee_descriptor=callee_descriptor,
|
|
213
|
+
site=site,
|
|
214
|
+
call_type="regular",
|
|
215
|
+
)
|
|
216
|
+
)
|
|
217
|
+
elif is_tracked:
|
|
218
|
+
# Backend-only mode: only record if caller matches a backend
|
|
219
|
+
callee_descriptor = self._resolve_func_descriptor(callee)
|
|
220
|
+
self.edges.append(
|
|
221
|
+
Edge(
|
|
222
|
+
caller,
|
|
223
|
+
callee,
|
|
224
|
+
callee_descriptor=callee_descriptor,
|
|
225
|
+
site=site,
|
|
226
|
+
call_type="regular",
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
def _filter_edges_by_reachability(self) -> None:
|
|
231
|
+
"""Filter edges to keep only those reachable from backend functions.
|
|
232
|
+
|
|
233
|
+
This implements the transitive closure: starting from backend functions,
|
|
234
|
+
we iteratively add all callees until no new functions are added.
|
|
235
|
+
"""
|
|
236
|
+
if not self.transitive_closure:
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
# Build caller -> callees mapping from all recorded edges
|
|
240
|
+
call_graph: Dict[str, Set[str]] = {}
|
|
241
|
+
for edge in self.edges:
|
|
242
|
+
if edge.caller not in call_graph:
|
|
243
|
+
call_graph[edge.caller] = set()
|
|
244
|
+
call_graph[edge.caller].add(edge.callee)
|
|
245
|
+
|
|
246
|
+
# Initialize reachable functions with backends
|
|
247
|
+
reachable: Set[str] = set()
|
|
248
|
+
for backend in self.backends:
|
|
249
|
+
# Add both the short name and any fully qualified names that match
|
|
250
|
+
reachable.add(backend)
|
|
251
|
+
for func in self.local_functions:
|
|
252
|
+
if backend in func:
|
|
253
|
+
reachable.add(func)
|
|
254
|
+
|
|
255
|
+
# Iteratively add callees until no new functions are added
|
|
256
|
+
changed = True
|
|
257
|
+
while changed:
|
|
258
|
+
changed = False
|
|
259
|
+
new_reachable = set(reachable)
|
|
260
|
+
for func in reachable:
|
|
261
|
+
if func in call_graph:
|
|
262
|
+
for callee in call_graph[func]:
|
|
263
|
+
if callee not in new_reachable:
|
|
264
|
+
new_reachable.add(callee)
|
|
265
|
+
changed = True
|
|
266
|
+
reachable = new_reachable
|
|
267
|
+
|
|
268
|
+
# Filter edges to keep only those where caller is reachable
|
|
269
|
+
self.edges = [edge for edge in self.edges if edge.caller in reachable]
|
|
270
|
+
|
|
271
|
+
# Update tracked_functions to reflect reachable functions
|
|
272
|
+
self.tracked_functions = reachable
|
|
273
|
+
|
|
274
|
+
def get_unique_edges(self) -> List[Edge]:
|
|
275
|
+
"""Return deduplicated edges, keeping only unique (caller, callee) pairs.
|
|
276
|
+
|
|
277
|
+
When a function calls another function from multiple locations,
|
|
278
|
+
this returns only one edge representing that relationship.
|
|
279
|
+
"""
|
|
280
|
+
seen: Set[Tuple[str, str]] = set()
|
|
281
|
+
unique_edges: List[Edge] = []
|
|
282
|
+
|
|
283
|
+
for edge in self.edges:
|
|
284
|
+
key = (edge.caller, edge.callee)
|
|
285
|
+
if key not in seen:
|
|
286
|
+
seen.add(key)
|
|
287
|
+
unique_edges.append(edge)
|
|
288
|
+
|
|
289
|
+
return unique_edges
|
|
290
|
+
|
|
291
|
+
def get_dependent_functions(self) -> Set[str]:
|
|
292
|
+
"""Return all functions that are transitively called from backend functions.
|
|
293
|
+
|
|
294
|
+
This returns the set of all functions reachable from the specified backend
|
|
295
|
+
functions through the call graph. In transitive closure mode, this includes
|
|
296
|
+
all functions in the call chain. In backend-only mode, this only includes
|
|
297
|
+
direct callees of backend functions.
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
Set of qualified function names that are dependencies of the backend functions.
|
|
301
|
+
Excludes the backend functions themselves.
|
|
302
|
+
"""
|
|
303
|
+
# Get all callees from the edges (functions that are called)
|
|
304
|
+
dependent_funcs: Set[str] = set()
|
|
305
|
+
for edge in self.edges:
|
|
306
|
+
dependent_funcs.add(edge.callee)
|
|
307
|
+
|
|
308
|
+
# Remove backend functions from the result if they appear as callees
|
|
309
|
+
for backend in self.backends:
|
|
310
|
+
dependent_funcs.discard(backend)
|
|
311
|
+
# Also check for fully qualified backend names
|
|
312
|
+
for func in list(dependent_funcs):
|
|
313
|
+
if backend in func and func in self.tracked_functions:
|
|
314
|
+
# This is a backend function, remove it
|
|
315
|
+
backend_qualified = None
|
|
316
|
+
for tracked in self.tracked_functions:
|
|
317
|
+
if backend in tracked and tracked in self.local_functions:
|
|
318
|
+
backend_qualified = tracked
|
|
319
|
+
break
|
|
320
|
+
if backend_qualified and func == backend_qualified:
|
|
321
|
+
dependent_funcs.discard(func)
|
|
322
|
+
|
|
323
|
+
return dependent_funcs
|
|
324
|
+
|
|
325
|
+
def visit(self, node: ast.AST) -> Any:
|
|
326
|
+
"""Override visit to filter edges and store source code after traversal."""
|
|
327
|
+
# Store source code if this is the module node
|
|
328
|
+
if isinstance(node, ast.Module) and not self.source_code:
|
|
329
|
+
# Read source code for later extraction
|
|
330
|
+
with open(self.filename, "r") as f:
|
|
331
|
+
self.source_code = f.read()
|
|
332
|
+
|
|
333
|
+
result = super().visit(node)
|
|
334
|
+
|
|
335
|
+
# Only filter edges when visiting the top-level Module node
|
|
336
|
+
if isinstance(node, ast.Module):
|
|
337
|
+
self._filter_edges_by_reachability()
|
|
338
|
+
return result
|
|
339
|
+
|
|
340
|
+
def get_dependent_functions_source_code(self) -> Dict[str, str]:
|
|
341
|
+
"""Return source code for all dependent functions with source location comments.
|
|
342
|
+
|
|
343
|
+
Extracts the source code of all functions that are transitively
|
|
344
|
+
called from the backend functions. Useful for creating standalone
|
|
345
|
+
reproducers or understanding function dependencies.
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
Dictionary mapping function qualified names to their source code.
|
|
349
|
+
Only includes functions that are defined in the analyzed file.
|
|
350
|
+
Each function's source code is prefixed with a comment indicating
|
|
351
|
+
the source file and line numbers.
|
|
352
|
+
"""
|
|
353
|
+
dependent_funcs = self.get_dependent_functions()
|
|
354
|
+
result: Dict[str, str] = {}
|
|
355
|
+
|
|
356
|
+
if not self.source_code:
|
|
357
|
+
# Source code wasn't stored during visit
|
|
358
|
+
with open(self.filename, "r") as f:
|
|
359
|
+
self.source_code = f.read()
|
|
360
|
+
|
|
361
|
+
source_lines = self.source_code.splitlines(keepends=True)
|
|
362
|
+
|
|
363
|
+
for func_name in dependent_funcs:
|
|
364
|
+
if func_name in self.func_nodes:
|
|
365
|
+
node = self.func_nodes[func_name]
|
|
366
|
+
|
|
367
|
+
# Determine starting line: if there are decorators, start from the first decorator
|
|
368
|
+
# Otherwise start from the function definition
|
|
369
|
+
if node.decorator_list:
|
|
370
|
+
# Get the first decorator's line number (1-indexed)
|
|
371
|
+
start_line = node.decorator_list[0].lineno
|
|
372
|
+
else:
|
|
373
|
+
# No decorators, start from the function definition
|
|
374
|
+
start_line = node.lineno
|
|
375
|
+
|
|
376
|
+
# Get the end line of the function (1-indexed)
|
|
377
|
+
end_line = node.end_lineno
|
|
378
|
+
|
|
379
|
+
if start_line is not None and end_line is not None:
|
|
380
|
+
# Convert to 0-indexed and extract source lines
|
|
381
|
+
func_source = "".join(source_lines[start_line - 1 : end_line])
|
|
382
|
+
# Add source location comment
|
|
383
|
+
source_comment = (
|
|
384
|
+
f"# Source: {self.filename}:{start_line}-{end_line}\n"
|
|
385
|
+
)
|
|
386
|
+
result[func_name] = source_comment + func_source
|
|
387
|
+
|
|
388
|
+
return result
|
|
389
|
+
|
|
390
|
+
# ---------- imports / aliases ----------
|
|
391
|
+
def visit_Import(self, node: ast.Import) -> None:
|
|
392
|
+
for alias in node.names:
|
|
393
|
+
name = alias.asname or alias.name.split(".")[0]
|
|
394
|
+
self._bind(name, alias.name)
|
|
395
|
+
self.generic_visit(node)
|
|
396
|
+
|
|
397
|
+
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
|
398
|
+
mod = node.module or ""
|
|
399
|
+
for alias in node.names:
|
|
400
|
+
local = alias.asname or alias.name
|
|
401
|
+
target = f"{mod}.{alias.name}" if mod else alias.name
|
|
402
|
+
self._bind(local, target)
|
|
403
|
+
self.generic_visit(node)
|
|
404
|
+
|
|
405
|
+
# ---------- defs ----------
|
|
406
|
+
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
|
407
|
+
return self._visit_function_like(node)
|
|
408
|
+
|
|
409
|
+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
|
410
|
+
return self._visit_function_like(node)
|
|
411
|
+
|
|
412
|
+
def _visit_function_like(self, node) -> None:
|
|
413
|
+
qual = (self._cur_scope() + "." if self._cur_scope() else "") + node.name
|
|
414
|
+
self._bind(node.name, qual)
|
|
415
|
+
self.local_functions.add(qual)
|
|
416
|
+
|
|
417
|
+
# Store the AST node for later source code extraction
|
|
418
|
+
self.func_nodes[qual] = node
|
|
419
|
+
|
|
420
|
+
# If this function matches any backend (by name substring),
|
|
421
|
+
# add its qualified name to tracked_functions for transitive tracking
|
|
422
|
+
if self.transitive_closure:
|
|
423
|
+
for backend in self.backends:
|
|
424
|
+
if backend in qual:
|
|
425
|
+
self.tracked_functions.add(qual)
|
|
426
|
+
break
|
|
427
|
+
|
|
428
|
+
decorators = []
|
|
429
|
+
if node.decorator_list:
|
|
430
|
+
for dec in node.decorator_list:
|
|
431
|
+
if isinstance(dec, ast.Name):
|
|
432
|
+
callee = self._resolve_name(dec.id)
|
|
433
|
+
elif isinstance(dec, ast.Attribute):
|
|
434
|
+
callee = self._resolve_attr(dec)
|
|
435
|
+
elif isinstance(dec, ast.Call):
|
|
436
|
+
# best effort to guess the structure of the decorator
|
|
437
|
+
if isinstance(dec.func, ast.Name):
|
|
438
|
+
callee = f"<dynamic_decorator_{dec.func.id}>"
|
|
439
|
+
elif isinstance(dec.func, ast.Attribute):
|
|
440
|
+
if isinstance(dec.func.value, ast.Name):
|
|
441
|
+
callee = f"<dynamic_decorator_{dec.func.value.id}.{dec.func.attr}>"
|
|
442
|
+
elif isinstance(dec.func.value, ast.Attribute):
|
|
443
|
+
callee = f"<dynamic_decorator_{dec.func.value.value.id}.{dec.func.attr}>"
|
|
444
|
+
else:
|
|
445
|
+
callee = "<dynamic_decorator>"
|
|
446
|
+
else:
|
|
447
|
+
callee = "<dynamic_decorator>"
|
|
448
|
+
else:
|
|
449
|
+
callee = "<dynamic_decorator>"
|
|
450
|
+
decorators.append(callee)
|
|
451
|
+
|
|
452
|
+
self._bind_func_descriptor(node, decorators)
|
|
453
|
+
|
|
454
|
+
self._push_scope(node.name)
|
|
455
|
+
if node.name in self.backends:
|
|
456
|
+
self._record_call(node.name, node, maybe_triton=False, caller=node.name)
|
|
457
|
+
self.generic_visit(node)
|
|
458
|
+
self._pop_scope()
|
|
459
|
+
|
|
460
|
+
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
|
461
|
+
qual = (self._cur_scope() + "." if self._cur_scope() else "") + node.name
|
|
462
|
+
self._bind(node.name, qual)
|
|
463
|
+
self._push_scope(node.name)
|
|
464
|
+
self.generic_visit(node)
|
|
465
|
+
self._pop_scope()
|
|
466
|
+
|
|
467
|
+
# ---------- lambda support ----------
|
|
468
|
+
def visit_Lambda(self, node: ast.Lambda) -> None:
|
|
469
|
+
"""
|
|
470
|
+
Give each lambda a synthetic qualified name and traverse its body in that scope
|
|
471
|
+
so we can record calls made inside lambda bodies.
|
|
472
|
+
"""
|
|
473
|
+
lid = self._lambda_id(node)
|
|
474
|
+
# Enter a readable, stable scope name
|
|
475
|
+
scope_name = lid.split(".")[-1] # "<lambda>@line:col"
|
|
476
|
+
self._push_scope(scope_name)
|
|
477
|
+
# The lambda body is a single expression; visit it so nested Calls are captured
|
|
478
|
+
self.visit(node.body)
|
|
479
|
+
self._pop_scope()
|
|
480
|
+
# Do not call generic_visit (we already visited body)
|
|
481
|
+
|
|
482
|
+
def visit_Assign(self, node: ast.Assign) -> None:
|
|
483
|
+
def rhs_symbol(n: ast.AST) -> Optional[str]:
|
|
484
|
+
if isinstance(n, ast.Name):
|
|
485
|
+
return self._resolve_name(n.id)
|
|
486
|
+
if isinstance(n, ast.Attribute):
|
|
487
|
+
return self._resolve_attr(n)
|
|
488
|
+
if isinstance(n, ast.Lambda):
|
|
489
|
+
return self._lambda_id(n)
|
|
490
|
+
return None
|
|
491
|
+
|
|
492
|
+
sym = rhs_symbol(node.value)
|
|
493
|
+
if sym:
|
|
494
|
+
for t in node.targets:
|
|
495
|
+
if isinstance(t, ast.Name):
|
|
496
|
+
self._bind(t.id, sym)
|
|
497
|
+
self.generic_visit(node)
|
|
498
|
+
|
|
499
|
+
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
|
|
500
|
+
# a: T = lambda ...
|
|
501
|
+
if node.value is not None:
|
|
502
|
+
if isinstance(node.target, ast.Name):
|
|
503
|
+
if isinstance(node.value, ast.Lambda):
|
|
504
|
+
self._bind(node.target.id, self._lambda_id(node.value))
|
|
505
|
+
elif isinstance(node.value, ast.Name):
|
|
506
|
+
self._bind(node.target.id, self._resolve_name(node.value.id))
|
|
507
|
+
elif isinstance(node.value, ast.Attribute):
|
|
508
|
+
self._bind(node.target.id, self._resolve_attr(node.value))
|
|
509
|
+
self.generic_visit(node)
|
|
510
|
+
|
|
511
|
+
# ---------- call sites ----------
|
|
512
|
+
def visit_Call(self, node: ast.Call) -> None:
|
|
513
|
+
fn = node.func
|
|
514
|
+
maybe_triton = False
|
|
515
|
+
if isinstance(fn, ast.Name):
|
|
516
|
+
callee = self._resolve_name(fn.id)
|
|
517
|
+
elif isinstance(fn, ast.Attribute):
|
|
518
|
+
callee = self._resolve_attr(fn)
|
|
519
|
+
elif isinstance(fn, ast.Lambda):
|
|
520
|
+
callee = self._lambda_id(fn) # inline IIFE-style lambda
|
|
521
|
+
elif isinstance(fn, ast.Subscript):
|
|
522
|
+
# Likely a Triton kernel call with subscript syntax
|
|
523
|
+
if isinstance(fn.value, ast.Name):
|
|
524
|
+
callee = fn.value.id
|
|
525
|
+
elif isinstance(fn.value, ast.Attribute):
|
|
526
|
+
callee = fn.value.value.id # pyre-ignore[16]
|
|
527
|
+
if hasattr(fn.value, "attr"):
|
|
528
|
+
callee = callee + "." + fn.value.attr
|
|
529
|
+
else:
|
|
530
|
+
callee = "<dynamic_call>"
|
|
531
|
+
maybe_triton = True
|
|
532
|
+
else:
|
|
533
|
+
callee = "<dynamic_call>"
|
|
534
|
+
|
|
535
|
+
self._record_call(callee, node, maybe_triton=maybe_triton)
|
|
536
|
+
self.generic_visit(node)
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def test_call_graph_analysis(
|
|
540
|
+
function_name: str, module_name: str, file_path: str
|
|
541
|
+
) -> None:
|
|
542
|
+
import ast
|
|
543
|
+
|
|
544
|
+
print(f"Analyzing call graph for: {function_name}")
|
|
545
|
+
print(f"File: {file_path}")
|
|
546
|
+
print("=" * 80)
|
|
547
|
+
|
|
548
|
+
with open(file_path, "r") as f:
|
|
549
|
+
source = f.read()
|
|
550
|
+
|
|
551
|
+
tree = ast.parse(source, filename=file_path)
|
|
552
|
+
|
|
553
|
+
# Analyze with prefix filters
|
|
554
|
+
analyzer = CallGraph(
|
|
555
|
+
filename=file_path,
|
|
556
|
+
module_name=module_name,
|
|
557
|
+
backends=[f"{module_name}.{function_name}"],
|
|
558
|
+
transitive_closure=True,
|
|
559
|
+
callee_prefix_filters=["triton.", "tl."],
|
|
560
|
+
)
|
|
561
|
+
analyzer.visit(tree)
|
|
562
|
+
|
|
563
|
+
print(f"\nTotal edges found: {len(analyzer.edges)}")
|
|
564
|
+
print(f"Total tracked functions: {len(analyzer.tracked_functions)}")
|
|
565
|
+
print(f"Total local functions: {len(analyzer.local_functions)}")
|
|
566
|
+
|
|
567
|
+
print("\n--- Tracked Functions (all dependencies) ---")
|
|
568
|
+
for func in sorted(analyzer.tracked_functions):
|
|
569
|
+
print(f" - {func}")
|
|
570
|
+
|
|
571
|
+
print("\n--- Sample of Local Functions ---")
|
|
572
|
+
for func in sorted(list(analyzer.local_functions)[:20]):
|
|
573
|
+
print(f" - {func}")
|
|
574
|
+
|
|
575
|
+
print("\n--- Call Graph Edges (triton.* and tl.* filtered out) ---")
|
|
576
|
+
unique_edges = analyzer.get_unique_edges()
|
|
577
|
+
print(
|
|
578
|
+
f"Unique edges: {len(unique_edges)} (total edges with duplicates: {len(analyzer.edges)})"
|
|
579
|
+
)
|
|
580
|
+
for i, edge in enumerate(unique_edges, 1):
|
|
581
|
+
print(f"{i}. {edge.caller} -> {edge.callee} (line {edge.site.lineno})")
|
|
582
|
+
|
|
583
|
+
print("\n--- Dependent Functions (transitively called from backend) ---")
|
|
584
|
+
dependent_funcs = analyzer.get_dependent_functions()
|
|
585
|
+
print(f"Total dependent functions: {len(dependent_funcs)}")
|
|
586
|
+
for func in sorted(dependent_funcs):
|
|
587
|
+
print(f" - {func}")
|
|
588
|
+
|
|
589
|
+
print("\n--- Dependent Functions Source Code ---")
|
|
590
|
+
source_code_map = analyzer.get_dependent_functions_source_code()
|
|
591
|
+
print(f"Total functions with source code: {len(source_code_map)}")
|
|
592
|
+
for func_name in sorted(source_code_map.keys()):
|
|
593
|
+
source = source_code_map[func_name]
|
|
594
|
+
lines = source.split("\n")
|
|
595
|
+
print(f"\n{func_name}:")
|
|
596
|
+
print(f" Lines of code: {len(lines)}")
|
|
597
|
+
print(" First 3 lines:")
|
|
598
|
+
for line in lines[:3]:
|
|
599
|
+
print(f" {line}")
|
|
600
|
+
|
|
601
|
+
print("\n" + "=" * 80)
|
|
602
|
+
print("Analysis complete!")
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
if __name__ == "__main__":
|
|
606
|
+
import argparse
|
|
607
|
+
|
|
608
|
+
parser = argparse.ArgumentParser(
|
|
609
|
+
description="Analyze call graph for a specific function in a Python file."
|
|
610
|
+
)
|
|
611
|
+
parser.add_argument(
|
|
612
|
+
"--file-path",
|
|
613
|
+
"-f",
|
|
614
|
+
required=True,
|
|
615
|
+
help="Path to the source file containing the function",
|
|
616
|
+
)
|
|
617
|
+
parser.add_argument(
|
|
618
|
+
"--module-name",
|
|
619
|
+
"-m",
|
|
620
|
+
required=True,
|
|
621
|
+
help="Fully qualified module name (e.g., module.submodule.file)",
|
|
622
|
+
)
|
|
623
|
+
parser.add_argument(
|
|
624
|
+
"--function-name",
|
|
625
|
+
"-n",
|
|
626
|
+
required=True,
|
|
627
|
+
help="Name of the function to analyze",
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
args = parser.parse_args()
|
|
631
|
+
|
|
632
|
+
test_call_graph_analysis(
|
|
633
|
+
function_name=args.function_name,
|
|
634
|
+
module_name=args.module_name,
|
|
635
|
+
file_path=args.file_path,
|
|
636
|
+
)
|