invar-tools 1.4.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. invar/__init__.py +7 -1
  2. invar/core/entry_points.py +12 -10
  3. invar/core/formatter.py +21 -1
  4. invar/core/models.py +98 -0
  5. invar/core/patterns/__init__.py +53 -0
  6. invar/core/patterns/detector.py +249 -0
  7. invar/core/patterns/p0_exhaustive.py +207 -0
  8. invar/core/patterns/p0_literal.py +307 -0
  9. invar/core/patterns/p0_newtype.py +211 -0
  10. invar/core/patterns/p0_nonempty.py +307 -0
  11. invar/core/patterns/p0_validation.py +278 -0
  12. invar/core/patterns/registry.py +234 -0
  13. invar/core/patterns/types.py +167 -0
  14. invar/core/trivial_detection.py +189 -0
  15. invar/mcp/server.py +4 -0
  16. invar/shell/commands/guard.py +100 -8
  17. invar/shell/config.py +46 -0
  18. invar/shell/contract_coverage.py +358 -0
  19. invar/shell/guard_output.py +15 -0
  20. invar/shell/pattern_integration.py +234 -0
  21. invar/shell/testing.py +13 -2
  22. invar/templates/CLAUDE.md.template +18 -10
  23. invar/templates/config/CLAUDE.md.jinja +52 -30
  24. invar/templates/config/context.md.jinja +14 -0
  25. invar/templates/protocol/INVAR.md +1 -0
  26. invar/templates/skills/develop/SKILL.md.jinja +51 -1
  27. invar/templates/skills/review/SKILL.md.jinja +196 -31
  28. {invar_tools-1.4.0.dist-info → invar_tools-1.6.0.dist-info}/METADATA +12 -8
  29. {invar_tools-1.4.0.dist-info → invar_tools-1.6.0.dist-info}/RECORD +34 -22
  30. {invar_tools-1.4.0.dist-info → invar_tools-1.6.0.dist-info}/WHEEL +0 -0
  31. {invar_tools-1.4.0.dist-info → invar_tools-1.6.0.dist-info}/entry_points.txt +0 -0
  32. {invar_tools-1.4.0.dist-info → invar_tools-1.6.0.dist-info}/licenses/LICENSE +0 -0
  33. {invar_tools-1.4.0.dist-info → invar_tools-1.6.0.dist-info}/licenses/LICENSE-GPL +0 -0
  34. {invar_tools-1.4.0.dist-info → invar_tools-1.6.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,207 @@
1
+ """
2
+ Exhaustive Match Pattern Detector (DX-61, P0).
3
+
4
+ Detects match statements on enums that don't use assert_never
5
+ for exhaustiveness checking.
6
+ """
7
+
8
+ import ast
9
+
10
+ from deal import post, pre
11
+
12
+ from invar.core.patterns.detector import BaseDetector
13
+ from invar.core.patterns.types import (
14
+ Confidence,
15
+ PatternID,
16
+ PatternSuggestion,
17
+ Priority,
18
+ )
19
+
20
+
21
+ class ExhaustiveMatchDetector(BaseDetector):
22
+ """
23
+ Detect non-exhaustive match statements on enums.
24
+
25
+ These are candidates for assert_never pattern to ensure
26
+ all cases are handled at compile time.
27
+
28
+ Detection logic:
29
+ - Find match statements
30
+ - Check if wildcard case exists but doesn't use assert_never
31
+ - Suggest adding assert_never for exhaustiveness
32
+
33
+ >>> import ast
34
+ >>> detector = ExhaustiveMatchDetector()
35
+ >>> code = '''
36
+ ... def handle(status: Status) -> str:
37
+ ... match status:
38
+ ... case Status.PENDING:
39
+ ... return "pending"
40
+ ... case Status.DONE:
41
+ ... return "done"
42
+ ... case _:
43
+ ... return "unknown"
44
+ ... '''
45
+ >>> tree = ast.parse(code)
46
+ >>> suggestions = detector.detect(tree, "test.py")
47
+ >>> len(suggestions) > 0
48
+ True
49
+ """
50
+
51
+ @property
52
+ @post(lambda result: result == PatternID.EXHAUSTIVE)
53
+ def pattern_id(self) -> PatternID:
54
+ """Unique identifier for this pattern."""
55
+ return PatternID.EXHAUSTIVE
56
+
57
+ @property
58
+ @post(lambda result: result == Priority.P0)
59
+ def priority(self) -> Priority:
60
+ """Priority tier."""
61
+ return Priority.P0
62
+
63
+ @property
64
+ @post(lambda result: len(result) > 0)
65
+ def description(self) -> str:
66
+ """Human-readable description."""
67
+ return "Use assert_never for exhaustive enum matching"
68
+
69
+ @post(lambda result: all(isinstance(s, PatternSuggestion) for s in result))
70
+ def detect(self, tree: ast.AST, file_path: str) -> list[PatternSuggestion]:
71
+ """
72
+ Find match statements with non-exhaustive patterns.
73
+
74
+ >>> import ast
75
+ >>> detector = ExhaustiveMatchDetector()
76
+ >>> code = '''
77
+ ... def f(x):
78
+ ... match x:
79
+ ... case A.ONE: return 1
80
+ ... case _: return 0
81
+ ... '''
82
+ >>> tree = ast.parse(code)
83
+ >>> suggestions = detector.detect(tree, "test.py")
84
+ >>> len(suggestions) > 0
85
+ True
86
+ """
87
+ suggestions = []
88
+
89
+ for node in ast.walk(tree):
90
+ if isinstance(node, ast.Match):
91
+ suggestion = self._check_match(node, file_path)
92
+ if suggestion:
93
+ suggestions.append(suggestion)
94
+
95
+ return suggestions
96
+
97
+ @pre(lambda self, node, file_path: len(file_path) > 0)
98
+ def _check_match(
99
+ self, node: ast.Match, file_path: str
100
+ ) -> PatternSuggestion | None:
101
+ """
102
+ Check if match statement could benefit from assert_never.
103
+
104
+ >>> import ast
105
+ >>> detector = ExhaustiveMatchDetector()
106
+ >>> code = '''
107
+ ... match x:
108
+ ... case Status.A: pass
109
+ ... case _: pass
110
+ ... '''
111
+ >>> tree = ast.parse(code)
112
+ >>> match = tree.body[0]
113
+ >>> suggestion = detector._check_match(match, "test.py")
114
+ >>> suggestion is not None
115
+ True
116
+ """
117
+ has_wildcard = False
118
+ uses_assert_never = False
119
+ has_enum_patterns = False
120
+ enum_cases = []
121
+
122
+ for case in node.cases:
123
+ pattern = case.pattern
124
+
125
+ # Check for wildcard
126
+ if isinstance(pattern, ast.MatchAs) and pattern.pattern is None:
127
+ has_wildcard = True
128
+ # Check if body uses assert_never
129
+ uses_assert_never = self._uses_assert_never(case.body)
130
+
131
+ # Check for enum-like patterns (e.g., Status.PENDING)
132
+ elif isinstance(pattern, ast.MatchValue):
133
+ if isinstance(pattern.value, ast.Attribute):
134
+ has_enum_patterns = True
135
+ enum_cases.append(ast.unparse(pattern.value) if hasattr(ast, "unparse") else "...")
136
+
137
+ # Suggest if: has enum patterns + has wildcard + doesn't use assert_never
138
+ if has_enum_patterns and has_wildcard and not uses_assert_never:
139
+ confidence = self._calculate_confidence(enum_cases)
140
+
141
+ return self.make_suggestion(
142
+ pattern_id=self.pattern_id,
143
+ priority=self.priority,
144
+ file_path=file_path,
145
+ line=node.lineno,
146
+ message="Match has wildcard without assert_never - missing cases won't be caught",
147
+ current_code=self._format_match_preview(enum_cases),
148
+ suggested_pattern="case _: assert_never(x) # Type error if cases missing",
149
+ confidence=confidence,
150
+ reference_pattern="Pattern 5: Exhaustive Match",
151
+ )
152
+
153
+ return None
154
+
155
+ @post(lambda result: isinstance(result, bool))
156
+ def _uses_assert_never(self, body: list[ast.stmt]) -> bool:
157
+ """
158
+ Check if body contains assert_never call.
159
+
160
+ >>> import ast
161
+ >>> detector = ExhaustiveMatchDetector()
162
+ >>> body = ast.parse("assert_never(x)").body
163
+ >>> detector._uses_assert_never(body)
164
+ True
165
+ """
166
+ for stmt in body:
167
+ if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Call):
168
+ func = stmt.value.func
169
+ if isinstance(func, ast.Name) and func.id == "assert_never":
170
+ return True
171
+ return False
172
+
173
+ @post(lambda result: result in Confidence)
174
+ def _calculate_confidence(self, enum_cases: list[str]) -> Confidence:
175
+ """
176
+ Calculate confidence based on context.
177
+
178
+ >>> detector = ExhaustiveMatchDetector()
179
+ >>> detector._calculate_confidence(["Status.A", "Status.B", "Status.C"])
180
+ <Confidence.HIGH: 'high'>
181
+ """
182
+ # High confidence if multiple enum cases from same type
183
+ if len(enum_cases) >= 2:
184
+ # Check if all from same enum
185
+ prefixes = [c.split(".")[0] if "." in c else c for c in enum_cases]
186
+ if len(set(prefixes)) == 1:
187
+ return Confidence.HIGH
188
+
189
+ # Medium confidence for any enum patterns
190
+ if enum_cases:
191
+ return Confidence.MEDIUM
192
+
193
+ return Confidence.LOW
194
+
195
+ @post(lambda result: "match" in result and "case" in result)
196
+ def _format_match_preview(self, enum_cases: list[str]) -> str:
197
+ """
198
+ Format match statement preview.
199
+
200
+ >>> detector = ExhaustiveMatchDetector()
201
+ >>> detector._format_match_preview(["Status.A", "Status.B"])
202
+ 'match ...: case Status.A | Status.B | _: ...'
203
+ """
204
+ cases_str = " | ".join(enum_cases[:3])
205
+ if len(enum_cases) > 3:
206
+ cases_str += " | ..."
207
+ return f"match ...: case {cases_str} | _: ..."
@@ -0,0 +1,307 @@
1
+ """
2
+ Literal Pattern Detector (DX-61, P0).
3
+
4
+ Detects runtime validation for finite value sets that could benefit
5
+ from Literal type for compile-time safety.
6
+ """
7
+
8
+ import ast
9
+
10
+ from deal import post, pre
11
+
12
+ from invar.core.patterns.detector import BaseDetector
13
+ from invar.core.patterns.types import (
14
+ Confidence,
15
+ PatternID,
16
+ PatternSuggestion,
17
+ Priority,
18
+ )
19
+
20
+
21
+ class LiteralDetector(BaseDetector):
22
+ """
23
+ Detect runtime checks for finite value sets.
24
+
25
+ These are candidates for Literal type to catch invalid values
26
+ at type-check time instead of runtime.
27
+
28
+ Detection logic:
29
+ - Find 'if x not in (...)' or 'if x not in [...]' patterns
30
+ - Look for small sets of string/int literals
31
+ - Suggest Literal type for the parameter
32
+
33
+ >>> import ast
34
+ >>> detector = LiteralDetector()
35
+ >>> code = '''
36
+ ... def set_level(level: str) -> None:
37
+ ... if level not in ("debug", "info", "warning", "error"):
38
+ ... raise ValueError(f"Invalid level: {level}")
39
+ ... '''
40
+ >>> tree = ast.parse(code)
41
+ >>> suggestions = detector.detect(tree, "test.py")
42
+ >>> len(suggestions) > 0
43
+ True
44
+ """
45
+
46
+ MAX_LITERAL_VALUES = 10 # Don't suggest Literal for large sets
47
+
48
+ @property
49
+ @post(lambda result: result == PatternID.LITERAL)
50
+ def pattern_id(self) -> PatternID:
51
+ """Unique identifier for this pattern."""
52
+ return PatternID.LITERAL
53
+
54
+ @property
55
+ @post(lambda result: result == Priority.P0)
56
+ def priority(self) -> Priority:
57
+ """Priority tier."""
58
+ return Priority.P0
59
+
60
+ @property
61
+ @post(lambda result: len(result) > 0)
62
+ def description(self) -> str:
63
+ """Human-readable description."""
64
+ return "Use Literal type for finite value sets"
65
+
66
+ @post(lambda result: all(isinstance(s, PatternSuggestion) for s in result))
67
+ def detect(self, tree: ast.AST, file_path: str) -> list[PatternSuggestion]:
68
+ """
69
+ Find functions with runtime finite-set validation.
70
+
71
+ >>> import ast
72
+ >>> detector = LiteralDetector()
73
+ >>> code = '''
74
+ ... def process(mode: str):
75
+ ... if mode not in ["fast", "slow", "auto"]:
76
+ ... raise ValueError("Bad mode")
77
+ ... '''
78
+ >>> tree = ast.parse(code)
79
+ >>> suggestions = detector.detect(tree, "test.py")
80
+ >>> len(suggestions) > 0
81
+ True
82
+ """
83
+ suggestions = []
84
+
85
+ for node in ast.walk(tree):
86
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
87
+ func_suggestions = self._check_function(node, file_path)
88
+ suggestions.extend(func_suggestions)
89
+
90
+ return suggestions
91
+
92
+ @pre(lambda self, node, file_path: len(file_path) > 0)
93
+ @post(lambda result: all(isinstance(s, PatternSuggestion) for s in result))
94
+ def _check_function(
95
+ self, node: ast.FunctionDef | ast.AsyncFunctionDef, file_path: str
96
+ ) -> list[PatternSuggestion]:
97
+ """
98
+ Check if function has finite-set validation patterns.
99
+
100
+ >>> import ast
101
+ >>> detector = LiteralDetector()
102
+ >>> code = '''
103
+ ... def f(x):
104
+ ... if x not in ("a", "b", "c"):
105
+ ... raise ValueError("Bad")
106
+ ... '''
107
+ >>> tree = ast.parse(code)
108
+ >>> func = tree.body[0]
109
+ >>> suggestions = detector._check_function(func, "test.py")
110
+ >>> len(suggestions) > 0
111
+ True
112
+ """
113
+ suggestions = []
114
+ checks = self._find_membership_checks(node)
115
+
116
+ for var_name, values, check_line in checks:
117
+ if len(values) <= self.MAX_LITERAL_VALUES:
118
+ confidence = self._calculate_confidence(values)
119
+
120
+ suggestions.append(
121
+ self.make_suggestion(
122
+ pattern_id=self.pattern_id,
123
+ priority=self.priority,
124
+ file_path=file_path,
125
+ line=check_line,
126
+ message=f"Runtime check for {len(values)} values - consider Literal type",
127
+ current_code=self._format_check(var_name, values),
128
+ suggested_pattern=self._format_literal(var_name, values),
129
+ confidence=confidence,
130
+ reference_pattern="Pattern 4: Literal for Finite Value Sets",
131
+ )
132
+ )
133
+
134
+ return suggestions
135
+
136
+ @post(lambda result: all(len(name) > 0 and len(vals) > 0 and line > 0 for name, vals, line in result))
137
+ def _find_membership_checks(
138
+ self, node: ast.FunctionDef | ast.AsyncFunctionDef
139
+ ) -> list[tuple[str, list[str | int], int]]:
140
+ """
141
+ Find 'if x not in (...)' patterns.
142
+
143
+ Returns list of (var_name, values, line_number).
144
+ Only checks function-level if statements, not nested functions.
145
+
146
+ >>> import ast
147
+ >>> detector = LiteralDetector()
148
+ >>> code = '''
149
+ ... def f(x):
150
+ ... if x not in ("a", "b"):
151
+ ... raise ValueError("Bad")
152
+ ... '''
153
+ >>> tree = ast.parse(code)
154
+ >>> func = tree.body[0]
155
+ >>> checks = detector._find_membership_checks(func)
156
+ >>> len(checks) > 0
157
+ True
158
+ >>> checks[0][1]
159
+ ['a', 'b']
160
+ """
161
+ checks: list[tuple[str, list[str | int], int]] = []
162
+ self._collect_membership_checks(node.body, checks)
163
+ return checks
164
+
165
+ @pre(lambda self, stmts, checks: stmts is not None and checks is not None)
166
+ def _collect_membership_checks(
167
+ self,
168
+ stmts: list[ast.stmt],
169
+ checks: list[tuple[str, list[str | int], int]],
170
+ ) -> None:
171
+ """
172
+ Recursively collect membership checks, avoiding nested functions.
173
+
174
+ >>> import ast
175
+ >>> detector = LiteralDetector()
176
+ >>> stmts = ast.parse("if x not in ('a', 'b'): raise ValueError()").body
177
+ >>> checks: list[tuple[str, list[str | int], int]] = []
178
+ >>> detector._collect_membership_checks(stmts, checks)
179
+ >>> len(checks)
180
+ 1
181
+ """
182
+ for stmt in stmts:
183
+ # Skip nested functions
184
+ if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)):
185
+ continue
186
+
187
+ if isinstance(stmt, ast.If):
188
+ result = self._extract_membership_check(stmt.test)
189
+ if result:
190
+ var_name, values = result
191
+ checks.append((var_name, values, stmt.lineno))
192
+ # Recurse into body and else
193
+ self._collect_membership_checks(stmt.body, checks)
194
+ self._collect_membership_checks(stmt.orelse, checks)
195
+
196
+ @post(lambda result: result is None or (len(result[0]) > 0 and len(result[1]) > 0))
197
+ def _extract_membership_check(
198
+ self, test: ast.expr
199
+ ) -> tuple[str, list[str | int]] | None:
200
+ """
201
+ Extract variable and values from membership check.
202
+
203
+ Handles:
204
+ - 'x not in ("a", "b", "c")'
205
+ - 'x not in ["a", "b", "c"]'
206
+
207
+ >>> import ast
208
+ >>> detector = LiteralDetector()
209
+ >>> test = ast.parse("x not in ('a', 'b')", mode="eval").body
210
+ >>> result = detector._extract_membership_check(test)
211
+ >>> result is not None
212
+ True
213
+ >>> result[0]
214
+ 'x'
215
+ >>> result[1]
216
+ ['a', 'b']
217
+ """
218
+ # Handle 'x not in (...)'
219
+ if isinstance(test, ast.Compare):
220
+ if (
221
+ len(test.ops) == 1
222
+ and isinstance(test.ops[0], ast.NotIn)
223
+ and len(test.comparators) == 1
224
+ ):
225
+ left = test.left
226
+ right = test.comparators[0]
227
+
228
+ if isinstance(left, ast.Name):
229
+ var_name = left.id
230
+ values = self._extract_literal_values(right)
231
+ if values:
232
+ return (var_name, values)
233
+
234
+ return None
235
+
236
+ @post(lambda result: result is None or len(result) > 0)
237
+ def _extract_literal_values(self, node: ast.expr) -> list[str | int] | None:
238
+ """
239
+ Extract literal values from tuple/list/set.
240
+
241
+ >>> import ast
242
+ >>> detector = LiteralDetector()
243
+ >>> node = ast.parse("('a', 'b', 'c')", mode="eval").body
244
+ >>> detector._extract_literal_values(node)
245
+ ['a', 'b', 'c']
246
+ """
247
+ if isinstance(node, (ast.Tuple, ast.List, ast.Set)):
248
+ values = []
249
+ for elt in node.elts:
250
+ if isinstance(elt, ast.Constant) and isinstance(
251
+ elt.value, (str, int)
252
+ ):
253
+ values.append(elt.value)
254
+ else:
255
+ return None # Non-literal value
256
+ return values
257
+ return None
258
+
259
+ @pre(lambda self, values: len(values) > 0)
260
+ @post(lambda result: result in Confidence)
261
+ def _calculate_confidence(self, values: list[str | int]) -> Confidence:
262
+ """
263
+ Calculate confidence based on value characteristics.
264
+
265
+ >>> detector = LiteralDetector()
266
+ >>> detector._calculate_confidence(["debug", "info", "warning", "error"])
267
+ <Confidence.HIGH: 'high'>
268
+ """
269
+ # High confidence for small sets of strings
270
+ if all(isinstance(v, str) for v in values) and len(values) <= 5:
271
+ return Confidence.HIGH
272
+
273
+ # Medium confidence for larger sets or mixed types
274
+ if len(values) <= 8:
275
+ return Confidence.MEDIUM
276
+
277
+ return Confidence.LOW
278
+
279
+ @pre(lambda self, var_name, values: len(var_name) > 0 and len(values) > 0)
280
+ @post(lambda result: "not in" in result)
281
+ def _format_check(self, var_name: str, values: list[str | int]) -> str:
282
+ """
283
+ Format the membership check for display.
284
+
285
+ >>> detector = LiteralDetector()
286
+ >>> detector._format_check("level", ["debug", "info"])
287
+ "if level not in ('debug', 'info'): raise"
288
+ """
289
+ formatted_values = ", ".join(repr(v) for v in values[:4])
290
+ if len(values) > 4:
291
+ formatted_values += ", ..."
292
+ return f"if {var_name} not in ({formatted_values}): raise"
293
+
294
+ @pre(lambda self, _var_name, values: len(_var_name) > 0 and len(values) > 0)
295
+ @post(lambda result: "Literal[" in result)
296
+ def _format_literal(self, _var_name: str, values: list[str | int]) -> str:
297
+ """
298
+ Format the Literal type suggestion.
299
+
300
+ >>> detector = LiteralDetector()
301
+ >>> detector._format_literal("level", ["debug", "info", "error"])
302
+ "Literal['debug', 'info', 'error']"
303
+ """
304
+ formatted_values = ", ".join(repr(v) for v in values[:5])
305
+ if len(values) > 5:
306
+ formatted_values += ", ..."
307
+ return f"Literal[{formatted_values}]"