pyrefactor 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyrefactor/__init__.py +3 -0
- pyrefactor/__main__.py +231 -0
- pyrefactor/analyzer.py +185 -0
- pyrefactor/ast_visitor.py +197 -0
- pyrefactor/config.py +224 -0
- pyrefactor/detectors/__init__.py +23 -0
- pyrefactor/detectors/boolean_logic.py +231 -0
- pyrefactor/detectors/comparisons.py +353 -0
- pyrefactor/detectors/complexity.py +248 -0
- pyrefactor/detectors/context_manager.py +188 -0
- pyrefactor/detectors/control_flow.py +156 -0
- pyrefactor/detectors/dict_operations.py +346 -0
- pyrefactor/detectors/duplication.py +358 -0
- pyrefactor/detectors/loops.py +267 -0
- pyrefactor/detectors/performance.py +267 -0
- pyrefactor/models.py +98 -0
- pyrefactor/py.typed +0 -0
- pyrefactor/reporter.py +208 -0
- pyrefactor-1.0.1.dist-info/METADATA +353 -0
- pyrefactor-1.0.1.dist-info/RECORD +24 -0
- pyrefactor-1.0.1.dist-info/WHEEL +5 -0
- pyrefactor-1.0.1.dist-info/entry_points.txt +2 -0
- pyrefactor-1.0.1.dist-info/licenses/LICENSE.md +70 -0
- pyrefactor-1.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
"""Code duplication detector for PyRefactor."""
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
import hashlib
|
|
5
|
+
import tokenize
|
|
6
|
+
from io import StringIO
|
|
7
|
+
from typing import Optional, cast
|
|
8
|
+
|
|
9
|
+
from ..ast_visitor import BaseDetector
|
|
10
|
+
from ..config import Config
|
|
11
|
+
from ..models import Issue, Severity
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class _ExclusionVisitor(ast.NodeVisitor):
|
|
15
|
+
"""AST visitor to identify line ranges to exclude from duplication detection."""
|
|
16
|
+
|
|
17
|
+
def __init__(self) -> None:
|
|
18
|
+
"""Initialize exclusion visitor."""
|
|
19
|
+
self.ranges: list[tuple[int, int]] = []
|
|
20
|
+
|
|
21
|
+
def _add_node_range(self, node: ast.AST) -> None:
|
|
22
|
+
"""Add node's line range to exclusions if available."""
|
|
23
|
+
if not hasattr(node, "lineno") or not hasattr(node, "end_lineno"):
|
|
24
|
+
return
|
|
25
|
+
lineno = cast(int, getattr(node, "lineno"))
|
|
26
|
+
end_lineno = cast(Optional[int], getattr(node, "end_lineno"))
|
|
27
|
+
if end_lineno is None:
|
|
28
|
+
return
|
|
29
|
+
self.ranges.append((lineno, end_lineno))
|
|
30
|
+
|
|
31
|
+
def visit_Set(self, node: ast.Set) -> None:
|
|
32
|
+
"""Visit set literal and add its range to exclusions."""
|
|
33
|
+
self._add_node_range(node)
|
|
34
|
+
self.generic_visit(node)
|
|
35
|
+
|
|
36
|
+
def visit_List(self, node: ast.List) -> None:
|
|
37
|
+
"""Visit list literal and add its range to exclusions."""
|
|
38
|
+
self._add_node_range(node)
|
|
39
|
+
self.generic_visit(node)
|
|
40
|
+
|
|
41
|
+
def visit_Dict(self, node: ast.Dict) -> None:
|
|
42
|
+
"""Visit dict literal and add its range to exclusions."""
|
|
43
|
+
self._add_node_range(node)
|
|
44
|
+
self.generic_visit(node)
|
|
45
|
+
|
|
46
|
+
def visit_Tuple(self, node: ast.Tuple) -> None:
|
|
47
|
+
"""Visit tuple literal and add its range to exclusions."""
|
|
48
|
+
self._add_node_range(node)
|
|
49
|
+
self.generic_visit(node)
|
|
50
|
+
|
|
51
|
+
def _add_docstring_range(self, node: ast.AST) -> None:
|
|
52
|
+
"""Check if node has a docstring and add to exclusions."""
|
|
53
|
+
if not hasattr(node, "body"):
|
|
54
|
+
return
|
|
55
|
+
body = cast(list[ast.stmt], getattr(node, "body"))
|
|
56
|
+
if not body:
|
|
57
|
+
return
|
|
58
|
+
|
|
59
|
+
# Only these node types are supported by ast.get_docstring
|
|
60
|
+
if isinstance(
|
|
61
|
+
node, (ast.AsyncFunctionDef, ast.FunctionDef, ast.ClassDef, ast.Module)
|
|
62
|
+
):
|
|
63
|
+
docstring_text = ast.get_docstring(node, clean=False)
|
|
64
|
+
if not docstring_text:
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
first_stmt = body[0]
|
|
68
|
+
if not isinstance(first_stmt, ast.Expr):
|
|
69
|
+
return
|
|
70
|
+
if not isinstance(first_stmt.value, ast.Constant):
|
|
71
|
+
return
|
|
72
|
+
if not isinstance(first_stmt.value.value, str):
|
|
73
|
+
return
|
|
74
|
+
|
|
75
|
+
self._add_node_range(first_stmt)
|
|
76
|
+
|
|
77
|
+
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
|
78
|
+
"""Visit function and exclude its docstring."""
|
|
79
|
+
self._add_docstring_range(node)
|
|
80
|
+
self.generic_visit(node)
|
|
81
|
+
|
|
82
|
+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
|
83
|
+
"""Visit async function and exclude its docstring."""
|
|
84
|
+
self._add_docstring_range(node)
|
|
85
|
+
self.generic_visit(node)
|
|
86
|
+
|
|
87
|
+
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
|
88
|
+
"""Visit class and exclude its docstring."""
|
|
89
|
+
self._add_docstring_range(node)
|
|
90
|
+
self.generic_visit(node)
|
|
91
|
+
|
|
92
|
+
def visit_Module(self, node: ast.Module) -> None:
|
|
93
|
+
"""Visit module and exclude its docstring."""
|
|
94
|
+
self._add_docstring_range(node)
|
|
95
|
+
self.generic_visit(node)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class DuplicationDetector(BaseDetector):
|
|
99
|
+
"""Detects code duplication."""
|
|
100
|
+
|
|
101
|
+
# Maximum block size to analyze (prevents excessive memory usage)
|
|
102
|
+
MAX_BLOCK_SIZE = 20
|
|
103
|
+
|
|
104
|
+
def __init__(self, config: Config, file_path: str, source_lines: list[str]) -> None:
|
|
105
|
+
"""Initialize duplication detector."""
|
|
106
|
+
super().__init__(config, file_path, source_lines)
|
|
107
|
+
self.code_blocks: dict[str, list[tuple[int, int, str, str]]] = {}
|
|
108
|
+
self.checked = False
|
|
109
|
+
self.excluded_ranges: list[tuple[int, int]] = []
|
|
110
|
+
|
|
111
|
+
def get_detector_name(self) -> str:
|
|
112
|
+
"""Return the name of this detector."""
|
|
113
|
+
return "duplication"
|
|
114
|
+
|
|
115
|
+
def analyze(self, tree: ast.AST) -> list[Issue]:
|
|
116
|
+
"""Run duplication detection on the entire file."""
|
|
117
|
+
# First, identify exclusion zones (data structures and docstrings)
|
|
118
|
+
self._identify_excluded_ranges(tree)
|
|
119
|
+
|
|
120
|
+
# Then, extract all code blocks
|
|
121
|
+
self._extract_code_blocks()
|
|
122
|
+
|
|
123
|
+
# Finally, find duplicates
|
|
124
|
+
self._find_duplicates()
|
|
125
|
+
|
|
126
|
+
return self.issues
|
|
127
|
+
|
|
128
|
+
def _identify_excluded_ranges(self, tree: ast.AST) -> None:
|
|
129
|
+
"""Identify line ranges that should be excluded from duplication detection.
|
|
130
|
+
|
|
131
|
+
This includes:
|
|
132
|
+
- Data structure literals (Set, List, Dict, Tuple)
|
|
133
|
+
- Docstrings (first string in functions/classes/modules)
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
tree: The AST of the source file
|
|
137
|
+
"""
|
|
138
|
+
visitor = _ExclusionVisitor()
|
|
139
|
+
visitor.visit(tree)
|
|
140
|
+
self.excluded_ranges = visitor.ranges
|
|
141
|
+
|
|
142
|
+
def _is_in_excluded_range(self, start_line: int, end_line: int) -> bool:
|
|
143
|
+
"""Check if a line range overlaps with any excluded range.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
start_line: Start line of the range (1-indexed)
|
|
147
|
+
end_line: End line of the range (1-indexed)
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
True if the range overlaps with any excluded range
|
|
151
|
+
"""
|
|
152
|
+
for excluded_start, excluded_end in self.excluded_ranges:
|
|
153
|
+
# Check if there's any overlap
|
|
154
|
+
if end_line >= excluded_start and start_line <= excluded_end:
|
|
155
|
+
return True
|
|
156
|
+
return False
|
|
157
|
+
|
|
158
|
+
def _extract_code_blocks(self) -> None:
|
|
159
|
+
"""Extract code blocks for comparison."""
|
|
160
|
+
min_lines = self.config.duplication.min_duplicate_lines
|
|
161
|
+
total_lines = len(self.source_lines)
|
|
162
|
+
|
|
163
|
+
# Extract sliding windows of code with optimized range
|
|
164
|
+
for start in range(total_lines):
|
|
165
|
+
max_length = min(self.MAX_BLOCK_SIZE, total_lines - start)
|
|
166
|
+
for length in range(min_lines, max_length + 1):
|
|
167
|
+
end = start + length
|
|
168
|
+
if end > total_lines:
|
|
169
|
+
break
|
|
170
|
+
|
|
171
|
+
# Skip blocks that are in excluded ranges (data structures, docstrings)
|
|
172
|
+
if self._is_in_excluded_range(start + 1, end):
|
|
173
|
+
continue
|
|
174
|
+
|
|
175
|
+
# Get the code block
|
|
176
|
+
code_block = "\n".join(self.source_lines[start:end])
|
|
177
|
+
|
|
178
|
+
# Skip if block is mostly whitespace or comments
|
|
179
|
+
if not self._is_meaningful_block(code_block):
|
|
180
|
+
continue
|
|
181
|
+
|
|
182
|
+
# Normalize the code for comparison
|
|
183
|
+
normalized = self._normalize_code(code_block)
|
|
184
|
+
if not normalized:
|
|
185
|
+
continue
|
|
186
|
+
|
|
187
|
+
# Hash the normalized code
|
|
188
|
+
code_hash = hashlib.md5(normalized.encode()).hexdigest()
|
|
189
|
+
|
|
190
|
+
if code_hash not in self.code_blocks:
|
|
191
|
+
self.code_blocks[code_hash] = []
|
|
192
|
+
|
|
193
|
+
# Store normalized code along with block to avoid re-normalization
|
|
194
|
+
self.code_blocks[code_hash].append(
|
|
195
|
+
(start + 1, end, code_block, normalized)
|
|
196
|
+
) # +1 for 1-indexed lines
|
|
197
|
+
|
|
198
|
+
def _find_duplicates(self) -> None:
|
|
199
|
+
"""Find and report duplicate code blocks."""
|
|
200
|
+
reported_ranges: list[tuple[int, int]] = []
|
|
201
|
+
|
|
202
|
+
for _, occurrences in self.code_blocks.items():
|
|
203
|
+
if len(occurrences) <= 1:
|
|
204
|
+
continue
|
|
205
|
+
|
|
206
|
+
# Sort by line number
|
|
207
|
+
sorted_occurrences = sorted(
|
|
208
|
+
occurrences, key=lambda item: item[0] # type: ignore[misc]
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Report each duplicate (except the first occurrence)
|
|
212
|
+
first_start, first_end, _, first_normalized = sorted_occurrences[0]
|
|
213
|
+
|
|
214
|
+
for start, end, _, normalized in sorted_occurrences[1:]:
|
|
215
|
+
# Check for suppression comments
|
|
216
|
+
if self._is_block_suppressed(start):
|
|
217
|
+
continue
|
|
218
|
+
|
|
219
|
+
# Skip if this range overlaps with first occurrence or already reported ranges
|
|
220
|
+
if self._overlaps_with_reported(start, end, reported_ranges):
|
|
221
|
+
continue
|
|
222
|
+
if self._overlaps_with_reported(start, end, [(first_start, first_end)]):
|
|
223
|
+
continue
|
|
224
|
+
|
|
225
|
+
# Check similarity using already-normalized code
|
|
226
|
+
similarity = self._calculate_similarity_from_normalized(
|
|
227
|
+
first_normalized, normalized
|
|
228
|
+
)
|
|
229
|
+
threshold = self.config.duplication.similarity_threshold
|
|
230
|
+
|
|
231
|
+
if similarity >= threshold:
|
|
232
|
+
self.add_issue(
|
|
233
|
+
Issue(
|
|
234
|
+
file=self.file_path,
|
|
235
|
+
line=start,
|
|
236
|
+
column=0,
|
|
237
|
+
severity=Severity.MEDIUM,
|
|
238
|
+
rule_id="D001",
|
|
239
|
+
message=f"Duplicate code block (lines {start}-{end}) similar to lines {first_start}-{first_end}",
|
|
240
|
+
suggestion="Extract duplicated code to a reusable function or method",
|
|
241
|
+
end_line=end,
|
|
242
|
+
)
|
|
243
|
+
)
|
|
244
|
+
reported_ranges.append((start, end))
|
|
245
|
+
|
|
246
|
+
def _is_block_suppressed(self, line: int) -> bool:
|
|
247
|
+
"""Check if a code block has a suppression comment.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
line: The line number to check
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
True if the line has a suppression comment
|
|
254
|
+
"""
|
|
255
|
+
if line < 1 or line > len(self.source_lines):
|
|
256
|
+
return False
|
|
257
|
+
|
|
258
|
+
# Check a range of lines before the block (up to 3 lines back)
|
|
259
|
+
for offset in range(4):
|
|
260
|
+
check_line = line - offset
|
|
261
|
+
if check_line < 1:
|
|
262
|
+
break
|
|
263
|
+
current_line = self.source_lines[check_line - 1]
|
|
264
|
+
if "# pyrefactor: ignore" in current_line or "# noqa" in current_line:
|
|
265
|
+
return True
|
|
266
|
+
|
|
267
|
+
return False
|
|
268
|
+
|
|
269
|
+
def _overlaps_with_reported( # pyrefactor: ignore
|
|
270
|
+
self, start: int, end: int, reported: list[tuple[int, int]]
|
|
271
|
+
) -> bool:
|
|
272
|
+
"""Check if a range overlaps with any already reported range.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
start: Start line of the range
|
|
276
|
+
end: End line of the range
|
|
277
|
+
reported: List of already reported ranges
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
True if the range overlaps with any reported range
|
|
281
|
+
"""
|
|
282
|
+
for reported_start, reported_end in reported:
|
|
283
|
+
# Check for any overlap (using De Morgan's law simplification)
|
|
284
|
+
if end >= reported_start and start <= reported_end:
|
|
285
|
+
return True
|
|
286
|
+
return False
|
|
287
|
+
|
|
288
|
+
def _is_meaningful_block(self, code: str) -> bool:
|
|
289
|
+
"""Check if a code block is meaningful (not just whitespace/comments)."""
|
|
290
|
+
lines = code.strip().split("\n")
|
|
291
|
+
meaningful_lines = sum(
|
|
292
|
+
1
|
|
293
|
+
for line in lines
|
|
294
|
+
if (stripped := line.strip()) and not stripped.startswith("#")
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
return meaningful_lines >= self.config.duplication.min_duplicate_lines
|
|
298
|
+
|
|
299
|
+
def _normalize_code(self, code: str) -> str:
|
|
300
|
+
"""Normalize code for comparison by tokenizing and removing literals."""
|
|
301
|
+
try:
|
|
302
|
+
tokens = tokenize.generate_tokens(StringIO(code).readline)
|
|
303
|
+
normalized_tokens = [self._normalize_token(token) for token in tokens]
|
|
304
|
+
# Filter out None values
|
|
305
|
+
return " ".join(token for token in normalized_tokens if token)
|
|
306
|
+
except (tokenize.TokenError, IndentationError, SyntaxError):
|
|
307
|
+
# Return empty string for code blocks that can't be tokenized
|
|
308
|
+
# (e.g., incomplete blocks with inconsistent indentation)
|
|
309
|
+
return ""
|
|
310
|
+
|
|
311
|
+
def _normalize_token(
|
|
312
|
+
self, token: tokenize.TokenInfo
|
|
313
|
+
) -> Optional[str]: # pyrefactor: ignore
|
|
314
|
+
"""Normalize a single token for comparison.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
token: Token to normalize
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
Normalized token string or None if token should be skipped
|
|
321
|
+
"""
|
|
322
|
+
if token.type == tokenize.NAME:
|
|
323
|
+
return token.string
|
|
324
|
+
if token.type == tokenize.OP:
|
|
325
|
+
return token.string
|
|
326
|
+
if token.type in (tokenize.NUMBER, tokenize.STRING):
|
|
327
|
+
return "LITERAL"
|
|
328
|
+
if token.type == tokenize.NEWLINE:
|
|
329
|
+
return "\n"
|
|
330
|
+
return None
|
|
331
|
+
|
|
332
|
+
def _calculate_similarity_from_normalized( # pyrefactor: ignore
|
|
333
|
+
self, normalized1: str, normalized2: str
|
|
334
|
+
) -> float:
|
|
335
|
+
"""Calculate similarity between two already-normalized code blocks.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
normalized1: First normalized code block
|
|
339
|
+
normalized2: Second normalized code block
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
Similarity score between 0.0 and 1.0
|
|
343
|
+
"""
|
|
344
|
+
tokens1 = set(normalized1.split())
|
|
345
|
+
tokens2 = set(normalized2.split())
|
|
346
|
+
|
|
347
|
+
if not tokens1 or not tokens2:
|
|
348
|
+
return 0.0
|
|
349
|
+
|
|
350
|
+
intersection = len(tokens1 & tokens2)
|
|
351
|
+
union = len(tokens1 | tokens2)
|
|
352
|
+
|
|
353
|
+
return intersection / union if union > 0 else 0.0
|
|
354
|
+
|
|
355
|
+
def visit(self, node: ast.AST) -> None:
|
|
356
|
+
"""Override visit to prevent default traversal."""
|
|
357
|
+
# Duplication detection works at file level, not node level
|
|
358
|
+
...
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
"""Loop optimization detector for PyRefactor."""
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
from typing import Optional, cast
|
|
5
|
+
|
|
6
|
+
from ..ast_visitor import BaseDetector
|
|
7
|
+
from ..models import Issue, Severity
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class IndexTracker(ast.NodeVisitor):
|
|
11
|
+
"""Track manual index increments in loops."""
|
|
12
|
+
|
|
13
|
+
def __init__(self) -> None:
|
|
14
|
+
self.manual_indices: set[str] = set()
|
|
15
|
+
|
|
16
|
+
def visit_AugAssign(self, aug_node: ast.AugAssign) -> None:
|
|
17
|
+
"""Track += 1 operations on variables."""
|
|
18
|
+
if (
|
|
19
|
+
isinstance(aug_node.target, ast.Name)
|
|
20
|
+
and isinstance(aug_node.op, ast.Add)
|
|
21
|
+
and isinstance(aug_node.value, ast.Constant)
|
|
22
|
+
and aug_node.value.value == 1
|
|
23
|
+
):
|
|
24
|
+
self.manual_indices.add(aug_node.target.id)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class InvariantChecker(ast.NodeVisitor):
|
|
28
|
+
"""Check for loop-invariant computations."""
|
|
29
|
+
|
|
30
|
+
# Methods that are potentially expensive when called repeatedly in loops
|
|
31
|
+
EXPENSIVE_METHODS = frozenset(
|
|
32
|
+
{"compile", "search", "match", "split", "findall", "sub"}
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def __init__(self, var_name: str) -> None:
|
|
36
|
+
self.var_name = var_name
|
|
37
|
+
self.invariant_calls: list[ast.Call] = []
|
|
38
|
+
|
|
39
|
+
def visit_Call(self, call_node: ast.Call) -> None:
|
|
40
|
+
"""Check if call depends on loop variable."""
|
|
41
|
+
if not self._depends_on_var(call_node):
|
|
42
|
+
# Check if it's a potentially expensive call
|
|
43
|
+
if isinstance(call_node.func, ast.Attribute):
|
|
44
|
+
if call_node.func.attr in self.EXPENSIVE_METHODS:
|
|
45
|
+
self.invariant_calls.append(call_node)
|
|
46
|
+
self.generic_visit(call_node)
|
|
47
|
+
|
|
48
|
+
def _depends_on_var(self, node: ast.AST) -> bool:
|
|
49
|
+
"""Check if node uses the loop variable."""
|
|
50
|
+
for child in ast.walk(node):
|
|
51
|
+
if isinstance(child, ast.Name) and child.id == self.var_name:
|
|
52
|
+
return True
|
|
53
|
+
return False
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class LoopsDetector(BaseDetector):
|
|
57
|
+
"""Detects loop optimization opportunities."""
|
|
58
|
+
|
|
59
|
+
# Minimum nesting level to trigger nested loop optimization warning
|
|
60
|
+
MIN_NESTED_LOOPS_FOR_WARNING = 2
|
|
61
|
+
|
|
62
|
+
def get_detector_name(self) -> str:
|
|
63
|
+
"""Return the name of this detector."""
|
|
64
|
+
return "loops"
|
|
65
|
+
|
|
66
|
+
def _create_issue(
|
|
67
|
+
self,
|
|
68
|
+
node: ast.AST,
|
|
69
|
+
*,
|
|
70
|
+
severity: Severity,
|
|
71
|
+
rule_id: str,
|
|
72
|
+
message: str,
|
|
73
|
+
suggestion: str,
|
|
74
|
+
) -> Issue:
|
|
75
|
+
"""Create an Issue object for loop-related issues."""
|
|
76
|
+
return Issue(
|
|
77
|
+
file=self.file_path,
|
|
78
|
+
line=cast(int, getattr(node, "lineno", 0)),
|
|
79
|
+
column=cast(int, getattr(node, "col_offset", 0)),
|
|
80
|
+
severity=severity,
|
|
81
|
+
rule_id=rule_id,
|
|
82
|
+
message=message,
|
|
83
|
+
suggestion=suggestion,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def visit_For(self, node: ast.For) -> None:
|
|
87
|
+
"""Check for loop optimization opportunities."""
|
|
88
|
+
if self.is_suppressed(node):
|
|
89
|
+
self.generic_visit(node)
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
# Check for range(len()) pattern
|
|
93
|
+
self._check_range_len_pattern(node)
|
|
94
|
+
|
|
95
|
+
# Check for manual index tracking
|
|
96
|
+
self._check_manual_index_tracking(node)
|
|
97
|
+
|
|
98
|
+
# Check for nested loops that might benefit from dict lookup
|
|
99
|
+
self._check_nested_loop_optimization(node)
|
|
100
|
+
|
|
101
|
+
# Check for loop invariant code
|
|
102
|
+
self._check_loop_invariants(node)
|
|
103
|
+
|
|
104
|
+
self.generic_visit(node)
|
|
105
|
+
|
|
106
|
+
def _check_range_len_pattern(self, node: ast.For) -> None:
|
|
107
|
+
"""Check for range(len(x)) that should use enumerate."""
|
|
108
|
+
# Validate the basic pattern structure
|
|
109
|
+
if not self._is_range_len_call(node):
|
|
110
|
+
return
|
|
111
|
+
|
|
112
|
+
# Extract the collection being iterated
|
|
113
|
+
collection = self._extract_collection_from_range_len(node)
|
|
114
|
+
if collection is None:
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
# Check if the loop body actually uses indexed access
|
|
118
|
+
if not self._loop_body_accesses_collection(node, collection):
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
self.add_issue(
|
|
122
|
+
self._create_issue(
|
|
123
|
+
node,
|
|
124
|
+
severity=Severity.LOW,
|
|
125
|
+
rule_id="L001",
|
|
126
|
+
message="Use enumerate() instead of range(len())",
|
|
127
|
+
suggestion="Replace 'for i in range(len(items)):' with 'for i, item in enumerate(items):'",
|
|
128
|
+
)
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def _is_range_len_call(self, node: ast.For) -> bool:
|
|
132
|
+
"""Check if the loop uses range(len(...)) pattern."""
|
|
133
|
+
if not isinstance(node.iter, ast.Call):
|
|
134
|
+
return False
|
|
135
|
+
|
|
136
|
+
if not isinstance(node.iter.func, ast.Name):
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
return node.iter.func.id == "range" and bool(node.iter.args)
|
|
140
|
+
|
|
141
|
+
def _extract_collection_from_range_len(self, node: ast.For) -> Optional[ast.AST]:
|
|
142
|
+
"""Extract the collection from a range(len(...)) call."""
|
|
143
|
+
if not isinstance(node.iter, ast.Call) or not node.iter.args:
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
first_arg = node.iter.args[0]
|
|
147
|
+
if not isinstance(first_arg, ast.Call):
|
|
148
|
+
return None
|
|
149
|
+
|
|
150
|
+
if not isinstance(first_arg.func, ast.Name) or first_arg.func.id != "len":
|
|
151
|
+
return None
|
|
152
|
+
|
|
153
|
+
if not first_arg.args:
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
return first_arg.args[0]
|
|
157
|
+
|
|
158
|
+
def _check_manual_index_tracking(self, node: ast.For) -> None:
|
|
159
|
+
"""Check for manual index variable incrementation."""
|
|
160
|
+
tracker = IndexTracker()
|
|
161
|
+
for stmt in node.body:
|
|
162
|
+
tracker.visit(stmt)
|
|
163
|
+
|
|
164
|
+
if tracker.manual_indices:
|
|
165
|
+
self.add_issue(
|
|
166
|
+
self._create_issue(
|
|
167
|
+
node,
|
|
168
|
+
severity=Severity.INFO,
|
|
169
|
+
rule_id="L002",
|
|
170
|
+
message=f"Manual index tracking detected: {', '.join(tracker.manual_indices)}",
|
|
171
|
+
suggestion="Use enumerate() to track indices automatically",
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def _check_nested_loop_optimization(self, node: ast.For) -> None:
|
|
176
|
+
"""Check for nested loops that might benefit from preprocessing."""
|
|
177
|
+
# Use a visitor to count nested loops more efficiently
|
|
178
|
+
nested_loop_count = self._count_nested_loops(node)
|
|
179
|
+
|
|
180
|
+
if nested_loop_count > self.MIN_NESTED_LOOPS_FOR_WARNING:
|
|
181
|
+
# Check if inner loop is doing lookups
|
|
182
|
+
if self._has_comparison_in_loops(node):
|
|
183
|
+
self.add_issue(
|
|
184
|
+
self._create_issue(
|
|
185
|
+
node,
|
|
186
|
+
severity=Severity.MEDIUM,
|
|
187
|
+
rule_id="L003",
|
|
188
|
+
message="Nested loops with comparisons detected",
|
|
189
|
+
suggestion="Consider using a dictionary or set for O(1) lookups instead of nested iteration",
|
|
190
|
+
)
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def _count_nested_loops(self, node: ast.For) -> int:
|
|
194
|
+
"""Count the number of nested for loops."""
|
|
195
|
+
count = 1 # Current loop
|
|
196
|
+
for child in node.body:
|
|
197
|
+
if isinstance(child, ast.For):
|
|
198
|
+
count += self._count_nested_loops(child)
|
|
199
|
+
return count
|
|
200
|
+
|
|
201
|
+
def _has_comparison_in_loops(self, node: ast.For) -> bool:
|
|
202
|
+
"""Check if there are comparisons in nested loops."""
|
|
203
|
+
for child in node.body:
|
|
204
|
+
if isinstance(child, ast.Compare):
|
|
205
|
+
return True
|
|
206
|
+
if isinstance(child, ast.For):
|
|
207
|
+
if self._has_comparison_in_loops(child):
|
|
208
|
+
return True
|
|
209
|
+
# Check other compound statements
|
|
210
|
+
for grandchild in ast.walk(child):
|
|
211
|
+
if isinstance(grandchild, ast.Compare):
|
|
212
|
+
return True
|
|
213
|
+
return False
|
|
214
|
+
|
|
215
|
+
def _check_loop_invariants(self, node: ast.For) -> None:
|
|
216
|
+
"""Check for loop-invariant code that could be hoisted."""
|
|
217
|
+
# Look for expensive operations that don't depend on loop variable
|
|
218
|
+
loop_var = node.target
|
|
219
|
+
if not isinstance(loop_var, ast.Name):
|
|
220
|
+
return
|
|
221
|
+
|
|
222
|
+
loop_var_name = loop_var.id
|
|
223
|
+
checker = InvariantChecker(loop_var_name)
|
|
224
|
+
for stmt in node.body:
|
|
225
|
+
checker.visit(stmt)
|
|
226
|
+
|
|
227
|
+
if checker.invariant_calls:
|
|
228
|
+
self.add_issue( # pyrefactor: ignore
|
|
229
|
+
self._create_issue(
|
|
230
|
+
node,
|
|
231
|
+
severity=Severity.MEDIUM,
|
|
232
|
+
rule_id="L004",
|
|
233
|
+
message="Loop-invariant code detected inside loop",
|
|
234
|
+
suggestion="Move computations that don't depend on loop variable outside the loop",
|
|
235
|
+
)
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
def _loop_body_accesses_collection(
|
|
239
|
+
self, loop_node: ast.For, collection: ast.AST
|
|
240
|
+
) -> bool:
|
|
241
|
+
"""Check if loop body accesses the collection by index."""
|
|
242
|
+
if not isinstance(loop_node.target, ast.Name):
|
|
243
|
+
return False
|
|
244
|
+
|
|
245
|
+
index_var = loop_node.target.id
|
|
246
|
+
collection_dump = ast.dump(collection)
|
|
247
|
+
|
|
248
|
+
# More efficient: iterate through body statements only once
|
|
249
|
+
for stmt in loop_node.body:
|
|
250
|
+
for node in ast.walk(stmt):
|
|
251
|
+
if (
|
|
252
|
+
isinstance(node, ast.Subscript)
|
|
253
|
+
and isinstance(node.slice, ast.Name)
|
|
254
|
+
and node.slice.id == index_var
|
|
255
|
+
and ast.dump(node.value) == collection_dump
|
|
256
|
+
):
|
|
257
|
+
return True
|
|
258
|
+
return False
|
|
259
|
+
|
|
260
|
+
def visit_While(self, node: ast.While) -> None:
|
|
261
|
+
"""Check while loops for optimization opportunities."""
|
|
262
|
+
if self.is_suppressed(node):
|
|
263
|
+
self.generic_visit(node)
|
|
264
|
+
return
|
|
265
|
+
|
|
266
|
+
# Could add while-loop specific checks here
|
|
267
|
+
self.generic_visit(node)
|