pyrefactor 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,353 @@
1
+ """Comparison improvements detector for PyRefactor."""
2
+
3
+ import ast
4
+ from typing import Optional, Tuple, cast
5
+
6
+ from ..ast_visitor import BaseDetector
7
+ from ..models import Issue, Severity
8
+
9
+ # Singleton values that should be compared with 'is' instead of '=='
10
+ SINGLETON_VALUES = frozenset({True, False, None})
11
+
12
+
13
+ class ComparisonsDetector(BaseDetector):
14
+ """Detects non-idiomatic or inefficient comparison patterns."""
15
+
16
+ def get_detector_name(self) -> str:
17
+ """Return the name of this detector."""
18
+ return "comparisons"
19
+
20
+ def _create_issue(
21
+ self,
22
+ node: ast.AST,
23
+ *,
24
+ severity: Severity,
25
+ rule_id: str,
26
+ message: str,
27
+ suggestion: str,
28
+ ) -> Issue:
29
+ """Create an Issue object for comparison issues."""
30
+ return Issue(
31
+ file=self.file_path,
32
+ line=cast(int, getattr(node, "lineno", 0)),
33
+ column=cast(int, getattr(node, "col_offset", 0)),
34
+ severity=severity,
35
+ rule_id=rule_id,
36
+ message=message,
37
+ suggestion=suggestion,
38
+ )
39
+
40
+ def visit_BoolOp(self, node: ast.BoolOp) -> None:
41
+ """Check for patterns that could use 'in' operator or chained comparisons."""
42
+ if self.is_suppressed(node):
43
+ self.generic_visit(node)
44
+ return
45
+
46
+ # Check for multiple equality comparisons
47
+ if isinstance(node.op, ast.Or):
48
+ self._check_consider_using_in(node)
49
+
50
+ # Check for chainable comparisons
51
+ if isinstance(node.op, ast.And):
52
+ self._check_chained_comparison(node)
53
+
54
+ self.generic_visit(node)
55
+
56
+ def _check_consider_using_in(self, node: ast.BoolOp) -> None:
57
+ """Check for pattern: x == a or x == b or x == c."""
58
+ if not all(isinstance(val, ast.Compare) for val in node.values):
59
+ return
60
+
61
+ # Extract all comparisons
62
+ comparisons: list[ast.Compare] = [cast(ast.Compare, val) for val in node.values]
63
+
64
+ # All should be single == comparisons
65
+ if not all(
66
+ len(comp.ops) == 1 and isinstance(comp.ops[0], ast.Eq)
67
+ for comp in comparisons
68
+ ):
69
+ return
70
+
71
+ # All should compare the same left operand
72
+ first_left = ast.dump(comparisons[0].left)
73
+ if not all(ast.dump(comp.left) == first_left for comp in comparisons):
74
+ return
75
+
76
+ # We have x == a or x == b pattern
77
+ if len(comparisons) >= 2:
78
+ var_name = (
79
+ ast.unparse(comparisons[0].left) if hasattr(ast, "unparse") else "x"
80
+ )
81
+ values = []
82
+ for comp in comparisons:
83
+ if comp.comparators:
84
+ val = (
85
+ ast.unparse(comp.comparators[0])
86
+ if hasattr(ast, "unparse")
87
+ else "value"
88
+ )
89
+ values.append(val)
90
+
91
+ values_str = ", ".join(values)
92
+
93
+ self.add_issue(
94
+ self._create_issue(
95
+ node,
96
+ severity=Severity.LOW,
97
+ rule_id="R011",
98
+ message="Multiple equality comparisons can be simplified using 'in' operator",
99
+ suggestion=f"Use '{var_name} in ({values_str})' instead of multiple '==' comparisons. "
100
+ f"Use a set if values are hashable for O(1) lookup.",
101
+ )
102
+ )
103
+
104
+ def _check_chained_comparison(self, node: ast.BoolOp) -> None:
105
+ """Check for pattern: a < b and b < c that can be chained."""
106
+ if len(node.values) < 2:
107
+ return
108
+
109
+ # Check pairs of comparisons
110
+ for i in range(len(node.values) - 1):
111
+ chain_info = self._try_extract_chainable_pair(
112
+ node.values[i], node.values[i + 1]
113
+ )
114
+ if chain_info:
115
+ self._report_chainable_comparison(node, chain_info)
116
+ return # Report once
117
+
118
+ def _try_extract_chainable_pair(
119
+ self, val1: ast.expr, val2: ast.expr
120
+ ) -> Optional[Tuple[str, str, str, str, str]]:
121
+ """Try to extract chainable comparison info from two values.
122
+
123
+ Returns (left1_str, op1, mid_str, op2, right2_str) if chainable, else None.
124
+ """
125
+ # Both must be comparisons
126
+ if not isinstance(val1, ast.Compare) or not isinstance(val2, ast.Compare):
127
+ return None
128
+
129
+ comp1 = val1
130
+ comp2 = val2
131
+
132
+ # Both should be single comparisons
133
+ if len(comp1.ops) != 1 or len(comp2.ops) != 1:
134
+ return None
135
+
136
+ # Check if comp1's right operand matches comp2's left operand
137
+ if not comp1.comparators:
138
+ return None
139
+
140
+ right1 = comp1.comparators[0]
141
+ left2 = comp2.left
142
+
143
+ # Check if they share a common operand
144
+ if ast.dump(right1) != ast.dump(left2):
145
+ return None
146
+
147
+ # Get operator strings
148
+ op1 = self._get_op_str(comp1.ops[0])
149
+ op2 = self._get_op_str(comp2.ops[0])
150
+
151
+ if not op1 or not op2:
152
+ return None
153
+
154
+ # Extract string representations
155
+ left1_str = ast.unparse(comp1.left) if hasattr(ast, "unparse") else "a"
156
+ mid_str = ast.unparse(right1) if hasattr(ast, "unparse") else "b"
157
+ right2_str = (
158
+ ast.unparse(comp2.comparators[0])
159
+ if hasattr(ast, "unparse") and comp2.comparators
160
+ else "c"
161
+ )
162
+
163
+ return (left1_str, op1, mid_str, op2, right2_str)
164
+
165
+ def _report_chainable_comparison(
166
+ self, node: ast.BoolOp, chain_info: Tuple[str, str, str, str, str]
167
+ ) -> None:
168
+ """Report a chainable comparison issue."""
169
+ left1_str, op1, mid_str, op2, right2_str = chain_info
170
+ self.add_issue(
171
+ self._create_issue(
172
+ node,
173
+ severity=Severity.LOW,
174
+ rule_id="R012",
175
+ message="Comparison can be chained for better readability",
176
+ suggestion=f"Use '{left1_str} {op1} {mid_str} {op2} {right2_str}' "
177
+ f"instead of separate comparisons",
178
+ )
179
+ )
180
+
181
+ def visit_Compare(self, node: ast.Compare) -> None:
182
+ """Check for singleton comparisons and type checks."""
183
+ if self.is_suppressed(node):
184
+ self.generic_visit(node)
185
+ return
186
+
187
+ self._check_singleton_comparison(node)
188
+ self._check_unidiomatic_typecheck(node)
189
+
190
+ self.generic_visit(node)
191
+
192
+ def _is_singleton_const(self, node: ast.AST) -> bool:
193
+ """Check if node is a singleton constant (True, False, or None).
194
+
195
+ Uses identity checking (is) like pylint to avoid false positives
196
+ with values like 1 which equals True but is not the same object.
197
+ """
198
+ if not isinstance(node, ast.Constant):
199
+ return False
200
+ # Use identity check (is) not equality (==) to avoid issues with 1 == True
201
+ return any(node.value is value for value in SINGLETON_VALUES)
202
+
203
+ def _report_none_comparison(
204
+ self, node: ast.Compare, checking_for_absence: bool
205
+ ) -> None:
206
+ """Report inappropriate None comparison."""
207
+ correct_op = "is not" if checking_for_absence else "is"
208
+ wrong_op = "!=" if checking_for_absence else "=="
209
+ self.add_issue(
210
+ self._create_issue(
211
+ node,
212
+ severity=Severity.MEDIUM,
213
+ rule_id="R014",
214
+ message="Comparison with None should use 'is' or 'is not'",
215
+ suggestion=f"Use '{correct_op}' instead of '{wrong_op}' when comparing with None",
216
+ )
217
+ )
218
+
219
+ def _report_bool_comparison(
220
+ self, node: ast.Compare, op: ast.cmpop, singleton_val: bool, other: ast.AST
221
+ ) -> None:
222
+ """Report redundant True/False comparison."""
223
+ other_str = ast.unparse(other) if hasattr(ast, "unparse") else "expr"
224
+
225
+ # Determine the suggested replacement
226
+ suggestion = self._get_bool_comparison_suggestion(singleton_val, op, other_str)
227
+
228
+ self.add_issue(
229
+ self._create_issue(
230
+ node,
231
+ severity=Severity.INFO,
232
+ rule_id="R014",
233
+ message=f"Redundant comparison with {singleton_val}",
234
+ suggestion=suggestion,
235
+ )
236
+ )
237
+
238
+ def _get_bool_comparison_suggestion(
239
+ self, singleton_val: bool, op: ast.cmpop, other_str: str
240
+ ) -> str:
241
+ """Generate suggestion text for boolean comparison."""
242
+ is_eq = isinstance(op, ast.Eq)
243
+
244
+ if singleton_val: # True
245
+ if is_eq:
246
+ return f"Use '{other_str}' directly instead of comparing with True"
247
+ return f"Use 'not {other_str}' instead of '!= True'"
248
+
249
+ # False
250
+ if is_eq:
251
+ return f"Use 'not {other_str}' instead of comparing with False"
252
+ return f"Use '{other_str}' directly instead of '!= False'"
253
+
254
+ def _check_singleton_comparison(self, node: ast.Compare) -> None:
255
+ """Check for comparisons with True/False/None using == instead of is.
256
+
257
+ Implementation based on pylint's comparison checker.
258
+ """
259
+ if len(node.ops) != 1:
260
+ return
261
+
262
+ op = node.ops[0]
263
+ comparator = node.comparators[0] if node.comparators else None
264
+
265
+ if not comparator or not isinstance(op, (ast.Eq, ast.NotEq)):
266
+ return
267
+
268
+ # Check if either side is a singleton constant
269
+ if self._is_singleton_const(node.left):
270
+ singleton = node.left
271
+ other = comparator
272
+ elif self._is_singleton_const(comparator):
273
+ singleton = comparator
274
+ other = node.left
275
+ else:
276
+ return
277
+
278
+ # Get the singleton value
279
+ if not isinstance(singleton, ast.Constant):
280
+ return
281
+
282
+ singleton_val = singleton.value
283
+ checking_for_absence = isinstance(op, ast.NotEq)
284
+
285
+ # Handle None comparisons
286
+ if singleton_val is None:
287
+ self._report_none_comparison(node, checking_for_absence)
288
+ # Handle True/False comparisons (using isinstance for type checking)
289
+ elif isinstance(singleton_val, bool):
290
+ # Type narrowing: at this point singleton_val is bool
291
+ self._report_bool_comparison(node, op, singleton_val, other)
292
+
293
+ def _check_unidiomatic_typecheck(self, node: ast.Compare) -> None:
294
+ """Check for type(x) == Y instead of isinstance(x, Y)."""
295
+ if len(node.ops) != 1:
296
+ return
297
+
298
+ op = node.ops[0]
299
+ if not isinstance(op, (ast.Eq, ast.Is)):
300
+ return
301
+
302
+ # Check for type(x) == Y or type(x) is Y
303
+ if not isinstance(node.left, ast.Call):
304
+ return
305
+
306
+ # Early return if not a type() call with one argument
307
+ if not isinstance(node.left.func, ast.Name):
308
+ return
309
+ if node.left.func.id != "type":
310
+ return
311
+ if len(node.left.args) != 1:
312
+ return
313
+
314
+ # At this point, node.left is a Call with args
315
+ obj = ast.unparse(node.left.args[0]) if hasattr(ast, "unparse") else "obj"
316
+ type_name = (
317
+ ast.unparse(node.comparators[0])
318
+ if hasattr(ast, "unparse") and node.comparators
319
+ else "Type"
320
+ )
321
+
322
+ self.add_issue(
323
+ self._create_issue(
324
+ node,
325
+ severity=Severity.MEDIUM,
326
+ rule_id="R015",
327
+ message="Use isinstance() for type checking instead of type() comparison",
328
+ suggestion=f"Use 'isinstance({obj}, {type_name})' instead of 'type({obj}) == {type_name}'",
329
+ )
330
+ )
331
+
332
+ def visit_Call(self, node: ast.Call) -> None:
333
+ """Check for consecutive isinstance calls."""
334
+ if self.is_suppressed(node):
335
+ self.generic_visit(node)
336
+ return
337
+
338
+ # This would need to check the parent context for multiple isinstance calls
339
+ # For now, we'll check in visit_BoolOp for the pattern
340
+
341
+ self.generic_visit(node)
342
+
343
+ def _get_op_str(self, op: ast.cmpop) -> Optional[str]:
344
+ """Convert comparison operator to string."""
345
+ op_map = {
346
+ ast.Lt: "<",
347
+ ast.LtE: "<=",
348
+ ast.Gt: ">",
349
+ ast.GtE: ">=",
350
+ ast.Eq: "==",
351
+ ast.NotEq: "!=",
352
+ }
353
+ return op_map.get(type(op))
@@ -0,0 +1,248 @@
1
+ """Complexity detector for PyRefactor."""
2
+
3
+ import ast
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Union
6
+
7
+ from ..ast_visitor import (
8
+ BaseDetector,
9
+ calculate_cyclomatic_complexity,
10
+ count_branches,
11
+ count_nesting_depth,
12
+ )
13
+ from ..models import Issue, Severity
14
+
15
+
16
+ @dataclass
17
+ class IssueParams:
18
+ """Parameters for creating a complexity issue."""
19
+
20
+ severity: Severity
21
+ rule_id: str
22
+ message: str
23
+ suggestion: str
24
+ end_line: Optional[int] = None
25
+
26
+
27
+ class LocalVarVisitor(ast.NodeVisitor):
28
+ """Visitor to count local variables in a function."""
29
+
30
+ def __init__(self) -> None:
31
+ self.vars: set[str] = set()
32
+
33
+ def visit_Name(self, name_node: ast.Name) -> None:
34
+ """Track variable assignments."""
35
+ if isinstance(name_node.ctx, ast.Store):
36
+ self.vars.add(name_node.id)
37
+
38
+ def visit_FunctionDef(self, func_node: ast.FunctionDef) -> None:
39
+ """Don't descend into nested functions."""
40
+ ...
41
+
42
+ def visit_AsyncFunctionDef(self, func_node: ast.AsyncFunctionDef) -> None:
43
+ """Don't descend into nested async functions."""
44
+ ...
45
+
46
+
47
+ class ComplexityDetector(BaseDetector):
48
+ """Detects complexity issues in code."""
49
+
50
+ def get_detector_name(self) -> str:
51
+ """Return the name of this detector."""
52
+ return "complexity"
53
+
54
+ def _create_issue(
55
+ self,
56
+ node: Union[ast.FunctionDef, ast.AsyncFunctionDef],
57
+ params: IssueParams,
58
+ ) -> Issue:
59
+ """Create an Issue object for function-related complexity issues."""
60
+ return Issue(
61
+ file=self.file_path,
62
+ line=node.lineno,
63
+ column=node.col_offset,
64
+ severity=params.severity,
65
+ rule_id=params.rule_id,
66
+ message=params.message,
67
+ suggestion=params.suggestion,
68
+ end_line=params.end_line,
69
+ )
70
+
71
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
72
+ """Check function complexity."""
73
+ self._check_function(node)
74
+ self.generic_visit(node)
75
+
76
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
77
+ """Check async function complexity."""
78
+ self._check_function(node)
79
+ self.generic_visit(node)
80
+
81
+ def _check_function(
82
+ self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
83
+ ) -> None:
84
+ """Check various complexity metrics for a function.
85
+
86
+ Runs all complexity checks efficiently by minimizing redundant AST traversals.
87
+ """
88
+ if self.is_suppressed(node):
89
+ return
90
+
91
+ # Save current function context
92
+ old_function = self.current_function
93
+ self.current_function = node
94
+
95
+ # Group checks that don't require AST traversal
96
+ self._check_function_length(node)
97
+ self._check_arguments(node)
98
+
99
+ # Checks requiring AST traversal - could be combined in future optimization
100
+ self._check_local_variables(node)
101
+ self._check_branches(node)
102
+ self._check_nesting_depth(node)
103
+ self._check_cyclomatic_complexity(node)
104
+
105
+ # Restore function context
106
+ self.current_function = old_function
107
+
108
+ def _check_function_length(
109
+ self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
110
+ ) -> None:
111
+ """Check if function is too long."""
112
+ if not hasattr(node, "lineno") or not hasattr(node, "end_lineno"):
113
+ return
114
+
115
+ if node.end_lineno is None:
116
+ return
117
+
118
+ function_lines = node.end_lineno - node.lineno + 1
119
+ max_lines = self.config.complexity.max_function_lines
120
+
121
+ if function_lines > max_lines:
122
+ self.add_issue(
123
+ self._create_issue(
124
+ node,
125
+ IssueParams(
126
+ severity=Severity.MEDIUM,
127
+ rule_id="C001",
128
+ message=f"Function '{node.name}' is too long ({function_lines} lines, max {max_lines})",
129
+ suggestion="Consider breaking this function into smaller, more focused functions",
130
+ end_line=node.end_lineno,
131
+ ),
132
+ )
133
+ )
134
+
135
+ def _check_arguments(
136
+ self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
137
+ ) -> None:
138
+ """Check if function has too many arguments."""
139
+ args = node.args
140
+ total_args = (
141
+ len(args.args)
142
+ + len(args.posonlyargs)
143
+ + len(args.kwonlyargs)
144
+ + (1 if args.vararg else 0)
145
+ + (1 if args.kwarg else 0)
146
+ )
147
+
148
+ # Exclude 'self' and 'cls' for methods
149
+ if args.args and args.args[0].arg in ("self", "cls"):
150
+ total_args -= 1
151
+
152
+ max_args = self.config.complexity.max_arguments
153
+
154
+ if total_args > max_args:
155
+ self.add_issue(
156
+ self._create_issue(
157
+ node,
158
+ IssueParams(
159
+ severity=Severity.MEDIUM,
160
+ rule_id="C002",
161
+ message=f"Function '{node.name}' has too many arguments ({total_args}, max {max_args})",
162
+ suggestion="Consider using a configuration object or dataclass to group related parameters",
163
+ ),
164
+ )
165
+ )
166
+
167
+ def _check_local_variables(
168
+ self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
169
+ ) -> None:
170
+ """Check if function has too many local variables."""
171
+ visitor = LocalVarVisitor()
172
+ visitor.visit(node)
173
+ local_vars = visitor.vars
174
+
175
+ max_vars = self.config.complexity.max_local_variables
176
+
177
+ if len(local_vars) > max_vars:
178
+ self.add_issue(
179
+ self._create_issue(
180
+ node,
181
+ IssueParams(
182
+ severity=Severity.LOW,
183
+ rule_id="C003",
184
+ message=f"Function '{node.name}' has too many local variables ({len(local_vars)}, max {max_vars})",
185
+ suggestion="Consider extracting functionality into helper functions or classes",
186
+ ),
187
+ )
188
+ )
189
+
190
+ def _check_branches(
191
+ self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
192
+ ) -> None:
193
+ """Check if function has too many branches."""
194
+ branches = count_branches(node)
195
+ max_branches = self.config.complexity.max_branches
196
+
197
+ if branches > max_branches:
198
+ self.add_issue(
199
+ self._create_issue(
200
+ node,
201
+ IssueParams(
202
+ severity=Severity.HIGH,
203
+ rule_id="C004",
204
+ message=f"Function '{node.name}' has too many branches ({branches}, max {max_branches})",
205
+ suggestion="Refactor using helper functions, guard clauses, or dictionary dispatch patterns",
206
+ ),
207
+ )
208
+ )
209
+
210
+ def _check_nesting_depth(
211
+ self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
212
+ ) -> None:
213
+ """Check if function has excessive nesting."""
214
+ nesting = count_nesting_depth(node)
215
+ max_nesting = self.config.complexity.max_nesting_depth
216
+
217
+ if nesting > max_nesting:
218
+ self.add_issue(
219
+ self._create_issue(
220
+ node,
221
+ IssueParams(
222
+ severity=Severity.HIGH,
223
+ rule_id="C005",
224
+ message=f"Function '{node.name}' has excessive nesting depth ({nesting}, max {max_nesting})",
225
+ suggestion="Use early returns, guard clauses, or extract nested logic to separate functions",
226
+ ),
227
+ )
228
+ )
229
+
230
+ def _check_cyclomatic_complexity(
231
+ self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
232
+ ) -> None:
233
+ """Check cyclomatic complexity."""
234
+ complexity = calculate_cyclomatic_complexity(node)
235
+ max_complexity = self.config.complexity.max_cyclomatic_complexity
236
+
237
+ if complexity > max_complexity:
238
+ self.add_issue(
239
+ self._create_issue(
240
+ node,
241
+ IssueParams(
242
+ severity=Severity.MEDIUM,
243
+ rule_id="C006",
244
+ message=f"Function '{node.name}' has high cyclomatic complexity ({complexity}, max {max_complexity})",
245
+ suggestion="Simplify the function by reducing decision points or extracting logic",
246
+ ),
247
+ )
248
+ )