codeshift 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. codeshift/__init__.py +8 -0
  2. codeshift/analyzer/__init__.py +5 -0
  3. codeshift/analyzer/risk_assessor.py +388 -0
  4. codeshift/api/__init__.py +1 -0
  5. codeshift/api/auth.py +182 -0
  6. codeshift/api/config.py +73 -0
  7. codeshift/api/database.py +215 -0
  8. codeshift/api/main.py +103 -0
  9. codeshift/api/models/__init__.py +55 -0
  10. codeshift/api/models/auth.py +108 -0
  11. codeshift/api/models/billing.py +92 -0
  12. codeshift/api/models/migrate.py +42 -0
  13. codeshift/api/models/usage.py +116 -0
  14. codeshift/api/routers/__init__.py +5 -0
  15. codeshift/api/routers/auth.py +440 -0
  16. codeshift/api/routers/billing.py +395 -0
  17. codeshift/api/routers/migrate.py +304 -0
  18. codeshift/api/routers/usage.py +291 -0
  19. codeshift/api/routers/webhooks.py +289 -0
  20. codeshift/cli/__init__.py +5 -0
  21. codeshift/cli/commands/__init__.py +7 -0
  22. codeshift/cli/commands/apply.py +352 -0
  23. codeshift/cli/commands/auth.py +842 -0
  24. codeshift/cli/commands/diff.py +221 -0
  25. codeshift/cli/commands/scan.py +368 -0
  26. codeshift/cli/commands/upgrade.py +436 -0
  27. codeshift/cli/commands/upgrade_all.py +518 -0
  28. codeshift/cli/main.py +221 -0
  29. codeshift/cli/quota.py +210 -0
  30. codeshift/knowledge/__init__.py +50 -0
  31. codeshift/knowledge/cache.py +167 -0
  32. codeshift/knowledge/generator.py +231 -0
  33. codeshift/knowledge/models.py +151 -0
  34. codeshift/knowledge/parser.py +270 -0
  35. codeshift/knowledge/sources.py +388 -0
  36. codeshift/knowledge_base/__init__.py +17 -0
  37. codeshift/knowledge_base/loader.py +102 -0
  38. codeshift/knowledge_base/models.py +110 -0
  39. codeshift/migrator/__init__.py +23 -0
  40. codeshift/migrator/ast_transforms.py +256 -0
  41. codeshift/migrator/engine.py +395 -0
  42. codeshift/migrator/llm_migrator.py +320 -0
  43. codeshift/migrator/transforms/__init__.py +19 -0
  44. codeshift/migrator/transforms/fastapi_transformer.py +174 -0
  45. codeshift/migrator/transforms/pandas_transformer.py +236 -0
  46. codeshift/migrator/transforms/pydantic_v1_to_v2.py +637 -0
  47. codeshift/migrator/transforms/requests_transformer.py +218 -0
  48. codeshift/migrator/transforms/sqlalchemy_transformer.py +175 -0
  49. codeshift/scanner/__init__.py +6 -0
  50. codeshift/scanner/code_scanner.py +352 -0
  51. codeshift/scanner/dependency_parser.py +473 -0
  52. codeshift/utils/__init__.py +5 -0
  53. codeshift/utils/api_client.py +266 -0
  54. codeshift/utils/cache.py +318 -0
  55. codeshift/utils/config.py +71 -0
  56. codeshift/utils/llm_client.py +221 -0
  57. codeshift/validator/__init__.py +6 -0
  58. codeshift/validator/syntax_checker.py +183 -0
  59. codeshift/validator/test_runner.py +224 -0
  60. codeshift-0.2.0.dist-info/METADATA +326 -0
  61. codeshift-0.2.0.dist-info/RECORD +65 -0
  62. codeshift-0.2.0.dist-info/WHEEL +5 -0
  63. codeshift-0.2.0.dist-info/entry_points.txt +2 -0
  64. codeshift-0.2.0.dist-info/licenses/LICENSE +21 -0
  65. codeshift-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,637 @@
