invar-tools 1.3.3__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invar/core/formatter.py +6 -1
- invar/core/models.py +13 -0
- invar/core/patterns/__init__.py +53 -0
- invar/core/patterns/detector.py +249 -0
- invar/core/patterns/p0_exhaustive.py +207 -0
- invar/core/patterns/p0_literal.py +307 -0
- invar/core/patterns/p0_newtype.py +211 -0
- invar/core/patterns/p0_nonempty.py +307 -0
- invar/core/patterns/p0_validation.py +278 -0
- invar/core/patterns/registry.py +234 -0
- invar/core/patterns/types.py +167 -0
- invar/core/trivial_detection.py +189 -0
- invar/mcp/server.py +4 -0
- invar/shell/commands/guard.py +65 -0
- invar/shell/contract_coverage.py +358 -0
- invar/shell/guard_output.py +5 -0
- invar/shell/pattern_integration.py +234 -0
- invar/shell/templates.py +2 -2
- invar/shell/testing.py +13 -2
- invar/templates/config/CLAUDE.md.jinja +1 -0
- invar/templates/skills/develop/SKILL.md.jinja +49 -0
- invar/templates/skills/review/SKILL.md.jinja +196 -31
- {invar_tools-1.3.3.dist-info → invar_tools-1.5.0.dist-info}/METADATA +24 -15
- {invar_tools-1.3.3.dist-info → invar_tools-1.5.0.dist-info}/RECORD +29 -17
- {invar_tools-1.3.3.dist-info → invar_tools-1.5.0.dist-info}/entry_points.txt +1 -0
- {invar_tools-1.3.3.dist-info → invar_tools-1.5.0.dist-info}/WHEEL +0 -0
- {invar_tools-1.3.3.dist-info → invar_tools-1.5.0.dist-info}/licenses/LICENSE +0 -0
- {invar_tools-1.3.3.dist-info → invar_tools-1.5.0.dist-info}/licenses/LICENSE-GPL +0 -0
- {invar_tools-1.3.3.dist-info → invar_tools-1.5.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Literal Pattern Detector (DX-61, P0).
|
|
3
|
+
|
|
4
|
+
Detects runtime validation for finite value sets that could benefit
|
|
5
|
+
from Literal type for compile-time safety.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import ast
|
|
9
|
+
|
|
10
|
+
from deal import post, pre
|
|
11
|
+
|
|
12
|
+
from invar.core.patterns.detector import BaseDetector
|
|
13
|
+
from invar.core.patterns.types import (
|
|
14
|
+
Confidence,
|
|
15
|
+
PatternID,
|
|
16
|
+
PatternSuggestion,
|
|
17
|
+
Priority,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class LiteralDetector(BaseDetector):
|
|
22
|
+
"""
|
|
23
|
+
Detect runtime checks for finite value sets.
|
|
24
|
+
|
|
25
|
+
These are candidates for Literal type to catch invalid values
|
|
26
|
+
at type-check time instead of runtime.
|
|
27
|
+
|
|
28
|
+
Detection logic:
|
|
29
|
+
- Find 'if x not in (...)' or 'if x not in [...]' patterns
|
|
30
|
+
- Look for small sets of string/int literals
|
|
31
|
+
- Suggest Literal type for the parameter
|
|
32
|
+
|
|
33
|
+
>>> import ast
|
|
34
|
+
>>> detector = LiteralDetector()
|
|
35
|
+
>>> code = '''
|
|
36
|
+
... def set_level(level: str) -> None:
|
|
37
|
+
... if level not in ("debug", "info", "warning", "error"):
|
|
38
|
+
... raise ValueError(f"Invalid level: {level}")
|
|
39
|
+
... '''
|
|
40
|
+
>>> tree = ast.parse(code)
|
|
41
|
+
>>> suggestions = detector.detect(tree, "test.py")
|
|
42
|
+
>>> len(suggestions) > 0
|
|
43
|
+
True
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
MAX_LITERAL_VALUES = 10 # Don't suggest Literal for large sets
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
@post(lambda result: result == PatternID.LITERAL)
|
|
50
|
+
def pattern_id(self) -> PatternID:
|
|
51
|
+
"""Unique identifier for this pattern."""
|
|
52
|
+
return PatternID.LITERAL
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
@post(lambda result: result == Priority.P0)
|
|
56
|
+
def priority(self) -> Priority:
|
|
57
|
+
"""Priority tier."""
|
|
58
|
+
return Priority.P0
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
@post(lambda result: len(result) > 0)
|
|
62
|
+
def description(self) -> str:
|
|
63
|
+
"""Human-readable description."""
|
|
64
|
+
return "Use Literal type for finite value sets"
|
|
65
|
+
|
|
66
|
+
@post(lambda result: all(isinstance(s, PatternSuggestion) for s in result))
|
|
67
|
+
def detect(self, tree: ast.AST, file_path: str) -> list[PatternSuggestion]:
|
|
68
|
+
"""
|
|
69
|
+
Find functions with runtime finite-set validation.
|
|
70
|
+
|
|
71
|
+
>>> import ast
|
|
72
|
+
>>> detector = LiteralDetector()
|
|
73
|
+
>>> code = '''
|
|
74
|
+
... def process(mode: str):
|
|
75
|
+
... if mode not in ["fast", "slow", "auto"]:
|
|
76
|
+
... raise ValueError("Bad mode")
|
|
77
|
+
... '''
|
|
78
|
+
>>> tree = ast.parse(code)
|
|
79
|
+
>>> suggestions = detector.detect(tree, "test.py")
|
|
80
|
+
>>> len(suggestions) > 0
|
|
81
|
+
True
|
|
82
|
+
"""
|
|
83
|
+
suggestions = []
|
|
84
|
+
|
|
85
|
+
for node in ast.walk(tree):
|
|
86
|
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
87
|
+
func_suggestions = self._check_function(node, file_path)
|
|
88
|
+
suggestions.extend(func_suggestions)
|
|
89
|
+
|
|
90
|
+
return suggestions
|
|
91
|
+
|
|
92
|
+
@pre(lambda self, node, file_path: len(file_path) > 0)
|
|
93
|
+
@post(lambda result: all(isinstance(s, PatternSuggestion) for s in result))
|
|
94
|
+
def _check_function(
|
|
95
|
+
self, node: ast.FunctionDef | ast.AsyncFunctionDef, file_path: str
|
|
96
|
+
) -> list[PatternSuggestion]:
|
|
97
|
+
"""
|
|
98
|
+
Check if function has finite-set validation patterns.
|
|
99
|
+
|
|
100
|
+
>>> import ast
|
|
101
|
+
>>> detector = LiteralDetector()
|
|
102
|
+
>>> code = '''
|
|
103
|
+
... def f(x):
|
|
104
|
+
... if x not in ("a", "b", "c"):
|
|
105
|
+
... raise ValueError("Bad")
|
|
106
|
+
... '''
|
|
107
|
+
>>> tree = ast.parse(code)
|
|
108
|
+
>>> func = tree.body[0]
|
|
109
|
+
>>> suggestions = detector._check_function(func, "test.py")
|
|
110
|
+
>>> len(suggestions) > 0
|
|
111
|
+
True
|
|
112
|
+
"""
|
|
113
|
+
suggestions = []
|
|
114
|
+
checks = self._find_membership_checks(node)
|
|
115
|
+
|
|
116
|
+
for var_name, values, check_line in checks:
|
|
117
|
+
if len(values) <= self.MAX_LITERAL_VALUES:
|
|
118
|
+
confidence = self._calculate_confidence(values)
|
|
119
|
+
|
|
120
|
+
suggestions.append(
|
|
121
|
+
self.make_suggestion(
|
|
122
|
+
pattern_id=self.pattern_id,
|
|
123
|
+
priority=self.priority,
|
|
124
|
+
file_path=file_path,
|
|
125
|
+
line=check_line,
|
|
126
|
+
message=f"Runtime check for {len(values)} values - consider Literal type",
|
|
127
|
+
current_code=self._format_check(var_name, values),
|
|
128
|
+
suggested_pattern=self._format_literal(var_name, values),
|
|
129
|
+
confidence=confidence,
|
|
130
|
+
reference_pattern="Pattern 4: Literal for Finite Value Sets",
|
|
131
|
+
)
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return suggestions
|
|
135
|
+
|
|
136
|
+
@post(lambda result: all(len(name) > 0 and len(vals) > 0 and line > 0 for name, vals, line in result))
|
|
137
|
+
def _find_membership_checks(
|
|
138
|
+
self, node: ast.FunctionDef | ast.AsyncFunctionDef
|
|
139
|
+
) -> list[tuple[str, list[str | int], int]]:
|
|
140
|
+
"""
|
|
141
|
+
Find 'if x not in (...)' patterns.
|
|
142
|
+
|
|
143
|
+
Returns list of (var_name, values, line_number).
|
|
144
|
+
Only checks function-level if statements, not nested functions.
|
|
145
|
+
|
|
146
|
+
>>> import ast
|
|
147
|
+
>>> detector = LiteralDetector()
|
|
148
|
+
>>> code = '''
|
|
149
|
+
... def f(x):
|
|
150
|
+
... if x not in ("a", "b"):
|
|
151
|
+
... raise ValueError("Bad")
|
|
152
|
+
... '''
|
|
153
|
+
>>> tree = ast.parse(code)
|
|
154
|
+
>>> func = tree.body[0]
|
|
155
|
+
>>> checks = detector._find_membership_checks(func)
|
|
156
|
+
>>> len(checks) > 0
|
|
157
|
+
True
|
|
158
|
+
>>> checks[0][1]
|
|
159
|
+
['a', 'b']
|
|
160
|
+
"""
|
|
161
|
+
checks: list[tuple[str, list[str | int], int]] = []
|
|
162
|
+
self._collect_membership_checks(node.body, checks)
|
|
163
|
+
return checks
|
|
164
|
+
|
|
165
|
+
@pre(lambda self, stmts, checks: stmts is not None and checks is not None)
|
|
166
|
+
def _collect_membership_checks(
|
|
167
|
+
self,
|
|
168
|
+
stmts: list[ast.stmt],
|
|
169
|
+
checks: list[tuple[str, list[str | int], int]],
|
|
170
|
+
) -> None:
|
|
171
|
+
"""
|
|
172
|
+
Recursively collect membership checks, avoiding nested functions.
|
|
173
|
+
|
|
174
|
+
>>> import ast
|
|
175
|
+
>>> detector = LiteralDetector()
|
|
176
|
+
>>> stmts = ast.parse("if x not in ('a', 'b'): raise ValueError()").body
|
|
177
|
+
>>> checks: list[tuple[str, list[str | int], int]] = []
|
|
178
|
+
>>> detector._collect_membership_checks(stmts, checks)
|
|
179
|
+
>>> len(checks)
|
|
180
|
+
1
|
|
181
|
+
"""
|
|
182
|
+
for stmt in stmts:
|
|
183
|
+
# Skip nested functions
|
|
184
|
+
if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
185
|
+
continue
|
|
186
|
+
|
|
187
|
+
if isinstance(stmt, ast.If):
|
|
188
|
+
result = self._extract_membership_check(stmt.test)
|
|
189
|
+
if result:
|
|
190
|
+
var_name, values = result
|
|
191
|
+
checks.append((var_name, values, stmt.lineno))
|
|
192
|
+
# Recurse into body and else
|
|
193
|
+
self._collect_membership_checks(stmt.body, checks)
|
|
194
|
+
self._collect_membership_checks(stmt.orelse, checks)
|
|
195
|
+
|
|
196
|
+
@post(lambda result: result is None or (len(result[0]) > 0 and len(result[1]) > 0))
|
|
197
|
+
def _extract_membership_check(
|
|
198
|
+
self, test: ast.expr
|
|
199
|
+
) -> tuple[str, list[str | int]] | None:
|
|
200
|
+
"""
|
|
201
|
+
Extract variable and values from membership check.
|
|
202
|
+
|
|
203
|
+
Handles:
|
|
204
|
+
- 'x not in ("a", "b", "c")'
|
|
205
|
+
- 'x not in ["a", "b", "c"]'
|
|
206
|
+
|
|
207
|
+
>>> import ast
|
|
208
|
+
>>> detector = LiteralDetector()
|
|
209
|
+
>>> test = ast.parse("x not in ('a', 'b')", mode="eval").body
|
|
210
|
+
>>> result = detector._extract_membership_check(test)
|
|
211
|
+
>>> result is not None
|
|
212
|
+
True
|
|
213
|
+
>>> result[0]
|
|
214
|
+
'x'
|
|
215
|
+
>>> result[1]
|
|
216
|
+
['a', 'b']
|
|
217
|
+
"""
|
|
218
|
+
# Handle 'x not in (...)'
|
|
219
|
+
if isinstance(test, ast.Compare):
|
|
220
|
+
if (
|
|
221
|
+
len(test.ops) == 1
|
|
222
|
+
and isinstance(test.ops[0], ast.NotIn)
|
|
223
|
+
and len(test.comparators) == 1
|
|
224
|
+
):
|
|
225
|
+
left = test.left
|
|
226
|
+
right = test.comparators[0]
|
|
227
|
+
|
|
228
|
+
if isinstance(left, ast.Name):
|
|
229
|
+
var_name = left.id
|
|
230
|
+
values = self._extract_literal_values(right)
|
|
231
|
+
if values:
|
|
232
|
+
return (var_name, values)
|
|
233
|
+
|
|
234
|
+
return None
|
|
235
|
+
|
|
236
|
+
@post(lambda result: result is None or len(result) > 0)
|
|
237
|
+
def _extract_literal_values(self, node: ast.expr) -> list[str | int] | None:
|
|
238
|
+
"""
|
|
239
|
+
Extract literal values from tuple/list/set.
|
|
240
|
+
|
|
241
|
+
>>> import ast
|
|
242
|
+
>>> detector = LiteralDetector()
|
|
243
|
+
>>> node = ast.parse("('a', 'b', 'c')", mode="eval").body
|
|
244
|
+
>>> detector._extract_literal_values(node)
|
|
245
|
+
['a', 'b', 'c']
|
|
246
|
+
"""
|
|
247
|
+
if isinstance(node, (ast.Tuple, ast.List, ast.Set)):
|
|
248
|
+
values = []
|
|
249
|
+
for elt in node.elts:
|
|
250
|
+
if isinstance(elt, ast.Constant) and isinstance(
|
|
251
|
+
elt.value, (str, int)
|
|
252
|
+
):
|
|
253
|
+
values.append(elt.value)
|
|
254
|
+
else:
|
|
255
|
+
return None # Non-literal value
|
|
256
|
+
return values
|
|
257
|
+
return None
|
|
258
|
+
|
|
259
|
+
@pre(lambda self, values: len(values) > 0)
|
|
260
|
+
@post(lambda result: result in Confidence)
|
|
261
|
+
def _calculate_confidence(self, values: list[str | int]) -> Confidence:
|
|
262
|
+
"""
|
|
263
|
+
Calculate confidence based on value characteristics.
|
|
264
|
+
|
|
265
|
+
>>> detector = LiteralDetector()
|
|
266
|
+
>>> detector._calculate_confidence(["debug", "info", "warning", "error"])
|
|
267
|
+
<Confidence.HIGH: 'high'>
|
|
268
|
+
"""
|
|
269
|
+
# High confidence for small sets of strings
|
|
270
|
+
if all(isinstance(v, str) for v in values) and len(values) <= 5:
|
|
271
|
+
return Confidence.HIGH
|
|
272
|
+
|
|
273
|
+
# Medium confidence for larger sets or mixed types
|
|
274
|
+
if len(values) <= 8:
|
|
275
|
+
return Confidence.MEDIUM
|
|
276
|
+
|
|
277
|
+
return Confidence.LOW
|
|
278
|
+
|
|
279
|
+
@pre(lambda self, var_name, values: len(var_name) > 0 and len(values) > 0)
|
|
280
|
+
@post(lambda result: "not in" in result)
|
|
281
|
+
def _format_check(self, var_name: str, values: list[str | int]) -> str:
|
|
282
|
+
"""
|
|
283
|
+
Format the membership check for display.
|
|
284
|
+
|
|
285
|
+
>>> detector = LiteralDetector()
|
|
286
|
+
>>> detector._format_check("level", ["debug", "info"])
|
|
287
|
+
"if level not in ('debug', 'info'): raise"
|
|
288
|
+
"""
|
|
289
|
+
formatted_values = ", ".join(repr(v) for v in values[:4])
|
|
290
|
+
if len(values) > 4:
|
|
291
|
+
formatted_values += ", ..."
|
|
292
|
+
return f"if {var_name} not in ({formatted_values}): raise"
|
|
293
|
+
|
|
294
|
+
@pre(lambda self, _var_name, values: len(_var_name) > 0 and len(values) > 0)
|
|
295
|
+
@post(lambda result: "Literal[" in result)
|
|
296
|
+
def _format_literal(self, _var_name: str, values: list[str | int]) -> str:
|
|
297
|
+
"""
|
|
298
|
+
Format the Literal type suggestion.
|
|
299
|
+
|
|
300
|
+
>>> detector = LiteralDetector()
|
|
301
|
+
>>> detector._format_literal("level", ["debug", "info", "error"])
|
|
302
|
+
"Literal['debug', 'info', 'error']"
|
|
303
|
+
"""
|
|
304
|
+
formatted_values = ", ".join(repr(v) for v in values[:5])
|
|
305
|
+
if len(values) > 5:
|
|
306
|
+
formatted_values += ", ..."
|
|
307
|
+
return f"Literal[{formatted_values}]"
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
"""
|
|
2
|
+
NewType Pattern Detector (DX-61, P0).
|
|
3
|
+
|
|
4
|
+
Detects opportunities to use NewType for semantic clarity when
|
|
5
|
+
multiple parameters share the same primitive type.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import ast
|
|
9
|
+
from typing import ClassVar
|
|
10
|
+
|
|
11
|
+
from deal import post, pre
|
|
12
|
+
|
|
13
|
+
from invar.core.patterns.detector import BaseDetector
|
|
14
|
+
from invar.core.patterns.types import (
|
|
15
|
+
Confidence,
|
|
16
|
+
PatternID,
|
|
17
|
+
PatternSuggestion,
|
|
18
|
+
Priority,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class NewTypeDetector(BaseDetector):
|
|
23
|
+
"""
|
|
24
|
+
Detect functions with 3+ parameters of the same primitive type.
|
|
25
|
+
|
|
26
|
+
These are candidates for NewType to prevent parameter confusion.
|
|
27
|
+
|
|
28
|
+
Detection logic:
|
|
29
|
+
- Find functions with 3+ str/int/float params of same type
|
|
30
|
+
- Exclude common patterns (e.g., *args, **kwargs)
|
|
31
|
+
- Suggest NewType for semantic differentiation
|
|
32
|
+
|
|
33
|
+
>>> import ast
|
|
34
|
+
>>> detector = NewTypeDetector()
|
|
35
|
+
>>> code = '''
|
|
36
|
+
... def process(user_id: str, order_id: str, product_id: str):
|
|
37
|
+
... pass
|
|
38
|
+
... '''
|
|
39
|
+
>>> tree = ast.parse(code)
|
|
40
|
+
>>> suggestions = detector.detect(tree, "test.py")
|
|
41
|
+
>>> len(suggestions) > 0
|
|
42
|
+
True
|
|
43
|
+
>>> suggestions[0].pattern_id == PatternID.NEWTYPE
|
|
44
|
+
True
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
PRIMITIVE_TYPES: ClassVar[set[str]] = {"str", "int", "float", "bool", "bytes"}
|
|
48
|
+
MIN_SAME_TYPE_PARAMS: ClassVar[int] = 3
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
@post(lambda result: result == PatternID.NEWTYPE)
|
|
52
|
+
def pattern_id(self) -> PatternID:
|
|
53
|
+
"""Unique identifier for this pattern."""
|
|
54
|
+
return PatternID.NEWTYPE
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
@post(lambda result: result == Priority.P0)
|
|
58
|
+
def priority(self) -> Priority:
|
|
59
|
+
"""Priority tier."""
|
|
60
|
+
return Priority.P0
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
@post(lambda result: len(result) > 0)
|
|
64
|
+
def description(self) -> str:
|
|
65
|
+
"""Human-readable description."""
|
|
66
|
+
return "Use NewType for semantic clarity with multiple same-type parameters"
|
|
67
|
+
|
|
68
|
+
@post(lambda result: all(isinstance(s, PatternSuggestion) for s in result))
|
|
69
|
+
def detect(self, tree: ast.AST, file_path: str) -> list[PatternSuggestion]:
|
|
70
|
+
"""
|
|
71
|
+
Find functions with multiple parameters of the same primitive type.
|
|
72
|
+
|
|
73
|
+
>>> import ast
|
|
74
|
+
>>> detector = NewTypeDetector()
|
|
75
|
+
>>> code = '''
|
|
76
|
+
... def good(a: str, b: int):
|
|
77
|
+
... pass
|
|
78
|
+
... def bad(user_id: str, order_id: str, product_id: str):
|
|
79
|
+
... pass
|
|
80
|
+
... '''
|
|
81
|
+
>>> tree = ast.parse(code)
|
|
82
|
+
>>> suggestions = detector.detect(tree, "test.py")
|
|
83
|
+
>>> len(suggestions)
|
|
84
|
+
1
|
|
85
|
+
>>> "bad" in suggestions[0].current_code
|
|
86
|
+
True
|
|
87
|
+
"""
|
|
88
|
+
suggestions = []
|
|
89
|
+
|
|
90
|
+
for node in ast.walk(tree):
|
|
91
|
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
92
|
+
suggestion = self._check_function(node, file_path)
|
|
93
|
+
if suggestion:
|
|
94
|
+
suggestions.append(suggestion)
|
|
95
|
+
|
|
96
|
+
return suggestions
|
|
97
|
+
|
|
98
|
+
@pre(lambda self, node, file_path: len(file_path) > 0)
|
|
99
|
+
def _check_function(
|
|
100
|
+
self, node: ast.FunctionDef | ast.AsyncFunctionDef, file_path: str
|
|
101
|
+
) -> PatternSuggestion | None:
|
|
102
|
+
"""
|
|
103
|
+
Check if function has multiple params of same primitive type.
|
|
104
|
+
|
|
105
|
+
>>> import ast
|
|
106
|
+
>>> detector = NewTypeDetector()
|
|
107
|
+
>>> code = "def f(user_id: str, order_id: str, product_id: str): pass"
|
|
108
|
+
>>> tree = ast.parse(code)
|
|
109
|
+
>>> func = tree.body[0]
|
|
110
|
+
>>> suggestion = detector._check_function(func, "test.py")
|
|
111
|
+
>>> suggestion is not None
|
|
112
|
+
True
|
|
113
|
+
>>> suggestion.confidence == Confidence.HIGH
|
|
114
|
+
True
|
|
115
|
+
"""
|
|
116
|
+
params = self.get_function_params(node)
|
|
117
|
+
|
|
118
|
+
# Skip if too few parameters
|
|
119
|
+
if len(params) < self.MIN_SAME_TYPE_PARAMS:
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
# Count occurrences of each primitive type
|
|
123
|
+
for prim_type in self.PRIMITIVE_TYPES:
|
|
124
|
+
count = self.count_type_occurrences(params, prim_type)
|
|
125
|
+
if count >= self.MIN_SAME_TYPE_PARAMS:
|
|
126
|
+
# Found opportunity
|
|
127
|
+
matching_params = [name for name, t in params if t == prim_type]
|
|
128
|
+
confidence = self._calculate_confidence(matching_params, node)
|
|
129
|
+
|
|
130
|
+
return self.make_suggestion(
|
|
131
|
+
pattern_id=self.pattern_id,
|
|
132
|
+
priority=self.priority,
|
|
133
|
+
file_path=file_path,
|
|
134
|
+
line=node.lineno,
|
|
135
|
+
message=f"{count} '{prim_type}' params - consider NewType for semantic clarity",
|
|
136
|
+
current_code=self._format_signature(node),
|
|
137
|
+
suggested_pattern=self._suggest_newtypes(matching_params, prim_type),
|
|
138
|
+
confidence=confidence,
|
|
139
|
+
reference_pattern="Pattern 1: NewType for Semantic Clarity",
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return None
|
|
143
|
+
|
|
144
|
+
@pre(lambda self, param_names, _node: len(param_names) > 0)
|
|
145
|
+
@post(lambda result: result in Confidence)
|
|
146
|
+
def _calculate_confidence(
|
|
147
|
+
self, param_names: list[str], _node: ast.FunctionDef | ast.AsyncFunctionDef
|
|
148
|
+
) -> Confidence:
|
|
149
|
+
"""
|
|
150
|
+
Calculate confidence based on parameter naming patterns.
|
|
151
|
+
|
|
152
|
+
Higher confidence if names suggest distinct entities (e.g., *_id patterns).
|
|
153
|
+
|
|
154
|
+
>>> detector = NewTypeDetector()
|
|
155
|
+
>>> import ast
|
|
156
|
+
>>> func = ast.parse("def f(user_id, order_id, product_id): pass").body[0]
|
|
157
|
+
>>> detector._calculate_confidence(["user_id", "order_id", "product_id"], func)
|
|
158
|
+
<Confidence.HIGH: 'high'>
|
|
159
|
+
"""
|
|
160
|
+
# High confidence if names follow *_id, *_name, or *_code patterns
|
|
161
|
+
id_pattern = sum(1 for n in param_names if n.endswith(("_id", "_name", "_code", "_key")))
|
|
162
|
+
if id_pattern >= 2:
|
|
163
|
+
return Confidence.HIGH
|
|
164
|
+
|
|
165
|
+
# Medium confidence for descriptive names
|
|
166
|
+
if all(len(n) > 3 for n in param_names):
|
|
167
|
+
return Confidence.MEDIUM
|
|
168
|
+
|
|
169
|
+
# Low confidence for short/generic names
|
|
170
|
+
return Confidence.LOW
|
|
171
|
+
|
|
172
|
+
@post(lambda result: len(result) > 0 and "def " in result)
|
|
173
|
+
def _format_signature(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> str:
|
|
174
|
+
"""
|
|
175
|
+
Format function signature for display.
|
|
176
|
+
|
|
177
|
+
>>> import ast
|
|
178
|
+
>>> detector = NewTypeDetector()
|
|
179
|
+
>>> func = ast.parse("def process(a: str, b: str): pass").body[0]
|
|
180
|
+
>>> sig = detector._format_signature(func)
|
|
181
|
+
>>> "process" in sig
|
|
182
|
+
True
|
|
183
|
+
"""
|
|
184
|
+
params = self.get_function_params(node)
|
|
185
|
+
param_str = ", ".join(
|
|
186
|
+
f"{name}: {t}" if t else name
|
|
187
|
+
for name, t in params[:5] # Limit for readability
|
|
188
|
+
)
|
|
189
|
+
if len(params) > 5:
|
|
190
|
+
param_str += ", ..."
|
|
191
|
+
prefix = "async def" if isinstance(node, ast.AsyncFunctionDef) else "def"
|
|
192
|
+
return f"{prefix} {node.name}({param_str})"
|
|
193
|
+
|
|
194
|
+
@pre(lambda self, param_names, base_type: len(param_names) > 0 and len(base_type) > 0)
|
|
195
|
+
@post(lambda result: "NewType" in result)
|
|
196
|
+
def _suggest_newtypes(self, param_names: list[str], base_type: str) -> str:
|
|
197
|
+
"""
|
|
198
|
+
Generate NewType suggestion for parameters.
|
|
199
|
+
|
|
200
|
+
>>> detector = NewTypeDetector()
|
|
201
|
+
>>> detector._suggest_newtypes(["user_id", "order_id"], "str")
|
|
202
|
+
"NewType('UserId', str), NewType('OrderId', str)"
|
|
203
|
+
"""
|
|
204
|
+
newtypes = []
|
|
205
|
+
for name in param_names[:3]: # Limit suggestions
|
|
206
|
+
# Convert snake_case to PascalCase
|
|
207
|
+
pascal = "".join(word.capitalize() for word in name.split("_"))
|
|
208
|
+
newtypes.append(f"NewType('{pascal}', {base_type})")
|
|
209
|
+
if len(param_names) > 3:
|
|
210
|
+
newtypes.append("...")
|
|
211
|
+
return ", ".join(newtypes)
|