invar-tools 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invar/core/formatter.py +6 -1
- invar/core/models.py +13 -0
- invar/core/patterns/__init__.py +53 -0
- invar/core/patterns/detector.py +249 -0
- invar/core/patterns/p0_exhaustive.py +207 -0
- invar/core/patterns/p0_literal.py +307 -0
- invar/core/patterns/p0_newtype.py +211 -0
- invar/core/patterns/p0_nonempty.py +307 -0
- invar/core/patterns/p0_validation.py +278 -0
- invar/core/patterns/registry.py +234 -0
- invar/core/patterns/types.py +167 -0
- invar/core/trivial_detection.py +189 -0
- invar/mcp/server.py +4 -0
- invar/shell/commands/guard.py +65 -0
- invar/shell/contract_coverage.py +358 -0
- invar/shell/guard_output.py +5 -0
- invar/shell/pattern_integration.py +234 -0
- invar/shell/testing.py +13 -2
- invar/templates/config/CLAUDE.md.jinja +1 -0
- invar/templates/skills/develop/SKILL.md.jinja +49 -0
- invar/templates/skills/review/SKILL.md.jinja +196 -31
- {invar_tools-1.4.0.dist-info → invar_tools-1.5.0.dist-info}/METADATA +12 -8
- {invar_tools-1.4.0.dist-info → invar_tools-1.5.0.dist-info}/RECORD +28 -16
- {invar_tools-1.4.0.dist-info → invar_tools-1.5.0.dist-info}/WHEEL +0 -0
- {invar_tools-1.4.0.dist-info → invar_tools-1.5.0.dist-info}/entry_points.txt +0 -0
- {invar_tools-1.4.0.dist-info → invar_tools-1.5.0.dist-info}/licenses/LICENSE +0 -0
- {invar_tools-1.4.0.dist-info → invar_tools-1.5.0.dist-info}/licenses/LICENSE-GPL +0 -0
- {invar_tools-1.4.0.dist-info → invar_tools-1.5.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
"""
|
|
2
|
+
NonEmpty Pattern Detector (DX-61, P0).
|
|
3
|
+
|
|
4
|
+
Detects runtime empty-collection checks that could benefit from
|
|
5
|
+
compile-time NonEmpty type safety.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import ast
|
|
9
|
+
|
|
10
|
+
from deal import post, pre
|
|
11
|
+
|
|
12
|
+
from invar.core.patterns.detector import BaseDetector
|
|
13
|
+
from invar.core.patterns.types import (
|
|
14
|
+
Confidence,
|
|
15
|
+
PatternID,
|
|
16
|
+
PatternSuggestion,
|
|
17
|
+
Priority,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class NonEmptyDetector(BaseDetector):
|
|
22
|
+
"""
|
|
23
|
+
Detect runtime checks for empty collections.
|
|
24
|
+
|
|
25
|
+
These are candidates for NonEmpty type to guarantee non-emptiness
|
|
26
|
+
at compile time instead of runtime.
|
|
27
|
+
|
|
28
|
+
Detection logic:
|
|
29
|
+
- Find 'if not items:' or 'if len(items) == 0:' patterns
|
|
30
|
+
- Look for raises or early returns after such checks
|
|
31
|
+
- Suggest NonEmpty type for the parameter
|
|
32
|
+
|
|
33
|
+
>>> import ast
|
|
34
|
+
>>> detector = NonEmptyDetector()
|
|
35
|
+
>>> code = '''
|
|
36
|
+
... def summarize(items: list[str]) -> str:
|
|
37
|
+
... if not items:
|
|
38
|
+
... raise ValueError("Cannot summarize empty list")
|
|
39
|
+
... return f"First: {items[0]}"
|
|
40
|
+
... '''
|
|
41
|
+
>>> tree = ast.parse(code)
|
|
42
|
+
>>> suggestions = detector.detect(tree, "test.py")
|
|
43
|
+
>>> len(suggestions) > 0
|
|
44
|
+
True
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
@post(lambda result: result == PatternID.NONEMPTY)
|
|
49
|
+
def pattern_id(self) -> PatternID:
|
|
50
|
+
"""Unique identifier for this pattern."""
|
|
51
|
+
return PatternID.NONEMPTY
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
@post(lambda result: result == Priority.P0)
|
|
55
|
+
def priority(self) -> Priority:
|
|
56
|
+
"""Priority tier."""
|
|
57
|
+
return Priority.P0
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
@post(lambda result: len(result) > 0)
|
|
61
|
+
def description(self) -> str:
|
|
62
|
+
"""Human-readable description."""
|
|
63
|
+
return "Use NonEmpty type for compile-time non-empty guarantees"
|
|
64
|
+
|
|
65
|
+
@post(lambda result: all(isinstance(s, PatternSuggestion) for s in result))
|
|
66
|
+
def detect(self, tree: ast.AST, file_path: str) -> list[PatternSuggestion]:
|
|
67
|
+
"""
|
|
68
|
+
Find functions with runtime empty-collection checks.
|
|
69
|
+
|
|
70
|
+
>>> import ast
|
|
71
|
+
>>> detector = NonEmptyDetector()
|
|
72
|
+
>>> code = '''
|
|
73
|
+
... def process(data: list[int]):
|
|
74
|
+
... if len(data) == 0:
|
|
75
|
+
... raise ValueError("Empty data")
|
|
76
|
+
... return data[0]
|
|
77
|
+
... '''
|
|
78
|
+
>>> tree = ast.parse(code)
|
|
79
|
+
>>> suggestions = detector.detect(tree, "test.py")
|
|
80
|
+
>>> len(suggestions) > 0
|
|
81
|
+
True
|
|
82
|
+
"""
|
|
83
|
+
suggestions = []
|
|
84
|
+
|
|
85
|
+
for node in ast.walk(tree):
|
|
86
|
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
87
|
+
suggestion = self._check_function(node, file_path)
|
|
88
|
+
if suggestion:
|
|
89
|
+
suggestions.append(suggestion)
|
|
90
|
+
|
|
91
|
+
return suggestions
|
|
92
|
+
|
|
93
|
+
@pre(lambda self, node, file_path: len(file_path) > 0)
|
|
94
|
+
def _check_function(
|
|
95
|
+
self, node: ast.FunctionDef | ast.AsyncFunctionDef, file_path: str
|
|
96
|
+
) -> PatternSuggestion | None:
|
|
97
|
+
"""
|
|
98
|
+
Check if function has empty-collection guard patterns.
|
|
99
|
+
|
|
100
|
+
>>> import ast
|
|
101
|
+
>>> detector = NonEmptyDetector()
|
|
102
|
+
>>> code = '''
|
|
103
|
+
... def f(items):
|
|
104
|
+
... if not items:
|
|
105
|
+
... raise ValueError("Empty")
|
|
106
|
+
... return items[0]
|
|
107
|
+
... '''
|
|
108
|
+
>>> tree = ast.parse(code)
|
|
109
|
+
>>> func = tree.body[0]
|
|
110
|
+
>>> suggestion = detector._check_function(func, "test.py")
|
|
111
|
+
>>> suggestion is not None
|
|
112
|
+
True
|
|
113
|
+
"""
|
|
114
|
+
empty_checks = self._find_empty_checks(node)
|
|
115
|
+
|
|
116
|
+
if empty_checks:
|
|
117
|
+
var_name, check_line = empty_checks[0]
|
|
118
|
+
param_type = self._get_param_type(node, var_name)
|
|
119
|
+
confidence = self._calculate_confidence(node, var_name, param_type)
|
|
120
|
+
|
|
121
|
+
return self.make_suggestion(
|
|
122
|
+
pattern_id=self.pattern_id,
|
|
123
|
+
priority=self.priority,
|
|
124
|
+
file_path=file_path,
|
|
125
|
+
line=check_line,
|
|
126
|
+
message=f"Runtime empty check on '{var_name}' - consider NonEmpty type",
|
|
127
|
+
current_code=self._format_check(var_name, param_type),
|
|
128
|
+
suggested_pattern=f"NonEmpty[{param_type or 'T'}] guarantees non-empty at compile time",
|
|
129
|
+
confidence=confidence,
|
|
130
|
+
reference_pattern="Pattern 3: NonEmpty for Compile-Time Safety",
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
@post(lambda result: all(isinstance(v, str) and line > 0 for v, line in result))
|
|
136
|
+
def _find_empty_checks(
|
|
137
|
+
self, node: ast.FunctionDef | ast.AsyncFunctionDef
|
|
138
|
+
) -> list[tuple[str, int]]:
|
|
139
|
+
"""
|
|
140
|
+
Find 'if not x' or 'if len(x) == 0' patterns with raise/return.
|
|
141
|
+
|
|
142
|
+
Only checks if statements at the function level, not nested functions.
|
|
143
|
+
|
|
144
|
+
>>> import ast
|
|
145
|
+
>>> detector = NonEmptyDetector()
|
|
146
|
+
>>> code = '''
|
|
147
|
+
... def f(items):
|
|
148
|
+
... if not items:
|
|
149
|
+
... raise ValueError("Empty")
|
|
150
|
+
... '''
|
|
151
|
+
>>> tree = ast.parse(code)
|
|
152
|
+
>>> func = tree.body[0]
|
|
153
|
+
>>> checks = detector._find_empty_checks(func)
|
|
154
|
+
>>> len(checks) > 0
|
|
155
|
+
True
|
|
156
|
+
>>> checks[0][0]
|
|
157
|
+
'items'
|
|
158
|
+
"""
|
|
159
|
+
checks: list[tuple[str, int]] = []
|
|
160
|
+
self._collect_empty_checks(node.body, checks)
|
|
161
|
+
return checks
|
|
162
|
+
|
|
163
|
+
@pre(lambda self, stmts, checks: stmts is not None and checks is not None)
|
|
164
|
+
def _collect_empty_checks(
|
|
165
|
+
self, stmts: list[ast.stmt], checks: list[tuple[str, int]]
|
|
166
|
+
) -> None:
|
|
167
|
+
"""
|
|
168
|
+
Recursively collect empty checks, avoiding nested functions.
|
|
169
|
+
|
|
170
|
+
>>> import ast
|
|
171
|
+
>>> detector = NonEmptyDetector()
|
|
172
|
+
>>> stmts = ast.parse("if not x: raise ValueError('e')").body
|
|
173
|
+
>>> checks: list[tuple[str, int]] = []
|
|
174
|
+
>>> detector._collect_empty_checks(stmts, checks)
|
|
175
|
+
>>> len(checks)
|
|
176
|
+
1
|
|
177
|
+
"""
|
|
178
|
+
for stmt in stmts:
|
|
179
|
+
# Skip nested functions
|
|
180
|
+
if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
181
|
+
continue
|
|
182
|
+
|
|
183
|
+
if isinstance(stmt, ast.If):
|
|
184
|
+
var_name = self._extract_empty_check_var(stmt.test)
|
|
185
|
+
if var_name and self._has_raise_or_return(stmt.body):
|
|
186
|
+
checks.append((var_name, stmt.lineno))
|
|
187
|
+
# Recurse into if body and else
|
|
188
|
+
self._collect_empty_checks(stmt.body, checks)
|
|
189
|
+
self._collect_empty_checks(stmt.orelse, checks)
|
|
190
|
+
|
|
191
|
+
@post(lambda result: result is None or (isinstance(result, str) and len(result) > 0))
|
|
192
|
+
def _extract_empty_check_var(self, test: ast.expr) -> str | None:
|
|
193
|
+
"""
|
|
194
|
+
Extract variable name from empty-check condition.
|
|
195
|
+
|
|
196
|
+
Handles:
|
|
197
|
+
- 'not items' -> 'items'
|
|
198
|
+
- 'len(items) == 0' -> 'items'
|
|
199
|
+
- 'len(items) < 1' -> 'items'
|
|
200
|
+
|
|
201
|
+
>>> import ast
|
|
202
|
+
>>> detector = NonEmptyDetector()
|
|
203
|
+
>>> test = ast.parse("not items", mode="eval").body
|
|
204
|
+
>>> detector._extract_empty_check_var(test)
|
|
205
|
+
'items'
|
|
206
|
+
"""
|
|
207
|
+
# Handle 'not items'
|
|
208
|
+
if isinstance(test, ast.UnaryOp) and isinstance(test.op, ast.Not):
|
|
209
|
+
if isinstance(test.operand, ast.Name):
|
|
210
|
+
return test.operand.id
|
|
211
|
+
|
|
212
|
+
# Handle 'len(items) == 0' or 'len(items) < 1'
|
|
213
|
+
if isinstance(test, ast.Compare):
|
|
214
|
+
if len(test.ops) == 1 and len(test.comparators) == 1:
|
|
215
|
+
left = test.left
|
|
216
|
+
op = test.ops[0]
|
|
217
|
+
right = test.comparators[0]
|
|
218
|
+
|
|
219
|
+
# Check for len(x) on left
|
|
220
|
+
if (
|
|
221
|
+
isinstance(left, ast.Call)
|
|
222
|
+
and isinstance(left.func, ast.Name)
|
|
223
|
+
and left.func.id == "len"
|
|
224
|
+
and len(left.args) == 1
|
|
225
|
+
and isinstance(left.args[0], ast.Name)
|
|
226
|
+
):
|
|
227
|
+
var_name = left.args[0].id
|
|
228
|
+
|
|
229
|
+
# Check for == 0 or < 1
|
|
230
|
+
if isinstance(right, ast.Constant):
|
|
231
|
+
if isinstance(op, ast.Eq) and right.value == 0:
|
|
232
|
+
return var_name
|
|
233
|
+
if isinstance(op, ast.Lt) and right.value == 1:
|
|
234
|
+
return var_name
|
|
235
|
+
|
|
236
|
+
return None
|
|
237
|
+
|
|
238
|
+
@post(lambda result: isinstance(result, bool))
|
|
239
|
+
def _has_raise_or_return(self, body: list[ast.stmt]) -> bool:
|
|
240
|
+
"""
|
|
241
|
+
Check if body contains raise or return statement.
|
|
242
|
+
|
|
243
|
+
>>> import ast
|
|
244
|
+
>>> detector = NonEmptyDetector()
|
|
245
|
+
>>> body = ast.parse("raise ValueError('x')").body
|
|
246
|
+
>>> detector._has_raise_or_return(body)
|
|
247
|
+
True
|
|
248
|
+
"""
|
|
249
|
+
return any(isinstance(stmt, (ast.Raise, ast.Return)) for stmt in body)
|
|
250
|
+
|
|
251
|
+
@pre(lambda self, node, var_name: len(var_name) > 0)
|
|
252
|
+
def _get_param_type(
|
|
253
|
+
self, node: ast.FunctionDef | ast.AsyncFunctionDef, var_name: str
|
|
254
|
+
) -> str | None:
|
|
255
|
+
"""
|
|
256
|
+
Get type annotation for a parameter.
|
|
257
|
+
|
|
258
|
+
>>> import ast
|
|
259
|
+
>>> detector = NonEmptyDetector()
|
|
260
|
+
>>> func = ast.parse("def f(items: list[str]): pass").body[0]
|
|
261
|
+
>>> detector._get_param_type(func, "items")
|
|
262
|
+
'list[str]'
|
|
263
|
+
"""
|
|
264
|
+
for arg in node.args.args:
|
|
265
|
+
if arg.arg == var_name and arg.annotation:
|
|
266
|
+
return self._annotation_to_str(arg.annotation)
|
|
267
|
+
return None
|
|
268
|
+
|
|
269
|
+
@pre(lambda self, _node, var_name, param_type: len(var_name) > 0)
|
|
270
|
+
@post(lambda result: result in Confidence)
|
|
271
|
+
def _calculate_confidence(
|
|
272
|
+
self,
|
|
273
|
+
_node: ast.FunctionDef | ast.AsyncFunctionDef,
|
|
274
|
+
var_name: str,
|
|
275
|
+
param_type: str | None,
|
|
276
|
+
) -> Confidence:
|
|
277
|
+
"""
|
|
278
|
+
Calculate confidence based on context.
|
|
279
|
+
|
|
280
|
+
>>> import ast
|
|
281
|
+
>>> detector = NonEmptyDetector()
|
|
282
|
+
>>> func = ast.parse("def f(items: list[str]): pass").body[0]
|
|
283
|
+
>>> detector._calculate_confidence(func, "items", "list[str]")
|
|
284
|
+
<Confidence.HIGH: 'high'>
|
|
285
|
+
"""
|
|
286
|
+
# High confidence if typed as list[T]
|
|
287
|
+
if param_type and param_type.startswith("list["):
|
|
288
|
+
return Confidence.HIGH
|
|
289
|
+
|
|
290
|
+
# Medium confidence if var name suggests collection
|
|
291
|
+
if any(kw in var_name.lower() for kw in ("items", "list", "elements", "data")):
|
|
292
|
+
return Confidence.MEDIUM
|
|
293
|
+
|
|
294
|
+
return Confidence.LOW
|
|
295
|
+
|
|
296
|
+
@pre(lambda self, var_name, param_type: len(var_name) > 0)
|
|
297
|
+
@post(lambda result: len(result) > 0 and "if not" in result)
|
|
298
|
+
def _format_check(self, var_name: str, param_type: str | None) -> str:
|
|
299
|
+
"""
|
|
300
|
+
Format the empty check for display.
|
|
301
|
+
|
|
302
|
+
>>> detector = NonEmptyDetector()
|
|
303
|
+
>>> detector._format_check("items", "list[str]")
|
|
304
|
+
'if not items: raise ... (items: list[str])'
|
|
305
|
+
"""
|
|
306
|
+
type_info = f" ({var_name}: {param_type})" if param_type else ""
|
|
307
|
+
return f"if not {var_name}: raise ...{type_info}"
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Validation Pattern Detector (DX-61, P0).
|
|
3
|
+
|
|
4
|
+
Detects fail-fast validation patterns that could benefit from
|
|
5
|
+
error accumulation for better user experience.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import ast
|
|
9
|
+
from typing import ClassVar
|
|
10
|
+
|
|
11
|
+
from deal import post, pre
|
|
12
|
+
|
|
13
|
+
from invar.core.patterns.detector import BaseDetector
|
|
14
|
+
from invar.core.patterns.types import (
|
|
15
|
+
Confidence,
|
|
16
|
+
PatternID,
|
|
17
|
+
PatternSuggestion,
|
|
18
|
+
Priority,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ValidationDetector(BaseDetector):
|
|
23
|
+
"""
|
|
24
|
+
Detect fail-fast validation that returns early on first error.
|
|
25
|
+
|
|
26
|
+
These are candidates for error accumulation pattern.
|
|
27
|
+
|
|
28
|
+
Detection logic:
|
|
29
|
+
- Find functions with multiple early returns of error-like values
|
|
30
|
+
- Look for patterns like: if condition: return Failure/raise
|
|
31
|
+
- Suggest accumulating all errors before returning
|
|
32
|
+
|
|
33
|
+
>>> import ast
|
|
34
|
+
>>> detector = ValidationDetector()
|
|
35
|
+
>>> code = '''
|
|
36
|
+
... def validate(data):
|
|
37
|
+
... if "name" not in data:
|
|
38
|
+
... return Failure("Missing name")
|
|
39
|
+
... if "email" not in data:
|
|
40
|
+
... return Failure("Missing email")
|
|
41
|
+
... if "age" not in data:
|
|
42
|
+
... return Failure("Missing age")
|
|
43
|
+
... return Success(data)
|
|
44
|
+
... '''
|
|
45
|
+
>>> tree = ast.parse(code)
|
|
46
|
+
>>> suggestions = detector.detect(tree, "test.py")
|
|
47
|
+
>>> len(suggestions) > 0
|
|
48
|
+
True
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
MIN_EARLY_RETURNS: ClassVar[int] = 3
|
|
52
|
+
ERROR_PATTERNS: ClassVar[set[str]] = {"Failure", "Err", "Error", "Left"}
|
|
53
|
+
VALIDATION_KEYWORDS: ClassVar[set[str]] = {"validate", "check", "verify", "parse"}
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
@post(lambda result: result == PatternID.VALIDATION)
|
|
57
|
+
def pattern_id(self) -> PatternID:
|
|
58
|
+
"""Unique identifier for this pattern."""
|
|
59
|
+
return PatternID.VALIDATION
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
@post(lambda result: result == Priority.P0)
|
|
63
|
+
def priority(self) -> Priority:
|
|
64
|
+
"""Priority tier."""
|
|
65
|
+
return Priority.P0
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
@post(lambda result: len(result) > 0)
|
|
69
|
+
def description(self) -> str:
|
|
70
|
+
"""Human-readable description."""
|
|
71
|
+
return "Use error accumulation instead of fail-fast validation"
|
|
72
|
+
|
|
73
|
+
@post(lambda result: all(isinstance(s, PatternSuggestion) for s in result))
|
|
74
|
+
def detect(self, tree: ast.AST, file_path: str) -> list[PatternSuggestion]:
|
|
75
|
+
"""
|
|
76
|
+
Find functions with fail-fast validation patterns.
|
|
77
|
+
|
|
78
|
+
>>> import ast
|
|
79
|
+
>>> detector = ValidationDetector()
|
|
80
|
+
>>> code = '''
|
|
81
|
+
... def check_config(cfg):
|
|
82
|
+
... if not cfg.get("host"):
|
|
83
|
+
... return Failure("No host")
|
|
84
|
+
... if not cfg.get("port"):
|
|
85
|
+
... return Failure("No port")
|
|
86
|
+
... if not cfg.get("user"):
|
|
87
|
+
... return Failure("No user")
|
|
88
|
+
... return Success(cfg)
|
|
89
|
+
... '''
|
|
90
|
+
>>> tree = ast.parse(code)
|
|
91
|
+
>>> suggestions = detector.detect(tree, "test.py")
|
|
92
|
+
>>> len(suggestions) > 0
|
|
93
|
+
True
|
|
94
|
+
"""
|
|
95
|
+
suggestions = []
|
|
96
|
+
|
|
97
|
+
for node in ast.walk(tree):
|
|
98
|
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
99
|
+
suggestion = self._check_function(node, file_path)
|
|
100
|
+
if suggestion:
|
|
101
|
+
suggestions.append(suggestion)
|
|
102
|
+
|
|
103
|
+
return suggestions
|
|
104
|
+
|
|
105
|
+
@pre(lambda self, node, file_path: len(file_path) > 0)
|
|
106
|
+
def _check_function(
|
|
107
|
+
self, node: ast.FunctionDef | ast.AsyncFunctionDef, file_path: str
|
|
108
|
+
) -> PatternSuggestion | None:
|
|
109
|
+
"""
|
|
110
|
+
Check if function has fail-fast validation pattern.
|
|
111
|
+
|
|
112
|
+
>>> import ast
|
|
113
|
+
>>> detector = ValidationDetector()
|
|
114
|
+
>>> code = '''
|
|
115
|
+
... def validate(x):
|
|
116
|
+
... if not x.a: return Failure("a")
|
|
117
|
+
... if not x.b: return Failure("b")
|
|
118
|
+
... if not x.c: return Failure("c")
|
|
119
|
+
... return Success(x)
|
|
120
|
+
... '''
|
|
121
|
+
>>> tree = ast.parse(code)
|
|
122
|
+
>>> func = tree.body[0]
|
|
123
|
+
>>> suggestion = detector._check_function(func, "test.py")
|
|
124
|
+
>>> suggestion is not None
|
|
125
|
+
True
|
|
126
|
+
"""
|
|
127
|
+
early_returns = self._count_early_error_returns(node)
|
|
128
|
+
|
|
129
|
+
if early_returns >= self.MIN_EARLY_RETURNS:
|
|
130
|
+
confidence = self._calculate_confidence(node, early_returns)
|
|
131
|
+
|
|
132
|
+
return self.make_suggestion(
|
|
133
|
+
pattern_id=self.pattern_id,
|
|
134
|
+
priority=self.priority,
|
|
135
|
+
file_path=file_path,
|
|
136
|
+
line=node.lineno,
|
|
137
|
+
message=f"{early_returns} early error returns - consider error accumulation",
|
|
138
|
+
current_code=self._format_function_preview(node),
|
|
139
|
+
suggested_pattern="Collect all errors, return list[Error]",
|
|
140
|
+
confidence=confidence,
|
|
141
|
+
reference_pattern="Pattern 2: Validation for Error Accumulation",
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
@post(lambda result: result >= 0)
|
|
147
|
+
def _count_early_error_returns(
|
|
148
|
+
self, node: ast.FunctionDef | ast.AsyncFunctionDef
|
|
149
|
+
) -> int:
|
|
150
|
+
"""
|
|
151
|
+
Count early returns with error-like values inside if statements.
|
|
152
|
+
|
|
153
|
+
Only counts if statements at the top level of the function body,
|
|
154
|
+
not nested functions or lambdas.
|
|
155
|
+
|
|
156
|
+
>>> import ast
|
|
157
|
+
>>> detector = ValidationDetector()
|
|
158
|
+
>>> code = '''
|
|
159
|
+
... def f(x):
|
|
160
|
+
... if not x.a: return Failure("a")
|
|
161
|
+
... if not x.b: return Failure("b")
|
|
162
|
+
... return Success(x)
|
|
163
|
+
... '''
|
|
164
|
+
>>> tree = ast.parse(code)
|
|
165
|
+
>>> func = tree.body[0]
|
|
166
|
+
>>> detector._count_early_error_returns(func)
|
|
167
|
+
2
|
|
168
|
+
"""
|
|
169
|
+
count = 0
|
|
170
|
+
|
|
171
|
+
# Only iterate direct children, not nested functions
|
|
172
|
+
for stmt in node.body:
|
|
173
|
+
count += self._count_if_error_returns(stmt)
|
|
174
|
+
|
|
175
|
+
return count
|
|
176
|
+
|
|
177
|
+
@post(lambda result: result >= 0)
|
|
178
|
+
def _count_if_error_returns(self, stmt: ast.stmt) -> int:
|
|
179
|
+
"""
|
|
180
|
+
Recursively count if statements with error returns, avoiding nested functions.
|
|
181
|
+
|
|
182
|
+
>>> import ast
|
|
183
|
+
>>> detector = ValidationDetector()
|
|
184
|
+
>>> stmt = ast.parse("if x: return Failure('e')").body[0]
|
|
185
|
+
>>> detector._count_if_error_returns(stmt)
|
|
186
|
+
1
|
|
187
|
+
"""
|
|
188
|
+
count = 0
|
|
189
|
+
|
|
190
|
+
if isinstance(stmt, ast.If):
|
|
191
|
+
# Check if body has early error return
|
|
192
|
+
for body_stmt in stmt.body:
|
|
193
|
+
if isinstance(body_stmt, ast.Return) and self._is_error_return(body_stmt):
|
|
194
|
+
count += 1
|
|
195
|
+
break
|
|
196
|
+
# Recurse into else/elif but NOT into nested functions
|
|
197
|
+
for body_stmt in stmt.body:
|
|
198
|
+
if not isinstance(body_stmt, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
199
|
+
count += self._count_if_error_returns(body_stmt)
|
|
200
|
+
for else_stmt in stmt.orelse:
|
|
201
|
+
if not isinstance(else_stmt, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
202
|
+
count += self._count_if_error_returns(else_stmt)
|
|
203
|
+
|
|
204
|
+
return count
|
|
205
|
+
|
|
206
|
+
@post(lambda result: isinstance(result, bool))
|
|
207
|
+
def _is_error_return(self, node: ast.Return) -> bool:
|
|
208
|
+
"""
|
|
209
|
+
Check if return value looks like an error.
|
|
210
|
+
|
|
211
|
+
>>> import ast
|
|
212
|
+
>>> detector = ValidationDetector()
|
|
213
|
+
>>> ret = ast.parse("return Failure('error')").body[0].value
|
|
214
|
+
>>> # This tests the return statement's value
|
|
215
|
+
>>> detector._is_error_return(ast.Return(value=ret))
|
|
216
|
+
True
|
|
217
|
+
"""
|
|
218
|
+
if node.value is None:
|
|
219
|
+
return False
|
|
220
|
+
|
|
221
|
+
# Check for Failure(...), Err(...), etc.
|
|
222
|
+
if isinstance(node.value, ast.Call):
|
|
223
|
+
if isinstance(node.value.func, ast.Name):
|
|
224
|
+
return node.value.func.id in self.ERROR_PATTERNS
|
|
225
|
+
if isinstance(node.value.func, ast.Attribute):
|
|
226
|
+
return node.value.func.attr in self.ERROR_PATTERNS
|
|
227
|
+
|
|
228
|
+
return False
|
|
229
|
+
|
|
230
|
+
@pre(lambda self, node, early_returns: early_returns >= 0)
|
|
231
|
+
@post(lambda result: result in Confidence)
|
|
232
|
+
def _calculate_confidence(
|
|
233
|
+
self, node: ast.FunctionDef | ast.AsyncFunctionDef, early_returns: int
|
|
234
|
+
) -> Confidence:
|
|
235
|
+
"""
|
|
236
|
+
Calculate confidence based on function characteristics.
|
|
237
|
+
|
|
238
|
+
>>> import ast
|
|
239
|
+
>>> detector = ValidationDetector()
|
|
240
|
+
>>> func = ast.parse("def validate_config(x): pass").body[0]
|
|
241
|
+
>>> detector._calculate_confidence(func, 3)
|
|
242
|
+
<Confidence.HIGH: 'high'>
|
|
243
|
+
"""
|
|
244
|
+
# High confidence if function name suggests validation
|
|
245
|
+
func_name = node.name.lower()
|
|
246
|
+
if any(kw in func_name for kw in self.VALIDATION_KEYWORDS):
|
|
247
|
+
return Confidence.HIGH
|
|
248
|
+
|
|
249
|
+
# High confidence if many early returns
|
|
250
|
+
if early_returns >= 5:
|
|
251
|
+
return Confidence.HIGH
|
|
252
|
+
|
|
253
|
+
# Medium confidence for moderate early returns
|
|
254
|
+
if early_returns >= 3:
|
|
255
|
+
return Confidence.MEDIUM
|
|
256
|
+
|
|
257
|
+
return Confidence.LOW
|
|
258
|
+
|
|
259
|
+
@post(lambda result: len(result) > 0 and "def " in result)
|
|
260
|
+
def _format_function_preview(
|
|
261
|
+
self, node: ast.FunctionDef | ast.AsyncFunctionDef
|
|
262
|
+
) -> str:
|
|
263
|
+
"""
|
|
264
|
+
Format function preview for display.
|
|
265
|
+
|
|
266
|
+
>>> import ast
|
|
267
|
+
>>> detector = ValidationDetector()
|
|
268
|
+
>>> func = ast.parse("def validate(data): pass").body[0]
|
|
269
|
+
>>> preview = detector._format_function_preview(func)
|
|
270
|
+
>>> "validate" in preview
|
|
271
|
+
True
|
|
272
|
+
"""
|
|
273
|
+
prefix = "async def" if isinstance(node, ast.AsyncFunctionDef) else "def"
|
|
274
|
+
params = self.get_function_params(node)
|
|
275
|
+
param_str = ", ".join(name for name, _ in params[:3])
|
|
276
|
+
if len(params) > 3:
|
|
277
|
+
param_str += ", ..."
|
|
278
|
+
return f"{prefix} {node.name}({param_str})"
|