codeshift 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codeshift/cli/commands/apply.py +24 -2
- codeshift/cli/package_manager.py +102 -0
- codeshift/knowledge/generator.py +11 -1
- codeshift/knowledge_base/libraries/aiohttp.yaml +186 -0
- codeshift/knowledge_base/libraries/attrs.yaml +181 -0
- codeshift/knowledge_base/libraries/celery.yaml +244 -0
- codeshift/knowledge_base/libraries/click.yaml +195 -0
- codeshift/knowledge_base/libraries/django.yaml +355 -0
- codeshift/knowledge_base/libraries/flask.yaml +270 -0
- codeshift/knowledge_base/libraries/httpx.yaml +183 -0
- codeshift/knowledge_base/libraries/marshmallow.yaml +238 -0
- codeshift/knowledge_base/libraries/numpy.yaml +429 -0
- codeshift/knowledge_base/libraries/pytest.yaml +192 -0
- codeshift/knowledge_base/libraries/sqlalchemy.yaml +2 -1
- codeshift/migrator/engine.py +60 -0
- codeshift/migrator/transforms/__init__.py +2 -0
- codeshift/migrator/transforms/aiohttp_transformer.py +608 -0
- codeshift/migrator/transforms/attrs_transformer.py +570 -0
- codeshift/migrator/transforms/celery_transformer.py +546 -0
- codeshift/migrator/transforms/click_transformer.py +526 -0
- codeshift/migrator/transforms/django_transformer.py +852 -0
- codeshift/migrator/transforms/fastapi_transformer.py +12 -7
- codeshift/migrator/transforms/flask_transformer.py +505 -0
- codeshift/migrator/transforms/httpx_transformer.py +419 -0
- codeshift/migrator/transforms/marshmallow_transformer.py +515 -0
- codeshift/migrator/transforms/numpy_transformer.py +413 -0
- codeshift/migrator/transforms/pydantic_v1_to_v2.py +53 -8
- codeshift/migrator/transforms/pytest_transformer.py +351 -0
- codeshift/migrator/transforms/requests_transformer.py +74 -1
- codeshift/migrator/transforms/sqlalchemy_transformer.py +692 -39
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/METADATA +46 -4
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/RECORD +36 -15
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/WHEEL +0 -0
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/entry_points.txt +0 -0
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/licenses/LICENSE +0 -0
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/top_level.txt +0 -0
|
@@ -1,18 +1,51 @@
|
|
|
1
1
|
"""SQLAlchemy 1.x to 2.0 transformation using LibCST."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
3
5
|
import libcst as cst
|
|
4
6
|
|
|
5
7
|
from codeshift.migrator.ast_transforms import BaseTransformer
|
|
6
8
|
|
|
7
9
|
|
|
10
|
+
class _FilterByArg:
|
|
11
|
+
"""Marker class to represent a filter_by argument that needs model reference."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, key: str, value: cst.BaseExpression) -> None:
|
|
14
|
+
self.key = key
|
|
15
|
+
self.value = value
|
|
16
|
+
|
|
17
|
+
|
|
8
18
|
class SQLAlchemyTransformer(BaseTransformer):
|
|
9
19
|
"""Transform SQLAlchemy 1.x code to 2.0."""
|
|
10
20
|
|
|
11
21
|
def __init__(self) -> None:
|
|
12
22
|
super().__init__()
|
|
13
23
|
self._needs_select_import = False
|
|
24
|
+
self._needs_func_import = False
|
|
14
25
|
self._needs_text_import = False
|
|
15
26
|
self._has_declarative_base_import = False
|
|
27
|
+
self._has_text_import = False
|
|
28
|
+
# Track declarative_base variable name for transformation
|
|
29
|
+
self._declarative_base_var_name: str | None = None
|
|
30
|
+
# Track engine variable names from create_engine() calls
|
|
31
|
+
self._engine_var_names: set[str] = set()
|
|
32
|
+
|
|
33
|
+
def visit_ImportFrom(self, node: cst.ImportFrom) -> bool:
|
|
34
|
+
"""Track existing imports."""
|
|
35
|
+
if node.module is None:
|
|
36
|
+
return True
|
|
37
|
+
|
|
38
|
+
module_name = self._get_module_name(node.module)
|
|
39
|
+
|
|
40
|
+
# Track if text is already imported from sqlalchemy
|
|
41
|
+
if module_name == "sqlalchemy":
|
|
42
|
+
if not isinstance(node.names, cst.ImportStar):
|
|
43
|
+
for name in node.names:
|
|
44
|
+
if isinstance(name, cst.ImportAlias):
|
|
45
|
+
if isinstance(name.name, cst.Name) and name.name.value == "text":
|
|
46
|
+
self._has_text_import = True
|
|
47
|
+
|
|
48
|
+
return True
|
|
16
49
|
|
|
17
50
|
def leave_ImportFrom(
|
|
18
51
|
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
|
|
@@ -23,18 +56,18 @@ class SQLAlchemyTransformer(BaseTransformer):
|
|
|
23
56
|
|
|
24
57
|
module_name = self._get_module_name(original_node.module)
|
|
25
58
|
|
|
26
|
-
# Transform declarative_base import
|
|
59
|
+
# Transform declarative_base import from sqlalchemy.ext.declarative
|
|
27
60
|
if module_name == "sqlalchemy.ext.declarative":
|
|
28
61
|
if isinstance(updated_node.names, cst.ImportStar):
|
|
29
62
|
return updated_node
|
|
30
63
|
|
|
31
64
|
new_names = []
|
|
32
|
-
|
|
65
|
+
found_declarative_base = False
|
|
33
66
|
|
|
34
67
|
for name in updated_node.names:
|
|
35
68
|
if isinstance(name, cst.ImportAlias):
|
|
36
69
|
if isinstance(name.name, cst.Name) and name.name.value == "declarative_base":
|
|
37
|
-
|
|
70
|
+
found_declarative_base = True
|
|
38
71
|
self.record_change(
|
|
39
72
|
description="Import DeclarativeBase from sqlalchemy.orm instead of declarative_base",
|
|
40
73
|
line_number=1,
|
|
@@ -43,33 +76,48 @@ class SQLAlchemyTransformer(BaseTransformer):
|
|
|
43
76
|
transform_name="import_declarative_base",
|
|
44
77
|
)
|
|
45
78
|
self._has_declarative_base_import = True
|
|
46
|
-
# Return updated import from sqlalchemy.orm
|
|
47
|
-
return updated_node.with_changes(
|
|
48
|
-
module=cst.Attribute(
|
|
49
|
-
value=cst.Name("sqlalchemy"),
|
|
50
|
-
attr=cst.Name("orm"),
|
|
51
|
-
),
|
|
52
|
-
names=[cst.ImportAlias(name=cst.Name("DeclarativeBase"))],
|
|
53
|
-
)
|
|
54
79
|
else:
|
|
55
80
|
new_names.append(name)
|
|
56
81
|
else:
|
|
57
82
|
new_names.append(name)
|
|
58
83
|
|
|
59
|
-
if
|
|
60
|
-
|
|
84
|
+
if found_declarative_base:
|
|
85
|
+
# Add DeclarativeBase to the list
|
|
86
|
+
new_names.insert(0, cst.ImportAlias(name=cst.Name("DeclarativeBase")))
|
|
87
|
+
# Change module to sqlalchemy.orm
|
|
88
|
+
return updated_node.with_changes(
|
|
89
|
+
module=cst.Attribute(
|
|
90
|
+
value=cst.Name("sqlalchemy"),
|
|
91
|
+
attr=cst.Name("orm"),
|
|
92
|
+
),
|
|
93
|
+
names=new_names,
|
|
94
|
+
)
|
|
61
95
|
|
|
62
|
-
# Handle
|
|
96
|
+
# Handle declarative_base import from sqlalchemy.orm and backref removal
|
|
63
97
|
if module_name == "sqlalchemy.orm":
|
|
64
98
|
if isinstance(updated_node.names, cst.ImportStar):
|
|
65
99
|
return updated_node
|
|
66
100
|
|
|
67
101
|
new_names = []
|
|
68
|
-
|
|
102
|
+
found_declarative_base = False
|
|
103
|
+
found_backref = False
|
|
69
104
|
|
|
70
105
|
for name in updated_node.names:
|
|
71
106
|
if isinstance(name, cst.ImportAlias):
|
|
72
|
-
if isinstance(name.name, cst.Name)
|
|
107
|
+
name_value = name.name.value if isinstance(name.name, cst.Name) else None
|
|
108
|
+
|
|
109
|
+
if name_value == "declarative_base":
|
|
110
|
+
found_declarative_base = True
|
|
111
|
+
self.record_change(
|
|
112
|
+
description="Replace declarative_base import with DeclarativeBase",
|
|
113
|
+
line_number=1,
|
|
114
|
+
original="from sqlalchemy.orm import declarative_base",
|
|
115
|
+
replacement="from sqlalchemy.orm import DeclarativeBase",
|
|
116
|
+
transform_name="import_declarative_base",
|
|
117
|
+
)
|
|
118
|
+
self._has_declarative_base_import = True
|
|
119
|
+
elif name_value == "backref":
|
|
120
|
+
found_backref = True
|
|
73
121
|
self.record_change(
|
|
74
122
|
description="Remove backref import (use back_populates instead)",
|
|
75
123
|
line_number=1,
|
|
@@ -77,37 +125,87 @@ class SQLAlchemyTransformer(BaseTransformer):
|
|
|
77
125
|
replacement="# backref removed, use back_populates",
|
|
78
126
|
transform_name="remove_backref_import",
|
|
79
127
|
)
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
new_names.append(name)
|
|
128
|
+
else:
|
|
129
|
+
new_names.append(name)
|
|
83
130
|
else:
|
|
84
131
|
new_names.append(name)
|
|
85
132
|
|
|
86
|
-
if
|
|
133
|
+
if found_declarative_base or found_backref:
|
|
134
|
+
# Add DeclarativeBase if we found declarative_base
|
|
135
|
+
if found_declarative_base:
|
|
136
|
+
new_names.insert(0, cst.ImportAlias(name=cst.Name("DeclarativeBase")))
|
|
137
|
+
|
|
87
138
|
if new_names:
|
|
139
|
+
# Fix trailing comma: ensure last item has no trailing comma
|
|
140
|
+
if new_names:
|
|
141
|
+
last_item = new_names[-1]
|
|
142
|
+
if (
|
|
143
|
+
hasattr(last_item, "comma")
|
|
144
|
+
and last_item.comma != cst.MaybeSentinel.DEFAULT
|
|
145
|
+
):
|
|
146
|
+
new_names[-1] = last_item.with_changes(comma=cst.MaybeSentinel.DEFAULT)
|
|
88
147
|
return updated_node.with_changes(names=new_names)
|
|
89
|
-
|
|
90
|
-
|
|
148
|
+
else:
|
|
149
|
+
# No imports left, remove the line
|
|
150
|
+
return cst.RemovalSentinel.REMOVE
|
|
91
151
|
|
|
92
152
|
return updated_node
|
|
93
153
|
|
|
94
|
-
def
|
|
95
|
-
"""
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
154
|
+
def visit_Assign(self, node: cst.Assign) -> bool:
|
|
155
|
+
"""Track engine variable names from create_engine() calls."""
|
|
156
|
+
if len(node.targets) == 1:
|
|
157
|
+
target = node.targets[0].target
|
|
158
|
+
if isinstance(target, cst.Name) and isinstance(node.value, cst.Call):
|
|
159
|
+
call = node.value
|
|
160
|
+
# Check if this is a create_engine() call
|
|
161
|
+
if isinstance(call.func, cst.Name) and call.func.value == "create_engine":
|
|
162
|
+
self._engine_var_names.add(target.value)
|
|
163
|
+
return True
|
|
100
164
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
165
|
+
def leave_SimpleStatementLine(
|
|
166
|
+
self,
|
|
167
|
+
original_node: cst.SimpleStatementLine,
|
|
168
|
+
updated_node: cst.SimpleStatementLine,
|
|
169
|
+
) -> cst.SimpleStatementLine | cst.ClassDef | cst.RemovalSentinel:
|
|
170
|
+
"""Transform assignment statements like Base = declarative_base()."""
|
|
171
|
+
# Check if this is an assignment with declarative_base() call
|
|
172
|
+
if len(updated_node.body) == 1:
|
|
173
|
+
stmt = updated_node.body[0]
|
|
174
|
+
if isinstance(stmt, cst.Assign) and len(stmt.targets) == 1:
|
|
175
|
+
target = stmt.targets[0].target
|
|
176
|
+
if isinstance(target, cst.Name) and isinstance(stmt.value, cst.Call):
|
|
177
|
+
call = stmt.value
|
|
178
|
+
if isinstance(call.func, cst.Name) and call.func.value == "declarative_base":
|
|
179
|
+
var_name = target.value
|
|
180
|
+
self._declarative_base_var_name = var_name
|
|
181
|
+
|
|
182
|
+
self.record_change(
|
|
183
|
+
description=f"Replace {var_name} = declarative_base() with class {var_name}(DeclarativeBase): pass",
|
|
184
|
+
line_number=1,
|
|
185
|
+
original=f"{var_name} = declarative_base()",
|
|
186
|
+
replacement=f"class {var_name}(DeclarativeBase):\n pass",
|
|
187
|
+
transform_name="declarative_base_to_class",
|
|
188
|
+
confidence=1.0,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Create a class definition: class Base(DeclarativeBase): pass
|
|
192
|
+
class_def = cst.ClassDef(
|
|
193
|
+
name=cst.Name(var_name),
|
|
194
|
+
bases=[cst.Arg(value=cst.Name("DeclarativeBase"))],
|
|
195
|
+
body=cst.IndentedBlock(
|
|
196
|
+
body=[cst.SimpleStatementLine(body=[cst.Pass()])]
|
|
197
|
+
),
|
|
198
|
+
)
|
|
199
|
+
return class_def
|
|
200
|
+
|
|
201
|
+
return updated_node
|
|
202
|
+
|
|
203
|
+
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.BaseExpression:
|
|
204
|
+
"""Transform SQLAlchemy function calls."""
|
|
205
|
+
# Handle session.query() transformations
|
|
206
|
+
transformed = self._transform_query_call(updated_node)
|
|
207
|
+
if transformed is not None:
|
|
208
|
+
return transformed
|
|
111
209
|
|
|
112
210
|
# Handle create_engine future flag
|
|
113
211
|
if isinstance(updated_node.func, cst.Name) and updated_node.func.value == "create_engine":
|
|
@@ -128,10 +226,392 @@ class SQLAlchemyTransformer(BaseTransformer):
|
|
|
128
226
|
new_args.append(arg)
|
|
129
227
|
|
|
130
228
|
if changed:
|
|
229
|
+
# Fix trailing comma: remove trailing comma from last argument if present
|
|
230
|
+
if new_args:
|
|
231
|
+
last_arg = new_args[-1]
|
|
232
|
+
# Remove trailing comma from the last argument
|
|
233
|
+
if last_arg.comma != cst.MaybeSentinel.DEFAULT:
|
|
234
|
+
new_args[-1] = last_arg.with_changes(comma=cst.MaybeSentinel.DEFAULT)
|
|
131
235
|
return updated_node.with_changes(args=new_args)
|
|
132
236
|
|
|
237
|
+
# Handle execute() with raw SQL string - wrap with text()
|
|
238
|
+
if (
|
|
239
|
+
isinstance(updated_node.func, cst.Attribute)
|
|
240
|
+
and updated_node.func.attr.value == "execute"
|
|
241
|
+
):
|
|
242
|
+
# Check if this is engine.execute() - which requires manual migration
|
|
243
|
+
# to use with engine.connect() as conn: conn.execute()
|
|
244
|
+
if self._is_engine_execute_call(updated_node):
|
|
245
|
+
self.record_change(
|
|
246
|
+
description="engine.execute() is removed in SQLAlchemy 2.0. "
|
|
247
|
+
"Use 'with engine.connect() as conn: conn.execute(...)' instead",
|
|
248
|
+
line_number=1,
|
|
249
|
+
original="engine.execute(...)",
|
|
250
|
+
replacement="with engine.connect() as conn:\n conn.execute(...)",
|
|
251
|
+
transform_name="engine_execute_to_connect",
|
|
252
|
+
confidence=0.5,
|
|
253
|
+
notes="MANUAL MIGRATION REQUIRED: This transformation requires "
|
|
254
|
+
"restructuring the code to use a context manager. The execute() call "
|
|
255
|
+
"must be moved inside a 'with engine.connect() as conn:' block, and "
|
|
256
|
+
"raw SQL strings should be wrapped with text(). If the result is used, "
|
|
257
|
+
"ensure proper handling within the context manager scope.",
|
|
258
|
+
)
|
|
259
|
+
# Don't transform the code - just record the warning
|
|
260
|
+
# The code still needs to have text() wrapping applied if applicable
|
|
261
|
+
# Fall through to the text wrapping logic below
|
|
262
|
+
|
|
263
|
+
if updated_node.args:
|
|
264
|
+
first_arg = updated_node.args[0]
|
|
265
|
+
# Check if the first argument is a string literal (raw SQL)
|
|
266
|
+
if isinstance(first_arg.value, cst.SimpleString) or isinstance(
|
|
267
|
+
first_arg.value, cst.ConcatenatedString
|
|
268
|
+
):
|
|
269
|
+
# Wrap the string with text()
|
|
270
|
+
self._needs_text_import = True
|
|
271
|
+
text_call = cst.Call(
|
|
272
|
+
func=cst.Name("text"),
|
|
273
|
+
args=[cst.Arg(value=first_arg.value)],
|
|
274
|
+
)
|
|
275
|
+
new_first_arg = first_arg.with_changes(value=text_call)
|
|
276
|
+
new_args = [new_first_arg] + list(updated_node.args[1:])
|
|
277
|
+
|
|
278
|
+
self.record_change(
|
|
279
|
+
description="Wrap raw SQL string with text()",
|
|
280
|
+
line_number=1,
|
|
281
|
+
original='execute("...")',
|
|
282
|
+
replacement='execute(text("..."))',
|
|
283
|
+
transform_name="wrap_execute_with_text",
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
return updated_node.with_changes(args=new_args)
|
|
287
|
+
|
|
133
288
|
return updated_node
|
|
134
289
|
|
|
290
|
+
def _is_engine_execute_call(self, node: cst.Call) -> bool:
|
|
291
|
+
"""Check if a call is engine.execute() where engine is likely a SQLAlchemy engine.
|
|
292
|
+
|
|
293
|
+
Uses heuristics:
|
|
294
|
+
1. The variable is known to be assigned from create_engine()
|
|
295
|
+
2. The variable name contains 'engine' (case insensitive)
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
node: The Call node to check (already verified to be *.execute())
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
True if this appears to be an engine.execute() call
|
|
302
|
+
"""
|
|
303
|
+
if not isinstance(node.func, cst.Attribute):
|
|
304
|
+
return False
|
|
305
|
+
|
|
306
|
+
caller = node.func.value
|
|
307
|
+
if not isinstance(caller, cst.Name):
|
|
308
|
+
return False
|
|
309
|
+
|
|
310
|
+
var_name = caller.value
|
|
311
|
+
|
|
312
|
+
# Check if this variable was assigned from create_engine()
|
|
313
|
+
if var_name in self._engine_var_names:
|
|
314
|
+
return True
|
|
315
|
+
|
|
316
|
+
# Heuristic: check if variable name contains 'engine'
|
|
317
|
+
if "engine" in var_name.lower():
|
|
318
|
+
return True
|
|
319
|
+
|
|
320
|
+
return False
|
|
321
|
+
|
|
322
|
+
def _transform_query_call(self, node: cst.Call) -> cst.BaseExpression | None:
|
|
323
|
+
"""Transform session.query(...) patterns to session.execute(select(...)).
|
|
324
|
+
|
|
325
|
+
Handles patterns like:
|
|
326
|
+
- session.query(Model).all() -> session.execute(select(Model)).scalars().all()
|
|
327
|
+
- session.query(Model).first() -> session.execute(select(Model)).scalars().first()
|
|
328
|
+
- session.query(Model).one() -> session.execute(select(Model)).scalars().one()
|
|
329
|
+
- session.query(Model).filter(...).all() -> session.execute(select(Model).where(...)).scalars().all()
|
|
330
|
+
- session.query(Model).get(id) -> session.get(Model, id)
|
|
331
|
+
- session.query(Model).count() -> session.execute(select(func.count()).select_from(Model)).scalar()
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
Transformed node if this is a query pattern, None otherwise.
|
|
335
|
+
"""
|
|
336
|
+
# Check if this is a method call (has Attribute as func)
|
|
337
|
+
if not isinstance(node.func, cst.Attribute):
|
|
338
|
+
return None
|
|
339
|
+
|
|
340
|
+
# Find the terminal method being called (.all(), .first(), .get(), etc.)
|
|
341
|
+
terminal_method = node.func.attr.value
|
|
342
|
+
|
|
343
|
+
# Methods that end a query chain
|
|
344
|
+
terminal_methods = {"all", "first", "one", "one_or_none", "get", "count", "scalar"}
|
|
345
|
+
if terminal_method not in terminal_methods:
|
|
346
|
+
return None
|
|
347
|
+
|
|
348
|
+
# Walk up the chain to find .query() and collect intermediate methods
|
|
349
|
+
chain_info = self._parse_query_chain(node)
|
|
350
|
+
if chain_info is None:
|
|
351
|
+
return None
|
|
352
|
+
|
|
353
|
+
session_node, model_node, filters, terminal_method, terminal_args = chain_info
|
|
354
|
+
|
|
355
|
+
# Handle .get(id) - transforms to session.get(Model, id)
|
|
356
|
+
if terminal_method == "get":
|
|
357
|
+
return self._transform_query_get(session_node, model_node, terminal_args)
|
|
358
|
+
|
|
359
|
+
# Handle .count() - transforms to session.execute(select(func.count()).select_from(Model)).scalar()
|
|
360
|
+
if terminal_method == "count":
|
|
361
|
+
return self._transform_query_count(session_node, model_node, filters)
|
|
362
|
+
|
|
363
|
+
# Handle .all(), .first(), .one(), .one_or_none(), .scalar()
|
|
364
|
+
return self._transform_query_execute(session_node, model_node, filters, terminal_method)
|
|
365
|
+
|
|
366
|
+
def _parse_query_chain(self, node: cst.Call) -> (
|
|
367
|
+
tuple[
|
|
368
|
+
cst.BaseExpression,
|
|
369
|
+
cst.BaseExpression,
|
|
370
|
+
list[cst.BaseExpression | _FilterByArg],
|
|
371
|
+
str,
|
|
372
|
+
Sequence[cst.Arg],
|
|
373
|
+
]
|
|
374
|
+
| None
|
|
375
|
+
):
|
|
376
|
+
"""Parse a query chain to extract session, model, filters, and terminal method.
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
Tuple of (session_node, model_node, filters, terminal_method, terminal_args) or None
|
|
380
|
+
"""
|
|
381
|
+
if not isinstance(node.func, cst.Attribute):
|
|
382
|
+
return None
|
|
383
|
+
|
|
384
|
+
terminal_method = node.func.attr.value
|
|
385
|
+
terminal_args = node.args
|
|
386
|
+
current = node.func.value # Move past the terminal method call
|
|
387
|
+
filters: list[cst.BaseExpression | _FilterByArg] = []
|
|
388
|
+
|
|
389
|
+
# Walk up the chain collecting .filter() and .filter_by() calls
|
|
390
|
+
while True:
|
|
391
|
+
if isinstance(current, cst.Call):
|
|
392
|
+
func = current.func
|
|
393
|
+
if isinstance(func, cst.Attribute):
|
|
394
|
+
method_name = func.attr.value
|
|
395
|
+
|
|
396
|
+
if method_name == "query":
|
|
397
|
+
# Found the root .query() call
|
|
398
|
+
session_node = func.value
|
|
399
|
+
if current.args:
|
|
400
|
+
model_node = current.args[0].value
|
|
401
|
+
return (
|
|
402
|
+
session_node,
|
|
403
|
+
model_node,
|
|
404
|
+
list(reversed(filters)), # Reverse to get correct order
|
|
405
|
+
terminal_method,
|
|
406
|
+
terminal_args,
|
|
407
|
+
)
|
|
408
|
+
return None
|
|
409
|
+
|
|
410
|
+
elif method_name == "filter":
|
|
411
|
+
# Collect filter arguments
|
|
412
|
+
for arg in current.args:
|
|
413
|
+
filters.append(arg.value)
|
|
414
|
+
current = func.value
|
|
415
|
+
|
|
416
|
+
elif method_name == "filter_by":
|
|
417
|
+
# Convert filter_by(key=val) to Model.key == val
|
|
418
|
+
# Store the kwargs for handling during transform
|
|
419
|
+
for arg in current.args:
|
|
420
|
+
if arg.keyword is not None:
|
|
421
|
+
filters.append(_FilterByArg(arg.keyword.value, arg.value))
|
|
422
|
+
current = func.value
|
|
423
|
+
|
|
424
|
+
elif method_name in {
|
|
425
|
+
"order_by",
|
|
426
|
+
"limit",
|
|
427
|
+
"offset",
|
|
428
|
+
"distinct",
|
|
429
|
+
"group_by",
|
|
430
|
+
"having",
|
|
431
|
+
"join",
|
|
432
|
+
"outerjoin",
|
|
433
|
+
}:
|
|
434
|
+
# Skip these for now - they can be added to the select() later
|
|
435
|
+
current = func.value
|
|
436
|
+
|
|
437
|
+
else:
|
|
438
|
+
# Unknown method, not a query chain we can handle
|
|
439
|
+
return None
|
|
440
|
+
else:
|
|
441
|
+
return None
|
|
442
|
+
elif isinstance(current, cst.Attribute):
|
|
443
|
+
# This might be something like query.Model or session.query
|
|
444
|
+
return None
|
|
445
|
+
else:
|
|
446
|
+
return None
|
|
447
|
+
|
|
448
|
+
def _transform_query_get(
|
|
449
|
+
self,
|
|
450
|
+
session_node: cst.BaseExpression,
|
|
451
|
+
model_node: cst.BaseExpression,
|
|
452
|
+
args: Sequence[cst.Arg],
|
|
453
|
+
) -> cst.Call:
|
|
454
|
+
"""Transform session.query(Model).get(id) to session.get(Model, id)."""
|
|
455
|
+
self._needs_select_import = True
|
|
456
|
+
|
|
457
|
+
self.record_change(
|
|
458
|
+
description="Convert session.query(Model).get(id) to session.get(Model, id)",
|
|
459
|
+
line_number=1,
|
|
460
|
+
original="session.query(Model).get(id)",
|
|
461
|
+
replacement="session.get(Model, id)",
|
|
462
|
+
transform_name="query_get_to_session_get",
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
# Build session.get(Model, id)
|
|
466
|
+
new_args = [cst.Arg(value=model_node)]
|
|
467
|
+
new_args.extend(args)
|
|
468
|
+
|
|
469
|
+
return cst.Call(
|
|
470
|
+
func=cst.Attribute(value=session_node, attr=cst.Name("get")),
|
|
471
|
+
args=new_args,
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
def _transform_query_count(
|
|
475
|
+
self,
|
|
476
|
+
session_node: cst.BaseExpression,
|
|
477
|
+
model_node: cst.BaseExpression,
|
|
478
|
+
filters: list[cst.BaseExpression | _FilterByArg],
|
|
479
|
+
) -> cst.Call:
|
|
480
|
+
"""Transform session.query(Model).count() to session.execute(select(func.count()).select_from(Model)).scalar()."""
|
|
481
|
+
self._needs_select_import = True
|
|
482
|
+
self._needs_func_import = True
|
|
483
|
+
|
|
484
|
+
self.record_change(
|
|
485
|
+
description="Convert session.query(Model).count() to session.execute(select(func.count()).select_from(Model)).scalar()",
|
|
486
|
+
line_number=1,
|
|
487
|
+
original="session.query(Model).count()",
|
|
488
|
+
replacement="session.execute(select(func.count()).select_from(Model)).scalar()",
|
|
489
|
+
transform_name="query_count_to_select_count",
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
# Build func.count()
|
|
493
|
+
func_count = cst.Call(
|
|
494
|
+
func=cst.Attribute(value=cst.Name("func"), attr=cst.Name("count")),
|
|
495
|
+
args=[],
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
# Build select(func.count())
|
|
499
|
+
select_call = cst.Call(
|
|
500
|
+
func=cst.Name("select"),
|
|
501
|
+
args=[cst.Arg(value=func_count)],
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
# Add .select_from(Model)
|
|
505
|
+
select_from = cst.Call(
|
|
506
|
+
func=cst.Attribute(value=select_call, attr=cst.Name("select_from")),
|
|
507
|
+
args=[cst.Arg(value=model_node)],
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
# Add .where() if there are filters
|
|
511
|
+
current: cst.BaseExpression = select_from
|
|
512
|
+
for filter_expr in filters:
|
|
513
|
+
if isinstance(filter_expr, _FilterByArg):
|
|
514
|
+
# Convert filter_by to where with Model.attr == val
|
|
515
|
+
where_condition = cst.Comparison(
|
|
516
|
+
left=cst.Attribute(value=model_node, attr=cst.Name(filter_expr.key)),
|
|
517
|
+
comparisons=[
|
|
518
|
+
cst.ComparisonTarget(
|
|
519
|
+
operator=cst.Equal(),
|
|
520
|
+
comparator=filter_expr.value,
|
|
521
|
+
)
|
|
522
|
+
],
|
|
523
|
+
)
|
|
524
|
+
current = cst.Call(
|
|
525
|
+
func=cst.Attribute(value=current, attr=cst.Name("where")),
|
|
526
|
+
args=[cst.Arg(value=where_condition)],
|
|
527
|
+
)
|
|
528
|
+
else:
|
|
529
|
+
current = cst.Call(
|
|
530
|
+
func=cst.Attribute(value=current, attr=cst.Name("where")),
|
|
531
|
+
args=[cst.Arg(value=filter_expr)],
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
# Build session.execute(...)
|
|
535
|
+
execute_call = cst.Call(
|
|
536
|
+
func=cst.Attribute(value=session_node, attr=cst.Name("execute")),
|
|
537
|
+
args=[cst.Arg(value=current)],
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
# Add .scalar()
|
|
541
|
+
return cst.Call(
|
|
542
|
+
func=cst.Attribute(value=execute_call, attr=cst.Name("scalar")),
|
|
543
|
+
args=[],
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
def _transform_query_execute(
|
|
547
|
+
self,
|
|
548
|
+
session_node: cst.BaseExpression,
|
|
549
|
+
model_node: cst.BaseExpression,
|
|
550
|
+
filters: list[cst.BaseExpression | _FilterByArg],
|
|
551
|
+
terminal_method: str,
|
|
552
|
+
) -> cst.Call:
|
|
553
|
+
"""Transform session.query(Model).all/first/one() to session.execute(select(Model)).scalars().all/first/one()."""
|
|
554
|
+
self._needs_select_import = True
|
|
555
|
+
|
|
556
|
+
original = f"session.query(Model).{terminal_method}()"
|
|
557
|
+
replacement = f"session.execute(select(Model)).scalars().{terminal_method}()"
|
|
558
|
+
|
|
559
|
+
self.record_change(
|
|
560
|
+
description=f"Convert {original} to {replacement}",
|
|
561
|
+
line_number=1,
|
|
562
|
+
original=original,
|
|
563
|
+
replacement=replacement,
|
|
564
|
+
transform_name=f"query_{terminal_method}_to_select",
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
# Build select(Model)
|
|
568
|
+
select_call = cst.Call(
|
|
569
|
+
func=cst.Name("select"),
|
|
570
|
+
args=[cst.Arg(value=model_node)],
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
# Add .where() for each filter
|
|
574
|
+
current: cst.BaseExpression = select_call
|
|
575
|
+
for filter_expr in filters:
|
|
576
|
+
if isinstance(filter_expr, _FilterByArg):
|
|
577
|
+
# Convert filter_by to where with Model.attr == val
|
|
578
|
+
where_condition = cst.Comparison(
|
|
579
|
+
left=cst.Attribute(value=model_node, attr=cst.Name(filter_expr.key)),
|
|
580
|
+
comparisons=[
|
|
581
|
+
cst.ComparisonTarget(
|
|
582
|
+
operator=cst.Equal(),
|
|
583
|
+
comparator=filter_expr.value,
|
|
584
|
+
)
|
|
585
|
+
],
|
|
586
|
+
)
|
|
587
|
+
current = cst.Call(
|
|
588
|
+
func=cst.Attribute(value=current, attr=cst.Name("where")),
|
|
589
|
+
args=[cst.Arg(value=where_condition)],
|
|
590
|
+
)
|
|
591
|
+
else:
|
|
592
|
+
current = cst.Call(
|
|
593
|
+
func=cst.Attribute(value=current, attr=cst.Name("where")),
|
|
594
|
+
args=[cst.Arg(value=filter_expr)],
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
# Build session.execute(...)
|
|
598
|
+
execute_call = cst.Call(
|
|
599
|
+
func=cst.Attribute(value=session_node, attr=cst.Name("execute")),
|
|
600
|
+
args=[cst.Arg(value=current)],
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
# Add .scalars()
|
|
604
|
+
scalars_call = cst.Call(
|
|
605
|
+
func=cst.Attribute(value=execute_call, attr=cst.Name("scalars")),
|
|
606
|
+
args=[],
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
# Add terminal method (.all(), .first(), .one(), etc.)
|
|
610
|
+
return cst.Call(
|
|
611
|
+
func=cst.Attribute(value=scalars_call, attr=cst.Name(terminal_method)),
|
|
612
|
+
args=[],
|
|
613
|
+
)
|
|
614
|
+
|
|
135
615
|
def leave_Attribute(
|
|
136
616
|
self, original_node: cst.Attribute, updated_node: cst.Attribute
|
|
137
617
|
) -> cst.Attribute:
|
|
@@ -151,6 +631,168 @@ class SQLAlchemyTransformer(BaseTransformer):
|
|
|
151
631
|
return ""
|
|
152
632
|
|
|
153
633
|
|
|
634
|
+
class SQLAlchemyImportTransformer(BaseTransformer):
|
|
635
|
+
"""Separate transformer for handling import additions.
|
|
636
|
+
|
|
637
|
+
This runs after the main transformer to add any missing imports.
|
|
638
|
+
"""
|
|
639
|
+
|
|
640
|
+
def __init__(
|
|
641
|
+
self,
|
|
642
|
+
needs_select_import: bool = False,
|
|
643
|
+
needs_func_import: bool = False,
|
|
644
|
+
needs_text_import: bool = False,
|
|
645
|
+
has_text_import: bool = False,
|
|
646
|
+
) -> None:
|
|
647
|
+
super().__init__()
|
|
648
|
+
self._needs_select_import = needs_select_import
|
|
649
|
+
self._needs_func_import = needs_func_import
|
|
650
|
+
self._needs_text_import = needs_text_import
|
|
651
|
+
self._has_text_import = has_text_import
|
|
652
|
+
self._has_select_import = False
|
|
653
|
+
self._has_func_import = False
|
|
654
|
+
self._found_sqlalchemy_import = False
|
|
655
|
+
|
|
656
|
+
def visit_ImportFrom(self, node: cst.ImportFrom) -> bool:
|
|
657
|
+
"""Check existing sqlalchemy imports."""
|
|
658
|
+
if node.module is None:
|
|
659
|
+
return True
|
|
660
|
+
|
|
661
|
+
module_name = self._get_module_name(node.module)
|
|
662
|
+
if module_name == "sqlalchemy":
|
|
663
|
+
self._found_sqlalchemy_import = True
|
|
664
|
+
if not isinstance(node.names, cst.ImportStar):
|
|
665
|
+
for name in node.names:
|
|
666
|
+
if isinstance(name, cst.ImportAlias):
|
|
667
|
+
if isinstance(name.name, cst.Name):
|
|
668
|
+
if name.name.value == "text":
|
|
669
|
+
self._has_text_import = True
|
|
670
|
+
elif name.name.value == "select":
|
|
671
|
+
self._has_select_import = True
|
|
672
|
+
elif name.name.value == "func":
|
|
673
|
+
self._has_func_import = True
|
|
674
|
+
|
|
675
|
+
return True
|
|
676
|
+
|
|
677
|
+
def leave_ImportFrom(
|
|
678
|
+
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
|
|
679
|
+
) -> cst.ImportFrom:
|
|
680
|
+
"""Add missing imports to sqlalchemy import statement."""
|
|
681
|
+
if updated_node.module is None:
|
|
682
|
+
return updated_node
|
|
683
|
+
|
|
684
|
+
module_name = self._get_module_name(updated_node.module)
|
|
685
|
+
if module_name != "sqlalchemy":
|
|
686
|
+
return updated_node
|
|
687
|
+
|
|
688
|
+
if isinstance(updated_node.names, cst.ImportStar):
|
|
689
|
+
return updated_node
|
|
690
|
+
|
|
691
|
+
new_names = list(updated_node.names)
|
|
692
|
+
changed = False
|
|
693
|
+
|
|
694
|
+
# Add text import if needed and not already present
|
|
695
|
+
if self._needs_text_import and not self._has_text_import:
|
|
696
|
+
new_names.append(cst.ImportAlias(name=cst.Name("text")))
|
|
697
|
+
self._has_text_import = True
|
|
698
|
+
changed = True
|
|
699
|
+
|
|
700
|
+
# Add select import if needed and not already present
|
|
701
|
+
if self._needs_select_import and not self._has_select_import:
|
|
702
|
+
new_names.append(cst.ImportAlias(name=cst.Name("select")))
|
|
703
|
+
self._has_select_import = True
|
|
704
|
+
changed = True
|
|
705
|
+
|
|
706
|
+
self.record_change(
|
|
707
|
+
description="Add 'select' import for query transformation",
|
|
708
|
+
line_number=1,
|
|
709
|
+
original="from sqlalchemy import ...",
|
|
710
|
+
replacement="from sqlalchemy import ..., select",
|
|
711
|
+
transform_name="add_select_import",
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
# Add func import if needed and not already present
|
|
715
|
+
if self._needs_func_import and not self._has_func_import:
|
|
716
|
+
new_names.append(cst.ImportAlias(name=cst.Name("func")))
|
|
717
|
+
self._has_func_import = True
|
|
718
|
+
changed = True
|
|
719
|
+
|
|
720
|
+
self.record_change(
|
|
721
|
+
description="Add 'func' import for count transformation",
|
|
722
|
+
line_number=1,
|
|
723
|
+
original="from sqlalchemy import ...",
|
|
724
|
+
replacement="from sqlalchemy import ..., func",
|
|
725
|
+
transform_name="add_func_import",
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
if changed:
|
|
729
|
+
return updated_node.with_changes(names=new_names)
|
|
730
|
+
|
|
731
|
+
return updated_node
|
|
732
|
+
|
|
733
|
+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
|
734
|
+
"""Add sqlalchemy import if not found but needed."""
|
|
735
|
+
if self._found_sqlalchemy_import:
|
|
736
|
+
return updated_node
|
|
737
|
+
|
|
738
|
+
needs_import = (
|
|
739
|
+
(self._needs_select_import and not self._has_select_import)
|
|
740
|
+
or (self._needs_func_import and not self._has_func_import)
|
|
741
|
+
or (self._needs_text_import and not self._has_text_import)
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
if not needs_import:
|
|
745
|
+
return updated_node
|
|
746
|
+
|
|
747
|
+
# Build the import names
|
|
748
|
+
import_names = []
|
|
749
|
+
if self._needs_select_import and not self._has_select_import:
|
|
750
|
+
import_names.append(cst.ImportAlias(name=cst.Name("select")))
|
|
751
|
+
self.record_change(
|
|
752
|
+
description="Add 'select' import for query transformation",
|
|
753
|
+
line_number=1,
|
|
754
|
+
original="",
|
|
755
|
+
replacement="from sqlalchemy import select",
|
|
756
|
+
transform_name="add_select_import",
|
|
757
|
+
)
|
|
758
|
+
if self._needs_func_import and not self._has_func_import:
|
|
759
|
+
import_names.append(cst.ImportAlias(name=cst.Name("func")))
|
|
760
|
+
self.record_change(
|
|
761
|
+
description="Add 'func' import for count transformation",
|
|
762
|
+
line_number=1,
|
|
763
|
+
original="",
|
|
764
|
+
replacement="from sqlalchemy import func",
|
|
765
|
+
transform_name="add_func_import",
|
|
766
|
+
)
|
|
767
|
+
if self._needs_text_import and not self._has_text_import:
|
|
768
|
+
import_names.append(cst.ImportAlias(name=cst.Name("text")))
|
|
769
|
+
|
|
770
|
+
if not import_names:
|
|
771
|
+
return updated_node
|
|
772
|
+
|
|
773
|
+
# Create the import statement
|
|
774
|
+
new_import = cst.SimpleStatementLine(
|
|
775
|
+
body=[
|
|
776
|
+
cst.ImportFrom(
|
|
777
|
+
module=cst.Name("sqlalchemy"),
|
|
778
|
+
names=import_names,
|
|
779
|
+
)
|
|
780
|
+
]
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
# Add at the beginning of the module (after any existing imports)
|
|
784
|
+
new_body = [new_import] + list(updated_node.body)
|
|
785
|
+
return updated_node.with_changes(body=new_body)
|
|
786
|
+
|
|
787
|
+
def _get_module_name(self, module: cst.BaseExpression) -> str:
|
|
788
|
+
"""Get the full module name from an Attribute or Name node."""
|
|
789
|
+
if isinstance(module, cst.Name):
|
|
790
|
+
return str(module.value)
|
|
791
|
+
elif isinstance(module, cst.Attribute):
|
|
792
|
+
return f"{self._get_module_name(module.value)}.{module.attr.value}"
|
|
793
|
+
return ""
|
|
794
|
+
|
|
795
|
+
|
|
154
796
|
def transform_sqlalchemy(source_code: str) -> tuple[str, list]:
|
|
155
797
|
"""Transform SQLAlchemy code from 1.x to 2.0.
|
|
156
798
|
|
|
@@ -170,6 +812,17 @@ def transform_sqlalchemy(source_code: str) -> tuple[str, list]:
|
|
|
170
812
|
|
|
171
813
|
try:
|
|
172
814
|
transformed_tree = tree.visit(transformer)
|
|
173
|
-
|
|
815
|
+
|
|
816
|
+
# Second pass: add missing imports
|
|
817
|
+
import_transformer = SQLAlchemyImportTransformer(
|
|
818
|
+
needs_select_import=transformer._needs_select_import,
|
|
819
|
+
needs_func_import=transformer._needs_func_import,
|
|
820
|
+
needs_text_import=transformer._needs_text_import,
|
|
821
|
+
has_text_import=transformer._has_text_import,
|
|
822
|
+
)
|
|
823
|
+
final_tree = transformed_tree.visit(import_transformer)
|
|
824
|
+
|
|
825
|
+
all_changes = transformer.changes + import_transformer.changes
|
|
826
|
+
return final_tree.code, all_changes
|
|
174
827
|
except Exception:
|
|
175
828
|
return source_code, []
|