pyrefactor 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,358 @@
1
+ """Code duplication detector for PyRefactor."""
2
+
3
+ import ast
4
+ import hashlib
5
+ import tokenize
6
+ from io import StringIO
7
+ from typing import Optional, cast
8
+
9
+ from ..ast_visitor import BaseDetector
10
+ from ..config import Config
11
+ from ..models import Issue, Severity
12
+
13
+
14
+ class _ExclusionVisitor(ast.NodeVisitor):
15
+ """AST visitor to identify line ranges to exclude from duplication detection."""
16
+
17
+ def __init__(self) -> None:
18
+ """Initialize exclusion visitor."""
19
+ self.ranges: list[tuple[int, int]] = []
20
+
21
+ def _add_node_range(self, node: ast.AST) -> None:
22
+ """Add node's line range to exclusions if available."""
23
+ if not hasattr(node, "lineno") or not hasattr(node, "end_lineno"):
24
+ return
25
+ lineno = cast(int, getattr(node, "lineno"))
26
+ end_lineno = cast(Optional[int], getattr(node, "end_lineno"))
27
+ if end_lineno is None:
28
+ return
29
+ self.ranges.append((lineno, end_lineno))
30
+
31
+ def visit_Set(self, node: ast.Set) -> None:
32
+ """Visit set literal and add its range to exclusions."""
33
+ self._add_node_range(node)
34
+ self.generic_visit(node)
35
+
36
+ def visit_List(self, node: ast.List) -> None:
37
+ """Visit list literal and add its range to exclusions."""
38
+ self._add_node_range(node)
39
+ self.generic_visit(node)
40
+
41
+ def visit_Dict(self, node: ast.Dict) -> None:
42
+ """Visit dict literal and add its range to exclusions."""
43
+ self._add_node_range(node)
44
+ self.generic_visit(node)
45
+
46
+ def visit_Tuple(self, node: ast.Tuple) -> None:
47
+ """Visit tuple literal and add its range to exclusions."""
48
+ self._add_node_range(node)
49
+ self.generic_visit(node)
50
+
51
+ def _add_docstring_range(self, node: ast.AST) -> None:
52
+ """Check if node has a docstring and add to exclusions."""
53
+ if not hasattr(node, "body"):
54
+ return
55
+ body = cast(list[ast.stmt], getattr(node, "body"))
56
+ if not body:
57
+ return
58
+
59
+ # Only these node types are supported by ast.get_docstring
60
+ if isinstance(
61
+ node, (ast.AsyncFunctionDef, ast.FunctionDef, ast.ClassDef, ast.Module)
62
+ ):
63
+ docstring_text = ast.get_docstring(node, clean=False)
64
+ if not docstring_text:
65
+ return
66
+
67
+ first_stmt = body[0]
68
+ if not isinstance(first_stmt, ast.Expr):
69
+ return
70
+ if not isinstance(first_stmt.value, ast.Constant):
71
+ return
72
+ if not isinstance(first_stmt.value.value, str):
73
+ return
74
+
75
+ self._add_node_range(first_stmt)
76
+
77
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
78
+ """Visit function and exclude its docstring."""
79
+ self._add_docstring_range(node)
80
+ self.generic_visit(node)
81
+
82
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
83
+ """Visit async function and exclude its docstring."""
84
+ self._add_docstring_range(node)
85
+ self.generic_visit(node)
86
+
87
+ def visit_ClassDef(self, node: ast.ClassDef) -> None:
88
+ """Visit class and exclude its docstring."""
89
+ self._add_docstring_range(node)
90
+ self.generic_visit(node)
91
+
92
+ def visit_Module(self, node: ast.Module) -> None:
93
+ """Visit module and exclude its docstring."""
94
+ self._add_docstring_range(node)
95
+ self.generic_visit(node)
96
+
97
+
98
+ class DuplicationDetector(BaseDetector):
99
+ """Detects code duplication."""
100
+
101
+ # Maximum block size to analyze (prevents excessive memory usage)
102
+ MAX_BLOCK_SIZE = 20
103
+
104
+ def __init__(self, config: Config, file_path: str, source_lines: list[str]) -> None:
105
+ """Initialize duplication detector."""
106
+ super().__init__(config, file_path, source_lines)
107
+ self.code_blocks: dict[str, list[tuple[int, int, str, str]]] = {}
108
+ self.checked = False
109
+ self.excluded_ranges: list[tuple[int, int]] = []
110
+
111
+ def get_detector_name(self) -> str:
112
+ """Return the name of this detector."""
113
+ return "duplication"
114
+
115
+ def analyze(self, tree: ast.AST) -> list[Issue]:
116
+ """Run duplication detection on the entire file."""
117
+ # First, identify exclusion zones (data structures and docstrings)
118
+ self._identify_excluded_ranges(tree)
119
+
120
+ # Then, extract all code blocks
121
+ self._extract_code_blocks()
122
+
123
+ # Finally, find duplicates
124
+ self._find_duplicates()
125
+
126
+ return self.issues
127
+
128
+ def _identify_excluded_ranges(self, tree: ast.AST) -> None:
129
+ """Identify line ranges that should be excluded from duplication detection.
130
+
131
+ This includes:
132
+ - Data structure literals (Set, List, Dict, Tuple)
133
+ - Docstrings (first string in functions/classes/modules)
134
+
135
+ Args:
136
+ tree: The AST of the source file
137
+ """
138
+ visitor = _ExclusionVisitor()
139
+ visitor.visit(tree)
140
+ self.excluded_ranges = visitor.ranges
141
+
142
+ def _is_in_excluded_range(self, start_line: int, end_line: int) -> bool:
143
+ """Check if a line range overlaps with any excluded range.
144
+
145
+ Args:
146
+ start_line: Start line of the range (1-indexed)
147
+ end_line: End line of the range (1-indexed)
148
+
149
+ Returns:
150
+ True if the range overlaps with any excluded range
151
+ """
152
+ for excluded_start, excluded_end in self.excluded_ranges:
153
+ # Check if there's any overlap
154
+ if end_line >= excluded_start and start_line <= excluded_end:
155
+ return True
156
+ return False
157
+
158
+ def _extract_code_blocks(self) -> None:
159
+ """Extract code blocks for comparison."""
160
+ min_lines = self.config.duplication.min_duplicate_lines
161
+ total_lines = len(self.source_lines)
162
+
163
+ # Extract sliding windows of code with optimized range
164
+ for start in range(total_lines):
165
+ max_length = min(self.MAX_BLOCK_SIZE, total_lines - start)
166
+ for length in range(min_lines, max_length + 1):
167
+ end = start + length
168
+ if end > total_lines:
169
+ break
170
+
171
+ # Skip blocks that are in excluded ranges (data structures, docstrings)
172
+ if self._is_in_excluded_range(start + 1, end):
173
+ continue
174
+
175
+ # Get the code block
176
+ code_block = "\n".join(self.source_lines[start:end])
177
+
178
+ # Skip if block is mostly whitespace or comments
179
+ if not self._is_meaningful_block(code_block):
180
+ continue
181
+
182
+ # Normalize the code for comparison
183
+ normalized = self._normalize_code(code_block)
184
+ if not normalized:
185
+ continue
186
+
187
+ # Hash the normalized code
188
+ code_hash = hashlib.md5(normalized.encode()).hexdigest()
189
+
190
+ if code_hash not in self.code_blocks:
191
+ self.code_blocks[code_hash] = []
192
+
193
+ # Store normalized code along with block to avoid re-normalization
194
+ self.code_blocks[code_hash].append(
195
+ (start + 1, end, code_block, normalized)
196
+ ) # +1 for 1-indexed lines
197
+
198
+ def _find_duplicates(self) -> None:
199
+ """Find and report duplicate code blocks."""
200
+ reported_ranges: list[tuple[int, int]] = []
201
+
202
+ for _, occurrences in self.code_blocks.items():
203
+ if len(occurrences) <= 1:
204
+ continue
205
+
206
+ # Sort by line number
207
+ sorted_occurrences = sorted(
208
+ occurrences, key=lambda item: item[0] # type: ignore[misc]
209
+ )
210
+
211
+ # Report each duplicate (except the first occurrence)
212
+ first_start, first_end, _, first_normalized = sorted_occurrences[0]
213
+
214
+ for start, end, _, normalized in sorted_occurrences[1:]:
215
+ # Check for suppression comments
216
+ if self._is_block_suppressed(start):
217
+ continue
218
+
219
+ # Skip if this range overlaps with first occurrence or already reported ranges
220
+ if self._overlaps_with_reported(start, end, reported_ranges):
221
+ continue
222
+ if self._overlaps_with_reported(start, end, [(first_start, first_end)]):
223
+ continue
224
+
225
+ # Check similarity using already-normalized code
226
+ similarity = self._calculate_similarity_from_normalized(
227
+ first_normalized, normalized
228
+ )
229
+ threshold = self.config.duplication.similarity_threshold
230
+
231
+ if similarity >= threshold:
232
+ self.add_issue(
233
+ Issue(
234
+ file=self.file_path,
235
+ line=start,
236
+ column=0,
237
+ severity=Severity.MEDIUM,
238
+ rule_id="D001",
239
+ message=f"Duplicate code block (lines {start}-{end}) similar to lines {first_start}-{first_end}",
240
+ suggestion="Extract duplicated code to a reusable function or method",
241
+ end_line=end,
242
+ )
243
+ )
244
+ reported_ranges.append((start, end))
245
+
246
+ def _is_block_suppressed(self, line: int) -> bool:
247
+ """Check if a code block has a suppression comment.
248
+
249
+ Args:
250
+ line: The line number to check
251
+
252
+ Returns:
253
+ True if the line has a suppression comment
254
+ """
255
+ if line < 1 or line > len(self.source_lines):
256
+ return False
257
+
258
+ # Check a range of lines before the block (up to 3 lines back)
259
+ for offset in range(4):
260
+ check_line = line - offset
261
+ if check_line < 1:
262
+ break
263
+ current_line = self.source_lines[check_line - 1]
264
+ if "# pyrefactor: ignore" in current_line or "# noqa" in current_line:
265
+ return True
266
+
267
+ return False
268
+
269
+ def _overlaps_with_reported( # pyrefactor: ignore
270
+ self, start: int, end: int, reported: list[tuple[int, int]]
271
+ ) -> bool:
272
+ """Check if a range overlaps with any already reported range.
273
+
274
+ Args:
275
+ start: Start line of the range
276
+ end: End line of the range
277
+ reported: List of already reported ranges
278
+
279
+ Returns:
280
+ True if the range overlaps with any reported range
281
+ """
282
+ for reported_start, reported_end in reported:
283
+ # Check for any overlap (using De Morgan's law simplification)
284
+ if end >= reported_start and start <= reported_end:
285
+ return True
286
+ return False
287
+
288
+ def _is_meaningful_block(self, code: str) -> bool:
289
+ """Check if a code block is meaningful (not just whitespace/comments)."""
290
+ lines = code.strip().split("\n")
291
+ meaningful_lines = sum(
292
+ 1
293
+ for line in lines
294
+ if (stripped := line.strip()) and not stripped.startswith("#")
295
+ )
296
+
297
+ return meaningful_lines >= self.config.duplication.min_duplicate_lines
298
+
299
+ def _normalize_code(self, code: str) -> str:
300
+ """Normalize code for comparison by tokenizing and removing literals."""
301
+ try:
302
+ tokens = tokenize.generate_tokens(StringIO(code).readline)
303
+ normalized_tokens = [self._normalize_token(token) for token in tokens]
304
+ # Filter out None values
305
+ return " ".join(token for token in normalized_tokens if token)
306
+ except (tokenize.TokenError, IndentationError, SyntaxError):
307
+ # Return empty string for code blocks that can't be tokenized
308
+ # (e.g., incomplete blocks with inconsistent indentation)
309
+ return ""
310
+
311
+ def _normalize_token(
312
+ self, token: tokenize.TokenInfo
313
+ ) -> Optional[str]: # pyrefactor: ignore
314
+ """Normalize a single token for comparison.
315
+
316
+ Args:
317
+ token: Token to normalize
318
+
319
+ Returns:
320
+ Normalized token string or None if token should be skipped
321
+ """
322
+ if token.type == tokenize.NAME:
323
+ return token.string
324
+ if token.type == tokenize.OP:
325
+ return token.string
326
+ if token.type in (tokenize.NUMBER, tokenize.STRING):
327
+ return "LITERAL"
328
+ if token.type == tokenize.NEWLINE:
329
+ return "\n"
330
+ return None
331
+
332
+ def _calculate_similarity_from_normalized( # pyrefactor: ignore
333
+ self, normalized1: str, normalized2: str
334
+ ) -> float:
335
+ """Calculate similarity between two already-normalized code blocks.
336
+
337
+ Args:
338
+ normalized1: First normalized code block
339
+ normalized2: Second normalized code block
340
+
341
+ Returns:
342
+ Similarity score between 0.0 and 1.0
343
+ """
344
+ tokens1 = set(normalized1.split())
345
+ tokens2 = set(normalized2.split())
346
+
347
+ if not tokens1 or not tokens2:
348
+ return 0.0
349
+
350
+ intersection = len(tokens1 & tokens2)
351
+ union = len(tokens1 | tokens2)
352
+
353
+ return intersection / union if union > 0 else 0.0
354
+
355
+ def visit(self, node: ast.AST) -> None:
356
+ """Override visit to prevent default traversal."""
357
+ # Duplication detection works at file level, not node level
358
+ ...
@@ -0,0 +1,267 @@
1
+ """Loop optimization detector for PyRefactor."""
2
+
3
+ import ast
4
+ from typing import Optional, cast
5
+
6
+ from ..ast_visitor import BaseDetector
7
+ from ..models import Issue, Severity
8
+
9
+
10
+ class IndexTracker(ast.NodeVisitor):
11
+ """Track manual index increments in loops."""
12
+
13
+ def __init__(self) -> None:
14
+ self.manual_indices: set[str] = set()
15
+
16
+ def visit_AugAssign(self, aug_node: ast.AugAssign) -> None:
17
+ """Track += 1 operations on variables."""
18
+ if (
19
+ isinstance(aug_node.target, ast.Name)
20
+ and isinstance(aug_node.op, ast.Add)
21
+ and isinstance(aug_node.value, ast.Constant)
22
+ and aug_node.value.value == 1
23
+ ):
24
+ self.manual_indices.add(aug_node.target.id)
25
+
26
+
27
+ class InvariantChecker(ast.NodeVisitor):
28
+ """Check for loop-invariant computations."""
29
+
30
+ # Methods that are potentially expensive when called repeatedly in loops
31
+ EXPENSIVE_METHODS = frozenset(
32
+ {"compile", "search", "match", "split", "findall", "sub"}
33
+ )
34
+
35
+ def __init__(self, var_name: str) -> None:
36
+ self.var_name = var_name
37
+ self.invariant_calls: list[ast.Call] = []
38
+
39
+ def visit_Call(self, call_node: ast.Call) -> None:
40
+ """Check if call depends on loop variable."""
41
+ if not self._depends_on_var(call_node):
42
+ # Check if it's a potentially expensive call
43
+ if isinstance(call_node.func, ast.Attribute):
44
+ if call_node.func.attr in self.EXPENSIVE_METHODS:
45
+ self.invariant_calls.append(call_node)
46
+ self.generic_visit(call_node)
47
+
48
+ def _depends_on_var(self, node: ast.AST) -> bool:
49
+ """Check if node uses the loop variable."""
50
+ for child in ast.walk(node):
51
+ if isinstance(child, ast.Name) and child.id == self.var_name:
52
+ return True
53
+ return False
54
+
55
+
56
+ class LoopsDetector(BaseDetector):
57
+ """Detects loop optimization opportunities."""
58
+
59
+ # Minimum nesting level to trigger nested loop optimization warning
60
+ MIN_NESTED_LOOPS_FOR_WARNING = 2
61
+
62
+ def get_detector_name(self) -> str:
63
+ """Return the name of this detector."""
64
+ return "loops"
65
+
66
+ def _create_issue(
67
+ self,
68
+ node: ast.AST,
69
+ *,
70
+ severity: Severity,
71
+ rule_id: str,
72
+ message: str,
73
+ suggestion: str,
74
+ ) -> Issue:
75
+ """Create an Issue object for loop-related issues."""
76
+ return Issue(
77
+ file=self.file_path,
78
+ line=cast(int, getattr(node, "lineno", 0)),
79
+ column=cast(int, getattr(node, "col_offset", 0)),
80
+ severity=severity,
81
+ rule_id=rule_id,
82
+ message=message,
83
+ suggestion=suggestion,
84
+ )
85
+
86
+ def visit_For(self, node: ast.For) -> None:
87
+ """Check for loop optimization opportunities."""
88
+ if self.is_suppressed(node):
89
+ self.generic_visit(node)
90
+ return
91
+
92
+ # Check for range(len()) pattern
93
+ self._check_range_len_pattern(node)
94
+
95
+ # Check for manual index tracking
96
+ self._check_manual_index_tracking(node)
97
+
98
+ # Check for nested loops that might benefit from dict lookup
99
+ self._check_nested_loop_optimization(node)
100
+
101
+ # Check for loop invariant code
102
+ self._check_loop_invariants(node)
103
+
104
+ self.generic_visit(node)
105
+
106
+ def _check_range_len_pattern(self, node: ast.For) -> None:
107
+ """Check for range(len(x)) that should use enumerate."""
108
+ # Validate the basic pattern structure
109
+ if not self._is_range_len_call(node):
110
+ return
111
+
112
+ # Extract the collection being iterated
113
+ collection = self._extract_collection_from_range_len(node)
114
+ if collection is None:
115
+ return
116
+
117
+ # Check if the loop body actually uses indexed access
118
+ if not self._loop_body_accesses_collection(node, collection):
119
+ return
120
+
121
+ self.add_issue(
122
+ self._create_issue(
123
+ node,
124
+ severity=Severity.LOW,
125
+ rule_id="L001",
126
+ message="Use enumerate() instead of range(len())",
127
+ suggestion="Replace 'for i in range(len(items)):' with 'for i, item in enumerate(items):'",
128
+ )
129
+ )
130
+
131
+ def _is_range_len_call(self, node: ast.For) -> bool:
132
+ """Check if the loop uses range(len(...)) pattern."""
133
+ if not isinstance(node.iter, ast.Call):
134
+ return False
135
+
136
+ if not isinstance(node.iter.func, ast.Name):
137
+ return False
138
+
139
+ return node.iter.func.id == "range" and bool(node.iter.args)
140
+
141
+ def _extract_collection_from_range_len(self, node: ast.For) -> Optional[ast.AST]:
142
+ """Extract the collection from a range(len(...)) call."""
143
+ if not isinstance(node.iter, ast.Call) or not node.iter.args:
144
+ return None
145
+
146
+ first_arg = node.iter.args[0]
147
+ if not isinstance(first_arg, ast.Call):
148
+ return None
149
+
150
+ if not isinstance(first_arg.func, ast.Name) or first_arg.func.id != "len":
151
+ return None
152
+
153
+ if not first_arg.args:
154
+ return None
155
+
156
+ return first_arg.args[0]
157
+
158
+ def _check_manual_index_tracking(self, node: ast.For) -> None:
159
+ """Check for manual index variable incrementation."""
160
+ tracker = IndexTracker()
161
+ for stmt in node.body:
162
+ tracker.visit(stmt)
163
+
164
+ if tracker.manual_indices:
165
+ self.add_issue(
166
+ self._create_issue(
167
+ node,
168
+ severity=Severity.INFO,
169
+ rule_id="L002",
170
+ message=f"Manual index tracking detected: {', '.join(tracker.manual_indices)}",
171
+ suggestion="Use enumerate() to track indices automatically",
172
+ )
173
+ )
174
+
175
+ def _check_nested_loop_optimization(self, node: ast.For) -> None:
176
+ """Check for nested loops that might benefit from preprocessing."""
177
+ # Use a visitor to count nested loops more efficiently
178
+ nested_loop_count = self._count_nested_loops(node)
179
+
180
+ if nested_loop_count > self.MIN_NESTED_LOOPS_FOR_WARNING:
181
+ # Check if inner loop is doing lookups
182
+ if self._has_comparison_in_loops(node):
183
+ self.add_issue(
184
+ self._create_issue(
185
+ node,
186
+ severity=Severity.MEDIUM,
187
+ rule_id="L003",
188
+ message="Nested loops with comparisons detected",
189
+ suggestion="Consider using a dictionary or set for O(1) lookups instead of nested iteration",
190
+ )
191
+ )
192
+
193
+ def _count_nested_loops(self, node: ast.For) -> int:
194
+ """Count the number of nested for loops."""
195
+ count = 1 # Current loop
196
+ for child in node.body:
197
+ if isinstance(child, ast.For):
198
+ count += self._count_nested_loops(child)
199
+ return count
200
+
201
+ def _has_comparison_in_loops(self, node: ast.For) -> bool:
202
+ """Check if there are comparisons in nested loops."""
203
+ for child in node.body:
204
+ if isinstance(child, ast.Compare):
205
+ return True
206
+ if isinstance(child, ast.For):
207
+ if self._has_comparison_in_loops(child):
208
+ return True
209
+ # Check other compound statements
210
+ for grandchild in ast.walk(child):
211
+ if isinstance(grandchild, ast.Compare):
212
+ return True
213
+ return False
214
+
215
+ def _check_loop_invariants(self, node: ast.For) -> None:
216
+ """Check for loop-invariant code that could be hoisted."""
217
+ # Look for expensive operations that don't depend on loop variable
218
+ loop_var = node.target
219
+ if not isinstance(loop_var, ast.Name):
220
+ return
221
+
222
+ loop_var_name = loop_var.id
223
+ checker = InvariantChecker(loop_var_name)
224
+ for stmt in node.body:
225
+ checker.visit(stmt)
226
+
227
+ if checker.invariant_calls:
228
+ self.add_issue( # pyrefactor: ignore
229
+ self._create_issue(
230
+ node,
231
+ severity=Severity.MEDIUM,
232
+ rule_id="L004",
233
+ message="Loop-invariant code detected inside loop",
234
+ suggestion="Move computations that don't depend on loop variable outside the loop",
235
+ )
236
+ )
237
+
238
+ def _loop_body_accesses_collection(
239
+ self, loop_node: ast.For, collection: ast.AST
240
+ ) -> bool:
241
+ """Check if loop body accesses the collection by index."""
242
+ if not isinstance(loop_node.target, ast.Name):
243
+ return False
244
+
245
+ index_var = loop_node.target.id
246
+ collection_dump = ast.dump(collection)
247
+
248
+ # More efficient: iterate through body statements only once
249
+ for stmt in loop_node.body:
250
+ for node in ast.walk(stmt):
251
+ if (
252
+ isinstance(node, ast.Subscript)
253
+ and isinstance(node.slice, ast.Name)
254
+ and node.slice.id == index_var
255
+ and ast.dump(node.value) == collection_dump
256
+ ):
257
+ return True
258
+ return False
259
+
260
+ def visit_While(self, node: ast.While) -> None:
261
+ """Check while loops for optimization opportunities."""
262
+ if self.is_suppressed(node):
263
+ self.generic_visit(node)
264
+ return
265
+
266
+ # Could add while-loop specific checks here
267
+ self.generic_visit(node)