codeshift 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codeshift/cli/commands/apply.py +24 -2
- codeshift/cli/package_manager.py +102 -0
- codeshift/knowledge/generator.py +11 -1
- codeshift/knowledge_base/libraries/aiohttp.yaml +186 -0
- codeshift/knowledge_base/libraries/attrs.yaml +181 -0
- codeshift/knowledge_base/libraries/celery.yaml +244 -0
- codeshift/knowledge_base/libraries/click.yaml +195 -0
- codeshift/knowledge_base/libraries/django.yaml +355 -0
- codeshift/knowledge_base/libraries/flask.yaml +270 -0
- codeshift/knowledge_base/libraries/httpx.yaml +183 -0
- codeshift/knowledge_base/libraries/marshmallow.yaml +238 -0
- codeshift/knowledge_base/libraries/numpy.yaml +429 -0
- codeshift/knowledge_base/libraries/pytest.yaml +192 -0
- codeshift/knowledge_base/libraries/sqlalchemy.yaml +2 -1
- codeshift/migrator/engine.py +60 -0
- codeshift/migrator/transforms/__init__.py +2 -0
- codeshift/migrator/transforms/aiohttp_transformer.py +608 -0
- codeshift/migrator/transforms/attrs_transformer.py +570 -0
- codeshift/migrator/transforms/celery_transformer.py +546 -0
- codeshift/migrator/transforms/click_transformer.py +526 -0
- codeshift/migrator/transforms/django_transformer.py +852 -0
- codeshift/migrator/transforms/fastapi_transformer.py +12 -7
- codeshift/migrator/transforms/flask_transformer.py +505 -0
- codeshift/migrator/transforms/httpx_transformer.py +419 -0
- codeshift/migrator/transforms/marshmallow_transformer.py +515 -0
- codeshift/migrator/transforms/numpy_transformer.py +413 -0
- codeshift/migrator/transforms/pydantic_v1_to_v2.py +53 -8
- codeshift/migrator/transforms/pytest_transformer.py +351 -0
- codeshift/migrator/transforms/requests_transformer.py +74 -1
- codeshift/migrator/transforms/sqlalchemy_transformer.py +692 -39
- {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/METADATA +46 -4
- {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/RECORD +36 -15
- {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/WHEEL +0 -0
- {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/entry_points.txt +0 -0
- {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/licenses/LICENSE +0 -0
- {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/top_level.txt +0 -0
|
@@ -59,14 +59,19 @@ class FastAPITransformer(BaseTransformer):
|
|
|
59
59
|
)
|
|
60
60
|
return updated_node.with_changes(module=cst.Name("fastapi"))
|
|
61
61
|
|
|
62
|
-
#
|
|
63
|
-
|
|
62
|
+
# NOTE: starlette.status imports are intentionally NOT transformed.
|
|
63
|
+
# FastAPI does not export status constants (HTTP_200_OK, etc.) directly.
|
|
64
|
+
# These imports should remain as `from starlette.status import ...`
|
|
65
|
+
# since FastAPI depends on Starlette and these imports work correctly.
|
|
66
|
+
|
|
67
|
+
# Transform starlette.background imports (BackgroundTasks)
|
|
68
|
+
if module_name == "starlette.background":
|
|
64
69
|
self.record_change(
|
|
65
|
-
description="Import
|
|
70
|
+
description="Import BackgroundTasks from fastapi instead of starlette.background",
|
|
66
71
|
line_number=1,
|
|
67
72
|
original=f"from {module_name}",
|
|
68
|
-
replacement="from fastapi
|
|
69
|
-
transform_name="
|
|
73
|
+
replacement="from fastapi",
|
|
74
|
+
transform_name="starlette_to_fastapi_background",
|
|
70
75
|
)
|
|
71
76
|
return updated_node.with_changes(module=cst.Name("fastapi"))
|
|
72
77
|
|
|
@@ -74,10 +79,10 @@ class FastAPITransformer(BaseTransformer):
|
|
|
74
79
|
|
|
75
80
|
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
|
|
76
81
|
"""Transform FastAPI function calls."""
|
|
77
|
-
# Handle Field, Query, Path, Body regex -> pattern
|
|
82
|
+
# Handle Field, Query, Path, Body, Header, Cookie regex -> pattern
|
|
78
83
|
if isinstance(updated_node.func, cst.Name):
|
|
79
84
|
func_name = updated_node.func.value
|
|
80
|
-
if func_name in ("Field", "Query", "Path", "Body"):
|
|
85
|
+
if func_name in ("Field", "Query", "Path", "Body", "Header", "Cookie"):
|
|
81
86
|
new_args = []
|
|
82
87
|
changed = False
|
|
83
88
|
for arg in updated_node.args:
|
|
@@ -0,0 +1,505 @@
|
|
|
1
|
+
"""Flask transformation using LibCST for Flask 1.x to 2.x/3.x migrations."""
|
|
2
|
+
|
|
3
|
+
import libcst as cst
|
|
4
|
+
|
|
5
|
+
from codeshift.migrator.ast_transforms import BaseTransformer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class FlaskTransformer(BaseTransformer):
|
|
9
|
+
"""Transform Flask code for version upgrades (1.x to 2.x/3.x)."""
|
|
10
|
+
|
|
11
|
+
def __init__(self) -> None:
|
|
12
|
+
super().__init__()
|
|
13
|
+
# Track imports that need to be added
|
|
14
|
+
self._needs_markupsafe_escape = False
|
|
15
|
+
self._needs_markupsafe_markup = False
|
|
16
|
+
self._needs_werkzeug_safe_join = False
|
|
17
|
+
self._needs_json_import = False
|
|
18
|
+
# Track what flask imports exist
|
|
19
|
+
self._has_flask_escape_import = False
|
|
20
|
+
self._has_flask_markup_import = False
|
|
21
|
+
self._has_flask_safe_join_import = False
|
|
22
|
+
# Track if markupsafe import already exists
|
|
23
|
+
self._has_markupsafe_import = False
|
|
24
|
+
self._markupsafe_import_names: set[str] = set()
|
|
25
|
+
|
|
26
|
+
def leave_ImportFrom(
|
|
27
|
+
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
|
|
28
|
+
) -> cst.BaseSmallStatement | cst.RemovalSentinel:
|
|
29
|
+
"""Transform Flask imports to their new locations."""
|
|
30
|
+
if updated_node.module is None:
|
|
31
|
+
return updated_node
|
|
32
|
+
|
|
33
|
+
module_name = self._get_module_name(updated_node.module)
|
|
34
|
+
|
|
35
|
+
# Handle flask imports
|
|
36
|
+
if module_name == "flask":
|
|
37
|
+
return self._handle_flask_import(updated_node)
|
|
38
|
+
|
|
39
|
+
# Handle flask.globals imports (deprecated context stacks)
|
|
40
|
+
if module_name == "flask.globals":
|
|
41
|
+
return self._handle_flask_globals_import(updated_node)
|
|
42
|
+
|
|
43
|
+
# Track existing markupsafe imports
|
|
44
|
+
if module_name == "markupsafe":
|
|
45
|
+
self._has_markupsafe_import = True
|
|
46
|
+
if not isinstance(updated_node.names, cst.ImportStar):
|
|
47
|
+
for name in updated_node.names:
|
|
48
|
+
if isinstance(name, cst.ImportAlias):
|
|
49
|
+
imported = self._get_name_value(name.name)
|
|
50
|
+
if imported:
|
|
51
|
+
self._markupsafe_import_names.add(imported)
|
|
52
|
+
|
|
53
|
+
return updated_node
|
|
54
|
+
|
|
55
|
+
def _handle_flask_import(self, node: cst.ImportFrom) -> cst.ImportFrom | cst.RemovalSentinel:
|
|
56
|
+
"""Handle imports from flask module."""
|
|
57
|
+
if isinstance(node.names, cst.ImportStar):
|
|
58
|
+
return node
|
|
59
|
+
|
|
60
|
+
new_names = []
|
|
61
|
+
changed = False
|
|
62
|
+
|
|
63
|
+
for name in node.names:
|
|
64
|
+
if isinstance(name, cst.ImportAlias):
|
|
65
|
+
imported_name = self._get_name_value(name.name)
|
|
66
|
+
|
|
67
|
+
if imported_name == "escape":
|
|
68
|
+
# Mark for adding markupsafe import
|
|
69
|
+
self._needs_markupsafe_escape = True
|
|
70
|
+
self._has_flask_escape_import = True
|
|
71
|
+
changed = True
|
|
72
|
+
self.record_change(
|
|
73
|
+
description="Move 'escape' import from flask to markupsafe",
|
|
74
|
+
line_number=1,
|
|
75
|
+
original="from flask import escape",
|
|
76
|
+
replacement="from markupsafe import escape",
|
|
77
|
+
transform_name="flask_escape_to_markupsafe",
|
|
78
|
+
)
|
|
79
|
+
# Don't add to new_names - we'll add markupsafe import later
|
|
80
|
+
continue
|
|
81
|
+
|
|
82
|
+
elif imported_name == "Markup":
|
|
83
|
+
# Mark for adding markupsafe import
|
|
84
|
+
self._needs_markupsafe_markup = True
|
|
85
|
+
self._has_flask_markup_import = True
|
|
86
|
+
changed = True
|
|
87
|
+
self.record_change(
|
|
88
|
+
description="Move 'Markup' import from flask to markupsafe",
|
|
89
|
+
line_number=1,
|
|
90
|
+
original="from flask import Markup",
|
|
91
|
+
replacement="from markupsafe import Markup",
|
|
92
|
+
transform_name="flask_markup_to_markupsafe",
|
|
93
|
+
)
|
|
94
|
+
# Don't add to new_names - we'll add markupsafe import later
|
|
95
|
+
continue
|
|
96
|
+
|
|
97
|
+
elif imported_name == "safe_join":
|
|
98
|
+
# Mark for adding werkzeug import
|
|
99
|
+
self._needs_werkzeug_safe_join = True
|
|
100
|
+
self._has_flask_safe_join_import = True
|
|
101
|
+
changed = True
|
|
102
|
+
self.record_change(
|
|
103
|
+
description="Move 'safe_join' import from flask to werkzeug.utils",
|
|
104
|
+
line_number=1,
|
|
105
|
+
original="from flask import safe_join",
|
|
106
|
+
replacement="from werkzeug.utils import safe_join",
|
|
107
|
+
transform_name="flask_safe_join_to_werkzeug",
|
|
108
|
+
)
|
|
109
|
+
# Don't add to new_names
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
new_names.append(name)
|
|
113
|
+
|
|
114
|
+
if changed:
|
|
115
|
+
if not new_names:
|
|
116
|
+
# All imports were moved, remove the flask import line
|
|
117
|
+
return cst.RemovalSentinel.REMOVE
|
|
118
|
+
return node.with_changes(names=new_names)
|
|
119
|
+
|
|
120
|
+
return node
|
|
121
|
+
|
|
122
|
+
def _handle_flask_globals_import(
|
|
123
|
+
self, node: cst.ImportFrom
|
|
124
|
+
) -> cst.BaseSmallStatement | cst.RemovalSentinel:
|
|
125
|
+
"""Handle imports from flask.globals (deprecated context stacks)."""
|
|
126
|
+
if isinstance(node.names, cst.ImportStar):
|
|
127
|
+
return node
|
|
128
|
+
|
|
129
|
+
new_names = []
|
|
130
|
+
changed = False
|
|
131
|
+
|
|
132
|
+
for name in node.names:
|
|
133
|
+
if isinstance(name, cst.ImportAlias):
|
|
134
|
+
imported_name = self._get_name_value(name.name)
|
|
135
|
+
|
|
136
|
+
if imported_name in ("_app_ctx_stack", "_request_ctx_stack"):
|
|
137
|
+
changed = True
|
|
138
|
+
self.record_change(
|
|
139
|
+
description=f"Remove deprecated '{imported_name}' import, use flask.g instead",
|
|
140
|
+
line_number=1,
|
|
141
|
+
original=f"from flask.globals import {imported_name}",
|
|
142
|
+
replacement="from flask import g",
|
|
143
|
+
transform_name=f"{imported_name.lstrip('_')}_to_g",
|
|
144
|
+
)
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
new_names.append(name)
|
|
148
|
+
|
|
149
|
+
if changed:
|
|
150
|
+
if not new_names:
|
|
151
|
+
return cst.RemovalSentinel.REMOVE
|
|
152
|
+
return node.with_changes(names=new_names)
|
|
153
|
+
|
|
154
|
+
return node
|
|
155
|
+
|
|
156
|
+
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
|
|
157
|
+
"""Transform Flask function calls."""
|
|
158
|
+
# Handle send_file parameter renames
|
|
159
|
+
if self._is_call_to(updated_node, "send_file"):
|
|
160
|
+
return self._transform_send_file(updated_node)
|
|
161
|
+
|
|
162
|
+
# Handle send_from_directory parameter renames
|
|
163
|
+
if self._is_call_to(updated_node, "send_from_directory"):
|
|
164
|
+
return self._transform_send_from_directory(updated_node)
|
|
165
|
+
|
|
166
|
+
# Handle app.config.from_json -> app.config.from_file
|
|
167
|
+
if self._is_method_call(updated_node, "from_json"):
|
|
168
|
+
return self._transform_from_json(updated_node)
|
|
169
|
+
|
|
170
|
+
return updated_node
|
|
171
|
+
|
|
172
|
+
def _transform_send_file(self, node: cst.Call) -> cst.Call:
|
|
173
|
+
"""Transform send_file() parameter names."""
|
|
174
|
+
new_args = []
|
|
175
|
+
changed = False
|
|
176
|
+
|
|
177
|
+
param_renames = {
|
|
178
|
+
"attachment_filename": (
|
|
179
|
+
"download_name",
|
|
180
|
+
"send_file_attachment_filename_to_download_name",
|
|
181
|
+
),
|
|
182
|
+
"cache_timeout": ("max_age", "send_file_cache_timeout_to_max_age"),
|
|
183
|
+
"add_etags": ("etag", "send_file_add_etags_to_etag"),
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
for arg in node.args:
|
|
187
|
+
if isinstance(arg.keyword, cst.Name):
|
|
188
|
+
keyword_name = arg.keyword.value
|
|
189
|
+
if keyword_name in param_renames:
|
|
190
|
+
new_name, transform_name = param_renames[keyword_name]
|
|
191
|
+
new_args.append(arg.with_changes(keyword=cst.Name(new_name)))
|
|
192
|
+
changed = True
|
|
193
|
+
self.record_change(
|
|
194
|
+
description=f"Rename send_file({keyword_name}=...) to send_file({new_name}=...)",
|
|
195
|
+
line_number=1,
|
|
196
|
+
original=f"send_file({keyword_name}=...)",
|
|
197
|
+
replacement=f"send_file({new_name}=...)",
|
|
198
|
+
transform_name=transform_name,
|
|
199
|
+
)
|
|
200
|
+
else:
|
|
201
|
+
new_args.append(arg)
|
|
202
|
+
else:
|
|
203
|
+
new_args.append(arg)
|
|
204
|
+
|
|
205
|
+
if changed:
|
|
206
|
+
return node.with_changes(args=new_args)
|
|
207
|
+
return node
|
|
208
|
+
|
|
209
|
+
def _transform_send_from_directory(self, node: cst.Call) -> cst.Call:
|
|
210
|
+
"""Transform send_from_directory() parameter names."""
|
|
211
|
+
new_args = []
|
|
212
|
+
changed = False
|
|
213
|
+
|
|
214
|
+
for arg in node.args:
|
|
215
|
+
if isinstance(arg.keyword, cst.Name) and arg.keyword.value == "filename":
|
|
216
|
+
new_args.append(arg.with_changes(keyword=cst.Name("path")))
|
|
217
|
+
changed = True
|
|
218
|
+
self.record_change(
|
|
219
|
+
description="Rename send_from_directory(filename=...) to send_from_directory(path=...)",
|
|
220
|
+
line_number=1,
|
|
221
|
+
original="send_from_directory(filename=...)",
|
|
222
|
+
replacement="send_from_directory(path=...)",
|
|
223
|
+
transform_name="send_from_directory_filename_to_path",
|
|
224
|
+
)
|
|
225
|
+
else:
|
|
226
|
+
new_args.append(arg)
|
|
227
|
+
|
|
228
|
+
if changed:
|
|
229
|
+
return node.with_changes(args=new_args)
|
|
230
|
+
return node
|
|
231
|
+
|
|
232
|
+
def _transform_from_json(self, node: cst.Call) -> cst.Call:
|
|
233
|
+
"""Transform config.from_json() to config.from_file() with json.load."""
|
|
234
|
+
# Check if this is actually a from_json call on config
|
|
235
|
+
if not isinstance(node.func, cst.Attribute):
|
|
236
|
+
return node
|
|
237
|
+
|
|
238
|
+
if node.func.attr.value != "from_json":
|
|
239
|
+
return node
|
|
240
|
+
|
|
241
|
+
# Check the call chain to see if it's on config
|
|
242
|
+
value = node.func.value
|
|
243
|
+
is_config_call = False
|
|
244
|
+
if isinstance(value, cst.Attribute) and value.attr.value == "config":
|
|
245
|
+
is_config_call = True
|
|
246
|
+
elif isinstance(value, cst.Name) and value.value == "config":
|
|
247
|
+
is_config_call = True
|
|
248
|
+
|
|
249
|
+
if not is_config_call:
|
|
250
|
+
return node
|
|
251
|
+
|
|
252
|
+
# Get the first positional argument (the filename)
|
|
253
|
+
if not node.args:
|
|
254
|
+
return node
|
|
255
|
+
|
|
256
|
+
file_arg = node.args[0]
|
|
257
|
+
|
|
258
|
+
# Transform from_json to from_file with json.load
|
|
259
|
+
self._needs_json_import = True
|
|
260
|
+
|
|
261
|
+
# Build new arguments: (filename, load=json.load)
|
|
262
|
+
new_args = [
|
|
263
|
+
file_arg,
|
|
264
|
+
cst.Arg(
|
|
265
|
+
keyword=cst.Name("load"),
|
|
266
|
+
value=cst.Attribute(
|
|
267
|
+
value=cst.Name("json"),
|
|
268
|
+
attr=cst.Name("load"),
|
|
269
|
+
),
|
|
270
|
+
equal=cst.AssignEqual(
|
|
271
|
+
whitespace_before=cst.SimpleWhitespace(""),
|
|
272
|
+
whitespace_after=cst.SimpleWhitespace(""),
|
|
273
|
+
),
|
|
274
|
+
),
|
|
275
|
+
]
|
|
276
|
+
|
|
277
|
+
# Change the method name from from_json to from_file
|
|
278
|
+
new_func = node.func.with_changes(attr=cst.Name("from_file"))
|
|
279
|
+
|
|
280
|
+
self.record_change(
|
|
281
|
+
description="Convert config.from_json() to config.from_file() with json.load",
|
|
282
|
+
line_number=1,
|
|
283
|
+
original="config.from_json('file.json')",
|
|
284
|
+
replacement="config.from_file('file.json', load=json.load)",
|
|
285
|
+
transform_name="config_from_json_to_from_file",
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
return node.with_changes(func=new_func, args=new_args)
|
|
289
|
+
|
|
290
|
+
def leave_Attribute(
|
|
291
|
+
self, original_node: cst.Attribute, updated_node: cst.Attribute
|
|
292
|
+
) -> cst.BaseExpression:
|
|
293
|
+
"""Transform attribute access for deprecated properties."""
|
|
294
|
+
attr_name = updated_node.attr.value
|
|
295
|
+
|
|
296
|
+
# Handle app.env -> app.debug
|
|
297
|
+
if attr_name == "env":
|
|
298
|
+
# Check if it's likely an app.env access
|
|
299
|
+
if isinstance(updated_node.value, cst.Name):
|
|
300
|
+
if updated_node.value.value in ("app", "application", "current_app"):
|
|
301
|
+
self.record_change(
|
|
302
|
+
description="Convert app.env to app.debug (env property deprecated)",
|
|
303
|
+
line_number=1,
|
|
304
|
+
original="app.env",
|
|
305
|
+
replacement="app.debug",
|
|
306
|
+
transform_name="app_env_to_debug",
|
|
307
|
+
)
|
|
308
|
+
return updated_node.with_changes(attr=cst.Name("debug"))
|
|
309
|
+
|
|
310
|
+
return updated_node
|
|
311
|
+
|
|
312
|
+
def _is_call_to(self, node: cst.Call, func_name: str) -> bool:
|
|
313
|
+
"""Check if a Call node is calling a specific function."""
|
|
314
|
+
if isinstance(node.func, cst.Name):
|
|
315
|
+
return bool(node.func.value == func_name)
|
|
316
|
+
return False
|
|
317
|
+
|
|
318
|
+
def _is_method_call(self, node: cst.Call, method_name: str) -> bool:
|
|
319
|
+
"""Check if a Call node is calling a specific method."""
|
|
320
|
+
if isinstance(node.func, cst.Attribute):
|
|
321
|
+
return bool(node.func.attr.value == method_name)
|
|
322
|
+
return False
|
|
323
|
+
|
|
324
|
+
def _get_module_name(self, module: cst.BaseExpression) -> str:
|
|
325
|
+
"""Get the full module name from a Name or Attribute node."""
|
|
326
|
+
if isinstance(module, cst.Name):
|
|
327
|
+
return str(module.value)
|
|
328
|
+
elif isinstance(module, cst.Attribute):
|
|
329
|
+
return f"{self._get_module_name(module.value)}.{module.attr.value}"
|
|
330
|
+
return ""
|
|
331
|
+
|
|
332
|
+
def _get_name_value(self, node: cst.BaseExpression) -> str | None:
|
|
333
|
+
"""Extract the string value from a Name node."""
|
|
334
|
+
if isinstance(node, cst.Name):
|
|
335
|
+
return str(node.value)
|
|
336
|
+
return None
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
class FlaskImportAdder(cst.CSTTransformer):
|
|
340
|
+
"""Adds new imports needed after Flask transformation."""
|
|
341
|
+
|
|
342
|
+
def __init__(
|
|
343
|
+
self,
|
|
344
|
+
needs_markupsafe_escape: bool = False,
|
|
345
|
+
needs_markupsafe_markup: bool = False,
|
|
346
|
+
needs_werkzeug_safe_join: bool = False,
|
|
347
|
+
needs_json_import: bool = False,
|
|
348
|
+
has_markupsafe_import: bool = False,
|
|
349
|
+
existing_markupsafe_names: set[str] | None = None,
|
|
350
|
+
) -> None:
|
|
351
|
+
super().__init__()
|
|
352
|
+
self.needs_markupsafe_escape = needs_markupsafe_escape
|
|
353
|
+
self.needs_markupsafe_markup = needs_markupsafe_markup
|
|
354
|
+
self.needs_werkzeug_safe_join = needs_werkzeug_safe_join
|
|
355
|
+
self.needs_json_import = needs_json_import
|
|
356
|
+
self.has_markupsafe_import = has_markupsafe_import
|
|
357
|
+
self.existing_markupsafe_names = existing_markupsafe_names or set()
|
|
358
|
+
self._added_imports = False
|
|
359
|
+
self._has_json_import = False
|
|
360
|
+
self._has_werkzeug_import = False
|
|
361
|
+
|
|
362
|
+
def visit_ImportFrom(self, node: cst.ImportFrom) -> bool:
|
|
363
|
+
"""Track existing imports."""
|
|
364
|
+
if node.module:
|
|
365
|
+
module_name = self._get_module_name(node.module)
|
|
366
|
+
if module_name == "json":
|
|
367
|
+
self._has_json_import = True
|
|
368
|
+
elif module_name == "werkzeug.utils":
|
|
369
|
+
self._has_werkzeug_import = True
|
|
370
|
+
return True
|
|
371
|
+
|
|
372
|
+
def visit_Import(self, node: cst.Import) -> bool:
|
|
373
|
+
"""Track existing json import."""
|
|
374
|
+
for name in node.names:
|
|
375
|
+
if isinstance(name, cst.ImportAlias):
|
|
376
|
+
if isinstance(name.name, cst.Name) and name.name.value == "json":
|
|
377
|
+
self._has_json_import = True
|
|
378
|
+
return True
|
|
379
|
+
|
|
380
|
+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
|
381
|
+
"""Add necessary imports at the top of the module."""
|
|
382
|
+
if self._added_imports:
|
|
383
|
+
return updated_node
|
|
384
|
+
|
|
385
|
+
new_imports = []
|
|
386
|
+
|
|
387
|
+
# Add markupsafe import if needed
|
|
388
|
+
if self.needs_markupsafe_escape or self.needs_markupsafe_markup:
|
|
389
|
+
names_to_import = []
|
|
390
|
+
if self.needs_markupsafe_escape and "escape" not in self.existing_markupsafe_names:
|
|
391
|
+
names_to_import.append(cst.ImportAlias(name=cst.Name("escape")))
|
|
392
|
+
if self.needs_markupsafe_markup and "Markup" not in self.existing_markupsafe_names:
|
|
393
|
+
names_to_import.append(cst.ImportAlias(name=cst.Name("Markup")))
|
|
394
|
+
|
|
395
|
+
if names_to_import:
|
|
396
|
+
new_imports.append(
|
|
397
|
+
cst.SimpleStatementLine(
|
|
398
|
+
body=[
|
|
399
|
+
cst.ImportFrom(
|
|
400
|
+
module=cst.Name("markupsafe"),
|
|
401
|
+
names=names_to_import,
|
|
402
|
+
)
|
|
403
|
+
]
|
|
404
|
+
)
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
# Add werkzeug.utils import if needed
|
|
408
|
+
if self.needs_werkzeug_safe_join and not self._has_werkzeug_import:
|
|
409
|
+
new_imports.append(
|
|
410
|
+
cst.SimpleStatementLine(
|
|
411
|
+
body=[
|
|
412
|
+
cst.ImportFrom(
|
|
413
|
+
module=cst.Attribute(
|
|
414
|
+
value=cst.Name("werkzeug"),
|
|
415
|
+
attr=cst.Name("utils"),
|
|
416
|
+
),
|
|
417
|
+
names=[cst.ImportAlias(name=cst.Name("safe_join"))],
|
|
418
|
+
)
|
|
419
|
+
]
|
|
420
|
+
)
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
# Add json import if needed
|
|
424
|
+
if self.needs_json_import and not self._has_json_import:
|
|
425
|
+
new_imports.append(
|
|
426
|
+
cst.SimpleStatementLine(
|
|
427
|
+
body=[cst.Import(names=[cst.ImportAlias(name=cst.Name("json"))])]
|
|
428
|
+
)
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
if new_imports:
|
|
432
|
+
# Find the first import statement and insert before it
|
|
433
|
+
new_body = list(updated_node.body)
|
|
434
|
+
|
|
435
|
+
# Find insertion point (after module docstring, before first import)
|
|
436
|
+
insert_idx = 0
|
|
437
|
+
for i, stmt in enumerate(new_body):
|
|
438
|
+
if isinstance(stmt, cst.SimpleStatementLine):
|
|
439
|
+
if stmt.body and isinstance(stmt.body[0], cst.Import | cst.ImportFrom):
|
|
440
|
+
insert_idx = i
|
|
441
|
+
break
|
|
442
|
+
elif stmt.body and isinstance(stmt.body[0], cst.Expr):
|
|
443
|
+
# Could be docstring, continue
|
|
444
|
+
if isinstance(stmt.body[0].value, cst.SimpleString):
|
|
445
|
+
insert_idx = i + 1
|
|
446
|
+
continue
|
|
447
|
+
insert_idx = i
|
|
448
|
+
break
|
|
449
|
+
|
|
450
|
+
# Insert new imports
|
|
451
|
+
for imp in reversed(new_imports):
|
|
452
|
+
new_body.insert(insert_idx, imp)
|
|
453
|
+
|
|
454
|
+
self._added_imports = True
|
|
455
|
+
return updated_node.with_changes(body=new_body)
|
|
456
|
+
|
|
457
|
+
return updated_node
|
|
458
|
+
|
|
459
|
+
def _get_module_name(self, module: cst.BaseExpression) -> str:
|
|
460
|
+
"""Get the full module name from a Name or Attribute node."""
|
|
461
|
+
if isinstance(module, cst.Name):
|
|
462
|
+
return str(module.value)
|
|
463
|
+
elif isinstance(module, cst.Attribute):
|
|
464
|
+
return f"{self._get_module_name(module.value)}.{module.attr.value}"
|
|
465
|
+
return ""
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def transform_flask(source_code: str) -> tuple[str, list]:
|
|
469
|
+
"""Transform Flask code for version upgrades.
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
source_code: The source code to transform
|
|
473
|
+
|
|
474
|
+
Returns:
|
|
475
|
+
Tuple of (transformed_code, list of changes)
|
|
476
|
+
"""
|
|
477
|
+
try:
|
|
478
|
+
tree = cst.parse_module(source_code)
|
|
479
|
+
except cst.ParserSyntaxError:
|
|
480
|
+
return source_code, []
|
|
481
|
+
|
|
482
|
+
# First pass: main transformations
|
|
483
|
+
transformer = FlaskTransformer()
|
|
484
|
+
transformer.set_source(source_code)
|
|
485
|
+
|
|
486
|
+
try:
|
|
487
|
+
transformed_tree = tree.visit(transformer)
|
|
488
|
+
except Exception:
|
|
489
|
+
return source_code, []
|
|
490
|
+
|
|
491
|
+
# Second pass: add missing imports
|
|
492
|
+
import_adder = FlaskImportAdder(
|
|
493
|
+
needs_markupsafe_escape=transformer._needs_markupsafe_escape,
|
|
494
|
+
needs_markupsafe_markup=transformer._needs_markupsafe_markup,
|
|
495
|
+
needs_werkzeug_safe_join=transformer._needs_werkzeug_safe_join,
|
|
496
|
+
needs_json_import=transformer._needs_json_import,
|
|
497
|
+
has_markupsafe_import=transformer._has_markupsafe_import,
|
|
498
|
+
existing_markupsafe_names=transformer._markupsafe_import_names,
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
try:
|
|
502
|
+
final_tree = transformed_tree.visit(import_adder)
|
|
503
|
+
return final_tree.code, transformer.changes
|
|
504
|
+
except Exception:
|
|
505
|
+
return transformed_tree.code, transformer.changes
|