codeshift 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codeshift/cli/commands/apply.py +24 -2
- codeshift/cli/commands/upgrade_all.py +4 -1
- codeshift/cli/package_manager.py +102 -0
- codeshift/knowledge/generator.py +11 -1
- codeshift/knowledge_base/libraries/aiohttp.yaml +186 -0
- codeshift/knowledge_base/libraries/attrs.yaml +181 -0
- codeshift/knowledge_base/libraries/celery.yaml +244 -0
- codeshift/knowledge_base/libraries/click.yaml +195 -0
- codeshift/knowledge_base/libraries/django.yaml +355 -0
- codeshift/knowledge_base/libraries/flask.yaml +270 -0
- codeshift/knowledge_base/libraries/httpx.yaml +183 -0
- codeshift/knowledge_base/libraries/marshmallow.yaml +238 -0
- codeshift/knowledge_base/libraries/numpy.yaml +429 -0
- codeshift/knowledge_base/libraries/pytest.yaml +192 -0
- codeshift/knowledge_base/libraries/sqlalchemy.yaml +2 -1
- codeshift/migrator/engine.py +60 -0
- codeshift/migrator/transforms/__init__.py +2 -0
- codeshift/migrator/transforms/aiohttp_transformer.py +608 -0
- codeshift/migrator/transforms/attrs_transformer.py +570 -0
- codeshift/migrator/transforms/celery_transformer.py +546 -0
- codeshift/migrator/transforms/click_transformer.py +526 -0
- codeshift/migrator/transforms/django_transformer.py +852 -0
- codeshift/migrator/transforms/fastapi_transformer.py +12 -7
- codeshift/migrator/transforms/flask_transformer.py +505 -0
- codeshift/migrator/transforms/httpx_transformer.py +419 -0
- codeshift/migrator/transforms/marshmallow_transformer.py +515 -0
- codeshift/migrator/transforms/numpy_transformer.py +413 -0
- codeshift/migrator/transforms/pydantic_v1_to_v2.py +53 -8
- codeshift/migrator/transforms/pytest_transformer.py +351 -0
- codeshift/migrator/transforms/requests_transformer.py +74 -1
- codeshift/migrator/transforms/sqlalchemy_transformer.py +692 -39
- codeshift/scanner/dependency_parser.py +1 -1
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/METADATA +46 -4
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/RECORD +38 -17
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/WHEEL +0 -0
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/entry_points.txt +0 -0
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/licenses/LICENSE +0 -0
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,526 @@
|
|
|
1
|
+
"""Click 7.x to 8.x transformation using LibCST."""
|
|
2
|
+
|
|
3
|
+
import libcst as cst
|
|
4
|
+
|
|
5
|
+
from codeshift.migrator.ast_transforms import BaseTransformer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ClickTransformer(BaseTransformer):
|
|
9
|
+
"""Transform Click 7.x code to 8.x."""
|
|
10
|
+
|
|
11
|
+
def __init__(self) -> None:
|
|
12
|
+
super().__init__()
|
|
13
|
+
# Track imports needed
|
|
14
|
+
self._needs_shutil_import = False
|
|
15
|
+
self._needs_sys_import = False
|
|
16
|
+
self._needs_importlib_metadata = False
|
|
17
|
+
self._has_shutil_import = False
|
|
18
|
+
self._has_sys_import = False
|
|
19
|
+
self._has_importlib_metadata_import = False
|
|
20
|
+
# Track click imports for transforming
|
|
21
|
+
self._has_click_import = False
|
|
22
|
+
|
|
23
|
+
def leave_ImportFrom(
|
|
24
|
+
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
|
|
25
|
+
) -> cst.ImportFrom:
|
|
26
|
+
"""Transform Click imports and track existing imports."""
|
|
27
|
+
if updated_node.module is None:
|
|
28
|
+
return updated_node
|
|
29
|
+
|
|
30
|
+
module_name = self._get_module_name(updated_node.module)
|
|
31
|
+
|
|
32
|
+
# Track existing imports
|
|
33
|
+
if module_name == "shutil":
|
|
34
|
+
self._has_shutil_import = True
|
|
35
|
+
elif module_name == "sys":
|
|
36
|
+
self._has_sys_import = True
|
|
37
|
+
elif module_name == "importlib.metadata" or module_name == "importlib":
|
|
38
|
+
self._has_importlib_metadata_import = True
|
|
39
|
+
elif module_name == "click":
|
|
40
|
+
self._has_click_import = True
|
|
41
|
+
|
|
42
|
+
# Transform deprecated class imports
|
|
43
|
+
if isinstance(updated_node.names, cst.ImportStar):
|
|
44
|
+
return updated_node
|
|
45
|
+
|
|
46
|
+
new_names = []
|
|
47
|
+
changed = False
|
|
48
|
+
|
|
49
|
+
for name in updated_node.names:
|
|
50
|
+
if isinstance(name, cst.ImportAlias):
|
|
51
|
+
imported_name = self._get_name_value(name.name)
|
|
52
|
+
|
|
53
|
+
if imported_name == "MultiCommand":
|
|
54
|
+
new_names.append(name.with_changes(name=cst.Name("Group")))
|
|
55
|
+
changed = True
|
|
56
|
+
self.record_change(
|
|
57
|
+
description="Replace MultiCommand import with Group",
|
|
58
|
+
line_number=1,
|
|
59
|
+
original="from click import MultiCommand",
|
|
60
|
+
replacement="from click import Group",
|
|
61
|
+
transform_name="multicommand_to_group",
|
|
62
|
+
)
|
|
63
|
+
elif imported_name == "BaseCommand":
|
|
64
|
+
new_names.append(name.with_changes(name=cst.Name("Command")))
|
|
65
|
+
changed = True
|
|
66
|
+
self.record_change(
|
|
67
|
+
description="Replace BaseCommand import with Command",
|
|
68
|
+
line_number=1,
|
|
69
|
+
original="from click import BaseCommand",
|
|
70
|
+
replacement="from click import Command",
|
|
71
|
+
transform_name="basecommand_to_command",
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
new_names.append(name)
|
|
75
|
+
|
|
76
|
+
if changed:
|
|
77
|
+
return updated_node.with_changes(names=new_names)
|
|
78
|
+
|
|
79
|
+
return updated_node
|
|
80
|
+
|
|
81
|
+
def leave_Attribute(
|
|
82
|
+
self, original_node: cst.Attribute, updated_node: cst.Attribute
|
|
83
|
+
) -> cst.BaseExpression:
|
|
84
|
+
"""Transform attribute access like result.output_bytes and click.__version__."""
|
|
85
|
+
attr_name = updated_node.attr.value
|
|
86
|
+
|
|
87
|
+
# Handle result.output_bytes -> result.output.encode()
|
|
88
|
+
if attr_name == "output_bytes":
|
|
89
|
+
self.record_change(
|
|
90
|
+
description="Replace .output_bytes with .output.encode()",
|
|
91
|
+
line_number=1,
|
|
92
|
+
original=".output_bytes",
|
|
93
|
+
replacement=".output.encode()",
|
|
94
|
+
transform_name="output_bytes_to_encode",
|
|
95
|
+
)
|
|
96
|
+
# Transform to .output.encode()
|
|
97
|
+
output_attr = cst.Attribute(
|
|
98
|
+
value=updated_node.value,
|
|
99
|
+
attr=cst.Name("output"),
|
|
100
|
+
)
|
|
101
|
+
return cst.Call(
|
|
102
|
+
func=cst.Attribute(
|
|
103
|
+
value=output_attr,
|
|
104
|
+
attr=cst.Name("encode"),
|
|
105
|
+
),
|
|
106
|
+
args=[],
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Handle click.__version__ -> importlib.metadata.version("click")
|
|
110
|
+
if attr_name == "__version__":
|
|
111
|
+
if isinstance(updated_node.value, cst.Name) and updated_node.value.value == "click":
|
|
112
|
+
self._needs_importlib_metadata = True
|
|
113
|
+
self.record_change(
|
|
114
|
+
description="Replace click.__version__ with importlib.metadata.version('click')",
|
|
115
|
+
line_number=1,
|
|
116
|
+
original="click.__version__",
|
|
117
|
+
replacement="importlib.metadata.version('click')",
|
|
118
|
+
transform_name="version_attr_to_importlib",
|
|
119
|
+
)
|
|
120
|
+
return cst.Call(
|
|
121
|
+
func=cst.Attribute(
|
|
122
|
+
value=cst.Attribute(
|
|
123
|
+
value=cst.Name("importlib"),
|
|
124
|
+
attr=cst.Name("metadata"),
|
|
125
|
+
),
|
|
126
|
+
attr=cst.Name("version"),
|
|
127
|
+
),
|
|
128
|
+
args=[cst.Arg(value=cst.SimpleString('"click"'))],
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return updated_node
|
|
132
|
+
|
|
133
|
+
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.BaseExpression:
|
|
134
|
+
"""Transform function calls."""
|
|
135
|
+
# Handle click.get_terminal_size() -> shutil.get_terminal_size()
|
|
136
|
+
if self._is_click_call(updated_node, "get_terminal_size"):
|
|
137
|
+
self._needs_shutil_import = True
|
|
138
|
+
self.record_change(
|
|
139
|
+
description="Replace click.get_terminal_size() with shutil.get_terminal_size()",
|
|
140
|
+
line_number=1,
|
|
141
|
+
original="click.get_terminal_size()",
|
|
142
|
+
replacement="shutil.get_terminal_size()",
|
|
143
|
+
transform_name="get_terminal_size_to_shutil",
|
|
144
|
+
)
|
|
145
|
+
return cst.Call(
|
|
146
|
+
func=cst.Attribute(
|
|
147
|
+
value=cst.Name("shutil"),
|
|
148
|
+
attr=cst.Name("get_terminal_size"),
|
|
149
|
+
),
|
|
150
|
+
args=updated_node.args,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Handle click.get_os_args() -> sys.argv[1:]
|
|
154
|
+
if self._is_click_call(updated_node, "get_os_args"):
|
|
155
|
+
self._needs_sys_import = True
|
|
156
|
+
self.record_change(
|
|
157
|
+
description="Replace click.get_os_args() with sys.argv[1:]",
|
|
158
|
+
line_number=1,
|
|
159
|
+
original="click.get_os_args()",
|
|
160
|
+
replacement="sys.argv[1:]",
|
|
161
|
+
transform_name="get_os_args_to_sys_argv",
|
|
162
|
+
)
|
|
163
|
+
return cst.Subscript(
|
|
164
|
+
value=cst.Attribute(
|
|
165
|
+
value=cst.Name("sys"),
|
|
166
|
+
attr=cst.Name("argv"),
|
|
167
|
+
),
|
|
168
|
+
slice=[
|
|
169
|
+
cst.SubscriptElement(
|
|
170
|
+
slice=cst.Slice(
|
|
171
|
+
lower=cst.Integer("1"),
|
|
172
|
+
upper=None,
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
],
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# Handle CliRunner(..., mix_stderr=...) -> CliRunner(...)
|
|
179
|
+
if isinstance(updated_node.func, cst.Name) and updated_node.func.value == "CliRunner":
|
|
180
|
+
new_args = []
|
|
181
|
+
changed = False
|
|
182
|
+
|
|
183
|
+
for arg in updated_node.args:
|
|
184
|
+
if isinstance(arg.keyword, cst.Name) and arg.keyword.value == "mix_stderr":
|
|
185
|
+
changed = True
|
|
186
|
+
self.record_change(
|
|
187
|
+
description="Remove deprecated mix_stderr parameter from CliRunner",
|
|
188
|
+
line_number=1,
|
|
189
|
+
original="CliRunner(mix_stderr=...)",
|
|
190
|
+
replacement="CliRunner()",
|
|
191
|
+
transform_name="clirunner_remove_mix_stderr",
|
|
192
|
+
)
|
|
193
|
+
else:
|
|
194
|
+
new_args.append(arg)
|
|
195
|
+
|
|
196
|
+
if changed:
|
|
197
|
+
return updated_node.with_changes(args=new_args)
|
|
198
|
+
|
|
199
|
+
# Handle @click.option/argument(..., autocompletion=...) -> shell_complete=...
|
|
200
|
+
if isinstance(updated_node.func, cst.Attribute):
|
|
201
|
+
if isinstance(updated_node.func.value, cst.Name):
|
|
202
|
+
if updated_node.func.value.value == "click":
|
|
203
|
+
attr_name = updated_node.func.attr.value
|
|
204
|
+
if attr_name in ("option", "argument"):
|
|
205
|
+
return self._transform_autocompletion_param(updated_node, attr_name)
|
|
206
|
+
|
|
207
|
+
# Also handle decorators without click. prefix (e.g., from click import option)
|
|
208
|
+
if isinstance(updated_node.func, cst.Name):
|
|
209
|
+
func_name = updated_node.func.value
|
|
210
|
+
if func_name in ("option", "argument"):
|
|
211
|
+
return self._transform_autocompletion_param(updated_node, func_name)
|
|
212
|
+
|
|
213
|
+
return updated_node
|
|
214
|
+
|
|
215
|
+
def _transform_autocompletion_param(self, node: cst.Call, decorator_name: str) -> cst.Call:
|
|
216
|
+
"""Transform autocompletion parameter to shell_complete."""
|
|
217
|
+
new_args = []
|
|
218
|
+
changed = False
|
|
219
|
+
|
|
220
|
+
for arg in node.args:
|
|
221
|
+
if isinstance(arg.keyword, cst.Name) and arg.keyword.value == "autocompletion":
|
|
222
|
+
new_arg = arg.with_changes(keyword=cst.Name("shell_complete"))
|
|
223
|
+
new_args.append(new_arg)
|
|
224
|
+
changed = True
|
|
225
|
+
|
|
226
|
+
self.record_change(
|
|
227
|
+
description=f"Rename {decorator_name}(autocompletion=...) to {decorator_name}(shell_complete=...)",
|
|
228
|
+
line_number=1,
|
|
229
|
+
original=f"@click.{decorator_name}(autocompletion=...)",
|
|
230
|
+
replacement=f"@click.{decorator_name}(shell_complete=...)",
|
|
231
|
+
transform_name="autocompletion_to_shell_complete",
|
|
232
|
+
notes="Callback signature changed from (ctx, args, incomplete) to (ctx, param, incomplete)",
|
|
233
|
+
)
|
|
234
|
+
else:
|
|
235
|
+
new_args.append(arg)
|
|
236
|
+
|
|
237
|
+
if changed:
|
|
238
|
+
return node.with_changes(args=new_args)
|
|
239
|
+
|
|
240
|
+
return node
|
|
241
|
+
|
|
242
|
+
def leave_Decorator(
|
|
243
|
+
self, original_node: cst.Decorator, updated_node: cst.Decorator
|
|
244
|
+
) -> cst.Decorator:
|
|
245
|
+
"""Transform decorator calls like @group.resultcallback to @group.result_callback."""
|
|
246
|
+
# Handle @group.resultcallback -> @group.result_callback
|
|
247
|
+
if isinstance(updated_node.decorator, cst.Call):
|
|
248
|
+
if isinstance(updated_node.decorator.func, cst.Attribute):
|
|
249
|
+
if updated_node.decorator.func.attr.value == "resultcallback":
|
|
250
|
+
new_func = updated_node.decorator.func.with_changes(
|
|
251
|
+
attr=cst.Name("result_callback")
|
|
252
|
+
)
|
|
253
|
+
new_call_decorator = updated_node.decorator.with_changes(func=new_func)
|
|
254
|
+
|
|
255
|
+
self.record_change(
|
|
256
|
+
description="Rename @group.resultcallback() to @group.result_callback()",
|
|
257
|
+
line_number=1,
|
|
258
|
+
original="@group.resultcallback()",
|
|
259
|
+
replacement="@group.result_callback()",
|
|
260
|
+
transform_name="resultcallback_to_result_callback",
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
return updated_node.with_changes(decorator=new_call_decorator)
|
|
264
|
+
|
|
265
|
+
elif isinstance(updated_node.decorator, cst.Attribute):
|
|
266
|
+
if updated_node.decorator.attr.value == "resultcallback":
|
|
267
|
+
new_attr_decorator = updated_node.decorator.with_changes(
|
|
268
|
+
attr=cst.Name("result_callback")
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
self.record_change(
|
|
272
|
+
description="Rename @group.resultcallback to @group.result_callback",
|
|
273
|
+
line_number=1,
|
|
274
|
+
original="@group.resultcallback",
|
|
275
|
+
replacement="@group.result_callback",
|
|
276
|
+
transform_name="resultcallback_to_result_callback",
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
return updated_node.with_changes(decorator=new_attr_decorator)
|
|
280
|
+
|
|
281
|
+
return updated_node
|
|
282
|
+
|
|
283
|
+
def leave_ClassDef(
|
|
284
|
+
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
|
|
285
|
+
) -> cst.ClassDef:
|
|
286
|
+
"""Transform class definitions that inherit from deprecated base classes."""
|
|
287
|
+
if not updated_node.bases:
|
|
288
|
+
return updated_node
|
|
289
|
+
|
|
290
|
+
new_bases = []
|
|
291
|
+
changed = False
|
|
292
|
+
|
|
293
|
+
for base in updated_node.bases:
|
|
294
|
+
if isinstance(base.value, cst.Attribute):
|
|
295
|
+
# Handle click.MultiCommand, click.BaseCommand
|
|
296
|
+
if isinstance(base.value.value, cst.Name) and base.value.value.value == "click":
|
|
297
|
+
if base.value.attr.value == "MultiCommand":
|
|
298
|
+
new_base = base.with_changes(
|
|
299
|
+
value=cst.Attribute(
|
|
300
|
+
value=cst.Name("click"),
|
|
301
|
+
attr=cst.Name("Group"),
|
|
302
|
+
)
|
|
303
|
+
)
|
|
304
|
+
new_bases.append(new_base)
|
|
305
|
+
changed = True
|
|
306
|
+
self.record_change(
|
|
307
|
+
description="Replace click.MultiCommand with click.Group as base class",
|
|
308
|
+
line_number=1,
|
|
309
|
+
original="class MyClass(click.MultiCommand)",
|
|
310
|
+
replacement="class MyClass(click.Group)",
|
|
311
|
+
transform_name="multicommand_to_group",
|
|
312
|
+
)
|
|
313
|
+
continue
|
|
314
|
+
elif base.value.attr.value == "BaseCommand":
|
|
315
|
+
new_base = base.with_changes(
|
|
316
|
+
value=cst.Attribute(
|
|
317
|
+
value=cst.Name("click"),
|
|
318
|
+
attr=cst.Name("Command"),
|
|
319
|
+
)
|
|
320
|
+
)
|
|
321
|
+
new_bases.append(new_base)
|
|
322
|
+
changed = True
|
|
323
|
+
self.record_change(
|
|
324
|
+
description="Replace click.BaseCommand with click.Command as base class",
|
|
325
|
+
line_number=1,
|
|
326
|
+
original="class MyClass(click.BaseCommand)",
|
|
327
|
+
replacement="class MyClass(click.Command)",
|
|
328
|
+
transform_name="basecommand_to_command",
|
|
329
|
+
)
|
|
330
|
+
continue
|
|
331
|
+
|
|
332
|
+
elif isinstance(base.value, cst.Name):
|
|
333
|
+
# Handle MultiCommand, BaseCommand (imported directly)
|
|
334
|
+
if base.value.value == "MultiCommand":
|
|
335
|
+
new_base = base.with_changes(value=cst.Name("Group"))
|
|
336
|
+
new_bases.append(new_base)
|
|
337
|
+
changed = True
|
|
338
|
+
self.record_change(
|
|
339
|
+
description="Replace MultiCommand with Group as base class",
|
|
340
|
+
line_number=1,
|
|
341
|
+
original="class MyClass(MultiCommand)",
|
|
342
|
+
replacement="class MyClass(Group)",
|
|
343
|
+
transform_name="multicommand_to_group",
|
|
344
|
+
)
|
|
345
|
+
continue
|
|
346
|
+
elif base.value.value == "BaseCommand":
|
|
347
|
+
new_base = base.with_changes(value=cst.Name("Command"))
|
|
348
|
+
new_bases.append(new_base)
|
|
349
|
+
changed = True
|
|
350
|
+
self.record_change(
|
|
351
|
+
description="Replace BaseCommand with Command as base class",
|
|
352
|
+
line_number=1,
|
|
353
|
+
original="class MyClass(BaseCommand)",
|
|
354
|
+
replacement="class MyClass(Command)",
|
|
355
|
+
transform_name="basecommand_to_command",
|
|
356
|
+
)
|
|
357
|
+
continue
|
|
358
|
+
|
|
359
|
+
new_bases.append(base)
|
|
360
|
+
|
|
361
|
+
if changed:
|
|
362
|
+
return updated_node.with_changes(bases=new_bases)
|
|
363
|
+
|
|
364
|
+
return updated_node
|
|
365
|
+
|
|
366
|
+
def _is_click_call(self, node: cst.Call, func_name: str) -> bool:
|
|
367
|
+
"""Check if a call is click.<func_name>()."""
|
|
368
|
+
if isinstance(node.func, cst.Attribute):
|
|
369
|
+
if isinstance(node.func.value, cst.Name):
|
|
370
|
+
return bool(node.func.value.value == "click" and node.func.attr.value == func_name)
|
|
371
|
+
return False
|
|
372
|
+
|
|
373
|
+
def _get_module_name(self, module: cst.BaseExpression) -> str:
|
|
374
|
+
"""Get the full module name from a Name or Attribute node."""
|
|
375
|
+
if isinstance(module, cst.Name):
|
|
376
|
+
return str(module.value)
|
|
377
|
+
elif isinstance(module, cst.Attribute):
|
|
378
|
+
return f"{self._get_module_name(module.value)}.{module.attr.value}"
|
|
379
|
+
return ""
|
|
380
|
+
|
|
381
|
+
def _get_name_value(self, node: cst.BaseExpression) -> str | None:
|
|
382
|
+
"""Extract the string value from a Name node."""
|
|
383
|
+
if isinstance(node, cst.Name):
|
|
384
|
+
return str(node.value)
|
|
385
|
+
return None
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
class ClickImportTransformer(BaseTransformer):
|
|
389
|
+
"""Separate transformer for handling import additions.
|
|
390
|
+
|
|
391
|
+
This runs after the main transformer to add any missing imports.
|
|
392
|
+
"""
|
|
393
|
+
|
|
394
|
+
def __init__(
|
|
395
|
+
self,
|
|
396
|
+
needs_shutil: bool = False,
|
|
397
|
+
needs_sys: bool = False,
|
|
398
|
+
needs_importlib_metadata: bool = False,
|
|
399
|
+
) -> None:
|
|
400
|
+
super().__init__()
|
|
401
|
+
self.needs_shutil = needs_shutil
|
|
402
|
+
self.needs_sys = needs_sys
|
|
403
|
+
self.needs_importlib_metadata = needs_importlib_metadata
|
|
404
|
+
self._added_shutil = False
|
|
405
|
+
self._added_sys = False
|
|
406
|
+
self._added_importlib = False
|
|
407
|
+
|
|
408
|
+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
|
409
|
+
"""Add missing imports at the beginning of the module."""
|
|
410
|
+
new_imports: list[cst.SimpleStatementLine] = []
|
|
411
|
+
|
|
412
|
+
# Check if imports already exist
|
|
413
|
+
for stmt in updated_node.body:
|
|
414
|
+
if isinstance(stmt, cst.SimpleStatementLine):
|
|
415
|
+
for item in stmt.body:
|
|
416
|
+
if isinstance(item, cst.Import):
|
|
417
|
+
for name in (
|
|
418
|
+
item.names if not isinstance(item.names, cst.ImportStar) else []
|
|
419
|
+
):
|
|
420
|
+
if isinstance(name, cst.ImportAlias):
|
|
421
|
+
name_val = self._get_name_value(name.name)
|
|
422
|
+
if name_val == "shutil":
|
|
423
|
+
self._added_shutil = True
|
|
424
|
+
elif name_val == "sys":
|
|
425
|
+
self._added_sys = True
|
|
426
|
+
elif name_val == "importlib":
|
|
427
|
+
self._added_importlib = True
|
|
428
|
+
elif isinstance(item, cst.ImportFrom):
|
|
429
|
+
if item.module:
|
|
430
|
+
mod_name = self._get_module_name(item.module)
|
|
431
|
+
if mod_name == "shutil":
|
|
432
|
+
self._added_shutil = True
|
|
433
|
+
elif mod_name == "sys":
|
|
434
|
+
self._added_sys = True
|
|
435
|
+
elif mod_name.startswith("importlib"):
|
|
436
|
+
self._added_importlib = True
|
|
437
|
+
|
|
438
|
+
# Add needed imports
|
|
439
|
+
if self.needs_shutil and not self._added_shutil:
|
|
440
|
+
new_imports.append(
|
|
441
|
+
cst.SimpleStatementLine(
|
|
442
|
+
body=[cst.Import(names=[cst.ImportAlias(name=cst.Name("shutil"))])]
|
|
443
|
+
)
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
if self.needs_sys and not self._added_sys:
|
|
447
|
+
new_imports.append(
|
|
448
|
+
cst.SimpleStatementLine(
|
|
449
|
+
body=[cst.Import(names=[cst.ImportAlias(name=cst.Name("sys"))])]
|
|
450
|
+
)
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
if self.needs_importlib_metadata and not self._added_importlib:
|
|
454
|
+
new_imports.append(
|
|
455
|
+
cst.SimpleStatementLine(
|
|
456
|
+
body=[cst.Import(names=[cst.ImportAlias(name=cst.Name("importlib"))])]
|
|
457
|
+
)
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
if new_imports:
|
|
461
|
+
# Insert imports at the beginning, after any existing imports/docstrings
|
|
462
|
+
insert_pos = 0
|
|
463
|
+
for i, stmt in enumerate(updated_node.body):
|
|
464
|
+
if isinstance(stmt, cst.SimpleStatementLine):
|
|
465
|
+
if any(
|
|
466
|
+
isinstance(s, cst.Import | cst.ImportFrom | cst.Expr) for s in stmt.body
|
|
467
|
+
):
|
|
468
|
+
insert_pos = i + 1
|
|
469
|
+
elif isinstance(stmt, cst.IndentedBlock | cst.Expr):
|
|
470
|
+
# Skip docstrings
|
|
471
|
+
insert_pos = i + 1
|
|
472
|
+
else:
|
|
473
|
+
break
|
|
474
|
+
|
|
475
|
+
new_body = (
|
|
476
|
+
list(updated_node.body[:insert_pos])
|
|
477
|
+
+ new_imports
|
|
478
|
+
+ list(updated_node.body[insert_pos:])
|
|
479
|
+
)
|
|
480
|
+
return updated_node.with_changes(body=new_body)
|
|
481
|
+
|
|
482
|
+
return updated_node
|
|
483
|
+
|
|
484
|
+
def _get_module_name(self, module: cst.BaseExpression) -> str:
|
|
485
|
+
"""Get the full module name from a Name or Attribute node."""
|
|
486
|
+
if isinstance(module, cst.Name):
|
|
487
|
+
return str(module.value)
|
|
488
|
+
elif isinstance(module, cst.Attribute):
|
|
489
|
+
return f"{self._get_module_name(module.value)}.{module.attr.value}"
|
|
490
|
+
return ""
|
|
491
|
+
|
|
492
|
+
def _get_name_value(self, node: cst.BaseExpression) -> str | None:
|
|
493
|
+
"""Extract the string value from a Name node."""
|
|
494
|
+
if isinstance(node, cst.Name):
|
|
495
|
+
return str(node.value)
|
|
496
|
+
return None
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
def transform_click(source_code: str) -> tuple[str, list]:
|
|
500
|
+
"""Transform Click 7.x code to 8.x.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
source_code: The source code to transform
|
|
504
|
+
|
|
505
|
+
Returns:
|
|
506
|
+
Tuple of (transformed_code, list of changes)
|
|
507
|
+
"""
|
|
508
|
+
try:
|
|
509
|
+
tree = cst.parse_module(source_code)
|
|
510
|
+
except cst.ParserSyntaxError as e:
|
|
511
|
+
raise SyntaxError(f"Invalid Python syntax: {e}") from e
|
|
512
|
+
|
|
513
|
+
# First pass: main transformations
|
|
514
|
+
transformer = ClickTransformer()
|
|
515
|
+
transformer.set_source(source_code)
|
|
516
|
+
transformed_tree = tree.visit(transformer)
|
|
517
|
+
|
|
518
|
+
# Second pass: add missing imports
|
|
519
|
+
import_transformer = ClickImportTransformer(
|
|
520
|
+
needs_shutil=transformer._needs_shutil_import,
|
|
521
|
+
needs_sys=transformer._needs_sys_import,
|
|
522
|
+
needs_importlib_metadata=transformer._needs_importlib_metadata,
|
|
523
|
+
)
|
|
524
|
+
final_tree = transformed_tree.visit(import_transformer)
|
|
525
|
+
|
|
526
|
+
return final_tree.code, transformer.changes
|