codeshift 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. codeshift/__init__.py +8 -0
  2. codeshift/analyzer/__init__.py +5 -0
  3. codeshift/analyzer/risk_assessor.py +388 -0
  4. codeshift/api/__init__.py +1 -0
  5. codeshift/api/auth.py +182 -0
  6. codeshift/api/config.py +73 -0
  7. codeshift/api/database.py +215 -0
  8. codeshift/api/main.py +103 -0
  9. codeshift/api/models/__init__.py +55 -0
  10. codeshift/api/models/auth.py +108 -0
  11. codeshift/api/models/billing.py +92 -0
  12. codeshift/api/models/migrate.py +42 -0
  13. codeshift/api/models/usage.py +116 -0
  14. codeshift/api/routers/__init__.py +5 -0
  15. codeshift/api/routers/auth.py +440 -0
  16. codeshift/api/routers/billing.py +395 -0
  17. codeshift/api/routers/migrate.py +304 -0
  18. codeshift/api/routers/usage.py +291 -0
  19. codeshift/api/routers/webhooks.py +289 -0
  20. codeshift/cli/__init__.py +5 -0
  21. codeshift/cli/commands/__init__.py +7 -0
  22. codeshift/cli/commands/apply.py +352 -0
  23. codeshift/cli/commands/auth.py +842 -0
  24. codeshift/cli/commands/diff.py +221 -0
  25. codeshift/cli/commands/scan.py +368 -0
  26. codeshift/cli/commands/upgrade.py +436 -0
  27. codeshift/cli/commands/upgrade_all.py +518 -0
  28. codeshift/cli/main.py +221 -0
  29. codeshift/cli/quota.py +210 -0
  30. codeshift/knowledge/__init__.py +50 -0
  31. codeshift/knowledge/cache.py +167 -0
  32. codeshift/knowledge/generator.py +231 -0
  33. codeshift/knowledge/models.py +151 -0
  34. codeshift/knowledge/parser.py +270 -0
  35. codeshift/knowledge/sources.py +388 -0
  36. codeshift/knowledge_base/__init__.py +17 -0
  37. codeshift/knowledge_base/loader.py +102 -0
  38. codeshift/knowledge_base/models.py +110 -0
  39. codeshift/migrator/__init__.py +23 -0
  40. codeshift/migrator/ast_transforms.py +256 -0
  41. codeshift/migrator/engine.py +395 -0
  42. codeshift/migrator/llm_migrator.py +320 -0
  43. codeshift/migrator/transforms/__init__.py +19 -0
  44. codeshift/migrator/transforms/fastapi_transformer.py +174 -0
  45. codeshift/migrator/transforms/pandas_transformer.py +236 -0
  46. codeshift/migrator/transforms/pydantic_v1_to_v2.py +637 -0
  47. codeshift/migrator/transforms/requests_transformer.py +218 -0
  48. codeshift/migrator/transforms/sqlalchemy_transformer.py +175 -0
  49. codeshift/scanner/__init__.py +6 -0
  50. codeshift/scanner/code_scanner.py +352 -0
  51. codeshift/scanner/dependency_parser.py +473 -0
  52. codeshift/utils/__init__.py +5 -0
  53. codeshift/utils/api_client.py +266 -0
  54. codeshift/utils/cache.py +318 -0
  55. codeshift/utils/config.py +71 -0
  56. codeshift/utils/llm_client.py +221 -0
  57. codeshift/validator/__init__.py +6 -0
  58. codeshift/validator/syntax_checker.py +183 -0
  59. codeshift/validator/test_runner.py +224 -0
  60. codeshift-0.2.0.dist-info/METADATA +326 -0
  61. codeshift-0.2.0.dist-info/RECORD +65 -0
  62. codeshift-0.2.0.dist-info/WHEEL +5 -0
  63. codeshift-0.2.0.dist-info/entry_points.txt +2 -0
  64. codeshift-0.2.0.dist-info/licenses/LICENSE +21 -0
  65. codeshift-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,218 @@
