codeshift 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. codeshift/cli/commands/apply.py +24 -2
  2. codeshift/cli/commands/upgrade_all.py +4 -1
  3. codeshift/cli/package_manager.py +102 -0
  4. codeshift/knowledge/generator.py +11 -1
  5. codeshift/knowledge_base/libraries/aiohttp.yaml +186 -0
  6. codeshift/knowledge_base/libraries/attrs.yaml +181 -0
  7. codeshift/knowledge_base/libraries/celery.yaml +244 -0
  8. codeshift/knowledge_base/libraries/click.yaml +195 -0
  9. codeshift/knowledge_base/libraries/django.yaml +355 -0
  10. codeshift/knowledge_base/libraries/flask.yaml +270 -0
  11. codeshift/knowledge_base/libraries/httpx.yaml +183 -0
  12. codeshift/knowledge_base/libraries/marshmallow.yaml +238 -0
  13. codeshift/knowledge_base/libraries/numpy.yaml +429 -0
  14. codeshift/knowledge_base/libraries/pytest.yaml +192 -0
  15. codeshift/knowledge_base/libraries/sqlalchemy.yaml +2 -1
  16. codeshift/migrator/engine.py +60 -0
  17. codeshift/migrator/transforms/__init__.py +2 -0
  18. codeshift/migrator/transforms/aiohttp_transformer.py +608 -0
  19. codeshift/migrator/transforms/attrs_transformer.py +570 -0
  20. codeshift/migrator/transforms/celery_transformer.py +546 -0
  21. codeshift/migrator/transforms/click_transformer.py +526 -0
  22. codeshift/migrator/transforms/django_transformer.py +852 -0
  23. codeshift/migrator/transforms/fastapi_transformer.py +12 -7
  24. codeshift/migrator/transforms/flask_transformer.py +505 -0
  25. codeshift/migrator/transforms/httpx_transformer.py +419 -0
  26. codeshift/migrator/transforms/marshmallow_transformer.py +515 -0
  27. codeshift/migrator/transforms/numpy_transformer.py +413 -0
  28. codeshift/migrator/transforms/pydantic_v1_to_v2.py +53 -8
  29. codeshift/migrator/transforms/pytest_transformer.py +351 -0
  30. codeshift/migrator/transforms/requests_transformer.py +74 -1
  31. codeshift/migrator/transforms/sqlalchemy_transformer.py +692 -39
  32. codeshift/scanner/dependency_parser.py +1 -1
  33. {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/METADATA +46 -4
  34. {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/RECORD +38 -17
  35. {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/WHEEL +0 -0
  36. {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/entry_points.txt +0 -0
  37. {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/licenses/LICENSE +0 -0
  38. {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/top_level.txt +0 -0
@@ -59,14 +59,19 @@ class FastAPITransformer(BaseTransformer):
59
59
  )
60
60
  return updated_node.with_changes(module=cst.Name("fastapi"))
61
61
 
62
- # Transform starlette.status imports
63
- if module_name == "starlette.status":
62
+ # NOTE: starlette.status imports are intentionally NOT transformed.
63
+ # FastAPI does not export status constants (HTTP_200_OK, etc.) directly.
64
+ # These imports should remain as `from starlette.status import ...`
65
+ # since FastAPI depends on Starlette and these imports work correctly.
66
+
67
+ # Transform starlette.background imports (BackgroundTasks)
68
+ if module_name == "starlette.background":
64
69
  self.record_change(
65
- description="Import status from fastapi instead of starlette",
70
+ description="Import BackgroundTasks from fastapi instead of starlette.background",
66
71
  line_number=1,
67
72
  original=f"from {module_name}",
68
- replacement="from fastapi import status",
69
- transform_name="starlette_to_fastapi_status",
73
+ replacement="from fastapi",
74
+ transform_name="starlette_to_fastapi_background",
70
75
  )
71
76
  return updated_node.with_changes(module=cst.Name("fastapi"))
72
77
 
@@ -74,10 +79,10 @@ class FastAPITransformer(BaseTransformer):
74
79
 
75
80
  def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
76
81
  """Transform FastAPI function calls."""
77
- # Handle Field, Query, Path, Body regex -> pattern
82
+ # Handle Field, Query, Path, Body, Header, Cookie regex -> pattern
78
83
  if isinstance(updated_node.func, cst.Name):
79
84
  func_name = updated_node.func.value
80
- if func_name in ("Field", "Query", "Path", "Body"):
85
+ if func_name in ("Field", "Query", "Path", "Body", "Header", "Cookie"):
81
86
  new_args = []
82
87
  changed = False
83
88
  for arg in updated_node.args:
@@ -0,0 +1,505 @@
1
+ """Flask transformation using LibCST for Flask 1.x to 2.x/3.x migrations."""
2
+
3
+ import libcst as cst
4
+
5
+ from codeshift.migrator.ast_transforms import BaseTransformer
6
+
7
+
8
+ class FlaskTransformer(BaseTransformer):
9
+ """Transform Flask code for version upgrades (1.x to 2.x/3.x)."""
10
+
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+ # Track imports that need to be added
14
+ self._needs_markupsafe_escape = False
15
+ self._needs_markupsafe_markup = False
16
+ self._needs_werkzeug_safe_join = False
17
+ self._needs_json_import = False
18
+ # Track what flask imports exist
19
+ self._has_flask_escape_import = False
20
+ self._has_flask_markup_import = False
21
+ self._has_flask_safe_join_import = False
22
+ # Track if markupsafe import already exists
23
+ self._has_markupsafe_import = False
24
+ self._markupsafe_import_names: set[str] = set()
25
+
26
+ def leave_ImportFrom(
27
+ self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
28
+ ) -> cst.BaseSmallStatement | cst.RemovalSentinel:
29
+ """Transform Flask imports to their new locations."""
30
+ if updated_node.module is None:
31
+ return updated_node
32
+
33
+ module_name = self._get_module_name(updated_node.module)
34
+
35
+ # Handle flask imports
36
+ if module_name == "flask":
37
+ return self._handle_flask_import(updated_node)
38
+
39
+ # Handle flask.globals imports (deprecated context stacks)
40
+ if module_name == "flask.globals":
41
+ return self._handle_flask_globals_import(updated_node)
42
+
43
+ # Track existing markupsafe imports
44
+ if module_name == "markupsafe":
45
+ self._has_markupsafe_import = True
46
+ if not isinstance(updated_node.names, cst.ImportStar):
47
+ for name in updated_node.names:
48
+ if isinstance(name, cst.ImportAlias):
49
+ imported = self._get_name_value(name.name)
50
+ if imported:
51
+ self._markupsafe_import_names.add(imported)
52
+
53
+ return updated_node
54
+
55
+ def _handle_flask_import(self, node: cst.ImportFrom) -> cst.ImportFrom | cst.RemovalSentinel:
56
+ """Handle imports from flask module."""
57
+ if isinstance(node.names, cst.ImportStar):
58
+ return node
59
+
60
+ new_names = []
61
+ changed = False
62
+
63
+ for name in node.names:
64
+ if isinstance(name, cst.ImportAlias):
65
+ imported_name = self._get_name_value(name.name)
66
+
67
+ if imported_name == "escape":
68
+ # Mark for adding markupsafe import
69
+ self._needs_markupsafe_escape = True
70
+ self._has_flask_escape_import = True
71
+ changed = True
72
+ self.record_change(
73
+ description="Move 'escape' import from flask to markupsafe",
74
+ line_number=1,
75
+ original="from flask import escape",
76
+ replacement="from markupsafe import escape",
77
+ transform_name="flask_escape_to_markupsafe",
78
+ )
79
+ # Don't add to new_names - we'll add markupsafe import later
80
+ continue
81
+
82
+ elif imported_name == "Markup":
83
+ # Mark for adding markupsafe import
84
+ self._needs_markupsafe_markup = True
85
+ self._has_flask_markup_import = True
86
+ changed = True
87
+ self.record_change(
88
+ description="Move 'Markup' import from flask to markupsafe",
89
+ line_number=1,
90
+ original="from flask import Markup",
91
+ replacement="from markupsafe import Markup",
92
+ transform_name="flask_markup_to_markupsafe",
93
+ )
94
+ # Don't add to new_names - we'll add markupsafe import later
95
+ continue
96
+
97
+ elif imported_name == "safe_join":
98
+ # Mark for adding werkzeug import
99
+ self._needs_werkzeug_safe_join = True
100
+ self._has_flask_safe_join_import = True
101
+ changed = True
102
+ self.record_change(
103
+ description="Move 'safe_join' import from flask to werkzeug.utils",
104
+ line_number=1,
105
+ original="from flask import safe_join",
106
+ replacement="from werkzeug.utils import safe_join",
107
+ transform_name="flask_safe_join_to_werkzeug",
108
+ )
109
+ # Don't add to new_names
110
+ continue
111
+
112
+ new_names.append(name)
113
+
114
+ if changed:
115
+ if not new_names:
116
+ # All imports were moved, remove the flask import line
117
+ return cst.RemovalSentinel.REMOVE
118
+ return node.with_changes(names=new_names)
119
+
120
+ return node
121
+
122
+ def _handle_flask_globals_import(
123
+ self, node: cst.ImportFrom
124
+ ) -> cst.BaseSmallStatement | cst.RemovalSentinel:
125
+ """Handle imports from flask.globals (deprecated context stacks)."""
126
+ if isinstance(node.names, cst.ImportStar):
127
+ return node
128
+
129
+ new_names = []
130
+ changed = False
131
+
132
+ for name in node.names:
133
+ if isinstance(name, cst.ImportAlias):
134
+ imported_name = self._get_name_value(name.name)
135
+
136
+ if imported_name in ("_app_ctx_stack", "_request_ctx_stack"):
137
+ changed = True
138
+ self.record_change(
139
+ description=f"Remove deprecated '{imported_name}' import, use flask.g instead",
140
+ line_number=1,
141
+ original=f"from flask.globals import {imported_name}",
142
+ replacement="from flask import g",
143
+ transform_name=f"{imported_name.lstrip('_')}_to_g",
144
+ )
145
+ continue
146
+
147
+ new_names.append(name)
148
+
149
+ if changed:
150
+ if not new_names:
151
+ return cst.RemovalSentinel.REMOVE
152
+ return node.with_changes(names=new_names)
153
+
154
+ return node
155
+
156
+ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
157
+ """Transform Flask function calls."""
158
+ # Handle send_file parameter renames
159
+ if self._is_call_to(updated_node, "send_file"):
160
+ return self._transform_send_file(updated_node)
161
+
162
+ # Handle send_from_directory parameter renames
163
+ if self._is_call_to(updated_node, "send_from_directory"):
164
+ return self._transform_send_from_directory(updated_node)
165
+
166
+ # Handle app.config.from_json -> app.config.from_file
167
+ if self._is_method_call(updated_node, "from_json"):
168
+ return self._transform_from_json(updated_node)
169
+
170
+ return updated_node
171
+
172
+ def _transform_send_file(self, node: cst.Call) -> cst.Call:
173
+ """Transform send_file() parameter names."""
174
+ new_args = []
175
+ changed = False
176
+
177
+ param_renames = {
178
+ "attachment_filename": (
179
+ "download_name",
180
+ "send_file_attachment_filename_to_download_name",
181
+ ),
182
+ "cache_timeout": ("max_age", "send_file_cache_timeout_to_max_age"),
183
+ "add_etags": ("etag", "send_file_add_etags_to_etag"),
184
+ }
185
+
186
+ for arg in node.args:
187
+ if isinstance(arg.keyword, cst.Name):
188
+ keyword_name = arg.keyword.value
189
+ if keyword_name in param_renames:
190
+ new_name, transform_name = param_renames[keyword_name]
191
+ new_args.append(arg.with_changes(keyword=cst.Name(new_name)))
192
+ changed = True
193
+ self.record_change(
194
+ description=f"Rename send_file({keyword_name}=...) to send_file({new_name}=...)",
195
+ line_number=1,
196
+ original=f"send_file({keyword_name}=...)",
197
+ replacement=f"send_file({new_name}=...)",
198
+ transform_name=transform_name,
199
+ )
200
+ else:
201
+ new_args.append(arg)
202
+ else:
203
+ new_args.append(arg)
204
+
205
+ if changed:
206
+ return node.with_changes(args=new_args)
207
+ return node
208
+
209
+ def _transform_send_from_directory(self, node: cst.Call) -> cst.Call:
210
+ """Transform send_from_directory() parameter names."""
211
+ new_args = []
212
+ changed = False
213
+
214
+ for arg in node.args:
215
+ if isinstance(arg.keyword, cst.Name) and arg.keyword.value == "filename":
216
+ new_args.append(arg.with_changes(keyword=cst.Name("path")))
217
+ changed = True
218
+ self.record_change(
219
+ description="Rename send_from_directory(filename=...) to send_from_directory(path=...)",
220
+ line_number=1,
221
+ original="send_from_directory(filename=...)",
222
+ replacement="send_from_directory(path=...)",
223
+ transform_name="send_from_directory_filename_to_path",
224
+ )
225
+ else:
226
+ new_args.append(arg)
227
+
228
+ if changed:
229
+ return node.with_changes(args=new_args)
230
+ return node
231
+
232
+ def _transform_from_json(self, node: cst.Call) -> cst.Call:
233
+ """Transform config.from_json() to config.from_file() with json.load."""
234
+ # Check if this is actually a from_json call on config
235
+ if not isinstance(node.func, cst.Attribute):
236
+ return node
237
+
238
+ if node.func.attr.value != "from_json":
239
+ return node
240
+
241
+ # Check the call chain to see if it's on config
242
+ value = node.func.value
243
+ is_config_call = False
244
+ if isinstance(value, cst.Attribute) and value.attr.value == "config":
245
+ is_config_call = True
246
+ elif isinstance(value, cst.Name) and value.value == "config":
247
+ is_config_call = True
248
+
249
+ if not is_config_call:
250
+ return node
251
+
252
+ # Get the first positional argument (the filename)
253
+ if not node.args:
254
+ return node
255
+
256
+ file_arg = node.args[0]
257
+
258
+ # Transform from_json to from_file with json.load
259
+ self._needs_json_import = True
260
+
261
+ # Build new arguments: (filename, load=json.load)
262
+ new_args = [
263
+ file_arg,
264
+ cst.Arg(
265
+ keyword=cst.Name("load"),
266
+ value=cst.Attribute(
267
+ value=cst.Name("json"),
268
+ attr=cst.Name("load"),
269
+ ),
270
+ equal=cst.AssignEqual(
271
+ whitespace_before=cst.SimpleWhitespace(""),
272
+ whitespace_after=cst.SimpleWhitespace(""),
273
+ ),
274
+ ),
275
+ ]
276
+
277
+ # Change the method name from from_json to from_file
278
+ new_func = node.func.with_changes(attr=cst.Name("from_file"))
279
+
280
+ self.record_change(
281
+ description="Convert config.from_json() to config.from_file() with json.load",
282
+ line_number=1,
283
+ original="config.from_json('file.json')",
284
+ replacement="config.from_file('file.json', load=json.load)",
285
+ transform_name="config_from_json_to_from_file",
286
+ )
287
+
288
+ return node.with_changes(func=new_func, args=new_args)
289
+
290
+ def leave_Attribute(
291
+ self, original_node: cst.Attribute, updated_node: cst.Attribute
292
+ ) -> cst.BaseExpression:
293
+ """Transform attribute access for deprecated properties."""
294
+ attr_name = updated_node.attr.value
295
+
296
+ # Handle app.env -> app.debug
297
+ if attr_name == "env":
298
+ # Check if it's likely an app.env access
299
+ if isinstance(updated_node.value, cst.Name):
300
+ if updated_node.value.value in ("app", "application", "current_app"):
301
+ self.record_change(
302
+ description="Convert app.env to app.debug (env property deprecated)",
303
+ line_number=1,
304
+ original="app.env",
305
+ replacement="app.debug",
306
+ transform_name="app_env_to_debug",
307
+ )
308
+ return updated_node.with_changes(attr=cst.Name("debug"))
309
+
310
+ return updated_node
311
+
312
+ def _is_call_to(self, node: cst.Call, func_name: str) -> bool:
313
+ """Check if a Call node is calling a specific function."""
314
+ if isinstance(node.func, cst.Name):
315
+ return bool(node.func.value == func_name)
316
+ return False
317
+
318
+ def _is_method_call(self, node: cst.Call, method_name: str) -> bool:
319
+ """Check if a Call node is calling a specific method."""
320
+ if isinstance(node.func, cst.Attribute):
321
+ return bool(node.func.attr.value == method_name)
322
+ return False
323
+
324
+ def _get_module_name(self, module: cst.BaseExpression) -> str:
325
+ """Get the full module name from a Name or Attribute node."""
326
+ if isinstance(module, cst.Name):
327
+ return str(module.value)
328
+ elif isinstance(module, cst.Attribute):
329
+ return f"{self._get_module_name(module.value)}.{module.attr.value}"
330
+ return ""
331
+
332
+ def _get_name_value(self, node: cst.BaseExpression) -> str | None:
333
+ """Extract the string value from a Name node."""
334
+ if isinstance(node, cst.Name):
335
+ return str(node.value)
336
+ return None
337
+
338
+
339
+ class FlaskImportAdder(cst.CSTTransformer):
340
+ """Adds new imports needed after Flask transformation."""
341
+
342
+ def __init__(
343
+ self,
344
+ needs_markupsafe_escape: bool = False,
345
+ needs_markupsafe_markup: bool = False,
346
+ needs_werkzeug_safe_join: bool = False,
347
+ needs_json_import: bool = False,
348
+ has_markupsafe_import: bool = False,
349
+ existing_markupsafe_names: set[str] | None = None,
350
+ ) -> None:
351
+ super().__init__()
352
+ self.needs_markupsafe_escape = needs_markupsafe_escape
353
+ self.needs_markupsafe_markup = needs_markupsafe_markup
354
+ self.needs_werkzeug_safe_join = needs_werkzeug_safe_join
355
+ self.needs_json_import = needs_json_import
356
+ self.has_markupsafe_import = has_markupsafe_import
357
+ self.existing_markupsafe_names = existing_markupsafe_names or set()
358
+ self._added_imports = False
359
+ self._has_json_import = False
360
+ self._has_werkzeug_import = False
361
+
362
+ def visit_ImportFrom(self, node: cst.ImportFrom) -> bool:
363
+ """Track existing imports."""
364
+ if node.module:
365
+ module_name = self._get_module_name(node.module)
366
+ if module_name == "json":
367
+ self._has_json_import = True
368
+ elif module_name == "werkzeug.utils":
369
+ self._has_werkzeug_import = True
370
+ return True
371
+
372
+ def visit_Import(self, node: cst.Import) -> bool:
373
+ """Track existing json import."""
374
+ for name in node.names:
375
+ if isinstance(name, cst.ImportAlias):
376
+ if isinstance(name.name, cst.Name) and name.name.value == "json":
377
+ self._has_json_import = True
378
+ return True
379
+
380
+ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
381
+ """Add necessary imports at the top of the module."""
382
+ if self._added_imports:
383
+ return updated_node
384
+
385
+ new_imports = []
386
+
387
+ # Add markupsafe import if needed
388
+ if self.needs_markupsafe_escape or self.needs_markupsafe_markup:
389
+ names_to_import = []
390
+ if self.needs_markupsafe_escape and "escape" not in self.existing_markupsafe_names:
391
+ names_to_import.append(cst.ImportAlias(name=cst.Name("escape")))
392
+ if self.needs_markupsafe_markup and "Markup" not in self.existing_markupsafe_names:
393
+ names_to_import.append(cst.ImportAlias(name=cst.Name("Markup")))
394
+
395
+ if names_to_import:
396
+ new_imports.append(
397
+ cst.SimpleStatementLine(
398
+ body=[
399
+ cst.ImportFrom(
400
+ module=cst.Name("markupsafe"),
401
+ names=names_to_import,
402
+ )
403
+ ]
404
+ )
405
+ )
406
+
407
+ # Add werkzeug.utils import if needed
408
+ if self.needs_werkzeug_safe_join and not self._has_werkzeug_import:
409
+ new_imports.append(
410
+ cst.SimpleStatementLine(
411
+ body=[
412
+ cst.ImportFrom(
413
+ module=cst.Attribute(
414
+ value=cst.Name("werkzeug"),
415
+ attr=cst.Name("utils"),
416
+ ),
417
+ names=[cst.ImportAlias(name=cst.Name("safe_join"))],
418
+ )
419
+ ]
420
+ )
421
+ )
422
+
423
+ # Add json import if needed
424
+ if self.needs_json_import and not self._has_json_import:
425
+ new_imports.append(
426
+ cst.SimpleStatementLine(
427
+ body=[cst.Import(names=[cst.ImportAlias(name=cst.Name("json"))])]
428
+ )
429
+ )
430
+
431
+ if new_imports:
432
+ # Find the first import statement and insert before it
433
+ new_body = list(updated_node.body)
434
+
435
+ # Find insertion point (after module docstring, before first import)
436
+ insert_idx = 0
437
+ for i, stmt in enumerate(new_body):
438
+ if isinstance(stmt, cst.SimpleStatementLine):
439
+ if stmt.body and isinstance(stmt.body[0], cst.Import | cst.ImportFrom):
440
+ insert_idx = i
441
+ break
442
+ elif stmt.body and isinstance(stmt.body[0], cst.Expr):
443
+ # Could be docstring, continue
444
+ if isinstance(stmt.body[0].value, cst.SimpleString):
445
+ insert_idx = i + 1
446
+ continue
447
+ insert_idx = i
448
+ break
449
+
450
+ # Insert new imports
451
+ for imp in reversed(new_imports):
452
+ new_body.insert(insert_idx, imp)
453
+
454
+ self._added_imports = True
455
+ return updated_node.with_changes(body=new_body)
456
+
457
+ return updated_node
458
+
459
+ def _get_module_name(self, module: cst.BaseExpression) -> str:
460
+ """Get the full module name from a Name or Attribute node."""
461
+ if isinstance(module, cst.Name):
462
+ return str(module.value)
463
+ elif isinstance(module, cst.Attribute):
464
+ return f"{self._get_module_name(module.value)}.{module.attr.value}"
465
+ return ""
466
+
467
+
468
+ def transform_flask(source_code: str) -> tuple[str, list]:
469
+ """Transform Flask code for version upgrades.
470
+
471
+ Args:
472
+ source_code: The source code to transform
473
+
474
+ Returns:
475
+ Tuple of (transformed_code, list of changes)
476
+ """
477
+ try:
478
+ tree = cst.parse_module(source_code)
479
+ except cst.ParserSyntaxError:
480
+ return source_code, []
481
+
482
+ # First pass: main transformations
483
+ transformer = FlaskTransformer()
484
+ transformer.set_source(source_code)
485
+
486
+ try:
487
+ transformed_tree = tree.visit(transformer)
488
+ except Exception:
489
+ return source_code, []
490
+
491
+ # Second pass: add missing imports
492
+ import_adder = FlaskImportAdder(
493
+ needs_markupsafe_escape=transformer._needs_markupsafe_escape,
494
+ needs_markupsafe_markup=transformer._needs_markupsafe_markup,
495
+ needs_werkzeug_safe_join=transformer._needs_werkzeug_safe_join,
496
+ needs_json_import=transformer._needs_json_import,
497
+ has_markupsafe_import=transformer._has_markupsafe_import,
498
+ existing_markupsafe_names=transformer._markupsafe_import_names,
499
+ )
500
+
501
+ try:
502
+ final_tree = transformed_tree.visit(import_adder)
503
+ return final_tree.code, transformer.changes
504
+ except Exception:
505
+ return transformed_tree.code, transformer.changes