codeshift 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. codeshift/cli/commands/apply.py +24 -2
  2. codeshift/cli/package_manager.py +102 -0
  3. codeshift/knowledge/generator.py +11 -1
  4. codeshift/knowledge_base/libraries/aiohttp.yaml +186 -0
  5. codeshift/knowledge_base/libraries/attrs.yaml +181 -0
  6. codeshift/knowledge_base/libraries/celery.yaml +244 -0
  7. codeshift/knowledge_base/libraries/click.yaml +195 -0
  8. codeshift/knowledge_base/libraries/django.yaml +355 -0
  9. codeshift/knowledge_base/libraries/flask.yaml +270 -0
  10. codeshift/knowledge_base/libraries/httpx.yaml +183 -0
  11. codeshift/knowledge_base/libraries/marshmallow.yaml +238 -0
  12. codeshift/knowledge_base/libraries/numpy.yaml +429 -0
  13. codeshift/knowledge_base/libraries/pytest.yaml +192 -0
  14. codeshift/knowledge_base/libraries/sqlalchemy.yaml +2 -1
  15. codeshift/migrator/engine.py +60 -0
  16. codeshift/migrator/transforms/__init__.py +2 -0
  17. codeshift/migrator/transforms/aiohttp_transformer.py +608 -0
  18. codeshift/migrator/transforms/attrs_transformer.py +570 -0
  19. codeshift/migrator/transforms/celery_transformer.py +546 -0
  20. codeshift/migrator/transforms/click_transformer.py +526 -0
  21. codeshift/migrator/transforms/django_transformer.py +852 -0
  22. codeshift/migrator/transforms/fastapi_transformer.py +12 -7
  23. codeshift/migrator/transforms/flask_transformer.py +505 -0
  24. codeshift/migrator/transforms/httpx_transformer.py +419 -0
  25. codeshift/migrator/transforms/marshmallow_transformer.py +515 -0
  26. codeshift/migrator/transforms/numpy_transformer.py +413 -0
  27. codeshift/migrator/transforms/pydantic_v1_to_v2.py +53 -8
  28. codeshift/migrator/transforms/pytest_transformer.py +351 -0
  29. codeshift/migrator/transforms/requests_transformer.py +74 -1
  30. codeshift/migrator/transforms/sqlalchemy_transformer.py +692 -39
  31. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/METADATA +46 -4
  32. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/RECORD +36 -15
  33. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/WHEEL +0 -0
  34. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/entry_points.txt +0 -0
  35. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/licenses/LICENSE +0 -0
  36. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,413 @@
