pydry-cli 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pydry/engine.py ADDED
@@ -0,0 +1,518 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import hashlib
5
+ from bisect import bisect_right
6
+ from collections import defaultdict
7
+ from dataclasses import asdict
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ from . import builtin_plugins # noqa: F401
11
+ from .analyze import (
12
+ _counter_jaccard,
13
+ _lcs_ratio,
14
+ canonicalize,
15
+ extract_features,
16
+ iter_functions,
17
+ iter_python_files,
18
+ occurrence_for,
19
+ )
20
+ from .models import ExactGroup, FunctionOccurrence, SimilarityEvidence, SimilarityResult
21
+ from .plugins import PairContext, PluginContext, apply_pair_plugins
22
+
23
+ if TYPE_CHECKING:
24
+ from pathlib import Path
25
+
26
+ _FeatureDict = dict[str, Any]
27
+ _SortKey = tuple[float, float, str, int, str, int]
28
+
29
+ DEFAULT_EXACT_OPTS = dict(
30
+ strip_docstrings=True,
31
+ strip_decorators=True,
32
+ normalize_arg_names=True,
33
+ strip_annotations=True,
34
+ normalize_local_names=False,
35
+ normalize_constants=False,
36
+ preserve_function_name=False,
37
+ )
38
+
39
+
40
+ def _sha(text: str) -> str:
41
+ return hashlib.sha256(text.encode("utf-8")).hexdigest()
42
+
43
+
44
+ def scan_functions(
45
+ root: Path,
46
+ *,
47
+ top_level_only: bool = False,
48
+ strict: bool = False,
49
+ scan_errors: list[str] | None = None,
50
+ ) -> list[dict[str, Any]]:
51
+ out = []
52
+ for path in iter_python_files(root):
53
+ try:
54
+ module = ast.parse(path.read_text(encoding="utf-8"), filename=str(path))
55
+ except Exception as exc:
56
+ msg = f"{path}: {type(exc).__name__}: {exc}"
57
+ if strict:
58
+ strict_msg = f"Failed to parse/read {path}: {type(exc).__name__}: {exc}"
59
+ raise RuntimeError(strict_msg) from exc
60
+ if scan_errors is not None:
61
+ scan_errors.append(msg)
62
+ continue
63
+ for fn, parents, is_method_flag in iter_functions(
64
+ module, top_level_only=top_level_only
65
+ ):
66
+ out.append(
67
+ {
68
+ "occurrence": occurrence_for(
69
+ path, fn, parents, is_method_flag=is_method_flag
70
+ ),
71
+ "node": fn,
72
+ "features": extract_features(fn),
73
+ }
74
+ )
75
+ return out
76
+
77
+
78
+ def exact_groups(
79
+ root: Path,
80
+ *,
81
+ min_count: int = 2,
82
+ top_level_only: bool = False,
83
+ include_canonical: bool = False,
84
+ normalize_local_names: bool = False,
85
+ normalize_constants: bool = False,
86
+ strict: bool = False,
87
+ scan_errors: list[str] | None = None,
88
+ ) -> list[ExactGroup]:
89
+ if min_count < 2:
90
+ msg = "min_count must be >= 2"
91
+ raise ValueError(msg)
92
+ groups = defaultdict(list)
93
+ canonical_by_hash = {}
94
+ opts = dict(DEFAULT_EXACT_OPTS)
95
+ opts["normalize_local_names"] = normalize_local_names
96
+ opts["normalize_constants"] = normalize_constants
97
+ for item in scan_functions(
98
+ root,
99
+ top_level_only=top_level_only,
100
+ strict=strict,
101
+ scan_errors=scan_errors,
102
+ ):
103
+ canonical = canonicalize(item["node"], **opts)
104
+ h = _sha(canonical)
105
+ groups[h].append(item["occurrence"])
106
+ if include_canonical and h not in canonical_by_hash:
107
+ canonical_by_hash[h] = canonical
108
+ res = []
109
+ for h, occs in groups.items():
110
+ if len(occs) >= min_count:
111
+ res.append(
112
+ ExactGroup(
113
+ hash=h,
114
+ count=len(occs),
115
+ occurrences=sorted(
116
+ occs, key=lambda o: (o.path, o.lineno, o.qualname)
117
+ ),
118
+ canonical=canonical_by_hash.get(h),
119
+ )
120
+ )
121
+ res.sort(key=lambda g: (-g.count, g.hash))
122
+ return res
123
+
124
+
125
+ def _sig_similarity(a: _FeatureDict, b: _FeatureDict) -> float:
126
+ a_pc: int = a["param_count"]
127
+ b_pc: int = b["param_count"]
128
+ pc = 1.0 - (abs(a_pc - b_pc) / max(a_pc, b_pc, 1))
129
+ modality = 1.0
130
+ if a["has_yield"] != b["has_yield"]:
131
+ modality -= 0.35
132
+ if a["has_await"] != b["has_await"]:
133
+ modality -= 0.25
134
+ return max(0.0, min(1.0, 0.7 * pc + 0.3 * modality))
135
+
136
+
137
+ def _wrapper_score(a: _FeatureDict, b: _FeatureDict) -> float:
138
+ score = 0.0
139
+ if a["is_wrapper"] and b["is_wrapper"]:
140
+ score += 0.5
141
+ if (
142
+ a["wrapper_target"] == b["wrapper_target"]
143
+ and a["wrapper_target"] is not None
144
+ ):
145
+ score += 0.35
146
+ elif a["is_wrapper"] or b["is_wrapper"]:
147
+ score += 0.25
148
+ return min(score, 1.0)
149
+
150
+
151
+ def _curry_score(a: _FeatureDict, b: _FeatureDict) -> float:
152
+ if a["returns_lambda"] and b["returns_lambda"]:
153
+ return 0.8 if a["curry_depth"] == b["curry_depth"] else 0.6
154
+ if a["returns_lambda"] or b["returns_lambda"]:
155
+ return 0.4
156
+ return 0.0
157
+
158
+
159
+ def _shape_similarity(a: _FeatureDict, b: _FeatureDict) -> float:
160
+ return _counter_jaccard(a["node_types"], b["node_types"])
161
+
162
+
163
+ def _stmt_similarity(a: _FeatureDict, b: _FeatureDict) -> float:
164
+ return _lcs_ratio(a["stmt_seq"], b["stmt_seq"])
165
+
166
+
167
+ def _call_similarity(a: _FeatureDict, b: _FeatureDict) -> float:
168
+ return _counter_jaccard(a["call_names"], b["call_names"])
169
+
170
+
171
+ def _difference_notes(a: _FeatureDict, b: _FeatureDict) -> list[str]:
172
+ notes = []
173
+ if a["param_count"] != b["param_count"]:
174
+ notes.append(
175
+ f"parameter count differs ({a['param_count']} vs {b['param_count']})"
176
+ )
177
+ if a["has_await"] != b["has_await"]:
178
+ notes.append("async behavior differs")
179
+ if a["has_yield"] != b["has_yield"]:
180
+ notes.append("generator behavior differs")
181
+ if a["raises"] != b["raises"]:
182
+ notes.append("exception behavior differs")
183
+ if a["wrapper_target"] != b["wrapper_target"] and (
184
+ a["is_wrapper"] or b["is_wrapper"]
185
+ ):
186
+ notes.append("wrapper targets differ")
187
+ if abs(a["literals"] - b["literals"]) >= 2:
188
+ notes.append("literal density differs")
189
+ if abs(a["control_count"] - b["control_count"]) >= 2:
190
+ notes.append("control-flow complexity differs")
191
+ return notes
192
+
193
+
194
+ def _risk_flags(a: _FeatureDict, b: _FeatureDict) -> list[str]:
195
+ flags = set()
196
+ if a["side_effect_calls"] or b["side_effect_calls"]:
197
+ flags.add("possible_side_effects")
198
+ if a["has_await"] != b["has_await"]:
199
+ flags.add("async_boundary_diff")
200
+ if a["has_yield"] != b["has_yield"]:
201
+ flags.add("return_shape_diff")
202
+ if a["raises"] != b["raises"]:
203
+ flags.add("exception_behavior_diff")
204
+ ext_diff = len(set(a["external_names"]) ^ set(b["external_names"]))
205
+ if ext_diff >= 6:
206
+ flags.add("ambient_dependency_diff")
207
+ return sorted(flags)
208
+
209
+
210
+ def _pattern_labels(
211
+ a: _FeatureDict, b: _FeatureDict, evidence: SimilarityEvidence
212
+ ) -> list[str]:
213
+ literal_token_diff = _literal_token_diff(a, b)
214
+ labels = []
215
+ if evidence.wrapper_score >= 0.5:
216
+ labels.append("wrapper")
217
+ if evidence.curry_score >= 0.4:
218
+ labels.append("partial_application")
219
+ if evidence.shape_similarity >= 0.9 and evidence.call_similarity >= 0.85:
220
+ labels.append("renamed_locals")
221
+ if (
222
+ literal_token_diff > 0
223
+ and abs(a["literals"] - b["literals"]) <= 2
224
+ and evidence.shape_similarity >= 0.85
225
+ and evidence.call_similarity >= 0.6
226
+ ):
227
+ labels.append("literal_specialization")
228
+ if evidence.shape_similarity >= 0.8 and evidence.stmt_similarity >= 0.8:
229
+ labels.append("extract_helper_candidate")
230
+ if (
231
+ evidence.signature_similarity >= 0.8
232
+ and abs(a["param_count"] - b["param_count"]) <= 1
233
+ and evidence.call_similarity < 0.5
234
+ ):
235
+ labels.append("same_shape_different_dependencies")
236
+ return labels
237
+
238
+
239
+ def _literal_token_diff(a: _FeatureDict, b: _FeatureDict) -> int:
240
+ a_tokens: dict[str, int] = a.get("literal_tokens", {})
241
+ b_tokens: dict[str, int] = b.get("literal_tokens", {})
242
+ keys = set(a_tokens) | set(b_tokens)
243
+ return sum(abs(a_tokens.get(k, 0) - b_tokens.get(k, 0)) for k in keys)
244
+
245
+
246
+ def _shared_summary(a: _FeatureDict, b: _FeatureDict) -> str:
247
+ common_calls = sorted(set(a["call_names"]) & set(b["call_names"]))
248
+ common_stmt = sorted(set(a["stmt_seq"]) & set(b["stmt_seq"]))
249
+ parts = []
250
+ if common_stmt:
251
+ parts.append("shared statements: " + ", ".join(common_stmt[:6]))
252
+ if common_calls:
253
+ parts.append("shared calls: " + ", ".join(common_calls[:6]))
254
+ if not parts:
255
+ parts.append("shared AST shape without strong call overlap")
256
+ return "; ".join(parts)
257
+
258
+
259
+ def _suggest_refactor(
260
+ labels: list[str], risks: list[str], evidence: SimilarityEvidence
261
+ ) -> str:
262
+ if "wrapper" in labels and evidence.wrapper_score >= 0.5:
263
+ return "merge_into_single_function_with_param"
264
+ if "partial_application" in labels:
265
+ return "introduce_partial"
266
+ if "extract_helper_candidate" in labels and "possible_side_effects" not in risks:
267
+ return "extract_common_helper"
268
+ if "literal_specialization" in labels:
269
+ return "parameterize_constant"
270
+ if (
271
+ "ambient_dependency_diff" in risks
272
+ or "async_boundary_diff" in risks
273
+ or "return_shape_diff" in risks
274
+ ):
275
+ return "leave_separate"
276
+ return "move_to_utils"
277
+
278
+
279
+ def _refactorability(
280
+ labels: list[str], risks: list[str], evidence: SimilarityEvidence
281
+ ) -> float:
282
+ score = (
283
+ 0.35 * evidence.shape_similarity
284
+ + 0.15 * evidence.stmt_similarity
285
+ + 0.10 * evidence.call_similarity
286
+ + 0.10 * evidence.signature_similarity
287
+ + 0.20 * evidence.wrapper_score
288
+ + 0.10 * evidence.curry_score
289
+ )
290
+ if "extract_helper_candidate" in labels:
291
+ score += 0.08
292
+ if "literal_specialization" in labels:
293
+ score += 0.05
294
+ score -= 0.1 * len(risks)
295
+ return max(0.0, min(1.0, score))
296
+
297
+
298
+ def _abstract_template(
299
+ a_occ: FunctionOccurrence,
300
+ b_occ: FunctionOccurrence,
301
+ labels: list[str],
302
+ shared_summary: str,
303
+ ) -> str | None:
304
+ if (
305
+ "extract_helper_candidate" in labels
306
+ or "literal_specialization" in labels
307
+ or "wrapper" in labels
308
+ ):
309
+ return (
310
+ f"def shared_helper(...):\n"
311
+ f" # candidate abstraction for {a_occ.qualname} and {b_occ.qualname}\n"
312
+ f" # {shared_summary}\n"
313
+ f" ..."
314
+ )
315
+ return None
316
+
317
+
318
+ def _result_sort_key(result: SimilarityResult) -> _SortKey:
319
+ return (
320
+ -result.refactorability_score,
321
+ -result.similarity_score,
322
+ result.a.path,
323
+ result.a.lineno,
324
+ result.b.path,
325
+ result.b.lineno,
326
+ )
327
+
328
+
329
+ def near_matches(
330
+ root: Path,
331
+ *,
332
+ threshold: float = 0.8,
333
+ top_k: int | None = None,
334
+ top_level_only: bool = False,
335
+ strict: bool = False,
336
+ scan_errors: list[str] | None = None,
337
+ plugin_errors: list[str] | None = None,
338
+ ) -> list[SimilarityResult]:
339
+ if not 0.0 <= threshold <= 1.0:
340
+ msg = "threshold must be between 0 and 1"
341
+ raise ValueError(msg)
342
+ if top_k is not None and top_k < 0:
343
+ msg = "top_k must be >= 0"
344
+ raise ValueError(msg)
345
+
346
+ items = scan_functions(
347
+ root,
348
+ top_level_only=top_level_only,
349
+ strict=strict,
350
+ scan_errors=scan_errors,
351
+ )
352
+ bounded_top_k = top_k
353
+ if bounded_top_k == 0:
354
+ return []
355
+
356
+ out: list[SimilarityResult] = []
357
+ top_rows: list[SimilarityResult] = []
358
+ top_keys: list[_SortKey] = []
359
+ for i in range(len(items)):
360
+ a = items[i]
361
+ af = a["features"]
362
+ for j in range(i + 1, len(items)):
363
+ b = items[j]
364
+ bf = b["features"]
365
+ size_ratio = min(af["stmt_count"], bf["stmt_count"]) / max(
366
+ af["stmt_count"], bf["stmt_count"], 1
367
+ )
368
+ if size_ratio < 0.4:
369
+ continue
370
+ shape_similarity = _shape_similarity(af, bf)
371
+ call_similarity = _call_similarity(af, bf)
372
+ signature_similarity = _sig_similarity(af, bf)
373
+ wrapper_score = _wrapper_score(af, bf)
374
+ curry_score = _curry_score(af, bf)
375
+
376
+ similarity_upper_bound = (
377
+ 0.40 * shape_similarity
378
+ + 0.20
379
+ + 0.15 * call_similarity
380
+ + 0.10 * signature_similarity
381
+ + 0.10 * wrapper_score
382
+ + 0.05 * curry_score
383
+ )
384
+ if similarity_upper_bound < threshold:
385
+ continue
386
+ stmt_similarity = _stmt_similarity(af, bf)
387
+ evidence = SimilarityEvidence(
388
+ shape_similarity=shape_similarity,
389
+ stmt_similarity=stmt_similarity,
390
+ call_similarity=call_similarity,
391
+ signature_similarity=signature_similarity,
392
+ wrapper_score=wrapper_score,
393
+ curry_score=curry_score,
394
+ )
395
+ similarity = (
396
+ 0.40 * shape_similarity
397
+ + 0.20 * stmt_similarity
398
+ + 0.15 * call_similarity
399
+ + 0.10 * signature_similarity
400
+ + 0.10 * wrapper_score
401
+ + 0.05 * curry_score
402
+ )
403
+ if similarity < threshold:
404
+ continue
405
+ base_risks = _risk_flags(af, bf)
406
+ base_labels = _pattern_labels(af, bf, evidence)
407
+ summary = _shared_summary(af, bf)
408
+ base_diffs = _difference_notes(af, bf)
409
+
410
+ plugin_result = apply_pair_plugins(
411
+ PairContext(
412
+ a=PluginContext(
413
+ occurrence=a["occurrence"], node=a["node"], features=af
414
+ ),
415
+ b=PluginContext(
416
+ occurrence=b["occurrence"], node=b["node"], features=bf
417
+ ),
418
+ evidence=evidence,
419
+ ),
420
+ plugin_errors=plugin_errors,
421
+ )
422
+
423
+ risks = []
424
+ for item in [*base_risks, *plugin_result.risk_flags]:
425
+ if item not in risks:
426
+ risks.append(item)
427
+
428
+ labels = []
429
+ for item in [*base_labels, *plugin_result.pattern_labels]:
430
+ if item not in labels:
431
+ labels.append(item)
432
+
433
+ diffs = []
434
+ for item in [*base_diffs, *plugin_result.key_differences]:
435
+ if item not in diffs:
436
+ diffs.append(item)
437
+
438
+ refactorability = (
439
+ _refactorability(labels, risks, evidence)
440
+ + plugin_result.refactorability_delta
441
+ )
442
+ refactorability = max(0.0, min(1.0, refactorability))
443
+ suggested = plugin_result.suggested_refactor_kind or _suggest_refactor(
444
+ labels, risks, evidence
445
+ )
446
+ abstract_template = plugin_result.abstract_template or _abstract_template(
447
+ a["occurrence"], b["occurrence"], labels, summary
448
+ )
449
+ metadata = {"size_ratio": round(size_ratio, 4), **plugin_result.metadata}
450
+
451
+ result = SimilarityResult(
452
+ similarity_score=round(similarity, 4),
453
+ refactorability_score=round(refactorability, 4),
454
+ pattern_labels=labels,
455
+ shared_structure_summary=summary,
456
+ key_differences=diffs,
457
+ risk_flags=risks,
458
+ suggested_refactor_kind=suggested,
459
+ a=a["occurrence"],
460
+ b=b["occurrence"],
461
+ evidence=evidence,
462
+ abstract_template=abstract_template,
463
+ metadata=metadata,
464
+ )
465
+ if bounded_top_k is None:
466
+ out.append(result)
467
+ else:
468
+ key = _result_sort_key(result)
469
+ if len(top_rows) < bounded_top_k:
470
+ idx = bisect_right(top_keys, key)
471
+ top_keys.insert(idx, key)
472
+ top_rows.insert(idx, result)
473
+ elif key < top_keys[-1]:
474
+ idx = bisect_right(top_keys, key)
475
+ top_keys.insert(idx, key)
476
+ top_rows.insert(idx, result)
477
+ top_keys.pop()
478
+ top_rows.pop()
479
+
480
+ if bounded_top_k is None:
481
+ out.sort(key=_result_sort_key)
482
+ if top_k is not None:
483
+ out = out[:top_k]
484
+ return out
485
+ return top_rows
486
+
487
+
488
+ def abstract_candidates(
489
+ root: Path,
490
+ *,
491
+ threshold: float = 0.82,
492
+ top_k: int | None = None,
493
+ top_level_only: bool = False,
494
+ strict: bool = False,
495
+ scan_errors: list[str] | None = None,
496
+ plugin_errors: list[str] | None = None,
497
+ ) -> list[SimilarityResult]:
498
+ matches = near_matches(
499
+ root,
500
+ threshold=threshold,
501
+ top_k=top_k,
502
+ top_level_only=top_level_only,
503
+ strict=strict,
504
+ scan_errors=scan_errors,
505
+ plugin_errors=plugin_errors,
506
+ )
507
+ return [m for m in matches if m.suggested_refactor_kind != "leave_separate"]
508
+
509
+
510
+ def to_jsonable(obj: Any) -> Any:
511
+ if isinstance(obj, list):
512
+ return [to_jsonable(x) for x in obj]
513
+ if hasattr(obj, "__dataclass_fields__"):
514
+ data = asdict(obj)
515
+ return to_jsonable(data)
516
+ if isinstance(obj, dict):
517
+ return {k: to_jsonable(v) for k, v in obj.items()}
518
+ return obj
pydry/models.py ADDED
@@ -0,0 +1,51 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class FunctionOccurrence:
9
+ path: str
10
+ lineno: int
11
+ end_lineno: int | None
12
+ col_offset: int
13
+ name: str
14
+ qualname: str
15
+ kind: str
16
+ param_count: int
17
+ is_method: bool
18
+
19
+
20
+ @dataclass
21
+ class ExactGroup:
22
+ hash: str
23
+ count: int
24
+ occurrences: list[FunctionOccurrence]
25
+ canonical: str | None = None
26
+
27
+
28
+ @dataclass
29
+ class SimilarityEvidence:
30
+ shape_similarity: float
31
+ stmt_similarity: float
32
+ call_similarity: float
33
+ signature_similarity: float
34
+ wrapper_score: float
35
+ curry_score: float
36
+
37
+
38
+ @dataclass
39
+ class SimilarityResult:
40
+ similarity_score: float
41
+ refactorability_score: float
42
+ pattern_labels: list[str]
43
+ shared_structure_summary: str
44
+ key_differences: list[str]
45
+ risk_flags: list[str]
46
+ suggested_refactor_kind: str
47
+ a: FunctionOccurrence
48
+ b: FunctionOccurrence
49
+ evidence: SimilarityEvidence
50
+ abstract_template: str | None = None
51
+ metadata: dict[str, Any] = field(default_factory=dict)