thailint 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,172 @@
1
+ """
2
+ Purpose: Detect function calls with string literal arguments in Python AST
3
+
4
+ Scope: Find function and method calls that consistently receive string arguments
5
+
6
+ Overview: Provides FunctionCallTracker class that traverses Python AST to find function
7
+ and method calls where string literals are passed as arguments. Tracks the function
8
+ name, parameter index, and string value to enable cross-file aggregation. When a
9
+ function is called with the same set of limited string values across files, it
10
+ suggests the parameter should be an enum. Handles both simple function calls
11
+ (foo("value")) and method calls (obj.method("value")).
12
+
13
+ Dependencies: ast module for AST parsing, dataclasses for pattern structure
14
+
15
+ Exports: FunctionCallTracker class, FunctionCallPattern dataclass
16
+
17
+ Interfaces: FunctionCallTracker.find_patterns(tree) -> list[FunctionCallPattern]
18
+
19
+ Implementation: AST NodeVisitor pattern with Call node handling for string arguments
20
+ """
21
+
22
+ import ast
23
+ from dataclasses import dataclass
24
+
25
+
26
+ @dataclass
27
+ class FunctionCallPattern:
28
+ """Represents a function call with a string literal argument.
29
+
30
+ Captures information about a function or method call where a string literal
31
+ is passed as an argument, enabling cross-file analysis to detect limited
32
+ value sets that should be enums.
33
+ """
34
+
35
+ function_name: str
36
+ """Fully qualified function name (e.g., 'process' or 'obj.method')."""
37
+
38
+ param_index: int
39
+ """Index of the parameter receiving the string value (0-indexed)."""
40
+
41
+ string_value: str
42
+ """The string literal value passed to the function."""
43
+
44
+ line_number: int
45
+ """Line number where the call occurs (1-indexed)."""
46
+
47
+ column: int
48
+ """Column number where the call starts (0-indexed)."""
49
+
50
+
51
+ class FunctionCallTracker(ast.NodeVisitor):
52
+ """Tracks function calls with string literal arguments.
53
+
54
+ Finds patterns like 'process("active")' and 'obj.set_status("pending")' where
55
+ string literals are used for arguments that could be enums.
56
+ """
57
+
58
+ def __init__(self) -> None:
59
+ """Initialize the tracker."""
60
+ self.patterns: list[FunctionCallPattern] = []
61
+
62
+ def find_patterns(self, tree: ast.AST) -> list[FunctionCallPattern]:
63
+ """Find all function calls with string arguments in the AST.
64
+
65
+ Args:
66
+ tree: The AST to analyze
67
+
68
+ Returns:
69
+ List of FunctionCallPattern instances for each detected call
70
+ """
71
+ self.patterns = []
72
+ self.visit(tree)
73
+ return self.patterns
74
+
75
+ def visit_Call(self, node: ast.Call) -> None: # pylint: disable=invalid-name
76
+ """Visit a Call node to check for string arguments.
77
+
78
+ Handles both simple function calls and method calls, extracting
79
+ the function name and any string literal arguments.
80
+
81
+ Args:
82
+ node: The Call node to analyze
83
+ """
84
+ function_name = self._extract_function_name(node.func)
85
+ if function_name is None:
86
+ self.generic_visit(node)
87
+ return
88
+
89
+ self._check_positional_args(node, function_name)
90
+ self.generic_visit(node)
91
+
92
+ def _extract_function_name(self, func_node: ast.expr) -> str | None:
93
+ """Extract the function name from a call expression.
94
+
95
+ Handles simple names (foo) and attribute access (obj.method).
96
+
97
+ Args:
98
+ func_node: The function expression node
99
+
100
+ Returns:
101
+ Function name string or None if not extractable
102
+ """
103
+ if isinstance(func_node, ast.Name):
104
+ return func_node.id
105
+ if isinstance(func_node, ast.Attribute):
106
+ return self._extract_attribute_name(func_node)
107
+ return None
108
+
109
+ def _extract_attribute_name(self, node: ast.Attribute) -> str | None:
110
+ """Extract function name from an attribute access.
111
+
112
+ Builds qualified names like 'obj.method' or 'a.b.method'.
113
+
114
+ Args:
115
+ node: The Attribute node
116
+
117
+ Returns:
118
+ Qualified function name or None if too complex
119
+ """
120
+ parts: list[str] = [node.attr]
121
+ current = node.value
122
+
123
+ # Limit depth to avoid overly complex names
124
+ max_depth = 3
125
+ depth = 0
126
+
127
+ while depth < max_depth:
128
+ if isinstance(current, ast.Name):
129
+ parts.append(current.id)
130
+ break
131
+ if isinstance(current, ast.Attribute):
132
+ parts.append(current.attr)
133
+ current = current.value
134
+ depth += 1
135
+ else:
136
+ # Complex expression (call result, subscript, etc.)
137
+ # Use placeholder to maintain function identity
138
+ parts.append("_")
139
+ break
140
+
141
+ return ".".join(reversed(parts))
142
+
143
+ def _check_positional_args(self, node: ast.Call, function_name: str) -> None:
144
+ """Check positional arguments for string literals.
145
+
146
+ Args:
147
+ node: The Call node
148
+ function_name: Extracted function name
149
+ """
150
+ for param_index, arg in enumerate(node.args):
151
+ if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
152
+ self._add_pattern(node, function_name, param_index, arg.value)
153
+
154
+ def _add_pattern(
155
+ self, node: ast.Call, function_name: str, param_index: int, string_value: str
156
+ ) -> None:
157
+ """Create and add a function call pattern to results.
158
+
159
+ Args:
160
+ node: The Call node containing the pattern
161
+ function_name: Name of the function being called
162
+ param_index: Index of the string argument
163
+ string_value: The string literal value
164
+ """
165
+ pattern = FunctionCallPattern(
166
+ function_name=function_name,
167
+ param_index=param_index,
168
+ string_value=string_value,
169
+ line_number=node.lineno,
170
+ column=node.col_offset,
171
+ )
172
+ self.patterns.append(pattern)
@@ -0,0 +1,252 @@
1
+ """
2
+ Purpose: Detect scattered string comparisons in Python AST
3
+
4
+ Scope: Find equality/inequality comparisons with string literals across Python files
5
+
6
+ Overview: Provides ComparisonTracker class that traverses Python AST to find scattered
7
+ string comparisons like `if env == "production"`. Tracks the variable name, compared
8
+ string value, and operator to enable cross-file aggregation. When a variable is compared
9
+ to multiple unique string values across files, it suggests the variable should be an enum.
10
+ Excludes common false positives like `__name__ == "__main__"` and type name checks.
11
+
12
+ Dependencies: ast module for AST parsing, dataclasses for pattern structure
13
+
14
+ Exports: ComparisonTracker class, ComparisonPattern dataclass
15
+
16
+ Interfaces: ComparisonTracker.find_patterns(tree) -> list[ComparisonPattern]
17
+
18
+ Implementation: AST NodeVisitor pattern with Compare node handling for string comparisons
19
+ """
20
+
21
+ import ast
22
+ from dataclasses import dataclass
23
+
24
+
25
+ @dataclass
26
+ class ComparisonPattern:
27
+ """Represents a string comparison found in Python code.
28
+
29
+ Captures information about a comparison like `if (env == "production")` to
30
+ enable cross-file analysis for detecting scattered string comparisons that
31
+ suggest missing enums.
32
+ """
33
+
34
+ variable_name: str
35
+ """Variable name being compared (e.g., 'env' or 'self.status')."""
36
+
37
+ compared_value: str
38
+ """The string literal value being compared to."""
39
+
40
+ operator: str
41
+ """The comparison operator ('==' or '!=')."""
42
+
43
+ line_number: int
44
+ """Line number where the comparison occurs (1-indexed)."""
45
+
46
+ column: int
47
+ """Column number where the comparison starts (0-indexed)."""
48
+
49
+
50
+ # Excluded variable names that are common false positives
51
+ _EXCLUDED_VARIABLES = frozenset(
52
+ {
53
+ "__name__",
54
+ "__class__.__name__",
55
+ }
56
+ )
57
+
58
+ # Excluded values that are common in legitimate comparisons
59
+ _EXCLUDED_VALUES = frozenset(
60
+ {
61
+ "__main__",
62
+ }
63
+ )
64
+
65
+
66
+ class ComparisonTracker(ast.NodeVisitor): # thailint: ignore[srp]
67
+ """Tracks scattered string comparisons in Python AST.
68
+
69
+ Finds patterns like `if env == "production"` and `if status != "deleted"` where
70
+ string literals are used for comparisons that could use enums instead.
71
+
72
+ Note: Method count exceeds SRP limit because AST traversal requires multiple helper
73
+ methods for extracting variable names, attribute names, and pattern filtering. All
74
+ methods support the single responsibility of tracking string comparisons.
75
+ """
76
+
77
+ def __init__(self) -> None:
78
+ """Initialize the tracker."""
79
+ self.patterns: list[ComparisonPattern] = []
80
+
81
+ def find_patterns(self, tree: ast.AST) -> list[ComparisonPattern]:
82
+ """Find all string comparisons in the AST.
83
+
84
+ Args:
85
+ tree: The AST to analyze
86
+
87
+ Returns:
88
+ List of ComparisonPattern instances for each detected comparison
89
+ """
90
+ self.patterns = []
91
+ self.visit(tree)
92
+ return self.patterns
93
+
94
+ def visit_Compare(self, node: ast.Compare) -> None: # pylint: disable=invalid-name
95
+ """Visit a Compare node to check for string comparisons.
96
+
97
+ Handles both `var == "string"` and `"string" == var` patterns.
98
+
99
+ Args:
100
+ node: The Compare node to analyze
101
+ """
102
+ self._check_comparison(node)
103
+ self.generic_visit(node)
104
+
105
+ def _check_comparison(self, node: ast.Compare) -> None:
106
+ """Check if comparison is a string comparison to track.
107
+
108
+ Args:
109
+ node: The Compare node to check
110
+ """
111
+ # Only handle simple binary comparisons
112
+ if len(node.ops) != 1 or len(node.comparators) != 1:
113
+ return
114
+
115
+ operator = node.ops[0]
116
+ if not isinstance(operator, (ast.Eq, ast.NotEq)):
117
+ return
118
+
119
+ op_str = "==" if isinstance(operator, ast.Eq) else "!="
120
+ left = node.left
121
+ right = node.comparators[0]
122
+
123
+ # Try both orientations: var == "string" and "string" == var
124
+ self._try_extract_pattern(left, right, op_str, node)
125
+ self._try_extract_pattern(right, left, op_str, node)
126
+
127
+ def _try_extract_pattern(
128
+ self,
129
+ var_side: ast.expr,
130
+ string_side: ast.expr,
131
+ operator: str,
132
+ node: ast.Compare,
133
+ ) -> None:
134
+ """Try to extract a pattern from a comparison.
135
+
136
+ Args:
137
+ var_side: The expression that might be a variable
138
+ string_side: The expression that might be a string literal
139
+ operator: The comparison operator
140
+ node: The original Compare node for location info
141
+ """
142
+ # Check if string_side is a string literal
143
+ if not isinstance(string_side, ast.Constant):
144
+ return
145
+ if not isinstance(string_side.value, str):
146
+ return
147
+
148
+ string_value = string_side.value
149
+
150
+ # Extract variable name
151
+ var_name = self._extract_variable_name(var_side)
152
+ if var_name is None:
153
+ return
154
+
155
+ # Check for excluded patterns
156
+ if self._should_exclude(var_name, string_value):
157
+ return
158
+
159
+ self._add_pattern(var_name, string_value, operator, node)
160
+
161
+ def _extract_variable_name(self, node: ast.expr) -> str | None:
162
+ """Extract variable name from an expression.
163
+
164
+ Handles simple names (var) and attribute access (obj.attr).
165
+
166
+ Args:
167
+ node: The expression to extract from
168
+
169
+ Returns:
170
+ Variable name string or None if not extractable
171
+ """
172
+ if isinstance(node, ast.Name):
173
+ return node.id
174
+ if isinstance(node, ast.Attribute):
175
+ return self._extract_attribute_name(node)
176
+ return None
177
+
178
+ def _extract_attribute_name(self, node: ast.Attribute) -> str | None:
179
+ """Extract attribute name from an attribute access.
180
+
181
+ Builds qualified names like 'obj.attr' or 'a.b.attr'.
182
+
183
+ Args:
184
+ node: The Attribute node
185
+
186
+ Returns:
187
+ Qualified attribute name or None if too complex
188
+ """
189
+ parts: list[str] = [node.attr]
190
+ current = node.value
191
+
192
+ # Limit depth to avoid overly complex names
193
+ max_depth = 3
194
+ depth = 0
195
+
196
+ while depth < max_depth:
197
+ if isinstance(current, ast.Name):
198
+ parts.append(current.id)
199
+ break
200
+ if isinstance(current, ast.Attribute):
201
+ parts.append(current.attr)
202
+ current = current.value
203
+ depth += 1
204
+ else:
205
+ # Complex expression, still return what we have
206
+ parts.append("_")
207
+ break
208
+
209
+ return ".".join(reversed(parts))
210
+
211
+ def _should_exclude(self, var_name: str, string_value: str) -> bool:
212
+ """Check if this comparison should be excluded.
213
+
214
+ Filters out common patterns that are not stringly-typed code:
215
+ - __name__ == "__main__"
216
+ - __class__.__name__ checks
217
+
218
+ Args:
219
+ var_name: The variable name
220
+ string_value: The string value
221
+
222
+ Returns:
223
+ True if the comparison should be excluded
224
+ """
225
+ if var_name in _EXCLUDED_VARIABLES:
226
+ return True
227
+ if string_value in _EXCLUDED_VALUES:
228
+ return True
229
+ # Also exclude if the full qualified name ends with __name__
230
+ if var_name.endswith("__name__"):
231
+ return True
232
+ return False
233
+
234
+ def _add_pattern(
235
+ self, var_name: str, string_value: str, operator: str, node: ast.Compare
236
+ ) -> None:
237
+ """Create and add a comparison pattern to results.
238
+
239
+ Args:
240
+ var_name: The variable name
241
+ string_value: The string value being compared
242
+ operator: The comparison operator
243
+ node: The Compare node for location info
244
+ """
245
+ pattern = ComparisonPattern(
246
+ variable_name=var_name,
247
+ compared_value=string_value,
248
+ operator=operator,
249
+ line_number=node.lineno,
250
+ column=node.col_offset,
251
+ )
252
+ self.patterns.append(pattern)