pyrefactor 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyrefactor/__init__.py +3 -0
- pyrefactor/__main__.py +231 -0
- pyrefactor/analyzer.py +185 -0
- pyrefactor/ast_visitor.py +197 -0
- pyrefactor/config.py +224 -0
- pyrefactor/detectors/__init__.py +23 -0
- pyrefactor/detectors/boolean_logic.py +231 -0
- pyrefactor/detectors/comparisons.py +353 -0
- pyrefactor/detectors/complexity.py +248 -0
- pyrefactor/detectors/context_manager.py +188 -0
- pyrefactor/detectors/control_flow.py +156 -0
- pyrefactor/detectors/dict_operations.py +346 -0
- pyrefactor/detectors/duplication.py +358 -0
- pyrefactor/detectors/loops.py +267 -0
- pyrefactor/detectors/performance.py +267 -0
- pyrefactor/models.py +98 -0
- pyrefactor/py.typed +0 -0
- pyrefactor/reporter.py +208 -0
- pyrefactor-1.0.1.dist-info/METADATA +353 -0
- pyrefactor-1.0.1.dist-info/RECORD +24 -0
- pyrefactor-1.0.1.dist-info/WHEEL +5 -0
- pyrefactor-1.0.1.dist-info/entry_points.txt +2 -0
- pyrefactor-1.0.1.dist-info/licenses/LICENSE.md +70 -0
- pyrefactor-1.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""Context manager detector for PyRefactor."""
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
from typing import Optional, Union, cast
|
|
5
|
+
|
|
6
|
+
from ..ast_visitor import BaseDetector
|
|
7
|
+
from ..config import Config
|
|
8
|
+
from ..models import Issue, Severity
|
|
9
|
+
|
|
10
|
+
# Functions that return context managers and should be used with 'with'
|
|
11
|
+
CONTEXT_MANAGER_FUNCS = frozenset(
|
|
12
|
+
{
|
|
13
|
+
"open",
|
|
14
|
+
"file",
|
|
15
|
+
"urlopen",
|
|
16
|
+
"NamedTemporaryFile",
|
|
17
|
+
"SpooledTemporaryFile",
|
|
18
|
+
"TemporaryDirectory",
|
|
19
|
+
"TemporaryFile",
|
|
20
|
+
"ZipFile",
|
|
21
|
+
"PyZipFile",
|
|
22
|
+
"TarFile",
|
|
23
|
+
"Popen",
|
|
24
|
+
"Pool",
|
|
25
|
+
}
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
# Methods that return context managers
|
|
29
|
+
CONTEXT_MANAGER_METHODS = frozenset({"open", "acquire", "start"})
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ContextManagerDetector(BaseDetector):
|
|
33
|
+
"""Detects resource-allocating operations that should use 'with' statements."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, config: Config, file_path: str, source_lines: list[str]) -> None:
|
|
36
|
+
"""Initialize context manager detector."""
|
|
37
|
+
super().__init__(config, file_path, source_lines)
|
|
38
|
+
self.resource_assignments: dict[str, Union[ast.Assign, ast.AnnAssign]] = {}
|
|
39
|
+
self.used_in_with: set[str] = set()
|
|
40
|
+
self.parent_map: dict[ast.AST, ast.AST] = {}
|
|
41
|
+
|
|
42
|
+
def analyze(self, tree: ast.AST) -> list[Issue]:
|
|
43
|
+
"""Run the detector on an AST and return issues found."""
|
|
44
|
+
# Build parent map once for the entire tree
|
|
45
|
+
self._build_parent_map(tree)
|
|
46
|
+
self.visit(tree)
|
|
47
|
+
return self.issues
|
|
48
|
+
|
|
49
|
+
def _build_parent_map(self, tree: ast.AST) -> None:
|
|
50
|
+
"""Build a map of child -> parent for the entire tree."""
|
|
51
|
+
for parent in ast.walk(tree):
|
|
52
|
+
for child in ast.iter_child_nodes(parent):
|
|
53
|
+
self.parent_map[child] = parent
|
|
54
|
+
|
|
55
|
+
def get_detector_name(self) -> str:
|
|
56
|
+
"""Return the name of this detector."""
|
|
57
|
+
return "context_manager"
|
|
58
|
+
|
|
59
|
+
def _create_issue(
|
|
60
|
+
self,
|
|
61
|
+
node: ast.AST,
|
|
62
|
+
*,
|
|
63
|
+
severity: Severity,
|
|
64
|
+
rule_id: str,
|
|
65
|
+
message: str,
|
|
66
|
+
suggestion: str,
|
|
67
|
+
) -> Issue:
|
|
68
|
+
"""Create an Issue object for context manager issues."""
|
|
69
|
+
return Issue(
|
|
70
|
+
file=self.file_path,
|
|
71
|
+
line=cast(int, getattr(node, "lineno", 0)),
|
|
72
|
+
column=cast(int, getattr(node, "col_offset", 0)),
|
|
73
|
+
severity=severity,
|
|
74
|
+
rule_id=rule_id,
|
|
75
|
+
message=message,
|
|
76
|
+
suggestion=suggestion,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def _is_context_manager_call(self, node: ast.Call) -> bool:
|
|
80
|
+
"""Check if a call returns a context manager."""
|
|
81
|
+
# Check for direct function calls (e.g., open(), file())
|
|
82
|
+
if isinstance(node.func, ast.Name):
|
|
83
|
+
return node.func.id in CONTEXT_MANAGER_FUNCS
|
|
84
|
+
|
|
85
|
+
# Check for method calls (e.g., lock.acquire(), Path.open())
|
|
86
|
+
if isinstance(node.func, ast.Attribute):
|
|
87
|
+
return node.func.attr in CONTEXT_MANAGER_METHODS
|
|
88
|
+
|
|
89
|
+
return False
|
|
90
|
+
|
|
91
|
+
def _is_used_in_return(self, node: ast.Call) -> bool:
|
|
92
|
+
"""Check if the call is part of a return statement."""
|
|
93
|
+
current = self.parent_map.get(node)
|
|
94
|
+
while current:
|
|
95
|
+
if isinstance(current, ast.Return):
|
|
96
|
+
return True
|
|
97
|
+
# Stop at function boundaries
|
|
98
|
+
if isinstance(current, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
99
|
+
return False
|
|
100
|
+
current = self.parent_map.get(current)
|
|
101
|
+
return False
|
|
102
|
+
|
|
103
|
+
def _is_used_in_with_context(self, node: ast.Call) -> bool:
|
|
104
|
+
"""Check if the call is already used in a 'with' statement."""
|
|
105
|
+
current = self.parent_map.get(node)
|
|
106
|
+
while current:
|
|
107
|
+
if isinstance(current, ast.With):
|
|
108
|
+
return True
|
|
109
|
+
# Stop at function boundaries
|
|
110
|
+
if isinstance(current, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
111
|
+
return False
|
|
112
|
+
current = self.parent_map.get(current)
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
def visit_Assign(self, node: ast.Assign) -> None:
|
|
116
|
+
"""Check for resource-allocating assignments."""
|
|
117
|
+
if self.is_suppressed(node):
|
|
118
|
+
self.generic_visit(node)
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
# Check if the value is a context manager call
|
|
122
|
+
if isinstance(node.value, ast.Call) and self._is_context_manager_call(
|
|
123
|
+
node.value
|
|
124
|
+
):
|
|
125
|
+
self._check_and_report_context_manager(node, node.value)
|
|
126
|
+
|
|
127
|
+
self.generic_visit(node)
|
|
128
|
+
|
|
129
|
+
def visit_Expr(self, node: ast.Expr) -> None:
|
|
130
|
+
"""Check for context manager calls used as statements without assignment."""
|
|
131
|
+
if self.is_suppressed(node):
|
|
132
|
+
self.generic_visit(node)
|
|
133
|
+
return
|
|
134
|
+
|
|
135
|
+
# Check if the expression contains a context manager call (could be chained)
|
|
136
|
+
cm_call = self._find_context_manager_call(node.value)
|
|
137
|
+
if cm_call:
|
|
138
|
+
self._check_and_report_context_manager(node, cm_call)
|
|
139
|
+
|
|
140
|
+
self.generic_visit(node)
|
|
141
|
+
|
|
142
|
+
def _check_and_report_context_manager(
|
|
143
|
+
self, node: ast.AST, cm_call: ast.Call
|
|
144
|
+
) -> None:
|
|
145
|
+
"""Check and report if a context manager call should use 'with' statement."""
|
|
146
|
+
# Skip if already in a with statement
|
|
147
|
+
if self._is_used_in_with_context(cm_call):
|
|
148
|
+
return
|
|
149
|
+
|
|
150
|
+
# Skip if this is in a return statement or being passed
|
|
151
|
+
if self._is_used_in_return(cm_call):
|
|
152
|
+
return
|
|
153
|
+
|
|
154
|
+
# Get the function name for a better error message
|
|
155
|
+
func_name = self._get_func_name(cm_call)
|
|
156
|
+
|
|
157
|
+
self.add_issue(
|
|
158
|
+
self._create_issue(
|
|
159
|
+
node,
|
|
160
|
+
severity=Severity.HIGH,
|
|
161
|
+
rule_id="R001",
|
|
162
|
+
message=f"Resource-allocating operation '{func_name}' should use 'with' statement",
|
|
163
|
+
suggestion=f"Use 'with {func_name}(...) as resource:' to ensure proper resource cleanup",
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def _find_context_manager_call(self, node: ast.AST) -> Optional[ast.Call]:
|
|
168
|
+
"""Find a context manager call in an expression tree."""
|
|
169
|
+
if isinstance(node, ast.Call) and self._is_context_manager_call(node):
|
|
170
|
+
return node
|
|
171
|
+
|
|
172
|
+
# Check nested calls (e.g., open(...).read())
|
|
173
|
+
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
|
|
174
|
+
# Check the object being called
|
|
175
|
+
if isinstance(node.func.value, ast.Call) and self._is_context_manager_call(
|
|
176
|
+
node.func.value
|
|
177
|
+
):
|
|
178
|
+
return node.func.value
|
|
179
|
+
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
def _get_func_name(self, call: ast.Call) -> str:
|
|
183
|
+
"""Extract the function name from a call node."""
|
|
184
|
+
if isinstance(call.func, ast.Name):
|
|
185
|
+
return call.func.id
|
|
186
|
+
if isinstance(call.func, ast.Attribute):
|
|
187
|
+
return call.func.attr
|
|
188
|
+
return "unknown"
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
"""Control flow simplification detector for PyRefactor."""
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
from typing import cast
|
|
5
|
+
|
|
6
|
+
from ..ast_visitor import BaseDetector
|
|
7
|
+
from ..models import Issue, Severity
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ControlFlowDetector(BaseDetector):
|
|
11
|
+
"""Detects unnecessary else/elif clauses after return/raise/break/continue."""
|
|
12
|
+
|
|
13
|
+
def get_detector_name(self) -> str:
|
|
14
|
+
"""Return the name of this detector."""
|
|
15
|
+
return "control_flow"
|
|
16
|
+
|
|
17
|
+
def _create_issue(
|
|
18
|
+
self,
|
|
19
|
+
node: ast.AST,
|
|
20
|
+
*,
|
|
21
|
+
severity: Severity,
|
|
22
|
+
rule_id: str,
|
|
23
|
+
message: str,
|
|
24
|
+
suggestion: str,
|
|
25
|
+
) -> Issue:
|
|
26
|
+
"""Create an Issue object for control flow issues."""
|
|
27
|
+
return Issue(
|
|
28
|
+
file=self.file_path,
|
|
29
|
+
line=cast(int, getattr(node, "lineno", 0)),
|
|
30
|
+
column=cast(int, getattr(node, "col_offset", 0)),
|
|
31
|
+
severity=severity,
|
|
32
|
+
rule_id=rule_id,
|
|
33
|
+
message=message,
|
|
34
|
+
suggestion=suggestion,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# Map terminator types to rule IDs
|
|
38
|
+
_TERMINATOR_RULES = {
|
|
39
|
+
"return": "R002",
|
|
40
|
+
"raise": "R003",
|
|
41
|
+
"break": "R004",
|
|
42
|
+
"continue": "R005",
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
def visit_If(self, node: ast.If) -> None:
|
|
46
|
+
"""Check for unnecessary else clauses."""
|
|
47
|
+
if self.is_suppressed(node):
|
|
48
|
+
self.generic_visit(node)
|
|
49
|
+
return
|
|
50
|
+
|
|
51
|
+
self._check_unnecessary_else(node)
|
|
52
|
+
self.generic_visit(node)
|
|
53
|
+
|
|
54
|
+
def _check_unnecessary_else(self, node: ast.If) -> None:
|
|
55
|
+
"""Check if the else clause is unnecessary after a terminating statement."""
|
|
56
|
+
# Early return if no else clause
|
|
57
|
+
if not node.orelse:
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
# Check if if-body always terminates
|
|
61
|
+
if not self._always_terminates(node.body):
|
|
62
|
+
return
|
|
63
|
+
|
|
64
|
+
# Determine what kind of termination
|
|
65
|
+
terminator = self._get_terminator_type(node.body)
|
|
66
|
+
|
|
67
|
+
# Report issue if we have a known terminator
|
|
68
|
+
if terminator in self._TERMINATOR_RULES:
|
|
69
|
+
self._report_unnecessary_else(
|
|
70
|
+
node, self._TERMINATOR_RULES[terminator], terminator
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def _always_terminates(self, body: list[ast.stmt]) -> bool:
|
|
74
|
+
"""Check if a code block always terminates (return/raise/break/continue)."""
|
|
75
|
+
if not body:
|
|
76
|
+
return False
|
|
77
|
+
|
|
78
|
+
# Check the last statement
|
|
79
|
+
last_stmt = body[-1]
|
|
80
|
+
|
|
81
|
+
# Direct terminating statements
|
|
82
|
+
if isinstance(last_stmt, (ast.Return, ast.Raise, ast.Break, ast.Continue)):
|
|
83
|
+
return True
|
|
84
|
+
|
|
85
|
+
# If statement - check if all branches terminate
|
|
86
|
+
if isinstance(last_stmt, ast.If):
|
|
87
|
+
# Must have an else clause to ensure all paths terminate
|
|
88
|
+
if not last_stmt.orelse:
|
|
89
|
+
return False
|
|
90
|
+
|
|
91
|
+
# Check if both if and else terminate
|
|
92
|
+
if_terminates = self._always_terminates(last_stmt.body)
|
|
93
|
+
else_terminates = self._always_terminates(last_stmt.orelse)
|
|
94
|
+
return if_terminates and else_terminates
|
|
95
|
+
|
|
96
|
+
# Try statement - all branches must terminate
|
|
97
|
+
if isinstance(last_stmt, ast.Try):
|
|
98
|
+
try_terminates = self._always_terminates(last_stmt.body)
|
|
99
|
+
handlers_terminate = all(
|
|
100
|
+
self._always_terminates(handler.body) for handler in last_stmt.handlers
|
|
101
|
+
)
|
|
102
|
+
# If there's an else clause, it must also terminate
|
|
103
|
+
else_terminates = (
|
|
104
|
+
self._always_terminates(last_stmt.orelse) if last_stmt.orelse else True
|
|
105
|
+
)
|
|
106
|
+
# Finally doesn't affect termination
|
|
107
|
+
return try_terminates and handlers_terminate and else_terminates
|
|
108
|
+
|
|
109
|
+
return False
|
|
110
|
+
|
|
111
|
+
def _get_terminator_type(self, body: list[ast.stmt]) -> str:
|
|
112
|
+
"""Get the type of terminator in a code block."""
|
|
113
|
+
if not body:
|
|
114
|
+
return ""
|
|
115
|
+
|
|
116
|
+
last_stmt = body[-1]
|
|
117
|
+
|
|
118
|
+
# Map statement types to their string names
|
|
119
|
+
terminator_map = {
|
|
120
|
+
ast.Return: "return",
|
|
121
|
+
ast.Raise: "raise",
|
|
122
|
+
ast.Break: "break",
|
|
123
|
+
ast.Continue: "continue",
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
stmt_type = type(last_stmt)
|
|
127
|
+
if stmt_type in terminator_map:
|
|
128
|
+
return terminator_map[stmt_type]
|
|
129
|
+
|
|
130
|
+
# Check nested structures
|
|
131
|
+
if isinstance(last_stmt, ast.If):
|
|
132
|
+
# Get terminator from if body (assuming we've already checked it terminates)
|
|
133
|
+
return self._get_terminator_type(last_stmt.body)
|
|
134
|
+
|
|
135
|
+
return ""
|
|
136
|
+
|
|
137
|
+
def _report_unnecessary_else(
|
|
138
|
+
self, node: ast.If, rule_id: str, terminator: str
|
|
139
|
+
) -> None:
|
|
140
|
+
"""Report an unnecessary else clause."""
|
|
141
|
+
# Determine if it's an elif or else
|
|
142
|
+
if len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If):
|
|
143
|
+
clause_type = "elif"
|
|
144
|
+
else:
|
|
145
|
+
clause_type = "else"
|
|
146
|
+
|
|
147
|
+
self.add_issue(
|
|
148
|
+
self._create_issue(
|
|
149
|
+
node,
|
|
150
|
+
severity=Severity.MEDIUM,
|
|
151
|
+
rule_id=rule_id,
|
|
152
|
+
message=f"Unnecessary '{clause_type}' after '{terminator}' statement",
|
|
153
|
+
suggestion=f"Remove '{clause_type}' and unindent its body since the "
|
|
154
|
+
f"preceding code always executes '{terminator}'",
|
|
155
|
+
)
|
|
156
|
+
)
|
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
"""Dictionary operations detector for PyRefactor."""
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
from typing import Optional, Tuple, cast
|
|
5
|
+
|
|
6
|
+
from ..ast_visitor import BaseDetector
|
|
7
|
+
from ..models import Issue, Severity
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DictOperationsDetector(BaseDetector):
|
|
11
|
+
"""Detects inefficient or non-idiomatic dictionary operations."""
|
|
12
|
+
|
|
13
|
+
def get_detector_name(self) -> str:
|
|
14
|
+
"""Return the name of this detector."""
|
|
15
|
+
return "dict_operations"
|
|
16
|
+
|
|
17
|
+
def _create_issue(
|
|
18
|
+
self,
|
|
19
|
+
node: ast.AST,
|
|
20
|
+
*,
|
|
21
|
+
severity: Severity,
|
|
22
|
+
rule_id: str,
|
|
23
|
+
message: str,
|
|
24
|
+
suggestion: str,
|
|
25
|
+
) -> Issue:
|
|
26
|
+
"""Create an Issue object for dictionary operation issues."""
|
|
27
|
+
return Issue(
|
|
28
|
+
file=self.file_path,
|
|
29
|
+
line=cast(int, getattr(node, "lineno", 0)),
|
|
30
|
+
column=cast(int, getattr(node, "col_offset", 0)),
|
|
31
|
+
severity=severity,
|
|
32
|
+
rule_id=rule_id,
|
|
33
|
+
message=message,
|
|
34
|
+
suggestion=suggestion,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
def visit_If(self, node: ast.If) -> None:
|
|
38
|
+
"""Check for dict.get() opportunities."""
|
|
39
|
+
if self.is_suppressed(node):
|
|
40
|
+
self.generic_visit(node)
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
# Pattern: if key in dict: x = dict[key] else: x = default
|
|
44
|
+
self._check_dict_get_pattern(node)
|
|
45
|
+
|
|
46
|
+
self.generic_visit(node)
|
|
47
|
+
|
|
48
|
+
def _check_dict_get_pattern(self, node: ast.If) -> None:
|
|
49
|
+
"""Check for pattern that could use dict.get()."""
|
|
50
|
+
# Validate basic structure
|
|
51
|
+
if not self._is_valid_dict_get_structure(node):
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
# Extract and validate components
|
|
55
|
+
components = self._extract_dict_get_components(node)
|
|
56
|
+
if not components:
|
|
57
|
+
return
|
|
58
|
+
|
|
59
|
+
var_name, key_name, dict_name, default_val = components
|
|
60
|
+
|
|
61
|
+
self.add_issue(
|
|
62
|
+
self._create_issue(
|
|
63
|
+
node,
|
|
64
|
+
severity=Severity.LOW,
|
|
65
|
+
rule_id="R006",
|
|
66
|
+
message="Consider using dict.get() instead of if/else for key lookup",
|
|
67
|
+
suggestion=f"Use '{var_name} = {dict_name}.get({key_name}, {default_val})' "
|
|
68
|
+
f"instead of if/else block",
|
|
69
|
+
)
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def _is_valid_dict_get_structure(self, node: ast.If) -> bool:
|
|
73
|
+
"""Check if node has the basic structure for dict.get() refactoring."""
|
|
74
|
+
# Check if condition is "key in dict"
|
|
75
|
+
if not isinstance(node.test, ast.Compare):
|
|
76
|
+
return False
|
|
77
|
+
|
|
78
|
+
if len(node.test.ops) != 1 or not isinstance(node.test.ops[0], ast.In):
|
|
79
|
+
return False
|
|
80
|
+
|
|
81
|
+
# Must have both if and else branches with single assignments
|
|
82
|
+
if not node.orelse or len(node.body) != 1 or len(node.orelse) != 1:
|
|
83
|
+
return False
|
|
84
|
+
|
|
85
|
+
if_stmt = node.body[0]
|
|
86
|
+
else_stmt = node.orelse[0]
|
|
87
|
+
|
|
88
|
+
return isinstance(if_stmt, ast.Assign) and isinstance(else_stmt, ast.Assign)
|
|
89
|
+
|
|
90
|
+
def _extract_dict_get_components(
|
|
91
|
+
self, node: ast.If
|
|
92
|
+
) -> Optional[Tuple[str, str, str, str]]:
|
|
93
|
+
"""Extract variable names and values for dict.get() suggestion."""
|
|
94
|
+
if_stmt = node.body[0]
|
|
95
|
+
else_stmt = node.orelse[0]
|
|
96
|
+
|
|
97
|
+
# Validate assignments structure
|
|
98
|
+
if not self._validate_assignment_structure(if_stmt, else_stmt):
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
# Cast to Assign after validation
|
|
102
|
+
if_assign = cast(ast.Assign, if_stmt)
|
|
103
|
+
else_assign = cast(ast.Assign, else_stmt)
|
|
104
|
+
|
|
105
|
+
# Extract and validate condition components
|
|
106
|
+
condition_data = self._extract_condition_data(node.test)
|
|
107
|
+
if not condition_data:
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
key_name, dict_name = condition_data
|
|
111
|
+
|
|
112
|
+
# Verify if_stmt accesses dict[key]
|
|
113
|
+
if not self._verify_dict_key_access(if_assign, dict_name, key_name):
|
|
114
|
+
return None
|
|
115
|
+
|
|
116
|
+
var_name = cast(ast.Name, if_assign.targets[0]).id
|
|
117
|
+
default_val = (
|
|
118
|
+
ast.unparse(else_assign.value) if hasattr(ast, "unparse") else "..."
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
return (var_name, key_name.id, dict_name.id, default_val)
|
|
122
|
+
|
|
123
|
+
def _validate_assignment_structure(
|
|
124
|
+
self, if_stmt: ast.stmt, else_stmt: ast.stmt
|
|
125
|
+
) -> bool:
|
|
126
|
+
"""Validate that both if and else branches have valid assignment structure."""
|
|
127
|
+
# Check both are assignments
|
|
128
|
+
if not (isinstance(if_stmt, ast.Assign) and isinstance(else_stmt, ast.Assign)):
|
|
129
|
+
return False
|
|
130
|
+
|
|
131
|
+
# Check both have exactly one target
|
|
132
|
+
if len(if_stmt.targets) != 1 or len(else_stmt.targets) != 1:
|
|
133
|
+
return False
|
|
134
|
+
|
|
135
|
+
# Check both targets are simple names
|
|
136
|
+
if not (
|
|
137
|
+
isinstance(if_stmt.targets[0], ast.Name)
|
|
138
|
+
and isinstance(else_stmt.targets[0], ast.Name)
|
|
139
|
+
):
|
|
140
|
+
return False
|
|
141
|
+
|
|
142
|
+
# Both should assign to the same variable
|
|
143
|
+
if if_stmt.targets[0].id != else_stmt.targets[0].id:
|
|
144
|
+
return False
|
|
145
|
+
|
|
146
|
+
# Validate if-body is dict[key] access
|
|
147
|
+
return isinstance(if_stmt.value, ast.Subscript)
|
|
148
|
+
|
|
149
|
+
def _extract_condition_data(
|
|
150
|
+
self, test: ast.expr
|
|
151
|
+
) -> Optional[Tuple[ast.Name, ast.Name]]:
|
|
152
|
+
"""Extract key and dict names from the condition."""
|
|
153
|
+
if not isinstance(test, ast.Compare):
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
key_name = test.left
|
|
157
|
+
if not test.comparators:
|
|
158
|
+
return None
|
|
159
|
+
dict_name = test.comparators[0]
|
|
160
|
+
|
|
161
|
+
if not isinstance(key_name, ast.Name) or not isinstance(dict_name, ast.Name):
|
|
162
|
+
return None
|
|
163
|
+
|
|
164
|
+
return (key_name, dict_name)
|
|
165
|
+
|
|
166
|
+
def _verify_dict_key_access(
|
|
167
|
+
self, if_stmt: ast.Assign, dict_name: ast.Name, key_name: ast.Name
|
|
168
|
+
) -> bool:
|
|
169
|
+
"""Verify that if_stmt accesses dict[key] correctly."""
|
|
170
|
+
# Check if value is a subscript
|
|
171
|
+
if not isinstance(if_stmt.value, ast.Subscript):
|
|
172
|
+
return False
|
|
173
|
+
|
|
174
|
+
# Check if the subscript is on the correct dict
|
|
175
|
+
if not isinstance(if_stmt.value.value, ast.Name):
|
|
176
|
+
return False
|
|
177
|
+
if if_stmt.value.value.id != dict_name.id:
|
|
178
|
+
return False
|
|
179
|
+
|
|
180
|
+
# Check if the slice uses the correct key
|
|
181
|
+
if not isinstance(if_stmt.value.slice, ast.Name):
|
|
182
|
+
return False
|
|
183
|
+
if if_stmt.value.slice.id != key_name.id:
|
|
184
|
+
return False
|
|
185
|
+
|
|
186
|
+
return True
|
|
187
|
+
|
|
188
|
+
def visit_For(self, node: ast.For) -> None:
|
|
189
|
+
"""Check for dictionary iteration improvements."""
|
|
190
|
+
if self.is_suppressed(node):
|
|
191
|
+
self.generic_visit(node)
|
|
192
|
+
return
|
|
193
|
+
|
|
194
|
+
# Check for .keys() that should be removed
|
|
195
|
+
self._check_unnecessary_keys(node)
|
|
196
|
+
|
|
197
|
+
# Check for .items() opportunity
|
|
198
|
+
self._check_dict_items_opportunity(node)
|
|
199
|
+
|
|
200
|
+
self.generic_visit(node)
|
|
201
|
+
|
|
202
|
+
def _check_unnecessary_keys(self, node: ast.For) -> None:
|
|
203
|
+
"""Check for unnecessary .keys() in for loop."""
|
|
204
|
+
# Pattern: for key in dict.keys()
|
|
205
|
+
if not isinstance(node.iter, ast.Call):
|
|
206
|
+
return
|
|
207
|
+
|
|
208
|
+
if not isinstance(node.iter.func, ast.Attribute):
|
|
209
|
+
return
|
|
210
|
+
|
|
211
|
+
if node.iter.func.attr != "keys":
|
|
212
|
+
return
|
|
213
|
+
|
|
214
|
+
dict_name = self._get_name(node.iter.func.value)
|
|
215
|
+
if not dict_name:
|
|
216
|
+
return
|
|
217
|
+
|
|
218
|
+
target_name = self._get_target_name(node.target)
|
|
219
|
+
self.add_issue(
|
|
220
|
+
self._create_issue(
|
|
221
|
+
node,
|
|
222
|
+
severity=Severity.INFO,
|
|
223
|
+
rule_id="R009",
|
|
224
|
+
message="Unnecessary .keys() call when iterating dictionary",
|
|
225
|
+
suggestion=f"Use 'for {target_name} in {dict_name}:' "
|
|
226
|
+
f"instead of 'for {target_name} in {dict_name}.keys():'",
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
def _check_dict_items_opportunity(self, node: ast.For) -> None:
|
|
231
|
+
"""Check if loop iterates keys but also accesses values."""
|
|
232
|
+
# Pattern: for key in dict: ... dict[key] ...
|
|
233
|
+
if not isinstance(node.target, ast.Name):
|
|
234
|
+
return
|
|
235
|
+
|
|
236
|
+
# Get the iterable name
|
|
237
|
+
iter_name = self._get_name(node.iter)
|
|
238
|
+
if not iter_name:
|
|
239
|
+
return
|
|
240
|
+
|
|
241
|
+
key_name = node.target.id
|
|
242
|
+
|
|
243
|
+
# Check if body contains dict[key] accesses
|
|
244
|
+
if self._has_dict_key_access(node.body, iter_name, key_name):
|
|
245
|
+
self.add_issue(
|
|
246
|
+
self._create_issue(
|
|
247
|
+
node,
|
|
248
|
+
severity=Severity.MEDIUM,
|
|
249
|
+
rule_id="R007",
|
|
250
|
+
message="Consider using .items() to access both keys and values",
|
|
251
|
+
suggestion=f"Use 'for {key_name}, value in {iter_name}.items():' "
|
|
252
|
+
f"to avoid repeated dict lookups",
|
|
253
|
+
)
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
def _has_dict_key_access(
|
|
257
|
+
self, body: list[ast.stmt], dict_name: str, key_name: str
|
|
258
|
+
) -> bool:
|
|
259
|
+
"""Check if body contains dict[key] access pattern."""
|
|
260
|
+
for stmt in body:
|
|
261
|
+
for child in ast.walk(stmt):
|
|
262
|
+
if self._is_dict_key_subscript(child, dict_name, key_name):
|
|
263
|
+
return True
|
|
264
|
+
return False
|
|
265
|
+
|
|
266
|
+
def _is_dict_key_subscript(
|
|
267
|
+
self, node: ast.AST, dict_name: str, key_name: str
|
|
268
|
+
) -> bool:
|
|
269
|
+
"""Check if node is a dict[key] subscript."""
|
|
270
|
+
# Check if node is a subscript
|
|
271
|
+
if not isinstance(node, ast.Subscript):
|
|
272
|
+
return False
|
|
273
|
+
|
|
274
|
+
# Check if subscript is on the correct dict
|
|
275
|
+
if not isinstance(node.value, ast.Name):
|
|
276
|
+
return False
|
|
277
|
+
if node.value.id != dict_name:
|
|
278
|
+
return False
|
|
279
|
+
|
|
280
|
+
# Check if slice is the correct key
|
|
281
|
+
if not isinstance(node.slice, ast.Name):
|
|
282
|
+
return False
|
|
283
|
+
if node.slice.id != key_name:
|
|
284
|
+
return False
|
|
285
|
+
|
|
286
|
+
return True
|
|
287
|
+
|
|
288
|
+
def visit_Call(self, node: ast.Call) -> None:
|
|
289
|
+
"""Check for dict comprehension opportunities."""
|
|
290
|
+
if self.is_suppressed(node):
|
|
291
|
+
self.generic_visit(node)
|
|
292
|
+
return
|
|
293
|
+
|
|
294
|
+
# Pattern: dict([(k, v) for ...]) or dict([...])
|
|
295
|
+
self._check_dict_comprehension(node)
|
|
296
|
+
|
|
297
|
+
self.generic_visit(node)
|
|
298
|
+
|
|
299
|
+
def _check_dict_comprehension(self, node: ast.Call) -> None:
|
|
300
|
+
"""Check if dict() call can be replaced with dict comprehension."""
|
|
301
|
+
# Check if it's a dict() call
|
|
302
|
+
if not isinstance(node.func, ast.Name):
|
|
303
|
+
return
|
|
304
|
+
if node.func.id != "dict":
|
|
305
|
+
return
|
|
306
|
+
|
|
307
|
+
# Check if it has arguments
|
|
308
|
+
if not node.args:
|
|
309
|
+
return
|
|
310
|
+
|
|
311
|
+
arg = node.args[0]
|
|
312
|
+
|
|
313
|
+
# Check if it's a list comprehension with tuples
|
|
314
|
+
if not isinstance(arg, ast.ListComp):
|
|
315
|
+
return
|
|
316
|
+
|
|
317
|
+
# Check if element is a 2-tuple
|
|
318
|
+
if not isinstance(arg.elt, ast.Tuple):
|
|
319
|
+
return
|
|
320
|
+
if len(arg.elt.elts) != 2:
|
|
321
|
+
return
|
|
322
|
+
|
|
323
|
+
self.add_issue(
|
|
324
|
+
self._create_issue(
|
|
325
|
+
node,
|
|
326
|
+
severity=Severity.LOW,
|
|
327
|
+
rule_id="R010",
|
|
328
|
+
message="Consider using dictionary comprehension instead of dict()",
|
|
329
|
+
suggestion="Use '{k: v for ...}' instead of 'dict([(k, v) for ...])' "
|
|
330
|
+
"for better readability and performance",
|
|
331
|
+
)
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
def _get_name(self, node: ast.AST) -> Optional[str]:
|
|
335
|
+
"""Extract the name from a node."""
|
|
336
|
+
if isinstance(node, ast.Name):
|
|
337
|
+
return node.id
|
|
338
|
+
if isinstance(node, ast.Attribute):
|
|
339
|
+
return node.attr
|
|
340
|
+
return None
|
|
341
|
+
|
|
342
|
+
def _get_target_name(self, node: ast.AST) -> str:
|
|
343
|
+
"""Get the target name from a for loop target."""
|
|
344
|
+
if isinstance(node, ast.Name):
|
|
345
|
+
return node.id
|
|
346
|
+
return "item"
|