codeshift 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. codeshift/cli/commands/apply.py +24 -2
  2. codeshift/cli/package_manager.py +102 -0
  3. codeshift/knowledge/generator.py +11 -1
  4. codeshift/knowledge_base/libraries/aiohttp.yaml +186 -0
  5. codeshift/knowledge_base/libraries/attrs.yaml +181 -0
  6. codeshift/knowledge_base/libraries/celery.yaml +244 -0
  7. codeshift/knowledge_base/libraries/click.yaml +195 -0
  8. codeshift/knowledge_base/libraries/django.yaml +355 -0
  9. codeshift/knowledge_base/libraries/flask.yaml +270 -0
  10. codeshift/knowledge_base/libraries/httpx.yaml +183 -0
  11. codeshift/knowledge_base/libraries/marshmallow.yaml +238 -0
  12. codeshift/knowledge_base/libraries/numpy.yaml +429 -0
  13. codeshift/knowledge_base/libraries/pytest.yaml +192 -0
  14. codeshift/knowledge_base/libraries/sqlalchemy.yaml +2 -1
  15. codeshift/migrator/engine.py +60 -0
  16. codeshift/migrator/transforms/__init__.py +2 -0
  17. codeshift/migrator/transforms/aiohttp_transformer.py +608 -0
  18. codeshift/migrator/transforms/attrs_transformer.py +570 -0
  19. codeshift/migrator/transforms/celery_transformer.py +546 -0
  20. codeshift/migrator/transforms/click_transformer.py +526 -0
  21. codeshift/migrator/transforms/django_transformer.py +852 -0
  22. codeshift/migrator/transforms/fastapi_transformer.py +12 -7
  23. codeshift/migrator/transforms/flask_transformer.py +505 -0
  24. codeshift/migrator/transforms/httpx_transformer.py +419 -0
  25. codeshift/migrator/transforms/marshmallow_transformer.py +515 -0
  26. codeshift/migrator/transforms/numpy_transformer.py +413 -0
  27. codeshift/migrator/transforms/pydantic_v1_to_v2.py +53 -8
  28. codeshift/migrator/transforms/pytest_transformer.py +351 -0
  29. codeshift/migrator/transforms/requests_transformer.py +74 -1
  30. codeshift/migrator/transforms/sqlalchemy_transformer.py +692 -39
  31. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/METADATA +46 -4
  32. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/RECORD +36 -15
  33. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/WHEEL +0 -0
  34. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/entry_points.txt +0 -0
  35. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/licenses/LICENSE +0 -0
  36. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,351 @@
