codeshift 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. codeshift/__init__.py +8 -0
  2. codeshift/analyzer/__init__.py +5 -0
  3. codeshift/analyzer/risk_assessor.py +388 -0
  4. codeshift/api/__init__.py +1 -0
  5. codeshift/api/auth.py +182 -0
  6. codeshift/api/config.py +73 -0
  7. codeshift/api/database.py +215 -0
  8. codeshift/api/main.py +103 -0
  9. codeshift/api/models/__init__.py +55 -0
  10. codeshift/api/models/auth.py +108 -0
  11. codeshift/api/models/billing.py +92 -0
  12. codeshift/api/models/migrate.py +42 -0
  13. codeshift/api/models/usage.py +116 -0
  14. codeshift/api/routers/__init__.py +5 -0
  15. codeshift/api/routers/auth.py +440 -0
  16. codeshift/api/routers/billing.py +395 -0
  17. codeshift/api/routers/migrate.py +304 -0
  18. codeshift/api/routers/usage.py +291 -0
  19. codeshift/api/routers/webhooks.py +289 -0
  20. codeshift/cli/__init__.py +5 -0
  21. codeshift/cli/commands/__init__.py +7 -0
  22. codeshift/cli/commands/apply.py +352 -0
  23. codeshift/cli/commands/auth.py +842 -0
  24. codeshift/cli/commands/diff.py +221 -0
  25. codeshift/cli/commands/scan.py +368 -0
  26. codeshift/cli/commands/upgrade.py +436 -0
  27. codeshift/cli/commands/upgrade_all.py +518 -0
  28. codeshift/cli/main.py +221 -0
  29. codeshift/cli/quota.py +210 -0
  30. codeshift/knowledge/__init__.py +50 -0
  31. codeshift/knowledge/cache.py +167 -0
  32. codeshift/knowledge/generator.py +231 -0
  33. codeshift/knowledge/models.py +151 -0
  34. codeshift/knowledge/parser.py +270 -0
  35. codeshift/knowledge/sources.py +388 -0
  36. codeshift/knowledge_base/__init__.py +17 -0
  37. codeshift/knowledge_base/loader.py +102 -0
  38. codeshift/knowledge_base/models.py +110 -0
  39. codeshift/migrator/__init__.py +23 -0
  40. codeshift/migrator/ast_transforms.py +256 -0
  41. codeshift/migrator/engine.py +395 -0
  42. codeshift/migrator/llm_migrator.py +320 -0
  43. codeshift/migrator/transforms/__init__.py +19 -0
  44. codeshift/migrator/transforms/fastapi_transformer.py +174 -0
  45. codeshift/migrator/transforms/pandas_transformer.py +236 -0
  46. codeshift/migrator/transforms/pydantic_v1_to_v2.py +637 -0
  47. codeshift/migrator/transforms/requests_transformer.py +218 -0
  48. codeshift/migrator/transforms/sqlalchemy_transformer.py +175 -0
  49. codeshift/scanner/__init__.py +6 -0
  50. codeshift/scanner/code_scanner.py +352 -0
  51. codeshift/scanner/dependency_parser.py +473 -0
  52. codeshift/utils/__init__.py +5 -0
  53. codeshift/utils/api_client.py +266 -0
  54. codeshift/utils/cache.py +318 -0
  55. codeshift/utils/config.py +71 -0
  56. codeshift/utils/llm_client.py +221 -0
  57. codeshift/validator/__init__.py +6 -0
  58. codeshift/validator/syntax_checker.py +183 -0
  59. codeshift/validator/test_runner.py +224 -0
  60. codeshift-0.2.0.dist-info/METADATA +326 -0
  61. codeshift-0.2.0.dist-info/RECORD +65 -0
  62. codeshift-0.2.0.dist-info/WHEEL +5 -0
  63. codeshift-0.2.0.dist-info/entry_points.txt +2 -0
  64. codeshift-0.2.0.dist-info/licenses/LICENSE +21 -0
  65. codeshift-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,221 @@
