codeshift 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codeshift/cli/commands/apply.py +24 -2
- codeshift/cli/package_manager.py +102 -0
- codeshift/knowledge/generator.py +11 -1
- codeshift/knowledge_base/libraries/aiohttp.yaml +186 -0
- codeshift/knowledge_base/libraries/attrs.yaml +181 -0
- codeshift/knowledge_base/libraries/celery.yaml +244 -0
- codeshift/knowledge_base/libraries/click.yaml +195 -0
- codeshift/knowledge_base/libraries/django.yaml +355 -0
- codeshift/knowledge_base/libraries/flask.yaml +270 -0
- codeshift/knowledge_base/libraries/httpx.yaml +183 -0
- codeshift/knowledge_base/libraries/marshmallow.yaml +238 -0
- codeshift/knowledge_base/libraries/numpy.yaml +429 -0
- codeshift/knowledge_base/libraries/pytest.yaml +192 -0
- codeshift/knowledge_base/libraries/sqlalchemy.yaml +2 -1
- codeshift/migrator/engine.py +60 -0
- codeshift/migrator/transforms/__init__.py +2 -0
- codeshift/migrator/transforms/aiohttp_transformer.py +608 -0
- codeshift/migrator/transforms/attrs_transformer.py +570 -0
- codeshift/migrator/transforms/celery_transformer.py +546 -0
- codeshift/migrator/transforms/click_transformer.py +526 -0
- codeshift/migrator/transforms/django_transformer.py +852 -0
- codeshift/migrator/transforms/fastapi_transformer.py +12 -7
- codeshift/migrator/transforms/flask_transformer.py +505 -0
- codeshift/migrator/transforms/httpx_transformer.py +419 -0
- codeshift/migrator/transforms/marshmallow_transformer.py +515 -0
- codeshift/migrator/transforms/numpy_transformer.py +413 -0
- codeshift/migrator/transforms/pydantic_v1_to_v2.py +53 -8
- codeshift/migrator/transforms/pytest_transformer.py +351 -0
- codeshift/migrator/transforms/requests_transformer.py +74 -1
- codeshift/migrator/transforms/sqlalchemy_transformer.py +692 -39
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/METADATA +46 -4
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/RECORD +36 -15
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/WHEEL +0 -0
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/entry_points.txt +0 -0
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/licenses/LICENSE +0 -0
- {codeshift-0.3.3.dist-info → codeshift-0.3.4.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Library-specific transformation modules."""
|
|
2
2
|
|
|
3
3
|
from codeshift.migrator.transforms.fastapi_transformer import FastAPITransformer
|
|
4
|
+
from codeshift.migrator.transforms.marshmallow_transformer import MarshmallowTransformer
|
|
4
5
|
from codeshift.migrator.transforms.pandas_transformer import (
|
|
5
6
|
PandasAppendTransformer,
|
|
6
7
|
PandasTransformer,
|
|
@@ -16,4 +17,5 @@ __all__ = [
|
|
|
16
17
|
"PandasTransformer",
|
|
17
18
|
"PandasAppendTransformer",
|
|
18
19
|
"RequestsTransformer",
|
|
20
|
+
"MarshmallowTransformer",
|
|
19
21
|
]
|
|
@@ -0,0 +1,608 @@
|
|
|
1
|
+
"""aiohttp 3.7 to 3.9+ transformation using LibCST."""
|
|
2
|
+
|
|
3
|
+
import libcst as cst
|
|
4
|
+
|
|
5
|
+
from codeshift.migrator.ast_transforms import BaseTransformer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AiohttpTransformer(BaseTransformer):
|
|
9
|
+
"""Transform aiohttp code for version upgrades (3.7 to 3.9+).
|
|
10
|
+
|
|
11
|
+
Handles breaking changes including:
|
|
12
|
+
- Removal of loop parameter from ClientSession, TCPConnector, web.Application, etc.
|
|
13
|
+
- BasicAuth.encode() removal
|
|
14
|
+
- Deprecated timeout parameters (read_timeout, conn_timeout)
|
|
15
|
+
- app.loop property deprecation
|
|
16
|
+
- WebSocket timeout parameter rename
|
|
17
|
+
- Response URL attribute changes
|
|
18
|
+
- WebSocket protocol attribute rename
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self) -> None:
|
|
22
|
+
super().__init__()
|
|
23
|
+
self._needs_asyncio_import = False
|
|
24
|
+
self._has_asyncio_import = False
|
|
25
|
+
self._needs_client_timeout_import = False
|
|
26
|
+
self._has_client_timeout_import = False
|
|
27
|
+
|
|
28
|
+
def visit_ImportFrom(self, node: cst.ImportFrom) -> bool:
|
|
29
|
+
"""Track existing imports."""
|
|
30
|
+
if node.module is None:
|
|
31
|
+
return True
|
|
32
|
+
|
|
33
|
+
module_name = self._get_module_name(node.module)
|
|
34
|
+
|
|
35
|
+
# Track if asyncio is imported
|
|
36
|
+
if module_name == "asyncio":
|
|
37
|
+
self._has_asyncio_import = True
|
|
38
|
+
|
|
39
|
+
# Track if ClientTimeout is imported from aiohttp
|
|
40
|
+
if module_name == "aiohttp":
|
|
41
|
+
if not isinstance(node.names, cst.ImportStar):
|
|
42
|
+
for name in node.names:
|
|
43
|
+
if isinstance(name, cst.ImportAlias):
|
|
44
|
+
if isinstance(name.name, cst.Name) and name.name.value == "ClientTimeout":
|
|
45
|
+
self._has_client_timeout_import = True
|
|
46
|
+
|
|
47
|
+
return True
|
|
48
|
+
|
|
49
|
+
def visit_Import(self, node: cst.Import) -> bool:
|
|
50
|
+
"""Track import asyncio statements."""
|
|
51
|
+
if isinstance(node.names, cst.ImportStar):
|
|
52
|
+
return True
|
|
53
|
+
|
|
54
|
+
for name in node.names:
|
|
55
|
+
if isinstance(name, cst.ImportAlias):
|
|
56
|
+
if isinstance(name.name, cst.Name) and name.name.value == "asyncio":
|
|
57
|
+
self._has_asyncio_import = True
|
|
58
|
+
elif isinstance(name.name, cst.Attribute):
|
|
59
|
+
# Handle import asyncio.something
|
|
60
|
+
if self._get_module_name(name.name).startswith("asyncio"):
|
|
61
|
+
self._has_asyncio_import = True
|
|
62
|
+
|
|
63
|
+
return True
|
|
64
|
+
|
|
65
|
+
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.BaseExpression:
|
|
66
|
+
"""Transform aiohttp function calls."""
|
|
67
|
+
# Handle ClientSession, TCPConnector, UnixConnector, web.Application, ClientTimeout
|
|
68
|
+
# with loop parameter removal
|
|
69
|
+
loop_transformed = self._remove_loop_parameter(updated_node)
|
|
70
|
+
if loop_transformed is not None:
|
|
71
|
+
return loop_transformed
|
|
72
|
+
|
|
73
|
+
# Handle deprecated timeout parameters (read_timeout, conn_timeout)
|
|
74
|
+
timeout_transformed = self._transform_deprecated_timeouts(updated_node)
|
|
75
|
+
if timeout_transformed is not None:
|
|
76
|
+
return timeout_transformed
|
|
77
|
+
|
|
78
|
+
# Handle BasicAuth.encode() -> str(BasicAuth(...))
|
|
79
|
+
auth_transformed = self._transform_basicauth_encode(updated_node)
|
|
80
|
+
if auth_transformed is not None:
|
|
81
|
+
return auth_transformed
|
|
82
|
+
|
|
83
|
+
# Handle ws_connect timeout -> receive_timeout
|
|
84
|
+
ws_transformed = self._transform_ws_connect_timeout(updated_node)
|
|
85
|
+
if ws_transformed is not None:
|
|
86
|
+
return ws_transformed
|
|
87
|
+
|
|
88
|
+
return updated_node
|
|
89
|
+
|
|
90
|
+
def _remove_loop_parameter(self, node: cst.Call) -> cst.Call | None:
|
|
91
|
+
"""Remove loop parameter from aiohttp constructors.
|
|
92
|
+
|
|
93
|
+
Handles:
|
|
94
|
+
- ClientSession(loop=...)
|
|
95
|
+
- TCPConnector(loop=...)
|
|
96
|
+
- UnixConnector(loop=...)
|
|
97
|
+
- web.Application(loop=...)
|
|
98
|
+
- aiohttp.web.Application(loop=...)
|
|
99
|
+
- ClientTimeout(loop=...)
|
|
100
|
+
"""
|
|
101
|
+
func_name = self._get_call_name(node)
|
|
102
|
+
|
|
103
|
+
# Classes that had loop parameter removed
|
|
104
|
+
classes_with_loop_removed = {
|
|
105
|
+
"ClientSession": "remove_loop_param_client_session",
|
|
106
|
+
"TCPConnector": "remove_loop_param_tcp_connector",
|
|
107
|
+
"UnixConnector": "remove_loop_param_unix_connector",
|
|
108
|
+
"Application": "remove_loop_param_web_application",
|
|
109
|
+
"ClientTimeout": "remove_loop_param_client_timeout",
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
# Check if this is one of the target classes
|
|
113
|
+
target_transform = None
|
|
114
|
+
for class_name, transform_name in classes_with_loop_removed.items():
|
|
115
|
+
if func_name == class_name or func_name.endswith(f".{class_name}"):
|
|
116
|
+
target_transform = transform_name
|
|
117
|
+
break
|
|
118
|
+
|
|
119
|
+
if target_transform is None:
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
# Look for loop parameter and remove it
|
|
123
|
+
new_args = []
|
|
124
|
+
found_loop = False
|
|
125
|
+
|
|
126
|
+
for arg in node.args:
|
|
127
|
+
if isinstance(arg.keyword, cst.Name) and arg.keyword.value == "loop":
|
|
128
|
+
found_loop = True
|
|
129
|
+
continue
|
|
130
|
+
new_args.append(arg)
|
|
131
|
+
|
|
132
|
+
if not found_loop:
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
# Fix trailing comma: remove trailing comma from last argument if present
|
|
136
|
+
if new_args:
|
|
137
|
+
last_arg = new_args[-1]
|
|
138
|
+
if last_arg.comma != cst.MaybeSentinel.DEFAULT:
|
|
139
|
+
new_args[-1] = last_arg.with_changes(comma=cst.MaybeSentinel.DEFAULT)
|
|
140
|
+
|
|
141
|
+
self.record_change(
|
|
142
|
+
description=f"Remove deprecated loop parameter from {func_name}",
|
|
143
|
+
line_number=1,
|
|
144
|
+
original=f"{func_name}(loop=...)",
|
|
145
|
+
replacement=f"{func_name}()",
|
|
146
|
+
transform_name=target_transform,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
return node.with_changes(args=new_args)
|
|
150
|
+
|
|
151
|
+
def _transform_deprecated_timeouts(self, node: cst.Call) -> cst.Call | None:
|
|
152
|
+
"""Transform deprecated read_timeout/conn_timeout to ClientTimeout.
|
|
153
|
+
|
|
154
|
+
Handles:
|
|
155
|
+
- ClientSession(read_timeout=X) -> ClientSession(timeout=ClientTimeout(total=X))
|
|
156
|
+
- ClientSession(conn_timeout=X) -> ClientSession(timeout=ClientTimeout(connect=X))
|
|
157
|
+
"""
|
|
158
|
+
func_name = self._get_call_name(node)
|
|
159
|
+
|
|
160
|
+
if func_name != "ClientSession" and not func_name.endswith(".ClientSession"):
|
|
161
|
+
return None
|
|
162
|
+
|
|
163
|
+
new_args = []
|
|
164
|
+
timeout_values: dict[str, cst.BaseExpression] = {}
|
|
165
|
+
existing_timeout = None
|
|
166
|
+
changed = False
|
|
167
|
+
|
|
168
|
+
for arg in node.args:
|
|
169
|
+
if isinstance(arg.keyword, cst.Name):
|
|
170
|
+
if arg.keyword.value == "read_timeout":
|
|
171
|
+
timeout_values["total"] = arg.value
|
|
172
|
+
changed = True
|
|
173
|
+
self.record_change(
|
|
174
|
+
description="Convert read_timeout to ClientTimeout(total=...)",
|
|
175
|
+
line_number=1,
|
|
176
|
+
original="ClientSession(read_timeout=...)",
|
|
177
|
+
replacement="ClientSession(timeout=ClientTimeout(total=...))",
|
|
178
|
+
transform_name="read_timeout_to_client_timeout",
|
|
179
|
+
)
|
|
180
|
+
continue
|
|
181
|
+
elif arg.keyword.value == "conn_timeout":
|
|
182
|
+
timeout_values["connect"] = arg.value
|
|
183
|
+
changed = True
|
|
184
|
+
self.record_change(
|
|
185
|
+
description="Convert conn_timeout to ClientTimeout(connect=...)",
|
|
186
|
+
line_number=1,
|
|
187
|
+
original="ClientSession(conn_timeout=...)",
|
|
188
|
+
replacement="ClientSession(timeout=ClientTimeout(connect=...))",
|
|
189
|
+
transform_name="conn_timeout_to_client_timeout",
|
|
190
|
+
)
|
|
191
|
+
continue
|
|
192
|
+
elif arg.keyword.value == "timeout":
|
|
193
|
+
existing_timeout = arg
|
|
194
|
+
continue
|
|
195
|
+
new_args.append(arg)
|
|
196
|
+
|
|
197
|
+
if not changed:
|
|
198
|
+
return None
|
|
199
|
+
|
|
200
|
+
self._needs_client_timeout_import = True
|
|
201
|
+
|
|
202
|
+
# Build ClientTimeout call
|
|
203
|
+
if timeout_values:
|
|
204
|
+
timeout_args = []
|
|
205
|
+
for key, value in timeout_values.items():
|
|
206
|
+
timeout_args.append(
|
|
207
|
+
cst.Arg(
|
|
208
|
+
keyword=cst.Name(key),
|
|
209
|
+
value=value,
|
|
210
|
+
equal=cst.AssignEqual(
|
|
211
|
+
whitespace_before=cst.SimpleWhitespace(""),
|
|
212
|
+
whitespace_after=cst.SimpleWhitespace(""),
|
|
213
|
+
),
|
|
214
|
+
)
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
client_timeout_call = cst.Call(
|
|
218
|
+
func=cst.Name("ClientTimeout"),
|
|
219
|
+
args=timeout_args,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Add the timeout argument
|
|
223
|
+
new_args.append(
|
|
224
|
+
cst.Arg(
|
|
225
|
+
keyword=cst.Name("timeout"),
|
|
226
|
+
value=client_timeout_call,
|
|
227
|
+
equal=cst.AssignEqual(
|
|
228
|
+
whitespace_before=cst.SimpleWhitespace(""),
|
|
229
|
+
whitespace_after=cst.SimpleWhitespace(""),
|
|
230
|
+
),
|
|
231
|
+
)
|
|
232
|
+
)
|
|
233
|
+
elif existing_timeout:
|
|
234
|
+
new_args.append(existing_timeout)
|
|
235
|
+
|
|
236
|
+
# Fix trailing comma
|
|
237
|
+
if new_args:
|
|
238
|
+
last_arg = new_args[-1]
|
|
239
|
+
if last_arg.comma != cst.MaybeSentinel.DEFAULT:
|
|
240
|
+
new_args[-1] = last_arg.with_changes(comma=cst.MaybeSentinel.DEFAULT)
|
|
241
|
+
|
|
242
|
+
return node.with_changes(args=new_args)
|
|
243
|
+
|
|
244
|
+
def _transform_basicauth_encode(self, node: cst.Call) -> cst.BaseExpression | None:
|
|
245
|
+
"""Transform BasicAuth(...).encode() to str(BasicAuth(...)).
|
|
246
|
+
|
|
247
|
+
Handles:
|
|
248
|
+
- BasicAuth(user, pass).encode() -> str(BasicAuth(user, pass))
|
|
249
|
+
"""
|
|
250
|
+
# Check if this is a .encode() call
|
|
251
|
+
if not isinstance(node.func, cst.Attribute):
|
|
252
|
+
return None
|
|
253
|
+
|
|
254
|
+
if node.func.attr.value != "encode":
|
|
255
|
+
return None
|
|
256
|
+
|
|
257
|
+
# Check if the base is a BasicAuth call
|
|
258
|
+
base = node.func.value
|
|
259
|
+
if not isinstance(base, cst.Call):
|
|
260
|
+
return None
|
|
261
|
+
|
|
262
|
+
base_func_name = self._get_call_name(base)
|
|
263
|
+
if base_func_name != "BasicAuth" and not base_func_name.endswith(".BasicAuth"):
|
|
264
|
+
return None
|
|
265
|
+
|
|
266
|
+
# Check that encode() has no arguments (the old API)
|
|
267
|
+
if node.args:
|
|
268
|
+
return None
|
|
269
|
+
|
|
270
|
+
self.record_change(
|
|
271
|
+
description="Convert BasicAuth(...).encode() to str(BasicAuth(...))",
|
|
272
|
+
line_number=1,
|
|
273
|
+
original="BasicAuth(...).encode()",
|
|
274
|
+
replacement="str(BasicAuth(...))",
|
|
275
|
+
transform_name="basicauth_encode_to_str",
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# Transform to str(BasicAuth(...))
|
|
279
|
+
return cst.Call(
|
|
280
|
+
func=cst.Name("str"),
|
|
281
|
+
args=[cst.Arg(value=base)],
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
def _transform_ws_connect_timeout(self, node: cst.Call) -> cst.Call | None:
|
|
285
|
+
"""Transform ws_connect timeout parameter to receive_timeout.
|
|
286
|
+
|
|
287
|
+
Handles:
|
|
288
|
+
- session.ws_connect(url, timeout=X) -> session.ws_connect(url, receive_timeout=X)
|
|
289
|
+
"""
|
|
290
|
+
func_name = self._get_call_name(node)
|
|
291
|
+
|
|
292
|
+
if not (func_name == "ws_connect" or func_name.endswith(".ws_connect")):
|
|
293
|
+
return None
|
|
294
|
+
|
|
295
|
+
new_args = []
|
|
296
|
+
changed = False
|
|
297
|
+
|
|
298
|
+
for arg in node.args:
|
|
299
|
+
if isinstance(arg.keyword, cst.Name) and arg.keyword.value == "timeout":
|
|
300
|
+
# Rename timeout to receive_timeout
|
|
301
|
+
new_arg = arg.with_changes(keyword=cst.Name("receive_timeout"))
|
|
302
|
+
new_args.append(new_arg)
|
|
303
|
+
changed = True
|
|
304
|
+
self.record_change(
|
|
305
|
+
description="Rename ws_connect timeout parameter to receive_timeout",
|
|
306
|
+
line_number=1,
|
|
307
|
+
original="ws_connect(..., timeout=...)",
|
|
308
|
+
replacement="ws_connect(..., receive_timeout=...)",
|
|
309
|
+
transform_name="ws_connect_timeout_rename",
|
|
310
|
+
)
|
|
311
|
+
else:
|
|
312
|
+
new_args.append(arg)
|
|
313
|
+
|
|
314
|
+
if not changed:
|
|
315
|
+
return None
|
|
316
|
+
|
|
317
|
+
return node.with_changes(args=new_args)
|
|
318
|
+
|
|
319
|
+
def leave_Attribute(
|
|
320
|
+
self, original_node: cst.Attribute, updated_node: cst.Attribute
|
|
321
|
+
) -> cst.BaseExpression:
|
|
322
|
+
"""Transform aiohttp attribute accesses."""
|
|
323
|
+
attr_name = updated_node.attr.value
|
|
324
|
+
|
|
325
|
+
# Handle url_obj -> url
|
|
326
|
+
if attr_name == "url_obj":
|
|
327
|
+
self.record_change(
|
|
328
|
+
description="Rename url_obj attribute to url",
|
|
329
|
+
line_number=1,
|
|
330
|
+
original=".url_obj",
|
|
331
|
+
replacement=".url",
|
|
332
|
+
transform_name="url_obj_to_url",
|
|
333
|
+
)
|
|
334
|
+
return updated_node.with_changes(attr=cst.Name("url"))
|
|
335
|
+
|
|
336
|
+
# Handle WebSocketResponse.protocol -> ws_protocol
|
|
337
|
+
# This is tricky because we can't know the type, so we look for patterns
|
|
338
|
+
# like ws.protocol or websocket.protocol or ws_response.protocol
|
|
339
|
+
if attr_name == "protocol":
|
|
340
|
+
# Check if this looks like it might be a WebSocket response
|
|
341
|
+
value_name = self._get_expression_name(updated_node.value)
|
|
342
|
+
ws_patterns = {"ws", "websocket", "ws_response", "websock", "socket"}
|
|
343
|
+
if value_name and any(pattern in value_name.lower() for pattern in ws_patterns):
|
|
344
|
+
self.record_change(
|
|
345
|
+
description="Rename WebSocketResponse.protocol to ws_protocol",
|
|
346
|
+
line_number=1,
|
|
347
|
+
original=".protocol",
|
|
348
|
+
replacement=".ws_protocol",
|
|
349
|
+
transform_name="ws_protocol_rename",
|
|
350
|
+
)
|
|
351
|
+
return updated_node.with_changes(attr=cst.Name("ws_protocol"))
|
|
352
|
+
|
|
353
|
+
# Handle app.loop -> asyncio.get_event_loop()
|
|
354
|
+
if attr_name == "loop":
|
|
355
|
+
value_name = self._get_expression_name(updated_node.value)
|
|
356
|
+
if value_name and value_name in {"app", "application", "request.app"}:
|
|
357
|
+
self._needs_asyncio_import = True
|
|
358
|
+
self.record_change(
|
|
359
|
+
description=f"Replace {value_name}.loop with asyncio.get_event_loop()",
|
|
360
|
+
line_number=1,
|
|
361
|
+
original=f"{value_name}.loop",
|
|
362
|
+
replacement="asyncio.get_event_loop()",
|
|
363
|
+
transform_name="app_loop_to_get_event_loop",
|
|
364
|
+
)
|
|
365
|
+
return cst.Call(
|
|
366
|
+
func=cst.Attribute(
|
|
367
|
+
value=cst.Name("asyncio"),
|
|
368
|
+
attr=cst.Name("get_event_loop"),
|
|
369
|
+
),
|
|
370
|
+
args=[],
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
return updated_node
|
|
374
|
+
|
|
375
|
+
def _get_call_name(self, node: cst.Call) -> str:
|
|
376
|
+
"""Get the name of a function/class being called."""
|
|
377
|
+
if isinstance(node.func, cst.Name):
|
|
378
|
+
return str(node.func.value)
|
|
379
|
+
elif isinstance(node.func, cst.Attribute):
|
|
380
|
+
return self._get_attribute_chain(node.func)
|
|
381
|
+
return ""
|
|
382
|
+
|
|
383
|
+
def _get_attribute_chain(self, node: cst.Attribute) -> str:
|
|
384
|
+
"""Get the full attribute chain as a string."""
|
|
385
|
+
if isinstance(node.value, cst.Name):
|
|
386
|
+
return f"{node.value.value}.{node.attr.value}"
|
|
387
|
+
elif isinstance(node.value, cst.Attribute):
|
|
388
|
+
return f"{self._get_attribute_chain(node.value)}.{node.attr.value}"
|
|
389
|
+
return str(node.attr.value)
|
|
390
|
+
|
|
391
|
+
def _get_expression_name(self, node: cst.BaseExpression) -> str | None:
|
|
392
|
+
"""Get a simple name representation of an expression."""
|
|
393
|
+
if isinstance(node, cst.Name):
|
|
394
|
+
return str(node.value)
|
|
395
|
+
elif isinstance(node, cst.Attribute):
|
|
396
|
+
base = self._get_expression_name(node.value)
|
|
397
|
+
if base:
|
|
398
|
+
return f"{base}.{node.attr.value}"
|
|
399
|
+
return str(node.attr.value)
|
|
400
|
+
return None
|
|
401
|
+
|
|
402
|
+
def _get_module_name(self, module: cst.BaseExpression) -> str:
|
|
403
|
+
"""Get the full module name from a Name or Attribute node."""
|
|
404
|
+
if isinstance(module, cst.Name):
|
|
405
|
+
return str(module.value)
|
|
406
|
+
elif isinstance(module, cst.Attribute):
|
|
407
|
+
return f"{self._get_module_name(module.value)}.{module.attr.value}"
|
|
408
|
+
return ""
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
class AiohttpImportTransformer(BaseTransformer):
|
|
412
|
+
"""Separate transformer for handling import additions.
|
|
413
|
+
|
|
414
|
+
This runs after the main transformer to add any missing imports.
|
|
415
|
+
"""
|
|
416
|
+
|
|
417
|
+
def __init__(
|
|
418
|
+
self,
|
|
419
|
+
needs_asyncio_import: bool = False,
|
|
420
|
+
has_asyncio_import: bool = False,
|
|
421
|
+
needs_client_timeout_import: bool = False,
|
|
422
|
+
has_client_timeout_import: bool = False,
|
|
423
|
+
) -> None:
|
|
424
|
+
super().__init__()
|
|
425
|
+
self._needs_asyncio_import = needs_asyncio_import
|
|
426
|
+
self._has_asyncio_import = has_asyncio_import
|
|
427
|
+
self._needs_client_timeout_import = needs_client_timeout_import
|
|
428
|
+
self._has_client_timeout_import = has_client_timeout_import
|
|
429
|
+
self._found_aiohttp_import = False
|
|
430
|
+
|
|
431
|
+
def visit_Import(self, node: cst.Import) -> bool:
|
|
432
|
+
"""Check for existing asyncio import."""
|
|
433
|
+
if isinstance(node.names, cst.ImportStar):
|
|
434
|
+
return True
|
|
435
|
+
|
|
436
|
+
for name in node.names:
|
|
437
|
+
if isinstance(name, cst.ImportAlias):
|
|
438
|
+
if isinstance(name.name, cst.Name) and name.name.value == "asyncio":
|
|
439
|
+
self._has_asyncio_import = True
|
|
440
|
+
|
|
441
|
+
return True
|
|
442
|
+
|
|
443
|
+
def visit_ImportFrom(self, node: cst.ImportFrom) -> bool:
|
|
444
|
+
"""Check existing aiohttp imports."""
|
|
445
|
+
if node.module is None:
|
|
446
|
+
return True
|
|
447
|
+
|
|
448
|
+
module_name = self._get_module_name(node.module)
|
|
449
|
+
|
|
450
|
+
if module_name == "asyncio":
|
|
451
|
+
self._has_asyncio_import = True
|
|
452
|
+
|
|
453
|
+
if module_name == "aiohttp":
|
|
454
|
+
self._found_aiohttp_import = True
|
|
455
|
+
if not isinstance(node.names, cst.ImportStar):
|
|
456
|
+
for name in node.names:
|
|
457
|
+
if isinstance(name, cst.ImportAlias):
|
|
458
|
+
if isinstance(name.name, cst.Name):
|
|
459
|
+
if name.name.value == "ClientTimeout":
|
|
460
|
+
self._has_client_timeout_import = True
|
|
461
|
+
|
|
462
|
+
return True
|
|
463
|
+
|
|
464
|
+
def leave_ImportFrom(
|
|
465
|
+
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
|
|
466
|
+
) -> cst.ImportFrom:
|
|
467
|
+
"""Add missing imports to aiohttp import statement."""
|
|
468
|
+
if updated_node.module is None:
|
|
469
|
+
return updated_node
|
|
470
|
+
|
|
471
|
+
module_name = self._get_module_name(updated_node.module)
|
|
472
|
+
|
|
473
|
+
if module_name != "aiohttp":
|
|
474
|
+
return updated_node
|
|
475
|
+
|
|
476
|
+
if isinstance(updated_node.names, cst.ImportStar):
|
|
477
|
+
return updated_node
|
|
478
|
+
|
|
479
|
+
new_names = list(updated_node.names)
|
|
480
|
+
changed = False
|
|
481
|
+
|
|
482
|
+
# Add ClientTimeout import if needed
|
|
483
|
+
if self._needs_client_timeout_import and not self._has_client_timeout_import:
|
|
484
|
+
new_names.append(cst.ImportAlias(name=cst.Name("ClientTimeout")))
|
|
485
|
+
self._has_client_timeout_import = True
|
|
486
|
+
changed = True
|
|
487
|
+
|
|
488
|
+
self.record_change(
|
|
489
|
+
description="Add 'ClientTimeout' import for timeout transformation",
|
|
490
|
+
line_number=1,
|
|
491
|
+
original="from aiohttp import ...",
|
|
492
|
+
replacement="from aiohttp import ..., ClientTimeout",
|
|
493
|
+
transform_name="add_client_timeout_import",
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
if changed:
|
|
497
|
+
return updated_node.with_changes(names=new_names)
|
|
498
|
+
|
|
499
|
+
return updated_node
|
|
500
|
+
|
|
501
|
+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
|
502
|
+
"""Add missing imports at module level."""
|
|
503
|
+
new_body = list(updated_node.body)
|
|
504
|
+
changed = False
|
|
505
|
+
|
|
506
|
+
# Add asyncio import if needed
|
|
507
|
+
if self._needs_asyncio_import and not self._has_asyncio_import:
|
|
508
|
+
asyncio_import = cst.SimpleStatementLine(
|
|
509
|
+
body=[cst.Import(names=[cst.ImportAlias(name=cst.Name("asyncio"))])]
|
|
510
|
+
)
|
|
511
|
+
# Insert after any __future__ imports
|
|
512
|
+
insert_pos = 0
|
|
513
|
+
for i, stmt in enumerate(new_body):
|
|
514
|
+
if isinstance(stmt, cst.SimpleStatementLine):
|
|
515
|
+
for s in stmt.body:
|
|
516
|
+
if isinstance(s, cst.ImportFrom):
|
|
517
|
+
if s.module and self._get_module_name(s.module) == "__future__":
|
|
518
|
+
insert_pos = i + 1
|
|
519
|
+
break
|
|
520
|
+
else:
|
|
521
|
+
break
|
|
522
|
+
|
|
523
|
+
new_body.insert(insert_pos, asyncio_import)
|
|
524
|
+
self._has_asyncio_import = True
|
|
525
|
+
changed = True
|
|
526
|
+
|
|
527
|
+
self.record_change(
|
|
528
|
+
description="Add 'asyncio' import for get_event_loop()",
|
|
529
|
+
line_number=1,
|
|
530
|
+
original="",
|
|
531
|
+
replacement="import asyncio",
|
|
532
|
+
transform_name="add_asyncio_import",
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
# Add ClientTimeout import if needed and no aiohttp import exists
|
|
536
|
+
if (
|
|
537
|
+
self._needs_client_timeout_import
|
|
538
|
+
and not self._has_client_timeout_import
|
|
539
|
+
and not self._found_aiohttp_import
|
|
540
|
+
):
|
|
541
|
+
client_timeout_import = cst.SimpleStatementLine(
|
|
542
|
+
body=[
|
|
543
|
+
cst.ImportFrom(
|
|
544
|
+
module=cst.Name("aiohttp"),
|
|
545
|
+
names=[cst.ImportAlias(name=cst.Name("ClientTimeout"))],
|
|
546
|
+
)
|
|
547
|
+
]
|
|
548
|
+
)
|
|
549
|
+
# Insert after asyncio import if we just added it
|
|
550
|
+
insert_pos = 1 if self._needs_asyncio_import else 0
|
|
551
|
+
new_body.insert(insert_pos, client_timeout_import)
|
|
552
|
+
changed = True
|
|
553
|
+
|
|
554
|
+
self.record_change(
|
|
555
|
+
description="Add 'ClientTimeout' import from aiohttp",
|
|
556
|
+
line_number=1,
|
|
557
|
+
original="",
|
|
558
|
+
replacement="from aiohttp import ClientTimeout",
|
|
559
|
+
transform_name="add_client_timeout_import",
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
if changed:
|
|
563
|
+
return updated_node.with_changes(body=new_body)
|
|
564
|
+
|
|
565
|
+
return updated_node
|
|
566
|
+
|
|
567
|
+
def _get_module_name(self, module: cst.BaseExpression) -> str:
|
|
568
|
+
"""Get the full module name from an Attribute or Name node."""
|
|
569
|
+
if isinstance(module, cst.Name):
|
|
570
|
+
return str(module.value)
|
|
571
|
+
elif isinstance(module, cst.Attribute):
|
|
572
|
+
return f"{self._get_module_name(module.value)}.{module.attr.value}"
|
|
573
|
+
return ""
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
def transform_aiohttp(source_code: str) -> tuple[str, list]:
|
|
577
|
+
"""Transform aiohttp code from 3.7 to 3.9+.
|
|
578
|
+
|
|
579
|
+
Args:
|
|
580
|
+
source_code: The source code to transform
|
|
581
|
+
|
|
582
|
+
Returns:
|
|
583
|
+
Tuple of (transformed_code, list of changes)
|
|
584
|
+
"""
|
|
585
|
+
try:
|
|
586
|
+
tree = cst.parse_module(source_code)
|
|
587
|
+
except cst.ParserSyntaxError:
|
|
588
|
+
return source_code, []
|
|
589
|
+
|
|
590
|
+
transformer = AiohttpTransformer()
|
|
591
|
+
transformer.set_source(source_code)
|
|
592
|
+
|
|
593
|
+
try:
|
|
594
|
+
transformed_tree = tree.visit(transformer)
|
|
595
|
+
|
|
596
|
+
# Second pass: add missing imports
|
|
597
|
+
import_transformer = AiohttpImportTransformer(
|
|
598
|
+
needs_asyncio_import=transformer._needs_asyncio_import,
|
|
599
|
+
has_asyncio_import=transformer._has_asyncio_import,
|
|
600
|
+
needs_client_timeout_import=transformer._needs_client_timeout_import,
|
|
601
|
+
has_client_timeout_import=transformer._has_client_timeout_import,
|
|
602
|
+
)
|
|
603
|
+
final_tree = transformed_tree.visit(import_transformer)
|
|
604
|
+
|
|
605
|
+
all_changes = transformer.changes + import_transformer.changes
|
|
606
|
+
return final_tree.code, all_changes
|
|
607
|
+
except Exception:
|
|
608
|
+
return source_code, []
|