gabion 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,726 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ import libcst as cst
6
+
7
+ from gabion.refactor.model import FieldSpec, RefactorPlan, RefactorRequest, TextEdit
8
+
9
+
10
+ class RefactorEngine:
11
+ def __init__(self, project_root: Path | None = None) -> None:
12
+ self.project_root = project_root
13
+
14
+ def plan_protocol_extraction(self, request: RefactorRequest) -> RefactorPlan:
15
+ path = Path(request.target_path)
16
+ if self.project_root and not path.is_absolute():
17
+ path = self.project_root / path
18
+ try:
19
+ source = path.read_text()
20
+ except Exception as exc:
21
+ return RefactorPlan(errors=[f"Failed to read {path}: {exc}"])
22
+ try:
23
+ module = cst.parse_module(source)
24
+ except Exception as exc:
25
+ return RefactorPlan(errors=[f"LibCST parse failed for {path}: {exc}"])
26
+ protocol = (request.protocol_name or "").strip()
27
+ if not protocol:
28
+ return RefactorPlan(errors=["Protocol name is required for extraction."])
29
+ bundle = [name.strip() for name in request.bundle or [] if name.strip()]
30
+ field_specs: list[FieldSpec] = []
31
+ seen_fields: set[str] = set()
32
+ for spec in request.fields or []:
33
+ name = (spec.name or "").strip()
34
+ if not name or name in seen_fields:
35
+ continue
36
+ seen_fields.add(name)
37
+ field_specs.append(spec)
38
+ if bundle:
39
+ for name in bundle:
40
+ if name in seen_fields:
41
+ continue
42
+ seen_fields.add(name)
43
+ field_specs.append(FieldSpec(name=name))
44
+ elif field_specs:
45
+ bundle = [spec.name for spec in field_specs]
46
+ if not bundle:
47
+ return RefactorPlan(errors=["Bundle fields are required for extraction."])
48
+
49
+ body = list(module.body)
50
+
51
+ insert_idx = _find_import_insert_index(body)
52
+
53
+ protocol_base: cst.CSTNode
54
+ import_stmt: cst.SimpleStatementLine | None = None
55
+ if _has_typing_import(body):
56
+ protocol_base = cst.Attribute(cst.Name("typing"), cst.Name("Protocol"))
57
+ elif _has_typing_protocol_import(body):
58
+ protocol_base = cst.Name("Protocol")
59
+ else:
60
+ protocol_base = cst.Name("Protocol")
61
+ import_stmt = cst.SimpleStatementLine(
62
+ [cst.ImportFrom(module=cst.Name("typing"), names=[cst.ImportAlias(name=cst.Name("Protocol"))])]
63
+ )
64
+
65
+ doc_lines = []
66
+ if request.rationale:
67
+ doc_lines.append(f"Rationale: {request.rationale}")
68
+ doc_lines.append(f"Bundle: {', '.join(bundle)}")
69
+ docstring = cst.SimpleStatementLine(
70
+ [cst.Expr(cst.SimpleString('"""' + "\\n".join(doc_lines) + '"""'))]
71
+ )
72
+ warnings: list[str] = []
73
+
74
+ def _annotation_for(hint: str | None) -> cst.BaseExpression:
75
+ if not hint:
76
+ return cst.Name("object")
77
+ try:
78
+ return cst.parse_expression(hint)
79
+ except Exception as exc:
80
+ warnings.append(f"Failed to parse type hint '{hint}': {exc}")
81
+ return cst.Name("object")
82
+
83
+ field_lines = []
84
+ for spec in field_specs:
85
+ field_lines.append(
86
+ cst.SimpleStatementLine(
87
+ [
88
+ cst.AnnAssign(
89
+ target=cst.Name(spec.name),
90
+ annotation=cst.Annotation(_annotation_for(spec.type_hint)),
91
+ value=None,
92
+ )
93
+ ]
94
+ )
95
+ )
96
+ class_body = [docstring] + field_lines
97
+ class_def = cst.ClassDef(
98
+ name=cst.Name(protocol),
99
+ bases=[cst.Arg(protocol_base)],
100
+ body=cst.IndentedBlock(body=class_body),
101
+ )
102
+
103
+ new_body = list(body)
104
+ if import_stmt is not None:
105
+ new_body.insert(insert_idx, import_stmt)
106
+ insert_idx += 1
107
+ if insert_idx > 0 and not isinstance(new_body[insert_idx - 1], cst.EmptyLine):
108
+ new_body.insert(insert_idx, cst.EmptyLine())
109
+ insert_idx += 1
110
+ new_body.insert(insert_idx, class_def)
111
+
112
+ new_module = module.with_changes(body=new_body)
113
+
114
+ targets = {name.strip() for name in request.target_functions or [] if name.strip()}
115
+ bundle_fields = [spec.name for spec in field_specs]
116
+ protocol_hint = protocol
117
+ if targets:
118
+ transformer = _RefactorTransformer(
119
+ targets=targets,
120
+ bundle_fields=bundle_fields,
121
+ protocol_hint=protocol_hint,
122
+ )
123
+ new_module = new_module.visit(transformer)
124
+ warnings.extend(transformer.warnings)
125
+ if targets:
126
+ target_module = _module_name(path, self.project_root)
127
+ call_warnings, call_edits = _rewrite_call_sites(
128
+ new_module,
129
+ file_path=path,
130
+ target_path=path,
131
+ target_module=target_module,
132
+ protocol_name=protocol,
133
+ bundle_fields=bundle_fields,
134
+ targets=targets,
135
+ )
136
+ warnings.extend(call_warnings)
137
+ if call_edits is None:
138
+ new_source = new_module.code
139
+ else:
140
+ new_module = call_edits
141
+ new_source = new_module.code
142
+ else:
143
+ new_source = new_module.code
144
+ if new_source == source:
145
+ warnings.append("No changes generated for protocol extraction.")
146
+ return RefactorPlan(warnings=warnings)
147
+ end_line = len(source.splitlines())
148
+ edits = [
149
+ TextEdit(
150
+ path=str(path),
151
+ start=(0, 0),
152
+ end=(end_line, 0),
153
+ replacement=new_source,
154
+ )
155
+ ]
156
+
157
+ if targets and self.project_root:
158
+ extra_edits, extra_warnings = _rewrite_call_sites_in_project(
159
+ project_root=self.project_root,
160
+ target_path=path,
161
+ target_module=_module_name(path, self.project_root),
162
+ protocol_name=protocol,
163
+ bundle_fields=bundle_fields,
164
+ targets=targets,
165
+ )
166
+ edits.extend(extra_edits)
167
+ warnings.extend(extra_warnings)
168
+ return RefactorPlan(edits=edits, warnings=warnings)
169
+
170
+
171
+ def _module_name(path: Path, project_root: Path | None) -> str:
172
+ rel = path.with_suffix("")
173
+ if project_root is not None:
174
+ try:
175
+ rel = rel.relative_to(project_root)
176
+ except ValueError:
177
+ pass
178
+ parts = list(rel.parts)
179
+ if parts and parts[0] == "src":
180
+ parts = parts[1:]
181
+ return ".".join(parts)
182
+
183
+
184
+ def _is_docstring(stmt: cst.CSTNode) -> bool:
185
+ if not isinstance(stmt, cst.SimpleStatementLine) or not stmt.body:
186
+ return False
187
+ expr = stmt.body[0]
188
+ return isinstance(expr, cst.Expr) and isinstance(expr.value, cst.SimpleString)
189
+
190
+
191
+ def _is_import(stmt: cst.CSTNode) -> bool:
192
+ if not isinstance(stmt, cst.SimpleStatementLine):
193
+ return False
194
+ return any(isinstance(item, (cst.Import, cst.ImportFrom)) for item in stmt.body)
195
+
196
+
197
+ def _find_import_insert_index(body: list[cst.CSTNode]) -> int:
198
+ insert_idx = 0
199
+ if body and _is_docstring(body[0]):
200
+ insert_idx = 1
201
+ while insert_idx < len(body) and _is_import(body[insert_idx]):
202
+ insert_idx += 1
203
+ return insert_idx
204
+
205
+
206
+ def _module_expr_to_str(expr: cst.BaseExpression | None) -> str | None:
207
+ if expr is None:
208
+ return None
209
+ if isinstance(expr, cst.Name):
210
+ return expr.value
211
+ if isinstance(expr, cst.Attribute):
212
+ parts = []
213
+ current: cst.BaseExpression | None = expr
214
+ while isinstance(current, cst.Attribute):
215
+ if isinstance(current.attr, cst.Name):
216
+ parts.append(current.attr.value)
217
+ current = current.value
218
+ if isinstance(current, cst.Name):
219
+ parts.append(current.value)
220
+ if parts:
221
+ return ".".join(reversed(parts))
222
+ return None
223
+
224
+
225
+ def _has_typing_import(body: list[cst.CSTNode]) -> bool:
226
+ for stmt in body:
227
+ if not isinstance(stmt, cst.SimpleStatementLine):
228
+ continue
229
+ for item in stmt.body:
230
+ if isinstance(item, cst.Import):
231
+ for alias in item.names:
232
+ if isinstance(alias, cst.ImportAlias) and isinstance(alias.name, cst.Name):
233
+ if alias.name.value == "typing":
234
+ return True
235
+ if isinstance(alias, cst.ImportAlias) and isinstance(
236
+ alias.name, cst.Attribute
237
+ ):
238
+ if _module_expr_to_str(alias.name) == "typing":
239
+ return True
240
+ return False
241
+
242
+
243
+ def _has_typing_protocol_import(body: list[cst.CSTNode]) -> bool:
244
+ for stmt in body:
245
+ if not isinstance(stmt, cst.SimpleStatementLine):
246
+ continue
247
+ for item in stmt.body:
248
+ if not isinstance(item, cst.ImportFrom):
249
+ continue
250
+ module = _module_expr_to_str(item.module)
251
+ if module != "typing":
252
+ continue
253
+ for alias in item.names:
254
+ if isinstance(alias, cst.ImportAlias) and isinstance(alias.name, cst.Name):
255
+ if alias.name.value == "Protocol":
256
+ return True
257
+ return False
258
+
259
+
260
+ def _collect_import_context(
261
+ module: cst.Module,
262
+ *,
263
+ target_module: str,
264
+ protocol_name: str,
265
+ ) -> tuple[dict[str, str], dict[str, str], str | None]:
266
+ module_aliases: dict[str, str] = {}
267
+ imported_targets: dict[str, str] = {}
268
+ protocol_alias: str | None = None
269
+ for stmt in module.body:
270
+ if not isinstance(stmt, cst.SimpleStatementLine):
271
+ continue
272
+ for item in stmt.body:
273
+ if isinstance(item, cst.Import):
274
+ for alias in item.names:
275
+ if not isinstance(alias, cst.ImportAlias):
276
+ continue
277
+ module_name = _module_expr_to_str(alias.name)
278
+ if not module_name:
279
+ continue
280
+ if module_name != target_module:
281
+ continue
282
+ local = alias.asname.name.value if alias.asname else module_name
283
+ module_aliases[local] = module_name
284
+ elif isinstance(item, cst.ImportFrom):
285
+ module_name = _module_expr_to_str(item.module)
286
+ if module_name != target_module:
287
+ continue
288
+ for alias in item.names:
289
+ if not isinstance(alias, cst.ImportAlias):
290
+ continue
291
+ if not isinstance(alias.name, cst.Name):
292
+ continue
293
+ local = alias.asname.name.value if alias.asname else alias.name.value
294
+ imported_targets[local] = alias.name.value
295
+ if alias.name.value == protocol_name:
296
+ protocol_alias = local
297
+ return module_aliases, imported_targets, protocol_alias
298
+
299
+
300
+ def _rewrite_call_sites(
301
+ module: cst.Module,
302
+ *,
303
+ file_path: Path,
304
+ target_path: Path,
305
+ target_module: str,
306
+ protocol_name: str,
307
+ bundle_fields: list[str],
308
+ targets: set[str],
309
+ ) -> tuple[list[str], cst.Module | None]:
310
+ warnings: list[str] = []
311
+ file_is_target = file_path == target_path
312
+ if not targets:
313
+ return warnings, None
314
+ target_simple = {name for name in targets if "." not in name}
315
+ target_methods: dict[str, set[str]] = {}
316
+ for name in targets:
317
+ if "." not in name:
318
+ continue
319
+ parts = name.split(".")
320
+ class_name = ".".join(parts[:-1])
321
+ method = parts[-1]
322
+ target_methods.setdefault(class_name, set()).add(method)
323
+ module_aliases: dict[str, str] = {}
324
+ imported_targets: dict[str, str] = {}
325
+ protocol_alias: str | None = None
326
+ if not file_is_target and target_module:
327
+ module_aliases, imported_targets, protocol_alias = _collect_import_context(
328
+ module, target_module=target_module, protocol_name=protocol_name
329
+ )
330
+
331
+ constructor_expr: cst.BaseExpression
332
+ needs_import = False
333
+ if file_is_target:
334
+ constructor_expr = cst.Name(protocol_name)
335
+ else:
336
+ if protocol_alias:
337
+ constructor_expr = cst.Name(protocol_alias)
338
+ elif module_aliases:
339
+ alias = sorted(module_aliases.keys())[0]
340
+ constructor_expr = cst.Attribute(cst.Name(alias), cst.Name(protocol_name))
341
+ else:
342
+ constructor_expr = cst.Name(protocol_name)
343
+ needs_import = True
344
+
345
+ transformer = _CallSiteTransformer(
346
+ file_is_target=file_is_target,
347
+ target_simple=target_simple,
348
+ target_methods=target_methods,
349
+ module_aliases=set(module_aliases.keys()),
350
+ imported_targets=set(
351
+ name for name, original in imported_targets.items() if original in target_simple
352
+ ),
353
+ bundle_fields=bundle_fields,
354
+ constructor_expr=constructor_expr,
355
+ )
356
+ new_module = module.visit(transformer)
357
+ warnings.extend(transformer.warnings)
358
+ if not transformer.changed:
359
+ return warnings, None
360
+
361
+ if not file_is_target and needs_import and target_module:
362
+ body = list(new_module.body)
363
+ insert_idx = _find_import_insert_index(body)
364
+ try:
365
+ module_expr = cst.parse_expression(target_module)
366
+ except Exception:
367
+ module_expr = cst.Name(target_module.split(".")[0])
368
+ if not isinstance(module_expr, (cst.Name, cst.Attribute)):
369
+ module_expr = cst.Name(target_module.split(".")[0])
370
+ import_stmt = cst.SimpleStatementLine(
371
+ [
372
+ cst.ImportFrom(
373
+ module=module_expr,
374
+ names=[cst.ImportAlias(name=cst.Name(protocol_name))],
375
+ )
376
+ ]
377
+ )
378
+ body.insert(insert_idx, import_stmt)
379
+ new_module = new_module.with_changes(body=body)
380
+ return warnings, new_module
381
+
382
+
383
+ def _rewrite_call_sites_in_project(
384
+ *,
385
+ project_root: Path,
386
+ target_path: Path,
387
+ target_module: str,
388
+ protocol_name: str,
389
+ bundle_fields: list[str],
390
+ targets: set[str],
391
+ ) -> tuple[list[TextEdit], list[str]]:
392
+ edits: list[TextEdit] = []
393
+ warnings: list[str] = []
394
+ scan_root = project_root / "src"
395
+ if not scan_root.exists():
396
+ scan_root = project_root
397
+ for path in sorted(scan_root.rglob("*.py")):
398
+ if path == target_path:
399
+ continue
400
+ try:
401
+ source = path.read_text()
402
+ except Exception as exc:
403
+ warnings.append(f"Failed to read {path}: {exc}")
404
+ continue
405
+ try:
406
+ module = cst.parse_module(source)
407
+ except Exception as exc:
408
+ warnings.append(f"LibCST parse failed for {path}: {exc}")
409
+ continue
410
+ call_warnings, updated_module = _rewrite_call_sites(
411
+ module,
412
+ file_path=path,
413
+ target_path=target_path,
414
+ target_module=target_module,
415
+ protocol_name=protocol_name,
416
+ bundle_fields=bundle_fields,
417
+ targets=targets,
418
+ )
419
+ warnings.extend(call_warnings)
420
+ if updated_module is None:
421
+ continue
422
+ new_source = updated_module.code
423
+ if new_source == source:
424
+ continue
425
+ end_line = len(source.splitlines())
426
+ edits.append(
427
+ TextEdit(
428
+ path=str(path),
429
+ start=(0, 0),
430
+ end=(end_line, 0),
431
+ replacement=new_source,
432
+ )
433
+ )
434
+ return edits, warnings
435
+
436
+
437
+ class _RefactorTransformer(cst.CSTTransformer):
438
+ def __init__(
439
+ self,
440
+ *,
441
+ targets: set[str],
442
+ bundle_fields: list[str],
443
+ protocol_hint: str,
444
+ ) -> None:
445
+ # dataflow-bundle: bundle_fields, protocol_hint, targets
446
+ self.targets = targets
447
+ self.bundle_fields = bundle_fields
448
+ self.protocol_hint = protocol_hint
449
+ self.warnings: list[str] = []
450
+ self._stack: list[str] = []
451
+
452
+ def visit_ClassDef(self, node: cst.ClassDef) -> bool:
453
+ self._stack.append(node.name.value)
454
+ return True
455
+
456
+ def leave_ClassDef(
457
+ self, original_node: cst.ClassDef, updated_node: cst.ClassDef
458
+ ) -> cst.CSTNode:
459
+ if self._stack:
460
+ self._stack.pop()
461
+ return updated_node
462
+
463
+ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
464
+ self._stack.append(node.name.value)
465
+ return True
466
+
467
+ def leave_FunctionDef(
468
+ self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
469
+ ) -> cst.CSTNode:
470
+ updated = self._maybe_rewrite_function(original_node, updated_node)
471
+ if self._stack:
472
+ self._stack.pop()
473
+ return updated
474
+
475
+ def visit_AsyncFunctionDef(self, node: cst.AsyncFunctionDef) -> bool:
476
+ self._stack.append(node.name.value)
477
+ return True
478
+
479
+ def leave_AsyncFunctionDef(
480
+ self, original_node: cst.AsyncFunctionDef, updated_node: cst.AsyncFunctionDef
481
+ ) -> cst.CSTNode:
482
+ updated = self._maybe_rewrite_function(original_node, updated_node)
483
+ if self._stack:
484
+ self._stack.pop()
485
+ return updated
486
+
487
+ def _maybe_rewrite_function(
488
+ self,
489
+ original_node: cst.FunctionDef | cst.AsyncFunctionDef,
490
+ updated_node: cst.FunctionDef | cst.AsyncFunctionDef,
491
+ ) -> cst.CSTNode:
492
+ qualname = ".".join(self._stack)
493
+ name = original_node.name.value
494
+ if name not in self.targets and qualname not in self.targets:
495
+ return updated_node
496
+
497
+ original_params = self._ordered_param_names(original_node.params)
498
+ if not original_params:
499
+ return updated_node
500
+
501
+ keep_self = original_params[0] in {"self", "cls"}
502
+ self_param = None
503
+ if keep_self:
504
+ self_param = self._find_self_param(original_node.params, original_params[0])
505
+
506
+ bundle_param_name = self._choose_bundle_name(original_params)
507
+ new_params = self._build_parameters(self_param, bundle_param_name)
508
+
509
+ bundle_set = set(self.bundle_fields)
510
+ target_fields = [
511
+ name
512
+ for name in original_params
513
+ if name in bundle_set and name not in {"self", "cls"}
514
+ ]
515
+ new_body = self._inject_preamble(
516
+ updated_node.body, bundle_param_name, target_fields
517
+ )
518
+ return updated_node.with_changes(params=new_params, body=new_body)
519
+
520
+ def _ordered_param_names(self, params: cst.Parameters) -> list[str]:
521
+ names: list[str] = []
522
+ for param in params.posonly_params:
523
+ names.append(param.name.value)
524
+ for param in params.params:
525
+ names.append(param.name.value)
526
+ for param in params.kwonly_params:
527
+ names.append(param.name.value)
528
+ return names
529
+
530
+ def _find_self_param(
531
+ self, params: cst.Parameters, name: str
532
+ ) -> cst.Param | None:
533
+ for param in params.posonly_params:
534
+ if param.name.value == name:
535
+ return param
536
+ for param in params.params:
537
+ if param.name.value == name:
538
+ return param
539
+ return None
540
+
541
+ def _choose_bundle_name(self, existing: list[str]) -> str:
542
+ candidate = "bundle"
543
+ if candidate not in existing:
544
+ return candidate
545
+ idx = 1
546
+ while f"bundle_{idx}" in existing:
547
+ idx += 1
548
+ return f"bundle_{idx}"
549
+
550
+ def _build_parameters(
551
+ self, self_param: cst.Param | None, bundle_name: str
552
+ ) -> cst.Parameters:
553
+ params: list[cst.Param] = []
554
+ if self_param is not None:
555
+ params.append(self_param)
556
+ annotation = None
557
+ if self.protocol_hint:
558
+ try:
559
+ annotation = cst.Annotation(cst.parse_expression(self.protocol_hint))
560
+ except Exception as exc:
561
+ self.warnings.append(
562
+ f"Failed to parse protocol type hint '{self.protocol_hint}': {exc}"
563
+ )
564
+ params.append(
565
+ cst.Param(
566
+ name=cst.Name(bundle_name),
567
+ annotation=annotation,
568
+ )
569
+ )
570
+ return cst.Parameters(
571
+ params=params,
572
+ star_arg=cst.MaybeSentinel.DEFAULT,
573
+ kwonly_params=[],
574
+ star_kwarg=None,
575
+ posonly_params=[],
576
+ posonly_ind=cst.MaybeSentinel.DEFAULT,
577
+ )
578
+
579
+ def _inject_preamble(
580
+ self, body: cst.BaseSuite, bundle_name: str, fields: list[str]
581
+ ) -> cst.BaseSuite:
582
+ if not fields:
583
+ return body
584
+ if not isinstance(body, cst.IndentedBlock):
585
+ return body
586
+ assign_lines = [
587
+ cst.SimpleStatementLine(
588
+ [
589
+ cst.Assign(
590
+ targets=[cst.AssignTarget(cst.Name(name))],
591
+ value=cst.Attribute(
592
+ value=cst.Name(bundle_name),
593
+ attr=cst.Name(name),
594
+ ),
595
+ )
596
+ ]
597
+ )
598
+ for name in fields
599
+ ]
600
+ existing = list(body.body)
601
+ insert_at = 0
602
+ if existing:
603
+ first = existing[0]
604
+ if isinstance(first, cst.SimpleStatementLine) and first.body:
605
+ expr = first.body[0]
606
+ if isinstance(expr, cst.Expr) and isinstance(expr.value, cst.SimpleString):
607
+ insert_at = 1
608
+ new_body = existing[:insert_at] + assign_lines + existing[insert_at:]
609
+ return body.with_changes(body=new_body)
610
+
611
+
612
+ class _CallSiteTransformer(cst.CSTTransformer):
613
+ def __init__(
614
+ self,
615
+ *,
616
+ file_is_target: bool,
617
+ target_simple: set[str],
618
+ target_methods: dict[str, set[str]],
619
+ module_aliases: set[str],
620
+ imported_targets: set[str],
621
+ bundle_fields: list[str],
622
+ constructor_expr: cst.BaseExpression,
623
+ ) -> None:
624
+ # dataflow-bundle: bundle_fields, constructor_expr, file_is_target, imported_targets, module_aliases, target_methods, target_simple
625
+ self.file_is_target = file_is_target
626
+ self.target_simple = target_simple
627
+ self.target_methods = target_methods
628
+ self.module_aliases = module_aliases
629
+ self.imported_targets = imported_targets
630
+ self.bundle_fields = bundle_fields
631
+ self.constructor_expr = constructor_expr
632
+ self.changed = False
633
+ self.warnings: list[str] = []
634
+ self._class_stack: list[str] = []
635
+
636
+ def visit_ClassDef(self, node: cst.ClassDef) -> bool:
637
+ self._class_stack.append(node.name.value)
638
+ return True
639
+
640
+ def leave_ClassDef(
641
+ self, original_node: cst.ClassDef, updated_node: cst.ClassDef
642
+ ) -> cst.CSTNode:
643
+ if self._class_stack:
644
+ self._class_stack.pop()
645
+ return updated_node
646
+
647
+ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
648
+ if not self._is_target_call(updated_node.func):
649
+ return updated_node
650
+ if self._already_wrapped(updated_node):
651
+ return updated_node
652
+ bundle_args = self._build_bundle_args(updated_node)
653
+ if bundle_args is None:
654
+ return updated_node
655
+ bundle_call = cst.Call(func=self.constructor_expr, args=bundle_args)
656
+ self.changed = True
657
+ return updated_node.with_changes(args=[cst.Arg(value=bundle_call)])
658
+
659
+ def _is_target_call(self, func: cst.BaseExpression) -> bool:
660
+ if isinstance(func, cst.Name):
661
+ if self.file_is_target and func.value in self.target_simple:
662
+ return True
663
+ if not self.file_is_target and func.value in self.imported_targets:
664
+ return True
665
+ return False
666
+ if isinstance(func, cst.Attribute):
667
+ if not isinstance(func.attr, cst.Name):
668
+ return False
669
+ attr = func.attr.value
670
+ if self.file_is_target and self._class_stack:
671
+ class_name = ".".join(self._class_stack)
672
+ methods = self.target_methods.get(class_name, set())
673
+ if attr in methods and isinstance(func.value, cst.Name):
674
+ if func.value.value in {"self", "cls", self._class_stack[-1]}:
675
+ return True
676
+ if not self.file_is_target and isinstance(func.value, cst.Name):
677
+ if func.value.value in self.module_aliases and attr in self.target_simple:
678
+ return True
679
+ return False
680
+
681
+ def _already_wrapped(self, call: cst.Call) -> bool:
682
+ if len(call.args) != 1:
683
+ return False
684
+ arg = call.args[0]
685
+ if arg.star:
686
+ return False
687
+ value = arg.value
688
+ if not isinstance(value, cst.Call):
689
+ return False
690
+ if isinstance(value.func, cst.Name) and isinstance(self.constructor_expr, cst.Name):
691
+ return value.func.value == self.constructor_expr.value
692
+ if isinstance(value.func, cst.Attribute) and isinstance(self.constructor_expr, cst.Attribute):
693
+ return value.func.attr.value == self.constructor_expr.attr.value
694
+ return False
695
+
696
+ def _build_bundle_args(self, call: cst.Call) -> list[cst.Arg] | None:
697
+ if any(arg.star in {"*", "**"} for arg in call.args):
698
+ self.warnings.append("Skipped call with star args/kwargs during refactor.")
699
+ return None
700
+ positional = [arg for arg in call.args if arg.keyword is None]
701
+ keyword_args = {
702
+ arg.keyword.value: arg.value
703
+ for arg in call.args
704
+ if arg.keyword is not None and isinstance(arg.keyword, cst.Name)
705
+ }
706
+ for key in keyword_args:
707
+ if key not in self.bundle_fields:
708
+ self.warnings.append(
709
+ f"Skipped call with unknown keyword '{key}' during refactor."
710
+ )
711
+ return None
712
+ mapping: dict[str, cst.BaseExpression] = {}
713
+ mapping.update(keyword_args)
714
+ remaining = [field for field in self.bundle_fields if field not in mapping]
715
+ if len(positional) > len(remaining):
716
+ self.warnings.append("Skipped call with extra positional args during refactor.")
717
+ return None
718
+ for field, arg in zip(remaining, positional):
719
+ mapping[field] = arg.value
720
+ if len(mapping) != len(self.bundle_fields):
721
+ self.warnings.append("Skipped call with missing bundle fields during refactor.")
722
+ return None
723
+ return [
724
+ cst.Arg(keyword=cst.Name(field), value=mapping[field])
725
+ for field in self.bundle_fields
726
+ ]