tritonparse 0.3.2.dev20251210071601__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritonparse might be problematic. Click here for more details.

Files changed (62) hide show
  1. tritonparse/__init__.py +0 -0
  2. tritonparse/__main__.py +7 -0
  3. tritonparse/cli.py +110 -0
  4. tritonparse/common.py +409 -0
  5. tritonparse/context_manager.py +64 -0
  6. tritonparse/event_diff.py +122 -0
  7. tritonparse/extract_source_mappings.py +49 -0
  8. tritonparse/info/__init__.py +30 -0
  9. tritonparse/info/cli.py +121 -0
  10. tritonparse/info/kernel_query.py +209 -0
  11. tritonparse/info/parse_helper.py +70 -0
  12. tritonparse/ir_analysis.py +427 -0
  13. tritonparse/ir_parser.py +365 -0
  14. tritonparse/mapper.py +102 -0
  15. tritonparse/reproducer/__init__.py +0 -0
  16. tritonparse/reproducer/ast_analyzer.py +636 -0
  17. tritonparse/reproducer/cli.py +72 -0
  18. tritonparse/reproducer/consolidated_result.py +52 -0
  19. tritonparse/reproducer/function_extractor.py +228 -0
  20. tritonparse/reproducer/import_info.py +25 -0
  21. tritonparse/reproducer/import_parser.py +178 -0
  22. tritonparse/reproducer/import_resolver.py +151 -0
  23. tritonparse/reproducer/ingestion/ndjson.py +237 -0
  24. tritonparse/reproducer/multi_file_analyzer.py +824 -0
  25. tritonparse/reproducer/orchestrator.py +110 -0
  26. tritonparse/reproducer/placeholder_replacer.py +335 -0
  27. tritonparse/reproducer/templates/__init__.py +0 -0
  28. tritonparse/reproducer/templates/example.py +38 -0
  29. tritonparse/reproducer/templates/loader.py +59 -0
  30. tritonparse/reproducer/templates/tritonbench.py +106 -0
  31. tritonparse/reproducer/templates/utils.py +48 -0
  32. tritonparse/reproducer/tests/__init__.py +0 -0
  33. tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
  34. tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
  35. tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
  36. tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
  37. tritonparse/reproducer/tests/test_import_parser.py +164 -0
  38. tritonparse/reproducer/tests/test_import_resolver.py +88 -0
  39. tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
  40. tritonparse/reproducer/types.py +20 -0
  41. tritonparse/reproducer/utils.py +580 -0
  42. tritonparse/shared_vars.py +12 -0
  43. tritonparse/source_type.py +56 -0
  44. tritonparse/sourcemap_utils.py +96 -0
  45. tritonparse/structured_logging.py +1634 -0
  46. tritonparse/tools/__init__.py +0 -0
  47. tritonparse/tools/decompress_bin_ndjson.py +120 -0
  48. tritonparse/tools/disasm.py +81 -0
  49. tritonparse/tools/extract_irs.py +244 -0
  50. tritonparse/tools/format_fix.py +151 -0
  51. tritonparse/tools/load_tensor.py +76 -0
  52. tritonparse/tools/prettify_ndjson.py +334 -0
  53. tritonparse/tools/readme.md +37 -0
  54. tritonparse/tp_logger.py +9 -0
  55. tritonparse/trace_processor.py +367 -0
  56. tritonparse/utils.py +155 -0
  57. tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
  58. tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
  59. tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
  60. tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
  61. tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
  62. tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
