codeshift 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codeshift/cli/commands/apply.py +24 -2
- codeshift/cli/commands/upgrade_all.py +4 -1
- codeshift/cli/package_manager.py +102 -0
- codeshift/knowledge/generator.py +11 -1
- codeshift/knowledge_base/libraries/aiohttp.yaml +186 -0
- codeshift/knowledge_base/libraries/attrs.yaml +181 -0
- codeshift/knowledge_base/libraries/celery.yaml +244 -0
- codeshift/knowledge_base/libraries/click.yaml +195 -0
- codeshift/knowledge_base/libraries/django.yaml +355 -0
- codeshift/knowledge_base/libraries/flask.yaml +270 -0
- codeshift/knowledge_base/libraries/httpx.yaml +183 -0
- codeshift/knowledge_base/libraries/marshmallow.yaml +238 -0
- codeshift/knowledge_base/libraries/numpy.yaml +429 -0
- codeshift/knowledge_base/libraries/pytest.yaml +192 -0
- codeshift/knowledge_base/libraries/sqlalchemy.yaml +2 -1
- codeshift/migrator/engine.py +60 -0
- codeshift/migrator/transforms/__init__.py +2 -0
- codeshift/migrator/transforms/aiohttp_transformer.py +608 -0
- codeshift/migrator/transforms/attrs_transformer.py +570 -0
- codeshift/migrator/transforms/celery_transformer.py +546 -0
- codeshift/migrator/transforms/click_transformer.py +526 -0
- codeshift/migrator/transforms/django_transformer.py +852 -0
- codeshift/migrator/transforms/fastapi_transformer.py +12 -7
- codeshift/migrator/transforms/flask_transformer.py +505 -0
- codeshift/migrator/transforms/httpx_transformer.py +419 -0
- codeshift/migrator/transforms/marshmallow_transformer.py +515 -0
- codeshift/migrator/transforms/numpy_transformer.py +413 -0
- codeshift/migrator/transforms/pydantic_v1_to_v2.py +53 -8
- codeshift/migrator/transforms/pytest_transformer.py +351 -0
- codeshift/migrator/transforms/requests_transformer.py +74 -1
- codeshift/migrator/transforms/sqlalchemy_transformer.py +692 -39
- codeshift/scanner/dependency_parser.py +1 -1
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/METADATA +46 -4
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/RECORD +38 -17
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/WHEEL +0 -0
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/entry_points.txt +0 -0
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/licenses/LICENSE +0 -0
- {codeshift-0.3.2.dist-info → codeshift-0.3.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
"""NumPy 1.x to 2.0 transformation using LibCST."""
|
|
2
|
+
|
|
3
|
+
import libcst as cst
|
|
4
|
+
|
|
5
|
+
from codeshift.migrator.ast_transforms import BaseTransformer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class NumPyTransformer(BaseTransformer):
|
|
9
|
+
"""Transform NumPy 1.x code to 2.0.
|
|
10
|
+
|
|
11
|
+
Handles the following breaking changes:
|
|
12
|
+
- Type alias removals (np.bool, np.int, np.float, np.complex, np.object, np.str)
|
|
13
|
+
- Function renames (alltrue, sometrue, product, cumproduct, trapz, in1d, row_stack, msort)
|
|
14
|
+
- Constant renames (Inf, Infinity, infty, NaN, PINF, NINF, PZERO, NZERO)
|
|
15
|
+
- Other deprecated/removed functions
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
# Type alias mappings: old_name -> new_name
|
|
19
|
+
TYPE_ALIAS_MAPPINGS = {
|
|
20
|
+
# Python builtin shadows (high priority)
|
|
21
|
+
"bool": "bool_",
|
|
22
|
+
"int": "int_",
|
|
23
|
+
"float": "float64",
|
|
24
|
+
"complex": "complex128",
|
|
25
|
+
"object": "object_",
|
|
26
|
+
"str": "str_",
|
|
27
|
+
# Other type aliases
|
|
28
|
+
"unicode_": "str_",
|
|
29
|
+
"string_": "bytes_",
|
|
30
|
+
"float_": "float64",
|
|
31
|
+
"complex_": "complex128",
|
|
32
|
+
"cfloat": "complex128",
|
|
33
|
+
"singlecomplex": "complex64",
|
|
34
|
+
"longfloat": "longdouble",
|
|
35
|
+
"longcomplex": "clongdouble",
|
|
36
|
+
"clongfloat": "clongdouble",
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
# Function renames: old_name -> new_name
|
|
40
|
+
FUNCTION_RENAMES = {
|
|
41
|
+
"alltrue": "all",
|
|
42
|
+
"sometrue": "any",
|
|
43
|
+
"product": "prod",
|
|
44
|
+
"cumproduct": "cumprod",
|
|
45
|
+
"trapz": "trapezoid",
|
|
46
|
+
"in1d": "isin",
|
|
47
|
+
"row_stack": "vstack",
|
|
48
|
+
"issubsctype": "issubdtype",
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
# Constant renames: old_name -> new_name
|
|
52
|
+
CONSTANT_RENAMES = {
|
|
53
|
+
"Inf": "inf",
|
|
54
|
+
"Infinity": "inf",
|
|
55
|
+
"infty": "inf",
|
|
56
|
+
"NaN": "nan",
|
|
57
|
+
"PINF": "inf",
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
# Constants that need special handling (replacement with expressions)
|
|
61
|
+
CONSTANT_SPECIAL = {
|
|
62
|
+
"NINF": "-np.inf", # Requires special handling
|
|
63
|
+
"PZERO": "0.0",
|
|
64
|
+
"NZERO": "-0.0",
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
def __init__(self) -> None:
|
|
68
|
+
super().__init__()
|
|
69
|
+
self._numpy_aliases: set[str] = {"np", "numpy"}
|
|
70
|
+
self._has_numpy_import = False
|
|
71
|
+
|
|
72
|
+
def visit_ImportFrom(self, node: cst.ImportFrom) -> bool:
|
|
73
|
+
"""Track numpy imports to detect aliases."""
|
|
74
|
+
if node.module is None:
|
|
75
|
+
return True
|
|
76
|
+
|
|
77
|
+
module_name = self._get_module_name(node.module)
|
|
78
|
+
if module_name == "numpy" or module_name.startswith("numpy."):
|
|
79
|
+
self._has_numpy_import = True
|
|
80
|
+
return True
|
|
81
|
+
|
|
82
|
+
def visit_Import(self, node: cst.Import) -> bool:
|
|
83
|
+
"""Track numpy import aliases (e.g., import numpy as np)."""
|
|
84
|
+
if isinstance(node.names, cst.ImportStar):
|
|
85
|
+
return True
|
|
86
|
+
|
|
87
|
+
for alias in node.names:
|
|
88
|
+
if isinstance(alias, cst.ImportAlias):
|
|
89
|
+
name = self._get_name_value(alias.name)
|
|
90
|
+
if name == "numpy":
|
|
91
|
+
self._has_numpy_import = True
|
|
92
|
+
if alias.asname:
|
|
93
|
+
if isinstance(alias.asname, cst.AsName):
|
|
94
|
+
if isinstance(alias.asname.name, cst.Name):
|
|
95
|
+
self._numpy_aliases.add(alias.asname.name.value)
|
|
96
|
+
return True
|
|
97
|
+
|
|
98
|
+
def leave_Attribute(
|
|
99
|
+
self, original_node: cst.Attribute, updated_node: cst.Attribute
|
|
100
|
+
) -> cst.BaseExpression:
|
|
101
|
+
"""Transform numpy attribute accesses."""
|
|
102
|
+
attr_name = updated_node.attr.value
|
|
103
|
+
|
|
104
|
+
# Check if this is a numpy attribute access
|
|
105
|
+
if not self._is_numpy_attribute(updated_node):
|
|
106
|
+
return updated_node
|
|
107
|
+
|
|
108
|
+
# Handle type alias removals
|
|
109
|
+
if attr_name in self.TYPE_ALIAS_MAPPINGS:
|
|
110
|
+
new_attr = self.TYPE_ALIAS_MAPPINGS[attr_name]
|
|
111
|
+
self.record_change(
|
|
112
|
+
description=f"Replace numpy.{attr_name} with numpy.{new_attr}",
|
|
113
|
+
line_number=1,
|
|
114
|
+
original=f"numpy.{attr_name}",
|
|
115
|
+
replacement=f"numpy.{new_attr}",
|
|
116
|
+
transform_name=f"{attr_name}_to_{new_attr}",
|
|
117
|
+
)
|
|
118
|
+
return updated_node.with_changes(attr=cst.Name(new_attr))
|
|
119
|
+
|
|
120
|
+
# Handle constant renames
|
|
121
|
+
if attr_name in self.CONSTANT_RENAMES:
|
|
122
|
+
new_attr = self.CONSTANT_RENAMES[attr_name]
|
|
123
|
+
self.record_change(
|
|
124
|
+
description=f"Replace numpy.{attr_name} with numpy.{new_attr}",
|
|
125
|
+
line_number=1,
|
|
126
|
+
original=f"numpy.{attr_name}",
|
|
127
|
+
replacement=f"numpy.{new_attr}",
|
|
128
|
+
transform_name=f"{attr_name}_to_{new_attr}",
|
|
129
|
+
)
|
|
130
|
+
return updated_node.with_changes(attr=cst.Name(new_attr))
|
|
131
|
+
|
|
132
|
+
# Handle NINF -> -np.inf
|
|
133
|
+
if attr_name == "NINF":
|
|
134
|
+
self.record_change(
|
|
135
|
+
description="Replace numpy.NINF with -numpy.inf",
|
|
136
|
+
line_number=1,
|
|
137
|
+
original="numpy.NINF",
|
|
138
|
+
replacement="-numpy.inf",
|
|
139
|
+
transform_name="NINF_to_neg_inf",
|
|
140
|
+
)
|
|
141
|
+
return cst.UnaryOperation(
|
|
142
|
+
operator=cst.Minus(),
|
|
143
|
+
expression=updated_node.with_changes(attr=cst.Name("inf")),
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Handle PZERO -> 0.0
|
|
147
|
+
if attr_name == "PZERO":
|
|
148
|
+
self.record_change(
|
|
149
|
+
description="Replace numpy.PZERO with 0.0",
|
|
150
|
+
line_number=1,
|
|
151
|
+
original="numpy.PZERO",
|
|
152
|
+
replacement="0.0",
|
|
153
|
+
transform_name="PZERO_to_zero",
|
|
154
|
+
)
|
|
155
|
+
return cst.Float("0.0")
|
|
156
|
+
|
|
157
|
+
# Handle NZERO -> -0.0
|
|
158
|
+
if attr_name == "NZERO":
|
|
159
|
+
self.record_change(
|
|
160
|
+
description="Replace numpy.NZERO with -0.0",
|
|
161
|
+
line_number=1,
|
|
162
|
+
original="numpy.NZERO",
|
|
163
|
+
replacement="-0.0",
|
|
164
|
+
transform_name="NZERO_to_neg_zero",
|
|
165
|
+
)
|
|
166
|
+
return cst.UnaryOperation(
|
|
167
|
+
operator=cst.Minus(),
|
|
168
|
+
expression=cst.Float("0.0"),
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
return updated_node
|
|
172
|
+
|
|
173
|
+
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.BaseExpression:
|
|
174
|
+
"""Transform numpy function calls."""
|
|
175
|
+
# Handle direct numpy function calls like np.alltrue(), np.product(), etc.
|
|
176
|
+
if isinstance(updated_node.func, cst.Attribute):
|
|
177
|
+
attr = updated_node.func
|
|
178
|
+
func_name = attr.attr.value
|
|
179
|
+
|
|
180
|
+
if not self._is_numpy_attribute(attr):
|
|
181
|
+
return updated_node
|
|
182
|
+
|
|
183
|
+
# Handle function renames
|
|
184
|
+
if func_name in self.FUNCTION_RENAMES:
|
|
185
|
+
new_func = self.FUNCTION_RENAMES[func_name]
|
|
186
|
+
self.record_change(
|
|
187
|
+
description=f"Replace numpy.{func_name}() with numpy.{new_func}()",
|
|
188
|
+
line_number=1,
|
|
189
|
+
original=f"numpy.{func_name}()",
|
|
190
|
+
replacement=f"numpy.{new_func}()",
|
|
191
|
+
transform_name=f"{func_name}_to_{new_func}",
|
|
192
|
+
)
|
|
193
|
+
new_attr = attr.with_changes(attr=cst.Name(new_func))
|
|
194
|
+
return updated_node.with_changes(func=new_attr)
|
|
195
|
+
|
|
196
|
+
# Handle msort(a) -> sort(a, axis=0)
|
|
197
|
+
if func_name == "msort":
|
|
198
|
+
self.record_change(
|
|
199
|
+
description="Replace numpy.msort(a) with numpy.sort(a, axis=0)",
|
|
200
|
+
line_number=1,
|
|
201
|
+
original="numpy.msort(a)",
|
|
202
|
+
replacement="numpy.sort(a, axis=0)",
|
|
203
|
+
transform_name="msort_to_sort_axis0",
|
|
204
|
+
)
|
|
205
|
+
new_attr = attr.with_changes(attr=cst.Name("sort"))
|
|
206
|
+
# Add axis=0 argument
|
|
207
|
+
new_args = list(updated_node.args)
|
|
208
|
+
new_args.append(
|
|
209
|
+
cst.Arg(
|
|
210
|
+
keyword=cst.Name("axis"),
|
|
211
|
+
value=cst.Integer("0"),
|
|
212
|
+
equal=cst.AssignEqual(
|
|
213
|
+
whitespace_before=cst.SimpleWhitespace(""),
|
|
214
|
+
whitespace_after=cst.SimpleWhitespace(""),
|
|
215
|
+
),
|
|
216
|
+
)
|
|
217
|
+
)
|
|
218
|
+
return updated_node.with_changes(func=new_attr, args=new_args)
|
|
219
|
+
|
|
220
|
+
# Handle asfarray(a) -> asarray(a, dtype=float)
|
|
221
|
+
if func_name == "asfarray":
|
|
222
|
+
self.record_change(
|
|
223
|
+
description="Replace numpy.asfarray(a) with numpy.asarray(a, dtype=float)",
|
|
224
|
+
line_number=1,
|
|
225
|
+
original="numpy.asfarray(a)",
|
|
226
|
+
replacement="numpy.asarray(a, dtype=float)",
|
|
227
|
+
transform_name="asfarray_to_asarray",
|
|
228
|
+
)
|
|
229
|
+
new_attr = attr.with_changes(attr=cst.Name("asarray"))
|
|
230
|
+
# Check if dtype is already specified
|
|
231
|
+
has_dtype = any(
|
|
232
|
+
isinstance(arg.keyword, cst.Name) and arg.keyword.value == "dtype"
|
|
233
|
+
for arg in updated_node.args
|
|
234
|
+
)
|
|
235
|
+
new_args = list(updated_node.args)
|
|
236
|
+
if not has_dtype:
|
|
237
|
+
new_args.append(
|
|
238
|
+
cst.Arg(
|
|
239
|
+
keyword=cst.Name("dtype"),
|
|
240
|
+
value=cst.Name("float"),
|
|
241
|
+
equal=cst.AssignEqual(
|
|
242
|
+
whitespace_before=cst.SimpleWhitespace(""),
|
|
243
|
+
whitespace_after=cst.SimpleWhitespace(""),
|
|
244
|
+
),
|
|
245
|
+
)
|
|
246
|
+
)
|
|
247
|
+
return updated_node.with_changes(func=new_attr, args=new_args)
|
|
248
|
+
|
|
249
|
+
# Handle issubclass_(arg1, arg2) -> issubclass(arg1, arg2)
|
|
250
|
+
if func_name == "issubclass_":
|
|
251
|
+
self.record_change(
|
|
252
|
+
description="Replace numpy.issubclass_() with builtin issubclass()",
|
|
253
|
+
line_number=1,
|
|
254
|
+
original="numpy.issubclass_()",
|
|
255
|
+
replacement="issubclass()",
|
|
256
|
+
transform_name="issubclass__to_builtin",
|
|
257
|
+
)
|
|
258
|
+
return updated_node.with_changes(func=cst.Name("issubclass"))
|
|
259
|
+
|
|
260
|
+
return updated_node
|
|
261
|
+
|
|
262
|
+
def _is_numpy_attribute(self, node: cst.Attribute) -> bool:
|
|
263
|
+
"""Check if an Attribute node is accessing numpy.
|
|
264
|
+
|
|
265
|
+
Handles both 'numpy.X' and 'np.X' patterns.
|
|
266
|
+
"""
|
|
267
|
+
if isinstance(node.value, cst.Name):
|
|
268
|
+
return node.value.value in self._numpy_aliases
|
|
269
|
+
return False
|
|
270
|
+
|
|
271
|
+
def _get_module_name(self, node: cst.BaseExpression) -> str:
|
|
272
|
+
"""Get the full module name from an Attribute or Name node."""
|
|
273
|
+
if isinstance(node, cst.Name):
|
|
274
|
+
return str(node.value)
|
|
275
|
+
if isinstance(node, cst.Attribute):
|
|
276
|
+
base = self._get_module_name(node.value)
|
|
277
|
+
return f"{base}.{node.attr.value}"
|
|
278
|
+
return ""
|
|
279
|
+
|
|
280
|
+
def _get_name_value(self, node: cst.BaseExpression) -> str | None:
|
|
281
|
+
"""Extract the string value from a Name node."""
|
|
282
|
+
if isinstance(node, cst.Name):
|
|
283
|
+
return str(node.value)
|
|
284
|
+
if isinstance(node, cst.Attribute):
|
|
285
|
+
return self._get_module_name(node)
|
|
286
|
+
return None
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class NumPyImportTransformer(BaseTransformer):
|
|
290
|
+
"""Transform numpy imports (e.g., from numpy import bool -> from numpy import bool_)."""
|
|
291
|
+
|
|
292
|
+
# Import name mappings
|
|
293
|
+
IMPORT_MAPPINGS = {
|
|
294
|
+
"bool": "bool_",
|
|
295
|
+
"int": "int_",
|
|
296
|
+
"float": "float64",
|
|
297
|
+
"complex": "complex128",
|
|
298
|
+
"object": "object_",
|
|
299
|
+
"str": "str_",
|
|
300
|
+
"unicode_": "str_",
|
|
301
|
+
"string_": "bytes_",
|
|
302
|
+
"float_": "float64",
|
|
303
|
+
"complex_": "complex128",
|
|
304
|
+
"cfloat": "complex128",
|
|
305
|
+
"singlecomplex": "complex64",
|
|
306
|
+
"longfloat": "longdouble",
|
|
307
|
+
"longcomplex": "clongdouble",
|
|
308
|
+
"clongfloat": "clongdouble",
|
|
309
|
+
"alltrue": "all",
|
|
310
|
+
"sometrue": "any",
|
|
311
|
+
"product": "prod",
|
|
312
|
+
"cumproduct": "cumprod",
|
|
313
|
+
"trapz": "trapezoid",
|
|
314
|
+
"in1d": "isin",
|
|
315
|
+
"row_stack": "vstack",
|
|
316
|
+
"issubsctype": "issubdtype",
|
|
317
|
+
"Inf": "inf",
|
|
318
|
+
"Infinity": "inf",
|
|
319
|
+
"infty": "inf",
|
|
320
|
+
"NaN": "nan",
|
|
321
|
+
"PINF": "inf",
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
def leave_ImportFrom(
|
|
325
|
+
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
|
|
326
|
+
) -> cst.ImportFrom:
|
|
327
|
+
"""Transform imports from numpy."""
|
|
328
|
+
if updated_node.module is None:
|
|
329
|
+
return updated_node
|
|
330
|
+
|
|
331
|
+
module_name = self._get_module_name(updated_node.module)
|
|
332
|
+
if module_name != "numpy":
|
|
333
|
+
return updated_node
|
|
334
|
+
|
|
335
|
+
if isinstance(updated_node.names, cst.ImportStar):
|
|
336
|
+
return updated_node
|
|
337
|
+
|
|
338
|
+
new_names = []
|
|
339
|
+
changed = False
|
|
340
|
+
|
|
341
|
+
for name in updated_node.names:
|
|
342
|
+
if isinstance(name, cst.ImportAlias):
|
|
343
|
+
imported_name = self._get_name_value(name.name)
|
|
344
|
+
|
|
345
|
+
if imported_name in self.IMPORT_MAPPINGS:
|
|
346
|
+
new_import_name = self.IMPORT_MAPPINGS[imported_name]
|
|
347
|
+
new_name = name.with_changes(name=cst.Name(new_import_name))
|
|
348
|
+
new_names.append(new_name)
|
|
349
|
+
changed = True
|
|
350
|
+
|
|
351
|
+
self.record_change(
|
|
352
|
+
description=f"Replace 'from numpy import {imported_name}' with '{new_import_name}'",
|
|
353
|
+
line_number=1,
|
|
354
|
+
original=f"from numpy import {imported_name}",
|
|
355
|
+
replacement=f"from numpy import {new_import_name}",
|
|
356
|
+
transform_name=f"import_{imported_name}_to_{new_import_name}",
|
|
357
|
+
)
|
|
358
|
+
else:
|
|
359
|
+
new_names.append(name)
|
|
360
|
+
|
|
361
|
+
if changed:
|
|
362
|
+
return updated_node.with_changes(names=new_names)
|
|
363
|
+
|
|
364
|
+
return updated_node
|
|
365
|
+
|
|
366
|
+
def _get_module_name(self, node: cst.BaseExpression) -> str:
|
|
367
|
+
"""Get the full module name from an Attribute or Name node."""
|
|
368
|
+
if isinstance(node, cst.Name):
|
|
369
|
+
return str(node.value)
|
|
370
|
+
if isinstance(node, cst.Attribute):
|
|
371
|
+
base = self._get_module_name(node.value)
|
|
372
|
+
return f"{base}.{node.attr.value}"
|
|
373
|
+
return ""
|
|
374
|
+
|
|
375
|
+
def _get_name_value(self, node: cst.BaseExpression) -> str | None:
|
|
376
|
+
"""Extract the string value from a Name node."""
|
|
377
|
+
if isinstance(node, cst.Name):
|
|
378
|
+
return str(node.value)
|
|
379
|
+
return None
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def transform_numpy(source_code: str) -> tuple[str, list]:
|
|
383
|
+
"""Transform NumPy code from 1.x to 2.0.
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
source_code: The source code to transform
|
|
387
|
+
|
|
388
|
+
Returns:
|
|
389
|
+
Tuple of (transformed_code, list of changes)
|
|
390
|
+
"""
|
|
391
|
+
try:
|
|
392
|
+
tree = cst.parse_module(source_code)
|
|
393
|
+
except cst.ParserSyntaxError:
|
|
394
|
+
return source_code, []
|
|
395
|
+
|
|
396
|
+
all_changes = []
|
|
397
|
+
|
|
398
|
+
try:
|
|
399
|
+
# First pass: transform imports
|
|
400
|
+
import_transformer = NumPyImportTransformer()
|
|
401
|
+
import_transformer.set_source(source_code)
|
|
402
|
+
tree = tree.visit(import_transformer)
|
|
403
|
+
all_changes.extend(import_transformer.changes)
|
|
404
|
+
|
|
405
|
+
# Second pass: main transformations
|
|
406
|
+
transformer = NumPyTransformer()
|
|
407
|
+
transformer.set_source(tree.code)
|
|
408
|
+
tree = tree.visit(transformer)
|
|
409
|
+
all_changes.extend(transformer.changes)
|
|
410
|
+
|
|
411
|
+
return tree.code, all_changes
|
|
412
|
+
except Exception:
|
|
413
|
+
return source_code, []
|
|
@@ -219,19 +219,64 @@ class PydanticV1ToV2Transformer(BaseTransformer):
|
|
|
219
219
|
return updated_node
|
|
220
220
|
|
|
221
221
|
def _transform_validator_decorator(self, node: cst.Decorator) -> cst.Decorator:
|
|
222
|
-
"""Transform @validator("field") to @field_validator("field").
|
|
222
|
+
"""Transform @validator("field") to @field_validator("field").
|
|
223
|
+
|
|
224
|
+
Also handles pre=True -> mode="before" and pre=False -> mode="after".
|
|
225
|
+
"""
|
|
223
226
|
if isinstance(node.decorator, cst.Call):
|
|
224
227
|
# @validator("field_name", ...)
|
|
225
|
-
|
|
228
|
+
# Check for pre=True/False and convert to mode="before"/"after"
|
|
229
|
+
mode: str | None = None
|
|
230
|
+
new_args = []
|
|
226
231
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
232
|
+
for arg in node.decorator.args:
|
|
233
|
+
if isinstance(arg.keyword, cst.Name) and arg.keyword.value == "pre":
|
|
234
|
+
# Found pre argument - determine mode
|
|
235
|
+
if isinstance(arg.value, cst.Name):
|
|
236
|
+
if arg.value.value == "True":
|
|
237
|
+
mode = "before"
|
|
238
|
+
elif arg.value.value == "False":
|
|
239
|
+
mode = "after"
|
|
240
|
+
# Skip adding this argument (we'll add mode instead if needed)
|
|
241
|
+
else:
|
|
242
|
+
# Keep other arguments
|
|
243
|
+
new_args.append(arg)
|
|
244
|
+
|
|
245
|
+
# Add mode argument if pre was present
|
|
246
|
+
if mode is not None:
|
|
247
|
+
new_args.append(
|
|
248
|
+
cst.Arg(
|
|
249
|
+
keyword=cst.Name("mode"),
|
|
250
|
+
value=cst.SimpleString(f'"{mode}"'),
|
|
251
|
+
equal=cst.AssignEqual(
|
|
252
|
+
whitespace_before=cst.SimpleWhitespace(""),
|
|
253
|
+
whitespace_after=cst.SimpleWhitespace(""),
|
|
254
|
+
),
|
|
255
|
+
)
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
new_call = cst.Call(
|
|
259
|
+
func=cst.Name("field_validator"),
|
|
260
|
+
args=new_args,
|
|
233
261
|
)
|
|
234
262
|
|
|
263
|
+
if mode is not None:
|
|
264
|
+
self.record_change(
|
|
265
|
+
description=f"Convert @validator to @field_validator with mode='{mode}'",
|
|
266
|
+
line_number=1,
|
|
267
|
+
original="@validator(..., pre=...)",
|
|
268
|
+
replacement=f'@field_validator(..., mode="{mode}")',
|
|
269
|
+
transform_name="validator_to_field_validator",
|
|
270
|
+
)
|
|
271
|
+
else:
|
|
272
|
+
self.record_change(
|
|
273
|
+
description="Convert @validator to @field_validator",
|
|
274
|
+
line_number=1,
|
|
275
|
+
original="@validator(...)",
|
|
276
|
+
replacement="@field_validator(...)",
|
|
277
|
+
transform_name="validator_to_field_validator",
|
|
278
|
+
)
|
|
279
|
+
|
|
235
280
|
return node.with_changes(decorator=new_call)
|
|
236
281
|
else:
|
|
237
282
|
# @validator without arguments (shouldn't happen but handle it)
|