1
+ """Pytest 6.x to 7.x/8.x transformation using LibCST."""
2
+
3
+ import libcst as cst
4
+ from libcst import matchers as m
5
+
6
+ from codeshift.migrator.ast_transforms import BaseTransformer
7
+
8
+
9
+ class PytestTransformer(BaseTransformer):
10
+ """Transform pytest 6.x code to 7.x/8.x compatible code."""
11
+
12
+ def __init__(self) -> None:
13
+ super().__init__()
14
+ # Track current class context for setup/teardown transforms
15
+ self._in_test_class = False
16
+ self._current_class_name: str | None = None
17
+
18
+ def visit_ClassDef(self, node: cst.ClassDef) -> bool:
19
+ """Track when we enter a test class."""
20
+ class_name = node.name.value
21
+ # Check if this looks like a test class
22
+ if class_name.startswith("Test") or any(
23
+ isinstance(base.value, cst.Name) and base.value.value == "TestCase"
24
+ for base in node.bases
25
+ if isinstance(base, cst.Arg)
26
+ ):
27
+ self._in_test_class = True
28
+ self._current_class_name = class_name
29
+ return True
30
+
31
+ def leave_ClassDef(
32
+ self, original_node: cst.ClassDef, updated_node: cst.ClassDef
33
+ ) -> cst.ClassDef:
34
+ """Reset class tracking when leaving a class."""
35
+ self._in_test_class = False
36
+ self._current_class_name = None
37
+ return updated_node
38
+
39
+ def leave_Decorator(
40
+ self, original_node: cst.Decorator, updated_node: cst.Decorator
41
+ ) -> cst.Decorator:
42
+ """Transform @pytest.yield_fixture to @pytest.fixture."""
43
+ # Handle @pytest.yield_fixture
44
+ if m.matches(
45
+ updated_node.decorator,
46
+ m.Attribute(
47
+ value=m.Name("pytest"),
48
+ attr=m.Name("yield_fixture"),
49
+ ),
50
+ ):
51
+ self.record_change(
52
+ description="Convert @pytest.yield_fixture to @pytest.fixture",
53
+ line_number=1,
54
+ original="@pytest.yield_fixture",
55
+ replacement="@pytest.fixture",
56
+ transform_name="yield_fixture_to_fixture",
57
+ )
58
+ new_attr = cst.Attribute(
59
+ value=cst.Name("pytest"),
60
+ attr=cst.Name("fixture"),
61
+ )
62
+ return updated_node.with_changes(decorator=new_attr)
63
+
64
+ # Handle @pytest.yield_fixture(...)
65
+ if m.matches(
66
+ updated_node.decorator,
67
+ m.Call(
68
+ func=m.Attribute(
69
+ value=m.Name("pytest"),
70
+ attr=m.Name("yield_fixture"),
71
+ ),
72
+ ),
73
+ ):
74
+ assert isinstance(updated_node.decorator, cst.Call)
75
+ self.record_change(
76
+ description="Convert @pytest.yield_fixture(...) to @pytest.fixture(...)",
77
+ line_number=1,
78
+ original="@pytest.yield_fixture(...)",
79
+ replacement="@pytest.fixture(...)",
80
+ transform_name="yield_fixture_to_fixture",
81
+ )
82
+ new_func = cst.Attribute(
83
+ value=cst.Name("pytest"),
84
+ attr=cst.Name("fixture"),
85
+ )
86
+ new_call = updated_node.decorator.with_changes(func=new_func)
87
+ return updated_node.with_changes(decorator=new_call)
88
+
89
+ return updated_node
90
+
91
+ def leave_FunctionDef(
92
+ self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
93
+ ) -> cst.FunctionDef:
94
+ """Transform setup/teardown methods in test classes and fixture parameters."""
95
+ # Handle tmpdir/tmpdir_factory fixture parameter renames
96
+ updated_node = self._transform_fixture_params(updated_node)
97
+
98
+ # Handle conftest hook parameter renames
99
+ updated_node = self._transform_hook_params(updated_node)
100
+
101
+ # Handle setup/teardown method renames only in test classes
102
+ if self._in_test_class:
103
+ func_name = updated_node.name.value
104
+
105
+ if func_name == "setup":
106
+ self.record_change(
107
+ description="Rename setup() to setup_method() for pytest 8.x compatibility",
108
+ line_number=1,
109
+ original="def setup(self):",
110
+ replacement="def setup_method(self):",
111
+ transform_name="setup_to_setup_method",
112
+ )
113
+ return updated_node.with_changes(name=cst.Name("setup_method"))
114
+
115
+ if func_name == "teardown":
116
+ self.record_change(
117
+ description="Rename teardown() to teardown_method() for pytest 8.x compatibility",
118
+ line_number=1,
119
+ original="def teardown(self):",
120
+ replacement="def teardown_method(self):",
121
+ transform_name="teardown_to_teardown_method",
122
+ )
123
+ return updated_node.with_changes(name=cst.Name("teardown_method"))
124
+
125
+ return updated_node
126
+
127
+ def _transform_fixture_params(self, node: cst.FunctionDef) -> cst.FunctionDef:
128
+ """Transform tmpdir/tmpdir_factory fixture parameters to tmp_path/tmp_path_factory."""
129
+ if node.params.params is None:
130
+ return node
131
+
132
+ new_params = []
133
+ changed = False
134
+
135
+ for param in node.params.params:
136
+ param_name = param.name.value if isinstance(param.name, cst.Name) else None
137
+
138
+ if param_name == "tmpdir":
139
+ self.record_change(
140
+ description="Convert tmpdir fixture to tmp_path (pathlib.Path)",
141
+ line_number=1,
142
+ original="def func(tmpdir):",
143
+ replacement="def func(tmp_path):",
144
+ transform_name="tmpdir_to_tmp_path",
145
+ )
146
+ new_param = param.with_changes(name=cst.Name("tmp_path"))
147
+ new_params.append(new_param)
148
+ changed = True
149
+ elif param_name == "tmpdir_factory":
150
+ self.record_change(
151
+ description="Convert tmpdir_factory fixture to tmp_path_factory",
152
+ line_number=1,
153
+ original="def func(tmpdir_factory):",
154
+ replacement="def func(tmp_path_factory):",
155
+ transform_name="tmpdir_factory_to_tmp_path_factory",
156
+ )
157
+ new_param = param.with_changes(name=cst.Name("tmp_path_factory"))
158
+ new_params.append(new_param)
159
+ changed = True
160
+ else:
161
+ new_params.append(param)
162
+
163
+ if changed:
164
+ new_parameters = node.params.with_changes(params=new_params)
165
+ return node.with_changes(params=new_parameters)
166
+
167
+ return node
168
+
169
+ def _transform_hook_params(self, node: cst.FunctionDef) -> cst.FunctionDef:
170
+ """Transform pytest hook parameters for 7.x/8.x compatibility."""
171
+ func_name = node.name.value
172
+
173
+ # Map of hook names to parameter renames
174
+ hook_param_renames: dict[str, dict[str, str]] = {
175
+ "pytest_collect_file": {"path": "file_path"},
176
+ "pytest_ignore_collect": {"path": "collection_path"},
177
+ "pytest_pycollect_makemodule": {"path": "module_path"},
178
+ "pytest_report_header": {"startdir": "start_path"},
179
+ "pytest_report_collectionfinish": {"startdir": "start_path"},
180
+ }
181
+
182
+ if func_name not in hook_param_renames:
183
+ return node
184
+
185
+ param_renames = hook_param_renames[func_name]
186
+ new_params = []
187
+ changed = False
188
+
189
+ for param in node.params.params:
190
+ param_name = param.name.value if isinstance(param.name, cst.Name) else None
191
+
192
+ if param_name in param_renames:
193
+ new_name = param_renames[param_name]
194
+ self.record_change(
195
+ description=f"Rename {func_name} parameter '{param_name}' to '{new_name}'",
196
+ line_number=1,
197
+ original=f"def {func_name}({param_name}):",
198
+ replacement=f"def {func_name}({new_name}):",
199
+ transform_name=f"hook_{param_name}_to_{new_name}",
200
+ )
201
+ new_param = param.with_changes(name=cst.Name(new_name))
202
+ new_params.append(new_param)
203
+ changed = True
204
+ else:
205
+ new_params.append(param)
206
+
207
+ if changed:
208
+ new_parameters = node.params.with_changes(params=new_params)
209
+ return node.with_changes(params=new_parameters)
210
+
211
+ return node
212
+
213
+ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.BaseExpression:
214
+ """Transform pytest function calls."""
215
+ # Handle pytest.skip(msg=...), pytest.fail(msg=...), pytest.exit(msg=...)
216
+ updated_node = self._transform_msg_to_reason(updated_node)
217
+
218
+ # Handle pytest.warns(None)
219
+ updated_node = self._transform_warns_none(updated_node)
220
+
221
+ return updated_node
222
+
223
+ def _transform_msg_to_reason(self, node: cst.Call) -> cst.Call:
224
+ """Transform msg parameter to reason in pytest.skip/fail/exit."""
225
+ # Check if this is a pytest.skip, pytest.fail, or pytest.exit call
226
+ if not isinstance(node.func, cst.Attribute):
227
+ return node
228
+
229
+ if not isinstance(node.func.value, cst.Name) or node.func.value.value != "pytest":
230
+ return node
231
+
232
+ func_name = node.func.attr.value
233
+ if func_name not in ("skip", "fail", "exit"):
234
+ return node
235
+
236
+ # Look for msg= parameter and rename to reason=
237
+ new_args = []
238
+ changed = False
239
+
240
+ for arg in node.args:
241
+ if isinstance(arg.keyword, cst.Name) and arg.keyword.value == "msg":
242
+ self.record_change(
243
+ description=f"Rename 'msg' parameter to 'reason' in pytest.{func_name}()",
244
+ line_number=1,
245
+ original=f"pytest.{func_name}(msg=...)",
246
+ replacement=f"pytest.{func_name}(reason=...)",
247
+ transform_name=f"{func_name}_msg_to_reason",
248
+ )
249
+ new_arg = arg.with_changes(keyword=cst.Name("reason"))
250
+ new_args.append(new_arg)
251
+ changed = True
252
+ else:
253
+ new_args.append(arg)
254
+
255
+ if changed:
256
+ return node.with_changes(args=new_args)
257
+
258
+ return node
259
+
260
+ def _transform_warns_none(self, node: cst.Call) -> cst.Call:
261
+ """Transform pytest.warns(None) to pytest.warns()."""
262
+ # Check if this is pytest.warns(None)
263
+ if not isinstance(node.func, cst.Attribute):
264
+ return node
265
+
266
+ if not isinstance(node.func.value, cst.Name) or node.func.value.value != "pytest":
267
+ return node
268
+
269
+ if node.func.attr.value != "warns":
270
+ return node
271
+
272
+ # Check if the first argument is None
273
+ if (
274
+ len(node.args) >= 1
275
+ and isinstance(node.args[0].value, cst.Name)
276
+ and node.args[0].value.value == "None"
277
+ and node.args[0].keyword is None
278
+ ):
279
+ self.record_change(
280
+ description="Convert pytest.warns(None) to pytest.warns()",
281
+ line_number=1,
282
+ original="pytest.warns(None)",
283
+ replacement="pytest.warns()",
284
+ transform_name="warns_none_to_warns",
285
+ )
286
+ # Remove the None argument
287
+ new_args = list(node.args[1:])
288
+ return node.with_changes(args=new_args)
289
+
290
+ return node
291
+
292
+ def leave_Attribute(
293
+ self, original_node: cst.Attribute, updated_node: cst.Attribute
294
+ ) -> cst.BaseExpression:
295
+ """Transform attribute accesses like .fspath to .path and funcargnames to fixturenames."""
296
+ attr_name = updated_node.attr.value
297
+
298
+ # Transform .fspath to .path
299
+ if attr_name == "fspath":
300
+ self.record_change(
301
+ description="Convert .fspath (py.path.local) to .path (pathlib.Path)",
302
+ line_number=1,
303
+ original=".fspath",
304
+ replacement=".path",
305
+ transform_name="fspath_to_path",
306
+ )
307
+ return updated_node.with_changes(attr=cst.Name("path"))
308
+
309
+ # Transform .funcargnames to .fixturenames
310
+ if attr_name == "funcargnames":
311
+ self.record_change(
312
+ description="Convert .funcargnames to .fixturenames",
313
+ line_number=1,
314
+ original=".funcargnames",
315
+ replacement=".fixturenames",
316
+ transform_name="funcargnames_to_fixturenames",
317
+ )
318
+ return updated_node.with_changes(attr=cst.Name("fixturenames"))
319
+
320
+ return updated_node
321
+
322
+ def leave_Name(self, original_node: cst.Name, updated_node: cst.Name) -> cst.BaseExpression:
323
+ """Transform bare name references like tmpdir inside function bodies."""
324
+ # Note: This is a simplified transform. In a real scenario, we'd need more
325
+ # context to determine if 'tmpdir' is a variable reference to the fixture.
326
+ # For safety, we only transform fixture parameters, not variable uses.
327
+ return updated_node
328
+
329
+
330
+ def transform_pytest(source_code: str) -> tuple[str, list]:
331
+ """Transform pytest code from 6.x to 7.x/8.x.
332
+
333
+ Args:
334
+ source_code: The source code to transform
335
+
336
+ Returns:
337
+ Tuple of (transformed_code, list of changes)
338
+ """
339
+ try:
340
+ tree = cst.parse_module(source_code)
341
+ except cst.ParserSyntaxError as e:
342
+ raise SyntaxError(f"Invalid Python syntax: {e}") from e
343
+
344
+ transformer = PytestTransformer()
345
+ transformer.set_source(source_code)
346
+
347
+ try:
348
+ transformed_tree = tree.visit(transformer)
349
+ return transformed_tree.code, transformer.changes
350
+ except Exception as e:
351
+ raise RuntimeError(f"Transform failed: {e}") from e
@@ -10,16 +10,42 @@ class RequestsTransformer(BaseTransformer):
10
10
 
