tritonparse 0.3.2.dev20251210071601__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/__init__.py +0 -0
- tritonparse/__main__.py +7 -0
- tritonparse/cli.py +110 -0
- tritonparse/common.py +409 -0
- tritonparse/context_manager.py +64 -0
- tritonparse/event_diff.py +122 -0
- tritonparse/extract_source_mappings.py +49 -0
- tritonparse/info/__init__.py +30 -0
- tritonparse/info/cli.py +121 -0
- tritonparse/info/kernel_query.py +209 -0
- tritonparse/info/parse_helper.py +70 -0
- tritonparse/ir_analysis.py +427 -0
- tritonparse/ir_parser.py +365 -0
- tritonparse/mapper.py +102 -0
- tritonparse/reproducer/__init__.py +0 -0
- tritonparse/reproducer/ast_analyzer.py +636 -0
- tritonparse/reproducer/cli.py +72 -0
- tritonparse/reproducer/consolidated_result.py +52 -0
- tritonparse/reproducer/function_extractor.py +228 -0
- tritonparse/reproducer/import_info.py +25 -0
- tritonparse/reproducer/import_parser.py +178 -0
- tritonparse/reproducer/import_resolver.py +151 -0
- tritonparse/reproducer/ingestion/ndjson.py +237 -0
- tritonparse/reproducer/multi_file_analyzer.py +824 -0
- tritonparse/reproducer/orchestrator.py +110 -0
- tritonparse/reproducer/placeholder_replacer.py +335 -0
- tritonparse/reproducer/templates/__init__.py +0 -0
- tritonparse/reproducer/templates/example.py +38 -0
- tritonparse/reproducer/templates/loader.py +59 -0
- tritonparse/reproducer/templates/tritonbench.py +106 -0
- tritonparse/reproducer/templates/utils.py +48 -0
- tritonparse/reproducer/tests/__init__.py +0 -0
- tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
- tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
- tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
- tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
- tritonparse/reproducer/tests/test_import_parser.py +164 -0
- tritonparse/reproducer/tests/test_import_resolver.py +88 -0
- tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
- tritonparse/reproducer/types.py +20 -0
- tritonparse/reproducer/utils.py +580 -0
- tritonparse/shared_vars.py +12 -0
- tritonparse/source_type.py +56 -0
- tritonparse/sourcemap_utils.py +96 -0
- tritonparse/structured_logging.py +1634 -0
- tritonparse/tools/__init__.py +0 -0
- tritonparse/tools/decompress_bin_ndjson.py +120 -0
- tritonparse/tools/disasm.py +81 -0
- tritonparse/tools/extract_irs.py +244 -0
- tritonparse/tools/format_fix.py +151 -0
- tritonparse/tools/load_tensor.py +76 -0
- tritonparse/tools/prettify_ndjson.py +334 -0
- tritonparse/tools/readme.md +37 -0
- tritonparse/tp_logger.py +9 -0
- tritonparse/trace_processor.py +367 -0
- tritonparse/utils.py +155 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,824 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
# pyre-strict
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
MultiFileCallGraphAnalyzer: Orchestrates multi-file static call graph analysis.
|
|
7
|
+
|
|
8
|
+
This module provides the main analyzer that coordinates traversal across multiple
|
|
9
|
+
Python files, following imports to extract all transitively-called functions.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import argparse
|
|
13
|
+
import ast
|
|
14
|
+
import json
|
|
15
|
+
import logging
|
|
16
|
+
import tempfile
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import List, Optional, Set
|
|
19
|
+
|
|
20
|
+
from tritonparse.reproducer.ast_analyzer import CallGraph, Edge
|
|
21
|
+
from tritonparse.reproducer.consolidated_result import AnalysisStats, ConsolidatedResult
|
|
22
|
+
from tritonparse.reproducer.import_info import ImportInfo
|
|
23
|
+
from tritonparse.reproducer.import_parser import ImportParser
|
|
24
|
+
from tritonparse.reproducer.import_resolver import ImportResolver
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _auto_detect_code_root(entry_file: str) -> str:
|
|
30
|
+
"""
|
|
31
|
+
Auto-detect code root from entry file path.
|
|
32
|
+
|
|
33
|
+
Walks up the directory tree from the entry file until it finds a common
|
|
34
|
+
code root indicator. Currently supports:
|
|
35
|
+
- "fbcode" directories (Meta's monorepo structure)
|
|
36
|
+
- Directories containing "setup.py", "pyproject.toml" (Python projects)
|
|
37
|
+
- Git repository roots (directories containing ".git")
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
entry_file: Absolute path to the entry file
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Absolute path to the code root directory
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
ValueError: If code root cannot be detected from the entry file path
|
|
47
|
+
"""
|
|
48
|
+
path = Path(entry_file).resolve()
|
|
49
|
+
logger.debug("Auto-detecting code root from: %s", path)
|
|
50
|
+
|
|
51
|
+
# Walk up the directory tree
|
|
52
|
+
for parent in path.parents:
|
|
53
|
+
# Check for fbcode (Meta's monorepo)
|
|
54
|
+
if parent.name == "fbcode" or parent.name == "fbsource":
|
|
55
|
+
logger.info("Auto-detected code root (fbcode): %s", parent)
|
|
56
|
+
return str(parent)
|
|
57
|
+
|
|
58
|
+
# Check for Python project markers
|
|
59
|
+
if (parent / "setup.py").exists() or (parent / "pyproject.toml").exists():
|
|
60
|
+
logger.info("Auto-detected code root (Python project): %s", parent)
|
|
61
|
+
return str(parent)
|
|
62
|
+
|
|
63
|
+
# Check for Git repository root
|
|
64
|
+
if (parent / ".git").exists():
|
|
65
|
+
logger.info("Auto-detected code root (Git repository): %s", parent)
|
|
66
|
+
return str(parent)
|
|
67
|
+
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"Could not auto-detect code root from entry file: {entry_file}. "
|
|
70
|
+
"Please specify --code-root explicitly."
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class MultiFileCallGraphAnalyzer:
|
|
75
|
+
"""
|
|
76
|
+
Multi-file static call graph analyzer for creating reproducers.
|
|
77
|
+
|
|
78
|
+
This analyzer orchestrates the analysis across multiple Python files by:
|
|
79
|
+
1. Starting with an entry file and function
|
|
80
|
+
2. Using CallGraph to extract dependencies within each file
|
|
81
|
+
3. Following imports to analyze dependent files
|
|
82
|
+
4. Consolidating results from all analyzed files
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
entry_file: str,
|
|
88
|
+
entry_function: str,
|
|
89
|
+
code_roots: Optional[str] = None,
|
|
90
|
+
) -> None:
|
|
91
|
+
"""
|
|
92
|
+
Initialize multi-file analyzer.
|
|
93
|
+
|
|
94
|
+
The analyzer automatically computes the qualified backend name from
|
|
95
|
+
entry_file + entry_function, removing the need for manual backend specification.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
entry_file: Path to file containing entry function
|
|
99
|
+
entry_function: Name of entry function (short name, e.g., "main_kernel")
|
|
100
|
+
code_roots: Absolute path to default code root directory (auto-detected if None)
|
|
101
|
+
|
|
102
|
+
Example:
|
|
103
|
+
>>> analyzer = MultiFileCallGraphAnalyzer(
|
|
104
|
+
... entry_file="/path/to/kernel.py",
|
|
105
|
+
... entry_function="main_kernel",
|
|
106
|
+
... code_roots="/path/to/fbcode", # Optional
|
|
107
|
+
... )
|
|
108
|
+
"""
|
|
109
|
+
self.entry_file = entry_file
|
|
110
|
+
self.entry_function = entry_function
|
|
111
|
+
|
|
112
|
+
# Auto-detect code_roots if not provided
|
|
113
|
+
if code_roots is None:
|
|
114
|
+
code_roots = _auto_detect_code_root(entry_file)
|
|
115
|
+
self.code_roots = code_roots
|
|
116
|
+
|
|
117
|
+
# Track multiple code roots for files from different projects
|
|
118
|
+
self.file_to_code_root: dict[str, str] = {entry_file: code_roots}
|
|
119
|
+
|
|
120
|
+
entry_module = self._file_to_module_name(entry_file)
|
|
121
|
+
qualified_backend = f"{entry_module}.{entry_function}"
|
|
122
|
+
self.backends = [entry_function]
|
|
123
|
+
self.qualified_backend = qualified_backend
|
|
124
|
+
|
|
125
|
+
self.import_resolver = ImportResolver(code_roots)
|
|
126
|
+
self.import_parser = ImportParser(self.import_resolver)
|
|
127
|
+
|
|
128
|
+
self.visited_files: Set[str] = set()
|
|
129
|
+
self.pending_files: list[str] = []
|
|
130
|
+
self.file_analyzers: dict[str, CallGraph] = {}
|
|
131
|
+
self.all_imports: dict[str, list[ImportInfo]] = {}
|
|
132
|
+
self.used_imports: dict[str, list[ImportInfo]] = {}
|
|
133
|
+
|
|
134
|
+
def analyze(self) -> ConsolidatedResult:
|
|
135
|
+
"""
|
|
136
|
+
Perform multi-file analysis.
|
|
137
|
+
|
|
138
|
+
This is the main entry point that:
|
|
139
|
+
1. Analyzes entry file
|
|
140
|
+
2. Recursively analyzes imported files (breadth-first)
|
|
141
|
+
3. Consolidates results
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
ConsolidatedResult with functions, imports, and statistics
|
|
145
|
+
"""
|
|
146
|
+
self._analyze_file(self.entry_file)
|
|
147
|
+
|
|
148
|
+
while self.pending_files:
|
|
149
|
+
pending_item = self.pending_files.pop(0)
|
|
150
|
+
|
|
151
|
+
if isinstance(pending_item, tuple):
|
|
152
|
+
file_path, backends_for_file = pending_item
|
|
153
|
+
else:
|
|
154
|
+
file_path = pending_item
|
|
155
|
+
backends_for_file = None
|
|
156
|
+
|
|
157
|
+
if file_path not in self.visited_files:
|
|
158
|
+
self._analyze_file(file_path, backends_for_file)
|
|
159
|
+
|
|
160
|
+
return self._consolidate_results()
|
|
161
|
+
|
|
162
|
+
def _analyze_file(
|
|
163
|
+
self, file_path: str, backends_for_file: Optional[list[str]] = None
|
|
164
|
+
) -> None:
|
|
165
|
+
"""
|
|
166
|
+
Analyze a single file with CallGraph.
|
|
167
|
+
|
|
168
|
+
This method:
|
|
169
|
+
1. Marks file as visited
|
|
170
|
+
2. Creates CallGraph analyzer for the file
|
|
171
|
+
3. Extracts dependent functions
|
|
172
|
+
4. Parses imports from AST
|
|
173
|
+
5. Identifies which imports are used by dependent functions
|
|
174
|
+
6. Adds imported files to pending_files queue
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
file_path: Absolute path to the Python file to analyze
|
|
178
|
+
backends_for_file: Specific backends for this file (defaults to self.backends for entry file)
|
|
179
|
+
"""
|
|
180
|
+
self.visited_files.add(file_path)
|
|
181
|
+
logger.info("Analyzing %s", file_path)
|
|
182
|
+
|
|
183
|
+
with open(file_path) as f:
|
|
184
|
+
source_code = f.read()
|
|
185
|
+
tree = ast.parse(source_code, filename=file_path)
|
|
186
|
+
|
|
187
|
+
file_backends = (
|
|
188
|
+
backends_for_file if backends_for_file is not None else self.backends
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
module_name = self._file_to_module_name(file_path)
|
|
192
|
+
analyzer = CallGraph(
|
|
193
|
+
filename=file_path,
|
|
194
|
+
module_name=module_name,
|
|
195
|
+
backends=file_backends,
|
|
196
|
+
transitive_closure=True,
|
|
197
|
+
)
|
|
198
|
+
analyzer.visit(tree)
|
|
199
|
+
|
|
200
|
+
self.file_analyzers[file_path] = analyzer
|
|
201
|
+
|
|
202
|
+
module_name = self._file_to_module_name(file_path)
|
|
203
|
+
package_name = (
|
|
204
|
+
".".join(module_name.split(".")[:-1]) if "." in module_name else None
|
|
205
|
+
)
|
|
206
|
+
imports = self.import_parser.parse_imports(
|
|
207
|
+
tree, file_path, package=package_name
|
|
208
|
+
)
|
|
209
|
+
self.all_imports[file_path] = imports
|
|
210
|
+
|
|
211
|
+
used_imports_list = self._identify_used_imports(imports, analyzer)
|
|
212
|
+
self.used_imports[file_path] = used_imports_list
|
|
213
|
+
|
|
214
|
+
imports_by_file: dict[str, list[str]] = {}
|
|
215
|
+
for import_info in used_imports_list:
|
|
216
|
+
if (
|
|
217
|
+
import_info.resolved_path
|
|
218
|
+
and not import_info.is_external
|
|
219
|
+
and import_info.resolved_path not in self.visited_files
|
|
220
|
+
):
|
|
221
|
+
if import_info.resolved_path not in imports_by_file:
|
|
222
|
+
imports_by_file[import_info.resolved_path] = []
|
|
223
|
+
imports_by_file[import_info.resolved_path].extend(import_info.names)
|
|
224
|
+
|
|
225
|
+
if imports_by_file:
|
|
226
|
+
logger.info(
|
|
227
|
+
"Found %d internal import file(s) to analyze from %s",
|
|
228
|
+
len(imports_by_file),
|
|
229
|
+
file_path,
|
|
230
|
+
)
|
|
231
|
+
for resolved_path, imported_names in imports_by_file.items():
|
|
232
|
+
logger.debug(
|
|
233
|
+
" → Will analyze %s for functions: %s",
|
|
234
|
+
resolved_path,
|
|
235
|
+
imported_names,
|
|
236
|
+
)
|
|
237
|
+
else:
|
|
238
|
+
logger.debug(
|
|
239
|
+
"No internal imports found in %s (all imports are external or already visited)",
|
|
240
|
+
file_path,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
for resolved_path, imported_names in imports_by_file.items():
|
|
244
|
+
already_pending = any(
|
|
245
|
+
(
|
|
246
|
+
pending_item[0] == resolved_path
|
|
247
|
+
if isinstance(pending_item, tuple)
|
|
248
|
+
else pending_item == resolved_path
|
|
249
|
+
)
|
|
250
|
+
for pending_item in self.pending_files
|
|
251
|
+
)
|
|
252
|
+
if not already_pending:
|
|
253
|
+
unique_names = list(dict.fromkeys(imported_names))
|
|
254
|
+
self.pending_files.append((resolved_path, unique_names))
|
|
255
|
+
logger.info(
|
|
256
|
+
"Added %s to pending analysis queue with backends: %s",
|
|
257
|
+
resolved_path,
|
|
258
|
+
unique_names,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
def _identify_used_imports(
|
|
262
|
+
self,
|
|
263
|
+
imports: list[ImportInfo],
|
|
264
|
+
analyzer: CallGraph,
|
|
265
|
+
) -> list[ImportInfo]:
|
|
266
|
+
"""
|
|
267
|
+
Identify which imports are actually used by dependent functions.
|
|
268
|
+
|
|
269
|
+
Strategy:
|
|
270
|
+
1. Get all callees from dependent functions (from call graph edges)
|
|
271
|
+
2. Match callees to import statements
|
|
272
|
+
3. Return only imports that are actually used
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
imports: All imports in the file
|
|
276
|
+
analyzer: CallGraph analyzer for this file
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
List of ImportInfo objects for imports that are used
|
|
280
|
+
"""
|
|
281
|
+
used_symbols = self._extract_used_symbols(analyzer)
|
|
282
|
+
logger.debug("Used symbols: %s", used_symbols)
|
|
283
|
+
|
|
284
|
+
matching_imports = self._find_matching_imports(imports, used_symbols)
|
|
285
|
+
import_groups = self._group_imports(matching_imports)
|
|
286
|
+
used_imports_list = self._select_best_imports(import_groups)
|
|
287
|
+
|
|
288
|
+
return used_imports_list
|
|
289
|
+
|
|
290
|
+
def _extract_used_symbols(self, analyzer: CallGraph) -> Set[str]:
|
|
291
|
+
"""Extract all symbols used in the call graph."""
|
|
292
|
+
used_symbols: Set[str] = set()
|
|
293
|
+
for edge in analyzer.edges:
|
|
294
|
+
used_symbols.add(edge.callee)
|
|
295
|
+
callee_parts = edge.callee.split(".")
|
|
296
|
+
if callee_parts:
|
|
297
|
+
used_symbols.add(callee_parts[0])
|
|
298
|
+
return used_symbols
|
|
299
|
+
|
|
300
|
+
def _find_matching_imports(
|
|
301
|
+
self, imports: list[ImportInfo], used_symbols: Set[str]
|
|
302
|
+
) -> list[tuple[ImportInfo, bool]]:
|
|
303
|
+
"""Find imports that match the used symbols."""
|
|
304
|
+
matching_imports: list[tuple[ImportInfo, bool]] = []
|
|
305
|
+
|
|
306
|
+
for import_info in imports:
|
|
307
|
+
logger.debug(
|
|
308
|
+
"Checking import: %s %s -> %s (external: %s)",
|
|
309
|
+
import_info.import_type,
|
|
310
|
+
import_info.module,
|
|
311
|
+
import_info.names,
|
|
312
|
+
import_info.is_external,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
is_used = self._is_import_used(import_info, used_symbols)
|
|
316
|
+
if is_used:
|
|
317
|
+
is_internal_match = not import_info.is_external
|
|
318
|
+
matching_imports.append((import_info, is_internal_match))
|
|
319
|
+
|
|
320
|
+
return matching_imports
|
|
321
|
+
|
|
322
|
+
def _is_import_used(self, import_info: ImportInfo, used_symbols: Set[str]) -> bool:
|
|
323
|
+
"""Check if an import is used based on the used symbols."""
|
|
324
|
+
if self._matches_qualified_name(import_info, used_symbols):
|
|
325
|
+
return True
|
|
326
|
+
if self._matches_short_name(import_info, used_symbols):
|
|
327
|
+
return True
|
|
328
|
+
if self._matches_name(import_info, used_symbols):
|
|
329
|
+
return True
|
|
330
|
+
if self._matches_alias(import_info, used_symbols):
|
|
331
|
+
return True
|
|
332
|
+
if self._matches_module_prefix(import_info, used_symbols):
|
|
333
|
+
return True
|
|
334
|
+
return False
|
|
335
|
+
|
|
336
|
+
def _matches_qualified_name(
|
|
337
|
+
self, import_info: ImportInfo, used_symbols: Set[str]
|
|
338
|
+
) -> bool:
|
|
339
|
+
"""Check if import matches on qualified name."""
|
|
340
|
+
if import_info.import_type != "from_import" or not import_info.module:
|
|
341
|
+
return False
|
|
342
|
+
|
|
343
|
+
for name in import_info.names:
|
|
344
|
+
qualified_name = f"{import_info.module}.{name}"
|
|
345
|
+
for symbol in used_symbols:
|
|
346
|
+
if symbol == qualified_name or symbol.startswith(qualified_name + "."):
|
|
347
|
+
logger.debug(
|
|
348
|
+
" Matched on qualified name: %s == %s",
|
|
349
|
+
qualified_name,
|
|
350
|
+
symbol,
|
|
351
|
+
)
|
|
352
|
+
return True
|
|
353
|
+
return False
|
|
354
|
+
|
|
355
|
+
def _matches_short_name(
|
|
356
|
+
self, import_info: ImportInfo, used_symbols: Set[str]
|
|
357
|
+
) -> bool:
|
|
358
|
+
"""Check if import matches on short module name."""
|
|
359
|
+
if import_info.import_type != "from_import" or not import_info.module:
|
|
360
|
+
return False
|
|
361
|
+
|
|
362
|
+
module_short_name = import_info.module.split(".")[-1]
|
|
363
|
+
for name in import_info.names:
|
|
364
|
+
short_qualified = f"{module_short_name}.{name}"
|
|
365
|
+
if short_qualified in used_symbols:
|
|
366
|
+
logger.debug(
|
|
367
|
+
" Matched on short name: %s in used_symbols",
|
|
368
|
+
short_qualified,
|
|
369
|
+
)
|
|
370
|
+
return True
|
|
371
|
+
return False
|
|
372
|
+
|
|
373
|
+
def _matches_name(self, import_info: ImportInfo, used_symbols: Set[str]) -> bool:
|
|
374
|
+
"""Check if import matches on name."""
|
|
375
|
+
for name in import_info.names:
|
|
376
|
+
if name in used_symbols:
|
|
377
|
+
logger.debug(" Matched on name: %s in used_symbols", name)
|
|
378
|
+
return True
|
|
379
|
+
return False
|
|
380
|
+
|
|
381
|
+
def _matches_alias(self, import_info: ImportInfo, used_symbols: Set[str]) -> bool:
|
|
382
|
+
"""Check if import matches on alias."""
|
|
383
|
+
for alias in import_info.aliases:
|
|
384
|
+
if alias in used_symbols:
|
|
385
|
+
logger.debug(" Matched on alias: %s in used_symbols", alias)
|
|
386
|
+
return True
|
|
387
|
+
return False
|
|
388
|
+
|
|
389
|
+
def _matches_module_prefix(
|
|
390
|
+
self, import_info: ImportInfo, used_symbols: Set[str]
|
|
391
|
+
) -> bool:
|
|
392
|
+
"""Check if import matches on module prefix."""
|
|
393
|
+
if import_info.import_type != "import":
|
|
394
|
+
return False
|
|
395
|
+
|
|
396
|
+
module_prefix = import_info.module + "."
|
|
397
|
+
for symbol in used_symbols:
|
|
398
|
+
if symbol.startswith(module_prefix):
|
|
399
|
+
logger.debug(" Matched on module prefix: %s", module_prefix)
|
|
400
|
+
return True
|
|
401
|
+
return False
|
|
402
|
+
|
|
403
|
+
def _group_imports(
|
|
404
|
+
self, matching_imports: list[tuple[ImportInfo, bool]]
|
|
405
|
+
) -> dict[tuple[str, tuple[str, ...]], list[tuple[ImportInfo, bool]]]:
|
|
406
|
+
"""Group imports by module and names."""
|
|
407
|
+
import_groups: dict[
|
|
408
|
+
tuple[str, tuple[str, ...]], list[tuple[ImportInfo, bool]]
|
|
409
|
+
] = {}
|
|
410
|
+
|
|
411
|
+
for import_info, is_internal_match in matching_imports:
|
|
412
|
+
module_short = (
|
|
413
|
+
import_info.module.split(".")[-1] if import_info.module else ""
|
|
414
|
+
)
|
|
415
|
+
names_key = tuple(sorted(import_info.names))
|
|
416
|
+
key = (module_short, names_key)
|
|
417
|
+
|
|
418
|
+
if key not in import_groups:
|
|
419
|
+
import_groups[key] = []
|
|
420
|
+
import_groups[key].append((import_info, is_internal_match))
|
|
421
|
+
|
|
422
|
+
return import_groups
|
|
423
|
+
|
|
424
|
+
def _select_best_imports(
|
|
425
|
+
self,
|
|
426
|
+
import_groups: dict[tuple[str, tuple[str, ...]], list[tuple[ImportInfo, bool]]],
|
|
427
|
+
) -> list[ImportInfo]:
|
|
428
|
+
"""Select the best import from each group."""
|
|
429
|
+
used_imports_list: list[ImportInfo] = []
|
|
430
|
+
|
|
431
|
+
for group_imports in import_groups.values():
|
|
432
|
+
group_imports.sort(key=lambda x: (not x[1], x[0].is_external))
|
|
433
|
+
used_imports_list.append(group_imports[0][0])
|
|
434
|
+
|
|
435
|
+
return used_imports_list
|
|
436
|
+
|
|
437
|
+
def _get_code_root_for_file(self, file_path: str) -> str:
|
|
438
|
+
"""
|
|
439
|
+
Get or auto-detect the code root for a specific file.
|
|
440
|
+
|
|
441
|
+
This allows handling files from different project roots within the same analysis.
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
file_path: Absolute path to Python file
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
Code root for this file (from cache or auto-detected)
|
|
448
|
+
"""
|
|
449
|
+
if file_path in self.file_to_code_root:
|
|
450
|
+
return self.file_to_code_root[file_path]
|
|
451
|
+
|
|
452
|
+
# Auto-detect code root for this file
|
|
453
|
+
try:
|
|
454
|
+
code_root = _auto_detect_code_root(file_path)
|
|
455
|
+
self.file_to_code_root[file_path] = code_root
|
|
456
|
+
logger.debug("Detected code root %s for file %s", code_root, file_path)
|
|
457
|
+
return code_root
|
|
458
|
+
except ValueError:
|
|
459
|
+
# Fall back to the analyzer's default code root
|
|
460
|
+
logger.warning(
|
|
461
|
+
"Could not auto-detect code root for %s, using default: %s",
|
|
462
|
+
file_path,
|
|
463
|
+
self.code_roots,
|
|
464
|
+
)
|
|
465
|
+
self.file_to_code_root[file_path] = self.code_roots
|
|
466
|
+
return self.code_roots
|
|
467
|
+
|
|
468
|
+
def _file_to_module_name(self, file_path: str) -> str:
|
|
469
|
+
"""
|
|
470
|
+
Convert file path to Python module name.
|
|
471
|
+
|
|
472
|
+
Example:
|
|
473
|
+
/data/users/wychi/fbsource/fbcode/pytorch/tritonparse/module.py
|
|
474
|
+
-> pytorch.tritonparse.module
|
|
475
|
+
|
|
476
|
+
Args:
|
|
477
|
+
file_path: Absolute path to Python file
|
|
478
|
+
|
|
479
|
+
Returns:
|
|
480
|
+
Module name as a dotted string
|
|
481
|
+
"""
|
|
482
|
+
code_root = self._get_code_root_for_file(file_path)
|
|
483
|
+
code_root_path = Path(code_root)
|
|
484
|
+
file = Path(file_path)
|
|
485
|
+
|
|
486
|
+
try:
|
|
487
|
+
rel_path = file.relative_to(code_root_path)
|
|
488
|
+
module_path = str(rel_path).replace("/", ".").removesuffix(".py")
|
|
489
|
+
return module_path
|
|
490
|
+
except ValueError:
|
|
491
|
+
# File is not under the detected code root
|
|
492
|
+
# Fall back to using the file's stem as the module name
|
|
493
|
+
logger.warning(
|
|
494
|
+
"File %s is not under code root %s, using file stem as module name",
|
|
495
|
+
file_path,
|
|
496
|
+
code_root,
|
|
497
|
+
)
|
|
498
|
+
return file.stem
|
|
499
|
+
|
|
500
|
+
def _consolidate_results(self) -> ConsolidatedResult:
|
|
501
|
+
"""
|
|
502
|
+
Consolidate results from all file analyzers.
|
|
503
|
+
|
|
504
|
+
This method:
|
|
505
|
+
1. Collects all functions and their source code from all files
|
|
506
|
+
2. Tracks function locations
|
|
507
|
+
3. Collects and deduplicates imports
|
|
508
|
+
4. Collects all call graph edges
|
|
509
|
+
5. Builds statistics
|
|
510
|
+
|
|
511
|
+
Returns:
|
|
512
|
+
ConsolidatedResult with all analysis results
|
|
513
|
+
"""
|
|
514
|
+
all_functions: dict[str, str] = {}
|
|
515
|
+
function_to_file: dict[str, str] = {}
|
|
516
|
+
all_edges: List[Edge] = []
|
|
517
|
+
|
|
518
|
+
functions_to_extract = self._collect_functions_to_extract(all_edges)
|
|
519
|
+
self._extract_function_sources(
|
|
520
|
+
functions_to_extract, all_functions, function_to_file
|
|
521
|
+
)
|
|
522
|
+
all_imports_list = self._collect_all_imports()
|
|
523
|
+
unique_imports = self._deduplicate_imports(all_imports_list)
|
|
524
|
+
|
|
525
|
+
stats = AnalysisStats(
|
|
526
|
+
total_files_analyzed=len(self.visited_files),
|
|
527
|
+
total_functions_found=len(all_functions),
|
|
528
|
+
total_imports=len(unique_imports),
|
|
529
|
+
external_imports=sum(1 for imp in unique_imports if imp.is_external),
|
|
530
|
+
internal_imports=sum(1 for imp in unique_imports if not imp.is_external),
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
return ConsolidatedResult(
|
|
534
|
+
functions=all_functions,
|
|
535
|
+
function_locations=function_to_file,
|
|
536
|
+
function_short_names={
|
|
537
|
+
qualified: qualified.split(".")[-1]
|
|
538
|
+
for qualified in all_functions.keys()
|
|
539
|
+
},
|
|
540
|
+
imports=unique_imports,
|
|
541
|
+
edges=all_edges,
|
|
542
|
+
analyzed_files=self.visited_files.copy(),
|
|
543
|
+
stats=stats,
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
def _collect_functions_to_extract(self, all_edges: List[Edge]) -> set[str]:
|
|
547
|
+
"""Collect all functions that need to be extracted from analyzers."""
|
|
548
|
+
functions_to_extract: set[str] = set()
|
|
549
|
+
|
|
550
|
+
for _file_path, analyzer in self.file_analyzers.items():
|
|
551
|
+
dependent_funcs = analyzer.get_dependent_functions()
|
|
552
|
+
functions_to_extract.update(dependent_funcs)
|
|
553
|
+
|
|
554
|
+
for backend in analyzer.backends:
|
|
555
|
+
for local_func in analyzer.local_functions:
|
|
556
|
+
if local_func.split(".")[-1] == backend:
|
|
557
|
+
# Skip the primary entry function - it's not a dependency
|
|
558
|
+
if local_func == self.qualified_backend:
|
|
559
|
+
logger.debug(
|
|
560
|
+
"Skipping entry function %s (not a dependency)",
|
|
561
|
+
local_func,
|
|
562
|
+
)
|
|
563
|
+
continue
|
|
564
|
+
|
|
565
|
+
logger.debug(
|
|
566
|
+
"Adding backend function %s from file %s (backend: %s)",
|
|
567
|
+
local_func,
|
|
568
|
+
_file_path,
|
|
569
|
+
backend,
|
|
570
|
+
)
|
|
571
|
+
functions_to_extract.add(local_func)
|
|
572
|
+
|
|
573
|
+
all_edges.extend(analyzer.edges)
|
|
574
|
+
|
|
575
|
+
return functions_to_extract
|
|
576
|
+
|
|
577
|
+
def _extract_function_sources(
|
|
578
|
+
self,
|
|
579
|
+
functions_to_extract: set[str],
|
|
580
|
+
all_functions: dict[str, str],
|
|
581
|
+
function_to_file: dict[str, str],
|
|
582
|
+
) -> None:
|
|
583
|
+
"""Extract source code for all functions from file analyzers."""
|
|
584
|
+
for file_path, analyzer in self.file_analyzers.items():
|
|
585
|
+
for func_name in functions_to_extract:
|
|
586
|
+
if func_name in analyzer.func_nodes:
|
|
587
|
+
self._extract_single_function_source(
|
|
588
|
+
func_name, file_path, analyzer, all_functions, function_to_file
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
def _extract_single_function_source(
|
|
592
|
+
self,
|
|
593
|
+
func_name: str,
|
|
594
|
+
file_path: str,
|
|
595
|
+
analyzer: CallGraph,
|
|
596
|
+
all_functions: dict[str, str],
|
|
597
|
+
function_to_file: dict[str, str],
|
|
598
|
+
) -> None:
|
|
599
|
+
"""Extract source code for a single function with source location comment."""
|
|
600
|
+
source_code_map = analyzer.get_dependent_functions_source_code()
|
|
601
|
+
|
|
602
|
+
if func_name in source_code_map:
|
|
603
|
+
# Source location comment is already added by get_dependent_functions_source_code
|
|
604
|
+
all_functions[func_name] = source_code_map[func_name]
|
|
605
|
+
function_to_file[func_name] = file_path
|
|
606
|
+
else:
|
|
607
|
+
self._extract_function_from_ast(
|
|
608
|
+
func_name, file_path, analyzer, all_functions, function_to_file
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
def _extract_function_from_ast(
|
|
612
|
+
self,
|
|
613
|
+
func_name: str,
|
|
614
|
+
file_path: str,
|
|
615
|
+
analyzer: CallGraph,
|
|
616
|
+
all_functions: dict[str, str],
|
|
617
|
+
function_to_file: dict[str, str],
|
|
618
|
+
) -> None:
|
|
619
|
+
"""Extract function source code directly from AST node with source location comment."""
|
|
620
|
+
node = analyzer.func_nodes[func_name]
|
|
621
|
+
|
|
622
|
+
if not analyzer.source_code:
|
|
623
|
+
with open(file_path) as f:
|
|
624
|
+
analyzer.source_code = f.read()
|
|
625
|
+
|
|
626
|
+
source_lines = analyzer.source_code.splitlines(keepends=True)
|
|
627
|
+
|
|
628
|
+
if node.decorator_list:
|
|
629
|
+
start_line = node.decorator_list[0].lineno
|
|
630
|
+
else:
|
|
631
|
+
start_line = node.lineno
|
|
632
|
+
|
|
633
|
+
end_line = node.end_lineno
|
|
634
|
+
|
|
635
|
+
if start_line is not None and end_line is not None:
|
|
636
|
+
func_source = "".join(source_lines[start_line - 1 : end_line])
|
|
637
|
+
# Add source location comment
|
|
638
|
+
source_comment = f"# Source: {file_path}:{start_line}-{end_line}\n"
|
|
639
|
+
all_functions[func_name] = source_comment + func_source
|
|
640
|
+
function_to_file[func_name] = file_path
|
|
641
|
+
|
|
642
|
+
def _collect_all_imports(self) -> list[ImportInfo]:
|
|
643
|
+
"""Collect all imports from visited files."""
|
|
644
|
+
all_imports_list: list[ImportInfo] = []
|
|
645
|
+
for file_path in self.visited_files:
|
|
646
|
+
if file_path in self.used_imports:
|
|
647
|
+
all_imports_list.extend(self.used_imports[file_path])
|
|
648
|
+
return all_imports_list
|
|
649
|
+
|
|
650
|
+
def _deduplicate_imports(self, imports: list[ImportInfo]) -> list[ImportInfo]:
|
|
651
|
+
"""
|
|
652
|
+
Deduplicate imports while preserving order.
|
|
653
|
+
|
|
654
|
+
Merges duplicate imports:
|
|
655
|
+
- from X import A
|
|
656
|
+
- from X import B
|
|
657
|
+
-> from X import A, B
|
|
658
|
+
|
|
659
|
+
Args:
|
|
660
|
+
imports: List of ImportInfo objects (may contain duplicates)
|
|
661
|
+
|
|
662
|
+
Returns:
|
|
663
|
+
Deduplicated list of ImportInfo objects
|
|
664
|
+
"""
|
|
665
|
+
import_groups: dict[tuple[str, str], ImportInfo] = {}
|
|
666
|
+
|
|
667
|
+
for imp in imports:
|
|
668
|
+
key = (imp.import_type, imp.module)
|
|
669
|
+
if key in import_groups:
|
|
670
|
+
existing = import_groups[key]
|
|
671
|
+
for name in imp.names:
|
|
672
|
+
if name not in existing.names:
|
|
673
|
+
existing.names.append(name)
|
|
674
|
+
existing.aliases.update(imp.aliases)
|
|
675
|
+
else:
|
|
676
|
+
import_groups[key] = ImportInfo(
|
|
677
|
+
import_type=imp.import_type,
|
|
678
|
+
module=imp.module,
|
|
679
|
+
names=imp.names.copy(),
|
|
680
|
+
source_file=imp.source_file,
|
|
681
|
+
resolved_path=imp.resolved_path,
|
|
682
|
+
is_external=imp.is_external,
|
|
683
|
+
lineno=imp.lineno,
|
|
684
|
+
aliases=imp.aliases.copy(),
|
|
685
|
+
level=imp.level,
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
return list(import_groups.values())
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
if __name__ == "__main__":
|
|
692
|
+
parser = argparse.ArgumentParser(
|
|
693
|
+
description="Multi-file call graph analyzer for Python code. "
|
|
694
|
+
"Analyzes a Python function and its dependencies across multiple files, "
|
|
695
|
+
"extracting all transitively-called functions and their imports."
|
|
696
|
+
)
|
|
697
|
+
parser.add_argument(
|
|
698
|
+
"--entry-file",
|
|
699
|
+
"-f",
|
|
700
|
+
required=True,
|
|
701
|
+
help="Path to the entry file containing the function to analyze",
|
|
702
|
+
)
|
|
703
|
+
parser.add_argument(
|
|
704
|
+
"--entry-function",
|
|
705
|
+
"-F",
|
|
706
|
+
required=True,
|
|
707
|
+
help="Name of the entry function to analyze",
|
|
708
|
+
)
|
|
709
|
+
parser.add_argument(
|
|
710
|
+
"--code-roots",
|
|
711
|
+
"-r",
|
|
712
|
+
default="",
|
|
713
|
+
help="Path to default code root directory (default: auto-detect from entry file path)",
|
|
714
|
+
)
|
|
715
|
+
parser.add_argument(
|
|
716
|
+
"--output",
|
|
717
|
+
"-o",
|
|
718
|
+
default=None,
|
|
719
|
+
help="Output JSON file path (default: creates a temp file in /tmp/)",
|
|
720
|
+
)
|
|
721
|
+
parser.add_argument(
|
|
722
|
+
"--verbose",
|
|
723
|
+
"-v",
|
|
724
|
+
action="store_true",
|
|
725
|
+
help="Enable verbose logging",
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
args = parser.parse_args()
|
|
729
|
+
|
|
730
|
+
logging.basicConfig(
|
|
731
|
+
level=logging.DEBUG if args.verbose else logging.INFO,
|
|
732
|
+
format="%(levelname)s: %(message)s",
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
# Auto-detect code_roots if not provided
|
|
736
|
+
code_roots = args.code_roots
|
|
737
|
+
if not code_roots:
|
|
738
|
+
try:
|
|
739
|
+
code_roots = _auto_detect_code_root(args.entry_file)
|
|
740
|
+
except ValueError as e:
|
|
741
|
+
logger.error("%s", e)
|
|
742
|
+
exit(1)
|
|
743
|
+
|
|
744
|
+
analyzer = MultiFileCallGraphAnalyzer(
|
|
745
|
+
entry_file=args.entry_file,
|
|
746
|
+
entry_function=args.entry_function,
|
|
747
|
+
code_roots=code_roots,
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
logger.info("Starting analysis of %s in %s", args.entry_function, args.entry_file)
|
|
751
|
+
result = analyzer.analyze()
|
|
752
|
+
|
|
753
|
+
logger.info("Analysis complete:")
|
|
754
|
+
logger.info(" Files analyzed: %d", result.stats.total_files_analyzed)
|
|
755
|
+
logger.info(" Functions found: %d", result.stats.total_functions_found)
|
|
756
|
+
logger.info(" Total imports: %d", result.stats.total_imports)
|
|
757
|
+
logger.info(" External imports: %d", result.stats.external_imports)
|
|
758
|
+
logger.info(" Internal imports: %d", result.stats.internal_imports)
|
|
759
|
+
|
|
760
|
+
logger.info("\nDependent functions (short names):")
|
|
761
|
+
for func_name in sorted(result.function_short_names.keys()):
|
|
762
|
+
short_name = result.function_short_names[func_name]
|
|
763
|
+
if short_name != args.entry_function:
|
|
764
|
+
logger.info(
|
|
765
|
+
" - %s. code size: %d", short_name, len(result.functions[func_name])
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
output_data = {
|
|
769
|
+
"entry_file": args.entry_file,
|
|
770
|
+
"entry_function": args.entry_function,
|
|
771
|
+
"qualified_backend": analyzer.qualified_backend,
|
|
772
|
+
"stats": {
|
|
773
|
+
"total_files_analyzed": result.stats.total_files_analyzed,
|
|
774
|
+
"total_functions_found": result.stats.total_functions_found,
|
|
775
|
+
"total_imports": result.stats.total_imports,
|
|
776
|
+
"external_imports": result.stats.external_imports,
|
|
777
|
+
"internal_imports": result.stats.internal_imports,
|
|
778
|
+
},
|
|
779
|
+
"analyzed_files": sorted(result.analyzed_files),
|
|
780
|
+
"functions": {
|
|
781
|
+
func_name: {
|
|
782
|
+
"source": source,
|
|
783
|
+
"file": result.function_locations.get(func_name, "unknown"),
|
|
784
|
+
"short_name": result.function_short_names.get(func_name, func_name),
|
|
785
|
+
}
|
|
786
|
+
for func_name, source in result.functions.items()
|
|
787
|
+
},
|
|
788
|
+
"imports": [
|
|
789
|
+
{
|
|
790
|
+
"import_type": imp.import_type,
|
|
791
|
+
"module": imp.module,
|
|
792
|
+
"names": imp.names,
|
|
793
|
+
"source_file": imp.source_file,
|
|
794
|
+
"resolved_path": imp.resolved_path,
|
|
795
|
+
"is_external": imp.is_external,
|
|
796
|
+
"lineno": imp.lineno,
|
|
797
|
+
"aliases": imp.aliases,
|
|
798
|
+
"level": imp.level,
|
|
799
|
+
}
|
|
800
|
+
for imp in result.imports
|
|
801
|
+
],
|
|
802
|
+
"edges": [
|
|
803
|
+
{"caller": edge.caller, "callee": edge.callee} for edge in result.edges
|
|
804
|
+
],
|
|
805
|
+
}
|
|
806
|
+
|
|
807
|
+
if args.output:
|
|
808
|
+
output_path = args.output
|
|
809
|
+
else:
|
|
810
|
+
# Create a temp file with a descriptive name
|
|
811
|
+
import os
|
|
812
|
+
|
|
813
|
+
temp_fd, temp_path = tempfile.mkstemp(
|
|
814
|
+
suffix=".json",
|
|
815
|
+
prefix=f"multi_file_analysis_{args.entry_function}_",
|
|
816
|
+
dir="/tmp",
|
|
817
|
+
text=True,
|
|
818
|
+
)
|
|
819
|
+
os.close(temp_fd) # Close the file descriptor
|
|
820
|
+
output_path = temp_path
|
|
821
|
+
|
|
822
|
+
with open(output_path, "w") as f:
|
|
823
|
+
json.dump(output_data, f, indent=2)
|
|
824
|
+
logger.info("Detailed results written to: %s", output_path)
|