codeshift 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codeshift/cli/commands/apply.py +24 -2
- codeshift/cli/commands/upgrade_all.py +4 -1
- codeshift/cli/package_manager.py +102 -0
- codeshift/knowledge/generator.py +11 -1
- codeshift/knowledge_base/libraries/aiohttp.yaml +186 -0
- codeshift/knowledge_base/libraries/attrs.yaml +181 -0
- codeshift/knowledge_base/libraries/celery.yaml +244 -0
- codeshift/knowledge_base/libraries/click.yaml +195 -0
- codeshift/knowledge_base/libraries/django.yaml +355 -0
- codeshift/knowledge_base/libraries/flask.yaml +270 -0
- codeshift/knowledge_base/libraries/httpx.yaml +183 -0
- codeshift/knowledge_base/libraries/marshmallow.yaml +238 -0
- codeshift/knowledge_base/libraries/numpy.yaml +429 -0
- codeshift/knowledge_base/libraries/pytest.yaml +192 -0
- codeshift/knowledge_base/libraries/sqlalchemy.yaml +2 -1
- codeshift/migrator/engine.py +60 -0
- codeshift/migrator/transforms/__init__.py +2 -0
- codeshift/migrator/transforms/aiohttp_transformer.py +608 -0
- codeshift/migrator/transforms/attrs_transformer.py +570 -0
- codeshift/migrator/transforms/celery_transformer.py +546 -0
- codeshift/migrator/transforms/click_transformer.py +526 -0
- codeshift/migrator/transforms/django_transformer.py +852 -0
- codeshift/migrator/transforms/fastapi_transformer.py +12 -7
- codeshift/migrator/transforms/flask_transformer.py +505 -0
- codeshift/migrator/transforms/httpx_transformer.py +419 -0
- codeshift/migrator/transforms/marshmallow_transformer.py +515 -0
- codeshift/migrator/transforms/numpy_transformer.py +413 -0
- codeshift/migrator/transforms/pydantic_v1_to_v2.py +53 -8
- codeshift/migrator/transforms/pytest_transformer.py +351 -0
- codeshift/migrator/transforms/requests_transformer.py +74 -1
- codeshift/migrator/transforms/sqlalchemy_transformer.py +692 -39
- codeshift/scanner/dependency_parser.py +1 -1
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/METADATA +46 -4
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/RECORD +38 -17
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/WHEEL +0 -0
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/entry_points.txt +0 -0
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/licenses/LICENSE +0 -0
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
"""Pytest 6.x to 7.x/8.x transformation using LibCST."""
|
|
2
|
+
|
|
3
|
+
import libcst as cst
|
|
4
|
+
from libcst import matchers as m
|
|
5
|
+
|
|
6
|
+
from codeshift.migrator.ast_transforms import BaseTransformer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PytestTransformer(BaseTransformer):
|
|
10
|
+
"""Transform pytest 6.x code to 7.x/8.x compatible code."""
|
|
11
|
+
|
|
12
|
+
def __init__(self) -> None:
|
|
13
|
+
super().__init__()
|
|
14
|
+
# Track current class context for setup/teardown transforms
|
|
15
|
+
self._in_test_class = False
|
|
16
|
+
self._current_class_name: str | None = None
|
|
17
|
+
|
|
18
|
+
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
|
19
|
+
"""Track when we enter a test class."""
|
|
20
|
+
class_name = node.name.value
|
|
21
|
+
# Check if this looks like a test class
|
|
22
|
+
if class_name.startswith("Test") or any(
|
|
23
|
+
isinstance(base.value, cst.Name) and base.value.value == "TestCase"
|
|
24
|
+
for base in node.bases
|
|
25
|
+
if isinstance(base, cst.Arg)
|
|
26
|
+
):
|
|
27
|
+
self._in_test_class = True
|
|
28
|
+
self._current_class_name = class_name
|
|
29
|
+
return True
|
|
30
|
+
|
|
31
|
+
def leave_ClassDef(
|
|
32
|
+
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
|
|
33
|
+
) -> cst.ClassDef:
|
|
34
|
+
"""Reset class tracking when leaving a class."""
|
|
35
|
+
self._in_test_class = False
|
|
36
|
+
self._current_class_name = None
|
|
37
|
+
return updated_node
|
|
38
|
+
|
|
39
|
+
def leave_Decorator(
|
|
40
|
+
self, original_node: cst.Decorator, updated_node: cst.Decorator
|
|
41
|
+
) -> cst.Decorator:
|
|
42
|
+
"""Transform @pytest.yield_fixture to @pytest.fixture."""
|
|
43
|
+
# Handle @pytest.yield_fixture
|
|
44
|
+
if m.matches(
|
|
45
|
+
updated_node.decorator,
|
|
46
|
+
m.Attribute(
|
|
47
|
+
value=m.Name("pytest"),
|
|
48
|
+
attr=m.Name("yield_fixture"),
|
|
49
|
+
),
|
|
50
|
+
):
|
|
51
|
+
self.record_change(
|
|
52
|
+
description="Convert @pytest.yield_fixture to @pytest.fixture",
|
|
53
|
+
line_number=1,
|
|
54
|
+
original="@pytest.yield_fixture",
|
|
55
|
+
replacement="@pytest.fixture",
|
|
56
|
+
transform_name="yield_fixture_to_fixture",
|
|
57
|
+
)
|
|
58
|
+
new_attr = cst.Attribute(
|
|
59
|
+
value=cst.Name("pytest"),
|
|
60
|
+
attr=cst.Name("fixture"),
|
|
61
|
+
)
|
|
62
|
+
return updated_node.with_changes(decorator=new_attr)
|
|
63
|
+
|
|
64
|
+
# Handle @pytest.yield_fixture(...)
|
|
65
|
+
if m.matches(
|
|
66
|
+
updated_node.decorator,
|
|
67
|
+
m.Call(
|
|
68
|
+
func=m.Attribute(
|
|
69
|
+
value=m.Name("pytest"),
|
|
70
|
+
attr=m.Name("yield_fixture"),
|
|
71
|
+
),
|
|
72
|
+
),
|
|
73
|
+
):
|
|
74
|
+
assert isinstance(updated_node.decorator, cst.Call)
|
|
75
|
+
self.record_change(
|
|
76
|
+
description="Convert @pytest.yield_fixture(...) to @pytest.fixture(...)",
|
|
77
|
+
line_number=1,
|
|
78
|
+
original="@pytest.yield_fixture(...)",
|
|
79
|
+
replacement="@pytest.fixture(...)",
|
|
80
|
+
transform_name="yield_fixture_to_fixture",
|
|
81
|
+
)
|
|
82
|
+
new_func = cst.Attribute(
|
|
83
|
+
value=cst.Name("pytest"),
|
|
84
|
+
attr=cst.Name("fixture"),
|
|
85
|
+
)
|
|
86
|
+
new_call = updated_node.decorator.with_changes(func=new_func)
|
|
87
|
+
return updated_node.with_changes(decorator=new_call)
|
|
88
|
+
|
|
89
|
+
return updated_node
|
|
90
|
+
|
|
91
|
+
def leave_FunctionDef(
|
|
92
|
+
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
|
|
93
|
+
) -> cst.FunctionDef:
|
|
94
|
+
"""Transform setup/teardown methods in test classes and fixture parameters."""
|
|
95
|
+
# Handle tmpdir/tmpdir_factory fixture parameter renames
|
|
96
|
+
updated_node = self._transform_fixture_params(updated_node)
|
|
97
|
+
|
|
98
|
+
# Handle conftest hook parameter renames
|
|
99
|
+
updated_node = self._transform_hook_params(updated_node)
|
|
100
|
+
|
|
101
|
+
# Handle setup/teardown method renames only in test classes
|
|
102
|
+
if self._in_test_class:
|
|
103
|
+
func_name = updated_node.name.value
|
|
104
|
+
|
|
105
|
+
if func_name == "setup":
|
|
106
|
+
self.record_change(
|
|
107
|
+
description="Rename setup() to setup_method() for pytest 8.x compatibility",
|
|
108
|
+
line_number=1,
|
|
109
|
+
original="def setup(self):",
|
|
110
|
+
replacement="def setup_method(self):",
|
|
111
|
+
transform_name="setup_to_setup_method",
|
|
112
|
+
)
|
|
113
|
+
return updated_node.with_changes(name=cst.Name("setup_method"))
|
|
114
|
+
|
|
115
|
+
if func_name == "teardown":
|
|
116
|
+
self.record_change(
|
|
117
|
+
description="Rename teardown() to teardown_method() for pytest 8.x compatibility",
|
|
118
|
+
line_number=1,
|
|
119
|
+
original="def teardown(self):",
|
|
120
|
+
replacement="def teardown_method(self):",
|
|
121
|
+
transform_name="teardown_to_teardown_method",
|
|
122
|
+
)
|
|
123
|
+
return updated_node.with_changes(name=cst.Name("teardown_method"))
|
|
124
|
+
|
|
125
|
+
return updated_node
|
|
126
|
+
|
|
127
|
+
def _transform_fixture_params(self, node: cst.FunctionDef) -> cst.FunctionDef:
|
|
128
|
+
"""Transform tmpdir/tmpdir_factory fixture parameters to tmp_path/tmp_path_factory."""
|
|
129
|
+
if node.params.params is None:
|
|
130
|
+
return node
|
|
131
|
+
|
|
132
|
+
new_params = []
|
|
133
|
+
changed = False
|
|
134
|
+
|
|
135
|
+
for param in node.params.params:
|
|
136
|
+
param_name = param.name.value if isinstance(param.name, cst.Name) else None
|
|
137
|
+
|
|
138
|
+
if param_name == "tmpdir":
|
|
139
|
+
self.record_change(
|
|
140
|
+
description="Convert tmpdir fixture to tmp_path (pathlib.Path)",
|
|
141
|
+
line_number=1,
|
|
142
|
+
original="def func(tmpdir):",
|
|
143
|
+
replacement="def func(tmp_path):",
|
|
144
|
+
transform_name="tmpdir_to_tmp_path",
|
|
145
|
+
)
|
|
146
|
+
new_param = param.with_changes(name=cst.Name("tmp_path"))
|
|
147
|
+
new_params.append(new_param)
|
|
148
|
+
changed = True
|
|
149
|
+
elif param_name == "tmpdir_factory":
|
|
150
|
+
self.record_change(
|
|
151
|
+
description="Convert tmpdir_factory fixture to tmp_path_factory",
|
|
152
|
+
line_number=1,
|
|
153
|
+
original="def func(tmpdir_factory):",
|
|
154
|
+
replacement="def func(tmp_path_factory):",
|
|
155
|
+
transform_name="tmpdir_factory_to_tmp_path_factory",
|
|
156
|
+
)
|
|
157
|
+
new_param = param.with_changes(name=cst.Name("tmp_path_factory"))
|
|
158
|
+
new_params.append(new_param)
|
|
159
|
+
changed = True
|
|
160
|
+
else:
|
|
161
|
+
new_params.append(param)
|
|
162
|
+
|
|
163
|
+
if changed:
|
|
164
|
+
new_parameters = node.params.with_changes(params=new_params)
|
|
165
|
+
return node.with_changes(params=new_parameters)
|
|
166
|
+
|
|
167
|
+
return node
|
|
168
|
+
|
|
169
|
+
def _transform_hook_params(self, node: cst.FunctionDef) -> cst.FunctionDef:
|
|
170
|
+
"""Transform pytest hook parameters for 7.x/8.x compatibility."""
|
|
171
|
+
func_name = node.name.value
|
|
172
|
+
|
|
173
|
+
# Map of hook names to parameter renames
|
|
174
|
+
hook_param_renames: dict[str, dict[str, str]] = {
|
|
175
|
+
"pytest_collect_file": {"path": "file_path"},
|
|
176
|
+
"pytest_ignore_collect": {"path": "collection_path"},
|
|
177
|
+
"pytest_pycollect_makemodule": {"path": "module_path"},
|
|
178
|
+
"pytest_report_header": {"startdir": "start_path"},
|
|
179
|
+
"pytest_report_collectionfinish": {"startdir": "start_path"},
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
if func_name not in hook_param_renames:
|
|
183
|
+
return node
|
|
184
|
+
|
|
185
|
+
param_renames = hook_param_renames[func_name]
|
|
186
|
+
new_params = []
|
|
187
|
+
changed = False
|
|
188
|
+
|
|
189
|
+
for param in node.params.params:
|
|
190
|
+
param_name = param.name.value if isinstance(param.name, cst.Name) else None
|
|
191
|
+
|
|
192
|
+
if param_name in param_renames:
|
|
193
|
+
new_name = param_renames[param_name]
|
|
194
|
+
self.record_change(
|
|
195
|
+
description=f"Rename {func_name} parameter '{param_name}' to '{new_name}'",
|
|
196
|
+
line_number=1,
|
|
197
|
+
original=f"def {func_name}({param_name}):",
|
|
198
|
+
replacement=f"def {func_name}({new_name}):",
|
|
199
|
+
transform_name=f"hook_{param_name}_to_{new_name}",
|
|
200
|
+
)
|
|
201
|
+
new_param = param.with_changes(name=cst.Name(new_name))
|
|
202
|
+
new_params.append(new_param)
|
|
203
|
+
changed = True
|
|
204
|
+
else:
|
|
205
|
+
new_params.append(param)
|
|
206
|
+
|
|
207
|
+
if changed:
|
|
208
|
+
new_parameters = node.params.with_changes(params=new_params)
|
|
209
|
+
return node.with_changes(params=new_parameters)
|
|
210
|
+
|
|
211
|
+
return node
|
|
212
|
+
|
|
213
|
+
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.BaseExpression:
|
|
214
|
+
"""Transform pytest function calls."""
|
|
215
|
+
# Handle pytest.skip(msg=...), pytest.fail(msg=...), pytest.exit(msg=...)
|
|
216
|
+
updated_node = self._transform_msg_to_reason(updated_node)
|
|
217
|
+
|
|
218
|
+
# Handle pytest.warns(None)
|
|
219
|
+
updated_node = self._transform_warns_none(updated_node)
|
|
220
|
+
|
|
221
|
+
return updated_node
|
|
222
|
+
|
|
223
|
+
def _transform_msg_to_reason(self, node: cst.Call) -> cst.Call:
|
|
224
|
+
"""Transform msg parameter to reason in pytest.skip/fail/exit."""
|
|
225
|
+
# Check if this is a pytest.skip, pytest.fail, or pytest.exit call
|
|
226
|
+
if not isinstance(node.func, cst.Attribute):
|
|
227
|
+
return node
|
|
228
|
+
|
|
229
|
+
if not isinstance(node.func.value, cst.Name) or node.func.value.value != "pytest":
|
|
230
|
+
return node
|
|
231
|
+
|
|
232
|
+
func_name = node.func.attr.value
|
|
233
|
+
if func_name not in ("skip", "fail", "exit"):
|
|
234
|
+
return node
|
|
235
|
+
|
|
236
|
+
# Look for msg= parameter and rename to reason=
|
|
237
|
+
new_args = []
|
|
238
|
+
changed = False
|
|
239
|
+
|
|
240
|
+
for arg in node.args:
|
|
241
|
+
if isinstance(arg.keyword, cst.Name) and arg.keyword.value == "msg":
|
|
242
|
+
self.record_change(
|
|
243
|
+
description=f"Rename 'msg' parameter to 'reason' in pytest.{func_name}()",
|
|
244
|
+
line_number=1,
|
|
245
|
+
original=f"pytest.{func_name}(msg=...)",
|
|
246
|
+
replacement=f"pytest.{func_name}(reason=...)",
|
|
247
|
+
transform_name=f"{func_name}_msg_to_reason",
|
|
248
|
+
)
|
|
249
|
+
new_arg = arg.with_changes(keyword=cst.Name("reason"))
|
|
250
|
+
new_args.append(new_arg)
|
|
251
|
+
changed = True
|
|
252
|
+
else:
|
|
253
|
+
new_args.append(arg)
|
|
254
|
+
|
|
255
|
+
if changed:
|
|
256
|
+
return node.with_changes(args=new_args)
|
|
257
|
+
|
|
258
|
+
return node
|
|
259
|
+
|
|
260
|
+
def _transform_warns_none(self, node: cst.Call) -> cst.Call:
|
|
261
|
+
"""Transform pytest.warns(None) to pytest.warns()."""
|
|
262
|
+
# Check if this is pytest.warns(None)
|
|
263
|
+
if not isinstance(node.func, cst.Attribute):
|
|
264
|
+
return node
|
|
265
|
+
|
|
266
|
+
if not isinstance(node.func.value, cst.Name) or node.func.value.value != "pytest":
|
|
267
|
+
return node
|
|
268
|
+
|
|
269
|
+
if node.func.attr.value != "warns":
|
|
270
|
+
return node
|
|
271
|
+
|
|
272
|
+
# Check if the first argument is None
|
|
273
|
+
if (
|
|
274
|
+
len(node.args) >= 1
|
|
275
|
+
and isinstance(node.args[0].value, cst.Name)
|
|
276
|
+
and node.args[0].value.value == "None"
|
|
277
|
+
and node.args[0].keyword is None
|
|
278
|
+
):
|
|
279
|
+
self.record_change(
|
|
280
|
+
description="Convert pytest.warns(None) to pytest.warns()",
|
|
281
|
+
line_number=1,
|
|
282
|
+
original="pytest.warns(None)",
|
|
283
|
+
replacement="pytest.warns()",
|
|
284
|
+
transform_name="warns_none_to_warns",
|
|
285
|
+
)
|
|
286
|
+
# Remove the None argument
|
|
287
|
+
new_args = list(node.args[1:])
|
|
288
|
+
return node.with_changes(args=new_args)
|
|
289
|
+
|
|
290
|
+
return node
|
|
291
|
+
|
|
292
|
+
def leave_Attribute(
|
|
293
|
+
self, original_node: cst.Attribute, updated_node: cst.Attribute
|
|
294
|
+
) -> cst.BaseExpression:
|
|
295
|
+
"""Transform attribute accesses like .fspath to .path and funcargnames to fixturenames."""
|
|
296
|
+
attr_name = updated_node.attr.value
|
|
297
|
+
|
|
298
|
+
# Transform .fspath to .path
|
|
299
|
+
if attr_name == "fspath":
|
|
300
|
+
self.record_change(
|
|
301
|
+
description="Convert .fspath (py.path.local) to .path (pathlib.Path)",
|
|
302
|
+
line_number=1,
|
|
303
|
+
original=".fspath",
|
|
304
|
+
replacement=".path",
|
|
305
|
+
transform_name="fspath_to_path",
|
|
306
|
+
)
|
|
307
|
+
return updated_node.with_changes(attr=cst.Name("path"))
|
|
308
|
+
|
|
309
|
+
# Transform .funcargnames to .fixturenames
|
|
310
|
+
if attr_name == "funcargnames":
|
|
311
|
+
self.record_change(
|
|
312
|
+
description="Convert .funcargnames to .fixturenames",
|
|
313
|
+
line_number=1,
|
|
314
|
+
original=".funcargnames",
|
|
315
|
+
replacement=".fixturenames",
|
|
316
|
+
transform_name="funcargnames_to_fixturenames",
|
|
317
|
+
)
|
|
318
|
+
return updated_node.with_changes(attr=cst.Name("fixturenames"))
|
|
319
|
+
|
|
320
|
+
return updated_node
|
|
321
|
+
|
|
322
|
+
def leave_Name(self, original_node: cst.Name, updated_node: cst.Name) -> cst.BaseExpression:
|
|
323
|
+
"""Transform bare name references like tmpdir inside function bodies."""
|
|
324
|
+
# Note: This is a simplified transform. In a real scenario, we'd need more
|
|
325
|
+
# context to determine if 'tmpdir' is a variable reference to the fixture.
|
|
326
|
+
# For safety, we only transform fixture parameters, not variable uses.
|
|
327
|
+
return updated_node
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def transform_pytest(source_code: str) -> tuple[str, list]:
|
|
331
|
+
"""Transform pytest code from 6.x to 7.x/8.x.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
source_code: The source code to transform
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
Tuple of (transformed_code, list of changes)
|
|
338
|
+
"""
|
|
339
|
+
try:
|
|
340
|
+
tree = cst.parse_module(source_code)
|
|
341
|
+
except cst.ParserSyntaxError as e:
|
|
342
|
+
raise SyntaxError(f"Invalid Python syntax: {e}") from e
|
|
343
|
+
|
|
344
|
+
transformer = PytestTransformer()
|
|
345
|
+
transformer.set_source(source_code)
|
|
346
|
+
|
|
347
|
+
try:
|
|
348
|
+
transformed_tree = tree.visit(transformer)
|
|
349
|
+
return transformed_tree.code, transformer.changes
|
|
350
|
+
except Exception as e:
|
|
351
|
+
raise RuntimeError(f"Transform failed: {e}") from e
|
|
@@ -10,16 +10,42 @@ class RequestsTransformer(BaseTransformer):
|
|
|
10
10
|
|
|
11
11
|
def __init__(self) -> None:
|
|
12
12
|
super().__init__()
|
|
13
|
+
self._imports_to_add: list[str] = []
|
|
13
14
|
|
|
14
15
|
def leave_ImportFrom(
|
|
15
16
|
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
|
|
16
|
-
) -> cst.ImportFrom:
|
|
17
|
+
) -> cst.ImportFrom | cst.RemovalSentinel:
|
|
17
18
|
"""Transform requests imports."""
|
|
18
19
|
if original_node.module is None:
|
|
19
20
|
return updated_node
|
|
20
21
|
|
|
21
22
|
module_name = self._get_module_name(original_node.module)
|
|
22
23
|
|
|
24
|
+
# Transform 'from requests.packages import urllib3' to 'import urllib3'
|
|
25
|
+
if module_name == "requests.packages":
|
|
26
|
+
if isinstance(updated_node.names, cst.ImportStar):
|
|
27
|
+
return updated_node
|
|
28
|
+
|
|
29
|
+
remaining_names = []
|
|
30
|
+
for name in updated_node.names:
|
|
31
|
+
if isinstance(name, cst.ImportAlias) and isinstance(name.name, cst.Name):
|
|
32
|
+
if name.name.value == "urllib3":
|
|
33
|
+
self.record_change(
|
|
34
|
+
description="Import urllib3 directly instead of through requests.packages",
|
|
35
|
+
line_number=1,
|
|
36
|
+
original="from requests.packages import urllib3",
|
|
37
|
+
replacement="import urllib3",
|
|
38
|
+
transform_name="urllib3_top_level_import_fix",
|
|
39
|
+
)
|
|
40
|
+
self._imports_to_add.append("urllib3")
|
|
41
|
+
continue
|
|
42
|
+
remaining_names.append(name)
|
|
43
|
+
|
|
44
|
+
if len(remaining_names) == 0:
|
|
45
|
+
return cst.RemovalSentinel.REMOVE
|
|
46
|
+
elif len(remaining_names) < len(updated_node.names):
|
|
47
|
+
return updated_node.with_changes(names=remaining_names)
|
|
48
|
+
|
|
23
49
|
# Transform requests.packages.urllib3 imports
|
|
24
50
|
if module_name == "requests.packages.urllib3" or module_name.startswith(
|
|
25
51
|
"requests.packages.urllib3."
|
|
@@ -146,6 +172,53 @@ class RequestsTransformer(BaseTransformer):
|
|
|
146
172
|
|
|
147
173
|
return updated_node
|
|
148
174
|
|
|
175
|
+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
|
176
|
+
"""Add new imports at the top of the module."""
|
|
177
|
+
if not self._imports_to_add:
|
|
178
|
+
return updated_node
|
|
179
|
+
|
|
180
|
+
# Create new import statements
|
|
181
|
+
new_imports = []
|
|
182
|
+
for module_name in self._imports_to_add:
|
|
183
|
+
new_import = cst.SimpleStatementLine(
|
|
184
|
+
body=[cst.Import(names=[cst.ImportAlias(name=cst.Name(module_name))])]
|
|
185
|
+
)
|
|
186
|
+
new_imports.append(new_import)
|
|
187
|
+
|
|
188
|
+
# Find the position to insert - after docstrings and __future__ imports
|
|
189
|
+
insert_pos = 0
|
|
190
|
+
for i, statement in enumerate(updated_node.body):
|
|
191
|
+
if isinstance(statement, cst.SimpleStatementLine):
|
|
192
|
+
# Check if it's a docstring
|
|
193
|
+
if (
|
|
194
|
+
i == 0
|
|
195
|
+
and len(statement.body) == 1
|
|
196
|
+
and isinstance(statement.body[0], cst.Expr)
|
|
197
|
+
and isinstance(
|
|
198
|
+
statement.body[0].value, cst.SimpleString | cst.ConcatenatedString
|
|
199
|
+
)
|
|
200
|
+
):
|
|
201
|
+
insert_pos = i + 1
|
|
202
|
+
continue
|
|
203
|
+
# Check if it's a __future__ import
|
|
204
|
+
if len(statement.body) == 1 and isinstance(statement.body[0], cst.ImportFrom):
|
|
205
|
+
import_from = statement.body[0]
|
|
206
|
+
if (
|
|
207
|
+
isinstance(import_from.module, cst.Name)
|
|
208
|
+
and import_from.module.value == "__future__"
|
|
209
|
+
):
|
|
210
|
+
insert_pos = i + 1
|
|
211
|
+
continue
|
|
212
|
+
break
|
|
213
|
+
|
|
214
|
+
# Insert the new imports
|
|
215
|
+
new_body = (
|
|
216
|
+
list(updated_node.body[:insert_pos])
|
|
217
|
+
+ new_imports
|
|
218
|
+
+ list(updated_node.body[insert_pos:])
|
|
219
|
+
)
|
|
220
|
+
return updated_node.with_changes(body=new_body)
|
|
221
|
+
|
|
149
222
|
def _get_module_name(self, module: cst.BaseExpression) -> str:
|
|
150
223
|
"""Get the full module name from a Name or Attribute node."""
|
|
151
224
|
if isinstance(module, cst.Name):
|