11
11
  def __init__(self) -> None:
12
12
  super().__init__()
13
+ self._imports_to_add: list[str] = []
13
14
 
14
15
  def leave_ImportFrom(
15
16
  self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
16
- ) -> cst.ImportFrom:
17
+ ) -> cst.ImportFrom | cst.RemovalSentinel:
17
18
  """Transform requests imports."""
18
19
  if original_node.module is None:
19
20
  return updated_node
20
21
 
21
22
  module_name = self._get_module_name(original_node.module)
22
23
 
24
+ # Transform 'from requests.packages import urllib3' to 'import urllib3'
25
+ if module_name == "requests.packages":
26
+ if isinstance(updated_node.names, cst.ImportStar):
27
+ return updated_node
28
+
29
+ remaining_names = []
30
+ for name in updated_node.names:
31
+ if isinstance(name, cst.ImportAlias) and isinstance(name.name, cst.Name):
32
+ if name.name.value == "urllib3":
33
+ self.record_change(
34
+ description="Import urllib3 directly instead of through requests.packages",
35
+ line_number=1,
36
+ original="from requests.packages import urllib3",
37
+ replacement="import urllib3",
38
+ transform_name="urllib3_top_level_import_fix",
39
+ )
40
+ self._imports_to_add.append("urllib3")
41
+ continue
42
+ remaining_names.append(name)
43
+
44
+ if len(remaining_names) == 0:
45
+ return cst.RemovalSentinel.REMOVE
46
+ elif len(remaining_names) < len(updated_node.names):
47
+ return updated_node.with_changes(names=remaining_names)
48
+
23
49
  # Transform requests.packages.urllib3 imports