1
+ """Pydantic v1 to v2 transformation using LibCST."""
2
+
3
+ from typing import Any
4
+
5
+ import libcst as cst
6
+ from libcst import matchers as m
7
+
8
+ from codeshift.migrator.ast_transforms import BaseTransformer
9
+
10
+
11
+ class PydanticV1ToV2Transformer(BaseTransformer):
12
+ """Transform Pydantic v1 code to v2."""
13
+
14
+ def __init__(self) -> None:
15
+ super().__init__()
16
+ # Track what needs to be imported
17
+ self._needs_config_dict = False
18
+ self._needs_field_validator = False
19
+ self._needs_model_validator = False
20
+ self._has_validator_import = False
21
+ self._has_root_validator_import = False
22
+ # Track classes that have inner Config
23
+ self._classes_with_config: dict[str, dict] = {}
24
+ self._current_class: str | None = None
25
+ # Track position info
26
+ self._line_offset = 0
27
+
28
+ def visit_ClassDef(self, node: cst.ClassDef) -> bool:
29
+ """Track the current class being visited."""
30
+ self._current_class = node.name.value
31
+ return True
32
+
33
+ def leave_ClassDef(
34
+ self, original_node: cst.ClassDef, updated_node: cst.ClassDef
35
+ ) -> cst.ClassDef:
36
+ """Process class definitions to convert inner Config to model_config."""
37
+ self._current_class = None
38
+
39
+ # Check if this class has a Config inner class
40
+ config_class = None
41
+ config_index = -1
42
+ new_body = list(updated_node.body.body)
43
+
44
+ for i, item in enumerate(new_body):
45
+ if isinstance(item, cst.ClassDef) and item.name.value == "Config":
46
+ config_class = item
47
+ config_index = i
48
+ break
49
+
50
+ if config_class is None:
51
+ return updated_node
52
+
53
+ # Extract Config options
54
+ config_dict = self._extract_config_options(config_class)
55
+ if not config_dict:
56
+ return updated_node
57
+
58
+ self._needs_config_dict = True
59
+
60
+ # Create model_config assignment
61
+ model_config_stmt = self._create_model_config(config_dict)
62
+
63
+ # Replace Config class with model_config
64
+ new_body[config_index] = model_config_stmt
65
+
66
+ self.record_change(
67
+ description="Convert inner Config class to model_config = ConfigDict(...)",
68
+ line_number=1, # Approximate
69
+ original="class Config: ...",
70
+ replacement="model_config = ConfigDict(...)",
71
+ transform_name="config_to_configdict",
72
+ )
73
+
74
+ return updated_node.with_changes(body=updated_node.body.with_changes(body=new_body))
75
+
76
+ def _extract_config_options(self, config_class: cst.ClassDef) -> dict[str, Any]:
77
+ """Extract configuration options from a Config class."""
78
+ options: dict[str, Any] = {}
79
+
80
+ for item in config_class.body.body:
81
+ if isinstance(item, cst.SimpleStatementLine):
82
+ for stmt in item.body:
83
+ if isinstance(stmt, cst.Assign):
84
+ for target in stmt.targets:
85
+ if isinstance(target.target, cst.Name):
86
+ name = target.target.value
87
+ value = self._extract_value(stmt.value)
88
+ if value is not None:
89
+ # Map v1 options to v2
90
+ mapped_name, mapped_value = self._map_config_option(name, value)
91
+ if mapped_name:
92
+ options[mapped_name] = mapped_value
93
+
94
+ return options
95
+
96
+ def _map_config_option(self, name: str, value: Any) -> tuple[str | None, Any]:
97
+ """Map a v1 Config option to v2 ConfigDict option."""
98
+ # Direct mappings
99
+ mappings = {
100
+ "orm_mode": ("from_attributes", value),
101
+ "validate_assignment": ("validate_assignment", value),
102
+ "extra": ("extra", value),
103
+ "frozen": ("frozen", value),
104
+ "use_enum_values": ("use_enum_values", value),
105
+ "validate_default": ("validate_default", value),
106
+ "populate_by_name": ("populate_by_name", value),
107
+ "str_strip_whitespace": ("str_strip_whitespace", value),
108
+ "str_min_length": ("str_min_length", value),
109
+ "str_max_length": ("str_max_length", value),
110
+ "arbitrary_types_allowed": ("arbitrary_types_allowed", value),
111
+ }
112
+
113
+ if name in mappings:
114
+ return mappings[name]
115
+
116
+ # Special mappings
117
+ if name == "allow_mutation":
118
+ # allow_mutation=False -> frozen=True
119
+ if value is False:
120
+ return ("frozen", True)
121
+ return (None, None)
122
+
123
+ if name == "allow_population_by_field_name":
124
+ return ("populate_by_name", value)
125
+
126
+ if name == "anystr_strip_whitespace":
127
+ return ("str_strip_whitespace", value)
128
+
129
+ if name == "underscore_attrs_are_private":
130
+ # This is the default in v2
131
+ return (None, None)
132
+
133
+ # Return as-is for unknown options (might work)
134
+ return (name, value)
135
+
136
+ def _extract_value(self, node: cst.BaseExpression) -> Any:
137
+ """Extract a Python value from a CST node."""
138
+ if isinstance(node, cst.Name):
139
+ if node.value == "True":
140
+ return True
141
+ if node.value == "False":
142
+ return False
143
+ if node.value == "None":
144
+ return None
145
+ return node.value # Return as string for enums etc.
146
+ if isinstance(node, cst.SimpleString):
147
+ # Remove quotes
148
+ return node.value[1:-1]
149
+ if isinstance(node, cst.Integer):
150
+ return int(node.value)
151
+ if isinstance(node, cst.Float):
152
+ return float(node.value)
153
+ return None
154
+
155
+ def _create_model_config(self, config_dict: dict[str, Any]) -> cst.SimpleStatementLine:
156
+ """Create a model_config = ConfigDict(...) statement."""
157
+ args = []
158
+ for key, value in config_dict.items():
159
+ # Create the value node
160
+ value_node: cst.BaseExpression
161
+ if isinstance(value, bool):
162
+ value_node = cst.Name("True" if value else "False")
163
+ elif isinstance(value, str):
164
+ value_node = cst.SimpleString(f'"{value}"')
165
+ elif isinstance(value, int):
166
+ value_node = cst.Integer(str(value))
167
+ elif isinstance(value, float):
168
+ value_node = cst.Float(str(value))
169
+ else:
170
+ value_node = cst.Name(str(value))
171
+
172
+ args.append(
173
+ cst.Arg(
174
+ keyword=cst.Name(key),
175
+ value=value_node,
176
+ equal=cst.AssignEqual(
177
+ whitespace_before=cst.SimpleWhitespace(""),
178
+ whitespace_after=cst.SimpleWhitespace(""),
179
+ ),
180
+ )
181
+ )
182
+
183
+ # Create ConfigDict call
184
+ config_dict_call = cst.Call(
185
+ func=cst.Name("ConfigDict"),
186
+ args=args,
187
+ )
188
+
189
+ # Create the assignment
190
+ return cst.SimpleStatementLine(
191
+ body=[
192
+ cst.Assign(
193
+ targets=[cst.AssignTarget(target=cst.Name("model_config"))],
194
+ value=config_dict_call,
195
+ )
196
+ ]
197
+ )
198
+
199
+ def leave_Decorator(
200
+ self, original_node: cst.Decorator, updated_node: cst.Decorator
201
+ ) -> cst.Decorator:
202
+ """Transform @validator to @field_validator and @root_validator to @model_validator."""
203
+ # Handle @validator
204
+ if m.matches(
205
+ updated_node.decorator,
206
+ m.Call(func=m.Name("validator")) | m.Name("validator"),
207
+ ):
208
+ self._needs_field_validator = True
209
+ return self._transform_validator_decorator(updated_node)
210
+
211
+ # Handle @root_validator
212
+ if m.matches(
213
+ updated_node.decorator,
214
+ m.Call(func=m.Name("root_validator")) | m.Name("root_validator"),
215
+ ):
216
+ self._needs_model_validator = True
217
+ return self._transform_root_validator_decorator(updated_node)
218
+
219
+ return updated_node
220
+
221
+ def _transform_validator_decorator(self, node: cst.Decorator) -> cst.Decorator:
222
+ """Transform @validator("field") to @field_validator("field")."""
223
+ if isinstance(node.decorator, cst.Call):
224
+ # @validator("field_name", ...)
225
+ new_call = node.decorator.with_changes(func=cst.Name("field_validator"))
226
+
227
+ self.record_change(
228
+ description="Convert @validator to @field_validator",
229
+ line_number=1,
230
+ original="@validator(...)",
231
+ replacement="@field_validator(...)",
232
+ transform_name="validator_to_field_validator",
233
+ )
234
+
235
+ return node.with_changes(decorator=new_call)
236
+ else:
237
+ # @validator without arguments (shouldn't happen but handle it)
238
+ self.record_change(
239
+ description="Convert @validator to @field_validator",
240
+ line_number=1,
241
+ original="@validator",
242
+ replacement="@field_validator",
243
+ transform_name="validator_to_field_validator",
244
+ )
245
+ return node.with_changes(decorator=cst.Name("field_validator"))
246
+
247
+ def _transform_root_validator_decorator(self, node: cst.Decorator) -> cst.Decorator:
248
+ """Transform @root_validator to @model_validator(mode='before')."""
249
+ mode = "before" # Default for v1 root_validator
250
+
251
+ if isinstance(node.decorator, cst.Call):
252
+ # Check for pre=False which means mode='after'
253
+ for arg in node.decorator.args:
254
+ if isinstance(arg.keyword, cst.Name) and arg.keyword.value == "pre":
255
+ if isinstance(arg.value, cst.Name) and arg.value.value == "False":
256
+ mode = "after"
257
+
258
+ # Create @model_validator(mode="before") or (mode="after")
259
+ new_decorator = cst.Call(
260
+ func=cst.Name("model_validator"),
261
+ args=[
262
+ cst.Arg(
263
+ keyword=cst.Name("mode"),
264
+ value=cst.SimpleString(f'"{mode}"'),
265
+ equal=cst.AssignEqual(
266
+ whitespace_before=cst.SimpleWhitespace(""),
267
+ whitespace_after=cst.SimpleWhitespace(""),
268
+ ),
269
+ )
270
+ ],
271
+ )
272
+
273
+ self.record_change(
274
+ description=f"Convert @root_validator to @model_validator(mode='{mode}')",
275
+ line_number=1,
276
+ original="@root_validator",
277
+ replacement=f'@model_validator(mode="{mode}")',
278
+ transform_name="root_validator_to_model_validator",
279
+ )
280
+
281
+ return node.with_changes(decorator=new_decorator)
282
+
283
+ def leave_FunctionDef(
284
+ self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
285
+ ) -> cst.FunctionDef:
286
+ """Add @classmethod decorator to validator methods if needed."""
287
+ # Check if this function has @field_validator or @model_validator
288
+ has_field_validator = False
289
+ has_model_validator = False
290
+ has_classmethod = False
291
+
292
+ for decorator in updated_node.decorators:
293
+ dec = decorator.decorator
294
+ if isinstance(dec, cst.Call):
295
+ if isinstance(dec.func, cst.Name):
296
+ if dec.func.value == "field_validator":
297
+ has_field_validator = True
298
+ elif dec.func.value == "model_validator":
299
+ has_model_validator = True
300
+ elif isinstance(dec, cst.Name):
301
+ if dec.value == "classmethod":
302
+ has_classmethod = True
303
+ elif dec.value == "field_validator":
304
+ has_field_validator = True
305
+ elif dec.value == "model_validator":
306
+ has_model_validator = True
307
+
308
+ # Add @classmethod if needed
309
+ if (has_field_validator or has_model_validator) and not has_classmethod:
310
+ classmethod_decorator = cst.Decorator(decorator=cst.Name("classmethod"))
311
+ new_decorators = list(updated_node.decorators) + [classmethod_decorator]
312
+
313
+ self.record_change(
314
+ description="Add @classmethod decorator to validator",
315
+ line_number=1,
316
+ original="def method(cls, ...)",
317
+ replacement="@classmethod\ndef method(cls, ...)",
318
+ transform_name="add_classmethod",
319
+ )
320
+
321
+ return updated_node.with_changes(decorators=new_decorators)
322
+
323
+ return updated_node
324
+
325
+ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.BaseExpression:
326
+ """Transform method calls like .dict() to .model_dump()."""
327
+ # Handle method calls on objects
328
+ if isinstance(updated_node.func, cst.Attribute):
329
+ method_name = updated_node.func.attr.value
330
+
331
+ method_mappings = {
332
+ "dict": "model_dump",
333
+ "json": "model_dump_json",
334
+ "copy": "model_copy",
335
+ "parse_obj": "model_validate",
336
+ "parse_raw": "model_validate_json",
337
+ "schema": "model_json_schema",
338
+ "schema_json": "model_json_schema",
339
+ "update_forward_refs": "model_rebuild",
340
+ }
341
+
342
+ if method_name in method_mappings:
343
+ new_method = method_mappings[method_name]
344
+ new_attr = updated_node.func.with_changes(attr=cst.Name(new_method))
345
+
346
+ self.record_change(
347
+ description=f"Convert .{method_name}() to .{new_method}()",
348
+ line_number=1,
349
+ original=f".{method_name}()",
350
+ replacement=f".{new_method}()",
351
+ transform_name=f"{method_name}_to_{new_method}",
352
+ )
353
+
354
+ return updated_node.with_changes(func=new_attr)
355
+
356
+ # Handle Field(regex=...) -> Field(pattern=...)
357
+ if isinstance(updated_node.func, cst.Name) and updated_node.func.value == "Field":
358
+ new_args = []
359
+ changed = False
360
+
361
+ for arg in updated_node.args:
362
+ if isinstance(arg.keyword, cst.Name) and arg.keyword.value == "regex":
363
+ # Change regex to pattern
364
+ new_arg = arg.with_changes(keyword=cst.Name("pattern"))
365
+ new_args.append(new_arg)
366
+ changed = True
367
+
368
+ self.record_change(
369
+ description="Convert Field(regex=...) to Field(pattern=...)",
370
+ line_number=1,
371
+ original="Field(regex=...)",
372
+ replacement="Field(pattern=...)",
373
+ transform_name="field_regex_to_pattern",
374
+ )
375
+ elif isinstance(arg.keyword, cst.Name) and arg.keyword.value == "min_items":
376
+ new_arg = arg.with_changes(keyword=cst.Name("min_length"))
377
+ new_args.append(new_arg)
378
+ changed = True
379
+
380
+ self.record_change(
381
+ description="Convert Field(min_items=...) to Field(min_length=...)",
382
+ line_number=1,
383
+ original="Field(min_items=...)",
384
+ replacement="Field(min_length=...)",
385
+ transform_name="field_min_items_to_min_length",
386
+ )
387
+ elif isinstance(arg.keyword, cst.Name) and arg.keyword.value == "max_items":
388
+ new_arg = arg.with_changes(keyword=cst.Name("max_length"))
389
+ new_args.append(new_arg)
390
+ changed = True
391
+
392
+ self.record_change(
393
+ description="Convert Field(max_items=...) to Field(max_length=...)",
394
+ line_number=1,
395
+ original="Field(max_items=...)",
396
+ replacement="Field(max_length=...)",
397
+ transform_name="field_max_items_to_max_length",
398
+ )
399
+ else:
400
+ new_args.append(arg)
401
+
402
+ if changed:
403
+ return updated_node.with_changes(args=new_args)
404
+
405
+ return updated_node
406
+
407
+ def leave_Attribute(
408
+ self, original_node: cst.Attribute, updated_node: cst.Attribute
409
+ ) -> cst.BaseExpression:
410
+ """Transform attribute access like __fields__ to model_fields."""
411
+ attr_name = updated_node.attr.value
412
+
413
+ attr_mappings = {
414
+ "__fields__": "model_fields",
415
+ "__validators__": "__pydantic_decorators__",
416
+ }
417
+
418
+ if attr_name in attr_mappings:
419
+ new_attr = attr_mappings[attr_name]
420
+
421
+ self.record_change(
422
+ description=f"Convert {attr_name} to {new_attr}",
423
+ line_number=1,
424
+ original=attr_name,
425
+ replacement=new_attr,
426
+ transform_name=f"{attr_name}_rename",
427
+ )
428
+
429
+ return updated_node.with_changes(attr=cst.Name(new_attr))
430
+
431
+ return updated_node
432
+
433
+ def leave_ImportFrom(
434
+ self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
435
+ ) -> cst.ImportFrom:
436
+ """Transform imports from pydantic."""
437
+ if updated_node.module is None:
438
+ return updated_node
439
+
440
+ module_name = self._get_module_name(updated_node.module)
441
+ if module_name != "pydantic":
442
+ return updated_node
443
+
444
+ if isinstance(updated_node.names, cst.ImportStar):
445
+ return updated_node
446
+
447
+ new_names = []
448
+ changed = False
449
+
450
+ for name in updated_node.names:
451
+ if isinstance(name, cst.ImportAlias):
452
+ imported_name = self._get_name_value(name.name)
453
+
454
+ if imported_name == "validator":
455
+ self._has_validator_import = True
456
+ new_name = name.with_changes(name=cst.Name("field_validator"))
457
+ new_names.append(new_name)
458
+ changed = True
459
+
460
+ self.record_change(
461
+ description="Convert 'validator' import to 'field_validator'",
462
+ line_number=1,
463
+ original="from pydantic import validator",
464
+ replacement="from pydantic import field_validator",
465
+ transform_name="import_validator_to_field_validator",
466
+ )
467
+ elif imported_name == "root_validator":
468
+ self._has_root_validator_import = True
469
+ new_name = name.with_changes(name=cst.Name("model_validator"))
470
+ new_names.append(new_name)
471
+ changed = True
472
+
473
+ self.record_change(
474
+ description="Convert 'root_validator' import to 'model_validator'",
475
+ line_number=1,
476
+ original="from pydantic import root_validator",
477
+ replacement="from pydantic import model_validator",
478
+ transform_name="import_root_validator_to_model_validator",
479
+ )
480
+ else:
481
+ new_names.append(name)
482
+
483
+ # Add ConfigDict import if needed
484
+ if self._needs_config_dict:
485
+ # Check if ConfigDict is already imported
486
+ has_config_dict = any(
487
+ isinstance(n, cst.ImportAlias) and self._get_name_value(n.name) == "ConfigDict"
488
+ for n in new_names
489
+ )
490
+ if not has_config_dict:
491
+ new_names.append(cst.ImportAlias(name=cst.Name("ConfigDict")))
492
+ changed = True
493
+
494
+ if changed:
495
+ return updated_node.with_changes(names=new_names)
496
+
497
+ return updated_node
498
+
499
+ def _get_module_name(self, node: cst.BaseExpression) -> str:
500
+ """Get the full module name from an Attribute or Name node."""
501
+ if isinstance(node, cst.Name):
502
+ return str(node.value)
503
+ if isinstance(node, cst.Attribute):
504
+ base = self._get_module_name(node.value)
505
+ return f"{base}.{node.attr.value}"
506
+ return ""
507
+
508
+ def _get_name_value(self, node: cst.BaseExpression) -> str | None:
509
+ """Extract the string value from a Name node."""
510
+ if isinstance(node, cst.Name):
511
+ return str(node.value)
512
+ return None
513
+
514
+
515
+ class PydanticImportTransformer(BaseTransformer):
516
+ """Separate transformer for handling import additions.
517
+
518
+ This runs after the main transformer to add any missing imports.
519
+ """
520
+
521
+ def __init__(
522
+ self,
523
+ needs_config_dict: bool = False,
524
+ needs_field_validator: bool = False,
525
+ needs_model_validator: bool = False,
526
+ ) -> None:
527
+ super().__init__()
528
+ self.needs_config_dict = needs_config_dict
529
+ self.needs_field_validator = needs_field_validator
530
+ self.needs_model_validator = needs_model_validator
531
+ self._found_pydantic_import = False
532
+ self._has_config_dict = False
533
+ self._has_field_validator = False
534
+ self._has_model_validator = False
535
+
536
+ def visit_ImportFrom(self, node: cst.ImportFrom) -> bool:
537
+ """Check existing pydantic imports."""
538
+ if node.module is None:
539
+ return True
540
+
541
+ if self._get_module_name(node.module) == "pydantic":
542
+ self._found_pydantic_import = True
543
+
544
+ # Check for existing imports (handle both tuple and list after transforms)
545
+ if not isinstance(node.names, cst.ImportStar):
546
+ for name in node.names:
547
+ if isinstance(name, cst.ImportAlias):
548
+ imported = self._get_name_value(name.name)
549
+ if imported == "ConfigDict":
550
+ self._has_config_dict = True
551
+ elif imported == "field_validator":
552
+ self._has_field_validator = True
553
+ elif imported == "model_validator":
554
+ self._has_model_validator = True
555
+
556
+ return True
557
+
558
+ def leave_ImportFrom(
559
+ self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
560
+ ) -> cst.ImportFrom:
561
+ """Add missing imports to pydantic import statement."""
562
+ if updated_node.module is None:
563
+ return updated_node
564
+
565
+ if self._get_module_name(updated_node.module) != "pydantic":
566
+ return updated_node
567
+
568
+ if isinstance(updated_node.names, cst.ImportStar):
569
+ return updated_node
570
+
571
+ new_names = list(updated_node.names)
572
+ changed = False
573
+
574
+ if self.needs_config_dict and not self._has_config_dict:
575
+ new_names.append(cst.ImportAlias(name=cst.Name("ConfigDict")))
576
+ self._has_config_dict = True
577
+ changed = True
578
+
579
+ if self.needs_field_validator and not self._has_field_validator:
580
+ new_names.append(cst.ImportAlias(name=cst.Name("field_validator")))
581
+ self._has_field_validator = True
582
+ changed = True
583
+
584
+ if self.needs_model_validator and not self._has_model_validator:
585
+ new_names.append(cst.ImportAlias(name=cst.Name("model_validator")))
586
+ self._has_model_validator = True
587
+ changed = True
588
+
589
+ if changed:
590
+ return updated_node.with_changes(names=new_names)
591
+
592
+ return updated_node
593
+
594
+ def _get_module_name(self, node: cst.BaseExpression) -> str:
595
+ """Get the full module name from an Attribute or Name node."""
596
+ if isinstance(node, cst.Name):
597
+ return str(node.value)
598
+ if isinstance(node, cst.Attribute):
599
+ base = self._get_module_name(node.value)
600
+ return f"{base}.{node.attr.value}"
601
+ return ""
602
+
603
+ def _get_name_value(self, node: cst.BaseExpression) -> str | None:
604
+ """Extract the string value from a Name node."""
605
+ if isinstance(node, cst.Name):
606
+ return str(node.value)
607
+ return None
608
+
609
+
610
+ def transform_pydantic_v1_to_v2(source_code: str) -> tuple[str, list]:
611
+ """Transform Pydantic v1 code to v2.
612
+
613
+ Args:
614
+ source_code: The source code to transform
615
+
616
+ Returns:
617
+ Tuple of (transformed_code, list of changes)
618
+ """
619
+ try:
620
+ tree = cst.parse_module(source_code)
621
+ except cst.ParserSyntaxError as e:
622
+ raise SyntaxError(f"Invalid Python syntax: {e}") from e
623
+
624
+ # First pass: main transformations
625
+ transformer = PydanticV1ToV2Transformer()
626
+ transformer.set_source(source_code)
627
+ transformed_tree = tree.visit(transformer)
628
+
629
+ # Second pass: add missing imports
630
+ import_transformer = PydanticImportTransformer(
631
+ needs_config_dict=transformer._needs_config_dict,
632
+ needs_field_validator=transformer._needs_field_validator,
633
+ needs_model_validator=transformer._needs_model_validator,
634
+ )
635
+ final_tree = transformed_tree.visit(import_transformer)
636
+
637
+ return final_tree.code, transformer.changes