codeshift 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. codeshift/cli/commands/apply.py +24 -2
  2. codeshift/cli/package_manager.py +102 -0
  3. codeshift/knowledge/generator.py +11 -1
  4. codeshift/knowledge_base/libraries/aiohttp.yaml +186 -0
  5. codeshift/knowledge_base/libraries/attrs.yaml +181 -0
  6. codeshift/knowledge_base/libraries/celery.yaml +244 -0
  7. codeshift/knowledge_base/libraries/click.yaml +195 -0
  8. codeshift/knowledge_base/libraries/django.yaml +355 -0
  9. codeshift/knowledge_base/libraries/flask.yaml +270 -0
  10. codeshift/knowledge_base/libraries/httpx.yaml +183 -0
  11. codeshift/knowledge_base/libraries/marshmallow.yaml +238 -0
  12. codeshift/knowledge_base/libraries/numpy.yaml +429 -0
  13. codeshift/knowledge_base/libraries/pytest.yaml +192 -0
  14. codeshift/knowledge_base/libraries/sqlalchemy.yaml +2 -1
  15. codeshift/migrator/engine.py +60 -0
  16. codeshift/migrator/transforms/__init__.py +2 -0
  17. codeshift/migrator/transforms/aiohttp_transformer.py +608 -0
  18. codeshift/migrator/transforms/attrs_transformer.py +570 -0
  19. codeshift/migrator/transforms/celery_transformer.py +546 -0
  20. codeshift/migrator/transforms/click_transformer.py +526 -0
  21. codeshift/migrator/transforms/django_transformer.py +852 -0
  22. codeshift/migrator/transforms/fastapi_transformer.py +12 -7
  23. codeshift/migrator/transforms/flask_transformer.py +505 -0
  24. codeshift/migrator/transforms/httpx_transformer.py +419 -0
  25. codeshift/migrator/transforms/marshmallow_transformer.py +515 -0
  26. codeshift/migrator/transforms/numpy_transformer.py +413 -0
  27. codeshift/migrator/transforms/pydantic_v1_to_v2.py +53 -8
  28. codeshift/migrator/transforms/pytest_transformer.py +351 -0
  29. codeshift/migrator/transforms/requests_transformer.py +74 -1
  30. codeshift/migrator/transforms/sqlalchemy_transformer.py +692 -39
  31. {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/METADATA +46 -4
  32. {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/RECORD +36 -15
  33. {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/WHEEL +0 -0
  34. {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/entry_points.txt +0 -0
  35. {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/licenses/LICENSE +0 -0
  36. {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,515 @@
1
+ """Marshmallow 2.x to 3.x transformation using LibCST."""
2
+
3
+ import libcst as cst
4
+
5
+ from codeshift.migrator.ast_transforms import BaseTransformer
6
+
7
+
8
+ class MarshmallowTransformer(BaseTransformer):
9
+ """Transform Marshmallow 2.x code to 3.x."""
10
+
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+ # Track methods decorated with pass_many for signature update
14
+ self._methods_needing_kwargs: set[str] = set()
15
+ # Track current decorator being processed
16
+ self._current_decorator_has_pass_many = False
17
+
18
+ def leave_Decorator(
19
+ self, original_node: cst.Decorator, updated_node: cst.Decorator
20
+ ) -> cst.Decorator:
21
+ """Transform decorators that have pass_many parameter.
22
+
23
+ Handles:
24
+ - @post_load(pass_many=True) -> @post_load
25
+ - @pre_load(pass_many=True) -> @pre_load
26
+ - @post_dump(pass_many=True) -> @post_dump
27
+ - @pre_dump(pass_many=True) -> @pre_dump
28
+ - @validates_schema(pass_many=True) -> @validates_schema
29
+ """
30
+ self._current_decorator_has_pass_many = False
31
+
32
+ # Check if decorator is a Call with pass_many argument
33
+ if not isinstance(updated_node.decorator, cst.Call):
34
+ return updated_node
35
+
36
+ call = updated_node.decorator
37
+ func_name = self._get_decorator_name(call.func)
38
+
39
+ if func_name not in {
40
+ "post_load",
41
+ "pre_load",
42
+ "post_dump",
43
+ "pre_dump",
44
+ "validates_schema",
45
+ }:
46
+ return updated_node
47
+
48
+ # Check for pass_many argument
49
+ new_args = []
50
+ found_pass_many = False
51
+
52
+ for arg in call.args:
53
+ if isinstance(arg.keyword, cst.Name) and arg.keyword.value == "pass_many":
54
+ found_pass_many = True
55
+ self._current_decorator_has_pass_many = True
56
+ self.record_change(
57
+ description=f"Remove pass_many parameter from @{func_name}",
58
+ line_number=1,
59
+ original=f"@{func_name}(pass_many=True)",
60
+ replacement=f"@{func_name}",
61
+ transform_name=f"{func_name}_pass_many",
62
+ )
63
+ else:
64
+ new_args.append(arg)
65
+
66
+ if found_pass_many:
67
+ if new_args:
68
+ # Still have other arguments, keep as call
69
+ return updated_node.with_changes(decorator=call.with_changes(args=new_args))
70
+ else:
71
+ # No more arguments, simplify to just the name
72
+ return updated_node.with_changes(decorator=call.func)
73
+
74
+ return updated_node
75
+
76
+ def leave_FunctionDef(
77
+ self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
78
+ ) -> cst.FunctionDef:
79
+ """Add **kwargs to methods that had pass_many decorators."""
80
+ # Check if this function has a marshmallow decorator that had pass_many
81
+ has_marshmallow_decorator = False
82
+ for decorator in updated_node.decorators:
83
+ dec = decorator.decorator
84
+ func_name = None
85
+ if isinstance(dec, cst.Call):
86
+ func_name = self._get_decorator_name(dec.func)
87
+ elif isinstance(dec, cst.Name):
88
+ func_name = dec.value
89
+
90
+ if func_name in {
91
+ "post_load",
92
+ "pre_load",
93
+ "post_dump",
94
+ "pre_dump",
95
+ "validates_schema",
96
+ }:
97
+ has_marshmallow_decorator = True
98
+ break
99
+
100
+ if not has_marshmallow_decorator:
101
+ return updated_node
102
+
103
+ # Check if **kwargs already exists
104
+ params = updated_node.params
105
+ if params.star_kwarg is not None:
106
+ return updated_node
107
+
108
+ # Add **kwargs to the parameters
109
+ new_star_kwarg = cst.Param(name=cst.Name("kwargs"))
110
+ new_params = params.with_changes(star_kwarg=new_star_kwarg)
111
+
112
+ self.record_change(
113
+ description="Add **kwargs to method signature for many/partial args",
114
+ line_number=1,
115
+ original=f"def {updated_node.name.value}(self, ...)",
116
+ replacement=f"def {updated_node.name.value}(self, ..., **kwargs)",
117
+ transform_name="add_kwargs_to_decorated_method",
118
+ )
119
+
120
+ return updated_node.with_changes(params=new_params)
121
+
122
+ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.BaseExpression:
123
+ """Transform method calls and field instantiations."""
124
+ # Handle Schema().dump().data and Schema().load().data patterns
125
+ transformed = self._transform_data_access(updated_node)
126
+ if transformed is not updated_node:
127
+ return transformed
128
+
129
+ # Handle Field parameter renames
130
+ transformed = self._transform_field_params(updated_node)
131
+ if transformed is not updated_node:
132
+ return transformed
133
+
134
+ # Handle Schema instantiation with strict parameter
135
+ transformed = self._transform_schema_instantiation(updated_node)
136
+ if transformed is not updated_node:
137
+ return transformed
138
+
139
+ # Handle self.fail() -> raise self.make_error()
140
+ transformed = self._transform_fail_to_make_error(updated_node)
141
+ if transformed is not updated_node:
142
+ return transformed
143
+
144
+ return updated_node
145
+
146
+ def _transform_data_access(self, node: cst.Call) -> cst.BaseExpression:
147
+ """Transform schema.dump(obj).data and schema.load(data).data patterns.
148
+
149
+ In v2: result = schema.dump(obj).data
150
+ In v3: result = schema.dump(obj)
151
+ """
152
+ # This is handled in leave_Attribute for the .data access pattern
153
+ return node
154
+
155
+ def leave_Attribute(
156
+ self, original_node: cst.Attribute, updated_node: cst.Attribute
157
+ ) -> cst.BaseExpression:
158
+ """Transform .data attribute access on dump/load results."""
159
+ # Check if this is .data access
160
+ if updated_node.attr.value != "data":
161
+ return updated_node
162
+
163
+ # Check if the value is a Call to dump, load, dumps, or loads
164
+ if not isinstance(updated_node.value, cst.Call):
165
+ return updated_node
166
+
167
+ call = updated_node.value
168
+ if not isinstance(call.func, cst.Attribute):
169
+ return updated_node
170
+
171
+ method_name = call.func.attr.value
172
+ if method_name not in {"dump", "load", "dumps", "loads"}:
173
+ return updated_node
174
+
175
+ # This is schema.dump(obj).data or schema.load(data).data
176
+ # Transform to just schema.dump(obj) or schema.load(data)
177
+ self.record_change(
178
+ description=f"Remove .data access from {method_name}() - v3 returns data directly",
179
+ line_number=1,
180
+ original=f"schema.{method_name}(...).data",
181
+ replacement=f"schema.{method_name}(...)",
182
+ transform_name=f"{method_name}_data_to_{method_name}",
183
+ )
184
+
185
+ return call
186
+
187
+ def _transform_field_params(self, node: cst.Call) -> cst.Call:
188
+ """Transform field parameter renames.
189
+
190
+ - missing -> load_default
191
+ - default -> dump_default
192
+ - load_from -> data_key
193
+ - dump_to -> data_key
194
+ """
195
+ # Check if this is a fields.* call or a Field-like call
196
+ func_name = self._get_call_func_name(node.func)
197
+ if func_name is None:
198
+ return node
199
+
200
+ # Common field types and Field itself
201
+ field_types = {
202
+ "Field",
203
+ "String",
204
+ "Str",
205
+ "Integer",
206
+ "Int",
207
+ "Float",
208
+ "Boolean",
209
+ "Bool",
210
+ "DateTime",
211
+ "Date",
212
+ "Time",
213
+ "TimeDelta",
214
+ "Decimal",
215
+ "UUID",
216
+ "Email",
217
+ "URL",
218
+ "Url",
219
+ "Method",
220
+ "Function",
221
+ "Nested",
222
+ "List",
223
+ "Dict",
224
+ "Tuple",
225
+ "Mapping",
226
+ "Raw",
227
+ "Number",
228
+ "Pluck",
229
+ "Constant",
230
+ }
231
+
232
+ if func_name not in field_types:
233
+ return node
234
+
235
+ new_args = []
236
+ changed = False
237
+ param_mappings = {
238
+ "missing": "load_default",
239
+ "default": "dump_default",
240
+ "load_from": "data_key",
241
+ "dump_to": "data_key",
242
+ }
243
+
244
+ for arg in node.args:
245
+ if isinstance(arg.keyword, cst.Name) and arg.keyword.value in param_mappings:
246
+ old_name = arg.keyword.value
247
+ new_name = param_mappings[old_name]
248
+ new_arg = arg.with_changes(keyword=cst.Name(new_name))
249
+ new_args.append(new_arg)
250
+ changed = True
251
+
252
+ self.record_change(
253
+ description=f"Rename field parameter '{old_name}' to '{new_name}'",
254
+ line_number=1,
255
+ original=f"{func_name}({old_name}=...)",
256
+ replacement=f"{func_name}({new_name}=...)",
257
+ transform_name=f"{old_name}_to_{new_name}",
258
+ )
259
+ else:
260
+ new_args.append(arg)
261
+
262
+ if changed:
263
+ return node.with_changes(args=new_args)
264
+
265
+ return node
266
+
267
+ def _transform_schema_instantiation(self, node: cst.Call) -> cst.Call:
268
+ """Remove strict parameter from schema instantiation.
269
+
270
+ In v2: UserSchema(strict=True)
271
+ In v3: UserSchema() # strict is always True
272
+ """
273
+ # Check for strict argument
274
+ new_args = []
275
+ changed = False
276
+
277
+ for arg in node.args:
278
+ if isinstance(arg.keyword, cst.Name) and arg.keyword.value == "strict":
279
+ changed = True
280
+ self.record_change(
281
+ description="Remove 'strict' parameter - schemas are always strict in v3",
282
+ line_number=1,
283
+ original="Schema(strict=True)",
284
+ replacement="Schema()",
285
+ transform_name="remove_schema_strict",
286
+ )
287
+ else:
288
+ new_args.append(arg)
289
+
290
+ if changed:
291
+ # Fix trailing comma if needed
292
+ if new_args:
293
+ last_arg = new_args[-1]
294
+ if last_arg.comma != cst.MaybeSentinel.DEFAULT:
295
+ new_args[-1] = last_arg.with_changes(comma=cst.MaybeSentinel.DEFAULT)
296
+ return node.with_changes(args=new_args)
297
+
298
+ return node
299
+
300
+ def _transform_fail_to_make_error(self, node: cst.Call) -> cst.BaseExpression:
301
+ """Transform self.fail(key) to self.make_error(key).
302
+
303
+ Note: The caller should wrap this in a raise statement.
304
+ This transform returns the make_error call; wrapping in Raise
305
+ is handled separately if needed (or flagged for manual review).
306
+ """
307
+ if not isinstance(node.func, cst.Attribute):
308
+ return node
309
+
310
+ if node.func.attr.value != "fail":
311
+ return node
312
+
313
+ # Check if it's self.fail
314
+ if not isinstance(node.func.value, cst.Name):
315
+ return node
316
+
317
+ if node.func.value.value != "self":
318
+ return node
319
+
320
+ # Transform self.fail(...) to self.make_error(...)
321
+ new_func = node.func.with_changes(attr=cst.Name("make_error"))
322
+
323
+ self.record_change(
324
+ description="Replace self.fail() with self.make_error() - wrap in raise",
325
+ line_number=1,
326
+ original="self.fail(key)",
327
+ replacement="raise self.make_error(key)",
328
+ transform_name="fail_to_make_error",
329
+ notes="The call should be wrapped in a raise statement",
330
+ )
331
+
332
+ return node.with_changes(func=new_func)
333
+
334
+ def leave_ClassDef(
335
+ self, original_node: cst.ClassDef, updated_node: cst.ClassDef
336
+ ) -> cst.ClassDef:
337
+ """Transform Schema classes to remove Meta.strict and Meta.json_module."""
338
+ # Check if this class has a Meta inner class
339
+ meta_class = None
340
+ meta_index = -1
341
+ new_body = list(updated_node.body.body)
342
+
343
+ for i, item in enumerate(new_body):
344
+ if isinstance(item, cst.ClassDef) and item.name.value == "Meta":
345
+ meta_class = item
346
+ meta_index = i
347
+ break
348
+
349
+ if meta_class is None:
350
+ return updated_node
351
+
352
+ # Process Meta class body
353
+ transformed_meta = self._transform_meta_class(meta_class)
354
+ if transformed_meta is not meta_class:
355
+ new_body[meta_index] = transformed_meta
356
+ return updated_node.with_changes(body=updated_node.body.with_changes(body=new_body))
357
+
358
+ return updated_node
359
+
360
+ def _transform_meta_class(self, meta_class: cst.ClassDef) -> cst.ClassDef:
361
+ """Transform Meta class attributes.
362
+
363
+ - Remove strict = True
364
+ - Rename json_module to render_module
365
+ """
366
+ new_body_items: list[cst.BaseStatement | cst.BaseSmallStatement] = []
367
+ changed = False
368
+
369
+ for item in meta_class.body.body:
370
+ if isinstance(item, cst.SimpleStatementLine):
371
+ transformed_stmt, was_changed = self._transform_meta_statement(item)
372
+ if was_changed:
373
+ changed = True
374
+ if transformed_stmt is not None:
375
+ new_body_items.append(transformed_stmt)
376
+ else:
377
+ new_body_items.append(item)
378
+ else:
379
+ new_body_items.append(item)
380
+
381
+ if changed:
382
+ # If all statements were removed, add pass
383
+ if not new_body_items:
384
+ new_body_items = [cst.SimpleStatementLine(body=[cst.Pass()])]
385
+
386
+ return meta_class.with_changes(body=meta_class.body.with_changes(body=new_body_items))
387
+
388
+ return meta_class
389
+
390
+ def _transform_meta_statement(
391
+ self, stmt: cst.SimpleStatementLine
392
+ ) -> tuple[cst.SimpleStatementLine | None, bool]:
393
+ """Transform a Meta class statement.
394
+
395
+ Returns:
396
+ Tuple of (transformed_statement or None if removed, was_changed)
397
+ """
398
+ new_body: list[cst.BaseSmallStatement] = []
399
+ changed = False
400
+
401
+ for s in stmt.body:
402
+ if isinstance(s, cst.Assign):
403
+ skip_statement = False
404
+ transformed_statement = None
405
+
406
+ for target in s.targets:
407
+ if isinstance(target.target, cst.Name):
408
+ name = target.target.value
409
+
410
+ if name == "strict":
411
+ # Remove strict assignment
412
+ changed = True
413
+ skip_statement = True
414
+ self.record_change(
415
+ description="Remove Meta.strict - schemas are always strict in v3",
416
+ line_number=1,
417
+ original="strict = True",
418
+ replacement="# removed",
419
+ transform_name="remove_meta_strict",
420
+ )
421
+ break
422
+
423
+ if name == "json_module":
424
+ # Rename to render_module
425
+ changed = True
426
+ new_target = target.with_changes(target=cst.Name("render_module"))
427
+ transformed_statement = s.with_changes(targets=[new_target])
428
+ self.record_change(
429
+ description="Rename Meta.json_module to Meta.render_module",
430
+ line_number=1,
431
+ original="json_module = ...",
432
+ replacement="render_module = ...",
433
+ transform_name="json_module_to_render_module",
434
+ )
435
+ break
436
+
437
+ if skip_statement:
438
+ continue
439
+ elif transformed_statement:
440
+ new_body.append(transformed_statement)
441
+ else:
442
+ new_body.append(s)
443
+ else:
444
+ new_body.append(s)
445
+
446
+ if changed:
447
+ if new_body:
448
+ return stmt.with_changes(body=new_body), True
449
+ return None, True
450
+
451
+ return stmt, False
452
+
453
+ def _get_decorator_name(self, node: cst.BaseExpression) -> str | None:
454
+ """Get the name of a decorator."""
455
+ if isinstance(node, cst.Name):
456
+ return str(node.value)
457
+ if isinstance(node, cst.Attribute):
458
+ return str(node.attr.value)
459
+ return None
460
+
461
+ def _get_call_func_name(self, node: cst.BaseExpression) -> str | None:
462
+ """Get the function name from a call's func attribute."""
463
+ if isinstance(node, cst.Name):
464
+ return str(node.value)
465
+ if isinstance(node, cst.Attribute):
466
+ # Handle fields.String, etc.
467
+ return str(node.attr.value)
468
+ return None
469
+
470
+
471
+ class MarshmallowImportTransformer(BaseTransformer):
472
+ """Handle import transformations for Marshmallow.
473
+
474
+ This runs after the main transformer.
475
+ """
476
+
477
+ def __init__(self) -> None:
478
+ super().__init__()
479
+
480
+ def leave_ImportFrom(
481
+ self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
482
+ ) -> cst.ImportFrom:
483
+ """Transform marshmallow imports if needed.
484
+
485
+ Currently no import changes are required for v2 to v3 migration
486
+ as the module structure remains the same.
487
+ """
488
+ return updated_node
489
+
490
+
491
+ def transform_marshmallow(source_code: str) -> tuple[str, list]:
492
+ """Transform Marshmallow code from 2.x to 3.x.
493
+
494
+ Args:
495
+ source_code: The source code to transform
496
+
497
+ Returns:
498
+ Tuple of (transformed_code, list of changes)
499
+ """
500
+ try:
501
+ tree = cst.parse_module(source_code)
502
+ except cst.ParserSyntaxError as e:
503
+ raise SyntaxError(f"Invalid Python syntax: {e}") from e
504
+
505
+ # Main transformation pass
506
+ transformer = MarshmallowTransformer()
507
+ transformer.set_source(source_code)
508
+ transformed_tree = tree.visit(transformer)
509
+
510
+ # Import transformation pass (currently minimal)
511
+ import_transformer = MarshmallowImportTransformer()
512
+ final_tree = transformed_tree.visit(import_transformer)
513
+
514
+ all_changes = transformer.changes + import_transformer.changes
515
+ return final_tree.code, all_changes