codeshift 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. codeshift/cli/commands/apply.py +24 -2
  2. codeshift/cli/package_manager.py +102 -0
  3. codeshift/knowledge/generator.py +11 -1
  4. codeshift/knowledge_base/libraries/aiohttp.yaml +186 -0
  5. codeshift/knowledge_base/libraries/attrs.yaml +181 -0
  6. codeshift/knowledge_base/libraries/celery.yaml +244 -0
  7. codeshift/knowledge_base/libraries/click.yaml +195 -0
  8. codeshift/knowledge_base/libraries/django.yaml +355 -0
  9. codeshift/knowledge_base/libraries/flask.yaml +270 -0
  10. codeshift/knowledge_base/libraries/httpx.yaml +183 -0
  11. codeshift/knowledge_base/libraries/marshmallow.yaml +238 -0
  12. codeshift/knowledge_base/libraries/numpy.yaml +429 -0
  13. codeshift/knowledge_base/libraries/pytest.yaml +192 -0
  14. codeshift/knowledge_base/libraries/sqlalchemy.yaml +2 -1
  15. codeshift/migrator/engine.py +60 -0
  16. codeshift/migrator/transforms/__init__.py +2 -0
  17. codeshift/migrator/transforms/aiohttp_transformer.py +608 -0
  18. codeshift/migrator/transforms/attrs_transformer.py +570 -0
  19. codeshift/migrator/transforms/celery_transformer.py +546 -0
  20. codeshift/migrator/transforms/click_transformer.py +526 -0
  21. codeshift/migrator/transforms/django_transformer.py +852 -0
  22. codeshift/migrator/transforms/fastapi_transformer.py +12 -7
  23. codeshift/migrator/transforms/flask_transformer.py +505 -0
  24. codeshift/migrator/transforms/httpx_transformer.py +419 -0
  25. codeshift/migrator/transforms/marshmallow_transformer.py +515 -0
  26. codeshift/migrator/transforms/numpy_transformer.py +413 -0
  27. codeshift/migrator/transforms/pydantic_v1_to_v2.py +53 -8
  28. codeshift/migrator/transforms/pytest_transformer.py +351 -0
  29. codeshift/migrator/transforms/requests_transformer.py +74 -1
  30. codeshift/migrator/transforms/sqlalchemy_transformer.py +692 -39
  31. {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/METADATA +46 -4
  32. {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/RECORD +36 -15
  33. {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/WHEEL +0 -0
  34. {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/entry_points.txt +0 -0
  35. {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/licenses/LICENSE +0 -0
  36. {codeshift-0.3.3.dist-info → codeshift-0.3.5.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,51 @@
1
1
  """SQLAlchemy 1.x to 2.0 transformation using LibCST."""
2
2
 
3
+ from collections.abc import Sequence
4
+
3
5
  import libcst as cst
4
6
 
5
7
  from codeshift.migrator.ast_transforms import BaseTransformer
6
8
 
7
9
 
10
+ class _FilterByArg:
11
+ """Marker class to represent a filter_by argument that needs model reference."""
12
+
13
+ def __init__(self, key: str, value: cst.BaseExpression) -> None:
14
+ self.key = key
15
+ self.value = value
16
+
17
+
8
18
  class SQLAlchemyTransformer(BaseTransformer):
9
19
  """Transform SQLAlchemy 1.x code to 2.0."""
10
20
 
11
21
  def __init__(self) -> None:
12
22
  super().__init__()
13
23
  self._needs_select_import = False
24
+ self._needs_func_import = False
14
25
  self._needs_text_import = False
15
26
  self._has_declarative_base_import = False
27
+ self._has_text_import = False
28
+ # Track declarative_base variable name for transformation
29
+ self._declarative_base_var_name: str | None = None
30
+ # Track engine variable names from create_engine() calls
31
+ self._engine_var_names: set[str] = set()
32
+
33
+ def visit_ImportFrom(self, node: cst.ImportFrom) -> bool:
34
+ """Track existing imports."""
35
+ if node.module is None:
36
+ return True
37
+
38
+ module_name = self._get_module_name(node.module)
39
+
40
+ # Track if text is already imported from sqlalchemy
41
+ if module_name == "sqlalchemy":
42
+ if not isinstance(node.names, cst.ImportStar):
43
+ for name in node.names:
44
+ if isinstance(name, cst.ImportAlias):
45
+ if isinstance(name.name, cst.Name) and name.name.value == "text":
46
+ self._has_text_import = True
47
+
48
+ return True
16
49
 
17
50
  def leave_ImportFrom(
18
51
  self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
@@ -23,18 +56,18 @@ class SQLAlchemyTransformer(BaseTransformer):
23
56
 
24
57
  module_name = self._get_module_name(original_node.module)
25
58
 
26
- # Transform declarative_base import
59
+ # Transform declarative_base import from sqlalchemy.ext.declarative
27
60
  if module_name == "sqlalchemy.ext.declarative":
28
61
  if isinstance(updated_node.names, cst.ImportStar):
29
62
  return updated_node
30
63
 
31
64
  new_names = []
32
- changed = False
65
+ found_declarative_base = False
33
66
 
34
67
  for name in updated_node.names:
35
68
  if isinstance(name, cst.ImportAlias):
36
69
  if isinstance(name.name, cst.Name) and name.name.value == "declarative_base":
37
- # Change to DeclarativeBase from sqlalchemy.orm
70
+ found_declarative_base = True
38
71
  self.record_change(
39
72
  description="Import DeclarativeBase from sqlalchemy.orm instead of declarative_base",
40
73
  line_number=1,
@@ -43,33 +76,48 @@ class SQLAlchemyTransformer(BaseTransformer):
43
76
  transform_name="import_declarative_base",
44
77
  )
45
78
  self._has_declarative_base_import = True
46
- # Return updated import from sqlalchemy.orm
47
- return updated_node.with_changes(
48
- module=cst.Attribute(
49
- value=cst.Name("sqlalchemy"),
50
- attr=cst.Name("orm"),
51
- ),
52
- names=[cst.ImportAlias(name=cst.Name("DeclarativeBase"))],
53
- )
54
79
  else:
55
80
  new_names.append(name)
56
81
  else:
57
82
  new_names.append(name)
58
83
 
59
- if changed and new_names:
60
- return updated_node.with_changes(names=new_names)
84
+ if found_declarative_base:
85
+ # Add DeclarativeBase to the list
86
+ new_names.insert(0, cst.ImportAlias(name=cst.Name("DeclarativeBase")))
87
+ # Change module to sqlalchemy.orm
88
+ return updated_node.with_changes(
89
+ module=cst.Attribute(
90
+ value=cst.Name("sqlalchemy"),
91
+ attr=cst.Name("orm"),
92
+ ),
93
+ names=new_names,
94
+ )
61
95
 
62
- # Handle backref import removal
96
+ # Handle declarative_base import from sqlalchemy.orm and backref removal
63
97
  if module_name == "sqlalchemy.orm":
64
98
  if isinstance(updated_node.names, cst.ImportStar):
65
99
  return updated_node
66
100
 
67
101
  new_names = []
68
- changed = False
102
+ found_declarative_base = False
103
+ found_backref = False
69
104
 
70
105
  for name in updated_node.names:
71
106
  if isinstance(name, cst.ImportAlias):
72
- if isinstance(name.name, cst.Name) and name.name.value == "backref":
107
+ name_value = name.name.value if isinstance(name.name, cst.Name) else None
108
+
109
+ if name_value == "declarative_base":
110
+ found_declarative_base = True
111
+ self.record_change(
112
+ description="Replace declarative_base import with DeclarativeBase",
113
+ line_number=1,
114
+ original="from sqlalchemy.orm import declarative_base",
115
+ replacement="from sqlalchemy.orm import DeclarativeBase",
116
+ transform_name="import_declarative_base",
117
+ )
118
+ self._has_declarative_base_import = True
119
+ elif name_value == "backref":
120
+ found_backref = True
73
121
  self.record_change(
74
122
  description="Remove backref import (use back_populates instead)",
75
123
  line_number=1,
@@ -77,37 +125,87 @@ class SQLAlchemyTransformer(BaseTransformer):
77
125
  replacement="# backref removed, use back_populates",
78
126
  transform_name="remove_backref_import",
79
127
  )
80
- changed = True
81
- continue
82
- new_names.append(name)
128
+ else:
129
+ new_names.append(name)
83
130
  else:
84
131
  new_names.append(name)
85
132
 
86
- if changed:
133
+ if found_declarative_base or found_backref:
134
+ # Add DeclarativeBase if we found declarative_base
135
+ if found_declarative_base:
136
+ new_names.insert(0, cst.ImportAlias(name=cst.Name("DeclarativeBase")))
137
+
87
138
  if new_names:
139
+ # Fix trailing comma: ensure last item has no trailing comma
140
+ if new_names:
141
+ last_item = new_names[-1]
142
+ if (
143
+ hasattr(last_item, "comma")
144
+ and last_item.comma != cst.MaybeSentinel.DEFAULT
145
+ ):
146
+ new_names[-1] = last_item.with_changes(comma=cst.MaybeSentinel.DEFAULT)
88
147
  return updated_node.with_changes(names=new_names)
89
- # If no names left, remove the import
90
- return cst.RemovalSentinel.REMOVE
148
+ else:
149
+ # No imports left, remove the line
150
+ return cst.RemovalSentinel.REMOVE
91
151
 
92
152
  return updated_node
93
153
 
94
- def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
95
- """Transform SQLAlchemy function calls."""
96
- # Handle declarative_base() -> class Base(DeclarativeBase): pass
97
- # This is complex - we just record the need for change
98
- if isinstance(updated_node.func, cst.Name):
99
- func_name = updated_node.func.value
154
+ def visit_Assign(self, node: cst.Assign) -> bool:
155
+ """Track engine variable names from create_engine() calls."""
156
+ if len(node.targets) == 1:
157
+ target = node.targets[0].target
158
+ if isinstance(target, cst.Name) and isinstance(node.value, cst.Call):
159
+ call = node.value
160
+ # Check if this is a create_engine() call
161
+ if isinstance(call.func, cst.Name) and call.func.value == "create_engine":
162
+ self._engine_var_names.add(target.value)
163
+ return True
100
164
 
101
- if func_name == "declarative_base":
102
- self.record_change(
103
- description="Replace declarative_base() with class Base(DeclarativeBase): pass",
104
- line_number=1,
105
- original="Base = declarative_base()",
106
- replacement="class Base(DeclarativeBase): pass",
107
- transform_name="declarative_base_to_class",
108
- confidence=0.8,
109
- notes="Manual review recommended - create class inheriting from DeclarativeBase",
110
- )
165
+ def leave_SimpleStatementLine(
166
+ self,
167
+ original_node: cst.SimpleStatementLine,
168
+ updated_node: cst.SimpleStatementLine,
169
+ ) -> cst.SimpleStatementLine | cst.ClassDef | cst.RemovalSentinel:
170
+ """Transform assignment statements like Base = declarative_base()."""
171
+ # Check if this is an assignment with declarative_base() call
172
+ if len(updated_node.body) == 1:
173
+ stmt = updated_node.body[0]
174
+ if isinstance(stmt, cst.Assign) and len(stmt.targets) == 1:
175
+ target = stmt.targets[0].target
176
+ if isinstance(target, cst.Name) and isinstance(stmt.value, cst.Call):
177
+ call = stmt.value
178
+ if isinstance(call.func, cst.Name) and call.func.value == "declarative_base":
179
+ var_name = target.value
180
+ self._declarative_base_var_name = var_name
181
+
182
+ self.record_change(
183
+ description=f"Replace {var_name} = declarative_base() with class {var_name}(DeclarativeBase): pass",
184
+ line_number=1,
185
+ original=f"{var_name} = declarative_base()",
186
+ replacement=f"class {var_name}(DeclarativeBase):\n pass",
187
+ transform_name="declarative_base_to_class",
188
+ confidence=1.0,
189
+ )
190
+
191
+ # Create a class definition: class Base(DeclarativeBase): pass
192
+ class_def = cst.ClassDef(
193
+ name=cst.Name(var_name),
194
+ bases=[cst.Arg(value=cst.Name("DeclarativeBase"))],
195
+ body=cst.IndentedBlock(
196
+ body=[cst.SimpleStatementLine(body=[cst.Pass()])]
197
+ ),
198
+ )
199
+ return class_def
200
+
201
+ return updated_node
202
+
203
+ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.BaseExpression:
204
+ """Transform SQLAlchemy function calls."""
205
+ # Handle session.query() transformations
206
+ transformed = self._transform_query_call(updated_node)
207
+ if transformed is not None:
208
+ return transformed
111
209
 
112
210
  # Handle create_engine future flag
113
211
  if isinstance(updated_node.func, cst.Name) and updated_node.func.value == "create_engine":
@@ -128,10 +226,392 @@ class SQLAlchemyTransformer(BaseTransformer):
128
226
  new_args.append(arg)
129
227
 
130
228
  if changed:
229
+ # Fix trailing comma: remove trailing comma from last argument if present
230
+ if new_args:
231
+ last_arg = new_args[-1]
232
+ # Remove trailing comma from the last argument
233
+ if last_arg.comma != cst.MaybeSentinel.DEFAULT:
234
+ new_args[-1] = last_arg.with_changes(comma=cst.MaybeSentinel.DEFAULT)
131
235
  return updated_node.with_changes(args=new_args)
132
236
 
237
+ # Handle execute() with raw SQL string - wrap with text()
238
+ if (
239
+ isinstance(updated_node.func, cst.Attribute)
240
+ and updated_node.func.attr.value == "execute"
241
+ ):
242
+ # Check if this is engine.execute() - which requires manual migration
243
+ # to use with engine.connect() as conn: conn.execute()
244
+ if self._is_engine_execute_call(updated_node):
245
+ self.record_change(
246
+ description="engine.execute() is removed in SQLAlchemy 2.0. "
247
+ "Use 'with engine.connect() as conn: conn.execute(...)' instead",
248
+ line_number=1,
249
+ original="engine.execute(...)",
250
+ replacement="with engine.connect() as conn:\n conn.execute(...)",
251
+ transform_name="engine_execute_to_connect",
252
+ confidence=0.5,
253
+ notes="MANUAL MIGRATION REQUIRED: This transformation requires "
254
+ "restructuring the code to use a context manager. The execute() call "
255
+ "must be moved inside a 'with engine.connect() as conn:' block, and "
256
+ "raw SQL strings should be wrapped with text(). If the result is used, "
257
+ "ensure proper handling within the context manager scope.",
258
+ )
259
+ # Don't transform the code - just record the warning
260
+ # The code still needs to have text() wrapping applied if applicable
261
+ # Fall through to the text wrapping logic below
262
+
263
+ if updated_node.args:
264
+ first_arg = updated_node.args[0]
265
+ # Check if the first argument is a string literal (raw SQL)
266
+ if isinstance(first_arg.value, cst.SimpleString) or isinstance(
267
+ first_arg.value, cst.ConcatenatedString
268
+ ):
269
+ # Wrap the string with text()
270
+ self._needs_text_import = True
271
+ text_call = cst.Call(
272
+ func=cst.Name("text"),
273
+ args=[cst.Arg(value=first_arg.value)],
274
+ )
275
+ new_first_arg = first_arg.with_changes(value=text_call)
276
+ new_args = [new_first_arg] + list(updated_node.args[1:])
277
+
278
+ self.record_change(
279
+ description="Wrap raw SQL string with text()",
280
+ line_number=1,
281
+ original='execute("...")',
282
+ replacement='execute(text("..."))',
283
+ transform_name="wrap_execute_with_text",
284
+ )
285
+
286
+ return updated_node.with_changes(args=new_args)
287
+
133
288
  return updated_node
134
289
 
290
+ def _is_engine_execute_call(self, node: cst.Call) -> bool:
291
+ """Check if a call is engine.execute() where engine is likely a SQLAlchemy engine.
292
+
293
+ Uses heuristics:
294
+ 1. The variable is known to be assigned from create_engine()
295
+ 2. The variable name contains 'engine' (case insensitive)
296
+
297
+ Args:
298
+ node: The Call node to check (already verified to be *.execute())
299
+
300
+ Returns:
301
+ True if this appears to be an engine.execute() call
302
+ """
303
+ if not isinstance(node.func, cst.Attribute):
304
+ return False
305
+
306
+ caller = node.func.value
307
+ if not isinstance(caller, cst.Name):
308
+ return False
309
+
310
+ var_name = caller.value
311
+
312
+ # Check if this variable was assigned from create_engine()
313
+ if var_name in self._engine_var_names:
314
+ return True
315
+
316
+ # Heuristic: check if variable name contains 'engine'
317
+ if "engine" in var_name.lower():
318
+ return True
319
+
320
+ return False
321
+
322
+ def _transform_query_call(self, node: cst.Call) -> cst.BaseExpression | None:
323
+ """Transform session.query(...) patterns to session.execute(select(...)).
324
+
325
+ Handles patterns like:
326
+ - session.query(Model).all() -> session.execute(select(Model)).scalars().all()
327
+ - session.query(Model).first() -> session.execute(select(Model)).scalars().first()
328
+ - session.query(Model).one() -> session.execute(select(Model)).scalars().one()
329
+ - session.query(Model).filter(...).all() -> session.execute(select(Model).where(...)).scalars().all()
330
+ - session.query(Model).get(id) -> session.get(Model, id)
331
+ - session.query(Model).count() -> session.execute(select(func.count()).select_from(Model)).scalar()
332
+
333
+ Returns:
334
+ Transformed node if this is a query pattern, None otherwise.
335
+ """
336
+ # Check if this is a method call (has Attribute as func)
337
+ if not isinstance(node.func, cst.Attribute):
338
+ return None
339
+
340
+ # Find the terminal method being called (.all(), .first(), .get(), etc.)
341
+ terminal_method = node.func.attr.value
342
+
343
+ # Methods that end a query chain
344
+ terminal_methods = {"all", "first", "one", "one_or_none", "get", "count", "scalar"}
345
+ if terminal_method not in terminal_methods:
346
+ return None
347
+
348
+ # Walk up the chain to find .query() and collect intermediate methods
349
+ chain_info = self._parse_query_chain(node)
350
+ if chain_info is None:
351
+ return None
352
+
353
+ session_node, model_node, filters, terminal_method, terminal_args = chain_info
354
+
355
+ # Handle .get(id) - transforms to session.get(Model, id)
356
+ if terminal_method == "get":
357
+ return self._transform_query_get(session_node, model_node, terminal_args)
358
+
359
+ # Handle .count() - transforms to session.execute(select(func.count()).select_from(Model)).scalar()
360
+ if terminal_method == "count":
361
+ return self._transform_query_count(session_node, model_node, filters)
362
+
363
+ # Handle .all(), .first(), .one(), .one_or_none(), .scalar()
364
+ return self._transform_query_execute(session_node, model_node, filters, terminal_method)
365
+
366
+ def _parse_query_chain(self, node: cst.Call) -> (
367
+ tuple[
368
+ cst.BaseExpression,
369
+ cst.BaseExpression,
370
+ list[cst.BaseExpression | _FilterByArg],
371
+ str,
372
+ Sequence[cst.Arg],
373
+ ]
374
+ | None
375
+ ):
376
+ """Parse a query chain to extract session, model, filters, and terminal method.
377
+
378
+ Returns:
379
+ Tuple of (session_node, model_node, filters, terminal_method, terminal_args) or None
380
+ """
381
+ if not isinstance(node.func, cst.Attribute):
382
+ return None
383
+
384
+ terminal_method = node.func.attr.value
385
+ terminal_args = node.args
386
+ current = node.func.value # Move past the terminal method call
387
+ filters: list[cst.BaseExpression | _FilterByArg] = []
388
+
389
+ # Walk up the chain collecting .filter() and .filter_by() calls
390
+ while True:
391
+ if isinstance(current, cst.Call):
392
+ func = current.func
393
+ if isinstance(func, cst.Attribute):
394
+ method_name = func.attr.value
395
+
396
+ if method_name == "query":
397
+ # Found the root .query() call
398
+ session_node = func.value
399
+ if current.args:
400
+ model_node = current.args[0].value
401
+ return (
402
+ session_node,
403
+ model_node,
404
+ list(reversed(filters)), # Reverse to get correct order
405
+ terminal_method,
406
+ terminal_args,
407
+ )
408
+ return None
409
+
410
+ elif method_name == "filter":
411
+ # Collect filter arguments
412
+ for arg in current.args:
413
+ filters.append(arg.value)
414
+ current = func.value
415
+
416
+ elif method_name == "filter_by":
417
+ # Convert filter_by(key=val) to Model.key == val
418
+ # Store the kwargs for handling during transform
419
+ for arg in current.args:
420
+ if arg.keyword is not None:
421
+ filters.append(_FilterByArg(arg.keyword.value, arg.value))
422
+ current = func.value
423
+
424
+ elif method_name in {
425
+ "order_by",
426
+ "limit",
427
+ "offset",
428
+ "distinct",
429
+ "group_by",
430
+ "having",
431
+ "join",
432
+ "outerjoin",
433
+ }:
434
+ # Skip these for now - they can be added to the select() later
435
+ current = func.value
436
+
437
+ else:
438
+ # Unknown method, not a query chain we can handle
439
+ return None
440
+ else:
441
+ return None
442
+ elif isinstance(current, cst.Attribute):
443
+ # This might be something like query.Model or session.query
444
+ return None
445
+ else:
446
+ return None
447
+
448
+ def _transform_query_get(
449
+ self,
450
+ session_node: cst.BaseExpression,
451
+ model_node: cst.BaseExpression,
452
+ args: Sequence[cst.Arg],
453
+ ) -> cst.Call:
454
+ """Transform session.query(Model).get(id) to session.get(Model, id)."""
455
+ self._needs_select_import = True
456
+
457
+ self.record_change(
458
+ description="Convert session.query(Model).get(id) to session.get(Model, id)",
459
+ line_number=1,
460
+ original="session.query(Model).get(id)",
461
+ replacement="session.get(Model, id)",
462
+ transform_name="query_get_to_session_get",
463
+ )
464
+
465
+ # Build session.get(Model, id)
466
+ new_args = [cst.Arg(value=model_node)]
467
+ new_args.extend(args)
468
+
469
+ return cst.Call(
470
+ func=cst.Attribute(value=session_node, attr=cst.Name("get")),
471
+ args=new_args,
472
+ )
473
+
474
+ def _transform_query_count(
475
+ self,
476
+ session_node: cst.BaseExpression,
477
+ model_node: cst.BaseExpression,
478
+ filters: list[cst.BaseExpression | _FilterByArg],
479
+ ) -> cst.Call:
480
+ """Transform session.query(Model).count() to session.execute(select(func.count()).select_from(Model)).scalar()."""
481
+ self._needs_select_import = True
482
+ self._needs_func_import = True
483
+
484
+ self.record_change(
485
+ description="Convert session.query(Model).count() to session.execute(select(func.count()).select_from(Model)).scalar()",
486
+ line_number=1,
487
+ original="session.query(Model).count()",
488
+ replacement="session.execute(select(func.count()).select_from(Model)).scalar()",
489
+ transform_name="query_count_to_select_count",
490
+ )
491
+
492
+ # Build func.count()
493
+ func_count = cst.Call(
494
+ func=cst.Attribute(value=cst.Name("func"), attr=cst.Name("count")),
495
+ args=[],
496
+ )
497
+
498
+ # Build select(func.count())
499
+ select_call = cst.Call(
500
+ func=cst.Name("select"),
501
+ args=[cst.Arg(value=func_count)],
502
+ )
503
+
504
+ # Add .select_from(Model)
505
+ select_from = cst.Call(
506
+ func=cst.Attribute(value=select_call, attr=cst.Name("select_from")),
507
+ args=[cst.Arg(value=model_node)],
508
+ )
509
+
510
+ # Add .where() if there are filters
511
+ current: cst.BaseExpression = select_from
512
+ for filter_expr in filters:
513
+ if isinstance(filter_expr, _FilterByArg):
514
+ # Convert filter_by to where with Model.attr == val
515
+ where_condition = cst.Comparison(
516
+ left=cst.Attribute(value=model_node, attr=cst.Name(filter_expr.key)),
517
+ comparisons=[
518
+ cst.ComparisonTarget(
519
+ operator=cst.Equal(),
520
+ comparator=filter_expr.value,
521
+ )
522
+ ],
523
+ )
524
+ current = cst.Call(
525
+ func=cst.Attribute(value=current, attr=cst.Name("where")),
526
+ args=[cst.Arg(value=where_condition)],
527
+ )
528
+ else:
529
+ current = cst.Call(
530
+ func=cst.Attribute(value=current, attr=cst.Name("where")),
531
+ args=[cst.Arg(value=filter_expr)],
532
+ )
533
+
534
+ # Build session.execute(...)
535
+ execute_call = cst.Call(
536
+ func=cst.Attribute(value=session_node, attr=cst.Name("execute")),
537
+ args=[cst.Arg(value=current)],
538
+ )
539
+
540
+ # Add .scalar()
541
+ return cst.Call(
542
+ func=cst.Attribute(value=execute_call, attr=cst.Name("scalar")),
543
+ args=[],
544
+ )
545
+
546
+ def _transform_query_execute(
547
+ self,
548
+ session_node: cst.BaseExpression,
549
+ model_node: cst.BaseExpression,
550
+ filters: list[cst.BaseExpression | _FilterByArg],
551
+ terminal_method: str,
552
+ ) -> cst.Call:
553
+ """Transform session.query(Model).all/first/one() to session.execute(select(Model)).scalars().all/first/one()."""
554
+ self._needs_select_import = True
555
+
556
+ original = f"session.query(Model).{terminal_method}()"
557
+ replacement = f"session.execute(select(Model)).scalars().{terminal_method}()"
558
+
559
+ self.record_change(
560
+ description=f"Convert {original} to {replacement}",
561
+ line_number=1,
562
+ original=original,
563
+ replacement=replacement,
564
+ transform_name=f"query_{terminal_method}_to_select",
565
+ )
566
+
567
+ # Build select(Model)
568
+ select_call = cst.Call(
569
+ func=cst.Name("select"),
570
+ args=[cst.Arg(value=model_node)],
571
+ )
572
+
573
+ # Add .where() for each filter
574
+ current: cst.BaseExpression = select_call
575
+ for filter_expr in filters:
576
+ if isinstance(filter_expr, _FilterByArg):
577
+ # Convert filter_by to where with Model.attr == val
578
+ where_condition = cst.Comparison(
579
+ left=cst.Attribute(value=model_node, attr=cst.Name(filter_expr.key)),
580
+ comparisons=[
581
+ cst.ComparisonTarget(
582
+ operator=cst.Equal(),
583
+ comparator=filter_expr.value,
584
+ )
585
+ ],
586
+ )
587
+ current = cst.Call(
588
+ func=cst.Attribute(value=current, attr=cst.Name("where")),
589
+ args=[cst.Arg(value=where_condition)],
590
+ )
591
+ else:
592
+ current = cst.Call(
593
+ func=cst.Attribute(value=current, attr=cst.Name("where")),
594
+ args=[cst.Arg(value=filter_expr)],
595
+ )
596
+
597
+ # Build session.execute(...)
598
+ execute_call = cst.Call(
599
+ func=cst.Attribute(value=session_node, attr=cst.Name("execute")),
600
+ args=[cst.Arg(value=current)],
601
+ )
602
+
603
+ # Add .scalars()
604
+ scalars_call = cst.Call(
605
+ func=cst.Attribute(value=execute_call, attr=cst.Name("scalars")),
606
+ args=[],
607
+ )
608
+
609
+ # Add terminal method (.all(), .first(), .one(), etc.)
610
+ return cst.Call(
611
+ func=cst.Attribute(value=scalars_call, attr=cst.Name(terminal_method)),
612
+ args=[],
613
+ )
614
+
135
615
  def leave_Attribute(
136
616
  self, original_node: cst.Attribute, updated_node: cst.Attribute
137
617
  ) -> cst.Attribute:
@@ -151,6 +631,168 @@ class SQLAlchemyTransformer(BaseTransformer):
151
631
  return ""
152
632
 
153
633
 
634
+ class SQLAlchemyImportTransformer(BaseTransformer):
635
+ """Separate transformer for handling import additions.
636
+
637
+ This runs after the main transformer to add any missing imports.
638
+ """
639
+
640
+ def __init__(
641
+ self,
642
+ needs_select_import: bool = False,
643
+ needs_func_import: bool = False,
644
+ needs_text_import: bool = False,
645
+ has_text_import: bool = False,
646
+ ) -> None:
647
+ super().__init__()
648
+ self._needs_select_import = needs_select_import
649
+ self._needs_func_import = needs_func_import
650
+ self._needs_text_import = needs_text_import
651
+ self._has_text_import = has_text_import
652
+ self._has_select_import = False
653
+ self._has_func_import = False
654
+ self._found_sqlalchemy_import = False
655
+
656
+ def visit_ImportFrom(self, node: cst.ImportFrom) -> bool:
657
+ """Check existing sqlalchemy imports."""
658
+ if node.module is None:
659
+ return True
660
+
661
+ module_name = self._get_module_name(node.module)
662
+ if module_name == "sqlalchemy":
663
+ self._found_sqlalchemy_import = True
664
+ if not isinstance(node.names, cst.ImportStar):
665
+ for name in node.names:
666
+ if isinstance(name, cst.ImportAlias):
667
+ if isinstance(name.name, cst.Name):
668
+ if name.name.value == "text":
669
+ self._has_text_import = True
670
+ elif name.name.value == "select":
671
+ self._has_select_import = True
672
+ elif name.name.value == "func":
673
+ self._has_func_import = True
674
+
675
+ return True
676
+
677
+ def leave_ImportFrom(
678
+ self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
679
+ ) -> cst.ImportFrom:
680
+ """Add missing imports to sqlalchemy import statement."""
681
+ if updated_node.module is None:
682
+ return updated_node
683
+
684
+ module_name = self._get_module_name(updated_node.module)
685
+ if module_name != "sqlalchemy":
686
+ return updated_node
687
+
688
+ if isinstance(updated_node.names, cst.ImportStar):
689
+ return updated_node
690
+
691
+ new_names = list(updated_node.names)
692
+ changed = False
693
+
694
+ # Add text import if needed and not already present
695
+ if self._needs_text_import and not self._has_text_import:
696
+ new_names.append(cst.ImportAlias(name=cst.Name("text")))
697
+ self._has_text_import = True
698
+ changed = True
699
+
700
+ # Add select import if needed and not already present
701
+ if self._needs_select_import and not self._has_select_import:
702
+ new_names.append(cst.ImportAlias(name=cst.Name("select")))
703
+ self._has_select_import = True
704
+ changed = True
705
+
706
+ self.record_change(
707
+ description="Add 'select' import for query transformation",
708
+ line_number=1,
709
+ original="from sqlalchemy import ...",
710
+ replacement="from sqlalchemy import ..., select",
711
+ transform_name="add_select_import",
712
+ )
713
+
714
+ # Add func import if needed and not already present
715
+ if self._needs_func_import and not self._has_func_import:
716
+ new_names.append(cst.ImportAlias(name=cst.Name("func")))
717
+ self._has_func_import = True
718
+ changed = True
719
+
720
+ self.record_change(
721
+ description="Add 'func' import for count transformation",
722
+ line_number=1,
723
+ original="from sqlalchemy import ...",
724
+ replacement="from sqlalchemy import ..., func",
725
+ transform_name="add_func_import",
726
+ )
727
+
728
+ if changed:
729
+ return updated_node.with_changes(names=new_names)
730
+
731
+ return updated_node
732
+
733
+ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
734
+ """Add sqlalchemy import if not found but needed."""
735
+ if self._found_sqlalchemy_import:
736
+ return updated_node
737
+
738
+ needs_import = (
739
+ (self._needs_select_import and not self._has_select_import)
740
+ or (self._needs_func_import and not self._has_func_import)
741
+ or (self._needs_text_import and not self._has_text_import)
742
+ )
743
+
744
+ if not needs_import:
745
+ return updated_node
746
+
747
+ # Build the import names
748
+ import_names = []
749
+ if self._needs_select_import and not self._has_select_import:
750
+ import_names.append(cst.ImportAlias(name=cst.Name("select")))
751
+ self.record_change(
752
+ description="Add 'select' import for query transformation",
753
+ line_number=1,
754
+ original="",
755
+ replacement="from sqlalchemy import select",
756
+ transform_name="add_select_import",
757
+ )
758
+ if self._needs_func_import and not self._has_func_import:
759
+ import_names.append(cst.ImportAlias(name=cst.Name("func")))
760
+ self.record_change(
761
+ description="Add 'func' import for count transformation",
762
+ line_number=1,
763
+ original="",
764
+ replacement="from sqlalchemy import func",
765
+ transform_name="add_func_import",
766
+ )
767
+ if self._needs_text_import and not self._has_text_import:
768
+ import_names.append(cst.ImportAlias(name=cst.Name("text")))
769
+
770
+ if not import_names:
771
+ return updated_node
772
+
773
+ # Create the import statement
774
+ new_import = cst.SimpleStatementLine(
775
+ body=[
776
+ cst.ImportFrom(
777
+ module=cst.Name("sqlalchemy"),
778
+ names=import_names,
779
+ )
780
+ ]
781
+ )
782
+
783
+ # Add at the beginning of the module (after any existing imports)
784
+ new_body = [new_import] + list(updated_node.body)
785
+ return updated_node.with_changes(body=new_body)
786
+
787
+ def _get_module_name(self, module: cst.BaseExpression) -> str:
788
+ """Get the full module name from an Attribute or Name node."""
789
+ if isinstance(module, cst.Name):
790
+ return str(module.value)
791
+ elif isinstance(module, cst.Attribute):
792
+ return f"{self._get_module_name(module.value)}.{module.attr.value}"
793
+ return ""
794
+
795
+
154
796
  def transform_sqlalchemy(source_code: str) -> tuple[str, list]:
155
797
  """Transform SQLAlchemy code from 1.x to 2.0.
156
798
 
@@ -170,6 +812,17 @@ def transform_sqlalchemy(source_code: str) -> tuple[str, list]:
170
812
 
171
813
  try:
172
814
  transformed_tree = tree.visit(transformer)
173
- return transformed_tree.code, transformer.changes
815
+
816
+ # Second pass: add missing imports
817
+ import_transformer = SQLAlchemyImportTransformer(
818
+ needs_select_import=transformer._needs_select_import,
819
+ needs_func_import=transformer._needs_func_import,
820
+ needs_text_import=transformer._needs_text_import,
821
+ has_text_import=transformer._has_text_import,
822
+ )
823
+ final_tree = transformed_tree.visit(import_transformer)
824
+
825
+ all_changes = transformer.changes + import_transformer.changes
826
+ return final_tree.code, all_changes
174
827
  except Exception:
175
828
  return source_code, []