invar-tools 1.3.3__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invar/core/formatter.py +6 -1
- invar/core/models.py +13 -0
- invar/core/patterns/__init__.py +53 -0
- invar/core/patterns/detector.py +249 -0
- invar/core/patterns/p0_exhaustive.py +207 -0
- invar/core/patterns/p0_literal.py +307 -0
- invar/core/patterns/p0_newtype.py +211 -0
- invar/core/patterns/p0_nonempty.py +307 -0
- invar/core/patterns/p0_validation.py +278 -0
- invar/core/patterns/registry.py +234 -0
- invar/core/patterns/types.py +167 -0
- invar/core/trivial_detection.py +189 -0
- invar/mcp/server.py +4 -0
- invar/shell/commands/guard.py +65 -0
- invar/shell/contract_coverage.py +358 -0
- invar/shell/guard_output.py +5 -0
- invar/shell/pattern_integration.py +234 -0
- invar/shell/templates.py +2 -2
- invar/shell/testing.py +13 -2
- invar/templates/config/CLAUDE.md.jinja +1 -0
- invar/templates/skills/develop/SKILL.md.jinja +49 -0
- invar/templates/skills/review/SKILL.md.jinja +196 -31
- {invar_tools-1.3.3.dist-info → invar_tools-1.5.0.dist-info}/METADATA +24 -15
- {invar_tools-1.3.3.dist-info → invar_tools-1.5.0.dist-info}/RECORD +29 -17
- {invar_tools-1.3.3.dist-info → invar_tools-1.5.0.dist-info}/entry_points.txt +1 -0
- {invar_tools-1.3.3.dist-info → invar_tools-1.5.0.dist-info}/WHEEL +0 -0
- {invar_tools-1.3.3.dist-info → invar_tools-1.5.0.dist-info}/licenses/LICENSE +0 -0
- {invar_tools-1.3.3.dist-info → invar_tools-1.5.0.dist-info}/licenses/LICENSE-GPL +0 -0
- {invar_tools-1.3.3.dist-info → invar_tools-1.5.0.dist-info}/licenses/NOTICE +0 -0
invar/core/formatter.py
CHANGED
|
@@ -235,7 +235,7 @@ def format_guard_agent(report: GuardReport, combined_status: str | None = None)
|
|
|
235
235
|
status = combined_status if combined_status else ("passed" if report.passed else "failed")
|
|
236
236
|
static_passed = report.errors == 0
|
|
237
237
|
|
|
238
|
-
|
|
238
|
+
result = {
|
|
239
239
|
"status": status,
|
|
240
240
|
# DX-26: Separate static results from combined status
|
|
241
241
|
"static": {
|
|
@@ -252,6 +252,11 @@ def format_guard_agent(report: GuardReport, combined_status: str | None = None)
|
|
|
252
252
|
},
|
|
253
253
|
"fixes": [_violation_to_fix(v) for v in report.violations],
|
|
254
254
|
}
|
|
255
|
+
# DX-61: Add suggests count if any pattern suggestions exist
|
|
256
|
+
if report.suggests > 0:
|
|
257
|
+
result["static"]["suggests"] = report.suggests
|
|
258
|
+
result["summary"]["suggests"] = report.suggests
|
|
259
|
+
return result
|
|
255
260
|
|
|
256
261
|
|
|
257
262
|
@post(lambda result: "file" in result and "rule" in result and "severity" in result)
|
invar/core/models.py
CHANGED
|
@@ -28,6 +28,7 @@ class Severity(str, Enum):
|
|
|
28
28
|
ERROR = "error"
|
|
29
29
|
WARNING = "warning"
|
|
30
30
|
INFO = "info" # Phase 7: For informational issues like redundant type contracts
|
|
31
|
+
SUGGEST = "suggest" # DX-61: Functional pattern suggestions
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
class Contract(BaseModel):
|
|
@@ -90,6 +91,7 @@ class GuardReport(BaseModel):
|
|
|
90
91
|
errors: int = 0
|
|
91
92
|
warnings: int = 0
|
|
92
93
|
infos: int = 0 # Phase 7: Track INFO-level issues
|
|
94
|
+
suggests: int = 0 # DX-61: Track SUGGEST-level pattern suggestions
|
|
93
95
|
# P24: Contract coverage statistics (Core files only)
|
|
94
96
|
core_functions_total: int = 0
|
|
95
97
|
core_functions_with_contracts: int = 0
|
|
@@ -106,12 +108,18 @@ class GuardReport(BaseModel):
|
|
|
106
108
|
>>> report.add_violation(v)
|
|
107
109
|
>>> report.errors
|
|
108
110
|
1
|
|
111
|
+
>>> v2 = Violation(rule="pattern", severity=Severity.SUGGEST, file="x.py", message="sug")
|
|
112
|
+
>>> report.add_violation(v2)
|
|
113
|
+
>>> report.suggests
|
|
114
|
+
1
|
|
109
115
|
"""
|
|
110
116
|
self.violations.append(violation)
|
|
111
117
|
if violation.severity == Severity.ERROR:
|
|
112
118
|
self.errors += 1
|
|
113
119
|
elif violation.severity == Severity.WARNING:
|
|
114
120
|
self.warnings += 1
|
|
121
|
+
elif violation.severity == Severity.SUGGEST:
|
|
122
|
+
self.suggests += 1
|
|
115
123
|
else:
|
|
116
124
|
self.infos += 1
|
|
117
125
|
|
|
@@ -263,6 +271,11 @@ class RuleConfig(BaseModel):
|
|
|
263
271
|
timeout_crosshair: int = Field(default=300, ge=1, le=1800) # Symbolic verification total
|
|
264
272
|
timeout_crosshair_per_condition: int = Field(default=30, ge=1, le=300) # Per-contract limit
|
|
265
273
|
|
|
274
|
+
# DX-61: Pattern detection configuration
|
|
275
|
+
pattern_min_confidence: str = Field(default="medium") # low, medium, high
|
|
276
|
+
pattern_priorities: list[str] = Field(default_factory=lambda: ["P0"]) # P0, P1
|
|
277
|
+
pattern_exclude: list[str] = Field(default_factory=list) # Pattern IDs to exclude
|
|
278
|
+
|
|
266
279
|
|
|
267
280
|
# Phase 4: Perception models
|
|
268
281
|
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Functional Pattern Detection (DX-61).
|
|
3
|
+
|
|
4
|
+
This module provides pattern detection for suggesting functional programming
|
|
5
|
+
improvements in Python code. Guard integrates with this module to provide
|
|
6
|
+
SUGGEST-level feedback when it detects opportunities for patterns like:
|
|
7
|
+
|
|
8
|
+
P0 (Core Patterns):
|
|
9
|
+
- NewType: Semantic clarity for multiple same-type parameters
|
|
10
|
+
- Validation: Error accumulation instead of fail-fast
|
|
11
|
+
- NonEmpty: Compile-time non-empty guarantees
|
|
12
|
+
- Literal: Type-safe finite value sets
|
|
13
|
+
- ExhaustiveMatch: assert_never for enum matching
|
|
14
|
+
|
|
15
|
+
P1 (Extended Patterns - future):
|
|
16
|
+
- SmartConstructor: Validation at construction time
|
|
17
|
+
- StructuredError: Typed errors for programmatic handling
|
|
18
|
+
|
|
19
|
+
Usage:
|
|
20
|
+
>>> from invar.core.patterns import detect_patterns
|
|
21
|
+
>>> source = "def f(a: str, b: str, c: str): pass"
|
|
22
|
+
>>> result = detect_patterns("test.py", source)
|
|
23
|
+
>>> result.has_suggestions
|
|
24
|
+
True
|
|
25
|
+
|
|
26
|
+
See .invar/examples/functional.py for pattern examples.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from invar.core.patterns.registry import (
|
|
30
|
+
PatternRegistry,
|
|
31
|
+
detect_patterns,
|
|
32
|
+
get_registry,
|
|
33
|
+
)
|
|
34
|
+
from invar.core.patterns.types import (
|
|
35
|
+
Confidence,
|
|
36
|
+
DetectionResult,
|
|
37
|
+
Location,
|
|
38
|
+
PatternID,
|
|
39
|
+
PatternSuggestion,
|
|
40
|
+
Priority,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
__all__ = [
|
|
44
|
+
"Confidence",
|
|
45
|
+
"DetectionResult",
|
|
46
|
+
"Location",
|
|
47
|
+
"PatternID",
|
|
48
|
+
"PatternRegistry",
|
|
49
|
+
"PatternSuggestion",
|
|
50
|
+
"Priority",
|
|
51
|
+
"detect_patterns",
|
|
52
|
+
"get_registry",
|
|
53
|
+
]
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pattern Detector Protocol (DX-61).
|
|
3
|
+
|
|
4
|
+
Base protocol for all pattern detectors.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import ast
|
|
8
|
+
from abc import abstractmethod
|
|
9
|
+
from typing import Protocol
|
|
10
|
+
|
|
11
|
+
from deal import post, pre
|
|
12
|
+
|
|
13
|
+
from invar.core.patterns.types import (
|
|
14
|
+
Confidence,
|
|
15
|
+
Location,
|
|
16
|
+
PatternID,
|
|
17
|
+
PatternSuggestion,
|
|
18
|
+
Priority,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PatternDetector(Protocol):
|
|
23
|
+
"""
|
|
24
|
+
Protocol for pattern detectors.
|
|
25
|
+
|
|
26
|
+
Each detector identifies opportunities for a specific functional pattern.
|
|
27
|
+
Detectors analyze AST nodes and return suggestions with confidence levels.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
@abstractmethod
|
|
32
|
+
@post(lambda result: result in PatternID)
|
|
33
|
+
def pattern_id(self) -> PatternID:
|
|
34
|
+
"""Unique identifier for this pattern."""
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
@abstractmethod
|
|
39
|
+
@post(lambda result: result in Priority)
|
|
40
|
+
def priority(self) -> Priority:
|
|
41
|
+
"""Priority tier (P0 or P1)."""
|
|
42
|
+
...
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
@abstractmethod
|
|
46
|
+
@post(lambda result: len(result) > 0)
|
|
47
|
+
def description(self) -> str:
|
|
48
|
+
"""Human-readable description of the pattern."""
|
|
49
|
+
...
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
@post(lambda result: all(isinstance(s, PatternSuggestion) for s in result))
|
|
53
|
+
def detect(self, tree: ast.AST, file_path: str) -> list[PatternSuggestion]:
|
|
54
|
+
"""
|
|
55
|
+
Analyze AST and return pattern suggestions.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
tree: Parsed AST of the source file
|
|
59
|
+
file_path: Path to the source file (for location reporting)
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
List of pattern suggestions found in the file
|
|
63
|
+
"""
|
|
64
|
+
...
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class BaseDetector:
|
|
68
|
+
"""
|
|
69
|
+
Base class with common detection utilities.
|
|
70
|
+
|
|
71
|
+
Provides helper methods for AST analysis that can be reused
|
|
72
|
+
across different pattern detectors.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
@post(lambda result: all(isinstance(name, str) and name for name, _ in result))
|
|
76
|
+
def get_function_params(self, node: ast.FunctionDef) -> list[tuple[str, str | None]]:
|
|
77
|
+
"""
|
|
78
|
+
Extract parameter names and type annotations from a function.
|
|
79
|
+
|
|
80
|
+
>>> import ast
|
|
81
|
+
>>> code = "def f(a: str, b: int, c): pass"
|
|
82
|
+
>>> tree = ast.parse(code)
|
|
83
|
+
>>> func = tree.body[0]
|
|
84
|
+
>>> detector = BaseDetector()
|
|
85
|
+
>>> params = detector.get_function_params(func)
|
|
86
|
+
>>> params[0]
|
|
87
|
+
('a', 'str')
|
|
88
|
+
>>> params[1]
|
|
89
|
+
('b', 'int')
|
|
90
|
+
>>> params[2]
|
|
91
|
+
('c', None)
|
|
92
|
+
"""
|
|
93
|
+
params = []
|
|
94
|
+
for arg in node.args.args:
|
|
95
|
+
name = arg.arg
|
|
96
|
+
type_hint = None
|
|
97
|
+
if arg.annotation:
|
|
98
|
+
type_hint = self._annotation_to_str(arg.annotation)
|
|
99
|
+
params.append((name, type_hint))
|
|
100
|
+
return params
|
|
101
|
+
|
|
102
|
+
@post(lambda result: isinstance(result, str) and len(result) > 0)
|
|
103
|
+
def _annotation_to_str(self, annotation: ast.expr) -> str:
|
|
104
|
+
"""
|
|
105
|
+
Convert an annotation AST node to string.
|
|
106
|
+
|
|
107
|
+
>>> import ast
|
|
108
|
+
>>> detector = BaseDetector()
|
|
109
|
+
>>> detector._annotation_to_str(ast.Name(id="str"))
|
|
110
|
+
'str'
|
|
111
|
+
>>> detector._annotation_to_str(ast.Constant(value="str"))
|
|
112
|
+
'str'
|
|
113
|
+
"""
|
|
114
|
+
if isinstance(annotation, ast.Name):
|
|
115
|
+
return annotation.id
|
|
116
|
+
elif isinstance(annotation, ast.Constant):
|
|
117
|
+
return str(annotation.value)
|
|
118
|
+
elif isinstance(annotation, ast.Subscript):
|
|
119
|
+
# Handle generics like list[str]
|
|
120
|
+
base = self._annotation_to_str(annotation.value)
|
|
121
|
+
if isinstance(annotation.slice, ast.Tuple):
|
|
122
|
+
args = ", ".join(self._annotation_to_str(e) for e in annotation.slice.elts)
|
|
123
|
+
else:
|
|
124
|
+
args = self._annotation_to_str(annotation.slice)
|
|
125
|
+
return f"{base}[{args}]"
|
|
126
|
+
elif isinstance(annotation, ast.Attribute):
|
|
127
|
+
# Handle qualified names like typing.List
|
|
128
|
+
parts = []
|
|
129
|
+
node = annotation
|
|
130
|
+
while isinstance(node, ast.Attribute):
|
|
131
|
+
parts.append(node.attr)
|
|
132
|
+
node = node.value
|
|
133
|
+
if isinstance(node, ast.Name):
|
|
134
|
+
parts.append(node.id)
|
|
135
|
+
return ".".join(reversed(parts))
|
|
136
|
+
elif isinstance(annotation, ast.BinOp) and isinstance(annotation.op, ast.BitOr):
|
|
137
|
+
# Handle X | Y union syntax
|
|
138
|
+
left = self._annotation_to_str(annotation.left)
|
|
139
|
+
right = self._annotation_to_str(annotation.right)
|
|
140
|
+
return f"{left} | {right}"
|
|
141
|
+
else:
|
|
142
|
+
# Python 3.9+ always has ast.unparse (project requires 3.11+)
|
|
143
|
+
return ast.unparse(annotation)
|
|
144
|
+
|
|
145
|
+
@pre(lambda self, params, type_name: len(type_name) > 0)
|
|
146
|
+
@post(lambda result: result >= 0)
|
|
147
|
+
def count_type_occurrences(
|
|
148
|
+
self, params: list[tuple[str, str | None]], type_name: str
|
|
149
|
+
) -> int:
|
|
150
|
+
"""
|
|
151
|
+
Count how many parameters have a specific type.
|
|
152
|
+
|
|
153
|
+
>>> detector = BaseDetector()
|
|
154
|
+
>>> params = [("a", "str"), ("b", "str"), ("c", "int")]
|
|
155
|
+
>>> detector.count_type_occurrences(params, "str")
|
|
156
|
+
2
|
|
157
|
+
"""
|
|
158
|
+
return sum(1 for _, t in params if t == type_name)
|
|
159
|
+
|
|
160
|
+
@post(lambda result: isinstance(result, bool))
|
|
161
|
+
def has_match_statement(self, node: ast.FunctionDef) -> bool:
|
|
162
|
+
"""
|
|
163
|
+
Check if function contains a match statement.
|
|
164
|
+
|
|
165
|
+
>>> import ast
|
|
166
|
+
>>> code = '''
|
|
167
|
+
... def f(x):
|
|
168
|
+
... match x:
|
|
169
|
+
... case 1: pass
|
|
170
|
+
... '''
|
|
171
|
+
>>> tree = ast.parse(code)
|
|
172
|
+
>>> func = tree.body[0]
|
|
173
|
+
>>> detector = BaseDetector()
|
|
174
|
+
>>> detector.has_match_statement(func)
|
|
175
|
+
True
|
|
176
|
+
"""
|
|
177
|
+
return any(isinstance(child, ast.Match) for child in ast.walk(node))
|
|
178
|
+
|
|
179
|
+
@post(lambda result: all(isinstance(c, str) for c in result))
|
|
180
|
+
def get_enum_cases(self, match_node: ast.Match) -> list[str]:
|
|
181
|
+
"""
|
|
182
|
+
Extract case patterns from a match statement.
|
|
183
|
+
|
|
184
|
+
>>> import ast
|
|
185
|
+
>>> code = '''
|
|
186
|
+
... match status:
|
|
187
|
+
... case Status.A: pass
|
|
188
|
+
... case Status.B: pass
|
|
189
|
+
... '''
|
|
190
|
+
>>> tree = ast.parse(code)
|
|
191
|
+
>>> match = tree.body[0]
|
|
192
|
+
>>> detector = BaseDetector()
|
|
193
|
+
>>> cases = detector.get_enum_cases(match)
|
|
194
|
+
>>> "Status.A" in cases
|
|
195
|
+
True
|
|
196
|
+
"""
|
|
197
|
+
cases = []
|
|
198
|
+
for case in match_node.cases:
|
|
199
|
+
pattern = case.pattern
|
|
200
|
+
if isinstance(pattern, ast.MatchValue):
|
|
201
|
+
cases.append(ast.unparse(pattern.value) if hasattr(ast, "unparse") else str(pattern.value))
|
|
202
|
+
elif isinstance(pattern, ast.MatchAs) and pattern.pattern is None:
|
|
203
|
+
cases.append("_") # Wildcard
|
|
204
|
+
return cases
|
|
205
|
+
|
|
206
|
+
@pre(lambda self, pattern_id, priority, file_path, line, message, current_code, suggested_pattern, confidence, reference_pattern: line > 0)
|
|
207
|
+
@post(lambda result: result.reference_file == ".invar/examples/functional.py")
|
|
208
|
+
def make_suggestion(
|
|
209
|
+
self,
|
|
210
|
+
pattern_id: PatternID,
|
|
211
|
+
priority: Priority,
|
|
212
|
+
file_path: str,
|
|
213
|
+
line: int,
|
|
214
|
+
message: str,
|
|
215
|
+
current_code: str,
|
|
216
|
+
suggested_pattern: str,
|
|
217
|
+
confidence: Confidence,
|
|
218
|
+
reference_pattern: str,
|
|
219
|
+
) -> PatternSuggestion:
|
|
220
|
+
"""
|
|
221
|
+
Create a pattern suggestion with standard reference file.
|
|
222
|
+
|
|
223
|
+
>>> from invar.core.patterns.types import PatternID, Priority, Confidence
|
|
224
|
+
>>> detector = BaseDetector()
|
|
225
|
+
>>> suggestion = detector.make_suggestion(
|
|
226
|
+
... pattern_id=PatternID.NEWTYPE,
|
|
227
|
+
... priority=Priority.P0,
|
|
228
|
+
... file_path="test.py",
|
|
229
|
+
... line=10,
|
|
230
|
+
... message="Test message",
|
|
231
|
+
... current_code="def f(): pass",
|
|
232
|
+
... suggested_pattern="NewType",
|
|
233
|
+
... confidence=Confidence.HIGH,
|
|
234
|
+
... reference_pattern="Pattern 1: NewType",
|
|
235
|
+
... )
|
|
236
|
+
>>> suggestion.reference_file
|
|
237
|
+
'.invar/examples/functional.py'
|
|
238
|
+
"""
|
|
239
|
+
return PatternSuggestion(
|
|
240
|
+
pattern_id=pattern_id,
|
|
241
|
+
location=Location(file=file_path, line=line),
|
|
242
|
+
message=message,
|
|
243
|
+
confidence=confidence,
|
|
244
|
+
priority=priority,
|
|
245
|
+
current_code=current_code,
|
|
246
|
+
suggested_pattern=suggested_pattern,
|
|
247
|
+
reference_file=".invar/examples/functional.py",
|
|
248
|
+
reference_pattern=reference_pattern,
|
|
249
|
+
)
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Exhaustive Match Pattern Detector (DX-61, P0).
|
|
3
|
+
|
|
4
|
+
Detects match statements on enums that don't use assert_never
|
|
5
|
+
for exhaustiveness checking.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import ast
|
|
9
|
+
|
|
10
|
+
from deal import post, pre
|
|
11
|
+
|
|
12
|
+
from invar.core.patterns.detector import BaseDetector
|
|
13
|
+
from invar.core.patterns.types import (
|
|
14
|
+
Confidence,
|
|
15
|
+
PatternID,
|
|
16
|
+
PatternSuggestion,
|
|
17
|
+
Priority,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ExhaustiveMatchDetector(BaseDetector):
|
|
22
|
+
"""
|
|
23
|
+
Detect non-exhaustive match statements on enums.
|
|
24
|
+
|
|
25
|
+
These are candidates for assert_never pattern to ensure
|
|
26
|
+
all cases are handled at compile time.
|
|
27
|
+
|
|
28
|
+
Detection logic:
|
|
29
|
+
- Find match statements
|
|
30
|
+
- Check if wildcard case exists but doesn't use assert_never
|
|
31
|
+
- Suggest adding assert_never for exhaustiveness
|
|
32
|
+
|
|
33
|
+
>>> import ast
|
|
34
|
+
>>> detector = ExhaustiveMatchDetector()
|
|
35
|
+
>>> code = '''
|
|
36
|
+
... def handle(status: Status) -> str:
|
|
37
|
+
... match status:
|
|
38
|
+
... case Status.PENDING:
|
|
39
|
+
... return "pending"
|
|
40
|
+
... case Status.DONE:
|
|
41
|
+
... return "done"
|
|
42
|
+
... case _:
|
|
43
|
+
... return "unknown"
|
|
44
|
+
... '''
|
|
45
|
+
>>> tree = ast.parse(code)
|
|
46
|
+
>>> suggestions = detector.detect(tree, "test.py")
|
|
47
|
+
>>> len(suggestions) > 0
|
|
48
|
+
True
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
@post(lambda result: result == PatternID.EXHAUSTIVE)
|
|
53
|
+
def pattern_id(self) -> PatternID:
|
|
54
|
+
"""Unique identifier for this pattern."""
|
|
55
|
+
return PatternID.EXHAUSTIVE
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
@post(lambda result: result == Priority.P0)
|
|
59
|
+
def priority(self) -> Priority:
|
|
60
|
+
"""Priority tier."""
|
|
61
|
+
return Priority.P0
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
@post(lambda result: len(result) > 0)
|
|
65
|
+
def description(self) -> str:
|
|
66
|
+
"""Human-readable description."""
|
|
67
|
+
return "Use assert_never for exhaustive enum matching"
|
|
68
|
+
|
|
69
|
+
@post(lambda result: all(isinstance(s, PatternSuggestion) for s in result))
|
|
70
|
+
def detect(self, tree: ast.AST, file_path: str) -> list[PatternSuggestion]:
|
|
71
|
+
"""
|
|
72
|
+
Find match statements with non-exhaustive patterns.
|
|
73
|
+
|
|
74
|
+
>>> import ast
|
|
75
|
+
>>> detector = ExhaustiveMatchDetector()
|
|
76
|
+
>>> code = '''
|
|
77
|
+
... def f(x):
|
|
78
|
+
... match x:
|
|
79
|
+
... case A.ONE: return 1
|
|
80
|
+
... case _: return 0
|
|
81
|
+
... '''
|
|
82
|
+
>>> tree = ast.parse(code)
|
|
83
|
+
>>> suggestions = detector.detect(tree, "test.py")
|
|
84
|
+
>>> len(suggestions) > 0
|
|
85
|
+
True
|
|
86
|
+
"""
|
|
87
|
+
suggestions = []
|
|
88
|
+
|
|
89
|
+
for node in ast.walk(tree):
|
|
90
|
+
if isinstance(node, ast.Match):
|
|
91
|
+
suggestion = self._check_match(node, file_path)
|
|
92
|
+
if suggestion:
|
|
93
|
+
suggestions.append(suggestion)
|
|
94
|
+
|
|
95
|
+
return suggestions
|
|
96
|
+
|
|
97
|
+
@pre(lambda self, node, file_path: len(file_path) > 0)
|
|
98
|
+
def _check_match(
|
|
99
|
+
self, node: ast.Match, file_path: str
|
|
100
|
+
) -> PatternSuggestion | None:
|
|
101
|
+
"""
|
|
102
|
+
Check if match statement could benefit from assert_never.
|
|
103
|
+
|
|
104
|
+
>>> import ast
|
|
105
|
+
>>> detector = ExhaustiveMatchDetector()
|
|
106
|
+
>>> code = '''
|
|
107
|
+
... match x:
|
|
108
|
+
... case Status.A: pass
|
|
109
|
+
... case _: pass
|
|
110
|
+
... '''
|
|
111
|
+
>>> tree = ast.parse(code)
|
|
112
|
+
>>> match = tree.body[0]
|
|
113
|
+
>>> suggestion = detector._check_match(match, "test.py")
|
|
114
|
+
>>> suggestion is not None
|
|
115
|
+
True
|
|
116
|
+
"""
|
|
117
|
+
has_wildcard = False
|
|
118
|
+
uses_assert_never = False
|
|
119
|
+
has_enum_patterns = False
|
|
120
|
+
enum_cases = []
|
|
121
|
+
|
|
122
|
+
for case in node.cases:
|
|
123
|
+
pattern = case.pattern
|
|
124
|
+
|
|
125
|
+
# Check for wildcard
|
|
126
|
+
if isinstance(pattern, ast.MatchAs) and pattern.pattern is None:
|
|
127
|
+
has_wildcard = True
|
|
128
|
+
# Check if body uses assert_never
|
|
129
|
+
uses_assert_never = self._uses_assert_never(case.body)
|
|
130
|
+
|
|
131
|
+
# Check for enum-like patterns (e.g., Status.PENDING)
|
|
132
|
+
elif isinstance(pattern, ast.MatchValue):
|
|
133
|
+
if isinstance(pattern.value, ast.Attribute):
|
|
134
|
+
has_enum_patterns = True
|
|
135
|
+
enum_cases.append(ast.unparse(pattern.value) if hasattr(ast, "unparse") else "...")
|
|
136
|
+
|
|
137
|
+
# Suggest if: has enum patterns + has wildcard + doesn't use assert_never
|
|
138
|
+
if has_enum_patterns and has_wildcard and not uses_assert_never:
|
|
139
|
+
confidence = self._calculate_confidence(enum_cases)
|
|
140
|
+
|
|
141
|
+
return self.make_suggestion(
|
|
142
|
+
pattern_id=self.pattern_id,
|
|
143
|
+
priority=self.priority,
|
|
144
|
+
file_path=file_path,
|
|
145
|
+
line=node.lineno,
|
|
146
|
+
message="Match has wildcard without assert_never - missing cases won't be caught",
|
|
147
|
+
current_code=self._format_match_preview(enum_cases),
|
|
148
|
+
suggested_pattern="case _: assert_never(x) # Type error if cases missing",
|
|
149
|
+
confidence=confidence,
|
|
150
|
+
reference_pattern="Pattern 5: Exhaustive Match",
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
return None
|
|
154
|
+
|
|
155
|
+
@post(lambda result: isinstance(result, bool))
|
|
156
|
+
def _uses_assert_never(self, body: list[ast.stmt]) -> bool:
|
|
157
|
+
"""
|
|
158
|
+
Check if body contains assert_never call.
|
|
159
|
+
|
|
160
|
+
>>> import ast
|
|
161
|
+
>>> detector = ExhaustiveMatchDetector()
|
|
162
|
+
>>> body = ast.parse("assert_never(x)").body
|
|
163
|
+
>>> detector._uses_assert_never(body)
|
|
164
|
+
True
|
|
165
|
+
"""
|
|
166
|
+
for stmt in body:
|
|
167
|
+
if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Call):
|
|
168
|
+
func = stmt.value.func
|
|
169
|
+
if isinstance(func, ast.Name) and func.id == "assert_never":
|
|
170
|
+
return True
|
|
171
|
+
return False
|
|
172
|
+
|
|
173
|
+
@post(lambda result: result in Confidence)
|
|
174
|
+
def _calculate_confidence(self, enum_cases: list[str]) -> Confidence:
|
|
175
|
+
"""
|
|
176
|
+
Calculate confidence based on context.
|
|
177
|
+
|
|
178
|
+
>>> detector = ExhaustiveMatchDetector()
|
|
179
|
+
>>> detector._calculate_confidence(["Status.A", "Status.B", "Status.C"])
|
|
180
|
+
<Confidence.HIGH: 'high'>
|
|
181
|
+
"""
|
|
182
|
+
# High confidence if multiple enum cases from same type
|
|
183
|
+
if len(enum_cases) >= 2:
|
|
184
|
+
# Check if all from same enum
|
|
185
|
+
prefixes = [c.split(".")[0] if "." in c else c for c in enum_cases]
|
|
186
|
+
if len(set(prefixes)) == 1:
|
|
187
|
+
return Confidence.HIGH
|
|
188
|
+
|
|
189
|
+
# Medium confidence for any enum patterns
|
|
190
|
+
if enum_cases:
|
|
191
|
+
return Confidence.MEDIUM
|
|
192
|
+
|
|
193
|
+
return Confidence.LOW
|
|
194
|
+
|
|
195
|
+
@post(lambda result: "match" in result and "case" in result)
|
|
196
|
+
def _format_match_preview(self, enum_cases: list[str]) -> str:
|
|
197
|
+
"""
|
|
198
|
+
Format match statement preview.
|
|
199
|
+
|
|
200
|
+
>>> detector = ExhaustiveMatchDetector()
|
|
201
|
+
>>> detector._format_match_preview(["Status.A", "Status.B"])
|
|
202
|
+
'match ...: case Status.A | Status.B | _: ...'
|
|
203
|
+
"""
|
|
204
|
+
cases_str = " | ".join(enum_cases[:3])
|
|
205
|
+
if len(enum_cases) > 3:
|
|
206
|
+
cases_str += " | ..."
|
|
207
|
+
return f"match ...: case {cases_str} | _: ..."
|