tritonparse 0.3.2.dev20251210071601__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritonparse might be problematic. Click here for more details.
- tritonparse/__init__.py +0 -0
- tritonparse/__main__.py +7 -0
- tritonparse/cli.py +110 -0
- tritonparse/common.py +409 -0
- tritonparse/context_manager.py +64 -0
- tritonparse/event_diff.py +122 -0
- tritonparse/extract_source_mappings.py +49 -0
- tritonparse/info/__init__.py +30 -0
- tritonparse/info/cli.py +121 -0
- tritonparse/info/kernel_query.py +209 -0
- tritonparse/info/parse_helper.py +70 -0
- tritonparse/ir_analysis.py +427 -0
- tritonparse/ir_parser.py +365 -0
- tritonparse/mapper.py +102 -0
- tritonparse/reproducer/__init__.py +0 -0
- tritonparse/reproducer/ast_analyzer.py +636 -0
- tritonparse/reproducer/cli.py +72 -0
- tritonparse/reproducer/consolidated_result.py +52 -0
- tritonparse/reproducer/function_extractor.py +228 -0
- tritonparse/reproducer/import_info.py +25 -0
- tritonparse/reproducer/import_parser.py +178 -0
- tritonparse/reproducer/import_resolver.py +151 -0
- tritonparse/reproducer/ingestion/ndjson.py +237 -0
- tritonparse/reproducer/multi_file_analyzer.py +824 -0
- tritonparse/reproducer/orchestrator.py +110 -0
- tritonparse/reproducer/placeholder_replacer.py +335 -0
- tritonparse/reproducer/templates/__init__.py +0 -0
- tritonparse/reproducer/templates/example.py +38 -0
- tritonparse/reproducer/templates/loader.py +59 -0
- tritonparse/reproducer/templates/tritonbench.py +106 -0
- tritonparse/reproducer/templates/utils.py +48 -0
- tritonparse/reproducer/tests/__init__.py +0 -0
- tritonparse/reproducer/tests/artifacts/__init__.py +5 -0
- tritonparse/reproducer/tests/artifacts/triton_fused_kernel.py +65 -0
- tritonparse/reproducer/tests/artifacts/triton_preprocess.py +16 -0
- tritonparse/reproducer/tests/artifacts/triton_utils.py +14 -0
- tritonparse/reproducer/tests/test_import_parser.py +164 -0
- tritonparse/reproducer/tests/test_import_resolver.py +88 -0
- tritonparse/reproducer/tests/test_multi_file_analyzer.py +118 -0
- tritonparse/reproducer/types.py +20 -0
- tritonparse/reproducer/utils.py +580 -0
- tritonparse/shared_vars.py +12 -0
- tritonparse/source_type.py +56 -0
- tritonparse/sourcemap_utils.py +96 -0
- tritonparse/structured_logging.py +1634 -0
- tritonparse/tools/__init__.py +0 -0
- tritonparse/tools/decompress_bin_ndjson.py +120 -0
- tritonparse/tools/disasm.py +81 -0
- tritonparse/tools/extract_irs.py +244 -0
- tritonparse/tools/format_fix.py +151 -0
- tritonparse/tools/load_tensor.py +76 -0
- tritonparse/tools/prettify_ndjson.py +334 -0
- tritonparse/tools/readme.md +37 -0
- tritonparse/tp_logger.py +9 -0
- tritonparse/trace_processor.py +367 -0
- tritonparse/utils.py +155 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/METADATA +195 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/RECORD +62 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/WHEEL +5 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/entry_points.txt +2 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/licenses/LICENSE +29 -0
- tritonparse-0.3.2.dev20251210071601.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
|
|
5
|
+
from tritonparse.reproducer.types import KernelImportMode
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _add_reproducer_args(parser: argparse.ArgumentParser) -> None:
|
|
9
|
+
"""Add common arguments for the reproducer to a parser."""
|
|
10
|
+
parser.add_argument("input", help="Path to the ndjson/ndjson.gz log file")
|
|
11
|
+
parser.add_argument(
|
|
12
|
+
"--line",
|
|
13
|
+
type=int,
|
|
14
|
+
default=0,
|
|
15
|
+
help=(
|
|
16
|
+
"The line index (0-based) of the launch event in the input file to reproduce. "
|
|
17
|
+
"Defaults to 0 (first launch event). Mutually exclusive with --kernel/--launch-id."
|
|
18
|
+
),
|
|
19
|
+
)
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"--kernel",
|
|
22
|
+
type=str,
|
|
23
|
+
default=None,
|
|
24
|
+
help=(
|
|
25
|
+
"Kernel name (exact match, case-sensitive) to reproduce. "
|
|
26
|
+
"Use with --launch-id to specify which launch of the kernel. "
|
|
27
|
+
"Mutually exclusive with --line."
|
|
28
|
+
),
|
|
29
|
+
)
|
|
30
|
+
parser.add_argument(
|
|
31
|
+
"--launch-id",
|
|
32
|
+
type=int,
|
|
33
|
+
default=0,
|
|
34
|
+
help=(
|
|
35
|
+
"0-based launch index for the kernel specified by --kernel. "
|
|
36
|
+
"Defaults to 0 (first launch). Only used when --kernel is provided."
|
|
37
|
+
),
|
|
38
|
+
)
|
|
39
|
+
parser.add_argument(
|
|
40
|
+
"--out-dir",
|
|
41
|
+
default="repro_output",
|
|
42
|
+
help=(
|
|
43
|
+
"Directory to save the reproducer script and context JSON. Defaults to "
|
|
44
|
+
"'repro_output/<kernel_name>/' if not provided."
|
|
45
|
+
),
|
|
46
|
+
)
|
|
47
|
+
parser.add_argument(
|
|
48
|
+
"--template",
|
|
49
|
+
default="example",
|
|
50
|
+
help=(
|
|
51
|
+
"Template name (builtin, without .py) or a filesystem path to a .py file. "
|
|
52
|
+
"Defaults to 'example'."
|
|
53
|
+
),
|
|
54
|
+
)
|
|
55
|
+
parser.add_argument(
|
|
56
|
+
"--kernel-import",
|
|
57
|
+
type=KernelImportMode,
|
|
58
|
+
choices=list(KernelImportMode),
|
|
59
|
+
default=KernelImportMode.DEFAULT,
|
|
60
|
+
help=(
|
|
61
|
+
"Kernel import strategy:\n"
|
|
62
|
+
" default: Import kernel from original file (current behavior)\n"
|
|
63
|
+
" copy: Embed kernel source code directly in reproducer\n"
|
|
64
|
+
" override-ttir: Use TTIR from compilation event (bypass Python frontend)\n"
|
|
65
|
+
"Defaults to 'default'."
|
|
66
|
+
),
|
|
67
|
+
)
|
|
68
|
+
parser.add_argument(
|
|
69
|
+
"--use-fbcode",
|
|
70
|
+
action="store_true",
|
|
71
|
+
help=("Use fbcode to setup repro environment."),
|
|
72
|
+
)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
# pyre-strict
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
ConsolidatedResult: Data structures for multi-file call graph analysis results.
|
|
7
|
+
|
|
8
|
+
This module defines the output structure for the MultiFileCallGraphAnalyzer,
|
|
9
|
+
containing all extracted functions, imports, edges, and statistics.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
|
|
14
|
+
from tritonparse.reproducer.ast_analyzer import Edge
|
|
15
|
+
from tritonparse.reproducer.import_info import ImportInfo
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class AnalysisStats:
|
|
20
|
+
"""Statistics about the multi-file analysis."""
|
|
21
|
+
|
|
22
|
+
total_files_analyzed: int
|
|
23
|
+
total_functions_found: int
|
|
24
|
+
total_imports: int
|
|
25
|
+
external_imports: int
|
|
26
|
+
internal_imports: int
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class ConsolidatedResult:
|
|
31
|
+
"""Consolidated analysis results across all files."""
|
|
32
|
+
|
|
33
|
+
# All dependent functions with their source code
|
|
34
|
+
functions: dict[str, str] # qualified_name -> source_code
|
|
35
|
+
|
|
36
|
+
# Mapping of function to file
|
|
37
|
+
function_locations: dict[str, str] # qualified_name -> file_path
|
|
38
|
+
|
|
39
|
+
# Short names for standalone file generation (just the function name)
|
|
40
|
+
function_short_names: dict[str, str] # qualified_name -> short_name
|
|
41
|
+
|
|
42
|
+
# All required imports, deduplicated and organized
|
|
43
|
+
imports: list[ImportInfo]
|
|
44
|
+
|
|
45
|
+
# Call graph edges across all files
|
|
46
|
+
edges: list[Edge]
|
|
47
|
+
|
|
48
|
+
# Files analyzed
|
|
49
|
+
analyzed_files: set[str]
|
|
50
|
+
|
|
51
|
+
# Statistics
|
|
52
|
+
stats: AnalysisStats
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Function extractor for reproducer utility functions.
|
|
5
|
+
|
|
6
|
+
This module extracts utility functions from utils.py and load_tensor.py
|
|
7
|
+
using AST parsing, and generates standalone code for reproducers.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import ast
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def extract_utility_functions() -> str:
|
|
15
|
+
"""
|
|
16
|
+
Extract all utility functions needed for the reproducer template.
|
|
17
|
+
|
|
18
|
+
Uses AST parsing to extract functions and constants from source files
|
|
19
|
+
without importing them (avoiding potential side effects).
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
str: Complete Python code including imports and all utility functions.
|
|
23
|
+
"""
|
|
24
|
+
# Prepare file paths
|
|
25
|
+
base_dir = Path(__file__).parent
|
|
26
|
+
utils_path = base_dir / "utils.py"
|
|
27
|
+
load_tensor_path = base_dir.parent / "tools" / "load_tensor.py"
|
|
28
|
+
|
|
29
|
+
# Parse source files
|
|
30
|
+
utils_tree, utils_lines = _parse_source_file(utils_path)
|
|
31
|
+
load_tensor_tree, load_tensor_lines = _parse_source_file(load_tensor_path)
|
|
32
|
+
|
|
33
|
+
# Define what to extract (in dependency order)
|
|
34
|
+
utils_function_names = [
|
|
35
|
+
"_get_triton_tensor_types",
|
|
36
|
+
"create_args_from_json_file",
|
|
37
|
+
"create_args_from_json",
|
|
38
|
+
"_apply_stride_and_offset",
|
|
39
|
+
"_create_base_tensor",
|
|
40
|
+
"_create_tensor",
|
|
41
|
+
"_create_arg_from_info",
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
load_tensor_function_names = [
|
|
45
|
+
"load_tensor",
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
# Extract content
|
|
49
|
+
extracted_parts = []
|
|
50
|
+
|
|
51
|
+
# Add required imports
|
|
52
|
+
extracted_parts.append(_generate_imports())
|
|
53
|
+
|
|
54
|
+
# Extract constant
|
|
55
|
+
constant = _extract_assignment(
|
|
56
|
+
utils_tree, utils_lines, "TRITON_KERNELS_CUSTOM_TYPES"
|
|
57
|
+
)
|
|
58
|
+
if constant:
|
|
59
|
+
extracted_parts.append(constant)
|
|
60
|
+
|
|
61
|
+
# Extract TRITON_DTYPE_MAP constant
|
|
62
|
+
dtype_map = _extract_assignment(utils_tree, utils_lines, "TRITON_DTYPE_MAP")
|
|
63
|
+
if dtype_map:
|
|
64
|
+
extracted_parts.append(dtype_map)
|
|
65
|
+
|
|
66
|
+
# Extract load_tensor functions
|
|
67
|
+
extracted_parts.extend(
|
|
68
|
+
_extract_functions(
|
|
69
|
+
load_tensor_tree, load_tensor_lines, load_tensor_function_names
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Extract utils functions
|
|
74
|
+
extracted_parts.extend(
|
|
75
|
+
_extract_functions(utils_tree, utils_lines, utils_function_names)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Combine all parts
|
|
79
|
+
return "\n\n".join(extracted_parts)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _parse_source_file(file_path: Path) -> tuple[ast.Module, list[str]]:
|
|
83
|
+
"""
|
|
84
|
+
Parse a Python source file and return its AST and source lines.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
file_path: Path to the Python source file
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
tuple: (AST tree, list of source code lines)
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
FileNotFoundError: If the source file doesn't exist
|
|
94
|
+
SyntaxError: If the source file has syntax errors
|
|
95
|
+
"""
|
|
96
|
+
try:
|
|
97
|
+
source_code = file_path.read_text(encoding="utf-8")
|
|
98
|
+
tree = ast.parse(source_code, filename=str(file_path))
|
|
99
|
+
except FileNotFoundError as e:
|
|
100
|
+
raise FileNotFoundError(f"Source file not found: {file_path}") from e
|
|
101
|
+
except SyntaxError as e:
|
|
102
|
+
raise SyntaxError(f"Failed to parse {file_path}: {e}") from e
|
|
103
|
+
|
|
104
|
+
lines = source_code.splitlines()
|
|
105
|
+
return tree, lines
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _extract_assignment(
|
|
109
|
+
tree: ast.Module, lines: list[str], var_name: str
|
|
110
|
+
) -> str | None:
|
|
111
|
+
"""
|
|
112
|
+
Extract a module-level assignment statement by variable name.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
tree: AST tree of the source file
|
|
116
|
+
lines: Source code lines
|
|
117
|
+
var_name: Name of the variable to extract
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Complete assignment statement source code, or None if not found
|
|
121
|
+
|
|
122
|
+
Example:
|
|
123
|
+
Extracts:
|
|
124
|
+
TRITON_KERNELS_CUSTOM_TYPES = (
|
|
125
|
+
importlib.util.find_spec("triton_kernels") is not None
|
|
126
|
+
and importlib.util.find_spec("triton_kernels.tensor") is not None
|
|
127
|
+
)
|
|
128
|
+
"""
|
|
129
|
+
# Search only at module level
|
|
130
|
+
for node in tree.body:
|
|
131
|
+
if isinstance(node, ast.Assign):
|
|
132
|
+
for target in node.targets:
|
|
133
|
+
if isinstance(target, ast.Name) and target.id == var_name:
|
|
134
|
+
# Found it! Extract source code using line numbers
|
|
135
|
+
start_line = node.lineno - 1 # Convert to 0-based index
|
|
136
|
+
end_line = node.end_lineno # Already suitable for slicing
|
|
137
|
+
assignment_lines = lines[start_line:end_line]
|
|
138
|
+
return "\n".join(assignment_lines)
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _extract_function(tree: ast.Module, lines: list[str], func_name: str) -> str | None:
|
|
143
|
+
"""
|
|
144
|
+
Extract a function definition by name, including decorators.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
tree: AST tree of the source file
|
|
148
|
+
lines: Source code lines
|
|
149
|
+
func_name: Name of the function to extract
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Complete function source code including decorators, or None if not found
|
|
153
|
+
|
|
154
|
+
Example:
|
|
155
|
+
Extracts:
|
|
156
|
+
@lru_cache(maxsize=1)
|
|
157
|
+
def _get_triton_tensor_types():
|
|
158
|
+
'''Docstring'''
|
|
159
|
+
...
|
|
160
|
+
"""
|
|
161
|
+
# Walk the entire tree (handles nested functions if needed)
|
|
162
|
+
for node in ast.walk(tree):
|
|
163
|
+
if isinstance(node, ast.FunctionDef) and node.name == func_name:
|
|
164
|
+
# If function has decorators, start from the first decorator
|
|
165
|
+
if node.decorator_list:
|
|
166
|
+
start_line = node.decorator_list[0].lineno - 1
|
|
167
|
+
else:
|
|
168
|
+
start_line = node.lineno - 1
|
|
169
|
+
|
|
170
|
+
end_line = node.end_lineno
|
|
171
|
+
func_lines = lines[start_line:end_line]
|
|
172
|
+
return "\n".join(func_lines)
|
|
173
|
+
return None
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _extract_functions(
|
|
177
|
+
tree: ast.Module, lines: list[str], func_names: list[str]
|
|
178
|
+
) -> list[str]:
|
|
179
|
+
"""
|
|
180
|
+
Extract multiple functions from a source file.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
tree: AST tree of the source file
|
|
184
|
+
lines: Source code lines
|
|
185
|
+
func_names: List of function names to extract
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
List of function source codes in the same order as func_names
|
|
189
|
+
|
|
190
|
+
Raises:
|
|
191
|
+
ValueError: If any function is not found
|
|
192
|
+
"""
|
|
193
|
+
extracted = []
|
|
194
|
+
for func_name in func_names:
|
|
195
|
+
func_source = _extract_function(tree, lines, func_name)
|
|
196
|
+
if func_source is None:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"Function '{func_name}' not found in source. "
|
|
199
|
+
f"Available functions might have been renamed or removed."
|
|
200
|
+
)
|
|
201
|
+
extracted.append(func_source)
|
|
202
|
+
return extracted
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _generate_imports() -> str:
|
|
206
|
+
"""
|
|
207
|
+
Generate the import statements needed for the extracted functions.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
str: Import statements as a single string
|
|
211
|
+
"""
|
|
212
|
+
imports = [
|
|
213
|
+
"import gzip",
|
|
214
|
+
"import hashlib",
|
|
215
|
+
"import importlib",
|
|
216
|
+
"import importlib.util",
|
|
217
|
+
"import io",
|
|
218
|
+
"import json",
|
|
219
|
+
"import logging",
|
|
220
|
+
"import sys",
|
|
221
|
+
"from functools import lru_cache",
|
|
222
|
+
"from pathlib import Path",
|
|
223
|
+
"from typing import Union",
|
|
224
|
+
"",
|
|
225
|
+
"import torch",
|
|
226
|
+
"import triton.language as tl",
|
|
227
|
+
]
|
|
228
|
+
return "\n".join(imports)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
# pyre-strict
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class ImportInfo:
|
|
10
|
+
"""Information about an import statement."""
|
|
11
|
+
|
|
12
|
+
# Import statement details
|
|
13
|
+
import_type: str # "import" or "from_import"
|
|
14
|
+
module: str # e.g., "torch.nn.functional"
|
|
15
|
+
names: list[str] # Imported names: ["func1", "func2"]
|
|
16
|
+
|
|
17
|
+
# Resolution metadata
|
|
18
|
+
source_file: str # File containing this import
|
|
19
|
+
resolved_path: str | None # Resolved file path (None if external)
|
|
20
|
+
is_external: bool # True for third-party/built-in
|
|
21
|
+
lineno: int # Line number in source file
|
|
22
|
+
|
|
23
|
+
# Fields with defaults (must come after required fields)
|
|
24
|
+
aliases: dict[str, str] = field(default_factory=dict) # {local_name: original_name}
|
|
25
|
+
level: int = 0 # 0 = absolute, 1 = ".", 2 = "..", etc.
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
|
|
3
|
+
# pyre-strict
|
|
4
|
+
|
|
5
|
+
import ast
|
|
6
|
+
|
|
7
|
+
from tritonparse.reproducer.import_info import ImportInfo
|
|
8
|
+
from tritonparse.reproducer.import_resolver import ImportResolver
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ImportParser:
|
|
12
|
+
"""
|
|
13
|
+
Parse import statements from Python AST.
|
|
14
|
+
|
|
15
|
+
Extracts all import statements from Python source code using AST,
|
|
16
|
+
resolves them to file paths, and returns structured ImportInfo objects.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, import_resolver: ImportResolver) -> None:
|
|
20
|
+
"""
|
|
21
|
+
Initialize the import parser.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
import_resolver: ImportResolver instance for resolving import paths
|
|
25
|
+
"""
|
|
26
|
+
self.import_resolver = import_resolver
|
|
27
|
+
|
|
28
|
+
def parse_imports(
|
|
29
|
+
self, tree: ast.Module, source_file: str, package: str | None = None
|
|
30
|
+
) -> list[ImportInfo]:
|
|
31
|
+
"""
|
|
32
|
+
Extract all import statements from AST.
|
|
33
|
+
|
|
34
|
+
Handles:
|
|
35
|
+
- import X
|
|
36
|
+
- import X as Y
|
|
37
|
+
- from X import Y
|
|
38
|
+
- from X import Y as Z
|
|
39
|
+
- from . import X (relative)
|
|
40
|
+
- from .. import X (relative)
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
tree: Parsed AST module
|
|
44
|
+
source_file: File path containing the imports
|
|
45
|
+
package: Package context for relative imports (e.g., "pytorch.tritonparse")
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
List of ImportInfo objects with resolved paths
|
|
49
|
+
"""
|
|
50
|
+
imports: list[ImportInfo] = []
|
|
51
|
+
|
|
52
|
+
for node in ast.walk(tree):
|
|
53
|
+
if isinstance(node, ast.Import):
|
|
54
|
+
# Handle: import X, Y as Z
|
|
55
|
+
imports.extend(self._parse_import_node(node, source_file))
|
|
56
|
+
elif isinstance(node, ast.ImportFrom):
|
|
57
|
+
# Handle: from X import Y, Z as W
|
|
58
|
+
imports.extend(self._parse_import_from_node(node, source_file, package))
|
|
59
|
+
|
|
60
|
+
return imports
|
|
61
|
+
|
|
62
|
+
def _parse_import_node(
|
|
63
|
+
self, node: ast.Import, source_file: str
|
|
64
|
+
) -> list[ImportInfo]:
|
|
65
|
+
"""
|
|
66
|
+
Parse 'import X' statements.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
node: ast.Import node
|
|
70
|
+
source_file: File containing this import
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
List of ImportInfo objects for each imported module
|
|
74
|
+
"""
|
|
75
|
+
imports: list[ImportInfo] = []
|
|
76
|
+
|
|
77
|
+
for alias in node.names:
|
|
78
|
+
# Resolve the import to a file path
|
|
79
|
+
resolved_path, is_external = self.import_resolver.resolve_import(alias.name)
|
|
80
|
+
|
|
81
|
+
# Build aliases dict if alias is used
|
|
82
|
+
aliases: dict[str, str] = {}
|
|
83
|
+
if alias.asname:
|
|
84
|
+
aliases[alias.asname] = alias.name
|
|
85
|
+
|
|
86
|
+
imports.append(
|
|
87
|
+
ImportInfo(
|
|
88
|
+
import_type="import",
|
|
89
|
+
module=alias.name,
|
|
90
|
+
names=[alias.name.split(".")[-1]], # Last component as name
|
|
91
|
+
aliases=aliases,
|
|
92
|
+
source_file=source_file,
|
|
93
|
+
resolved_path=resolved_path,
|
|
94
|
+
is_external=is_external,
|
|
95
|
+
lineno=node.lineno,
|
|
96
|
+
level=0, # Absolute import
|
|
97
|
+
)
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return imports
|
|
101
|
+
|
|
102
|
+
def _parse_import_from_node(
|
|
103
|
+
self,
|
|
104
|
+
node: ast.ImportFrom,
|
|
105
|
+
source_file: str,
|
|
106
|
+
package: str | None = None,
|
|
107
|
+
) -> list[ImportInfo]:
|
|
108
|
+
"""
|
|
109
|
+
Parse 'from X import Y' statements.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
node: ast.ImportFrom node
|
|
113
|
+
source_file: File containing this import
|
|
114
|
+
package: Package context for relative imports
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
List of ImportInfo objects for each imported name
|
|
118
|
+
"""
|
|
119
|
+
imports: list[ImportInfo] = []
|
|
120
|
+
|
|
121
|
+
# Get module name (may be None for relative imports like "from . import X")
|
|
122
|
+
module = node.module or ""
|
|
123
|
+
level = node.level # 0 = absolute, 1 = ".", 2 = "..", etc.
|
|
124
|
+
|
|
125
|
+
# For relative imports, construct the full module name
|
|
126
|
+
if level > 0 and package:
|
|
127
|
+
# Build relative module name
|
|
128
|
+
# level=1: current package (from . import X)
|
|
129
|
+
# level=2: parent package (from .. import X)
|
|
130
|
+
# Formula: Remove (level-1) components from package, then append module
|
|
131
|
+
# Examples:
|
|
132
|
+
# - package="a.b.c", level=1, module=None -> "a.b.c" (current)
|
|
133
|
+
# - package="a.b.c", level=1, module="d" -> "a.b.c.d"
|
|
134
|
+
# - package="a.b.c", level=2, module=None -> "a.b" (parent)
|
|
135
|
+
# - package="a.b.c", level=2, module="d" -> "a.b.d"
|
|
136
|
+
package_parts = package.split(".")
|
|
137
|
+
|
|
138
|
+
# Remove (level-1) components
|
|
139
|
+
if level == 1:
|
|
140
|
+
parent_package = package
|
|
141
|
+
elif level <= len(package_parts):
|
|
142
|
+
parent_package = ".".join(package_parts[: -(level - 1)])
|
|
143
|
+
else:
|
|
144
|
+
parent_package = ""
|
|
145
|
+
|
|
146
|
+
# Append module if specified
|
|
147
|
+
if module:
|
|
148
|
+
full_module = f"{parent_package}.{module}" if parent_package else module
|
|
149
|
+
else:
|
|
150
|
+
full_module = parent_package
|
|
151
|
+
else:
|
|
152
|
+
full_module = module
|
|
153
|
+
|
|
154
|
+
# Resolve the module to a file path
|
|
155
|
+
resolved_path, is_external = self.import_resolver.resolve_import(full_module)
|
|
156
|
+
|
|
157
|
+
# Parse each imported name
|
|
158
|
+
for alias in node.names:
|
|
159
|
+
# Build aliases dict if alias is used
|
|
160
|
+
aliases: dict[str, str] = {}
|
|
161
|
+
if alias.asname:
|
|
162
|
+
aliases[alias.asname] = alias.name
|
|
163
|
+
|
|
164
|
+
imports.append(
|
|
165
|
+
ImportInfo(
|
|
166
|
+
import_type="from_import",
|
|
167
|
+
module=full_module,
|
|
168
|
+
names=[alias.name],
|
|
169
|
+
aliases=aliases,
|
|
170
|
+
source_file=source_file,
|
|
171
|
+
resolved_path=resolved_path,
|
|
172
|
+
is_external=is_external,
|
|
173
|
+
lineno=node.lineno,
|
|
174
|
+
level=level,
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
return imports
|