codeshift 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codeshift/cli/commands/apply.py +24 -2
- codeshift/cli/package_manager.py +102 -0
- codeshift/knowledge/generator.py +11 -1
- codeshift/knowledge_base/libraries/aiohttp.yaml +186 -0
- codeshift/knowledge_base/libraries/attrs.yaml +181 -0
- codeshift/knowledge_base/libraries/celery.yaml +244 -0
- codeshift/knowledge_base/libraries/click.yaml +195 -0
- codeshift/knowledge_base/libraries/django.yaml +355 -0
- codeshift/knowledge_base/libraries/flask.yaml +270 -0
- codeshift/knowledge_base/libraries/httpx.yaml +183 -0
- codeshift/knowledge_base/libraries/marshmallow.yaml +238 -0
- codeshift/knowledge_base/libraries/numpy.yaml +429 -0
- codeshift/knowledge_base/libraries/pytest.yaml +192 -0
- codeshift/knowledge_base/libraries/sqlalchemy.yaml +2 -1
- codeshift/migrator/engine.py +60 -0
- codeshift/migrator/transforms/__init__.py +2 -0
- codeshift/migrator/transforms/aiohttp_transformer.py +608 -0
- codeshift/migrator/transforms/attrs_transformer.py +570 -0
- codeshift/migrator/transforms/celery_transformer.py +546 -0
- codeshift/migrator/transforms/click_transformer.py +526 -0
- codeshift/migrator/transforms/django_transformer.py +852 -0
- codeshift/migrator/transforms/fastapi_transformer.py +12 -7
- codeshift/migrator/transforms/flask_transformer.py +505 -0
- codeshift/migrator/transforms/httpx_transformer.py +419 -0
- codeshift/migrator/transforms/marshmallow_transformer.py +515 -0
- codeshift/migrator/transforms/numpy_transformer.py +413 -0
- codeshift/migrator/transforms/pydantic_v1_to_v2.py +53 -8
- codeshift/migrator/transforms/pytest_transformer.py +351 -0
- codeshift/migrator/transforms/requests_transformer.py +74 -1
- codeshift/migrator/transforms/sqlalchemy_transformer.py +692 -39
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/METADATA +46 -4
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/RECORD +36 -15
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/WHEEL +0 -0
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/entry_points.txt +0 -0
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/licenses/LICENSE +0 -0
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,546 @@
|
|
|
1
|
+
"""Celery 4.x to 5.x transformation using LibCST."""
|
|
2
|
+
|
|
3
|
+
import libcst as cst
|
|
4
|
+
|
|
5
|
+
from codeshift.migrator.ast_transforms import BaseTransformer
|
|
6
|
+
|
|
7
|
+
# Mapping of old uppercase config keys to new lowercase keys
|
|
8
|
+
CONFIG_KEY_MAPPINGS = {
|
|
9
|
+
# Result backend and broker
|
|
10
|
+
"CELERY_RESULT_BACKEND": "result_backend",
|
|
11
|
+
"CELERY_BROKER_URL": "broker_url",
|
|
12
|
+
"BROKER_URL": "broker_url",
|
|
13
|
+
# Task settings
|
|
14
|
+
"CELERY_TASK_ALWAYS_EAGER": "task_always_eager",
|
|
15
|
+
"CELERY_TASK_EAGER_PROPAGATES": "task_eager_propagates",
|
|
16
|
+
"CELERY_TASK_IGNORE_RESULT": "task_ignore_result",
|
|
17
|
+
"CELERY_TASK_TRACK_STARTED": "task_track_started",
|
|
18
|
+
"CELERY_TASK_TIME_LIMIT": "task_time_limit",
|
|
19
|
+
"CELERY_TASK_SOFT_TIME_LIMIT": "task_soft_time_limit",
|
|
20
|
+
"CELERY_TASK_ACKS_LATE": "task_acks_late",
|
|
21
|
+
"CELERY_TASK_SERIALIZER": "task_serializer",
|
|
22
|
+
"CELERY_TASK_ANNOTATIONS": "task_annotations",
|
|
23
|
+
# Result settings
|
|
24
|
+
"CELERY_RESULT_SERIALIZER": "result_serializer",
|
|
25
|
+
"CELERY_RESULT_EXPIRES": "result_expires",
|
|
26
|
+
# General settings
|
|
27
|
+
"CELERY_ACCEPT_CONTENT": "accept_content",
|
|
28
|
+
"CELERY_TIMEZONE": "timezone",
|
|
29
|
+
"CELERY_ENABLE_UTC": "enable_utc",
|
|
30
|
+
"CELERY_IMPORTS": "imports",
|
|
31
|
+
"CELERY_INCLUDE": "include",
|
|
32
|
+
# Worker settings (CELERYD_ prefix)
|
|
33
|
+
"CELERYD_CONCURRENCY": "worker_concurrency",
|
|
34
|
+
"CELERYD_PREFETCH_MULTIPLIER": "worker_prefetch_multiplier",
|
|
35
|
+
"CELERYD_MAX_TASKS_PER_CHILD": "worker_max_tasks_per_child",
|
|
36
|
+
"CELERYD_DISABLE_RATE_LIMITS": "worker_disable_rate_limits",
|
|
37
|
+
"CELERYD_TASK_TIME_LIMIT": "worker_task_time_limit",
|
|
38
|
+
"CELERYD_TASK_SOFT_TIME_LIMIT": "worker_task_soft_time_limit",
|
|
39
|
+
# Beat settings
|
|
40
|
+
"CELERY_BEAT_SCHEDULE": "beat_schedule",
|
|
41
|
+
"CELERY_BEAT_SCHEDULER": "beat_scheduler",
|
|
42
|
+
"CELERYBEAT_SCHEDULE": "beat_schedule",
|
|
43
|
+
"CELERYBEAT_SCHEDULER": "beat_scheduler",
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class CeleryTransformer(BaseTransformer):
|
|
48
|
+
"""Transform Celery 4.x code to 5.x."""
|
|
49
|
+
|
|
50
|
+
def __init__(self) -> None:
|
|
51
|
+
super().__init__()
|
|
52
|
+
self._needs_shared_task_import = False
|
|
53
|
+
self._needs_task_import = False
|
|
54
|
+
self._has_shared_task_import = False
|
|
55
|
+
self._has_task_import = False
|
|
56
|
+
self._removed_celery_task_import = False
|
|
57
|
+
self._removed_celery_decorators_import = False
|
|
58
|
+
|
|
59
|
+
def leave_ImportFrom(
|
|
60
|
+
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
|
|
61
|
+
) -> cst.ImportFrom | cst.RemovalSentinel:
|
|
62
|
+
"""Transform Celery imports."""
|
|
63
|
+
if updated_node.module is None:
|
|
64
|
+
return updated_node
|
|
65
|
+
|
|
66
|
+
module_name = self._get_module_name(updated_node.module)
|
|
67
|
+
|
|
68
|
+
# Handle celery.task module (removed in 5.0)
|
|
69
|
+
if module_name == "celery.task":
|
|
70
|
+
return self._transform_celery_task_import(updated_node)
|
|
71
|
+
|
|
72
|
+
# Handle celery.decorators module (removed in 5.0)
|
|
73
|
+
if module_name == "celery.decorators":
|
|
74
|
+
return self._transform_celery_decorators_import(updated_node)
|
|
75
|
+
|
|
76
|
+
# Handle celery.task.schedules -> celery.schedules
|
|
77
|
+
if module_name == "celery.task.schedules":
|
|
78
|
+
return self._transform_schedules_import(updated_node)
|
|
79
|
+
|
|
80
|
+
# Handle celery.utils.encoding -> kombu.utils.encoding
|
|
81
|
+
if module_name == "celery.utils.encoding":
|
|
82
|
+
return self._transform_encoding_import(updated_node)
|
|
83
|
+
|
|
84
|
+
# Track existing celery imports
|
|
85
|
+
if module_name == "celery":
|
|
86
|
+
self._track_celery_imports(updated_node)
|
|
87
|
+
|
|
88
|
+
return updated_node
|
|
89
|
+
|
|
90
|
+
def _transform_celery_task_import(
|
|
91
|
+
self, node: cst.ImportFrom
|
|
92
|
+
) -> cst.ImportFrom | cst.RemovalSentinel:
|
|
93
|
+
"""Transform imports from celery.task module."""
|
|
94
|
+
if isinstance(node.names, cst.ImportStar):
|
|
95
|
+
self.record_change(
|
|
96
|
+
description="Remove 'from celery.task import *' (module removed)",
|
|
97
|
+
line_number=1,
|
|
98
|
+
original="from celery.task import *",
|
|
99
|
+
replacement="from celery import shared_task, Task",
|
|
100
|
+
transform_name="import_celery_task_star",
|
|
101
|
+
)
|
|
102
|
+
self._needs_shared_task_import = True
|
|
103
|
+
self._needs_task_import = True
|
|
104
|
+
self._removed_celery_task_import = True
|
|
105
|
+
return cst.RemovalSentinel.REMOVE
|
|
106
|
+
|
|
107
|
+
imports_to_add_to_celery: list[str] = []
|
|
108
|
+
|
|
109
|
+
for name in node.names:
|
|
110
|
+
if isinstance(name, cst.ImportAlias) and isinstance(name.name, cst.Name):
|
|
111
|
+
imported_name = name.name.value
|
|
112
|
+
|
|
113
|
+
if imported_name == "Task":
|
|
114
|
+
# from celery.task import Task -> from celery import Task
|
|
115
|
+
imports_to_add_to_celery.append("Task")
|
|
116
|
+
self._needs_task_import = True
|
|
117
|
+
self.record_change(
|
|
118
|
+
description="Change 'from celery.task import Task' to 'from celery import Task'",
|
|
119
|
+
line_number=1,
|
|
120
|
+
original="from celery.task import Task",
|
|
121
|
+
replacement="from celery import Task",
|
|
122
|
+
transform_name="import_task_class",
|
|
123
|
+
)
|
|
124
|
+
elif imported_name == "task":
|
|
125
|
+
# from celery.task import task -> from celery import shared_task
|
|
126
|
+
imports_to_add_to_celery.append("shared_task")
|
|
127
|
+
self._needs_shared_task_import = True
|
|
128
|
+
self.record_change(
|
|
129
|
+
description="Change 'from celery.task import task' to 'from celery import shared_task'",
|
|
130
|
+
line_number=1,
|
|
131
|
+
original="from celery.task import task",
|
|
132
|
+
replacement="from celery import shared_task",
|
|
133
|
+
transform_name="import_task_decorator",
|
|
134
|
+
)
|
|
135
|
+
elif imported_name == "periodic_task":
|
|
136
|
+
# periodic_task is removed, but we still need to signal this
|
|
137
|
+
self.record_change(
|
|
138
|
+
description="Remove 'periodic_task' import (use beat_schedule config instead)",
|
|
139
|
+
line_number=1,
|
|
140
|
+
original="from celery.task import periodic_task",
|
|
141
|
+
replacement="# Configure periodic tasks via beat_schedule",
|
|
142
|
+
transform_name="import_periodic_task_removed",
|
|
143
|
+
confidence=0.8,
|
|
144
|
+
notes="periodic_task decorator removed; use beat_schedule configuration",
|
|
145
|
+
)
|
|
146
|
+
else:
|
|
147
|
+
# Keep other imports but they will need to come from celery
|
|
148
|
+
imports_to_add_to_celery.append(imported_name)
|
|
149
|
+
|
|
150
|
+
self._removed_celery_task_import = True
|
|
151
|
+
|
|
152
|
+
# Remove this import line entirely; imports will be added to celery
|
|
153
|
+
return cst.RemovalSentinel.REMOVE
|
|
154
|
+
|
|
155
|
+
def _transform_celery_decorators_import(
|
|
156
|
+
self, node: cst.ImportFrom
|
|
157
|
+
) -> cst.ImportFrom | cst.RemovalSentinel:
|
|
158
|
+
"""Transform imports from celery.decorators module."""
|
|
159
|
+
if isinstance(node.names, cst.ImportStar):
|
|
160
|
+
self.record_change(
|
|
161
|
+
description="Remove 'from celery.decorators import *' (module removed)",
|
|
162
|
+
line_number=1,
|
|
163
|
+
original="from celery.decorators import *",
|
|
164
|
+
replacement="from celery import shared_task",
|
|
165
|
+
transform_name="import_celery_decorators_star",
|
|
166
|
+
)
|
|
167
|
+
self._needs_shared_task_import = True
|
|
168
|
+
self._removed_celery_decorators_import = True
|
|
169
|
+
return cst.RemovalSentinel.REMOVE
|
|
170
|
+
|
|
171
|
+
for name in node.names:
|
|
172
|
+
if isinstance(name, cst.ImportAlias) and isinstance(name.name, cst.Name):
|
|
173
|
+
imported_name = name.name.value
|
|
174
|
+
|
|
175
|
+
if imported_name == "task":
|
|
176
|
+
self._needs_shared_task_import = True
|
|
177
|
+
self.record_change(
|
|
178
|
+
description="Change 'from celery.decorators import task' to 'from celery import shared_task'",
|
|
179
|
+
line_number=1,
|
|
180
|
+
original="from celery.decorators import task",
|
|
181
|
+
replacement="from celery import shared_task",
|
|
182
|
+
transform_name="import_decorators_task",
|
|
183
|
+
)
|
|
184
|
+
elif imported_name == "periodic_task":
|
|
185
|
+
self.record_change(
|
|
186
|
+
description="Remove 'periodic_task' import (use beat_schedule config instead)",
|
|
187
|
+
line_number=1,
|
|
188
|
+
original="from celery.decorators import periodic_task",
|
|
189
|
+
replacement="# Configure periodic tasks via beat_schedule",
|
|
190
|
+
transform_name="import_periodic_task_removed",
|
|
191
|
+
confidence=0.8,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
self._removed_celery_decorators_import = True
|
|
195
|
+
return cst.RemovalSentinel.REMOVE
|
|
196
|
+
|
|
197
|
+
def _transform_schedules_import(self, node: cst.ImportFrom) -> cst.ImportFrom:
|
|
198
|
+
"""Transform celery.task.schedules -> celery.schedules."""
|
|
199
|
+
self.record_change(
|
|
200
|
+
description="Change 'celery.task.schedules' to 'celery.schedules'",
|
|
201
|
+
line_number=1,
|
|
202
|
+
original="from celery.task.schedules import ...",
|
|
203
|
+
replacement="from celery.schedules import ...",
|
|
204
|
+
transform_name="import_schedules",
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Change module path
|
|
208
|
+
new_module = cst.Attribute(
|
|
209
|
+
value=cst.Name("celery"),
|
|
210
|
+
attr=cst.Name("schedules"),
|
|
211
|
+
)
|
|
212
|
+
return node.with_changes(module=new_module)
|
|
213
|
+
|
|
214
|
+
def _transform_encoding_import(self, node: cst.ImportFrom) -> cst.ImportFrom:
|
|
215
|
+
"""Transform celery.utils.encoding -> kombu.utils.encoding."""
|
|
216
|
+
self.record_change(
|
|
217
|
+
description="Change 'celery.utils.encoding' to 'kombu.utils.encoding'",
|
|
218
|
+
line_number=1,
|
|
219
|
+
original="from celery.utils.encoding import ...",
|
|
220
|
+
replacement="from kombu.utils.encoding import ...",
|
|
221
|
+
transform_name="import_utils_encoding",
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Change module path to kombu.utils.encoding
|
|
225
|
+
new_module = cst.Attribute(
|
|
226
|
+
value=cst.Attribute(
|
|
227
|
+
value=cst.Name("kombu"),
|
|
228
|
+
attr=cst.Name("utils"),
|
|
229
|
+
),
|
|
230
|
+
attr=cst.Name("encoding"),
|
|
231
|
+
)
|
|
232
|
+
return node.with_changes(module=new_module)
|
|
233
|
+
|
|
234
|
+
def _track_celery_imports(self, node: cst.ImportFrom) -> None:
|
|
235
|
+
"""Track what's already imported from celery."""
|
|
236
|
+
if isinstance(node.names, cst.ImportStar):
|
|
237
|
+
self._has_shared_task_import = True
|
|
238
|
+
self._has_task_import = True
|
|
239
|
+
return
|
|
240
|
+
|
|
241
|
+
for name in node.names:
|
|
242
|
+
if isinstance(name, cst.ImportAlias) and isinstance(name.name, cst.Name):
|
|
243
|
+
if name.name.value == "shared_task":
|
|
244
|
+
self._has_shared_task_import = True
|
|
245
|
+
elif name.name.value == "Task":
|
|
246
|
+
self._has_task_import = True
|
|
247
|
+
|
|
248
|
+
def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> cst.Assign:
|
|
249
|
+
"""Transform configuration assignments with old uppercase names."""
|
|
250
|
+
# Check for simple name assignments like CELERY_RESULT_BACKEND = "..."
|
|
251
|
+
for target in updated_node.targets:
|
|
252
|
+
if isinstance(target.target, cst.Name):
|
|
253
|
+
var_name = target.target.value
|
|
254
|
+
if var_name in CONFIG_KEY_MAPPINGS:
|
|
255
|
+
new_name = CONFIG_KEY_MAPPINGS[var_name]
|
|
256
|
+
self.record_change(
|
|
257
|
+
description=f"Rename config '{var_name}' to '{new_name}'",
|
|
258
|
+
line_number=1,
|
|
259
|
+
original=f"{var_name} = ...",
|
|
260
|
+
replacement=f"{new_name} = ...",
|
|
261
|
+
transform_name=f"config_{new_name}",
|
|
262
|
+
)
|
|
263
|
+
# Update the target name
|
|
264
|
+
new_targets = []
|
|
265
|
+
for t in updated_node.targets:
|
|
266
|
+
if isinstance(t.target, cst.Name) and t.target.value == var_name:
|
|
267
|
+
new_targets.append(t.with_changes(target=cst.Name(new_name)))
|
|
268
|
+
else:
|
|
269
|
+
new_targets.append(t)
|
|
270
|
+
return updated_node.with_changes(targets=new_targets)
|
|
271
|
+
|
|
272
|
+
return updated_node
|
|
273
|
+
|
|
274
|
+
def leave_Attribute(
|
|
275
|
+
self, original_node: cst.Attribute, updated_node: cst.Attribute
|
|
276
|
+
) -> cst.BaseExpression:
|
|
277
|
+
"""Transform attribute access like app.conf.CELERY_RESULT_BACKEND."""
|
|
278
|
+
attr_name = updated_node.attr.value
|
|
279
|
+
|
|
280
|
+
# Check if this is a config key that needs renaming
|
|
281
|
+
if attr_name in CONFIG_KEY_MAPPINGS:
|
|
282
|
+
new_name = CONFIG_KEY_MAPPINGS[attr_name]
|
|
283
|
+
self.record_change(
|
|
284
|
+
description=f"Rename config attribute '{attr_name}' to '{new_name}'",
|
|
285
|
+
line_number=1,
|
|
286
|
+
original=f".{attr_name}",
|
|
287
|
+
replacement=f".{new_name}",
|
|
288
|
+
transform_name=f"attr_{new_name}",
|
|
289
|
+
)
|
|
290
|
+
return updated_node.with_changes(attr=cst.Name(new_name))
|
|
291
|
+
|
|
292
|
+
return updated_node
|
|
293
|
+
|
|
294
|
+
def leave_Subscript(
|
|
295
|
+
self, original_node: cst.Subscript, updated_node: cst.Subscript
|
|
296
|
+
) -> cst.BaseExpression:
|
|
297
|
+
"""Transform subscript access like app.conf['CELERY_RESULT_BACKEND']."""
|
|
298
|
+
# Check if this is a string subscript with a config key
|
|
299
|
+
if len(updated_node.slice) == 1:
|
|
300
|
+
slice_elem = updated_node.slice[0]
|
|
301
|
+
if isinstance(slice_elem, cst.SubscriptElement):
|
|
302
|
+
if isinstance(slice_elem.slice, cst.Index):
|
|
303
|
+
index_value = slice_elem.slice.value
|
|
304
|
+
if isinstance(index_value, cst.SimpleString):
|
|
305
|
+
# Extract the string value (remove quotes)
|
|
306
|
+
key = index_value.value[1:-1]
|
|
307
|
+
if key in CONFIG_KEY_MAPPINGS:
|
|
308
|
+
new_key = CONFIG_KEY_MAPPINGS[key]
|
|
309
|
+
quote_char = index_value.value[0]
|
|
310
|
+
new_string = cst.SimpleString(f"{quote_char}{new_key}{quote_char}")
|
|
311
|
+
new_index = cst.Index(value=new_string)
|
|
312
|
+
new_slice = [cst.SubscriptElement(slice=new_index)]
|
|
313
|
+
|
|
314
|
+
self.record_change(
|
|
315
|
+
description=f"Rename config key '{key}' to '{new_key}'",
|
|
316
|
+
line_number=1,
|
|
317
|
+
original=f"['{key}']",
|
|
318
|
+
replacement=f"['{new_key}']",
|
|
319
|
+
transform_name=f"subscript_{new_key}",
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
return updated_node.with_changes(slice=new_slice)
|
|
323
|
+
|
|
324
|
+
return updated_node
|
|
325
|
+
|
|
326
|
+
def leave_Decorator(
|
|
327
|
+
self, original_node: cst.Decorator, updated_node: cst.Decorator
|
|
328
|
+
) -> cst.Decorator:
|
|
329
|
+
"""Transform @task decorator to @shared_task if imported from removed modules."""
|
|
330
|
+
# Check if the decorator is @task (from celery.task or celery.decorators)
|
|
331
|
+
if isinstance(updated_node.decorator, cst.Name):
|
|
332
|
+
if updated_node.decorator.value == "task":
|
|
333
|
+
# If we removed celery.task or celery.decorators import, rename to shared_task
|
|
334
|
+
if self._removed_celery_task_import or self._removed_celery_decorators_import:
|
|
335
|
+
self.record_change(
|
|
336
|
+
description="Rename @task decorator to @shared_task",
|
|
337
|
+
line_number=1,
|
|
338
|
+
original="@task",
|
|
339
|
+
replacement="@shared_task",
|
|
340
|
+
transform_name="decorator_task_to_shared_task",
|
|
341
|
+
)
|
|
342
|
+
return updated_node.with_changes(decorator=cst.Name("shared_task"))
|
|
343
|
+
|
|
344
|
+
elif isinstance(updated_node.decorator, cst.Call):
|
|
345
|
+
if isinstance(updated_node.decorator.func, cst.Name):
|
|
346
|
+
if updated_node.decorator.func.value == "task":
|
|
347
|
+
if self._removed_celery_task_import or self._removed_celery_decorators_import:
|
|
348
|
+
self.record_change(
|
|
349
|
+
description="Rename @task(...) decorator to @shared_task(...)",
|
|
350
|
+
line_number=1,
|
|
351
|
+
original="@task(...)",
|
|
352
|
+
replacement="@shared_task(...)",
|
|
353
|
+
transform_name="decorator_task_to_shared_task",
|
|
354
|
+
)
|
|
355
|
+
new_call = updated_node.decorator.with_changes(func=cst.Name("shared_task"))
|
|
356
|
+
return updated_node.with_changes(decorator=new_call)
|
|
357
|
+
|
|
358
|
+
return updated_node
|
|
359
|
+
|
|
360
|
+
def _get_module_name(self, module: cst.BaseExpression) -> str:
|
|
361
|
+
"""Get the full module name from a Name or Attribute node."""
|
|
362
|
+
if isinstance(module, cst.Name):
|
|
363
|
+
return str(module.value)
|
|
364
|
+
elif isinstance(module, cst.Attribute):
|
|
365
|
+
return f"{self._get_module_name(module.value)}.{module.attr.value}"
|
|
366
|
+
return ""
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
class CeleryImportTransformer(BaseTransformer):
|
|
370
|
+
"""Separate transformer for adding missing Celery imports.
|
|
371
|
+
|
|
372
|
+
This runs after the main transformer to add any missing imports.
|
|
373
|
+
"""
|
|
374
|
+
|
|
375
|
+
def __init__(
|
|
376
|
+
self,
|
|
377
|
+
needs_shared_task_import: bool = False,
|
|
378
|
+
needs_task_import: bool = False,
|
|
379
|
+
has_shared_task_import: bool = False,
|
|
380
|
+
has_task_import: bool = False,
|
|
381
|
+
) -> None:
|
|
382
|
+
super().__init__()
|
|
383
|
+
self._needs_shared_task_import = needs_shared_task_import
|
|
384
|
+
self._needs_task_import = needs_task_import
|
|
385
|
+
self._has_shared_task_import = has_shared_task_import
|
|
386
|
+
self._has_task_import = has_task_import
|
|
387
|
+
self._found_celery_import = False
|
|
388
|
+
|
|
389
|
+
def visit_ImportFrom(self, node: cst.ImportFrom) -> bool:
|
|
390
|
+
"""Check existing celery imports."""
|
|
391
|
+
if node.module is None:
|
|
392
|
+
return True
|
|
393
|
+
|
|
394
|
+
module_name = self._get_module_name(node.module)
|
|
395
|
+
if module_name == "celery":
|
|
396
|
+
self._found_celery_import = True
|
|
397
|
+
if not isinstance(node.names, cst.ImportStar):
|
|
398
|
+
for name in node.names:
|
|
399
|
+
if isinstance(name, cst.ImportAlias) and isinstance(name.name, cst.Name):
|
|
400
|
+
if name.name.value == "shared_task":
|
|
401
|
+
self._has_shared_task_import = True
|
|
402
|
+
elif name.name.value == "Task":
|
|
403
|
+
self._has_task_import = True
|
|
404
|
+
|
|
405
|
+
return True
|
|
406
|
+
|
|
407
|
+
def leave_ImportFrom(
|
|
408
|
+
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
|
|
409
|
+
) -> cst.ImportFrom:
|
|
410
|
+
"""Add missing imports to celery import statement."""
|
|
411
|
+
if updated_node.module is None:
|
|
412
|
+
return updated_node
|
|
413
|
+
|
|
414
|
+
module_name = self._get_module_name(updated_node.module)
|
|
415
|
+
if module_name != "celery":
|
|
416
|
+
return updated_node
|
|
417
|
+
|
|
418
|
+
if isinstance(updated_node.names, cst.ImportStar):
|
|
419
|
+
return updated_node
|
|
420
|
+
|
|
421
|
+
new_names = list(updated_node.names)
|
|
422
|
+
changed = False
|
|
423
|
+
|
|
424
|
+
if self._needs_shared_task_import and not self._has_shared_task_import:
|
|
425
|
+
new_names.append(cst.ImportAlias(name=cst.Name("shared_task")))
|
|
426
|
+
self._has_shared_task_import = True
|
|
427
|
+
changed = True
|
|
428
|
+
|
|
429
|
+
self.record_change(
|
|
430
|
+
description="Add 'shared_task' import",
|
|
431
|
+
line_number=1,
|
|
432
|
+
original="from celery import ...",
|
|
433
|
+
replacement="from celery import ..., shared_task",
|
|
434
|
+
transform_name="add_shared_task_import",
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
if self._needs_task_import and not self._has_task_import:
|
|
438
|
+
new_names.append(cst.ImportAlias(name=cst.Name("Task")))
|
|
439
|
+
self._has_task_import = True
|
|
440
|
+
changed = True
|
|
441
|
+
|
|
442
|
+
self.record_change(
|
|
443
|
+
description="Add 'Task' import",
|
|
444
|
+
line_number=1,
|
|
445
|
+
original="from celery import ...",
|
|
446
|
+
replacement="from celery import ..., Task",
|
|
447
|
+
transform_name="add_task_import",
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
if changed:
|
|
451
|
+
return updated_node.with_changes(names=new_names)
|
|
452
|
+
|
|
453
|
+
return updated_node
|
|
454
|
+
|
|
455
|
+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
|
456
|
+
"""Add celery import if not found but needed."""
|
|
457
|
+
if self._found_celery_import:
|
|
458
|
+
return updated_node
|
|
459
|
+
|
|
460
|
+
needs_import = (self._needs_shared_task_import and not self._has_shared_task_import) or (
|
|
461
|
+
self._needs_task_import and not self._has_task_import
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
if not needs_import:
|
|
465
|
+
return updated_node
|
|
466
|
+
|
|
467
|
+
# Build the import names
|
|
468
|
+
import_names = []
|
|
469
|
+
if self._needs_shared_task_import and not self._has_shared_task_import:
|
|
470
|
+
import_names.append(cst.ImportAlias(name=cst.Name("shared_task")))
|
|
471
|
+
self.record_change(
|
|
472
|
+
description="Add 'shared_task' import from celery",
|
|
473
|
+
line_number=1,
|
|
474
|
+
original="",
|
|
475
|
+
replacement="from celery import shared_task",
|
|
476
|
+
transform_name="add_shared_task_import",
|
|
477
|
+
)
|
|
478
|
+
if self._needs_task_import and not self._has_task_import:
|
|
479
|
+
import_names.append(cst.ImportAlias(name=cst.Name("Task")))
|
|
480
|
+
self.record_change(
|
|
481
|
+
description="Add 'Task' import from celery",
|
|
482
|
+
line_number=1,
|
|
483
|
+
original="",
|
|
484
|
+
replacement="from celery import Task",
|
|
485
|
+
transform_name="add_task_import",
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
if not import_names:
|
|
489
|
+
return updated_node
|
|
490
|
+
|
|
491
|
+
# Create the import statement
|
|
492
|
+
new_import = cst.SimpleStatementLine(
|
|
493
|
+
body=[
|
|
494
|
+
cst.ImportFrom(
|
|
495
|
+
module=cst.Name("celery"),
|
|
496
|
+
names=import_names,
|
|
497
|
+
)
|
|
498
|
+
]
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
# Add at the beginning of the module
|
|
502
|
+
new_body = [new_import] + list(updated_node.body)
|
|
503
|
+
return updated_node.with_changes(body=new_body)
|
|
504
|
+
|
|
505
|
+
def _get_module_name(self, module: cst.BaseExpression) -> str:
|
|
506
|
+
"""Get the full module name from an Attribute or Name node."""
|
|
507
|
+
if isinstance(module, cst.Name):
|
|
508
|
+
return str(module.value)
|
|
509
|
+
elif isinstance(module, cst.Attribute):
|
|
510
|
+
return f"{self._get_module_name(module.value)}.{module.attr.value}"
|
|
511
|
+
return ""
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def transform_celery(source_code: str) -> tuple[str, list]:
|
|
515
|
+
"""Transform Celery code from 4.x to 5.x.
|
|
516
|
+
|
|
517
|
+
Args:
|
|
518
|
+
source_code: The source code to transform
|
|
519
|
+
|
|
520
|
+
Returns:
|
|
521
|
+
Tuple of (transformed_code, list of changes)
|
|
522
|
+
"""
|
|
523
|
+
try:
|
|
524
|
+
tree = cst.parse_module(source_code)
|
|
525
|
+
except cst.ParserSyntaxError:
|
|
526
|
+
return source_code, []
|
|
527
|
+
|
|
528
|
+
transformer = CeleryTransformer()
|
|
529
|
+
transformer.set_source(source_code)
|
|
530
|
+
|
|
531
|
+
try:
|
|
532
|
+
transformed_tree = tree.visit(transformer)
|
|
533
|
+
|
|
534
|
+
# Second pass: add missing imports
|
|
535
|
+
import_transformer = CeleryImportTransformer(
|
|
536
|
+
needs_shared_task_import=transformer._needs_shared_task_import,
|
|
537
|
+
needs_task_import=transformer._needs_task_import,
|
|
538
|
+
has_shared_task_import=transformer._has_shared_task_import,
|
|
539
|
+
has_task_import=transformer._has_task_import,
|
|
540
|
+
)
|
|
541
|
+
final_tree = transformed_tree.visit(import_transformer)
|
|
542
|
+
|
|
543
|
+
all_changes = transformer.changes + import_transformer.changes
|
|
544
|
+
return final_tree.code, all_changes
|
|
545
|
+
except Exception:
|
|
546
|
+
return source_code, []
|