1
+ """NumPy 1.x to 2.0 transformation using LibCST."""
2
+
3
+ import libcst as cst
4
+
5
+ from codeshift.migrator.ast_transforms import BaseTransformer
6
+
7
+
8
+ class NumPyTransformer(BaseTransformer):
9
+ """Transform NumPy 1.x code to 2.0.
10
+
11
+ Handles the following breaking changes:
12
+ - Type alias removals (np.bool, np.int, np.float, np.complex, np.object, np.str)
13
+ - Function renames (alltrue, sometrue, product, cumproduct, trapz, in1d, row_stack, msort)
14
+ - Constant renames (Inf, Infinity, infty, NaN, PINF, NINF, PZERO, NZERO)
15
+ - Other deprecated/removed functions
16
+ """
17
+
18
+ # Type alias mappings: old_name -> new_name
19
+ TYPE_ALIAS_MAPPINGS = {
20
+ # Python builtin shadows (high priority)
21
+ "bool": "bool_",
22
+ "int": "int_",
23
+ "float": "float64",
24
+ "complex": "complex128",
25
+ "object": "object_",
26
+ "str": "str_",
27
+ # Other type aliases
28
+ "unicode_": "str_",
29
+ "string_": "bytes_",
30
+ "float_": "float64",
31
+ "complex_": "complex128",
32
+ "cfloat": "complex128",
33
+ "singlecomplex": "complex64",
34
+ "longfloat": "longdouble",
35
+ "longcomplex": "clongdouble",
36
+ "clongfloat": "clongdouble",
37
+ }
38
+
39
+ # Function renames: old_name -> new_name
40
+ FUNCTION_RENAMES = {
41
+ "alltrue": "all",
42
+ "sometrue": "any",
43
+ "product": "prod",
44
+ "cumproduct": "cumprod",
45
+ "trapz": "trapezoid",
46
+ "in1d": "isin",
47
+ "row_stack": "vstack",
48
+ "issubsctype": "issubdtype",
49
+ }
50
+
51
+ # Constant renames: old_name -> new_name
52
+ CONSTANT_RENAMES = {
53
+ "Inf": "inf",
54
+ "Infinity": "inf",
55
+ "infty": "inf",
56
+ "NaN": "nan",
57
+ "PINF": "inf",
58
+ }
59
+
60
+ # Constants that need special handling (replacement with expressions)
61
+ CONSTANT_SPECIAL = {
62
+ "NINF": "-np.inf", # Requires special handling
63
+ "PZERO": "0.0",
64
+ "NZERO": "-0.0",
65
+ }
66
+
67
+ def __init__(self) -> None:
68
+ super().__init__()
69
+ self._numpy_aliases: set[str] = {"np", "numpy"}
70
+ self._has_numpy_import = False
71
+
72
+ def visit_ImportFrom(self, node: cst.ImportFrom) -> bool:
73
+ """Track numpy imports to detect aliases."""
74
+ if node.module is None:
75
+ return True
76
+
77
+ module_name = self._get_module_name(node.module)
78
+ if module_name == "numpy" or module_name.startswith("numpy."):
79
+ self._has_numpy_import = True
80
+ return True
81
+
82
+ def visit_Import(self, node: cst.Import) -> bool:
83
+ """Track numpy import aliases (e.g., import numpy as np)."""
84
+ if isinstance(node.names, cst.ImportStar):
85
+ return True
86
+
87
+ for alias in node.names:
88
+ if isinstance(alias, cst.ImportAlias):
89
+ name = self._get_name_value(alias.name)
90
+ if name == "numpy":
91
+ self._has_numpy_import = True
92
+ if alias.asname:
93
+ if isinstance(alias.asname, cst.AsName):
94
+ if isinstance(alias.asname.name, cst.Name):
95
+ self._numpy_aliases.add(alias.asname.name.value)
96
+ return True
97
+
98
+ def leave_Attribute(
99
+ self, original_node: cst.Attribute, updated_node: cst.Attribute
100
+ ) -> cst.BaseExpression:
101
+ """Transform numpy attribute accesses."""
102
+ attr_name = updated_node.attr.value
103
+
104
+ # Check if this is a numpy attribute access
105
+ if not self._is_numpy_attribute(updated_node):
106
+ return updated_node
107
+
108
+ # Handle type alias removals
109
+ if attr_name in self.TYPE_ALIAS_MAPPINGS:
110
+ new_attr = self.TYPE_ALIAS_MAPPINGS[attr_name]
111
+ self.record_change(
112
+ description=f"Replace numpy.{attr_name} with numpy.{new_attr}",
113
+ line_number=1,
114
+ original=f"numpy.{attr_name}",
115
+ replacement=f"numpy.{new_attr}",
116
+ transform_name=f"{attr_name}_to_{new_attr}",
117
+ )
118
+ return updated_node.with_changes(attr=cst.Name(new_attr))
119
+
120
+ # Handle constant renames
121
+ if attr_name in self.CONSTANT_RENAMES:
122
+ new_attr = self.CONSTANT_RENAMES[attr_name]
123
+ self.record_change(
124
+ description=f"Replace numpy.{attr_name} with numpy.{new_attr}",
125
+ line_number=1,
126
+ original=f"numpy.{attr_name}",
127
+ replacement=f"numpy.{new_attr}",
128
+ transform_name=f"{attr_name}_to_{new_attr}",
129
+ )
130
+ return updated_node.with_changes(attr=cst.Name(new_attr))
131
+
132
+ # Handle NINF -> -np.inf
133
+ if attr_name == "NINF":
134
+ self.record_change(
135
+ description="Replace numpy.NINF with -numpy.inf",
136
+ line_number=1,
137
+ original="numpy.NINF",
138
+ replacement="-numpy.inf",
139
+ transform_name="NINF_to_neg_inf",
140
+ )
141
+ return cst.UnaryOperation(
142
+ operator=cst.Minus(),
143
+ expression=updated_node.with_changes(attr=cst.Name("inf")),
144
+ )
145
+
146
+ # Handle PZERO -> 0.0
147
+ if attr_name == "PZERO":
148
+ self.record_change(
149
+ description="Replace numpy.PZERO with 0.0",
150
+ line_number=1,
151
+ original="numpy.PZERO",
152
+ replacement="0.0",
153
+ transform_name="PZERO_to_zero",
154
+ )
155
+ return cst.Float("0.0")
156
+
157
+ # Handle NZERO -> -0.0
158
+ if attr_name == "NZERO":
159
+ self.record_change(
160
+ description="Replace numpy.NZERO with -0.0",
161
+ line_number=1,
162
+ original="numpy.NZERO",
163
+ replacement="-0.0",
164
+ transform_name="NZERO_to_neg_zero",
165
+ )
166
+ return cst.UnaryOperation(
167
+ operator=cst.Minus(),
168
+ expression=cst.Float("0.0"),
169
+ )
170
+
171
+ return updated_node
172
+
173
+ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.BaseExpression:
174
+ """Transform numpy function calls."""
175
+ # Handle direct numpy function calls like np.alltrue(), np.product(), etc.
176
+ if isinstance(updated_node.func, cst.Attribute):
177
+ attr = updated_node.func
178
+ func_name = attr.attr.value
179
+
180
+ if not self._is_numpy_attribute(attr):
181
+ return updated_node
182
+
183
+ # Handle function renames
184
+ if func_name in self.FUNCTION_RENAMES:
185
+ new_func = self.FUNCTION_RENAMES[func_name]
186
+ self.record_change(
187
+ description=f"Replace numpy.{func_name}() with numpy.{new_func}()",
188
+ line_number=1,
189
+ original=f"numpy.{func_name}()",
190
+ replacement=f"numpy.{new_func}()",
191
+ transform_name=f"{func_name}_to_{new_func}",
192
+ )
193
+ new_attr = attr.with_changes(attr=cst.Name(new_func))
194
+ return updated_node.with_changes(func=new_attr)
195
+
196
+ # Handle msort(a) -> sort(a, axis=0)
197
+ if func_name == "msort":
198
+ self.record_change(
199
+ description="Replace numpy.msort(a) with numpy.sort(a, axis=0)",
200
+ line_number=1,
201
+ original="numpy.msort(a)",
202
+ replacement="numpy.sort(a, axis=0)",
203
+ transform_name="msort_to_sort_axis0",
204
+ )
205
+ new_attr = attr.with_changes(attr=cst.Name("sort"))
206
+ # Add axis=0 argument
207
+ new_args = list(updated_node.args)
208
+ new_args.append(
209
+ cst.Arg(
210
+ keyword=cst.Name("axis"),
211
+ value=cst.Integer("0"),
212
+ equal=cst.AssignEqual(
213
+ whitespace_before=cst.SimpleWhitespace(""),
214
+ whitespace_after=cst.SimpleWhitespace(""),
215
+ ),
216
+ )
217
+ )
218
+ return updated_node.with_changes(func=new_attr, args=new_args)
219
+
220
+ # Handle asfarray(a) -> asarray(a, dtype=float)
221
+ if func_name == "asfarray":
222
+ self.record_change(
223
+ description="Replace numpy.asfarray(a) with numpy.asarray(a, dtype=float)",
224
+ line_number=1,
225
+ original="numpy.asfarray(a)",
226
+ replacement="numpy.asarray(a, dtype=float)",
227
+ transform_name="asfarray_to_asarray",
228
+ )
229
+ new_attr = attr.with_changes(attr=cst.Name("asarray"))
230
+ # Check if dtype is already specified
231
+ has_dtype = any(
232
+ isinstance(arg.keyword, cst.Name) and arg.keyword.value == "dtype"
233
+ for arg in updated_node.args
234
+ )
235
+ new_args = list(updated_node.args)
236
+ if not has_dtype:
237
+ new_args.append(
238
+ cst.Arg(
239
+ keyword=cst.Name("dtype"),
240
+ value=cst.Name("float"),
241
+ equal=cst.AssignEqual(
242
+ whitespace_before=cst.SimpleWhitespace(""),
243
+ whitespace_after=cst.SimpleWhitespace(""),
244
+ ),
245
+ )
246
+ )
247
+ return updated_node.with_changes(func=new_attr, args=new_args)
248
+
249
+ # Handle issubclass_(arg1, arg2) -> issubclass(arg1, arg2)
250
+ if func_name == "issubclass_":
251
+ self.record_change(
252
+ description="Replace numpy.issubclass_() with builtin issubclass()",
253
+ line_number=1,
254
+ original="numpy.issubclass_()",
255
+ replacement="issubclass()",
256
+ transform_name="issubclass__to_builtin",
257
+ )
258
+ return updated_node.with_changes(func=cst.Name("issubclass"))
259
+
260
+ return updated_node
261
+
262
+ def _is_numpy_attribute(self, node: cst.Attribute) -> bool:
263
+ """Check if an Attribute node is accessing numpy.
264
+
265
+ Handles both 'numpy.X' and 'np.X' patterns.
266
+ """
267
+ if isinstance(node.value, cst.Name):
268
+ return node.value.value in self._numpy_aliases
269
+ return False
270
+
271
+ def _get_module_name(self, node: cst.BaseExpression) -> str:
272
+ """Get the full module name from an Attribute or Name node."""
273
+ if isinstance(node, cst.Name):
274
+ return str(node.value)
275
+ if isinstance(node, cst.Attribute):
276
+ base = self._get_module_name(node.value)
277
+ return f"{base}.{node.attr.value}"
278
+ return ""
279
+
280
+ def _get_name_value(self, node: cst.BaseExpression) -> str | None:
281
+ """Extract the string value from a Name node."""
282
+ if isinstance(node, cst.Name):
283
+ return str(node.value)
284
+ if isinstance(node, cst.Attribute):
285
+ return self._get_module_name(node)
286
+ return None
287
+
288
+
289
+ class NumPyImportTransformer(BaseTransformer):
290
+ """Transform numpy imports (e.g., from numpy import bool -> from numpy import bool_)."""
291
+
292
+ # Import name mappings
293
+ IMPORT_MAPPINGS = {
294
+ "bool": "bool_",
295
+ "int": "int_",
296
+ "float": "float64",
297
+ "complex": "complex128",
298
+ "object": "object_",
299
+ "str": "str_",
300
+ "unicode_": "str_",
301
+ "string_": "bytes_",
302
+ "float_": "float64",
303
+ "complex_": "complex128",
304
+ "cfloat": "complex128",
305
+ "singlecomplex": "complex64",
306
+ "longfloat": "longdouble",
307
+ "longcomplex": "clongdouble",
308
+ "clongfloat": "clongdouble",
309
+ "alltrue": "all",
310
+ "sometrue": "any",
311
+ "product": "prod",
312
+ "cumproduct": "cumprod",
313
+ "trapz": "trapezoid",
314
+ "in1d": "isin",
315
+ "row_stack": "vstack",
316
+ "issubsctype": "issubdtype",
317
+ "Inf": "inf",
318
+ "Infinity": "inf",
319
+ "infty": "inf",
320
+ "NaN": "nan",
321
+ "PINF": "inf",
322
+ }
323
+
324
+ def leave_ImportFrom(
325
+ self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
326
+ ) -> cst.ImportFrom:
327
+ """Transform imports from numpy."""
328
+ if updated_node.module is None:
329
+ return updated_node
330
+
331
+ module_name = self._get_module_name(updated_node.module)
332
+ if module_name != "numpy":
333
+ return updated_node
334
+
335
+ if isinstance(updated_node.names, cst.ImportStar):
336
+ return updated_node
337
+
338
+ new_names = []
339
+ changed = False
340
+
341
+ for name in updated_node.names:
342
+ if isinstance(name, cst.ImportAlias):
343
+ imported_name = self._get_name_value(name.name)
344
+
345
+ if imported_name in self.IMPORT_MAPPINGS:
346
+ new_import_name = self.IMPORT_MAPPINGS[imported_name]
347
+ new_name = name.with_changes(name=cst.Name(new_import_name))
348
+ new_names.append(new_name)
349
+ changed = True
350
+
351
+ self.record_change(
352
+ description=f"Replace 'from numpy import {imported_name}' with '{new_import_name}'",
353
+ line_number=1,
354
+ original=f"from numpy import {imported_name}",
355
+ replacement=f"from numpy import {new_import_name}",
356
+ transform_name=f"import_{imported_name}_to_{new_import_name}",
357
+ )
358
+ else:
359
+ new_names.append(name)
360
+
361
+ if changed:
362
+ return updated_node.with_changes(names=new_names)
363
+
364
+ return updated_node
365
+
366
+ def _get_module_name(self, node: cst.BaseExpression) -> str:
367
+ """Get the full module name from an Attribute or Name node."""
368
+ if isinstance(node, cst.Name):
369
+ return str(node.value)
370
+ if isinstance(node, cst.Attribute):
371
+ base = self._get_module_name(node.value)
372
+ return f"{base}.{node.attr.value}"
373
+ return ""
374
+
375
+ def _get_name_value(self, node: cst.BaseExpression) -> str | None:
376
+ """Extract the string value from a Name node."""
377
+ if isinstance(node, cst.Name):
378
+ return str(node.value)
379
+ return None
380
+
381
+
382
+ def transform_numpy(source_code: str) -> tuple[str, list]:
383
+ """Transform NumPy code from 1.x to 2.0.
384
+
385
+ Args:
386
+ source_code: The source code to transform
387
+
388
+ Returns:
389
+ Tuple of (transformed_code, list of changes)
390
+ """
391
+ try:
392
+ tree = cst.parse_module(source_code)
393
+ except cst.ParserSyntaxError:
394
+ return source_code, []
395
+
396
+ all_changes = []
397
+
398
+ try:
399
+ # First pass: transform imports
400
+ import_transformer = NumPyImportTransformer()
401
+ import_transformer.set_source(source_code)
402
+ tree = tree.visit(import_transformer)
403
+ all_changes.extend(import_transformer.changes)
404
+
405
+ # Second pass: main transformations
406
+ transformer = NumPyTransformer()
407
+ transformer.set_source(tree.code)
408
+ tree = tree.visit(transformer)
409
+ all_changes.extend(transformer.changes)
410
+
411
+ return tree.code, all_changes
412
+ except Exception:
413
+ return source_code, []
@@ -219,19 +219,64 @@ class PydanticV1ToV2Transformer(BaseTransformer):
219
219
  return updated_node
