gabion 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gabion/__init__.py +5 -0
- gabion/__main__.py +11 -0
- gabion/analysis/__init__.py +37 -0
- gabion/analysis/dataflow_audit.py +3173 -0
- gabion/analysis/engine.py +8 -0
- gabion/analysis/model.py +45 -0
- gabion/analysis/visitors.py +402 -0
- gabion/cli.py +503 -0
- gabion/config.py +45 -0
- gabion/lsp_client.py +111 -0
- gabion/refactor/__init__.py +4 -0
- gabion/refactor/engine.py +726 -0
- gabion/refactor/model.py +37 -0
- gabion/schema.py +84 -0
- gabion/server.py +447 -0
- gabion/synthesis/__init__.py +26 -0
- gabion/synthesis/merge.py +41 -0
- gabion/synthesis/model.py +41 -0
- gabion/synthesis/naming.py +45 -0
- gabion/synthesis/protocols.py +74 -0
- gabion/synthesis/schedule.py +87 -0
- gabion-0.1.0.dist-info/METADATA +250 -0
- gabion-0.1.0.dist-info/RECORD +26 -0
- gabion-0.1.0.dist-info/WHEEL +4 -0
- gabion-0.1.0.dist-info/entry_points.txt +3 -0
- gabion-0.1.0.dist-info/licenses/LICENSE +190 -0
|
@@ -0,0 +1,726 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import libcst as cst
|
|
6
|
+
|
|
7
|
+
from gabion.refactor.model import FieldSpec, RefactorPlan, RefactorRequest, TextEdit
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RefactorEngine:
|
|
11
|
+
def __init__(self, project_root: Path | None = None) -> None:
|
|
12
|
+
self.project_root = project_root
|
|
13
|
+
|
|
14
|
+
def plan_protocol_extraction(self, request: RefactorRequest) -> RefactorPlan:
|
|
15
|
+
path = Path(request.target_path)
|
|
16
|
+
if self.project_root and not path.is_absolute():
|
|
17
|
+
path = self.project_root / path
|
|
18
|
+
try:
|
|
19
|
+
source = path.read_text()
|
|
20
|
+
except Exception as exc:
|
|
21
|
+
return RefactorPlan(errors=[f"Failed to read {path}: {exc}"])
|
|
22
|
+
try:
|
|
23
|
+
module = cst.parse_module(source)
|
|
24
|
+
except Exception as exc:
|
|
25
|
+
return RefactorPlan(errors=[f"LibCST parse failed for {path}: {exc}"])
|
|
26
|
+
protocol = (request.protocol_name or "").strip()
|
|
27
|
+
if not protocol:
|
|
28
|
+
return RefactorPlan(errors=["Protocol name is required for extraction."])
|
|
29
|
+
bundle = [name.strip() for name in request.bundle or [] if name.strip()]
|
|
30
|
+
field_specs: list[FieldSpec] = []
|
|
31
|
+
seen_fields: set[str] = set()
|
|
32
|
+
for spec in request.fields or []:
|
|
33
|
+
name = (spec.name or "").strip()
|
|
34
|
+
if not name or name in seen_fields:
|
|
35
|
+
continue
|
|
36
|
+
seen_fields.add(name)
|
|
37
|
+
field_specs.append(spec)
|
|
38
|
+
if bundle:
|
|
39
|
+
for name in bundle:
|
|
40
|
+
if name in seen_fields:
|
|
41
|
+
continue
|
|
42
|
+
seen_fields.add(name)
|
|
43
|
+
field_specs.append(FieldSpec(name=name))
|
|
44
|
+
elif field_specs:
|
|
45
|
+
bundle = [spec.name for spec in field_specs]
|
|
46
|
+
if not bundle:
|
|
47
|
+
return RefactorPlan(errors=["Bundle fields are required for extraction."])
|
|
48
|
+
|
|
49
|
+
body = list(module.body)
|
|
50
|
+
|
|
51
|
+
insert_idx = _find_import_insert_index(body)
|
|
52
|
+
|
|
53
|
+
protocol_base: cst.CSTNode
|
|
54
|
+
import_stmt: cst.SimpleStatementLine | None = None
|
|
55
|
+
if _has_typing_import(body):
|
|
56
|
+
protocol_base = cst.Attribute(cst.Name("typing"), cst.Name("Protocol"))
|
|
57
|
+
elif _has_typing_protocol_import(body):
|
|
58
|
+
protocol_base = cst.Name("Protocol")
|
|
59
|
+
else:
|
|
60
|
+
protocol_base = cst.Name("Protocol")
|
|
61
|
+
import_stmt = cst.SimpleStatementLine(
|
|
62
|
+
[cst.ImportFrom(module=cst.Name("typing"), names=[cst.ImportAlias(name=cst.Name("Protocol"))])]
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
doc_lines = []
|
|
66
|
+
if request.rationale:
|
|
67
|
+
doc_lines.append(f"Rationale: {request.rationale}")
|
|
68
|
+
doc_lines.append(f"Bundle: {', '.join(bundle)}")
|
|
69
|
+
docstring = cst.SimpleStatementLine(
|
|
70
|
+
[cst.Expr(cst.SimpleString('"""' + "\\n".join(doc_lines) + '"""'))]
|
|
71
|
+
)
|
|
72
|
+
warnings: list[str] = []
|
|
73
|
+
|
|
74
|
+
def _annotation_for(hint: str | None) -> cst.BaseExpression:
|
|
75
|
+
if not hint:
|
|
76
|
+
return cst.Name("object")
|
|
77
|
+
try:
|
|
78
|
+
return cst.parse_expression(hint)
|
|
79
|
+
except Exception as exc:
|
|
80
|
+
warnings.append(f"Failed to parse type hint '{hint}': {exc}")
|
|
81
|
+
return cst.Name("object")
|
|
82
|
+
|
|
83
|
+
field_lines = []
|
|
84
|
+
for spec in field_specs:
|
|
85
|
+
field_lines.append(
|
|
86
|
+
cst.SimpleStatementLine(
|
|
87
|
+
[
|
|
88
|
+
cst.AnnAssign(
|
|
89
|
+
target=cst.Name(spec.name),
|
|
90
|
+
annotation=cst.Annotation(_annotation_for(spec.type_hint)),
|
|
91
|
+
value=None,
|
|
92
|
+
)
|
|
93
|
+
]
|
|
94
|
+
)
|
|
95
|
+
)
|
|
96
|
+
class_body = [docstring] + field_lines
|
|
97
|
+
class_def = cst.ClassDef(
|
|
98
|
+
name=cst.Name(protocol),
|
|
99
|
+
bases=[cst.Arg(protocol_base)],
|
|
100
|
+
body=cst.IndentedBlock(body=class_body),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
new_body = list(body)
|
|
104
|
+
if import_stmt is not None:
|
|
105
|
+
new_body.insert(insert_idx, import_stmt)
|
|
106
|
+
insert_idx += 1
|
|
107
|
+
if insert_idx > 0 and not isinstance(new_body[insert_idx - 1], cst.EmptyLine):
|
|
108
|
+
new_body.insert(insert_idx, cst.EmptyLine())
|
|
109
|
+
insert_idx += 1
|
|
110
|
+
new_body.insert(insert_idx, class_def)
|
|
111
|
+
|
|
112
|
+
new_module = module.with_changes(body=new_body)
|
|
113
|
+
|
|
114
|
+
targets = {name.strip() for name in request.target_functions or [] if name.strip()}
|
|
115
|
+
bundle_fields = [spec.name for spec in field_specs]
|
|
116
|
+
protocol_hint = protocol
|
|
117
|
+
if targets:
|
|
118
|
+
transformer = _RefactorTransformer(
|
|
119
|
+
targets=targets,
|
|
120
|
+
bundle_fields=bundle_fields,
|
|
121
|
+
protocol_hint=protocol_hint,
|
|
122
|
+
)
|
|
123
|
+
new_module = new_module.visit(transformer)
|
|
124
|
+
warnings.extend(transformer.warnings)
|
|
125
|
+
if targets:
|
|
126
|
+
target_module = _module_name(path, self.project_root)
|
|
127
|
+
call_warnings, call_edits = _rewrite_call_sites(
|
|
128
|
+
new_module,
|
|
129
|
+
file_path=path,
|
|
130
|
+
target_path=path,
|
|
131
|
+
target_module=target_module,
|
|
132
|
+
protocol_name=protocol,
|
|
133
|
+
bundle_fields=bundle_fields,
|
|
134
|
+
targets=targets,
|
|
135
|
+
)
|
|
136
|
+
warnings.extend(call_warnings)
|
|
137
|
+
if call_edits is None:
|
|
138
|
+
new_source = new_module.code
|
|
139
|
+
else:
|
|
140
|
+
new_module = call_edits
|
|
141
|
+
new_source = new_module.code
|
|
142
|
+
else:
|
|
143
|
+
new_source = new_module.code
|
|
144
|
+
if new_source == source:
|
|
145
|
+
warnings.append("No changes generated for protocol extraction.")
|
|
146
|
+
return RefactorPlan(warnings=warnings)
|
|
147
|
+
end_line = len(source.splitlines())
|
|
148
|
+
edits = [
|
|
149
|
+
TextEdit(
|
|
150
|
+
path=str(path),
|
|
151
|
+
start=(0, 0),
|
|
152
|
+
end=(end_line, 0),
|
|
153
|
+
replacement=new_source,
|
|
154
|
+
)
|
|
155
|
+
]
|
|
156
|
+
|
|
157
|
+
if targets and self.project_root:
|
|
158
|
+
extra_edits, extra_warnings = _rewrite_call_sites_in_project(
|
|
159
|
+
project_root=self.project_root,
|
|
160
|
+
target_path=path,
|
|
161
|
+
target_module=_module_name(path, self.project_root),
|
|
162
|
+
protocol_name=protocol,
|
|
163
|
+
bundle_fields=bundle_fields,
|
|
164
|
+
targets=targets,
|
|
165
|
+
)
|
|
166
|
+
edits.extend(extra_edits)
|
|
167
|
+
warnings.extend(extra_warnings)
|
|
168
|
+
return RefactorPlan(edits=edits, warnings=warnings)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _module_name(path: Path, project_root: Path | None) -> str:
|
|
172
|
+
rel = path.with_suffix("")
|
|
173
|
+
if project_root is not None:
|
|
174
|
+
try:
|
|
175
|
+
rel = rel.relative_to(project_root)
|
|
176
|
+
except ValueError:
|
|
177
|
+
pass
|
|
178
|
+
parts = list(rel.parts)
|
|
179
|
+
if parts and parts[0] == "src":
|
|
180
|
+
parts = parts[1:]
|
|
181
|
+
return ".".join(parts)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _is_docstring(stmt: cst.CSTNode) -> bool:
|
|
185
|
+
if not isinstance(stmt, cst.SimpleStatementLine) or not stmt.body:
|
|
186
|
+
return False
|
|
187
|
+
expr = stmt.body[0]
|
|
188
|
+
return isinstance(expr, cst.Expr) and isinstance(expr.value, cst.SimpleString)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _is_import(stmt: cst.CSTNode) -> bool:
|
|
192
|
+
if not isinstance(stmt, cst.SimpleStatementLine):
|
|
193
|
+
return False
|
|
194
|
+
return any(isinstance(item, (cst.Import, cst.ImportFrom)) for item in stmt.body)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _find_import_insert_index(body: list[cst.CSTNode]) -> int:
|
|
198
|
+
insert_idx = 0
|
|
199
|
+
if body and _is_docstring(body[0]):
|
|
200
|
+
insert_idx = 1
|
|
201
|
+
while insert_idx < len(body) and _is_import(body[insert_idx]):
|
|
202
|
+
insert_idx += 1
|
|
203
|
+
return insert_idx
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _module_expr_to_str(expr: cst.BaseExpression | None) -> str | None:
|
|
207
|
+
if expr is None:
|
|
208
|
+
return None
|
|
209
|
+
if isinstance(expr, cst.Name):
|
|
210
|
+
return expr.value
|
|
211
|
+
if isinstance(expr, cst.Attribute):
|
|
212
|
+
parts = []
|
|
213
|
+
current: cst.BaseExpression | None = expr
|
|
214
|
+
while isinstance(current, cst.Attribute):
|
|
215
|
+
if isinstance(current.attr, cst.Name):
|
|
216
|
+
parts.append(current.attr.value)
|
|
217
|
+
current = current.value
|
|
218
|
+
if isinstance(current, cst.Name):
|
|
219
|
+
parts.append(current.value)
|
|
220
|
+
if parts:
|
|
221
|
+
return ".".join(reversed(parts))
|
|
222
|
+
return None
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _has_typing_import(body: list[cst.CSTNode]) -> bool:
|
|
226
|
+
for stmt in body:
|
|
227
|
+
if not isinstance(stmt, cst.SimpleStatementLine):
|
|
228
|
+
continue
|
|
229
|
+
for item in stmt.body:
|
|
230
|
+
if isinstance(item, cst.Import):
|
|
231
|
+
for alias in item.names:
|
|
232
|
+
if isinstance(alias, cst.ImportAlias) and isinstance(alias.name, cst.Name):
|
|
233
|
+
if alias.name.value == "typing":
|
|
234
|
+
return True
|
|
235
|
+
if isinstance(alias, cst.ImportAlias) and isinstance(
|
|
236
|
+
alias.name, cst.Attribute
|
|
237
|
+
):
|
|
238
|
+
if _module_expr_to_str(alias.name) == "typing":
|
|
239
|
+
return True
|
|
240
|
+
return False
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def _has_typing_protocol_import(body: list[cst.CSTNode]) -> bool:
|
|
244
|
+
for stmt in body:
|
|
245
|
+
if not isinstance(stmt, cst.SimpleStatementLine):
|
|
246
|
+
continue
|
|
247
|
+
for item in stmt.body:
|
|
248
|
+
if not isinstance(item, cst.ImportFrom):
|
|
249
|
+
continue
|
|
250
|
+
module = _module_expr_to_str(item.module)
|
|
251
|
+
if module != "typing":
|
|
252
|
+
continue
|
|
253
|
+
for alias in item.names:
|
|
254
|
+
if isinstance(alias, cst.ImportAlias) and isinstance(alias.name, cst.Name):
|
|
255
|
+
if alias.name.value == "Protocol":
|
|
256
|
+
return True
|
|
257
|
+
return False
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def _collect_import_context(
|
|
261
|
+
module: cst.Module,
|
|
262
|
+
*,
|
|
263
|
+
target_module: str,
|
|
264
|
+
protocol_name: str,
|
|
265
|
+
) -> tuple[dict[str, str], dict[str, str], str | None]:
|
|
266
|
+
module_aliases: dict[str, str] = {}
|
|
267
|
+
imported_targets: dict[str, str] = {}
|
|
268
|
+
protocol_alias: str | None = None
|
|
269
|
+
for stmt in module.body:
|
|
270
|
+
if not isinstance(stmt, cst.SimpleStatementLine):
|
|
271
|
+
continue
|
|
272
|
+
for item in stmt.body:
|
|
273
|
+
if isinstance(item, cst.Import):
|
|
274
|
+
for alias in item.names:
|
|
275
|
+
if not isinstance(alias, cst.ImportAlias):
|
|
276
|
+
continue
|
|
277
|
+
module_name = _module_expr_to_str(alias.name)
|
|
278
|
+
if not module_name:
|
|
279
|
+
continue
|
|
280
|
+
if module_name != target_module:
|
|
281
|
+
continue
|
|
282
|
+
local = alias.asname.name.value if alias.asname else module_name
|
|
283
|
+
module_aliases[local] = module_name
|
|
284
|
+
elif isinstance(item, cst.ImportFrom):
|
|
285
|
+
module_name = _module_expr_to_str(item.module)
|
|
286
|
+
if module_name != target_module:
|
|
287
|
+
continue
|
|
288
|
+
for alias in item.names:
|
|
289
|
+
if not isinstance(alias, cst.ImportAlias):
|
|
290
|
+
continue
|
|
291
|
+
if not isinstance(alias.name, cst.Name):
|
|
292
|
+
continue
|
|
293
|
+
local = alias.asname.name.value if alias.asname else alias.name.value
|
|
294
|
+
imported_targets[local] = alias.name.value
|
|
295
|
+
if alias.name.value == protocol_name:
|
|
296
|
+
protocol_alias = local
|
|
297
|
+
return module_aliases, imported_targets, protocol_alias
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _rewrite_call_sites(
|
|
301
|
+
module: cst.Module,
|
|
302
|
+
*,
|
|
303
|
+
file_path: Path,
|
|
304
|
+
target_path: Path,
|
|
305
|
+
target_module: str,
|
|
306
|
+
protocol_name: str,
|
|
307
|
+
bundle_fields: list[str],
|
|
308
|
+
targets: set[str],
|
|
309
|
+
) -> tuple[list[str], cst.Module | None]:
|
|
310
|
+
warnings: list[str] = []
|
|
311
|
+
file_is_target = file_path == target_path
|
|
312
|
+
if not targets:
|
|
313
|
+
return warnings, None
|
|
314
|
+
target_simple = {name for name in targets if "." not in name}
|
|
315
|
+
target_methods: dict[str, set[str]] = {}
|
|
316
|
+
for name in targets:
|
|
317
|
+
if "." not in name:
|
|
318
|
+
continue
|
|
319
|
+
parts = name.split(".")
|
|
320
|
+
class_name = ".".join(parts[:-1])
|
|
321
|
+
method = parts[-1]
|
|
322
|
+
target_methods.setdefault(class_name, set()).add(method)
|
|
323
|
+
module_aliases: dict[str, str] = {}
|
|
324
|
+
imported_targets: dict[str, str] = {}
|
|
325
|
+
protocol_alias: str | None = None
|
|
326
|
+
if not file_is_target and target_module:
|
|
327
|
+
module_aliases, imported_targets, protocol_alias = _collect_import_context(
|
|
328
|
+
module, target_module=target_module, protocol_name=protocol_name
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
constructor_expr: cst.BaseExpression
|
|
332
|
+
needs_import = False
|
|
333
|
+
if file_is_target:
|
|
334
|
+
constructor_expr = cst.Name(protocol_name)
|
|
335
|
+
else:
|
|
336
|
+
if protocol_alias:
|
|
337
|
+
constructor_expr = cst.Name(protocol_alias)
|
|
338
|
+
elif module_aliases:
|
|
339
|
+
alias = sorted(module_aliases.keys())[0]
|
|
340
|
+
constructor_expr = cst.Attribute(cst.Name(alias), cst.Name(protocol_name))
|
|
341
|
+
else:
|
|
342
|
+
constructor_expr = cst.Name(protocol_name)
|
|
343
|
+
needs_import = True
|
|
344
|
+
|
|
345
|
+
transformer = _CallSiteTransformer(
|
|
346
|
+
file_is_target=file_is_target,
|
|
347
|
+
target_simple=target_simple,
|
|
348
|
+
target_methods=target_methods,
|
|
349
|
+
module_aliases=set(module_aliases.keys()),
|
|
350
|
+
imported_targets=set(
|
|
351
|
+
name for name, original in imported_targets.items() if original in target_simple
|
|
352
|
+
),
|
|
353
|
+
bundle_fields=bundle_fields,
|
|
354
|
+
constructor_expr=constructor_expr,
|
|
355
|
+
)
|
|
356
|
+
new_module = module.visit(transformer)
|
|
357
|
+
warnings.extend(transformer.warnings)
|
|
358
|
+
if not transformer.changed:
|
|
359
|
+
return warnings, None
|
|
360
|
+
|
|
361
|
+
if not file_is_target and needs_import and target_module:
|
|
362
|
+
body = list(new_module.body)
|
|
363
|
+
insert_idx = _find_import_insert_index(body)
|
|
364
|
+
try:
|
|
365
|
+
module_expr = cst.parse_expression(target_module)
|
|
366
|
+
except Exception:
|
|
367
|
+
module_expr = cst.Name(target_module.split(".")[0])
|
|
368
|
+
if not isinstance(module_expr, (cst.Name, cst.Attribute)):
|
|
369
|
+
module_expr = cst.Name(target_module.split(".")[0])
|
|
370
|
+
import_stmt = cst.SimpleStatementLine(
|
|
371
|
+
[
|
|
372
|
+
cst.ImportFrom(
|
|
373
|
+
module=module_expr,
|
|
374
|
+
names=[cst.ImportAlias(name=cst.Name(protocol_name))],
|
|
375
|
+
)
|
|
376
|
+
]
|
|
377
|
+
)
|
|
378
|
+
body.insert(insert_idx, import_stmt)
|
|
379
|
+
new_module = new_module.with_changes(body=body)
|
|
380
|
+
return warnings, new_module
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def _rewrite_call_sites_in_project(
|
|
384
|
+
*,
|
|
385
|
+
project_root: Path,
|
|
386
|
+
target_path: Path,
|
|
387
|
+
target_module: str,
|
|
388
|
+
protocol_name: str,
|
|
389
|
+
bundle_fields: list[str],
|
|
390
|
+
targets: set[str],
|
|
391
|
+
) -> tuple[list[TextEdit], list[str]]:
|
|
392
|
+
edits: list[TextEdit] = []
|
|
393
|
+
warnings: list[str] = []
|
|
394
|
+
scan_root = project_root / "src"
|
|
395
|
+
if not scan_root.exists():
|
|
396
|
+
scan_root = project_root
|
|
397
|
+
for path in sorted(scan_root.rglob("*.py")):
|
|
398
|
+
if path == target_path:
|
|
399
|
+
continue
|
|
400
|
+
try:
|
|
401
|
+
source = path.read_text()
|
|
402
|
+
except Exception as exc:
|
|
403
|
+
warnings.append(f"Failed to read {path}: {exc}")
|
|
404
|
+
continue
|
|
405
|
+
try:
|
|
406
|
+
module = cst.parse_module(source)
|
|
407
|
+
except Exception as exc:
|
|
408
|
+
warnings.append(f"LibCST parse failed for {path}: {exc}")
|
|
409
|
+
continue
|
|
410
|
+
call_warnings, updated_module = _rewrite_call_sites(
|
|
411
|
+
module,
|
|
412
|
+
file_path=path,
|
|
413
|
+
target_path=target_path,
|
|
414
|
+
target_module=target_module,
|
|
415
|
+
protocol_name=protocol_name,
|
|
416
|
+
bundle_fields=bundle_fields,
|
|
417
|
+
targets=targets,
|
|
418
|
+
)
|
|
419
|
+
warnings.extend(call_warnings)
|
|
420
|
+
if updated_module is None:
|
|
421
|
+
continue
|
|
422
|
+
new_source = updated_module.code
|
|
423
|
+
if new_source == source:
|
|
424
|
+
continue
|
|
425
|
+
end_line = len(source.splitlines())
|
|
426
|
+
edits.append(
|
|
427
|
+
TextEdit(
|
|
428
|
+
path=str(path),
|
|
429
|
+
start=(0, 0),
|
|
430
|
+
end=(end_line, 0),
|
|
431
|
+
replacement=new_source,
|
|
432
|
+
)
|
|
433
|
+
)
|
|
434
|
+
return edits, warnings
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
class _RefactorTransformer(cst.CSTTransformer):
|
|
438
|
+
def __init__(
|
|
439
|
+
self,
|
|
440
|
+
*,
|
|
441
|
+
targets: set[str],
|
|
442
|
+
bundle_fields: list[str],
|
|
443
|
+
protocol_hint: str,
|
|
444
|
+
) -> None:
|
|
445
|
+
# dataflow-bundle: bundle_fields, protocol_hint, targets
|
|
446
|
+
self.targets = targets
|
|
447
|
+
self.bundle_fields = bundle_fields
|
|
448
|
+
self.protocol_hint = protocol_hint
|
|
449
|
+
self.warnings: list[str] = []
|
|
450
|
+
self._stack: list[str] = []
|
|
451
|
+
|
|
452
|
+
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
|
453
|
+
self._stack.append(node.name.value)
|
|
454
|
+
return True
|
|
455
|
+
|
|
456
|
+
def leave_ClassDef(
|
|
457
|
+
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
|
|
458
|
+
) -> cst.CSTNode:
|
|
459
|
+
if self._stack:
|
|
460
|
+
self._stack.pop()
|
|
461
|
+
return updated_node
|
|
462
|
+
|
|
463
|
+
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
|
|
464
|
+
self._stack.append(node.name.value)
|
|
465
|
+
return True
|
|
466
|
+
|
|
467
|
+
def leave_FunctionDef(
|
|
468
|
+
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
|
|
469
|
+
) -> cst.CSTNode:
|
|
470
|
+
updated = self._maybe_rewrite_function(original_node, updated_node)
|
|
471
|
+
if self._stack:
|
|
472
|
+
self._stack.pop()
|
|
473
|
+
return updated
|
|
474
|
+
|
|
475
|
+
def visit_AsyncFunctionDef(self, node: cst.AsyncFunctionDef) -> bool:
|
|
476
|
+
self._stack.append(node.name.value)
|
|
477
|
+
return True
|
|
478
|
+
|
|
479
|
+
def leave_AsyncFunctionDef(
|
|
480
|
+
self, original_node: cst.AsyncFunctionDef, updated_node: cst.AsyncFunctionDef
|
|
481
|
+
) -> cst.CSTNode:
|
|
482
|
+
updated = self._maybe_rewrite_function(original_node, updated_node)
|
|
483
|
+
if self._stack:
|
|
484
|
+
self._stack.pop()
|
|
485
|
+
return updated
|
|
486
|
+
|
|
487
|
+
def _maybe_rewrite_function(
|
|
488
|
+
self,
|
|
489
|
+
original_node: cst.FunctionDef | cst.AsyncFunctionDef,
|
|
490
|
+
updated_node: cst.FunctionDef | cst.AsyncFunctionDef,
|
|
491
|
+
) -> cst.CSTNode:
|
|
492
|
+
qualname = ".".join(self._stack)
|
|
493
|
+
name = original_node.name.value
|
|
494
|
+
if name not in self.targets and qualname not in self.targets:
|
|
495
|
+
return updated_node
|
|
496
|
+
|
|
497
|
+
original_params = self._ordered_param_names(original_node.params)
|
|
498
|
+
if not original_params:
|
|
499
|
+
return updated_node
|
|
500
|
+
|
|
501
|
+
keep_self = original_params[0] in {"self", "cls"}
|
|
502
|
+
self_param = None
|
|
503
|
+
if keep_self:
|
|
504
|
+
self_param = self._find_self_param(original_node.params, original_params[0])
|
|
505
|
+
|
|
506
|
+
bundle_param_name = self._choose_bundle_name(original_params)
|
|
507
|
+
new_params = self._build_parameters(self_param, bundle_param_name)
|
|
508
|
+
|
|
509
|
+
bundle_set = set(self.bundle_fields)
|
|
510
|
+
target_fields = [
|
|
511
|
+
name
|
|
512
|
+
for name in original_params
|
|
513
|
+
if name in bundle_set and name not in {"self", "cls"}
|
|
514
|
+
]
|
|
515
|
+
new_body = self._inject_preamble(
|
|
516
|
+
updated_node.body, bundle_param_name, target_fields
|
|
517
|
+
)
|
|
518
|
+
return updated_node.with_changes(params=new_params, body=new_body)
|
|
519
|
+
|
|
520
|
+
def _ordered_param_names(self, params: cst.Parameters) -> list[str]:
|
|
521
|
+
names: list[str] = []
|
|
522
|
+
for param in params.posonly_params:
|
|
523
|
+
names.append(param.name.value)
|
|
524
|
+
for param in params.params:
|
|
525
|
+
names.append(param.name.value)
|
|
526
|
+
for param in params.kwonly_params:
|
|
527
|
+
names.append(param.name.value)
|
|
528
|
+
return names
|
|
529
|
+
|
|
530
|
+
def _find_self_param(
|
|
531
|
+
self, params: cst.Parameters, name: str
|
|
532
|
+
) -> cst.Param | None:
|
|
533
|
+
for param in params.posonly_params:
|
|
534
|
+
if param.name.value == name:
|
|
535
|
+
return param
|
|
536
|
+
for param in params.params:
|
|
537
|
+
if param.name.value == name:
|
|
538
|
+
return param
|
|
539
|
+
return None
|
|
540
|
+
|
|
541
|
+
def _choose_bundle_name(self, existing: list[str]) -> str:
|
|
542
|
+
candidate = "bundle"
|
|
543
|
+
if candidate not in existing:
|
|
544
|
+
return candidate
|
|
545
|
+
idx = 1
|
|
546
|
+
while f"bundle_{idx}" in existing:
|
|
547
|
+
idx += 1
|
|
548
|
+
return f"bundle_{idx}"
|
|
549
|
+
|
|
550
|
+
def _build_parameters(
|
|
551
|
+
self, self_param: cst.Param | None, bundle_name: str
|
|
552
|
+
) -> cst.Parameters:
|
|
553
|
+
params: list[cst.Param] = []
|
|
554
|
+
if self_param is not None:
|
|
555
|
+
params.append(self_param)
|
|
556
|
+
annotation = None
|
|
557
|
+
if self.protocol_hint:
|
|
558
|
+
try:
|
|
559
|
+
annotation = cst.Annotation(cst.parse_expression(self.protocol_hint))
|
|
560
|
+
except Exception as exc:
|
|
561
|
+
self.warnings.append(
|
|
562
|
+
f"Failed to parse protocol type hint '{self.protocol_hint}': {exc}"
|
|
563
|
+
)
|
|
564
|
+
params.append(
|
|
565
|
+
cst.Param(
|
|
566
|
+
name=cst.Name(bundle_name),
|
|
567
|
+
annotation=annotation,
|
|
568
|
+
)
|
|
569
|
+
)
|
|
570
|
+
return cst.Parameters(
|
|
571
|
+
params=params,
|
|
572
|
+
star_arg=cst.MaybeSentinel.DEFAULT,
|
|
573
|
+
kwonly_params=[],
|
|
574
|
+
star_kwarg=None,
|
|
575
|
+
posonly_params=[],
|
|
576
|
+
posonly_ind=cst.MaybeSentinel.DEFAULT,
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
def _inject_preamble(
|
|
580
|
+
self, body: cst.BaseSuite, bundle_name: str, fields: list[str]
|
|
581
|
+
) -> cst.BaseSuite:
|
|
582
|
+
if not fields:
|
|
583
|
+
return body
|
|
584
|
+
if not isinstance(body, cst.IndentedBlock):
|
|
585
|
+
return body
|
|
586
|
+
assign_lines = [
|
|
587
|
+
cst.SimpleStatementLine(
|
|
588
|
+
[
|
|
589
|
+
cst.Assign(
|
|
590
|
+
targets=[cst.AssignTarget(cst.Name(name))],
|
|
591
|
+
value=cst.Attribute(
|
|
592
|
+
value=cst.Name(bundle_name),
|
|
593
|
+
attr=cst.Name(name),
|
|
594
|
+
),
|
|
595
|
+
)
|
|
596
|
+
]
|
|
597
|
+
)
|
|
598
|
+
for name in fields
|
|
599
|
+
]
|
|
600
|
+
existing = list(body.body)
|
|
601
|
+
insert_at = 0
|
|
602
|
+
if existing:
|
|
603
|
+
first = existing[0]
|
|
604
|
+
if isinstance(first, cst.SimpleStatementLine) and first.body:
|
|
605
|
+
expr = first.body[0]
|
|
606
|
+
if isinstance(expr, cst.Expr) and isinstance(expr.value, cst.SimpleString):
|
|
607
|
+
insert_at = 1
|
|
608
|
+
new_body = existing[:insert_at] + assign_lines + existing[insert_at:]
|
|
609
|
+
return body.with_changes(body=new_body)
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
class _CallSiteTransformer(cst.CSTTransformer):
|
|
613
|
+
def __init__(
|
|
614
|
+
self,
|
|
615
|
+
*,
|
|
616
|
+
file_is_target: bool,
|
|
617
|
+
target_simple: set[str],
|
|
618
|
+
target_methods: dict[str, set[str]],
|
|
619
|
+
module_aliases: set[str],
|
|
620
|
+
imported_targets: set[str],
|
|
621
|
+
bundle_fields: list[str],
|
|
622
|
+
constructor_expr: cst.BaseExpression,
|
|
623
|
+
) -> None:
|
|
624
|
+
# dataflow-bundle: bundle_fields, constructor_expr, file_is_target, imported_targets, module_aliases, target_methods, target_simple
|
|
625
|
+
self.file_is_target = file_is_target
|
|
626
|
+
self.target_simple = target_simple
|
|
627
|
+
self.target_methods = target_methods
|
|
628
|
+
self.module_aliases = module_aliases
|
|
629
|
+
self.imported_targets = imported_targets
|
|
630
|
+
self.bundle_fields = bundle_fields
|
|
631
|
+
self.constructor_expr = constructor_expr
|
|
632
|
+
self.changed = False
|
|
633
|
+
self.warnings: list[str] = []
|
|
634
|
+
self._class_stack: list[str] = []
|
|
635
|
+
|
|
636
|
+
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
|
637
|
+
self._class_stack.append(node.name.value)
|
|
638
|
+
return True
|
|
639
|
+
|
|
640
|
+
def leave_ClassDef(
|
|
641
|
+
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
|
|
642
|
+
) -> cst.CSTNode:
|
|
643
|
+
if self._class_stack:
|
|
644
|
+
self._class_stack.pop()
|
|
645
|
+
return updated_node
|
|
646
|
+
|
|
647
|
+
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
|
|
648
|
+
if not self._is_target_call(updated_node.func):
|
|
649
|
+
return updated_node
|
|
650
|
+
if self._already_wrapped(updated_node):
|
|
651
|
+
return updated_node
|
|
652
|
+
bundle_args = self._build_bundle_args(updated_node)
|
|
653
|
+
if bundle_args is None:
|
|
654
|
+
return updated_node
|
|
655
|
+
bundle_call = cst.Call(func=self.constructor_expr, args=bundle_args)
|
|
656
|
+
self.changed = True
|
|
657
|
+
return updated_node.with_changes(args=[cst.Arg(value=bundle_call)])
|
|
658
|
+
|
|
659
|
+
def _is_target_call(self, func: cst.BaseExpression) -> bool:
|
|
660
|
+
if isinstance(func, cst.Name):
|
|
661
|
+
if self.file_is_target and func.value in self.target_simple:
|
|
662
|
+
return True
|
|
663
|
+
if not self.file_is_target and func.value in self.imported_targets:
|
|
664
|
+
return True
|
|
665
|
+
return False
|
|
666
|
+
if isinstance(func, cst.Attribute):
|
|
667
|
+
if not isinstance(func.attr, cst.Name):
|
|
668
|
+
return False
|
|
669
|
+
attr = func.attr.value
|
|
670
|
+
if self.file_is_target and self._class_stack:
|
|
671
|
+
class_name = ".".join(self._class_stack)
|
|
672
|
+
methods = self.target_methods.get(class_name, set())
|
|
673
|
+
if attr in methods and isinstance(func.value, cst.Name):
|
|
674
|
+
if func.value.value in {"self", "cls", self._class_stack[-1]}:
|
|
675
|
+
return True
|
|
676
|
+
if not self.file_is_target and isinstance(func.value, cst.Name):
|
|
677
|
+
if func.value.value in self.module_aliases and attr in self.target_simple:
|
|
678
|
+
return True
|
|
679
|
+
return False
|
|
680
|
+
|
|
681
|
+
def _already_wrapped(self, call: cst.Call) -> bool:
|
|
682
|
+
if len(call.args) != 1:
|
|
683
|
+
return False
|
|
684
|
+
arg = call.args[0]
|
|
685
|
+
if arg.star:
|
|
686
|
+
return False
|
|
687
|
+
value = arg.value
|
|
688
|
+
if not isinstance(value, cst.Call):
|
|
689
|
+
return False
|
|
690
|
+
if isinstance(value.func, cst.Name) and isinstance(self.constructor_expr, cst.Name):
|
|
691
|
+
return value.func.value == self.constructor_expr.value
|
|
692
|
+
if isinstance(value.func, cst.Attribute) and isinstance(self.constructor_expr, cst.Attribute):
|
|
693
|
+
return value.func.attr.value == self.constructor_expr.attr.value
|
|
694
|
+
return False
|
|
695
|
+
|
|
696
|
+
def _build_bundle_args(self, call: cst.Call) -> list[cst.Arg] | None:
|
|
697
|
+
if any(arg.star in {"*", "**"} for arg in call.args):
|
|
698
|
+
self.warnings.append("Skipped call with star args/kwargs during refactor.")
|
|
699
|
+
return None
|
|
700
|
+
positional = [arg for arg in call.args if arg.keyword is None]
|
|
701
|
+
keyword_args = {
|
|
702
|
+
arg.keyword.value: arg.value
|
|
703
|
+
for arg in call.args
|
|
704
|
+
if arg.keyword is not None and isinstance(arg.keyword, cst.Name)
|
|
705
|
+
}
|
|
706
|
+
for key in keyword_args:
|
|
707
|
+
if key not in self.bundle_fields:
|
|
708
|
+
self.warnings.append(
|
|
709
|
+
f"Skipped call with unknown keyword '{key}' during refactor."
|
|
710
|
+
)
|
|
711
|
+
return None
|
|
712
|
+
mapping: dict[str, cst.BaseExpression] = {}
|
|
713
|
+
mapping.update(keyword_args)
|
|
714
|
+
remaining = [field for field in self.bundle_fields if field not in mapping]
|
|
715
|
+
if len(positional) > len(remaining):
|
|
716
|
+
self.warnings.append("Skipped call with extra positional args during refactor.")
|
|
717
|
+
return None
|
|
718
|
+
for field, arg in zip(remaining, positional):
|
|
719
|
+
mapping[field] = arg.value
|
|
720
|
+
if len(mapping) != len(self.bundle_fields):
|
|
721
|
+
self.warnings.append("Skipped call with missing bundle fields during refactor.")
|
|
722
|
+
return None
|
|
723
|
+
return [
|
|
724
|
+
cst.Arg(keyword=cst.Name(field), value=mapping[field])
|
|
725
|
+
for field in self.bundle_fields
|
|
726
|
+
]
|