1
+ """Anthropic Claude client wrapper for LLM-based migrations."""
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+
6
+ from anthropic import Anthropic
7
+
8
+
9
+ @dataclass
10
+ class LLMResponse:
11
+ """Response from the LLM."""
12
+
13
+ content: str
14
+ model: str
15
+ usage: dict
16
+ success: bool
17
+ error: str | None = None
18
+
19
+
20
+ class LLMClient:
21
+ """Client for interacting with Anthropic's Claude API."""
22
+
23
+ DEFAULT_MODEL = "claude-sonnet-4-20250514"
24
+ MAX_TOKENS = 4096
25
+
26
+ def __init__(
27
+ self,
28
+ api_key: str | None = None,
29
+ model: str | None = None,
30
+ ):
31
+ """Initialize the LLM client.
32
+
33
+ Args:
34
+ api_key: Anthropic API key. Defaults to ANTHROPIC_API_KEY env var.
35
+ model: Model to use. Defaults to claude-sonnet-4-20250514.
36
+ """
37
+ self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
38
+ self.model = model or self.DEFAULT_MODEL
39
+ self._client: Anthropic | None = None
40
+
41
+ @property
42
+ def client(self) -> Anthropic:
43
+ """Get or create the Anthropic client."""
44
+ if self._client is None:
45
+ if not self.api_key:
46
+ raise ValueError(
47
+ "Anthropic API key not found. Set ANTHROPIC_API_KEY environment variable "
48
+ "or pass api_key to LLMClient."
49
+ )
50
+ self._client = Anthropic(api_key=self.api_key)
51
+ return self._client
52
+
53
+ @property
54
+ def is_available(self) -> bool:
55
+ """Check if the LLM client is available (API key is set)."""
56
+ return bool(self.api_key)
57
+
58
+ def generate(
59
+ self,
60
+ prompt: str,
61
+ system_prompt: str | None = None,
62
+ max_tokens: int | None = None,
63
+ temperature: float = 0.0,
64
+ ) -> LLMResponse:
65
+ """Generate a response from the LLM.
66
+
67
+ Args:
68
+ prompt: The user prompt
69
+ system_prompt: Optional system prompt
70
+ max_tokens: Maximum tokens in response
71
+ temperature: Sampling temperature (0.0 for deterministic)
72
+
73
+ Returns:
74
+ LLMResponse with the generated content
75
+ """
76
+ if not self.is_available:
77
+ return LLMResponse(
78
+ content="",
79
+ model=self.model,
80
+ usage={},
81
+ success=False,
82
+ error="API key not configured",
83
+ )
84
+
85
+ try:
86
+ response = self.client.messages.create(
87
+ model=self.model,
88
+ max_tokens=max_tokens or self.MAX_TOKENS,
89
+ temperature=temperature,
90
+ system=system_prompt or "",
91
+ messages=[{"role": "user", "content": prompt}],
92
+ )
93
+
94
+ content = ""
95
+ for block in response.content:
96
+ if hasattr(block, "text"):
97
+ content += block.text
98
+
99
+ return LLMResponse(
100
+ content=content,
101
+ model=response.model,
102
+ usage={
103
+ "input_tokens": response.usage.input_tokens,
104
+ "output_tokens": response.usage.output_tokens,
105
+ },
106
+ success=True,
107
+ )
108
+
109
+ except Exception as e:
110
+ return LLMResponse(
111
+ content="",
112
+ model=self.model,
113
+ usage={},
114
+ success=False,
115
+ error=str(e),
116
+ )
117
+
118
+ def migrate_code(
119
+ self,
120
+ code: str,
121
+ library: str,
122
+ from_version: str,
123
+ to_version: str,
124
+ context: str | None = None,
125
+ ) -> LLMResponse:
126
+ """Use the LLM to migrate code.
127
+
128
+ Args:
129
+ code: The source code to migrate
130
+ library: The library being upgraded
131
+ from_version: Current version
132
+ to_version: Target version
133
+ context: Optional context about the migration
134
+
135
+ Returns:
136
+ LLMResponse with the migrated code
137
+ """
138
+ system_prompt = f"""You are an expert Python developer specializing in code migrations.
139
+ Your task is to migrate Python code from {library} v{from_version} to v{to_version}.
140
+
141
+ Guidelines:
142
+ 1. Only modify code that needs to change for the migration
143
+ 2. Preserve all comments, formatting, and code style where possible
144
+ 3. Add brief inline comments explaining non-obvious changes
145
+ 4. If you're unsure about a change, add a TODO comment
146
+ 5. Return ONLY the migrated code, no explanations before or after
147
+
148
+ Important {library} v{from_version} to v{to_version} changes:
149
+ - Config class -> model_config = ConfigDict(...)
150
+ - @validator -> @field_validator with @classmethod
151
+ - @root_validator -> @model_validator with @classmethod
152
+ - .dict() -> .model_dump()
153
+ - .json() -> .model_dump_json()
154
+ - .schema() -> .model_json_schema()
155
+ - .parse_obj() -> .model_validate()
156
+ - .parse_raw() -> .model_validate_json()
157
+ - .copy() -> .model_copy()
158
+ - orm_mode -> from_attributes
159
+ - Field(regex=...) -> Field(pattern=...)
160
+ """
161
+
162
+ prompt = f"""Migrate the following Python code from {library} v{from_version} to v{to_version}.
163
+
164
+ {f"Context: {context}" if context else ""}
165
+
166
+ Code to migrate:
167
+ ```python
168
+ {code}
169
+ ```
170
+
171
+ Return only the migrated Python code:"""
172
+
173
+ return self.generate(prompt, system_prompt=system_prompt)
174
+
175
+ def explain_change(
176
+ self,
177
+ original: str,
178
+ transformed: str,
179
+ library: str,
180
+ ) -> LLMResponse:
181
+ """Use the LLM to explain a migration change.
182
+
183
+ Args:
184
+ original: Original code
185
+ transformed: Transformed code
186
+ library: The library being upgraded
187
+
188
+ Returns:
189
+ LLMResponse with the explanation
190
+ """
191
+ system_prompt = """You are an expert Python developer.
192
+ Explain code changes clearly and concisely for other developers.
193
+ Focus on the 'why' not just the 'what'."""
194
+
195
+ prompt = f"""Explain the following {library} migration change:
196
+
197
+ Original:
198
+ ```python
199
+ {original}
200
+ ```
201
+
202
+ Migrated:
203
+ ```python
204
+ {transformed}
205
+ ```
206
+
207
+ Provide a brief explanation (2-3 sentences) of what changed and why:"""
208
+
209
+ return self.generate(prompt, system_prompt=system_prompt, max_tokens=500)
210
+
211
+
212
+ # Singleton instance for convenience
213
+ _default_client: LLMClient | None = None
214
+
215
+
216
+ def get_llm_client() -> LLMClient:
217
+ """Get the default LLM client instance."""
218
+ global _default_client
219
+ if _default_client is None:
220
+ _default_client = LLMClient()
221
+ return _default_client
@@ -0,0 +1,6 @@
1
+ """Validator module for checking transformed code."""
2
+
3
+ from codeshift.validator.syntax_checker import SyntaxChecker, SyntaxCheckResult
4
+ from codeshift.validator.test_runner import TestResult, TestRunner
5
+
6
+ __all__ = ["SyntaxChecker", "SyntaxCheckResult", "TestRunner", "TestResult"]
@@ -0,0 +1,183 @@
1
+ """Syntax checker for validating transformed code."""
2
+
3
+ import ast
4
+ import sys
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+
8
+
9
+ @dataclass
10
+ class SyntaxIssue:
11
+ """Represents a syntax error in code."""
12
+
13
+ message: str
14
+ line_number: int
15
+ column: int
16
+ line_text: str | None = None
17
+
18
+
19
+ @dataclass
20
+ class SyntaxCheckResult:
21
+ """Result of a syntax check."""
22
+
23
+ is_valid: bool
24
+ file_path: Path | None = None
25
+ errors: list[SyntaxIssue] = field(default_factory=list)
26
+ warnings: list[str] = field(default_factory=list)
27
+
28
+ @property
29
+ def error_count(self) -> int:
30
+ """Get the number of errors."""
31
+ return len(self.errors)
32
+
33
+
34
+ class SyntaxChecker:
35
+ """Validates Python code syntax."""
36
+
37
+ def __init__(self, python_version: tuple[int, int] | None = None):
38
+ """Initialize the syntax checker.
39
+
40
+ Args:
41
+ python_version: Target Python version as (major, minor).
42
+ Defaults to current Python version.
43
+ """
44
+ if python_version is None:
45
+ python_version = (sys.version_info.major, sys.version_info.minor)
46
+ self.python_version = python_version
47
+
48
+ def check_code(self, source_code: str, filename: str = "<string>") -> SyntaxCheckResult:
49
+ """Check if source code has valid Python syntax.
50
+
51
+ Args:
52
+ source_code: The Python source code to check
53
+ filename: Optional filename for error messages
54
+
55
+ Returns:
56
+ SyntaxCheckResult with validation status
57
+ """
58
+ try:
59
+ # First, try to compile the code
60
+ compile(source_code, filename, "exec")
61
+
62
+ # Then parse with AST for more detailed checking
63
+ ast.parse(source_code, filename=filename)
64
+
65
+ return SyntaxCheckResult(is_valid=True)
66
+
67
+ except SyntaxError as e:
68
+ error = SyntaxIssue(
69
+ message=str(e.msg) if hasattr(e, "msg") else str(e),
70
+ line_number=e.lineno or 0,
71
+ column=e.offset or 0,
72
+ line_text=e.text,
73
+ )
74
+ return SyntaxCheckResult(
75
+ is_valid=False,
76
+ errors=[error],
77
+ )
78
+
79
+ def check_file(self, file_path: Path) -> SyntaxCheckResult:
80
+ """Check if a Python file has valid syntax.
81
+
82
+ Args:
83
+ file_path: Path to the Python file
84
+
85
+ Returns:
86
+ SyntaxCheckResult with validation status
87
+ """
88
+ try:
89
+ source_code = file_path.read_text()
90
+ except Exception as e:
91
+ return SyntaxCheckResult(
92
+ is_valid=False,
93
+ file_path=file_path,
94
+ errors=[
95
+ SyntaxIssue(
96
+ message=f"Could not read file: {e}",
97
+ line_number=0,
98
+ column=0,
99
+ )
100
+ ],
101
+ )
102
+
103
+ result = self.check_code(source_code, str(file_path))
104
+ result.file_path = file_path
105
+ return result
106
+
107
+ def check_directory(
108
+ self, directory: Path, exclude_patterns: list[str] | None = None
109
+ ) -> list[SyntaxCheckResult]:
110
+ """Check all Python files in a directory.
111
+
112
+ Args:
113
+ directory: Path to the directory
114
+ exclude_patterns: Glob patterns to exclude
115
+
116
+ Returns:
117
+ List of SyntaxCheckResult for each file with errors
118
+ """
119
+ import fnmatch
120
+
121
+ exclude_patterns = exclude_patterns or []
122
+ results = []
123
+
124
+ for file_path in directory.rglob("*.py"):
125
+ relative_path = str(file_path.relative_to(directory))
126
+
127
+ # Check exclude patterns
128
+ excluded = False
129
+ for pattern in exclude_patterns:
130
+ if fnmatch.fnmatch(relative_path, pattern):
131
+ excluded = True
132
+ break
133
+
134
+ if excluded:
135
+ continue
136
+
137
+ result = self.check_file(file_path)
138
+ if not result.is_valid:
139
+ results.append(result)
140
+
141
+ return results
142
+
143
+ def validate_transform(self, original: str, transformed: str) -> tuple[bool, list[str]]:
144
+ """Validate that a transformation didn't break syntax.
145
+
146
+ Args:
147
+ original: Original source code
148
+ transformed: Transformed source code
149
+
150
+ Returns:
151
+ Tuple of (is_valid, list of issues)
152
+ """
153
+ issues = []
154
+
155
+ # Check original syntax (should be valid)
156
+ original_result = self.check_code(original, "<original>")
157
+ if not original_result.is_valid:
158
+ issues.append("Original code has syntax errors")
159
+
160
+ # Check transformed syntax
161
+ transformed_result = self.check_code(transformed, "<transformed>")
162
+ if not transformed_result.is_valid:
163
+ for error in transformed_result.errors:
164
+ issues.append(f"Line {error.line_number}: {error.message}")
165
+ return False, issues
166
+
167
+ return True, issues
168
+
169
+
170
+ def quick_syntax_check(source_code: str) -> bool:
171
+ """Quick check if code has valid Python syntax.
172
+
173
+ Args:
174
+ source_code: The Python source code to check
175
+
176
+ Returns:
177
+ True if syntax is valid, False otherwise
178
+ """
179
+ try:
180
+ compile(source_code, "<string>", "exec")
181
+ return True
182
+ except SyntaxError:
183
+ return False
@@ -0,0 +1,224 @@
1
+ """Test runner for validating migrations."""
2
+
3
+ import subprocess
4
+ import sys
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+
8
+
9
+ @dataclass
10
+ class TestResult:
11
+ """Result of running tests."""
12
+
13
+ success: bool
14
+ exit_code: int
15
+ stdout: str = ""
16
+ stderr: str = ""
17
+ tests_run: int = 0
18
+ tests_passed: int = 0
19
+ tests_failed: int = 0
20
+ tests_skipped: int = 0
21
+ duration: float = 0.0
22
+ error_message: str | None = None
23
+
24
+ @property
25
+ def summary(self) -> str:
26
+ """Get a summary string of the test results."""
27
+ if self.success:
28
+ return f"✓ {self.tests_passed}/{self.tests_run} tests passed"
29
+ return f"✗ {self.tests_failed}/{self.tests_run} tests failed"
30
+
31
+
32
+ class TestRunner:
33
+ """Runs project tests to validate migrations."""
34
+
35
+ def __init__(
36
+ self,
37
+ project_path: Path,
38
+ test_command: list[str] | None = None,
39
+ timeout: int = 300,
40
+ ):
41
+ """Initialize the test runner.
42
+
43
+ Args:
44
+ project_path: Path to the project root
45
+ test_command: Custom test command. Defaults to pytest detection.
46
+ timeout: Maximum time in seconds to run tests
47
+ """
48
+ self.project_path = project_path
49
+ self.test_command = test_command or self._detect_test_command()
50
+ self.timeout = timeout
51
+
52
+ def _detect_test_command(self) -> list[str]:
53
+ """Detect the appropriate test command for the project."""
54
+ # Check for pytest
55
+ if (
56
+ (self.project_path / "pytest.ini").exists()
57
+ or (self.project_path / "pyproject.toml").exists()
58
+ or (self.project_path / "tests").exists()
59
+ ):
60
+ return [sys.executable, "-m", "pytest", "-v", "--tb=short"]
61
+
62
+ # Check for unittest
63
+ if (self.project_path / "tests").exists():
64
+ return [sys.executable, "-m", "unittest", "discover", "-v"]
65
+
66
+ # Default to pytest
67
+ return [sys.executable, "-m", "pytest", "-v", "--tb=short"]
68
+
69
+ def run(
70
+ self,
71
+ specific_tests: list[str] | None = None,
72
+ extra_args: list[str] | None = None,
73
+ ) -> TestResult:
74
+ """Run the project tests.
75
+
76
+ Args:
77
+ specific_tests: List of specific test files or patterns to run
78
+ extra_args: Additional arguments to pass to the test runner
79
+
80
+ Returns:
81
+ TestResult with the outcome
82
+ """
83
+ command = self.test_command.copy()
84
+
85
+ if extra_args:
86
+ command.extend(extra_args)
87
+
88
+ if specific_tests:
89
+ command.extend(specific_tests)
90
+
91
+ try:
92
+ result = subprocess.run(
93
+ command,
94
+ cwd=self.project_path,
95
+ capture_output=True,
96
+ text=True,
97
+ timeout=self.timeout,
98
+ )
99
+
100
+ # Parse pytest output
101
+ tests_run, tests_passed, tests_failed, tests_skipped = self._parse_pytest_output(
102
+ result.stdout + result.stderr
103
+ )
104
+
105
+ return TestResult(
106
+ success=result.returncode == 0,
107
+ exit_code=result.returncode,
108
+ stdout=result.stdout,
109
+ stderr=result.stderr,
110
+ tests_run=tests_run,
111
+ tests_passed=tests_passed,
112
+ tests_failed=tests_failed,
113
+ tests_skipped=tests_skipped,
114
+ )
115
+
116
+ except subprocess.TimeoutExpired:
117
+ return TestResult(
118
+ success=False,
119
+ exit_code=-1,
120
+ error_message=f"Tests timed out after {self.timeout} seconds",
121
+ )
122
+ except FileNotFoundError as e:
123
+ return TestResult(
124
+ success=False,
125
+ exit_code=-1,
126
+ error_message=f"Test command not found: {e}",
127
+ )
128
+ except Exception as e:
129
+ return TestResult(
130
+ success=False,
131
+ exit_code=-1,
132
+ error_message=f"Error running tests: {e}",
133
+ )
134
+
135
+ def _parse_pytest_output(self, output: str) -> tuple[int, int, int, int]:
136
+ """Parse pytest output to extract test counts.
137
+
138
+ Args:
139
+ output: Combined stdout and stderr from pytest
140
+
141
+ Returns:
142
+ Tuple of (total, passed, failed, skipped)
143
+ """
144
+ import re
145
+
146
+ # Look for pytest summary line like "5 passed, 2 failed, 1 skipped"
147
+ # or "1 passed in 0.05s"
148
+ passed = 0
149
+ failed = 0
150
+ skipped = 0
151
+ errors = 0
152
+
153
+ # Match patterns like "5 passed", "2 failed", etc.
154
+ passed_match = re.search(r"(\d+) passed", output)
155
+ if passed_match:
156
+ passed = int(passed_match.group(1))
157
+
158
+ failed_match = re.search(r"(\d+) failed", output)
159
+ if failed_match:
160
+ failed = int(failed_match.group(1))
161
+
162
+ skipped_match = re.search(r"(\d+) skipped", output)
163
+ if skipped_match:
164
+ skipped = int(skipped_match.group(1))
165
+
166
+ error_match = re.search(r"(\d+) error", output)
167
+ if error_match:
168
+ errors = int(error_match.group(1))
169
+
170
+ total = passed + failed + skipped + errors
171
+ return total, passed, failed + errors, skipped
172
+
173
+ def run_quick_check(self) -> TestResult:
174
+ """Run a quick smoke test (collection only, no execution).
175
+
176
+ Returns:
177
+ TestResult indicating if tests can be collected
178
+ """
179
+ command = self.test_command.copy()
180
+ command.extend(["--collect-only", "-q"])
181
+
182
+ try:
183
+ result = subprocess.run(
184
+ command,
185
+ cwd=self.project_path,
186
+ capture_output=True,
187
+ text=True,
188
+ timeout=60,
189
+ )
190
+
191
+ # Count collected tests
192
+ tests_collected = 0
193
+ for line in result.stdout.splitlines():
194
+ if "test" in line.lower() and "::" in line:
195
+ tests_collected += 1
196
+
197
+ return TestResult(
198
+ success=result.returncode == 0,
199
+ exit_code=result.returncode,
200
+ stdout=result.stdout,
201
+ stderr=result.stderr,
202
+ tests_run=tests_collected,
203
+ )
204
+
205
+ except Exception as e:
206
+ return TestResult(
207
+ success=False,
208
+ exit_code=-1,
209
+ error_message=f"Error collecting tests: {e}",
210
+ )
211
+
212
+
213
+ def run_tests(project_path: Path, timeout: int = 300) -> TestResult:
214
+ """Convenience function to run tests for a project.
215
+
216
+ Args:
217
+ project_path: Path to the project
218
+ timeout: Maximum time in seconds
219
+
220
+ Returns:
221
+ TestResult with the outcome
222
+ """
223
+ runner = TestRunner(project_path, timeout=timeout)
224
+ return runner.run()