24
50
  if module_name == "requests.packages.urllib3" or module_name.startswith(
25
51
  "requests.packages.urllib3."
@@ -146,6 +172,53 @@ class RequestsTransformer(BaseTransformer):
146
172
 
147
173
  return updated_node
148
174
 
175
+ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
176
+ """Add new imports at the top of the module."""
177
+ if not self._imports_to_add:
178
+ return updated_node
179
+
180
+ # Create new import statements
181
+ new_imports = []
182
+ for module_name in self._imports_to_add:
183
+ new_import = cst.SimpleStatementLine(
184
+ body=[cst.Import(names=[cst.ImportAlias(name=cst.Name(module_name))])]
185
+ )
186
+ new_imports.append(new_import)
187
+
188
+ # Find the position to insert - after docstrings and __future__ imports
189
+ insert_pos = 0
190
+ for i, statement in enumerate(updated_node.body):
191
+ if isinstance(statement, cst.SimpleStatementLine):
192
+ # Check if it's a docstring
193
+ if (
194
+ i == 0
195
+ and len(statement.body) == 1
196
+ and isinstance(statement.body[0], cst.Expr)
197
+ and isinstance(
198
+ statement.body[0].value, cst.SimpleString | cst.ConcatenatedString
199
+ )
200
+ ):
201
+ insert_pos = i + 1
202
+ continue
203
+ # Check if it's a __future__ import
204
+ if len(statement.body) == 1 and isinstance(statement.body[0], cst.ImportFrom):
205
+ import_from = statement.body[0]
206
+ if (
207
+ isinstance(import_from.module, cst.Name)
208
+ and import_from.module.value == "__future__"
209
+ ):
210
+ insert_pos = i + 1
211
+ continue
212
+ break
213
+
214
+ # Insert the new imports
215
+ new_body = (
216
+ list(updated_node.body[:insert_pos])
217
+ + new_imports
218
+ + list(updated_node.body[insert_pos:])
219
+ )
220
+ return updated_node.with_changes(body=new_body)
221
+
149
222
  def _get_module_name(self, module: cst.BaseExpression) -> str:
150
223
  """Get the full module name from a Name or Attribute node."""
151
224
  if isinstance(module, cst.Name):