codeshift 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. codeshift/cli/commands/apply.py +24 -2
  2. codeshift/cli/package_manager.py +102 -0
  3. codeshift/knowledge/generator.py +11 -1
  4. codeshift/knowledge_base/libraries/aiohttp.yaml +186 -0
  5. codeshift/knowledge_base/libraries/attrs.yaml +181 -0
  6. codeshift/knowledge_base/libraries/celery.yaml +244 -0
  7. codeshift/knowledge_base/libraries/click.yaml +195 -0
  8. codeshift/knowledge_base/libraries/django.yaml +355 -0
  9. codeshift/knowledge_base/libraries/flask.yaml +270 -0
  10. codeshift/knowledge_base/libraries/httpx.yaml +183 -0
  11. codeshift/knowledge_base/libraries/marshmallow.yaml +238 -0
  12. codeshift/knowledge_base/libraries/numpy.yaml +429 -0
  13. codeshift/knowledge_base/libraries/pytest.yaml +192 -0
  14. codeshift/knowledge_base/libraries/sqlalchemy.yaml +2 -1
  15. codeshift/migrator/engine.py +60 -0
  16. codeshift/migrator/transforms/__init__.py +2 -0
  17. codeshift/migrator/transforms/aiohttp_transformer.py +608 -0
  18. codeshift/migrator/transforms/attrs_transformer.py +570 -0
  19. codeshift/migrator/transforms/celery_transformer.py +546 -0
  20. codeshift/migrator/transforms/click_transformer.py +526 -0
  21. codeshift/migrator/transforms/django_transformer.py +852 -0
  22. codeshift/migrator/transforms/fastapi_transformer.py +12 -7
  23. codeshift/migrator/transforms/flask_transformer.py +505 -0
  24. codeshift/migrator/transforms/httpx_transformer.py +419 -0
  25. codeshift/migrator/transforms/marshmallow_transformer.py +515 -0
  26. codeshift/migrator/transforms/numpy_transformer.py +413 -0
  27. codeshift/migrator/transforms/pydantic_v1_to_v2.py +53 -8
  28. codeshift/migrator/transforms/pytest_transformer.py +351 -0
  29. codeshift/migrator/transforms/requests_transformer.py +74 -1
  30. codeshift/migrator/transforms/sqlalchemy_transformer.py +692 -39
  31. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/METADATA +46 -4
  32. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/RECORD +36 -15
  33. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/WHEEL +0 -0
  34. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/entry_points.txt +0 -0
  35. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/licenses/LICENSE +0 -0
  36. {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  """Library-specific transformation modules."""
2
2
 
3
3
  from codeshift.migrator.transforms.fastapi_transformer import FastAPITransformer
4
+ from codeshift.migrator.transforms.marshmallow_transformer import MarshmallowTransformer
4
5
  from codeshift.migrator.transforms.pandas_transformer import (
5
6
  PandasAppendTransformer,
6
7
  PandasTransformer,
@@ -16,4 +17,5 @@ __all__ = [
16
17
  "PandasTransformer",
17
18
  "PandasAppendTransformer",
18
19
  "RequestsTransformer",
20
+ "MarshmallowTransformer",
19
21
  ]
@@ -0,0 +1,608 @@
1
+ """aiohttp 3.7 to 3.9+ transformation using LibCST."""
2
+
3
+ import libcst as cst
4
+
5
+ from codeshift.migrator.ast_transforms import BaseTransformer
6
+
7
+
8
+ class AiohttpTransformer(BaseTransformer):
9
+ """Transform aiohttp code for version upgrades (3.7 to 3.9+).
10
+
11
+ Handles breaking changes including:
12
+ - Removal of loop parameter from ClientSession, TCPConnector, web.Application, etc.
13
+ - BasicAuth.encode() removal
14
+ - Deprecated timeout parameters (read_timeout, conn_timeout)
15
+ - app.loop property deprecation
16
+ - WebSocket timeout parameter rename
17
+ - Response URL attribute changes
18
+ - WebSocket protocol attribute rename
19
+ """
20
+
21
+ def __init__(self) -> None:
22
+ super().__init__()
23
+ self._needs_asyncio_import = False
24
+ self._has_asyncio_import = False
25
+ self._needs_client_timeout_import = False
26
+ self._has_client_timeout_import = False
27
+
28
+ def visit_ImportFrom(self, node: cst.ImportFrom) -> bool:
29
+ """Track existing imports."""
30
+ if node.module is None:
31
+ return True
32
+
33
+ module_name = self._get_module_name(node.module)
34
+
35
+ # Track if asyncio is imported
36
+ if module_name == "asyncio":
37
+ self._has_asyncio_import = True
38
+
39
+ # Track if ClientTimeout is imported from aiohttp
40
+ if module_name == "aiohttp":
41
+ if not isinstance(node.names, cst.ImportStar):
42
+ for name in node.names:
43
+ if isinstance(name, cst.ImportAlias):
44
+ if isinstance(name.name, cst.Name) and name.name.value == "ClientTimeout":
45
+ self._has_client_timeout_import = True
46
+
47
+ return True
48
+
49
+ def visit_Import(self, node: cst.Import) -> bool:
50
+ """Track import asyncio statements."""
51
+ if isinstance(node.names, cst.ImportStar):
52
+ return True
53
+
54
+ for name in node.names:
55
+ if isinstance(name, cst.ImportAlias):
56
+ if isinstance(name.name, cst.Name) and name.name.value == "asyncio":
57
+ self._has_asyncio_import = True
58
+ elif isinstance(name.name, cst.Attribute):
59
+ # Handle import asyncio.something
60
+ if self._get_module_name(name.name).startswith("asyncio"):
61
+ self._has_asyncio_import = True
62
+
63
+ return True
64
+
65
+ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.BaseExpression:
66
+ """Transform aiohttp function calls."""
67
+ # Handle ClientSession, TCPConnector, UnixConnector, web.Application, ClientTimeout
68
+ # with loop parameter removal
69
+ loop_transformed = self._remove_loop_parameter(updated_node)
70
+ if loop_transformed is not None:
71
+ return loop_transformed
72
+
73
+ # Handle deprecated timeout parameters (read_timeout, conn_timeout)
74
+ timeout_transformed = self._transform_deprecated_timeouts(updated_node)
75
+ if timeout_transformed is not None:
76
+ return timeout_transformed
77
+
78
+ # Handle BasicAuth.encode() -> str(BasicAuth(...))
79
+ auth_transformed = self._transform_basicauth_encode(updated_node)
80
+ if auth_transformed is not None:
81
+ return auth_transformed
82
+
83
+ # Handle ws_connect timeout -> receive_timeout
84
+ ws_transformed = self._transform_ws_connect_timeout(updated_node)
85
+ if ws_transformed is not None:
86
+ return ws_transformed
87
+
88
+ return updated_node
89
+
90
+ def _remove_loop_parameter(self, node: cst.Call) -> cst.Call | None:
91
+ """Remove loop parameter from aiohttp constructors.
92
+
93
+ Handles:
94
+ - ClientSession(loop=...)
95
+ - TCPConnector(loop=...)
96
+ - UnixConnector(loop=...)
97
+ - web.Application(loop=...)
98
+ - aiohttp.web.Application(loop=...)
99
+ - ClientTimeout(loop=...)
100
+ """
101
+ func_name = self._get_call_name(node)
102
+
103
+ # Classes that had loop parameter removed
104
+ classes_with_loop_removed = {
105
+ "ClientSession": "remove_loop_param_client_session",
106
+ "TCPConnector": "remove_loop_param_tcp_connector",
107
+ "UnixConnector": "remove_loop_param_unix_connector",
108
+ "Application": "remove_loop_param_web_application",
109
+ "ClientTimeout": "remove_loop_param_client_timeout",
110
+ }
111
+
112
+ # Check if this is one of the target classes
113
+ target_transform = None
114
+ for class_name, transform_name in classes_with_loop_removed.items():
115
+ if func_name == class_name or func_name.endswith(f".{class_name}"):
116
+ target_transform = transform_name
117
+ break
118
+
119
+ if target_transform is None:
120
+ return None
121
+
122
+ # Look for loop parameter and remove it
123
+ new_args = []
124
+ found_loop = False
125
+
126
+ for arg in node.args:
127
+ if isinstance(arg.keyword, cst.Name) and arg.keyword.value == "loop":
128
+ found_loop = True
129
+ continue
130
+ new_args.append(arg)
131
+
132
+ if not found_loop:
133
+ return None
134
+
135
+ # Fix trailing comma: remove trailing comma from last argument if present
136
+ if new_args:
137
+ last_arg = new_args[-1]
138
+ if last_arg.comma != cst.MaybeSentinel.DEFAULT:
139
+ new_args[-1] = last_arg.with_changes(comma=cst.MaybeSentinel.DEFAULT)
140
+
141
+ self.record_change(
142
+ description=f"Remove deprecated loop parameter from {func_name}",
143
+ line_number=1,
144
+ original=f"{func_name}(loop=...)",
145
+ replacement=f"{func_name}()",
146
+ transform_name=target_transform,
147
+ )
148
+
149
+ return node.with_changes(args=new_args)
150
+
151
+ def _transform_deprecated_timeouts(self, node: cst.Call) -> cst.Call | None:
152
+ """Transform deprecated read_timeout/conn_timeout to ClientTimeout.
153
+
154
+ Handles:
155
+ - ClientSession(read_timeout=X) -> ClientSession(timeout=ClientTimeout(total=X))
156
+ - ClientSession(conn_timeout=X) -> ClientSession(timeout=ClientTimeout(connect=X))
157
+ """
158
+ func_name = self._get_call_name(node)
159
+
160
+ if func_name != "ClientSession" and not func_name.endswith(".ClientSession"):
161
+ return None
162
+
163
+ new_args = []
164
+ timeout_values: dict[str, cst.BaseExpression] = {}
165
+ existing_timeout = None
166
+ changed = False
167
+
168
+ for arg in node.args:
169
+ if isinstance(arg.keyword, cst.Name):
170
+ if arg.keyword.value == "read_timeout":
171
+ timeout_values["total"] = arg.value
172
+ changed = True
173
+ self.record_change(
174
+ description="Convert read_timeout to ClientTimeout(total=...)",
175
+ line_number=1,
176
+ original="ClientSession(read_timeout=...)",
177
+ replacement="ClientSession(timeout=ClientTimeout(total=...))",
178
+ transform_name="read_timeout_to_client_timeout",
179
+ )
180
+ continue
181
+ elif arg.keyword.value == "conn_timeout":
182
+ timeout_values["connect"] = arg.value
183
+ changed = True
184
+ self.record_change(
185
+ description="Convert conn_timeout to ClientTimeout(connect=...)",
186
+ line_number=1,
187
+ original="ClientSession(conn_timeout=...)",
188
+ replacement="ClientSession(timeout=ClientTimeout(connect=...))",
189
+ transform_name="conn_timeout_to_client_timeout",
190
+ )
191
+ continue
192
+ elif arg.keyword.value == "timeout":
193
+ existing_timeout = arg
194
+ continue
195
+ new_args.append(arg)
196
+
197
+ if not changed:
198
+ return None
199
+
200
+ self._needs_client_timeout_import = True
201
+
202
+ # Build ClientTimeout call
203
+ if timeout_values:
204
+ timeout_args = []
205
+ for key, value in timeout_values.items():
206
+ timeout_args.append(
207
+ cst.Arg(
208
+ keyword=cst.Name(key),
209
+ value=value,
210
+ equal=cst.AssignEqual(
211
+ whitespace_before=cst.SimpleWhitespace(""),
212
+ whitespace_after=cst.SimpleWhitespace(""),
213
+ ),
214
+ )
215
+ )
216
+
217
+ client_timeout_call = cst.Call(
218
+ func=cst.Name("ClientTimeout"),
219
+ args=timeout_args,
220
+ )
221
+
222
+ # Add the timeout argument
223
+ new_args.append(
224
+ cst.Arg(
225
+ keyword=cst.Name("timeout"),
226
+ value=client_timeout_call,
227
+ equal=cst.AssignEqual(
228
+ whitespace_before=cst.SimpleWhitespace(""),
229
+ whitespace_after=cst.SimpleWhitespace(""),
230
+ ),
231
+ )
232
+ )
233
+ elif existing_timeout:
234
+ new_args.append(existing_timeout)
235
+
236
+ # Fix trailing comma
237
+ if new_args:
238
+ last_arg = new_args[-1]
239
+ if last_arg.comma != cst.MaybeSentinel.DEFAULT:
240
+ new_args[-1] = last_arg.with_changes(comma=cst.MaybeSentinel.DEFAULT)
241
+
242
+ return node.with_changes(args=new_args)
243
+
244
+ def _transform_basicauth_encode(self, node: cst.Call) -> cst.BaseExpression | None:
245
+ """Transform BasicAuth(...).encode() to str(BasicAuth(...)).
246
+
247
+ Handles:
248
+ - BasicAuth(user, pass).encode() -> str(BasicAuth(user, pass))
249
+ """
250
+ # Check if this is a .encode() call
251
+ if not isinstance(node.func, cst.Attribute):
252
+ return None
253
+
254
+ if node.func.attr.value != "encode":
255
+ return None
256
+
257
+ # Check if the base is a BasicAuth call
258
+ base = node.func.value
259
+ if not isinstance(base, cst.Call):
260
+ return None
261
+
262
+ base_func_name = self._get_call_name(base)
263
+ if base_func_name != "BasicAuth" and not base_func_name.endswith(".BasicAuth"):
264
+ return None
265
+
266
+ # Check that encode() has no arguments (the old API)
267
+ if node.args:
268
+ return None
269
+
270
+ self.record_change(
271
+ description="Convert BasicAuth(...).encode() to str(BasicAuth(...))",
272
+ line_number=1,
273
+ original="BasicAuth(...).encode()",
274
+ replacement="str(BasicAuth(...))",
275
+ transform_name="basicauth_encode_to_str",
276
+ )
277
+
278
+ # Transform to str(BasicAuth(...))
279
+ return cst.Call(
280
+ func=cst.Name("str"),
281
+ args=[cst.Arg(value=base)],
282
+ )
283
+
284
+ def _transform_ws_connect_timeout(self, node: cst.Call) -> cst.Call | None:
285
+ """Transform ws_connect timeout parameter to receive_timeout.
286
+
287
+ Handles:
288
+ - session.ws_connect(url, timeout=X) -> session.ws_connect(url, receive_timeout=X)
289
+ """
290
+ func_name = self._get_call_name(node)
291
+
292
+ if not (func_name == "ws_connect" or func_name.endswith(".ws_connect")):
293
+ return None
294
+
295
+ new_args = []
296
+ changed = False
297
+
298
+ for arg in node.args:
299
+ if isinstance(arg.keyword, cst.Name) and arg.keyword.value == "timeout":
300
+ # Rename timeout to receive_timeout
301
+ new_arg = arg.with_changes(keyword=cst.Name("receive_timeout"))
302
+ new_args.append(new_arg)
303
+ changed = True
304
+ self.record_change(
305
+ description="Rename ws_connect timeout parameter to receive_timeout",
306
+ line_number=1,
307
+ original="ws_connect(..., timeout=...)",
308
+ replacement="ws_connect(..., receive_timeout=...)",
309
+ transform_name="ws_connect_timeout_rename",
310
+ )
311
+ else:
312
+ new_args.append(arg)
313
+
314
+ if not changed:
315
+ return None
316
+
317
+ return node.with_changes(args=new_args)
318
+
319
+ def leave_Attribute(
320
+ self, original_node: cst.Attribute, updated_node: cst.Attribute
321
+ ) -> cst.BaseExpression:
322
+ """Transform aiohttp attribute accesses."""
323
+ attr_name = updated_node.attr.value
324
+
325
+ # Handle url_obj -> url
326
+ if attr_name == "url_obj":
327
+ self.record_change(
328
+ description="Rename url_obj attribute to url",
329
+ line_number=1,
330
+ original=".url_obj",
331
+ replacement=".url",
332
+ transform_name="url_obj_to_url",
333
+ )
334
+ return updated_node.with_changes(attr=cst.Name("url"))
335
+
336
+ # Handle WebSocketResponse.protocol -> ws_protocol
337
+ # This is tricky because we can't know the type, so we look for patterns
338
+ # like ws.protocol or websocket.protocol or ws_response.protocol
339
+ if attr_name == "protocol":
340
+ # Check if this looks like it might be a WebSocket response
341
+ value_name = self._get_expression_name(updated_node.value)
342
+ ws_patterns = {"ws", "websocket", "ws_response", "websock", "socket"}
343
+ if value_name and any(pattern in value_name.lower() for pattern in ws_patterns):
344
+ self.record_change(
345
+ description="Rename WebSocketResponse.protocol to ws_protocol",
346
+ line_number=1,
347
+ original=".protocol",
348
+ replacement=".ws_protocol",
349
+ transform_name="ws_protocol_rename",
350
+ )
351
+ return updated_node.with_changes(attr=cst.Name("ws_protocol"))
352
+
353
+ # Handle app.loop -> asyncio.get_event_loop()
354
+ if attr_name == "loop":
355
+ value_name = self._get_expression_name(updated_node.value)
356
+ if value_name and value_name in {"app", "application", "request.app"}:
357
+ self._needs_asyncio_import = True
358
+ self.record_change(
359
+ description=f"Replace {value_name}.loop with asyncio.get_event_loop()",
360
+ line_number=1,
361
+ original=f"{value_name}.loop",
362
+ replacement="asyncio.get_event_loop()",
363
+ transform_name="app_loop_to_get_event_loop",
364
+ )
365
+ return cst.Call(
366
+ func=cst.Attribute(
367
+ value=cst.Name("asyncio"),
368
+ attr=cst.Name("get_event_loop"),
369
+ ),
370
+ args=[],
371
+ )
372
+
373
+ return updated_node
374
+
375
+ def _get_call_name(self, node: cst.Call) -> str:
376
+ """Get the name of a function/class being called."""
377
+ if isinstance(node.func, cst.Name):
378
+ return str(node.func.value)
379
+ elif isinstance(node.func, cst.Attribute):
380
+ return self._get_attribute_chain(node.func)
381
+ return ""
382
+
383
+ def _get_attribute_chain(self, node: cst.Attribute) -> str:
384
+ """Get the full attribute chain as a string."""
385
+ if isinstance(node.value, cst.Name):
386
+ return f"{node.value.value}.{node.attr.value}"
387
+ elif isinstance(node.value, cst.Attribute):
388
+ return f"{self._get_attribute_chain(node.value)}.{node.attr.value}"
389
+ return str(node.attr.value)
390
+
391
+ def _get_expression_name(self, node: cst.BaseExpression) -> str | None:
392
+ """Get a simple name representation of an expression."""
393
+ if isinstance(node, cst.Name):
394
+ return str(node.value)
395
+ elif isinstance(node, cst.Attribute):
396
+ base = self._get_expression_name(node.value)
397
+ if base:
398
+ return f"{base}.{node.attr.value}"
399
+ return str(node.attr.value)
400
+ return None
401
+
402
+ def _get_module_name(self, module: cst.BaseExpression) -> str:
403
+ """Get the full module name from a Name or Attribute node."""
404
+ if isinstance(module, cst.Name):
405
+ return str(module.value)
406
+ elif isinstance(module, cst.Attribute):
407
+ return f"{self._get_module_name(module.value)}.{module.attr.value}"
408
+ return ""
409
+
410
+
411
+ class AiohttpImportTransformer(BaseTransformer):
412
+ """Separate transformer for handling import additions.
413
+
414
+ This runs after the main transformer to add any missing imports.
415
+ """
416
+
417
+ def __init__(
418
+ self,
419
+ needs_asyncio_import: bool = False,
420
+ has_asyncio_import: bool = False,
421
+ needs_client_timeout_import: bool = False,
422
+ has_client_timeout_import: bool = False,
423
+ ) -> None:
424
+ super().__init__()
425
+ self._needs_asyncio_import = needs_asyncio_import
426
+ self._has_asyncio_import = has_asyncio_import
427
+ self._needs_client_timeout_import = needs_client_timeout_import
428
+ self._has_client_timeout_import = has_client_timeout_import
429
+ self._found_aiohttp_import = False
430
+
431
+ def visit_Import(self, node: cst.Import) -> bool:
432
+ """Check for existing asyncio import."""
433
+ if isinstance(node.names, cst.ImportStar):
434
+ return True
435
+
436
+ for name in node.names:
437
+ if isinstance(name, cst.ImportAlias):
438
+ if isinstance(name.name, cst.Name) and name.name.value == "asyncio":
439
+ self._has_asyncio_import = True
440
+
441
+ return True
442
+
443
+ def visit_ImportFrom(self, node: cst.ImportFrom) -> bool:
444
+ """Check existing aiohttp imports."""
445
+ if node.module is None:
446
+ return True
447
+
448
+ module_name = self._get_module_name(node.module)
449
+
450
+ if module_name == "asyncio":
451
+ self._has_asyncio_import = True
452
+
453
+ if module_name == "aiohttp":
454
+ self._found_aiohttp_import = True
455
+ if not isinstance(node.names, cst.ImportStar):
456
+ for name in node.names:
457
+ if isinstance(name, cst.ImportAlias):
458
+ if isinstance(name.name, cst.Name):
459
+ if name.name.value == "ClientTimeout":
460
+ self._has_client_timeout_import = True
461
+
462
+ return True
463
+
464
+ def leave_ImportFrom(
465
+ self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
466
+ ) -> cst.ImportFrom:
467
+ """Add missing imports to aiohttp import statement."""
468
+ if updated_node.module is None:
469
+ return updated_node
470
+
471
+ module_name = self._get_module_name(updated_node.module)
472
+
473
+ if module_name != "aiohttp":
474
+ return updated_node
475
+
476
+ if isinstance(updated_node.names, cst.ImportStar):
477
+ return updated_node
478
+
479
+ new_names = list(updated_node.names)
480
+ changed = False
481
+
482
+ # Add ClientTimeout import if needed
483
+ if self._needs_client_timeout_import and not self._has_client_timeout_import:
484
+ new_names.append(cst.ImportAlias(name=cst.Name("ClientTimeout")))
485
+ self._has_client_timeout_import = True
486
+ changed = True
487
+
488
+ self.record_change(
489
+ description="Add 'ClientTimeout' import for timeout transformation",
490
+ line_number=1,
491
+ original="from aiohttp import ...",
492
+ replacement="from aiohttp import ..., ClientTimeout",
493
+ transform_name="add_client_timeout_import",
494
+ )
495
+
496
+ if changed:
497
+ return updated_node.with_changes(names=new_names)
498
+
499
+ return updated_node
500
+
501
+ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
502
+ """Add missing imports at module level."""
503
+ new_body = list(updated_node.body)
504
+ changed = False
505
+
506
+ # Add asyncio import if needed
507
+ if self._needs_asyncio_import and not self._has_asyncio_import:
508
+ asyncio_import = cst.SimpleStatementLine(
509
+ body=[cst.Import(names=[cst.ImportAlias(name=cst.Name("asyncio"))])]
510
+ )
511
+ # Insert after any __future__ imports
512
+ insert_pos = 0
513
+ for i, stmt in enumerate(new_body):
514
+ if isinstance(stmt, cst.SimpleStatementLine):
515
+ for s in stmt.body:
516
+ if isinstance(s, cst.ImportFrom):
517
+ if s.module and self._get_module_name(s.module) == "__future__":
518
+ insert_pos = i + 1
519
+ break
520
+ else:
521
+ break
522
+
523
+ new_body.insert(insert_pos, asyncio_import)
524
+ self._has_asyncio_import = True
525
+ changed = True
526
+
527
+ self.record_change(
528
+ description="Add 'asyncio' import for get_event_loop()",
529
+ line_number=1,
530
+ original="",
531
+ replacement="import asyncio",
532
+ transform_name="add_asyncio_import",
533
+ )
534
+
535
+ # Add ClientTimeout import if needed and no aiohttp import exists
536
+ if (
537
+ self._needs_client_timeout_import
538
+ and not self._has_client_timeout_import
539
+ and not self._found_aiohttp_import
540
+ ):
541
+ client_timeout_import = cst.SimpleStatementLine(
542
+ body=[
543
+ cst.ImportFrom(
544
+ module=cst.Name("aiohttp"),
545
+ names=[cst.ImportAlias(name=cst.Name("ClientTimeout"))],
546
+ )
547
+ ]
548
+ )
549
+ # Insert after asyncio import if we just added it
550
+ insert_pos = 1 if self._needs_asyncio_import else 0
551
+ new_body.insert(insert_pos, client_timeout_import)
552
+ changed = True
553
+
554
+ self.record_change(
555
+ description="Add 'ClientTimeout' import from aiohttp",
556
+ line_number=1,
557
+ original="",
558
+ replacement="from aiohttp import ClientTimeout",
559
+ transform_name="add_client_timeout_import",
560
+ )
561
+
562
+ if changed:
563
+ return updated_node.with_changes(body=new_body)
564
+
565
+ return updated_node
566
+
567
+ def _get_module_name(self, module: cst.BaseExpression) -> str:
568
+ """Get the full module name from an Attribute or Name node."""
569
+ if isinstance(module, cst.Name):
570
+ return str(module.value)
571
+ elif isinstance(module, cst.Attribute):
572
+ return f"{self._get_module_name(module.value)}.{module.attr.value}"
573
+ return ""
574
+
575
+
576
+ def transform_aiohttp(source_code: str) -> tuple[str, list]:
577
+ """Transform aiohttp code from 3.7 to 3.9+.
578
+
579
+ Args:
580
+ source_code: The source code to transform
581
+
582
+ Returns:
583
+ Tuple of (transformed_code, list of changes)
584
+ """
585
+ try:
586
+ tree = cst.parse_module(source_code)
587
+ except cst.ParserSyntaxError:
588
+ return source_code, []
589
+
590
+ transformer = AiohttpTransformer()
591
+ transformer.set_source(source_code)
592
+
593
+ try:
594
+ transformed_tree = tree.visit(transformer)
595
+
596
+ # Second pass: add missing imports
597
+ import_transformer = AiohttpImportTransformer(
598
+ needs_asyncio_import=transformer._needs_asyncio_import,
599
+ has_asyncio_import=transformer._has_asyncio_import,
600
+ needs_client_timeout_import=transformer._needs_client_timeout_import,
601
+ has_client_timeout_import=transformer._has_client_timeout_import,
602
+ )
603
+ final_tree = transformed_tree.visit(import_transformer)
604
+
605
+ all_changes = transformer.changes + import_transformer.changes
606
+ return final_tree.code, all_changes
607
+ except Exception:
608
+ return source_code, []