invar-tools 1.3.3__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,307 @@
1
+ """
2
+ Literal Pattern Detector (DX-61, P0).
3
+
4
+ Detects runtime validation for finite value sets that could benefit
5
+ from Literal type for compile-time safety.
6
+ """
7
+
8
+ import ast
9
+
10
+ from deal import post, pre
11
+
12
+ from invar.core.patterns.detector import BaseDetector
13
+ from invar.core.patterns.types import (
14
+ Confidence,
15
+ PatternID,
16
+ PatternSuggestion,
17
+ Priority,
18
+ )
19
+
20
+
21
+ class LiteralDetector(BaseDetector):
22
+ """
23
+ Detect runtime checks for finite value sets.
24
+
25
+ These are candidates for Literal type to catch invalid values
26
+ at type-check time instead of runtime.
27
+
28
+ Detection logic:
29
+ - Find 'if x not in (...)' or 'if x not in [...]' patterns
30
+ - Look for small sets of string/int literals
31
+ - Suggest Literal type for the parameter
32
+
33
+ >>> import ast
34
+ >>> detector = LiteralDetector()
35
+ >>> code = '''
36
+ ... def set_level(level: str) -> None:
37
+ ... if level not in ("debug", "info", "warning", "error"):
38
+ ... raise ValueError(f"Invalid level: {level}")
39
+ ... '''
40
+ >>> tree = ast.parse(code)
41
+ >>> suggestions = detector.detect(tree, "test.py")
42
+ >>> len(suggestions) > 0
43
+ True
44
+ """
45
+
46
+ MAX_LITERAL_VALUES = 10 # Don't suggest Literal for large sets
47
+
48
+ @property
49
+ @post(lambda result: result == PatternID.LITERAL)
50
+ def pattern_id(self) -> PatternID:
51
+ """Unique identifier for this pattern."""
52
+ return PatternID.LITERAL
53
+
54
+ @property
55
+ @post(lambda result: result == Priority.P0)
56
+ def priority(self) -> Priority:
57
+ """Priority tier."""
58
+ return Priority.P0
59
+
60
+ @property
61
+ @post(lambda result: len(result) > 0)
62
+ def description(self) -> str:
63
+ """Human-readable description."""
64
+ return "Use Literal type for finite value sets"
65
+
66
+ @post(lambda result: all(isinstance(s, PatternSuggestion) for s in result))
67
+ def detect(self, tree: ast.AST, file_path: str) -> list[PatternSuggestion]:
68
+ """
69
+ Find functions with runtime finite-set validation.
70
+
71
+ >>> import ast
72
+ >>> detector = LiteralDetector()
73
+ >>> code = '''
74
+ ... def process(mode: str):
75
+ ... if mode not in ["fast", "slow", "auto"]:
76
+ ... raise ValueError("Bad mode")
77
+ ... '''
78
+ >>> tree = ast.parse(code)
79
+ >>> suggestions = detector.detect(tree, "test.py")
80
+ >>> len(suggestions) > 0
81
+ True
82
+ """
83
+ suggestions = []
84
+
85
+ for node in ast.walk(tree):
86
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
87
+ func_suggestions = self._check_function(node, file_path)
88
+ suggestions.extend(func_suggestions)
89
+
90
+ return suggestions
91
+
92
+ @pre(lambda self, node, file_path: len(file_path) > 0)
93
+ @post(lambda result: all(isinstance(s, PatternSuggestion) for s in result))
94
+ def _check_function(
95
+ self, node: ast.FunctionDef | ast.AsyncFunctionDef, file_path: str
96
+ ) -> list[PatternSuggestion]:
97
+ """
98
+ Check if function has finite-set validation patterns.
99
+
100
+ >>> import ast
101
+ >>> detector = LiteralDetector()
102
+ >>> code = '''
103
+ ... def f(x):
104
+ ... if x not in ("a", "b", "c"):
105
+ ... raise ValueError("Bad")
106
+ ... '''
107
+ >>> tree = ast.parse(code)
108
+ >>> func = tree.body[0]
109
+ >>> suggestions = detector._check_function(func, "test.py")
110
+ >>> len(suggestions) > 0
111
+ True
112
+ """
113
+ suggestions = []
114
+ checks = self._find_membership_checks(node)
115
+
116
+ for var_name, values, check_line in checks:
117
+ if len(values) <= self.MAX_LITERAL_VALUES:
118
+ confidence = self._calculate_confidence(values)
119
+
120
+ suggestions.append(
121
+ self.make_suggestion(
122
+ pattern_id=self.pattern_id,
123
+ priority=self.priority,
124
+ file_path=file_path,
125
+ line=check_line,
126
+ message=f"Runtime check for {len(values)} values - consider Literal type",
127
+ current_code=self._format_check(var_name, values),
128
+ suggested_pattern=self._format_literal(var_name, values),
129
+ confidence=confidence,
130
+ reference_pattern="Pattern 4: Literal for Finite Value Sets",
131
+ )
132
+ )
133
+
134
+ return suggestions
135
+
136
+ @post(lambda result: all(len(name) > 0 and len(vals) > 0 and line > 0 for name, vals, line in result))
137
+ def _find_membership_checks(
138
+ self, node: ast.FunctionDef | ast.AsyncFunctionDef
139
+ ) -> list[tuple[str, list[str | int], int]]:
140
+ """
141
+ Find 'if x not in (...)' patterns.
142
+
143
+ Returns list of (var_name, values, line_number).
144
+ Only checks function-level if statements, not nested functions.
145
+
146
+ >>> import ast
147
+ >>> detector = LiteralDetector()
148
+ >>> code = '''
149
+ ... def f(x):
150
+ ... if x not in ("a", "b"):
151
+ ... raise ValueError("Bad")
152
+ ... '''
153
+ >>> tree = ast.parse(code)
154
+ >>> func = tree.body[0]
155
+ >>> checks = detector._find_membership_checks(func)
156
+ >>> len(checks) > 0
157
+ True
158
+ >>> checks[0][1]
159
+ ['a', 'b']
160
+ """
161
+ checks: list[tuple[str, list[str | int], int]] = []
162
+ self._collect_membership_checks(node.body, checks)
163
+ return checks
164
+
165
+ @pre(lambda self, stmts, checks: stmts is not None and checks is not None)
166
+ def _collect_membership_checks(
167
+ self,
168
+ stmts: list[ast.stmt],
169
+ checks: list[tuple[str, list[str | int], int]],
170
+ ) -> None:
171
+ """
172
+ Recursively collect membership checks, avoiding nested functions.
173
+
174
+ >>> import ast
175
+ >>> detector = LiteralDetector()
176
+ >>> stmts = ast.parse("if x not in ('a', 'b'): raise ValueError()").body
177
+ >>> checks: list[tuple[str, list[str | int], int]] = []
178
+ >>> detector._collect_membership_checks(stmts, checks)
179
+ >>> len(checks)
180
+ 1
181
+ """
182
+ for stmt in stmts:
183
+ # Skip nested functions
184
+ if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)):
185
+ continue
186
+
187
+ if isinstance(stmt, ast.If):
188
+ result = self._extract_membership_check(stmt.test)
189
+ if result:
190
+ var_name, values = result
191
+ checks.append((var_name, values, stmt.lineno))
192
+ # Recurse into body and else
193
+ self._collect_membership_checks(stmt.body, checks)
194
+ self._collect_membership_checks(stmt.orelse, checks)
195
+
196
+ @post(lambda result: result is None or (len(result[0]) > 0 and len(result[1]) > 0))
197
+ def _extract_membership_check(
198
+ self, test: ast.expr
199
+ ) -> tuple[str, list[str | int]] | None:
200
+ """
201
+ Extract variable and values from membership check.
202
+
203
+ Handles:
204
+ - 'x not in ("a", "b", "c")'
205
+ - 'x not in ["a", "b", "c"]'
206
+
207
+ >>> import ast
208
+ >>> detector = LiteralDetector()
209
+ >>> test = ast.parse("x not in ('a', 'b')", mode="eval").body
210
+ >>> result = detector._extract_membership_check(test)
211
+ >>> result is not None
212
+ True
213
+ >>> result[0]
214
+ 'x'
215
+ >>> result[1]
216
+ ['a', 'b']
217
+ """
218
+ # Handle 'x not in (...)'
219
+ if isinstance(test, ast.Compare):
220
+ if (
221
+ len(test.ops) == 1
222
+ and isinstance(test.ops[0], ast.NotIn)
223
+ and len(test.comparators) == 1
224
+ ):
225
+ left = test.left
226
+ right = test.comparators[0]
227
+
228
+ if isinstance(left, ast.Name):
229
+ var_name = left.id
230
+ values = self._extract_literal_values(right)
231
+ if values:
232
+ return (var_name, values)
233
+
234
+ return None
235
+
236
+ @post(lambda result: result is None or len(result) > 0)
237
+ def _extract_literal_values(self, node: ast.expr) -> list[str | int] | None:
238
+ """
239
+ Extract literal values from tuple/list/set.
240
+
241
+ >>> import ast
242
+ >>> detector = LiteralDetector()
243
+ >>> node = ast.parse("('a', 'b', 'c')", mode="eval").body
244
+ >>> detector._extract_literal_values(node)
245
+ ['a', 'b', 'c']
246
+ """
247
+ if isinstance(node, (ast.Tuple, ast.List, ast.Set)):
248
+ values = []
249
+ for elt in node.elts:
250
+ if isinstance(elt, ast.Constant) and isinstance(
251
+ elt.value, (str, int)
252
+ ):
253
+ values.append(elt.value)
254
+ else:
255
+ return None # Non-literal value
256
+ return values
257
+ return None
258
+
259
+ @pre(lambda self, values: len(values) > 0)
260
+ @post(lambda result: result in Confidence)
261
+ def _calculate_confidence(self, values: list[str | int]) -> Confidence:
262
+ """
263
+ Calculate confidence based on value characteristics.
264
+
265
+ >>> detector = LiteralDetector()
266
+ >>> detector._calculate_confidence(["debug", "info", "warning", "error"])
267
+ <Confidence.HIGH: 'high'>
268
+ """
269
+ # High confidence for small sets of strings
270
+ if all(isinstance(v, str) for v in values) and len(values) <= 5:
271
+ return Confidence.HIGH
272
+
273
+ # Medium confidence for larger sets or mixed types
274
+ if len(values) <= 8:
275
+ return Confidence.MEDIUM
276
+
277
+ return Confidence.LOW
278
+
279
+ @pre(lambda self, var_name, values: len(var_name) > 0 and len(values) > 0)
280
+ @post(lambda result: "not in" in result)
281
+ def _format_check(self, var_name: str, values: list[str | int]) -> str:
282
+ """
283
+ Format the membership check for display.
284
+
285
+ >>> detector = LiteralDetector()
286
+ >>> detector._format_check("level", ["debug", "info"])
287
+ "if level not in ('debug', 'info'): raise"
288
+ """
289
+ formatted_values = ", ".join(repr(v) for v in values[:4])
290
+ if len(values) > 4:
291
+ formatted_values += ", ..."
292
+ return f"if {var_name} not in ({formatted_values}): raise"
293
+
294
+ @pre(lambda self, _var_name, values: len(_var_name) > 0 and len(values) > 0)
295
+ @post(lambda result: "Literal[" in result)
296
+ def _format_literal(self, _var_name: str, values: list[str | int]) -> str:
297
+ """
298
+ Format the Literal type suggestion.
299
+
300
+ >>> detector = LiteralDetector()
301
+ >>> detector._format_literal("level", ["debug", "info", "error"])
302
+ "Literal['debug', 'info', 'error']"
303
+ """
304
+ formatted_values = ", ".join(repr(v) for v in values[:5])
305
+ if len(values) > 5:
306
+ formatted_values += ", ..."
307
+ return f"Literal[{formatted_values}]"
@@ -0,0 +1,211 @@
1
+ """
2
+ NewType Pattern Detector (DX-61, P0).
3
+
4
+ Detects opportunities to use NewType for semantic clarity when
5
+ multiple parameters share the same primitive type.
6
+ """
7
+
8
+ import ast
9
+ from typing import ClassVar
10
+
11
+ from deal import post, pre
12
+
13
+ from invar.core.patterns.detector import BaseDetector
14
+ from invar.core.patterns.types import (
15
+ Confidence,
16
+ PatternID,
17
+ PatternSuggestion,
18
+ Priority,
19
+ )
20
+
21
+
22
+ class NewTypeDetector(BaseDetector):
23
+ """
24
+ Detect functions with 3+ parameters of the same primitive type.
25
+
26
+ These are candidates for NewType to prevent parameter confusion.
27
+
28
+ Detection logic:
29
+ - Find functions with 3+ str/int/float params of same type
30
+ - Exclude common patterns (e.g., *args, **kwargs)
31
+ - Suggest NewType for semantic differentiation
32
+
33
+ >>> import ast
34
+ >>> detector = NewTypeDetector()
35
+ >>> code = '''
36
+ ... def process(user_id: str, order_id: str, product_id: str):
37
+ ... pass
38
+ ... '''
39
+ >>> tree = ast.parse(code)
40
+ >>> suggestions = detector.detect(tree, "test.py")
41
+ >>> len(suggestions) > 0
42
+ True
43
+ >>> suggestions[0].pattern_id == PatternID.NEWTYPE
44
+ True
45
+ """
46
+
47
+ PRIMITIVE_TYPES: ClassVar[set[str]] = {"str", "int", "float", "bool", "bytes"}
48
+ MIN_SAME_TYPE_PARAMS: ClassVar[int] = 3
49
+
50
+ @property
51
+ @post(lambda result: result == PatternID.NEWTYPE)
52
+ def pattern_id(self) -> PatternID:
53
+ """Unique identifier for this pattern."""
54
+ return PatternID.NEWTYPE
55
+
56
+ @property
57
+ @post(lambda result: result == Priority.P0)
58
+ def priority(self) -> Priority:
59
+ """Priority tier."""
60
+ return Priority.P0
61
+
62
+ @property
63
+ @post(lambda result: len(result) > 0)
64
+ def description(self) -> str:
65
+ """Human-readable description."""
66
+ return "Use NewType for semantic clarity with multiple same-type parameters"
67
+
68
+ @post(lambda result: all(isinstance(s, PatternSuggestion) for s in result))
69
+ def detect(self, tree: ast.AST, file_path: str) -> list[PatternSuggestion]:
70
+ """
71
+ Find functions with multiple parameters of the same primitive type.
72
+
73
+ >>> import ast
74
+ >>> detector = NewTypeDetector()
75
+ >>> code = '''
76
+ ... def good(a: str, b: int):
77
+ ... pass
78
+ ... def bad(user_id: str, order_id: str, product_id: str):
79
+ ... pass
80
+ ... '''
81
+ >>> tree = ast.parse(code)
82
+ >>> suggestions = detector.detect(tree, "test.py")
83
+ >>> len(suggestions)
84
+ 1
85
+ >>> "bad" in suggestions[0].current_code
86
+ True
87
+ """
88
+ suggestions = []
89
+
90
+ for node in ast.walk(tree):
91
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
92
+ suggestion = self._check_function(node, file_path)
93
+ if suggestion:
94
+ suggestions.append(suggestion)
95
+
96
+ return suggestions
97
+
98
+ @pre(lambda self, node, file_path: len(file_path) > 0)
99
+ def _check_function(
100
+ self, node: ast.FunctionDef | ast.AsyncFunctionDef, file_path: str
101
+ ) -> PatternSuggestion | None:
102
+ """
103
+ Check if function has multiple params of same primitive type.
104
+
105
+ >>> import ast
106
+ >>> detector = NewTypeDetector()
107
+ >>> code = "def f(user_id: str, order_id: str, product_id: str): pass"
108
+ >>> tree = ast.parse(code)
109
+ >>> func = tree.body[0]
110
+ >>> suggestion = detector._check_function(func, "test.py")
111
+ >>> suggestion is not None
112
+ True
113
+ >>> suggestion.confidence == Confidence.HIGH
114
+ True
115
+ """
116
+ params = self.get_function_params(node)
117
+
118
+ # Skip if too few parameters
119
+ if len(params) < self.MIN_SAME_TYPE_PARAMS:
120
+ return None
121
+
122
+ # Count occurrences of each primitive type
123
+ for prim_type in self.PRIMITIVE_TYPES:
124
+ count = self.count_type_occurrences(params, prim_type)
125
+ if count >= self.MIN_SAME_TYPE_PARAMS:
126
+ # Found opportunity
127
+ matching_params = [name for name, t in params if t == prim_type]
128
+ confidence = self._calculate_confidence(matching_params, node)
129
+
130
+ return self.make_suggestion(
131
+ pattern_id=self.pattern_id,
132
+ priority=self.priority,
133
+ file_path=file_path,
134
+ line=node.lineno,
135
+ message=f"{count} '{prim_type}' params - consider NewType for semantic clarity",
136
+ current_code=self._format_signature(node),
137
+ suggested_pattern=self._suggest_newtypes(matching_params, prim_type),
138
+ confidence=confidence,
139
+ reference_pattern="Pattern 1: NewType for Semantic Clarity",
140
+ )
141
+
142
+ return None
143
+
144
+ @pre(lambda self, param_names, _node: len(param_names) > 0)
145
+ @post(lambda result: result in Confidence)
146
+ def _calculate_confidence(
147
+ self, param_names: list[str], _node: ast.FunctionDef | ast.AsyncFunctionDef
148
+ ) -> Confidence:
149
+ """
150
+ Calculate confidence based on parameter naming patterns.
151
+
152
+ Higher confidence if names suggest distinct entities (e.g., *_id patterns).
153
+
154
+ >>> detector = NewTypeDetector()
155
+ >>> import ast
156
+ >>> func = ast.parse("def f(user_id, order_id, product_id): pass").body[0]
157
+ >>> detector._calculate_confidence(["user_id", "order_id", "product_id"], func)
158
+ <Confidence.HIGH: 'high'>
159
+ """
160
+ # High confidence if names follow *_id, *_name, or *_code patterns
161
+ id_pattern = sum(1 for n in param_names if n.endswith(("_id", "_name", "_code", "_key")))
162
+ if id_pattern >= 2:
163
+ return Confidence.HIGH
164
+
165
+ # Medium confidence for descriptive names
166
+ if all(len(n) > 3 for n in param_names):
167
+ return Confidence.MEDIUM
168
+
169
+ # Low confidence for short/generic names
170
+ return Confidence.LOW
171
+
172
+ @post(lambda result: len(result) > 0 and "def " in result)
173
+ def _format_signature(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> str:
174
+ """
175
+ Format function signature for display.
176
+
177
+ >>> import ast
178
+ >>> detector = NewTypeDetector()
179
+ >>> func = ast.parse("def process(a: str, b: str): pass").body[0]
180
+ >>> sig = detector._format_signature(func)
181
+ >>> "process" in sig
182
+ True
183
+ """
184
+ params = self.get_function_params(node)
185
+ param_str = ", ".join(
186
+ f"{name}: {t}" if t else name
187
+ for name, t in params[:5] # Limit for readability
188
+ )
189
+ if len(params) > 5:
190
+ param_str += ", ..."
191
+ prefix = "async def" if isinstance(node, ast.AsyncFunctionDef) else "def"
192
+ return f"{prefix} {node.name}({param_str})"
193
+
194
+ @pre(lambda self, param_names, base_type: len(param_names) > 0 and len(base_type) > 0)
195
+ @post(lambda result: "NewType" in result)
196
+ def _suggest_newtypes(self, param_names: list[str], base_type: str) -> str:
197
+ """
198
+ Generate NewType suggestion for parameters.
199
+
200
+ >>> detector = NewTypeDetector()
201
+ >>> detector._suggest_newtypes(["user_id", "order_id"], "str")
202
+ "NewType('UserId', str), NewType('OrderId', str)"
203
+ """
204
+ newtypes = []
205
+ for name in param_names[:3]: # Limit suggestions
206
+ # Convert snake_case to PascalCase
207
+ pascal = "".join(word.capitalize() for word in name.split("_"))
208
+ newtypes.append(f"NewType('{pascal}', {base_type})")
209
+ if len(param_names) > 3:
210
+ newtypes.append("...")
211
+ return ", ".join(newtypes)