1
+ """Requests library transformation using LibCST."""
2
+
3
+ import libcst as cst
4
+
5
+ from codeshift.migrator.ast_transforms import BaseTransformer
6
+
7
+
8
+ class RequestsTransformer(BaseTransformer):
9
+ """Transform Requests library code for version upgrades."""
10
+
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+
14
+ def leave_ImportFrom(
15
+ self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
16
+ ) -> cst.ImportFrom:
17
+ """Transform requests imports."""
18
+ if original_node.module is None:
19
+ return updated_node
20
+
21
+ module_name = self._get_module_name(original_node.module)
22
+
23
+ # Transform requests.packages.urllib3 imports
24
+ if module_name == "requests.packages.urllib3" or module_name.startswith(
25
+ "requests.packages.urllib3."
26
+ ):
27
+ new_module_name = module_name.replace("requests.packages.urllib3", "urllib3")
28
+ self.record_change(
29
+ description="Import urllib3 directly instead of through requests.packages",
30
+ line_number=1,
31
+ original=f"from {module_name}",
32
+ replacement=f"from {new_module_name}",
33
+ transform_name="urllib3_import_fix",
34
+ )
35
+ return updated_node.with_changes(module=self._build_module_node(new_module_name))
36
+
37
+ # Transform requests.compat imports
38
+ if module_name == "requests.compat":
39
+ if isinstance(updated_node.names, cst.ImportStar):
40
+ return updated_node
41
+
42
+ for name in updated_node.names:
43
+ if isinstance(name, cst.ImportAlias) and isinstance(name.name, cst.Name):
44
+ import_name = name.name.value
45
+ if import_name in (
46
+ "urljoin",
47
+ "urlparse",
48
+ "urlsplit",
49
+ "urlunparse",
50
+ "urlencode",
51
+ "quote",
52
+ "unquote",
53
+ ):
54
+ self.record_change(
55
+ description=f"Import {import_name} from urllib.parse instead of requests.compat",
56
+ line_number=1,
57
+ original=f"from requests.compat import {import_name}",
58
+ replacement=f"from urllib.parse import {import_name}",
59
+ transform_name=f"compat_{import_name}_fix",
60
+ )
61
+ return updated_node.with_changes(
62
+ module=cst.Attribute(
63
+ value=cst.Name("urllib"),
64
+ attr=cst.Name("parse"),
65
+ )
66
+ )
67
+
68
+ return updated_node
69
+
70
+ def leave_Attribute(
71
+ self, original_node: cst.Attribute, updated_node: cst.Attribute
72
+ ) -> cst.Attribute | cst.Name:
73
+ """Transform requests.packages.urllib3 attribute access."""
74
+ # Check for requests.packages.urllib3 pattern
75
+ attr_str = self._get_full_attribute(updated_node)
76
+
77
+ if attr_str.startswith("requests.packages.urllib3"):
78
+ new_attr_str = attr_str.replace("requests.packages.urllib3", "urllib3")
79
+ self.record_change(
80
+ description="Access urllib3 directly instead of through requests.packages",
81
+ line_number=1,
82
+ original=attr_str,
83
+ replacement=new_attr_str,
84
+ transform_name="urllib3_attribute_fix",
85
+ )
86
+ return self._build_name_or_attribute_node(new_attr_str)
87
+
88
+ return updated_node
89
+
90
+ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
91
+ """Transform requests function calls."""
92
+ # Check for requests.get/post/put/delete without timeout
93
+ if isinstance(updated_node.func, cst.Attribute):
94
+ if (
95
+ isinstance(updated_node.func.value, cst.Name)
96
+ and updated_node.func.value.value == "requests"
97
+ ):
98
+ method_name = updated_node.func.attr.value
99
+ if method_name in ("get", "post", "put", "delete", "patch", "head", "options"):
100
+ # Check if timeout is specified
101
+ has_timeout = any(
102
+ isinstance(arg.keyword, cst.Name) and arg.keyword.value == "timeout"
103
+ for arg in updated_node.args
104
+ )
105
+ if not has_timeout:
106
+ self.record_change(
107
+ description=f"requests.{method_name}() called without explicit timeout",
108
+ line_number=1,
109
+ original=f"requests.{method_name}(...)",
110
+ replacement=f"requests.{method_name}(..., timeout=30)",
111
+ transform_name=f"{method_name}_add_explicit_timeout",
112
+ confidence=0.7,
113
+ notes="Consider adding explicit timeout parameter",
114
+ )
115
+
116
+ # Check for session method calls without timeout
117
+ if isinstance(updated_node.func, cst.Attribute):
118
+ method_name = updated_node.func.attr.value
119
+ if method_name in (
120
+ "get",
121
+ "post",
122
+ "put",
123
+ "delete",
124
+ "patch",
125
+ "head",
126
+ "options",
127
+ "request",
128
+ ):
129
+ # Check if this might be a session call (heuristic)
130
+ has_timeout = any(
131
+ isinstance(arg.keyword, cst.Name) and arg.keyword.value == "timeout"
132
+ for arg in updated_node.args
133
+ )
134
+ if not has_timeout and isinstance(updated_node.func.value, cst.Name):
135
+ value_name = updated_node.func.value.value.lower()
136
+ if "session" in value_name or value_name in ("s", "sess", "client"):
137
+ self.record_change(
138
+ description=f"Session.{method_name}() called without explicit timeout",
139
+ line_number=1,
140
+ original=f"session.{method_name}(...)",
141
+ replacement=f"session.{method_name}(..., timeout=30)",
142
+ transform_name=f"session_{method_name}_add_timeout",
143
+ confidence=0.6,
144
+ notes="Consider adding explicit timeout parameter",
145
+ )
146
+
147
+ return updated_node
148
+
149
+ def _get_module_name(self, module: cst.BaseExpression) -> str:
150
+ """Get the full module name from a Name or Attribute node."""
151
+ if isinstance(module, cst.Name):
152
+ return str(module.value)
153
+ elif isinstance(module, cst.Attribute):
154
+ return f"{self._get_module_name(module.value)}.{module.attr.value}"
155
+ return ""
156
+
157
+ def _build_module_node(self, module_name: str) -> cst.Name | cst.Attribute:
158
+ """Build a module node from a dotted name string."""
159
+ parts = module_name.split(".")
160
+ if len(parts) == 1:
161
+ return cst.Name(parts[0])
162
+
163
+ result: cst.Name | cst.Attribute = cst.Name(parts[0])
164
+ for part in parts[1:]:
165
+ result = cst.Attribute(value=result, attr=cst.Name(part))
166
+ return result
167
+
168
+ def _get_full_attribute(self, node: cst.Attribute) -> str:
169
+ """Get the full attribute path as a string."""
170
+ if isinstance(node.value, cst.Name):
171
+ return f"{node.value.value}.{node.attr.value}"
172
+ elif isinstance(node.value, cst.Attribute):
173
+ return f"{self._get_full_attribute(node.value)}.{node.attr.value}"
174
+ return str(node.attr.value)
175
+
176
+ def _build_attribute_node(self, attr_str: str) -> cst.Attribute:
177
+ """Build an attribute node from a dotted string."""
178
+ parts = attr_str.split(".")
179
+ result: cst.Name | cst.Attribute = cst.Name(parts[0])
180
+ for part in parts[1:]:
181
+ result = cst.Attribute(value=result, attr=cst.Name(part))
182
+ # Safe to cast since we always have at least 2 parts for an Attribute
183
+ assert isinstance(result, cst.Attribute)
184
+ return result
185
+
186
+ def _build_name_or_attribute_node(self, name_str: str) -> cst.Name | cst.Attribute:
187
+ """Build a Name or Attribute node from a dotted string."""
188
+ parts = name_str.split(".")
189
+ if len(parts) == 1:
190
+ return cst.Name(parts[0])
191
+ result: cst.Name | cst.Attribute = cst.Name(parts[0])
192
+ for part in parts[1:]:
193
+ result = cst.Attribute(value=result, attr=cst.Name(part))
194
+ return result
195
+
196
+
197
+ def transform_requests(source_code: str) -> tuple[str, list]:
198
+ """Transform Requests library code.
199
+
200
+ Args:
201
+ source_code: The source code to transform
202
+
203
+ Returns:
204
+ Tuple of (transformed_code, list of changes)
205
+ """
206
+ try:
207
+ tree = cst.parse_module(source_code)
208
+ except cst.ParserSyntaxError:
209
+ return source_code, []
210
+
211
+ transformer = RequestsTransformer()
212
+ transformer.set_source(source_code)
213
+
214
+ try:
215
+ transformed_tree = tree.visit(transformer)
216
+ return transformed_tree.code, transformer.changes
217
+ except Exception:
218
+ return source_code, []
@@ -0,0 +1,175 @@
1
+ """SQLAlchemy 1.x to 2.0 transformation using LibCST."""
2
+
3
+ import libcst as cst
4
+
5
+ from codeshift.migrator.ast_transforms import BaseTransformer
6
+
7
+
8
+ class SQLAlchemyTransformer(BaseTransformer):
9
+ """Transform SQLAlchemy 1.x code to 2.0."""
10
+
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+ self._needs_select_import = False
14
+ self._needs_text_import = False
15
+ self._has_declarative_base_import = False
16
+
17
+ def leave_ImportFrom(
18
+ self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
19
+ ) -> cst.ImportFrom | cst.RemovalSentinel:
20
+ """Transform SQLAlchemy imports."""
21
+ if original_node.module is None:
22
+ return updated_node
23
+
24
+ module_name = self._get_module_name(original_node.module)
25
+
26
+ # Transform declarative_base import
27
+ if module_name == "sqlalchemy.ext.declarative":
28
+ if isinstance(updated_node.names, cst.ImportStar):
29
+ return updated_node
30
+
31
+ new_names = []
32
+ changed = False
33
+
34
+ for name in updated_node.names:
35
+ if isinstance(name, cst.ImportAlias):
36
+ if isinstance(name.name, cst.Name) and name.name.value == "declarative_base":
37
+ # Change to DeclarativeBase from sqlalchemy.orm
38
+ self.record_change(
39
+ description="Import DeclarativeBase from sqlalchemy.orm instead of declarative_base",
40
+ line_number=1,
41
+ original="from sqlalchemy.ext.declarative import declarative_base",
42
+ replacement="from sqlalchemy.orm import DeclarativeBase",
43
+ transform_name="import_declarative_base",
44
+ )
45
+ self._has_declarative_base_import = True
46
+ # Return updated import from sqlalchemy.orm
47
+ return updated_node.with_changes(
48
+ module=cst.Attribute(
49
+ value=cst.Name("sqlalchemy"),
50
+ attr=cst.Name("orm"),
51
+ ),
52
+ names=[cst.ImportAlias(name=cst.Name("DeclarativeBase"))],
53
+ )
54
+ else:
55
+ new_names.append(name)
56
+ else:
57
+ new_names.append(name)
58
+
59
+ if changed and new_names:
60
+ return updated_node.with_changes(names=new_names)
61
+
62
+ # Handle backref import removal
63
+ if module_name == "sqlalchemy.orm":
64
+ if isinstance(updated_node.names, cst.ImportStar):
65
+ return updated_node
66
+
67
+ new_names = []
68
+ changed = False
69
+
70
+ for name in updated_node.names:
71
+ if isinstance(name, cst.ImportAlias):
72
+ if isinstance(name.name, cst.Name) and name.name.value == "backref":
73
+ self.record_change(
74
+ description="Remove backref import (use back_populates instead)",
75
+ line_number=1,
76
+ original="backref",
77
+ replacement="# backref removed, use back_populates",
78
+ transform_name="remove_backref_import",
79
+ )
80
+ changed = True
81
+ continue
82
+ new_names.append(name)
83
+ else:
84
+ new_names.append(name)
85
+
86
+ if changed:
87
+ if new_names:
88
+ return updated_node.with_changes(names=new_names)
89
+ # If no names left, remove the import
90
+ return cst.RemovalSentinel.REMOVE
91
+
92
+ return updated_node
93
+
94
+ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
95
+ """Transform SQLAlchemy function calls."""
96
+ # Handle declarative_base() -> class Base(DeclarativeBase): pass
97
+ # This is complex - we just record the need for change
98
+ if isinstance(updated_node.func, cst.Name):
99
+ func_name = updated_node.func.value
100
+
101
+ if func_name == "declarative_base":
102
+ self.record_change(
103
+ description="Replace declarative_base() with class Base(DeclarativeBase): pass",
104
+ line_number=1,
105
+ original="Base = declarative_base()",
106
+ replacement="class Base(DeclarativeBase): pass",
107
+ transform_name="declarative_base_to_class",
108
+ confidence=0.8,
109
+ notes="Manual review recommended - create class inheriting from DeclarativeBase",
110
+ )
111
+
112
+ # Handle create_engine future flag
113
+ if isinstance(updated_node.func, cst.Name) and updated_node.func.value == "create_engine":
114
+ new_args = []
115
+ changed = False
116
+ for arg in updated_node.args:
117
+ if isinstance(arg.keyword, cst.Name) and arg.keyword.value == "future":
118
+ # Remove future=True as it's now default
119
+ changed = True
120
+ self.record_change(
121
+ description="Remove future=True from create_engine (now default)",
122
+ line_number=1,
123
+ original="create_engine(..., future=True)",
124
+ replacement="create_engine(...)",
125
+ transform_name="remove_future_flag",
126
+ )
127
+ continue
128
+ new_args.append(arg)
129
+
130
+ if changed:
131
+ return updated_node.with_changes(args=new_args)
132
+
133
+ return updated_node
134
+
135
+ def leave_Attribute(
136
+ self, original_node: cst.Attribute, updated_node: cst.Attribute
137
+ ) -> cst.Attribute:
138
+ """Transform SQLAlchemy attribute accesses."""
139
+ # Handle method renames: .all() when preceded by query-like calls
140
+ # This is simplified - would need more context for accurate detection
141
+ # Note: attr_name = updated_node.attr.value would be used for future transforms
142
+
143
+ return updated_node
144
+
145
+ def _get_module_name(self, module: cst.BaseExpression) -> str:
146
+ """Get the full module name from a Name or Attribute node."""
147
+ if isinstance(module, cst.Name):
148
+ return str(module.value)
149
+ elif isinstance(module, cst.Attribute):
150
+ return f"{self._get_module_name(module.value)}.{module.attr.value}"
151
+ return ""
152
+
153
+
154
+ def transform_sqlalchemy(source_code: str) -> tuple[str, list]:
155
+ """Transform SQLAlchemy code from 1.x to 2.0.
156
+
157
+ Args:
158
+ source_code: The source code to transform
159
+
160
+ Returns:
161
+ Tuple of (transformed_code, list of changes)
162
+ """
163
+ try:
164
+ tree = cst.parse_module(source_code)
165
+ except cst.ParserSyntaxError:
166
+ return source_code, []
167
+
168
+ transformer = SQLAlchemyTransformer()
169
+ transformer.set_source(source_code)
170
+
171
+ try:
172
+ transformed_tree = tree.visit(transformer)
173
+ return transformed_tree.code, transformer.changes
174
+ except Exception:
175
+ return source_code, []
@@ -0,0 +1,6 @@
1
+ """Scanner module for finding library usage in code."""
2
+
3
+ from codeshift.scanner.code_scanner import CodeScanner, ImportInfo, UsageInfo
4
+ from codeshift.scanner.dependency_parser import Dependency, DependencyParser
5
+
6
+ __all__ = ["CodeScanner", "ImportInfo", "UsageInfo", "DependencyParser", "Dependency"]