220
220
 
221
221
  def _transform_validator_decorator(self, node: cst.Decorator) -> cst.Decorator:
222
- """Transform @validator("field") to @field_validator("field")."""
222
+ """Transform @validator("field") to @field_validator("field").
223
+
224
+ Also handles pre=True -> mode="before" and pre=False -> mode="after".
225
+ """
223
226
  if isinstance(node.decorator, cst.Call):
224
227
  # @validator("field_name", ...)
225
- new_call = node.decorator.with_changes(func=cst.Name("field_validator"))
228
+ # Check for pre=True/False and convert to mode="before"/"after"
229
+ mode: str | None = None
230
+ new_args = []
226
231
 
227
- self.record_change(
228
- description="Convert @validator to @field_validator",
229
- line_number=1,
230
- original="@validator(...)",
231
- replacement="@field_validator(...)",
232
- transform_name="validator_to_field_validator",
232
+ for arg in node.decorator.args:
233
+ if isinstance(arg.keyword, cst.Name) and arg.keyword.value == "pre":
234
+ # Found pre argument - determine mode
235
+ if isinstance(arg.value, cst.Name):
236
+ if arg.value.value == "True":
237
+ mode = "before"
238
+ elif arg.value.value == "False":
239
+ mode = "after"
240
+ # Skip adding this argument (we'll add mode instead if needed)
241
+ else:
242
+ # Keep other arguments
243
+ new_args.append(arg)
244
+
245
+ # Add mode argument if pre was present
246
+ if mode is not None:
247
+ new_args.append(
248
+ cst.Arg(
249
+ keyword=cst.Name("mode"),
250
+ value=cst.SimpleString(f'"{mode}"'),
251
+ equal=cst.AssignEqual(
252
+ whitespace_before=cst.SimpleWhitespace(""),
253
+ whitespace_after=cst.SimpleWhitespace(""),
254
+ ),
255
+ )
256
+ )
257
+
258
+ new_call = cst.Call(
259
+ func=cst.Name("field_validator"),
260
+ args=new_args,
233
261
  )
234
262
 
263
+ if mode is not None:
264
+ self.record_change(
265
+ description=f"Convert @validator to @field_validator with mode='{mode}'",
266
+ line_number=1,
267
+ original="@validator(..., pre=...)",
268
+ replacement=f'@field_validator(..., mode="{mode}")',
269
+ transform_name="validator_to_field_validator",
270
+ )
271
+ else:
272
+ self.record_change(
273
+ description="Convert @validator to @field_validator",
274
+ line_number=1,
275
+ original="@validator(...)",
276
+ replacement="@field_validator(...)",
277
+ transform_name="validator_to_field_validator",
278
+ )
279
+
235
280
  return node.with_changes(decorator=new_call)
236
281
  else:
237
282
  # @validator without arguments (shouldn't happen but handle it)