python-oop-analyzer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,291 @@
1
+ """
2
+ Encapsulation Rule - Tell Don't Ask principle.
3
+
4
+ This rule detects violations of encapsulation where objects are accessed
5
+ directly through their properties instead of through methods.
6
+
7
+ In OOP, we should "tell" objects what to do, not "ask" them for data
8
+ and then make decisions based on that data.
9
+ """
10
+
11
+ import ast
12
+ from typing import Any
13
+
14
+ from .base import BaseRule, RuleResult, RuleViolation
15
+
16
+
17
+ class EncapsulationRule(BaseRule):
18
+ """
19
+ Detects direct property access on objects (tell don't ask violations).
20
+
21
+ Violations include:
22
+ - Accessing object attributes directly: obj.property
23
+ - Especially when followed by operations on that property
24
+ - Chained attribute access: obj.prop1.prop2
25
+
26
+ Exceptions:
27
+ - Method calls: obj.method() is fine
28
+ - Self access within a class: self.x is often necessary
29
+ - Module-level constants: CONSTANT access
30
+ - Module attribute access: json.JSONEncoder, redis.Redis (normal Python usage)
31
+ - Named tuple / dataclass field access (configurable)
32
+ """
33
+
34
+ name = "encapsulation"
35
+ description = "Check for direct property access (tell don't ask)"
36
+ severity = "warning"
37
+
38
+ def __init__(self, options: dict[str, Any] | None = None):
39
+ super().__init__(options)
40
+ self.allow_self_access = self.options.get("allow_self_access", True)
41
+ self.allow_private_access = self.options.get("allow_private_access", False)
42
+ self.allow_dunder_access = self.options.get("allow_dunder_access", True)
43
+ self.max_chain_length = self.options.get("max_chain_length", 1)
44
+ self.warn_dependency_access = self.options.get("warn_dependency_access", True)
45
+
46
+ def analyze(
47
+ self,
48
+ tree: ast.Module,
49
+ source: str,
50
+ file_path: str,
51
+ ) -> RuleResult:
52
+ """Analyze the AST for encapsulation violations."""
53
+ violations: list[RuleViolation] = []
54
+ visitor = EncapsulationVisitor(
55
+ file_path=file_path,
56
+ source=source,
57
+ allow_self_access=self.allow_self_access,
58
+ allow_private_access=self.allow_private_access,
59
+ allow_dunder_access=self.allow_dunder_access,
60
+ max_chain_length=self.max_chain_length,
61
+ warn_dependency_access=self.warn_dependency_access,
62
+ )
63
+ visitor.visit(tree)
64
+ violations = visitor.violations
65
+
66
+ return RuleResult(
67
+ rule_name=self.name,
68
+ violations=violations,
69
+ summary={
70
+ "total_violations": len(violations),
71
+ "files_analyzed": 1,
72
+ "module_access_skipped": visitor.module_access_skipped,
73
+ },
74
+ )
75
+
76
+
77
+ class EncapsulationVisitor(ast.NodeVisitor):
78
+ """AST visitor that detects encapsulation violations."""
79
+
80
+ def __init__(
81
+ self,
82
+ file_path: str,
83
+ source: str,
84
+ allow_self_access: bool = True,
85
+ allow_private_access: bool = False,
86
+ allow_dunder_access: bool = True,
87
+ max_chain_length: int = 1,
88
+ warn_dependency_access: bool = True,
89
+ ):
90
+ self.file_path = file_path
91
+ self.source = source
92
+ self.allow_self_access = allow_self_access
93
+ self.allow_private_access = allow_private_access
94
+ self.allow_dunder_access = allow_dunder_access
95
+ self.max_chain_length = max_chain_length
96
+ self.warn_dependency_access = warn_dependency_access
97
+ self.violations: list[RuleViolation] = []
98
+ self._in_class = False
99
+ self._current_class: str | None = None
100
+ self._call_targets: set[int] = set() # IDs of nodes that are call targets
101
+ self._imported_modules: set[str] = set() # Names of imported modules
102
+ self._class_bases: set[int] = set() # IDs of nodes used as class bases
103
+ self.module_access_skipped = 0
104
+
105
+ def visit_Import(self, node: ast.Import) -> None:
106
+ """Track imported module names."""
107
+ for alias in node.names:
108
+ # Use the alias if provided, otherwise the module name
109
+ name = alias.asname if alias.asname else alias.name.split(".")[0]
110
+ self._imported_modules.add(name)
111
+ self.generic_visit(node)
112
+
113
+ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
114
+ """Track imported names from modules."""
115
+ if node.module:
116
+ # Track the module itself if imported with alias
117
+ for alias in node.names:
118
+ if alias.asname:
119
+ self._imported_modules.add(alias.asname)
120
+ else:
121
+ self._imported_modules.add(alias.name)
122
+ self.generic_visit(node)
123
+
124
+ def visit_ClassDef(self, node: ast.ClassDef) -> None:
125
+ """Track when we're inside a class definition."""
126
+ # Mark base class nodes to skip them (e.g., json.JSONEncoder)
127
+ for base in node.bases:
128
+ self._mark_as_class_base(base)
129
+
130
+ old_in_class = self._in_class
131
+ old_class = self._current_class
132
+ self._in_class = True
133
+ self._current_class = node.name
134
+ self.generic_visit(node)
135
+ self._in_class = old_in_class
136
+ self._current_class = old_class
137
+
138
+ def _mark_as_class_base(self, node: ast.expr) -> None:
139
+ """Recursively mark nodes used as class bases."""
140
+ self._class_bases.add(id(node))
141
+ if isinstance(node, ast.Attribute):
142
+ self._class_bases.add(id(node.value))
143
+ self._mark_as_class_base(node.value)
144
+
145
+ def visit_Call(self, node: ast.Call) -> None:
146
+ """Mark call targets so we don't flag method calls as violations."""
147
+ if isinstance(node.func, ast.Attribute):
148
+ self._call_targets.add(id(node.func))
149
+ self.generic_visit(node)
150
+
151
+ def visit_Attribute(self, node: ast.Attribute) -> None:
152
+ """Check attribute access for encapsulation violations."""
153
+ # Skip if this is a method call (the attribute is the target of a Call)
154
+ if id(node) in self._call_targets:
155
+ self.generic_visit(node)
156
+ return
157
+
158
+ # Skip if this is a class base (e.g., class Foo(json.JSONEncoder))
159
+ if id(node) in self._class_bases:
160
+ self.module_access_skipped += 1
161
+ self.generic_visit(node)
162
+ return
163
+
164
+ # Get the chain of attribute access
165
+ chain = self._get_attribute_chain(node)
166
+
167
+ if not chain:
168
+ self.generic_visit(node)
169
+ return
170
+
171
+ base_name = chain[0]
172
+ attr_names = chain[1:]
173
+
174
+ # Skip self/cls access if allowed
175
+ if self.allow_self_access and base_name in ("self", "cls"):
176
+ self.generic_visit(node)
177
+ return
178
+
179
+ # Skip module attribute access (e.g., json.JSONEncoder, redis.Redis)
180
+ # This is normal Python module usage, not an encapsulation violation
181
+ if base_name in self._imported_modules:
182
+ # Check if accessing a class/constant from a module (PascalCase or UPPER_CASE)
183
+ if len(attr_names) == 1:
184
+ attr = attr_names[0]
185
+ # PascalCase (class) or UPPER_CASE (constant)
186
+ if attr[0].isupper():
187
+ self.module_access_skipped += 1
188
+ self.generic_visit(node)
189
+ return
190
+
191
+ # Skip dunder attributes if allowed
192
+ if self.allow_dunder_access:
193
+ if any(attr.startswith("__") and attr.endswith("__") for attr in attr_names):
194
+ self.generic_visit(node)
195
+ return
196
+
197
+ # Skip private attributes if allowed
198
+ if self.allow_private_access and any(attr.startswith("_") for attr in attr_names):
199
+ self.generic_visit(node)
200
+ return
201
+
202
+ # Check for violations
203
+ if len(attr_names) > 0:
204
+ # Direct property access detected
205
+ violation = self._create_violation(node, base_name, attr_names)
206
+ if violation:
207
+ self.violations.append(violation)
208
+
209
+ self.generic_visit(node)
210
+
211
+ def _get_attribute_chain(self, node: ast.Attribute) -> list[str]:
212
+ """
213
+ Get the full chain of attribute access.
214
+
215
+ For `obj.prop1.prop2`, returns ["obj", "prop1", "prop2"]
216
+ """
217
+ chain: list[str] = []
218
+ current: ast.expr = node
219
+
220
+ while isinstance(current, ast.Attribute):
221
+ chain.append(current.attr)
222
+ current = current.value
223
+
224
+ if isinstance(current, ast.Name):
225
+ chain.append(current.id)
226
+ else:
227
+ # Complex expression, can't determine base
228
+ return []
229
+
230
+ chain.reverse()
231
+ return chain
232
+
233
+ def _create_violation(
234
+ self,
235
+ node: ast.Attribute,
236
+ base_name: str,
237
+ attr_names: list[str],
238
+ ) -> RuleViolation | None:
239
+ """Create a violation for direct property access."""
240
+ # Skip module-level access patterns (all caps = constants)
241
+ if all(c.isupper() or c == "_" for c in attr_names[-1]):
242
+ return None
243
+
244
+ # Skip common module access patterns
245
+ if base_name in ("os", "sys", "math", "typing", "collections", "functools"):
246
+ return None
247
+
248
+ full_access = f"{base_name}.{'.'.join(attr_names)}"
249
+
250
+ # Check chain length
251
+ if len(attr_names) > self.max_chain_length:
252
+ message = (
253
+ f"Long attribute chain detected: '{full_access}'. "
254
+ f"This violates the Law of Demeter. Consider using delegation."
255
+ )
256
+ suggestion = (
257
+ f"Instead of accessing '{full_access}', consider adding a method "
258
+ f"to '{base_name}' that encapsulates this behavior."
259
+ )
260
+ else:
261
+ message = (
262
+ f"Direct property access: '{full_access}'. "
263
+ f"Consider using a method instead (Tell Don't Ask)."
264
+ )
265
+ suggestion = (
266
+ f"Instead of accessing '{base_name}.{attr_names[0]}', "
267
+ f"consider telling '{base_name}' what to do with a method call."
268
+ )
269
+
270
+ return RuleViolation(
271
+ rule_name="encapsulation",
272
+ message=message,
273
+ file_path=self.file_path,
274
+ line=node.lineno,
275
+ column=node.col_offset,
276
+ severity="warning",
277
+ suggestion=suggestion,
278
+ code_snippet=self._get_source_line(node.lineno),
279
+ metadata={
280
+ "base_object": base_name,
281
+ "accessed_attributes": attr_names,
282
+ "chain_length": len(attr_names),
283
+ },
284
+ )
285
+
286
+ def _get_source_line(self, line_number: int) -> str:
287
+ """Get a specific line from the source code."""
288
+ lines = self.source.splitlines()
289
+ if 1 <= line_number <= len(lines):
290
+ return lines[line_number - 1].strip()
291
+ return ""
@@ -0,0 +1,331 @@
1
+ """
2
+ Functions to Objects Rule.
3
+
4
+ This rule detects standalone functions that could be better represented
5
+ as objects/classes, following OOP principles.
6
+ """
7
+
8
+ import ast
9
+ from typing import Any
10
+
11
+ from .base import BaseRule, RuleResult, RuleViolation
12
+
13
+
14
+ class FunctionsToObjectsRule(BaseRule):
15
+ """
16
+ Detects functions that could be replaced by objects.
17
+
18
+ Patterns detected:
19
+ - Functions with many parameters (could be a class with attributes)
20
+ - Functions that operate on the same data repeatedly (could be methods)
21
+ - Groups of related functions (could be a class)
22
+ - Functions with complex state management (could be objects)
23
+ - Functions returning dictionaries that could be objects
24
+ """
25
+
26
+ name = "functions_to_objects"
27
+ description = "Detect functions that could be objects"
28
+ severity = "info"
29
+
30
+ def __init__(self, options: dict[str, Any] | None = None):
31
+ super().__init__(options)
32
+ self.max_params = self.options.get("max_params", 4)
33
+ self.check_dict_returns = self.options.get("check_dict_returns", True)
34
+ self.check_related_functions = self.options.get("check_related_functions", True)
35
+
36
+ def analyze(
37
+ self,
38
+ tree: ast.Module,
39
+ source: str,
40
+ file_path: str,
41
+ ) -> RuleResult:
42
+ """Analyze the AST for functions that could be objects."""
43
+ visitor = FunctionVisitor(
44
+ file_path=file_path,
45
+ source=source,
46
+ max_params=self.max_params,
47
+ check_dict_returns=self.check_dict_returns,
48
+ )
49
+ visitor.visit(tree)
50
+
51
+ violations = visitor.violations.copy()
52
+
53
+ # Check for related functions (functions with similar prefixes/suffixes)
54
+ if self.check_related_functions:
55
+ related_violations = self._check_related_functions(
56
+ visitor.function_info,
57
+ file_path,
58
+ source,
59
+ )
60
+ violations.extend(related_violations)
61
+
62
+ return RuleResult(
63
+ rule_name=self.name,
64
+ violations=violations,
65
+ summary={
66
+ "total_functions": len(visitor.function_info),
67
+ "functions_with_many_params": visitor.many_params_count,
68
+ "functions_returning_dicts": visitor.dict_return_count,
69
+ "related_function_groups": len(self._find_function_groups(visitor.function_info)),
70
+ },
71
+ metadata={
72
+ "functions": visitor.function_info,
73
+ "function_groups": self._find_function_groups(visitor.function_info),
74
+ },
75
+ )
76
+
77
+ def _check_related_functions(
78
+ self,
79
+ function_info: list[dict[str, Any]],
80
+ file_path: str,
81
+ source: str,
82
+ ) -> list[RuleViolation]:
83
+ """Check for groups of related functions that could be a class."""
84
+ violations: list[RuleViolation] = []
85
+ groups = self._find_function_groups(function_info)
86
+
87
+ for prefix, functions in groups.items():
88
+ if len(functions) >= 3:
89
+ func_names = [f["name"] for f in functions]
90
+ first_line = min(f["line"] for f in functions)
91
+
92
+ violations.append(
93
+ RuleViolation(
94
+ rule_name="functions_to_objects",
95
+ message=(
96
+ f"Found {len(functions)} related functions with prefix '{prefix}_': "
97
+ f"{', '.join(func_names[:5])}{'...' if len(func_names) > 5 else ''}. "
98
+ f"Consider grouping into a class."
99
+ ),
100
+ file_path=file_path,
101
+ line=first_line,
102
+ column=0,
103
+ severity="info",
104
+ suggestion=(
105
+ f"These functions appear related. Consider creating a class "
106
+ f"'{prefix.title().replace('_', '')}' with these as methods."
107
+ ),
108
+ metadata={
109
+ "pattern": "related_functions",
110
+ "prefix": prefix,
111
+ "functions": func_names,
112
+ },
113
+ )
114
+ )
115
+
116
+ return violations
117
+
118
+ def _find_function_groups(
119
+ self,
120
+ function_info: list[dict[str, Any]],
121
+ ) -> dict[str, list[dict[str, Any]]]:
122
+ """Find groups of functions with common prefixes."""
123
+ from collections import defaultdict
124
+
125
+ groups: dict[str, list[dict[str, Any]]] = defaultdict(list)
126
+
127
+ for func in function_info:
128
+ name = func["name"]
129
+ # Skip private/dunder functions
130
+ if name.startswith("_"):
131
+ continue
132
+
133
+ # Extract prefix (first word before underscore)
134
+ parts = name.split("_")
135
+ if len(parts) >= 2:
136
+ prefix = parts[0]
137
+ if len(prefix) >= 3: # Meaningful prefix
138
+ groups[prefix].append(func)
139
+
140
+ # Filter to groups with multiple functions
141
+ return {k: v for k, v in groups.items() if len(v) >= 2}
142
+
143
+
144
+ class FunctionVisitor(ast.NodeVisitor):
145
+ """AST visitor that analyzes functions."""
146
+
147
+ def __init__(
148
+ self,
149
+ file_path: str,
150
+ source: str,
151
+ max_params: int = 4,
152
+ check_dict_returns: bool = True,
153
+ ):
154
+ self.file_path = file_path
155
+ self.source = source
156
+ self.max_params = max_params
157
+ self.check_dict_returns = check_dict_returns
158
+
159
+ self.violations: list[RuleViolation] = []
160
+ self.function_info: list[dict[str, Any]] = []
161
+ self.many_params_count = 0
162
+ self.dict_return_count = 0
163
+
164
+ self._in_class = False
165
+
166
+ def visit_ClassDef(self, node: ast.ClassDef) -> None:
167
+ """Track when inside a class (skip methods)."""
168
+ old_in_class = self._in_class
169
+ self._in_class = True
170
+ self.generic_visit(node)
171
+ self._in_class = old_in_class
172
+
173
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
174
+ """Analyze function definitions."""
175
+ # Skip methods (inside classes)
176
+ if self._in_class:
177
+ self.generic_visit(node)
178
+ return
179
+
180
+ # Skip private/dunder functions for some checks
181
+ is_private = node.name.startswith("_")
182
+
183
+ # Count parameters
184
+ num_params = self._count_params(node)
185
+
186
+ # Check for dict returns
187
+ returns_dict = self._returns_dict(node) if self.check_dict_returns else False
188
+
189
+ # Store function info
190
+ self.function_info.append(
191
+ {
192
+ "name": node.name,
193
+ "line": node.lineno,
194
+ "params": num_params,
195
+ "returns_dict": returns_dict,
196
+ "is_private": is_private,
197
+ }
198
+ )
199
+
200
+ # Check for too many parameters
201
+ if num_params > self.max_params and not is_private:
202
+ self.many_params_count += 1
203
+ self._add_many_params_violation(node, num_params)
204
+
205
+ # Check for dict returns
206
+ if returns_dict and not is_private:
207
+ self.dict_return_count += 1
208
+ self._add_dict_return_violation(node)
209
+
210
+ self.generic_visit(node)
211
+
212
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
213
+ """Handle async functions same as regular functions."""
214
+ if self._in_class:
215
+ self.generic_visit(node)
216
+ return
217
+
218
+ is_private = node.name.startswith("_")
219
+ num_params = self._count_params(node)
220
+ returns_dict = self._returns_dict(node) if self.check_dict_returns else False
221
+
222
+ self.function_info.append(
223
+ {
224
+ "name": node.name,
225
+ "line": node.lineno,
226
+ "params": num_params,
227
+ "returns_dict": returns_dict,
228
+ "is_private": is_private,
229
+ "is_async": True,
230
+ }
231
+ )
232
+
233
+ if num_params > self.max_params and not is_private:
234
+ self.many_params_count += 1
235
+ self._add_many_params_violation(node, num_params)
236
+
237
+ if returns_dict and not is_private:
238
+ self.dict_return_count += 1
239
+ self._add_dict_return_violation(node)
240
+
241
+ self.generic_visit(node)
242
+
243
+ def _count_params(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> int:
244
+ """Count the number of parameters in a function."""
245
+ args = node.args
246
+ count = len(args.args) + len(args.kwonlyargs)
247
+ if args.vararg:
248
+ count += 1
249
+ if args.kwarg:
250
+ count += 1
251
+ return count
252
+
253
+ def _returns_dict(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> bool:
254
+ """Check if function returns a dictionary literal."""
255
+ for child in ast.walk(node):
256
+ if isinstance(child, ast.Return) and child.value:
257
+ if isinstance(child.value, ast.Dict):
258
+ return True
259
+ # Check for dict() call
260
+ if isinstance(child.value, ast.Call):
261
+ if isinstance(child.value.func, ast.Name):
262
+ if child.value.func.id == "dict":
263
+ return True
264
+ return False
265
+
266
+ def _add_many_params_violation(
267
+ self,
268
+ node: ast.FunctionDef | ast.AsyncFunctionDef,
269
+ num_params: int,
270
+ ) -> None:
271
+ """Add violation for function with too many parameters."""
272
+ self.violations.append(
273
+ RuleViolation(
274
+ rule_name="functions_to_objects",
275
+ message=(
276
+ f"Function '{node.name}' has {num_params} parameters. "
277
+ f"Consider converting to a class."
278
+ ),
279
+ file_path=self.file_path,
280
+ line=node.lineno,
281
+ column=node.col_offset,
282
+ severity="info",
283
+ suggestion=(
284
+ "Functions with many parameters often indicate the need for an object. "
285
+ "Consider creating a class where parameters become attributes, "
286
+ "and the function becomes a method."
287
+ ),
288
+ code_snippet=self._get_source_line(node.lineno),
289
+ metadata={
290
+ "pattern": "many_parameters",
291
+ "function": node.name,
292
+ "param_count": num_params,
293
+ },
294
+ )
295
+ )
296
+
297
+ def _add_dict_return_violation(
298
+ self,
299
+ node: ast.FunctionDef | ast.AsyncFunctionDef,
300
+ ) -> None:
301
+ """Add violation for function returning a dict."""
302
+ self.violations.append(
303
+ RuleViolation(
304
+ rule_name="functions_to_objects",
305
+ message=(
306
+ f"Function '{node.name}' returns a dictionary. "
307
+ f"Consider using a dataclass or named tuple instead."
308
+ ),
309
+ file_path=self.file_path,
310
+ line=node.lineno,
311
+ column=node.col_offset,
312
+ severity="info",
313
+ suggestion=(
314
+ "Returning dictionaries loses type information and makes code "
315
+ "harder to maintain. Consider using a dataclass, named tuple, "
316
+ "or a proper class to represent the returned data."
317
+ ),
318
+ code_snippet=self._get_source_line(node.lineno),
319
+ metadata={
320
+ "pattern": "dict_return",
321
+ "function": node.name,
322
+ },
323
+ )
324
+ )
325
+
326
+ def _get_source_line(self, line_number: int) -> str:
327
+ """Get a specific line from the source code."""
328
+ lines = self.source.splitlines()
329
+ if 1 <= line_number <= len(lines):
330
+ return lines[line_number - 1].strip()
331
+ return ""