pydry-cli 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pydry/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ from .cli import main
2
+
3
+ __version__ = "0.0.3"
4
+
5
+ __all__ = ["__version__", "main"]
pydry/__main__.py ADDED
@@ -0,0 +1,5 @@
1
+ from __future__ import annotations
2
+
3
+ from .cli import main
4
+
5
+ raise SystemExit(main())
pydry/analyze.py ADDED
@@ -0,0 +1,319 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import os
5
+ from collections import Counter
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ from .models import FunctionOccurrence
10
+ from .normalize import FunctionNormalizer
11
+
12
+ if TYPE_CHECKING:
13
+ from collections.abc import Generator, Iterable
14
+
15
+ _FuncNode = ast.FunctionDef | ast.AsyncFunctionDef
16
+
17
+ SIDE_EFFECT_CALLS = {
18
+ "print",
19
+ "open",
20
+ "write",
21
+ "send",
22
+ "post",
23
+ "put",
24
+ "delete",
25
+ "remove",
26
+ "unlink",
27
+ "save",
28
+ "commit",
29
+ }
30
+ CONTROL_FLOW_NODES = (
31
+ ast.If,
32
+ ast.For,
33
+ ast.AsyncFor,
34
+ ast.While,
35
+ ast.Try,
36
+ ast.With,
37
+ ast.AsyncWith,
38
+ ast.Match,
39
+ )
40
+ STMT_TYPES = (
41
+ ast.Assign,
42
+ ast.AnnAssign,
43
+ ast.AugAssign,
44
+ ast.Return,
45
+ ast.Expr,
46
+ ast.If,
47
+ ast.For,
48
+ ast.AsyncFor,
49
+ ast.While,
50
+ ast.Try,
51
+ ast.With,
52
+ ast.AsyncWith,
53
+ ast.Raise,
54
+ ast.Assert,
55
+ ast.Pass,
56
+ ast.Break,
57
+ ast.Continue,
58
+ ast.Import,
59
+ ast.ImportFrom,
60
+ ast.Delete,
61
+ ast.Match,
62
+ ast.Yield,
63
+ ast.YieldFrom,
64
+ )
65
+
66
+ DEFAULT_EXCLUDED_DIRS = {
67
+ "__pycache__",
68
+ ".git",
69
+ ".hg",
70
+ ".mypy_cache",
71
+ ".pytest_cache",
72
+ ".ruff_cache",
73
+ ".tox",
74
+ ".nox",
75
+ ".venv",
76
+ "venv",
77
+ "site-packages",
78
+ "build",
79
+ "dist",
80
+ ".eggs",
81
+ }
82
+
83
+
84
+ def iter_python_files(root: Path) -> Iterable[Path]:
85
+ for dirpath, dirnames, filenames in os.walk(root, topdown=True):
86
+ dirnames[:] = sorted(d for d in dirnames if d not in DEFAULT_EXCLUDED_DIRS)
87
+ for filename in sorted(filenames):
88
+ if filename.endswith(".py"):
89
+ path = Path(dirpath, filename)
90
+ if path.is_file():
91
+ yield path
92
+
93
+
94
+ def build_qualname(parents: list[str], name: str) -> str:
95
+ return ".".join([*parents, name]) if parents else name
96
+
97
+
98
+ def iter_functions(
99
+ module: ast.Module, top_level_only: bool = False
100
+ ) -> Generator[tuple[_FuncNode, list[str], bool]]:
101
+ if top_level_only:
102
+ for node in module.body:
103
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
104
+ yield node, [], False
105
+ return
106
+
107
+ def walk(
108
+ nodes: list[ast.stmt], parents: list[str], container_kind: str
109
+ ) -> Generator[tuple[_FuncNode, list[str], bool]]:
110
+ for node in nodes:
111
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
112
+ is_class_method = container_kind == "class"
113
+ yield node, parents, is_class_method
114
+ yield from walk(node.body, [*parents, node.name], "function")
115
+ elif isinstance(node, ast.ClassDef):
116
+ yield from walk(node.body, [*parents, node.name], "class")
117
+
118
+ yield from walk(module.body, [], "module")
119
+
120
+
121
+ def param_count(fn: ast.FunctionDef | ast.AsyncFunctionDef) -> int:
122
+ return (
123
+ len(fn.args.posonlyargs)
124
+ + len(fn.args.args)
125
+ + len(fn.args.kwonlyargs)
126
+ + int(fn.args.vararg is not None)
127
+ + int(fn.args.kwarg is not None)
128
+ )
129
+
130
+
131
+ def is_method(parents: list[str]) -> bool:
132
+ return bool(parents)
133
+
134
+
135
+ def canonicalize(fn: _FuncNode, **opts: Any) -> str:
136
+ cloned = ast.fix_missing_locations(ast.parse(ast.unparse(fn)).body[0])
137
+ norm = FunctionNormalizer(**opts)
138
+ cloned = ast.fix_missing_locations(norm.visit(cloned))
139
+ return ast.dump(cloned, annotate_fields=True, include_attributes=False)
140
+
141
+
142
+ def _call_name(node: ast.Call) -> str:
143
+ f = node.func
144
+ if isinstance(f, ast.Name):
145
+ return f.id
146
+ if isinstance(f, ast.Attribute):
147
+ parts = []
148
+ cur: ast.expr = f
149
+ while isinstance(cur, ast.Attribute):
150
+ parts.append(cur.attr)
151
+ cur = cur.value
152
+ if isinstance(cur, ast.Name):
153
+ parts.append(cur.id)
154
+ return ".".join(reversed(parts))
155
+ return "<dynamic>"
156
+
157
+
158
+ def _literal_token(value: object) -> str:
159
+ if isinstance(value, str):
160
+ return f"str:{value}"
161
+ if isinstance(value, bytes):
162
+ return f"bytes:{value!r}"
163
+ if value is None:
164
+ return "none"
165
+ if isinstance(value, bool):
166
+ return f"bool:{value}"
167
+ if isinstance(value, int):
168
+ return f"int:{value}"
169
+ if isinstance(value, float):
170
+ return f"float:{value!r}"
171
+ if isinstance(value, complex):
172
+ return f"complex:{value!r}"
173
+ return f"type:{type(value).__name__}"
174
+
175
+
176
+ def _stmt_sequence(fn: _FuncNode) -> list[str]:
177
+ seq = []
178
+ for n in ast.walk(fn):
179
+ if isinstance(n, STMT_TYPES):
180
+ seq.append(type(n).__name__)
181
+ return seq
182
+
183
+
184
+ def _counter_jaccard(a: Counter[str], b: Counter[str]) -> float:
185
+ keys = set(a) | set(b)
186
+ if not keys:
187
+ return 1.0
188
+ inter = sum(min(a[k], b[k]) for k in keys)
189
+ union = sum(max(a[k], b[k]) for k in keys)
190
+ return inter / union if union else 1.0
191
+
192
+
193
+ def _lcs_ratio(a: list[str], b: list[str]) -> float:
194
+ if not a and not b:
195
+ return 1.0
196
+ if not a or not b:
197
+ return 0.0
198
+ longer = a
199
+ shorter = b
200
+ if len(shorter) > len(longer):
201
+ longer, shorter = shorter, longer
202
+
203
+ prev = [0] * (len(shorter) + 1)
204
+ for token in longer:
205
+ current = [0] * (len(shorter) + 1)
206
+ for j, short_token in enumerate(shorter, start=1):
207
+ if token == short_token:
208
+ current[j] = prev[j - 1] + 1
209
+ else:
210
+ current[j] = max(prev[j], current[j - 1])
211
+ prev = current
212
+
213
+ lcs = prev[-1]
214
+ return (2 * lcs) / (len(a) + len(b))
215
+
216
+
217
+ def extract_features(fn: _FuncNode) -> dict[str, Any]:
218
+ node_types = Counter(type(n).__name__ for n in ast.walk(fn))
219
+ stmt_seq = _stmt_sequence(fn)
220
+ call_names = Counter(_call_name(n) for n in ast.walk(fn) if isinstance(n, ast.Call))
221
+ literal_tokens = Counter(
222
+ _literal_token(n.value) for n in ast.walk(fn) if isinstance(n, ast.Constant)
223
+ )
224
+ external_names = Counter(
225
+ n.id
226
+ for n in ast.walk(fn)
227
+ if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
228
+ )
229
+ has_yield = any(isinstance(n, (ast.Yield, ast.YieldFrom)) for n in ast.walk(fn))
230
+ has_await = any(isinstance(n, ast.Await) for n in ast.walk(fn))
231
+ control_count = sum(1 for n in ast.walk(fn) if isinstance(n, CONTROL_FLOW_NODES))
232
+ returns = sum(1 for n in ast.walk(fn) if isinstance(n, ast.Return))
233
+ raises = sum(1 for n in ast.walk(fn) if isinstance(n, ast.Raise))
234
+ literals = sum(literal_tokens.values())
235
+ side_effect_calls = sorted(
236
+ {name for name in call_names if name.split(".")[-1] in SIDE_EFFECT_CALLS}
237
+ )
238
+ is_wrapper = False
239
+ wrapper_target = None
240
+ fixed_args = 0
241
+ passthrough_args = 0
242
+
243
+ body = getattr(fn, "body", [])
244
+ if len(body) == 1:
245
+ stmt = body[0]
246
+ call = None
247
+ if (isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Call)) or (
248
+ isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Call)
249
+ ):
250
+ call = stmt.value
251
+ if call is not None:
252
+ is_wrapper = True
253
+ wrapper_target = _call_name(call)
254
+ arg_names = {
255
+ a.arg
256
+ for a in list(fn.args.posonlyargs)
257
+ + list(fn.args.args)
258
+ + list(fn.args.kwonlyargs)
259
+ }
260
+ for arg in call.args:
261
+ if isinstance(arg, ast.Name) and arg.id in arg_names:
262
+ passthrough_args += 1
263
+ else:
264
+ fixed_args += 1
265
+
266
+ returns_lambda = False
267
+ curry_depth = 0
268
+ for stmt in body:
269
+ candidate = stmt.value if isinstance(stmt, ast.Return) else None
270
+ while isinstance(candidate, ast.Lambda):
271
+ returns_lambda = True
272
+ curry_depth += 1
273
+ candidate = candidate.body
274
+
275
+ return {
276
+ "node_types": node_types,
277
+ "stmt_seq": stmt_seq,
278
+ "call_names": call_names,
279
+ "external_names": external_names,
280
+ "param_count": param_count(fn),
281
+ "has_yield": has_yield,
282
+ "has_await": has_await,
283
+ "control_count": control_count,
284
+ "returns": returns,
285
+ "raises": raises,
286
+ "literals": literals,
287
+ "literal_tokens": literal_tokens,
288
+ "side_effect_calls": side_effect_calls,
289
+ "is_wrapper": is_wrapper,
290
+ "wrapper_target": wrapper_target,
291
+ "fixed_args": fixed_args,
292
+ "passthrough_args": passthrough_args,
293
+ "returns_lambda": returns_lambda,
294
+ "curry_depth": curry_depth,
295
+ "stmt_count": len(stmt_seq),
296
+ }
297
+
298
+
299
+ def occurrence_for(
300
+ path: Path,
301
+ fn: _FuncNode,
302
+ parents: list[str],
303
+ *,
304
+ is_method_flag: bool | None = None,
305
+ ) -> FunctionOccurrence:
306
+ resolved_is_method = (
307
+ is_method(parents) if is_method_flag is None else is_method_flag
308
+ )
309
+ return FunctionOccurrence(
310
+ path=str(path),
311
+ lineno=getattr(fn, "lineno", 0),
312
+ end_lineno=getattr(fn, "end_lineno", None),
313
+ col_offset=getattr(fn, "col_offset", 0),
314
+ name=fn.name,
315
+ qualname=build_qualname(parents, fn.name),
316
+ kind="async def" if isinstance(fn, ast.AsyncFunctionDef) else "def",
317
+ param_count=param_count(fn),
318
+ is_method=resolved_is_method,
319
+ )
@@ -0,0 +1,206 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from .plugins import PairContext, PairPluginResult, register_pair_plugin
6
+
7
+
8
+ def _literal_token_diff(a: dict[str, Any], b: dict[str, Any]) -> int:
9
+ a_tokens: dict[str, int] = a.get("literal_tokens", {})
10
+ b_tokens: dict[str, int] = b.get("literal_tokens", {})
11
+ keys = set(a_tokens) | set(b_tokens)
12
+ return sum(abs(a_tokens.get(k, 0) - b_tokens.get(k, 0)) for k in keys)
13
+
14
+
15
+ @register_pair_plugin
16
+ class WrapperPlugin:
17
+ name = "wrapper"
18
+
19
+ def analyze_pair(self, ctx: PairContext) -> PairPluginResult:
20
+ a = ctx.a.features
21
+ b = ctx.b.features
22
+ evidence = ctx.evidence
23
+ if evidence.wrapper_score < 0.5:
24
+ return PairPluginResult()
25
+ metadata = {
26
+ "a_wrapper_target": a.get("wrapper_target"),
27
+ "b_wrapper_target": b.get("wrapper_target"),
28
+ "a_fixed_args": a.get("fixed_args", 0),
29
+ "b_fixed_args": b.get("fixed_args", 0),
30
+ }
31
+ return PairPluginResult(
32
+ pattern_labels=["wrapper"],
33
+ key_differences=(
34
+ ["wrapper targets differ"]
35
+ if a.get("wrapper_target") != b.get("wrapper_target")
36
+ else []
37
+ ),
38
+ suggested_refactor_kind="merge_into_single_function_with_param",
39
+ refactorability_delta=0.05,
40
+ metadata=metadata,
41
+ )
42
+
43
+
44
+ @register_pair_plugin
45
+ class CurryingPlugin:
46
+ name = "currying"
47
+
48
+ def analyze_pair(self, ctx: PairContext) -> PairPluginResult:
49
+ a = ctx.a.features
50
+ b = ctx.b.features
51
+ evidence = ctx.evidence
52
+ if evidence.curry_score < 0.4:
53
+ return PairPluginResult()
54
+ return PairPluginResult(
55
+ pattern_labels=["partial_application"],
56
+ suggested_refactor_kind="introduce_partial",
57
+ refactorability_delta=0.04,
58
+ metadata={
59
+ "a_returns_lambda": a.get("returns_lambda"),
60
+ "b_returns_lambda": b.get("returns_lambda"),
61
+ "a_curry_depth": a.get("curry_depth"),
62
+ "b_curry_depth": b.get("curry_depth"),
63
+ },
64
+ )
65
+
66
+
67
+ @register_pair_plugin
68
+ class SideEffectRiskPlugin:
69
+ name = "side_effects"
70
+
71
+ def analyze_pair(self, ctx: PairContext) -> PairPluginResult:
72
+ a = ctx.a.features
73
+ b = ctx.b.features
74
+ calls = sorted(
75
+ set(a.get("side_effect_calls", [])) | set(b.get("side_effect_calls", []))
76
+ )
77
+ if not calls:
78
+ return PairPluginResult()
79
+ return PairPluginResult(
80
+ risk_flags=["possible_side_effects"],
81
+ refactorability_delta=-0.05,
82
+ metadata={"calls": calls},
83
+ )
84
+
85
+
86
+ @register_pair_plugin
87
+ class AsyncBoundaryPlugin:
88
+ name = "async_boundary"
89
+
90
+ def analyze_pair(self, ctx: PairContext) -> PairPluginResult:
91
+ a = ctx.a.features
92
+ b = ctx.b.features
93
+ flags = []
94
+ diffs = []
95
+ delta = 0.0
96
+ if a.get("has_await") != b.get("has_await"):
97
+ flags.append("async_boundary_diff")
98
+ diffs.append("async behavior differs")
99
+ delta -= 0.08
100
+ if a.get("has_yield") != b.get("has_yield"):
101
+ flags.append("return_shape_diff")
102
+ diffs.append("generator behavior differs")
103
+ delta -= 0.08
104
+ if a.get("raises") != b.get("raises"):
105
+ flags.append("exception_behavior_diff")
106
+ diffs.append("exception behavior differs")
107
+ delta -= 0.05
108
+ if not flags and not diffs:
109
+ return PairPluginResult()
110
+ return PairPluginResult(
111
+ risk_flags=flags,
112
+ key_differences=diffs,
113
+ refactorability_delta=delta,
114
+ metadata={
115
+ "a_has_await": a.get("has_await"),
116
+ "b_has_await": b.get("has_await"),
117
+ "a_has_yield": a.get("has_yield"),
118
+ "b_has_yield": b.get("has_yield"),
119
+ "a_raises": a.get("raises"),
120
+ "b_raises": b.get("raises"),
121
+ },
122
+ )
123
+
124
+
125
+ @register_pair_plugin
126
+ class LiteralSpecializationPlugin:
127
+ name = "literal_specialization"
128
+
129
+ def analyze_pair(self, ctx: PairContext) -> PairPluginResult:
130
+ a = ctx.a.features
131
+ b = ctx.b.features
132
+ e = ctx.evidence
133
+ literal_token_diff = _literal_token_diff(a, b)
134
+ if (
135
+ literal_token_diff > 0
136
+ and abs(a.get("literals", 0) - b.get("literals", 0)) <= 2
137
+ and e.shape_similarity >= 0.85
138
+ and e.call_similarity >= 0.6
139
+ ):
140
+ return PairPluginResult(
141
+ pattern_labels=["literal_specialization"],
142
+ suggested_refactor_kind="parameterize_constant",
143
+ refactorability_delta=0.03,
144
+ abstract_template=(
145
+ "def shared_helper(..., configurable_value):\n"
146
+ " # parameterize constant-like variation\n"
147
+ " ..."
148
+ ),
149
+ metadata={
150
+ "a_literals": a.get("literals"),
151
+ "b_literals": b.get("literals"),
152
+ "literal_token_diff": literal_token_diff,
153
+ },
154
+ )
155
+ return PairPluginResult()
156
+
157
+
158
+ @register_pair_plugin
159
+ class ExtractHelperPlugin:
160
+ name = "extract_helper"
161
+
162
+ def analyze_pair(self, ctx: PairContext) -> PairPluginResult:
163
+ e = ctx.evidence
164
+ if e.shape_similarity >= 0.8 and e.stmt_similarity >= 0.8:
165
+ return PairPluginResult(
166
+ pattern_labels=["extract_helper_candidate"],
167
+ suggested_refactor_kind="extract_common_helper",
168
+ refactorability_delta=0.05,
169
+ abstract_template=(
170
+ "def shared_helper(...):\n"
171
+ " # candidate abstraction for "
172
+ f"{ctx.a.occurrence.qualname}"
173
+ f" and {ctx.b.occurrence.qualname}\n"
174
+ " ..."
175
+ ),
176
+ metadata={
177
+ "shape_similarity": e.shape_similarity,
178
+ "stmt_similarity": e.stmt_similarity,
179
+ },
180
+ )
181
+ return PairPluginResult()
182
+
183
+
184
+ @register_pair_plugin
185
+ class DependencyDivergencePlugin:
186
+ name = "dependency_divergence"
187
+
188
+ def analyze_pair(self, ctx: PairContext) -> PairPluginResult:
189
+ a = ctx.a.features
190
+ b = ctx.b.features
191
+ e = ctx.evidence
192
+ ext_diff = len(
193
+ set(a.get("external_names", {})) ^ set(b.get("external_names", {}))
194
+ )
195
+ if ext_diff >= 6:
196
+ return PairPluginResult(
197
+ risk_flags=["ambient_dependency_diff"],
198
+ pattern_labels=(
199
+ ["same_shape_different_dependencies"]
200
+ if e.signature_similarity >= 0.8 and e.call_similarity < 0.5
201
+ else []
202
+ ),
203
+ refactorability_delta=-0.06,
204
+ metadata={"external_name_symmetric_difference": ext_diff},
205
+ )
206
+ return PairPluginResult()