@@ -0,0 +1,824 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # pyre-strict
4
+
5
+ """
6
+ MultiFileCallGraphAnalyzer: Orchestrates multi-file static call graph analysis.
7
+
8
+ This module provides the main analyzer that coordinates traversal across multiple
9
+ Python files, following imports to extract all transitively-called functions.
10
+ """
11
+
12
+ import argparse
13
+ import ast
14
+ import json
15
+ import logging
16
+ import tempfile
17
+ from pathlib import Path
18
+ from typing import List, Optional, Set
19
+
20
+ from tritonparse.reproducer.ast_analyzer import CallGraph, Edge
21
+ from tritonparse.reproducer.consolidated_result import AnalysisStats, ConsolidatedResult
22
+ from tritonparse.reproducer.import_info import ImportInfo
23
+ from tritonparse.reproducer.import_parser import ImportParser
24
+ from tritonparse.reproducer.import_resolver import ImportResolver
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def _auto_detect_code_root(entry_file: str) -> str:
30
+ """
31
+ Auto-detect code root from entry file path.
32
+
33
+ Walks up the directory tree from the entry file until it finds a common
34
+ code root indicator. Currently supports:
35
+ - "fbcode" directories (Meta's monorepo structure)
36
+ - Directories containing "setup.py", "pyproject.toml" (Python projects)
37
+ - Git repository roots (directories containing ".git")
38
+
39
+ Args:
40
+ entry_file: Absolute path to the entry file
41
+
42
+ Returns:
43
+ Absolute path to the code root directory
44
+
45
+ Raises:
46
+ ValueError: If code root cannot be detected from the entry file path
47
+ """
48
+ path = Path(entry_file).resolve()
49
+ logger.debug("Auto-detecting code root from: %s", path)
50
+
51
+ # Walk up the directory tree
52
+ for parent in path.parents:
53
+ # Check for fbcode (Meta's monorepo)
54
+ if parent.name == "fbcode" or parent.name == "fbsource":
55
+ logger.info("Auto-detected code root (fbcode): %s", parent)
56
+ return str(parent)
57
+
58
+ # Check for Python project markers
59
+ if (parent / "setup.py").exists() or (parent / "pyproject.toml").exists():
60
+ logger.info("Auto-detected code root (Python project): %s", parent)
61
+ return str(parent)
62
+
63
+ # Check for Git repository root
64
+ if (parent / ".git").exists():
65
+ logger.info("Auto-detected code root (Git repository): %s", parent)
66
+ return str(parent)
67
+
68
+ raise ValueError(
69
+ f"Could not auto-detect code root from entry file: {entry_file}. "
70
+ "Please specify --code-root explicitly."
71
+ )
72
+
73
+
74
+ class MultiFileCallGraphAnalyzer:
75
+ """
76
+ Multi-file static call graph analyzer for creating reproducers.
77
+
78
+ This analyzer orchestrates the analysis across multiple Python files by:
79
+ 1. Starting with an entry file and function
80
+ 2. Using CallGraph to extract dependencies within each file
81
+ 3. Following imports to analyze dependent files
82
+ 4. Consolidating results from all analyzed files
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ entry_file: str,
88
+ entry_function: str,
89
+ code_roots: Optional[str] = None,
90
+ ) -> None:
91
+ """
92
+ Initialize multi-file analyzer.
93
+
94
+ The analyzer automatically computes the qualified backend name from
95
+ entry_file + entry_function, removing the need for manual backend specification.
96
+
97
+ Args:
98
+ entry_file: Path to file containing entry function
99
+ entry_function: Name of entry function (short name, e.g., "main_kernel")
100
+ code_roots: Absolute path to default code root directory (auto-detected if None)
101
+
102
+ Example:
103
+ >>> analyzer = MultiFileCallGraphAnalyzer(
104
+ ... entry_file="/path/to/kernel.py",
105
+ ... entry_function="main_kernel",
106
+ ... code_roots="/path/to/fbcode", # Optional
107
+ ... )
108
+ """
109
+ self.entry_file = entry_file
110
+ self.entry_function = entry_function
111
+
112
+ # Auto-detect code_roots if not provided
113
+ if code_roots is None:
114
+ code_roots = _auto_detect_code_root(entry_file)
115
+ self.code_roots = code_roots
116
+
117
+ # Track multiple code roots for files from different projects
118
+ self.file_to_code_root: dict[str, str] = {entry_file: code_roots}
119
+
120
+ entry_module = self._file_to_module_name(entry_file)
121
+ qualified_backend = f"{entry_module}.{entry_function}"
122
+ self.backends = [entry_function]
123
+ self.qualified_backend = qualified_backend
124
+
125
+ self.import_resolver = ImportResolver(code_roots)
126
+ self.import_parser = ImportParser(self.import_resolver)
127
+
128
+ self.visited_files: Set[str] = set()
129
+ self.pending_files: list[str] = []
130
+ self.file_analyzers: dict[str, CallGraph] = {}
131
+ self.all_imports: dict[str, list[ImportInfo]] = {}
132
+ self.used_imports: dict[str, list[ImportInfo]] = {}
133
+
134
+ def analyze(self) -> ConsolidatedResult:
135
+ """
136
+ Perform multi-file analysis.
137
+
138
+ This is the main entry point that:
139
+ 1. Analyzes entry file
140
+ 2. Recursively analyzes imported files (breadth-first)
141
+ 3. Consolidates results
142
+
143
+ Returns:
144
+ ConsolidatedResult with functions, imports, and statistics
145
+ """
146
+ self._analyze_file(self.entry_file)
147
+
148
+ while self.pending_files:
149
+ pending_item = self.pending_files.pop(0)
150
+
151
+ if isinstance(pending_item, tuple):
152
+ file_path, backends_for_file = pending_item
153
+ else:
154
+ file_path = pending_item
155
+ backends_for_file = None
156
+
157
+ if file_path not in self.visited_files:
158
+ self._analyze_file(file_path, backends_for_file)
159
+
160
+ return self._consolidate_results()
161
+
162
+ def _analyze_file(
163
+ self, file_path: str, backends_for_file: Optional[list[str]] = None
164
+ ) -> None:
165
+ """
166
+ Analyze a single file with CallGraph.
167
+
168
+ This method:
169
+ 1. Marks file as visited
170
+ 2. Creates CallGraph analyzer for the file
171
+ 3. Extracts dependent functions
172
+ 4. Parses imports from AST
173
+ 5. Identifies which imports are used by dependent functions
174
+ 6. Adds imported files to pending_files queue
175
+
176
+ Args:
177
+ file_path: Absolute path to the Python file to analyze
178
+ backends_for_file: Specific backends for this file (defaults to self.backends for entry file)
179
+ """
180
+ self.visited_files.add(file_path)
181
+ logger.info("Analyzing %s", file_path)
182
+
183
+ with open(file_path) as f:
184
+ source_code = f.read()
185
+ tree = ast.parse(source_code, filename=file_path)
186
+
187
+ file_backends = (
188
+ backends_for_file if backends_for_file is not None else self.backends
189
+ )
190
+
191
+ module_name = self._file_to_module_name(file_path)
192
+ analyzer = CallGraph(
193
+ filename=file_path,
194
+ module_name=module_name,
195
+ backends=file_backends,
196
+ transitive_closure=True,
197
+ )
198
+ analyzer.visit(tree)
199
+
200
+ self.file_analyzers[file_path] = analyzer
201
+
202
+ module_name = self._file_to_module_name(file_path)
203
+ package_name = (
204
+ ".".join(module_name.split(".")[:-1]) if "." in module_name else None
205
+ )
206
+ imports = self.import_parser.parse_imports(
207
+ tree, file_path, package=package_name
208
+ )
209
+ self.all_imports[file_path] = imports
210
+
211
+ used_imports_list = self._identify_used_imports(imports, analyzer)
212
+ self.used_imports[file_path] = used_imports_list
213
+
214
+ imports_by_file: dict[str, list[str]] = {}
215
+ for import_info in used_imports_list:
216
+ if (
217
+ import_info.resolved_path
218
+ and not import_info.is_external
219
+ and import_info.resolved_path not in self.visited_files
220
+ ):
221
+ if import_info.resolved_path not in imports_by_file:
222
+ imports_by_file[import_info.resolved_path] = []
223
+ imports_by_file[import_info.resolved_path].extend(import_info.names)
224
+
225
+ if imports_by_file:
226
+ logger.info(
227
+ "Found %d internal import file(s) to analyze from %s",
228
+ len(imports_by_file),
229
+ file_path,
230
+ )
231
+ for resolved_path, imported_names in imports_by_file.items():
232
+ logger.debug(
233
+ " → Will analyze %s for functions: %s",
234
+ resolved_path,
235
+ imported_names,
236
+ )
237
+ else:
238
+ logger.debug(
239
+ "No internal imports found in %s (all imports are external or already visited)",
240
+ file_path,
241
+ )
242
+
243
+ for resolved_path, imported_names in imports_by_file.items():
244
+ already_pending = any(
245
+ (
246
+ pending_item[0] == resolved_path
247
+ if isinstance(pending_item, tuple)
248
+ else pending_item == resolved_path
249
+ )
250
+ for pending_item in self.pending_files
251
+ )
252
+ if not already_pending:
253
+ unique_names = list(dict.fromkeys(imported_names))
254
+ self.pending_files.append((resolved_path, unique_names))
255
+ logger.info(
256
+ "Added %s to pending analysis queue with backends: %s",
257
+ resolved_path,
258
+ unique_names,
259
+ )
260
+
261
+ def _identify_used_imports(
262
+ self,
263
+ imports: list[ImportInfo],
264
+ analyzer: CallGraph,
265
+ ) -> list[ImportInfo]:
266
+ """
267
+ Identify which imports are actually used by dependent functions.
268
+
269
+ Strategy:
270
+ 1. Get all callees from dependent functions (from call graph edges)
271
+ 2. Match callees to import statements
272
+ 3. Return only imports that are actually used
273
+
274
+ Args:
275
+ imports: All imports in the file
276
+ analyzer: CallGraph analyzer for this file
277
+
278
+ Returns:
279
+ List of ImportInfo objects for imports that are used
280
+ """
281
+ used_symbols = self._extract_used_symbols(analyzer)
282
+ logger.debug("Used symbols: %s", used_symbols)
283
+
284
+ matching_imports = self._find_matching_imports(imports, used_symbols)
285
+ import_groups = self._group_imports(matching_imports)
286
+ used_imports_list = self._select_best_imports(import_groups)
287
+
288
+ return used_imports_list
289
+
290
+ def _extract_used_symbols(self, analyzer: CallGraph) -> Set[str]:
291
+ """Extract all symbols used in the call graph."""
292
+ used_symbols: Set[str] = set()
293
+ for edge in analyzer.edges:
294
+ used_symbols.add(edge.callee)
295
+ callee_parts = edge.callee.split(".")
296
+ if callee_parts:
297
+ used_symbols.add(callee_parts[0])
298
+ return used_symbols
299
+
300
+ def _find_matching_imports(
301
+ self, imports: list[ImportInfo], used_symbols: Set[str]
302
+ ) -> list[tuple[ImportInfo, bool]]:
303
+ """Find imports that match the used symbols."""
304
+ matching_imports: list[tuple[ImportInfo, bool]] = []
305
+
306
+ for import_info in imports:
307
+ logger.debug(
308
+ "Checking import: %s %s -> %s (external: %s)",
309
+ import_info.import_type,
310
+ import_info.module,
311
+ import_info.names,
312
+ import_info.is_external,
313
+ )
314
+
315
+ is_used = self._is_import_used(import_info, used_symbols)
316
+ if is_used:
317
+ is_internal_match = not import_info.is_external
318
+ matching_imports.append((import_info, is_internal_match))
319
+
320
+ return matching_imports
321
+
322
+ def _is_import_used(self, import_info: ImportInfo, used_symbols: Set[str]) -> bool:
323
+ """Check if an import is used based on the used symbols."""
324
+ if self._matches_qualified_name(import_info, used_symbols):
325
+ return True
326
+ if self._matches_short_name(import_info, used_symbols):
327
+ return True
328
+ if self._matches_name(import_info, used_symbols):
329
+ return True
330
+ if self._matches_alias(import_info, used_symbols):
331
+ return True
332
+ if self._matches_module_prefix(import_info, used_symbols):
333
+ return True
334
+ return False
335
+
336
+ def _matches_qualified_name(
337
+ self, import_info: ImportInfo, used_symbols: Set[str]
338
+ ) -> bool:
339
+ """Check if import matches on qualified name."""
340
+ if import_info.import_type != "from_import" or not import_info.module:
341
+ return False
342
+
343
+ for name in import_info.names:
344
+ qualified_name = f"{import_info.module}.{name}"
345
+ for symbol in used_symbols:
346
+ if symbol == qualified_name or symbol.startswith(qualified_name + "."):
347
+ logger.debug(
348
+ " Matched on qualified name: %s == %s",
349
+ qualified_name,
350
+ symbol,
351
+ )
352
+ return True
353
+ return False
354
+
355
+ def _matches_short_name(
356
+ self, import_info: ImportInfo, used_symbols: Set[str]
357
+ ) -> bool:
358
+ """Check if import matches on short module name."""
359
+ if import_info.import_type != "from_import" or not import_info.module:
360
+ return False
361
+
362
+ module_short_name = import_info.module.split(".")[-1]
363
+ for name in import_info.names:
364
+ short_qualified = f"{module_short_name}.{name}"
365
+ if short_qualified in used_symbols:
366
+ logger.debug(
367
+ " Matched on short name: %s in used_symbols",
368
+ short_qualified,
369
+ )
370
+ return True
371
+ return False
372
+
373
+ def _matches_name(self, import_info: ImportInfo, used_symbols: Set[str]) -> bool:
374
+ """Check if import matches on name."""
375
+ for name in import_info.names:
376
+ if name in used_symbols:
377
+ logger.debug(" Matched on name: %s in used_symbols", name)
378
+ return True
379
+ return False
380
+
381
+ def _matches_alias(self, import_info: ImportInfo, used_symbols: Set[str]) -> bool:
382
+ """Check if import matches on alias."""
383
+ for alias in import_info.aliases:
384
+ if alias in used_symbols:
385
+ logger.debug(" Matched on alias: %s in used_symbols", alias)
386
+ return True
387
+ return False
388
+
389
+ def _matches_module_prefix(
390
+ self, import_info: ImportInfo, used_symbols: Set[str]
391
+ ) -> bool:
392
+ """Check if import matches on module prefix."""
393
+ if import_info.import_type != "import":
394
+ return False
395
+
396
+ module_prefix = import_info.module + "."
397
+ for symbol in used_symbols:
398
+ if symbol.startswith(module_prefix):
399
+ logger.debug(" Matched on module prefix: %s", module_prefix)
400
+ return True
401
+ return False
402
+
403
+ def _group_imports(
404
+ self, matching_imports: list[tuple[ImportInfo, bool]]
405
+ ) -> dict[tuple[str, tuple[str, ...]], list[tuple[ImportInfo, bool]]]:
406
+ """Group imports by module and names."""
407
+ import_groups: dict[
408
+ tuple[str, tuple[str, ...]], list[tuple[ImportInfo, bool]]
409
+ ] = {}
410
+
411
+ for import_info, is_internal_match in matching_imports:
412
+ module_short = (
413
+ import_info.module.split(".")[-1] if import_info.module else ""
414
+ )
415
+ names_key = tuple(sorted(import_info.names))
416
+ key = (module_short, names_key)
417
+
418
+ if key not in import_groups:
419
+ import_groups[key] = []
420
+ import_groups[key].append((import_info, is_internal_match))
421
+
422
+ return import_groups
423
+
424
+ def _select_best_imports(
425
+ self,
426
+ import_groups: dict[tuple[str, tuple[str, ...]], list[tuple[ImportInfo, bool]]],
427
+ ) -> list[ImportInfo]:
428
+ """Select the best import from each group."""
429
+ used_imports_list: list[ImportInfo] = []
430
+
431
+ for group_imports in import_groups.values():
432
+ group_imports.sort(key=lambda x: (not x[1], x[0].is_external))
433
+ used_imports_list.append(group_imports[0][0])
434
+
435
+ return used_imports_list
436
+
437
+ def _get_code_root_for_file(self, file_path: str) -> str:
438
+ """
439
+ Get or auto-detect the code root for a specific file.
440
+
441
+ This allows handling files from different project roots within the same analysis.
442
+
443
+ Args:
444
+ file_path: Absolute path to Python file
445
+
446
+ Returns:
447
+ Code root for this file (from cache or auto-detected)
448
+ """
449
+ if file_path in self.file_to_code_root:
450
+ return self.file_to_code_root[file_path]
451
+
452
+ # Auto-detect code root for this file
453
+ try:
454
+ code_root = _auto_detect_code_root(file_path)
455
+ self.file_to_code_root[file_path] = code_root
456
+ logger.debug("Detected code root %s for file %s", code_root, file_path)
457
+ return code_root
458
+ except ValueError:
459
+ # Fall back to the analyzer's default code root
460
+ logger.warning(
461
+ "Could not auto-detect code root for %s, using default: %s",
462
+ file_path,
463
+ self.code_roots,
464
+ )
465
+ self.file_to_code_root[file_path] = self.code_roots
466
+ return self.code_roots
467
+
468
+ def _file_to_module_name(self, file_path: str) -> str:
469
+ """
470
+ Convert file path to Python module name.
471
+
472
+ Example:
473
+ /data/users/wychi/fbsource/fbcode/pytorch/tritonparse/module.py
474
+ -> pytorch.tritonparse.module
475
+
476
+ Args:
477
+ file_path: Absolute path to Python file
478
+
479
+ Returns:
480
+ Module name as a dotted string
481
+ """
482
+ code_root = self._get_code_root_for_file(file_path)
483
+ code_root_path = Path(code_root)
484
+ file = Path(file_path)
485
+
486
+ try:
487
+ rel_path = file.relative_to(code_root_path)
488
+ module_path = str(rel_path).replace("/", ".").removesuffix(".py")
489
+ return module_path
490
+ except ValueError:
491
+ # File is not under the detected code root
492
+ # Fall back to using the file's stem as the module name
493
+ logger.warning(
494
+ "File %s is not under code root %s, using file stem as module name",
495
+ file_path,
496
+ code_root,
497
+ )
498
+ return file.stem
499
+
500
+ def _consolidate_results(self) -> ConsolidatedResult:
501
+ """
502
+ Consolidate results from all file analyzers.
503
+
504
+ This method:
505
+ 1. Collects all functions and their source code from all files
506
+ 2. Tracks function locations
507
+ 3. Collects and deduplicates imports
508
+ 4. Collects all call graph edges
509
+ 5. Builds statistics
510
+
511
+ Returns:
512
+ ConsolidatedResult with all analysis results
513
+ """
514
+ all_functions: dict[str, str] = {}
515
+ function_to_file: dict[str, str] = {}
516
+ all_edges: List[Edge] = []
517
+
518
+ functions_to_extract = self._collect_functions_to_extract(all_edges)
519
+ self._extract_function_sources(
520
+ functions_to_extract, all_functions, function_to_file
521
+ )
522
+ all_imports_list = self._collect_all_imports()
523
+ unique_imports = self._deduplicate_imports(all_imports_list)
524
+
525
+ stats = AnalysisStats(
526
+ total_files_analyzed=len(self.visited_files),
527
+ total_functions_found=len(all_functions),
528
+ total_imports=len(unique_imports),
529
+ external_imports=sum(1 for imp in unique_imports if imp.is_external),
530
+ internal_imports=sum(1 for imp in unique_imports if not imp.is_external),
531
+ )
532
+
533
+ return ConsolidatedResult(
534
+ functions=all_functions,
535
+ function_locations=function_to_file,
536
+ function_short_names={
537
+ qualified: qualified.split(".")[-1]
538
+ for qualified in all_functions.keys()
539
+ },
540
+ imports=unique_imports,
541
+ edges=all_edges,
542
+ analyzed_files=self.visited_files.copy(),
543
+ stats=stats,
544
+ )
545
+
546
+ def _collect_functions_to_extract(self, all_edges: List[Edge]) -> set[str]:
547
+ """Collect all functions that need to be extracted from analyzers."""
548
+ functions_to_extract: set[str] = set()
549
+
550
+ for _file_path, analyzer in self.file_analyzers.items():
551
+ dependent_funcs = analyzer.get_dependent_functions()
552
+ functions_to_extract.update(dependent_funcs)
553
+
554
+ for backend in analyzer.backends:
555
+ for local_func in analyzer.local_functions:
556
+ if local_func.split(".")[-1] == backend:
557
+ # Skip the primary entry function - it's not a dependency
558
+ if local_func == self.qualified_backend:
559
+ logger.debug(
560
+ "Skipping entry function %s (not a dependency)",
561
+ local_func,
562
+ )
563
+ continue
564
+
565
+ logger.debug(
566
+ "Adding backend function %s from file %s (backend: %s)",
567
+ local_func,
568
+ _file_path,
569
+ backend,
570
+ )
571
+ functions_to_extract.add(local_func)
572
+
573
+ all_edges.extend(analyzer.edges)
574
+
575
+ return functions_to_extract
576
+
577
+ def _extract_function_sources(
578
+ self,
579
+ functions_to_extract: set[str],
580
+ all_functions: dict[str, str],
581
+ function_to_file: dict[str, str],
582
+ ) -> None:
583
+ """Extract source code for all functions from file analyzers."""
584
+ for file_path, analyzer in self.file_analyzers.items():
585
+ for func_name in functions_to_extract:
586
+ if func_name in analyzer.func_nodes:
587
+ self._extract_single_function_source(
588
+ func_name, file_path, analyzer, all_functions, function_to_file
589
+ )
590
+
591
+ def _extract_single_function_source(
592
+ self,
593
+ func_name: str,
594
+ file_path: str,
595
+ analyzer: CallGraph,
596
+ all_functions: dict[str, str],
597
+ function_to_file: dict[str, str],
598
+ ) -> None:
599
+ """Extract source code for a single function with source location comment."""
600
+ source_code_map = analyzer.get_dependent_functions_source_code()
601
+
602
+ if func_name in source_code_map:
603
+ # Source location comment is already added by get_dependent_functions_source_code
604
+ all_functions[func_name] = source_code_map[func_name]
605
+ function_to_file[func_name] = file_path
606
+ else:
607
+ self._extract_function_from_ast(
608
+ func_name, file_path, analyzer, all_functions, function_to_file
609
+ )
610
+
611
+ def _extract_function_from_ast(
612
+ self,
613
+ func_name: str,
614
+ file_path: str,
615
+ analyzer: CallGraph,
616
+ all_functions: dict[str, str],
617
+ function_to_file: dict[str, str],
618
+ ) -> None:
619
+ """Extract function source code directly from AST node with source location comment."""
620
+ node = analyzer.func_nodes[func_name]
621
+
622
+ if not analyzer.source_code:
623
+ with open(file_path) as f:
624
+ analyzer.source_code = f.read()
625
+
626
+ source_lines = analyzer.source_code.splitlines(keepends=True)
627
+
628
+ if node.decorator_list:
629
+ start_line = node.decorator_list[0].lineno
630
+ else:
631
+ start_line = node.lineno
632
+
633
+ end_line = node.end_lineno
634
+
635
+ if start_line is not None and end_line is not None:
636
+ func_source = "".join(source_lines[start_line - 1 : end_line])
637
+ # Add source location comment
638
+ source_comment = f"# Source: {file_path}:{start_line}-{end_line}\n"
639
+ all_functions[func_name] = source_comment + func_source
640
+ function_to_file[func_name] = file_path
641
+
642
+ def _collect_all_imports(self) -> list[ImportInfo]:
643
+ """Collect all imports from visited files."""
644
+ all_imports_list: list[ImportInfo] = []
645
+ for file_path in self.visited_files:
646
+ if file_path in self.used_imports:
647
+ all_imports_list.extend(self.used_imports[file_path])
648
+ return all_imports_list
649
+
650
+ def _deduplicate_imports(self, imports: list[ImportInfo]) -> list[ImportInfo]:
651
+ """
652
+ Deduplicate imports while preserving order.
653
+
654
+ Merges duplicate imports:
655
+ - from X import A
656
+ - from X import B
657
+ -> from X import A, B
658
+
659
+ Args:
660
+ imports: List of ImportInfo objects (may contain duplicates)
661
+
662
+ Returns:
663
+ Deduplicated list of ImportInfo objects
664
+ """
665
+ import_groups: dict[tuple[str, str], ImportInfo] = {}
666
+
667
+ for imp in imports:
668
+ key = (imp.import_type, imp.module)
669
+ if key in import_groups:
670
+ existing = import_groups[key]
671
+ for name in imp.names:
672
+ if name not in existing.names:
673
+ existing.names.append(name)
674
+ existing.aliases.update(imp.aliases)
675
+ else:
676
+ import_groups[key] = ImportInfo(
677
+ import_type=imp.import_type,
678
+ module=imp.module,
679
+ names=imp.names.copy(),
680
+ source_file=imp.source_file,
681
+ resolved_path=imp.resolved_path,
682
+ is_external=imp.is_external,
683
+ lineno=imp.lineno,
684
+ aliases=imp.aliases.copy(),
685
+ level=imp.level,
686
+ )
687
+
688
+ return list(import_groups.values())
689
+
690
+
691
+ if __name__ == "__main__":
692
+ parser = argparse.ArgumentParser(
693
+ description="Multi-file call graph analyzer for Python code. "
694
+ "Analyzes a Python function and its dependencies across multiple files, "
695
+ "extracting all transitively-called functions and their imports."
696
+ )
697
+ parser.add_argument(
698
+ "--entry-file",
699
+ "-f",
700
+ required=True,
701
+ help="Path to the entry file containing the function to analyze",
702
+ )
703
+ parser.add_argument(
704
+ "--entry-function",
705
+ "-F",
706
+ required=True,
707
+ help="Name of the entry function to analyze",
708
+ )
709
+ parser.add_argument(
710
+ "--code-roots",
711
+ "-r",
712
+ default="",
713
+ help="Path to default code root directory (default: auto-detect from entry file path)",
714
+ )
715
+ parser.add_argument(
716
+ "--output",
717
+ "-o",
718
+ default=None,
719
+ help="Output JSON file path (default: creates a temp file in /tmp/)",
720
+ )
721
+ parser.add_argument(
722
+ "--verbose",
723
+ "-v",
724
+ action="store_true",
725
+ help="Enable verbose logging",
726
+ )
727
+
728
+ args = parser.parse_args()
729
+
730
+ logging.basicConfig(
731
+ level=logging.DEBUG if args.verbose else logging.INFO,
732
+ format="%(levelname)s: %(message)s",
733
+ )
734
+
735
+ # Auto-detect code_roots if not provided
736
+ code_roots = args.code_roots
737
+ if not code_roots:
738
+ try:
739
+ code_roots = _auto_detect_code_root(args.entry_file)
740
+ except ValueError as e:
741
+ logger.error("%s", e)
742
+ exit(1)
743
+
744
+ analyzer = MultiFileCallGraphAnalyzer(
745
+ entry_file=args.entry_file,
746
+ entry_function=args.entry_function,
747
+ code_roots=code_roots,
748
+ )
749
+
750
+ logger.info("Starting analysis of %s in %s", args.entry_function, args.entry_file)
751
+ result = analyzer.analyze()
752
+
753
+ logger.info("Analysis complete:")
754
+ logger.info(" Files analyzed: %d", result.stats.total_files_analyzed)
755
+ logger.info(" Functions found: %d", result.stats.total_functions_found)
756
+ logger.info(" Total imports: %d", result.stats.total_imports)
757
+ logger.info(" External imports: %d", result.stats.external_imports)
758
+ logger.info(" Internal imports: %d", result.stats.internal_imports)
759
+
760
+ logger.info("\nDependent functions (short names):")
761
+ for func_name in sorted(result.function_short_names.keys()):
762
+ short_name = result.function_short_names[func_name]
763
+ if short_name != args.entry_function:
764
+ logger.info(
765
+ " - %s. code size: %d", short_name, len(result.functions[func_name])
766
+ )
767
+
768
+ output_data = {
769
+ "entry_file": args.entry_file,
770
+ "entry_function": args.entry_function,
771
+ "qualified_backend": analyzer.qualified_backend,
772
+ "stats": {
773
+ "total_files_analyzed": result.stats.total_files_analyzed,
774
+ "total_functions_found": result.stats.total_functions_found,
775
+ "total_imports": result.stats.total_imports,
776
+ "external_imports": result.stats.external_imports,
777
+ "internal_imports": result.stats.internal_imports,
778
+ },
779
+ "analyzed_files": sorted(result.analyzed_files),
780
+ "functions": {
781
+ func_name: {
782
+ "source": source,
783
+ "file": result.function_locations.get(func_name, "unknown"),
784
+ "short_name": result.function_short_names.get(func_name, func_name),
785
+ }
786
+ for func_name, source in result.functions.items()
787
+ },
788
+ "imports": [
789
+ {
790
+ "import_type": imp.import_type,
791
+ "module": imp.module,
792
+ "names": imp.names,
793
+ "source_file": imp.source_file,
794
+ "resolved_path": imp.resolved_path,
795
+ "is_external": imp.is_external,
796
+ "lineno": imp.lineno,
797
+ "aliases": imp.aliases,
798
+ "level": imp.level,
799
+ }
800
+ for imp in result.imports
801
+ ],
802
+ "edges": [
803
+ {"caller": edge.caller, "callee": edge.callee} for edge in result.edges
804
+ ],
805
+ }
806
+
807
+ if args.output:
808
+ output_path = args.output
809
+ else:
810
+ # Create a temp file with a descriptive name
811
+ import os
812
+
813
+ temp_fd, temp_path = tempfile.mkstemp(
814
+ suffix=".json",
815
+ prefix=f"multi_file_analysis_{args.entry_function}_",
816
+ dir="/tmp",
817
+ text=True,
818
+ )
819
+ os.close(temp_fd) # Close the file descriptor
820
+ output_path = temp_path
821
+
822
+ with open(output_path, "w") as f:
823
+ json.dump(output_data, f, indent=2)
824
+ logger.info("Detailed results written to: %s", output_path)