tritonparse 0.3.2.dev20251210071601__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritonparse might be problematic. Click here for more details.

Files changed (62) hide show
  1. tritonparse/__init__.py +0 -0
  2. tritonparse/__main__.py +7 -0
  3. tritonparse/cli.py +110 -0
  4. tritonparse/common.py +409 -0
  5. tritonparse/context_manager.py +64 -0
  6. tritonparse/event_diff.py +122 -0
  7. tritonparse/extract_source_mappings.py +49 -0
  8. tritonparse/info/__init__.py +30 -0
  9. tritonparse/info/cli.py +121 -0
  10. tritonparse/info/kernel_query.py +209 -0
  11. tritonparse/info/parse_helper.py +70 -0
  12. tritonparse/ir_analysis.py +427 -0
  13. tritonparse/ir_parser.py +365 -0
  14. tritonparse/mapper.py +102 -0
  15. tritonparse/reproducer/__init__.py +0 -0
  16. tritonparse/reproducer/ast_analyzer.py +636 -0
  17. tritonparse/reproducer/cli.py +72 -0
  18. tritonparse/reproducer/consolidated_result.py +52 -0
  19. tritonparse/reproducer/function_extractor.py +228 -0
  20. tritonparse/reproducer/import_info.py +25 -0
  21. tritonparse/reproducer/import_parser.py +178 -0
  22. tritonparse/reproducer/import_resolver.py +151 -0
  23. tritonparse/reproducer/ingestion/ndjson.py +237 -0
  24. tritonparse/reproducer/multi_file_analyzer.py +824 -0
  25. tritonparse/reproducer/orchestrator.py +110 -0
  26. tritonparse/reproducer/placeholder_replacer.py +335 -0
  27. tritonparse/reproducer/templates/__init__.py +0 -0
  28. tritonparse/reproducer/templates/example.py +38 -0
  29. tritonparse/reproducer/templates/loader.py +59 -0
  30. tritonparse/reproducer/templates/tritonbench.py +106 -0
  31. tritonparse/reproducer/templates/utils.py +48 -0
  32. tritonparse/reproducer/tests/__init__.py +0 -0
  33. tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
  34. tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
  35. tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
  36. tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
  37. tritonparse/reproducer/tests/test_import_parser.py +164 -0
  38. tritonparse/reproducer/tests/test_import_resolver.py +88 -0
  39. tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
  40. tritonparse/reproducer/types.py +20 -0
  41. tritonparse/reproducer/utils.py +580 -0
  42. tritonparse/shared_vars.py +12 -0
  43. tritonparse/source_type.py +56 -0
  44. tritonparse/sourcemap_utils.py +96 -0
  45. tritonparse/structured_logging.py +1634 -0
  46. tritonparse/tools/__init__.py +0 -0
  47. tritonparse/tools/decompress_bin_ndjson.py +120 -0
  48. tritonparse/tools/disasm.py +81 -0
  49. tritonparse/tools/extract_irs.py +244 -0
  50. tritonparse/tools/format_fix.py +151 -0
  51. tritonparse/tools/load_tensor.py +76 -0
  52. tritonparse/tools/prettify_ndjson.py +334 -0
  53. tritonparse/tools/readme.md +37 -0
  54. tritonparse/tp_logger.py +9 -0
  55. tritonparse/trace_processor.py +367 -0
  56. tritonparse/utils.py +155 -0
  57. tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
  58. tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
  59. tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
  60. tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
  61. tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
  62. tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
@@ -0,0 +1,636 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import ast
4
+ import builtins
5
+ from dataclasses import dataclass
6
+ from typing import Any, Dict, List, Optional, Set, Tuple
7
+
8
+
9
+ # Default built-in Python functions to filter out from call graph
10
+ DEFAULT_BUILTIN_FILTERS = [
11
+ name
12
+ for name in dir(builtins)
13
+ if callable(getattr(builtins, name)) and not name.startswith("_")
14
+ ]
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class Site:
19
+ filename: str
20
+ lineno: int
21
+ col: int
22
+
23
+
24
+ @dataclass
25
+ class FuncDescriptor:
26
+ name: str
27
+ decorators: List[str]
28
+ site: Site
29
+
30
+
31
+ @dataclass
32
+ class Edge:
33
+ caller: str
34
+ callee: str
35
+ site: Site
36
+ call_type: str
37
+ callee_descriptor: Optional[FuncDescriptor]
38
+
39
+
40
+ def split_by_the_last_dot(s: str) -> Tuple[Optional[str], Optional[str]]:
41
+ if s is None:
42
+ return None, None
43
+ if "." in s:
44
+ return tuple(s.rsplit(".", 1)) # pyre-ignore[7]
45
+ else:
46
+ return None, s
47
+
48
+
49
+ class CallGraph(ast.NodeVisitor):
50
+ """
51
+ AST visitor that builds a call graph by tracking function calls and definitions.
52
+
53
+ This class traverses an AST and records:
54
+ - Function definitions and their decorators
55
+ - Function calls and their call sites
56
+ - Import statements and name bindings
57
+ - Lambda expressions
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ filename: str = "<string>",
63
+ module_name: str = "<module>",
64
+ backends: Optional[List[str]] = None,
65
+ transitive_closure: bool = True,
66
+ callee_prefix_filters: Optional[List[str]] = None,
67
+ callee_name_filters: Optional[List[str]] = None,
68
+ ):
69
+ self.filename = filename
70
+
71
+ self.edges: List[Edge] = []
72
+ self.decorator_edges: List[Edge] = []
73
+ assert backends is not None, "Backends must not be None"
74
+ self.backends: Dict[str, List[Any]] = {}
75
+ for backend in backends:
76
+ self.backends[backend] = []
77
+
78
+ self.scope_stack: List[str] = []
79
+ self.module_name = module_name
80
+
81
+ self.bindings_stack: List[Dict[str, str]] = [dict()]
82
+ self.local_functions: Set[str] = set()
83
+
84
+ # Track functions in the call chain for transitive closure
85
+ self.transitive_closure = transitive_closure
86
+ # Note: backends are provided as short names (e.g., "_attn_fwd_base_opt")
87
+ # but we'll need to match them against fully qualified names later
88
+ # We store both the short name and will add the qualified name when we see the function definition
89
+ self.tracked_functions: Set[str] = (
90
+ set(backends) if transitive_closure else set()
91
+ )
92
+
93
+ # Prefix filters to exclude certain callees (e.g., "triton.", "tl.")
94
+ self.callee_prefix_filters = callee_prefix_filters or []
95
+
96
+ # Name filters to exclude specific built-in function names
97
+ # Combine user-provided filters with default built-ins
98
+ self.callee_name_filters = set(
99
+ (callee_name_filters or []) + DEFAULT_BUILTIN_FILTERS
100
+ )
101
+
102
+ # lambda node -> synthetic id (stable within this pass)
103
+ self._lambda_ids: Dict[ast.Lambda, str] = {}
104
+
105
+ # Store function AST nodes and source code for extraction
106
+ self.func_nodes: Dict[str, ast.FunctionDef] = {}
107
+ self.source_code: str = ""
108
+
109
+ # ---------- helpers ----------
110
+ def _cur_scope(self) -> str:
111
+ return ".".join([self.module_name] + self.scope_stack).strip(".")
112
+
113
+ def _push_scope(self, name: str) -> None:
114
+ self.scope_stack.append(name)
115
+ self.bindings_stack.append({})
116
+
117
+ def _pop_scope(self) -> None:
118
+ self.scope_stack.pop()
119
+ self.bindings_stack.pop()
120
+
121
+ def _bind(self, name: str, target: str) -> None:
122
+ self.bindings_stack[-1][name] = target
123
+
124
+ def _bind_func_descriptor(self, node, decorators: List[str]) -> None:
125
+ name = node.name
126
+ site = Site(
127
+ self.filename, getattr(node, "lineno", -1), getattr(node, "col_offset", -1)
128
+ )
129
+ self.bindings_stack[-1][f"__{name}_descriptor__"] = FuncDescriptor(
130
+ name, decorators, site
131
+ )
132
+
133
+ def _resolve_name(self, id_: str) -> str:
134
+ for env in reversed(self.bindings_stack):
135
+ if id_ in env:
136
+ return env[id_]
137
+ return id_
138
+
139
+ def _resolve_func_descriptor(self, id_: str) -> Optional[FuncDescriptor]:
140
+ for env in reversed(self.bindings_stack):
141
+ decorator_constant = f"__{id_}_descriptor__"
142
+ if decorator_constant in env:
143
+ return env[decorator_constant]
144
+ return None
145
+
146
+ def _resolve_attr(self, node: ast.AST) -> str:
147
+ parts: List[str] = []
148
+ cur = node
149
+ while isinstance(cur, ast.Attribute):
150
+ parts.append(cur.attr)
151
+ cur = cur.value
152
+ if isinstance(cur, ast.Name):
153
+ head = self._resolve_name(cur.id)
154
+ else:
155
+ return "<dynamic_attr>"
156
+ return ".".join([head] + list(reversed(parts)))
157
+
158
+ def _lambda_id(self, node: ast.Lambda) -> str:
159
+ lid = self._lambda_ids.get(node)
160
+ if lid is None:
161
+ scope = self._cur_scope() or "<module>"
162
+ lid = f"{scope}.<lambda>@{getattr(node,'lineno',-1)}:{getattr(node,'col_offset',-1)}"
163
+ self._lambda_ids[node] = lid
164
+ return lid
165
+
166
+ def _record_call(
167
+ self, callee: str, node: ast.AST, maybe_triton: bool = False, caller=None
168
+ ) -> None:
169
+ if caller is None:
170
+ caller = self._cur_scope() or "<module>"
171
+ # replace callee with caller class name if it is "self." call
172
+ if "." in caller and callee.startswith("self."):
173
+ caller_prefix, _ = split_by_the_last_dot(caller)
174
+ # remove the "self." prefix
175
+ callee_name = callee[5:]
176
+ callee = caller_prefix + "." + callee_name
177
+ site = Site(
178
+ self.filename, getattr(node, "lineno", -1), getattr(node, "col_offset", -1)
179
+ )
180
+
181
+ # Check if callee should be filtered out based on prefix filters
182
+ is_filtered = any(
183
+ callee.startswith(prefix) for prefix in self.callee_prefix_filters
184
+ )
185
+ if is_filtered:
186
+ return
187
+
188
+ # Check if callee should be filtered out based on exact name match
189
+ # Extract the function name from qualified name (e.g., "module.func" -> "func")
190
+ callee_name = callee.split(".")[-1] if "." in callee else callee
191
+ if callee_name in self.callee_name_filters:
192
+ return
193
+
194
+ # Determine if the caller should be tracked based on the transitive_closure flag
195
+ if self.transitive_closure:
196
+ # Transitive mode: track calls from functions in tracked_functions or matching backends
197
+ is_tracked = caller in self.tracked_functions or any(
198
+ backend in caller for backend in self.backends
199
+ )
200
+ else:
201
+ # Backend-only mode: only track calls from functions matching backend patterns
202
+ is_tracked = any(backend in caller for backend in self.backends)
203
+
204
+ # In transitive closure mode, record all edges during AST traversal
205
+ # We'll filter them afterwards based on reachability from backends
206
+ if self.transitive_closure:
207
+ callee_descriptor = self._resolve_func_descriptor(callee)
208
+ self.edges.append(
209
+ Edge(
210
+ caller,
211
+ callee,
212
+ callee_descriptor=callee_descriptor,
213
+ site=site,
214
+ call_type="regular",
215
+ )
216
+ )
217
+ elif is_tracked:
218
+ # Backend-only mode: only record if caller matches a backend
219
+ callee_descriptor = self._resolve_func_descriptor(callee)
220
+ self.edges.append(
221
+ Edge(
222
+ caller,
223
+ callee,
224
+ callee_descriptor=callee_descriptor,
225
+ site=site,
226
+ call_type="regular",
227
+ )
228
+ )
229
+
230
+ def _filter_edges_by_reachability(self) -> None:
231
+ """Filter edges to keep only those reachable from backend functions.
232
+
233
+ This implements the transitive closure: starting from backend functions,
234
+ we iteratively add all callees until no new functions are added.
235
+ """
236
+ if not self.transitive_closure:
237
+ return
238
+
239
+ # Build caller -> callees mapping from all recorded edges
240
+ call_graph: Dict[str, Set[str]] = {}
241
+ for edge in self.edges:
242
+ if edge.caller not in call_graph:
243
+ call_graph[edge.caller] = set()
244
+ call_graph[edge.caller].add(edge.callee)
245
+
246
+ # Initialize reachable functions with backends
247
+ reachable: Set[str] = set()
248
+ for backend in self.backends:
249
+ # Add both the short name and any fully qualified names that match
250
+ reachable.add(backend)
251
+ for func in self.local_functions:
252
+ if backend in func:
253
+ reachable.add(func)
254
+
255
+ # Iteratively add callees until no new functions are added
256
+ changed = True
257
+ while changed:
258
+ changed = False
259
+ new_reachable = set(reachable)
260
+ for func in reachable:
261
+ if func in call_graph:
262
+ for callee in call_graph[func]:
263
+ if callee not in new_reachable:
264
+ new_reachable.add(callee)
265
+ changed = True
266
+ reachable = new_reachable
267
+
268
+ # Filter edges to keep only those where caller is reachable
269
+ self.edges = [edge for edge in self.edges if edge.caller in reachable]
270
+
271
+ # Update tracked_functions to reflect reachable functions
272
+ self.tracked_functions = reachable
273
+
274
+ def get_unique_edges(self) -> List[Edge]:
275
+ """Return deduplicated edges, keeping only unique (caller, callee) pairs.
276
+
277
+ When a function calls another function from multiple locations,
278
+ this returns only one edge representing that relationship.
279
+ """
280
+ seen: Set[Tuple[str, str]] = set()
281
+ unique_edges: List[Edge] = []
282
+
283
+ for edge in self.edges:
284
+ key = (edge.caller, edge.callee)
285
+ if key not in seen:
286
+ seen.add(key)
287
+ unique_edges.append(edge)
288
+
289
+ return unique_edges
290
+
291
+ def get_dependent_functions(self) -> Set[str]:
292
+ """Return all functions that are transitively called from backend functions.
293
+
294
+ This returns the set of all functions reachable from the specified backend
295
+ functions through the call graph. In transitive closure mode, this includes
296
+ all functions in the call chain. In backend-only mode, this only includes
297
+ direct callees of backend functions.
298
+
299
+ Returns:
300
+ Set of qualified function names that are dependencies of the backend functions.
301
+ Excludes the backend functions themselves.
302
+ """
303
+ # Get all callees from the edges (functions that are called)
304
+ dependent_funcs: Set[str] = set()
305
+ for edge in self.edges:
306
+ dependent_funcs.add(edge.callee)
307
+
308
+ # Remove backend functions from the result if they appear as callees
309
+ for backend in self.backends:
310
+ dependent_funcs.discard(backend)
311
+ # Also check for fully qualified backend names
312
+ for func in list(dependent_funcs):
313
+ if backend in func and func in self.tracked_functions:
314
+ # This is a backend function, remove it
315
+ backend_qualified = None
316
+ for tracked in self.tracked_functions:
317
+ if backend in tracked and tracked in self.local_functions:
318
+ backend_qualified = tracked
319
+ break
320
+ if backend_qualified and func == backend_qualified:
321
+ dependent_funcs.discard(func)
322
+
323
+ return dependent_funcs
324
+
325
+ def visit(self, node: ast.AST) -> Any:
326
+ """Override visit to filter edges and store source code after traversal."""
327
+ # Store source code if this is the module node
328
+ if isinstance(node, ast.Module) and not self.source_code:
329
+ # Read source code for later extraction
330
+ with open(self.filename, "r") as f:
331
+ self.source_code = f.read()
332
+
333
+ result = super().visit(node)
334
+
335
+ # Only filter edges when visiting the top-level Module node
336
+ if isinstance(node, ast.Module):
337
+ self._filter_edges_by_reachability()
338
+ return result
339
+
340
+ def get_dependent_functions_source_code(self) -> Dict[str, str]:
341
+ """Return source code for all dependent functions with source location comments.
342
+
343
+ Extracts the source code of all functions that are transitively
344
+ called from the backend functions. Useful for creating standalone
345
+ reproducers or understanding function dependencies.
346
+
347
+ Returns:
348
+ Dictionary mapping function qualified names to their source code.
349
+ Only includes functions that are defined in the analyzed file.
350
+ Each function's source code is prefixed with a comment indicating
351
+ the source file and line numbers.
352
+ """
353
+ dependent_funcs = self.get_dependent_functions()
354
+ result: Dict[str, str] = {}
355
+
356
+ if not self.source_code:
357
+ # Source code wasn't stored during visit
358
+ with open(self.filename, "r") as f:
359
+ self.source_code = f.read()
360
+
361
+ source_lines = self.source_code.splitlines(keepends=True)
362
+
363
+ for func_name in dependent_funcs:
364
+ if func_name in self.func_nodes:
365
+ node = self.func_nodes[func_name]
366
+
367
+ # Determine starting line: if there are decorators, start from the first decorator
368
+ # Otherwise start from the function definition
369
+ if node.decorator_list:
370
+ # Get the first decorator's line number (1-indexed)
371
+ start_line = node.decorator_list[0].lineno
372
+ else:
373
+ # No decorators, start from the function definition
374
+ start_line = node.lineno
375
+
376
+ # Get the end line of the function (1-indexed)
377
+ end_line = node.end_lineno
378
+
379
+ if start_line is not None and end_line is not None:
380
+ # Convert to 0-indexed and extract source lines
381
+ func_source = "".join(source_lines[start_line - 1 : end_line])
382
+ # Add source location comment
383
+ source_comment = (
384
+ f"# Source: {self.filename}:{start_line}-{end_line}\n"
385
+ )
386
+ result[func_name] = source_comment + func_source
387
+
388
+ return result
389
+
390
+ # ---------- imports / aliases ----------
391
+ def visit_Import(self, node: ast.Import) -> None:
392
+ for alias in node.names:
393
+ name = alias.asname or alias.name.split(".")[0]
394
+ self._bind(name, alias.name)
395
+ self.generic_visit(node)
396
+
397
+ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
398
+ mod = node.module or ""
399
+ for alias in node.names:
400
+ local = alias.asname or alias.name
401
+ target = f"{mod}.{alias.name}" if mod else alias.name
402
+ self._bind(local, target)
403
+ self.generic_visit(node)
404
+
405
+ # ---------- defs ----------
406
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
407
+ return self._visit_function_like(node)
408
+
409
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
410
+ return self._visit_function_like(node)
411
+
412
+ def _visit_function_like(self, node) -> None:
413
+ qual = (self._cur_scope() + "." if self._cur_scope() else "") + node.name
414
+ self._bind(node.name, qual)
415
+ self.local_functions.add(qual)
416
+
417
+ # Store the AST node for later source code extraction
418
+ self.func_nodes[qual] = node
419
+
420
+ # If this function matches any backend (by name substring),
421
+ # add its qualified name to tracked_functions for transitive tracking
422
+ if self.transitive_closure:
423
+ for backend in self.backends:
424
+ if backend in qual:
425
+ self.tracked_functions.add(qual)
426
+ break
427
+
428
+ decorators = []
429
+ if node.decorator_list:
430
+ for dec in node.decorator_list:
431
+ if isinstance(dec, ast.Name):
432
+ callee = self._resolve_name(dec.id)
433
+ elif isinstance(dec, ast.Attribute):
434
+ callee = self._resolve_attr(dec)
435
+ elif isinstance(dec, ast.Call):
436
+ # best effort to guess the structure of the decorator
437
+ if isinstance(dec.func, ast.Name):
438
+ callee = f"<dynamic_decorator_{dec.func.id}>"
439
+ elif isinstance(dec.func, ast.Attribute):
440
+ if isinstance(dec.func.value, ast.Name):
441
+ callee = f"<dynamic_decorator_{dec.func.value.id}.{dec.func.attr}>"
442
+ elif isinstance(dec.func.value, ast.Attribute):
443
+ callee = f"<dynamic_decorator_{dec.func.value.value.id}.{dec.func.attr}>"
444
+ else:
445
+ callee = "<dynamic_decorator>"
446
+ else:
447
+ callee = "<dynamic_decorator>"
448
+ else:
449
+ callee = "<dynamic_decorator>"
450
+ decorators.append(callee)
451
+
452
+ self._bind_func_descriptor(node, decorators)
453
+
454
+ self._push_scope(node.name)
455
+ if node.name in self.backends:
456
+ self._record_call(node.name, node, maybe_triton=False, caller=node.name)
457
+ self.generic_visit(node)
458
+ self._pop_scope()
459
+
460
+ def visit_ClassDef(self, node: ast.ClassDef) -> None:
461
+ qual = (self._cur_scope() + "." if self._cur_scope() else "") + node.name
462
+ self._bind(node.name, qual)
463
+ self._push_scope(node.name)
464
+ self.generic_visit(node)
465
+ self._pop_scope()
466
+
467
+ # ---------- lambda support ----------
468
+ def visit_Lambda(self, node: ast.Lambda) -> None:
469
+ """
470
+ Give each lambda a synthetic qualified name and traverse its body in that scope
471
+ so we can record calls made inside lambda bodies.
472
+ """
473
+ lid = self._lambda_id(node)
474
+ # Enter a readable, stable scope name
475
+ scope_name = lid.split(".")[-1] # "<lambda>@line:col"
476
+ self._push_scope(scope_name)
477
+ # The lambda body is a single expression; visit it so nested Calls are captured
478
+ self.visit(node.body)
479
+ self._pop_scope()
480
+ # Do not call generic_visit (we already visited body)
481
+
482
+ def visit_Assign(self, node: ast.Assign) -> None:
483
+ def rhs_symbol(n: ast.AST) -> Optional[str]:
484
+ if isinstance(n, ast.Name):
485
+ return self._resolve_name(n.id)
486
+ if isinstance(n, ast.Attribute):
487
+ return self._resolve_attr(n)
488
+ if isinstance(n, ast.Lambda):
489
+ return self._lambda_id(n)
490
+ return None
491
+
492
+ sym = rhs_symbol(node.value)
493
+ if sym:
494
+ for t in node.targets:
495
+ if isinstance(t, ast.Name):
496
+ self._bind(t.id, sym)
497
+ self.generic_visit(node)
498
+
499
+ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
500
+ # a: T = lambda ...
501
+ if node.value is not None:
502
+ if isinstance(node.target, ast.Name):
503
+ if isinstance(node.value, ast.Lambda):
504
+ self._bind(node.target.id, self._lambda_id(node.value))
505
+ elif isinstance(node.value, ast.Name):
506
+ self._bind(node.target.id, self._resolve_name(node.value.id))
507
+ elif isinstance(node.value, ast.Attribute):
508
+ self._bind(node.target.id, self._resolve_attr(node.value))
509
+ self.generic_visit(node)
510
+
511
+ # ---------- call sites ----------
512
+ def visit_Call(self, node: ast.Call) -> None:
513
+ fn = node.func
514
+ maybe_triton = False
515
+ if isinstance(fn, ast.Name):
516
+ callee = self._resolve_name(fn.id)
517
+ elif isinstance(fn, ast.Attribute):
518
+ callee = self._resolve_attr(fn)
519
+ elif isinstance(fn, ast.Lambda):
520
+ callee = self._lambda_id(fn) # inline IIFE-style lambda
521
+ elif isinstance(fn, ast.Subscript):
522
+ # Likely a Triton kernel call with subscript syntax
523
+ if isinstance(fn.value, ast.Name):
524
+ callee = fn.value.id
525
+ elif isinstance(fn.value, ast.Attribute):
526
+ callee = fn.value.value.id # pyre-ignore[16]
527
+ if hasattr(fn.value, "attr"):
528
+ callee = callee + "." + fn.value.attr
529
+ else:
530
+ callee = "<dynamic_call>"
531
+ maybe_triton = True
532
+ else:
533
+ callee = "<dynamic_call>"
534
+
535
+ self._record_call(callee, node, maybe_triton=maybe_triton)
536
+ self.generic_visit(node)
537
+
538
+
539
+ def test_call_graph_analysis(
540
+ function_name: str, module_name: str, file_path: str
541
+ ) -> None:
542
+ import ast
543
+
544
+ print(f"Analyzing call graph for: {function_name}")
545
+ print(f"File: {file_path}")
546
+ print("=" * 80)
547
+
548
+ with open(file_path, "r") as f:
549
+ source = f.read()
550
+
551
+ tree = ast.parse(source, filename=file_path)
552
+
553
+ # Analyze with prefix filters
554
+ analyzer = CallGraph(
555
+ filename=file_path,
556
+ module_name=module_name,
557
+ backends=[f"{module_name}.{function_name}"],
558
+ transitive_closure=True,
559
+ callee_prefix_filters=["triton.", "tl."],
560
+ )
561
+ analyzer.visit(tree)
562
+
563
+ print(f"\nTotal edges found: {len(analyzer.edges)}")
564
+ print(f"Total tracked functions: {len(analyzer.tracked_functions)}")
565
+ print(f"Total local functions: {len(analyzer.local_functions)}")
566
+
567
+ print("\n--- Tracked Functions (all dependencies) ---")
568
+ for func in sorted(analyzer.tracked_functions):
569
+ print(f" - {func}")
570
+
571
+ print("\n--- Sample of Local Functions ---")
572
+ for func in sorted(list(analyzer.local_functions)[:20]):
573
+ print(f" - {func}")
574
+
575
+ print("\n--- Call Graph Edges (triton.* and tl.* filtered out) ---")
576
+ unique_edges = analyzer.get_unique_edges()
577
+ print(
578
+ f"Unique edges: {len(unique_edges)} (total edges with duplicates: {len(analyzer.edges)})"
579
+ )
580
+ for i, edge in enumerate(unique_edges, 1):
581
+ print(f"{i}. {edge.caller} -> {edge.callee} (line {edge.site.lineno})")
582
+
583
+ print("\n--- Dependent Functions (transitively called from backend) ---")
584
+ dependent_funcs = analyzer.get_dependent_functions()
585
+ print(f"Total dependent functions: {len(dependent_funcs)}")
586
+ for func in sorted(dependent_funcs):
587
+ print(f" - {func}")
588
+
589
+ print("\n--- Dependent Functions Source Code ---")
590
+ source_code_map = analyzer.get_dependent_functions_source_code()
591
+ print(f"Total functions with source code: {len(source_code_map)}")
592
+ for func_name in sorted(source_code_map.keys()):
593
+ source = source_code_map[func_name]
594
+ lines = source.split("\n")
595
+ print(f"\n{func_name}:")
596
+ print(f" Lines of code: {len(lines)}")
597
+ print(" First 3 lines:")
598
+ for line in lines[:3]:
599
+ print(f" {line}")
600
+
601
+ print("\n" + "=" * 80)
602
+ print("Analysis complete!")
603
+
604
+
605
+ if __name__ == "__main__":
606
+ import argparse
607
+
608
+ parser = argparse.ArgumentParser(
609
+ description="Analyze call graph for a specific function in a Python file."
610
+ )
611
+ parser.add_argument(
612
+ "--file-path",
613
+ "-f",
614
+ required=True,
615
+ help="Path to the source file containing the function",
616
+ )
617
+ parser.add_argument(
618
+ "--module-name",
619
+ "-m",
620
+ required=True,
621
+ help="Fully qualified module name (e.g., module.submodule.file)",
622
+ )
623
+ parser.add_argument(
624
+ "--function-name",
625
+ "-n",
626
+ required=True,
627
+ help="Name of the function to analyze",
628
+ )
629
+
630
+ args = parser.parse_args()
631
+
632
+ test_call_graph_analysis(
633
+ function_name=args.function_name,
634
+ module_name=args.module_name,
635
+ file_path=args.file_path,
636
+ )