codeshift 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. codeshift/__init__.py +8 -0
  2. codeshift/analyzer/__init__.py +5 -0
  3. codeshift/analyzer/risk_assessor.py +388 -0
  4. codeshift/api/__init__.py +1 -0
  5. codeshift/api/auth.py +182 -0
  6. codeshift/api/config.py +73 -0
  7. codeshift/api/database.py +215 -0
  8. codeshift/api/main.py +103 -0
  9. codeshift/api/models/__init__.py +55 -0
  10. codeshift/api/models/auth.py +108 -0
  11. codeshift/api/models/billing.py +92 -0
  12. codeshift/api/models/migrate.py +42 -0
  13. codeshift/api/models/usage.py +116 -0
  14. codeshift/api/routers/__init__.py +5 -0
  15. codeshift/api/routers/auth.py +440 -0
  16. codeshift/api/routers/billing.py +395 -0
  17. codeshift/api/routers/migrate.py +304 -0
  18. codeshift/api/routers/usage.py +291 -0
  19. codeshift/api/routers/webhooks.py +289 -0
  20. codeshift/cli/__init__.py +5 -0
  21. codeshift/cli/commands/__init__.py +7 -0
  22. codeshift/cli/commands/apply.py +352 -0
  23. codeshift/cli/commands/auth.py +842 -0
  24. codeshift/cli/commands/diff.py +221 -0
  25. codeshift/cli/commands/scan.py +368 -0
  26. codeshift/cli/commands/upgrade.py +436 -0
  27. codeshift/cli/commands/upgrade_all.py +518 -0
  28. codeshift/cli/main.py +221 -0
  29. codeshift/cli/quota.py +210 -0
  30. codeshift/knowledge/__init__.py +50 -0
  31. codeshift/knowledge/cache.py +167 -0
  32. codeshift/knowledge/generator.py +231 -0
  33. codeshift/knowledge/models.py +151 -0
  34. codeshift/knowledge/parser.py +270 -0
  35. codeshift/knowledge/sources.py +388 -0
  36. codeshift/knowledge_base/__init__.py +17 -0
  37. codeshift/knowledge_base/loader.py +102 -0
  38. codeshift/knowledge_base/models.py +110 -0
  39. codeshift/migrator/__init__.py +23 -0
  40. codeshift/migrator/ast_transforms.py +256 -0
  41. codeshift/migrator/engine.py +395 -0
  42. codeshift/migrator/llm_migrator.py +320 -0
  43. codeshift/migrator/transforms/__init__.py +19 -0
  44. codeshift/migrator/transforms/fastapi_transformer.py +174 -0
  45. codeshift/migrator/transforms/pandas_transformer.py +236 -0
  46. codeshift/migrator/transforms/pydantic_v1_to_v2.py +637 -0
  47. codeshift/migrator/transforms/requests_transformer.py +218 -0
  48. codeshift/migrator/transforms/sqlalchemy_transformer.py +175 -0
  49. codeshift/scanner/__init__.py +6 -0
  50. codeshift/scanner/code_scanner.py +352 -0
  51. codeshift/scanner/dependency_parser.py +473 -0
  52. codeshift/utils/__init__.py +5 -0
  53. codeshift/utils/api_client.py +266 -0
  54. codeshift/utils/cache.py +318 -0
  55. codeshift/utils/config.py +71 -0
  56. codeshift/utils/llm_client.py +221 -0
  57. codeshift/validator/__init__.py +6 -0
  58. codeshift/validator/syntax_checker.py +183 -0
  59. codeshift/validator/test_runner.py +224 -0
  60. codeshift-0.2.0.dist-info/METADATA +326 -0
  61. codeshift-0.2.0.dist-info/RECORD +65 -0
  62. codeshift-0.2.0.dist-info/WHEEL +5 -0
  63. codeshift-0.2.0.dist-info/entry_points.txt +2 -0
  64. codeshift-0.2.0.dist-info/licenses/LICENSE +21 -0
  65. codeshift-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,266 @@
1
+ """PyResolve API client for LLM-powered migrations.
2
+
3
+ This client calls the PyResolve API instead of Anthropic directly,
4
+ ensuring that LLM features are gated behind the subscription model.
5
+ """
6
+
7
+ from dataclasses import dataclass
8
+
9
+ import httpx
10
+
11
+ from codeshift.cli.commands.auth import get_api_key, get_api_url
12
+
13
+
14
+ @dataclass
15
+ class APIResponse:
16
+ """Response from the PyResolve API."""
17
+
18
+ success: bool
19
+ content: str
20
+ error: str | None = None
21
+ usage: dict | None = None
22
+ cached: bool = False
23
+
24
+
25
+ class PyResolveAPIClient:
26
+ """Client for interacting with the PyResolve API for LLM migrations.
27
+
28
+ This client routes all LLM calls through the PyResolve API,
29
+ which handles:
30
+ - Authentication and authorization
31
+ - Quota checking and billing
32
+ - Server-side Anthropic API calls
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ api_key: str | None = None,
38
+ api_url: str | None = None,
39
+ timeout: int = 60,
40
+ ):
41
+ """Initialize the API client.
42
+
43
+ Args:
44
+ api_key: PyResolve API key. Defaults to stored credentials.
45
+ api_url: API base URL. Defaults to stored URL.
46
+ timeout: Request timeout in seconds.
47
+ """
48
+ self.api_key = api_key or get_api_key()
49
+ self.api_url = api_url or get_api_url()
50
+ self.timeout = timeout
51
+
52
+ @property
53
+ def is_available(self) -> bool:
54
+ """Check if the API client is available (has API key)."""
55
+ return bool(self.api_key)
56
+
57
+ def _make_request(
58
+ self,
59
+ endpoint: str,
60
+ payload: dict,
61
+ ) -> httpx.Response:
62
+ """Make a request to the API.
63
+
64
+ Args:
65
+ endpoint: API endpoint (e.g., '/migrate/code')
66
+ payload: Request payload
67
+
68
+ Returns:
69
+ HTTP response
70
+
71
+ Raises:
72
+ httpx.RequestError: On network errors
73
+ """
74
+ if not self.api_key:
75
+ raise ValueError("API key not configured. Run 'codeshift login' to authenticate.")
76
+
77
+ return httpx.post(
78
+ f"{self.api_url}{endpoint}",
79
+ headers={"X-API-Key": self.api_key},
80
+ json=payload,
81
+ timeout=self.timeout,
82
+ )
83
+
84
+ def migrate_code(
85
+ self,
86
+ code: str,
87
+ library: str,
88
+ from_version: str,
89
+ to_version: str,
90
+ context: str | None = None,
91
+ ) -> APIResponse:
92
+ """Migrate code using the PyResolve API.
93
+
94
+ Args:
95
+ code: Source code to migrate
96
+ library: Library being upgraded
97
+ from_version: Current version
98
+ to_version: Target version
99
+ context: Optional context about the migration
100
+
101
+ Returns:
102
+ APIResponse with the migrated code
103
+ """
104
+ if not self.is_available:
105
+ return APIResponse(
106
+ success=False,
107
+ content=code,
108
+ error="Not authenticated. Run 'codeshift login' to use LLM migrations.",
109
+ )
110
+
111
+ try:
112
+ response = self._make_request(
113
+ "/migrate/code",
114
+ {
115
+ "code": code,
116
+ "library": library,
117
+ "from_version": from_version,
118
+ "to_version": to_version,
119
+ "context": context,
120
+ },
121
+ )
122
+
123
+ if response.status_code == 200:
124
+ data = response.json()
125
+ return APIResponse(
126
+ success=data.get("success", False),
127
+ content=data.get("migrated_code", code),
128
+ error=data.get("error"),
129
+ usage=data.get("usage"),
130
+ cached=data.get("cached", False),
131
+ )
132
+
133
+ elif response.status_code == 401:
134
+ return APIResponse(
135
+ success=False,
136
+ content=code,
137
+ error="Authentication failed. Run 'codeshift login' to re-authenticate.",
138
+ )
139
+
140
+ elif response.status_code == 402:
141
+ data = response.json()
142
+ detail = data.get("detail", {})
143
+ return APIResponse(
144
+ success=False,
145
+ content=code,
146
+ error=(
147
+ f"LLM quota exceeded. Current usage: {detail.get('current_usage', '?')}, "
148
+ f"Limit: {detail.get('limit', '?')}. "
149
+ f"Upgrade at {detail.get('upgrade_url', 'https://codeshift.dev/pricing')}"
150
+ ),
151
+ )
152
+
153
+ elif response.status_code == 403:
154
+ return APIResponse(
155
+ success=False,
156
+ content=code,
157
+ error="LLM migrations require Pro tier or higher. Run 'codeshift upgrade-plan' to upgrade.",
158
+ )
159
+
160
+ elif response.status_code == 503:
161
+ return APIResponse(
162
+ success=False,
163
+ content=code,
164
+ error="LLM service temporarily unavailable. Please try again later.",
165
+ )
166
+
167
+ else:
168
+ return APIResponse(
169
+ success=False,
170
+ content=code,
171
+ error=f"API error: {response.status_code}",
172
+ )
173
+
174
+ except httpx.RequestError as e:
175
+ return APIResponse(
176
+ success=False,
177
+ content=code,
178
+ error=f"Network error: {str(e)}",
179
+ )
180
+
181
+ def explain_change(
182
+ self,
183
+ original: str,
184
+ transformed: str,
185
+ library: str,
186
+ ) -> APIResponse:
187
+ """Get an explanation of a migration change.
188
+
189
+ Args:
190
+ original: Original code
191
+ transformed: Transformed code
192
+ library: Library being upgraded
193
+
194
+ Returns:
195
+ APIResponse with the explanation
196
+ """
197
+ if not self.is_available:
198
+ return APIResponse(
199
+ success=False,
200
+ content="",
201
+ error="Not authenticated. Run 'codeshift login' to use this feature.",
202
+ )
203
+
204
+ try:
205
+ response = self._make_request(
206
+ "/migrate/explain",
207
+ {
208
+ "original_code": original,
209
+ "transformed_code": transformed,
210
+ "library": library,
211
+ },
212
+ )
213
+
214
+ if response.status_code == 200:
215
+ data = response.json()
216
+ return APIResponse(
217
+ success=data.get("success", False),
218
+ content=data.get("explanation", ""),
219
+ error=data.get("error"),
220
+ )
221
+
222
+ elif response.status_code == 402:
223
+ return APIResponse(
224
+ success=False,
225
+ content="",
226
+ error="LLM quota exceeded. Upgrade your plan to continue.",
227
+ )
228
+
229
+ elif response.status_code == 403:
230
+ return APIResponse(
231
+ success=False,
232
+ content="",
233
+ error="This feature requires Pro tier or higher.",
234
+ )
235
+
236
+ else:
237
+ return APIResponse(
238
+ success=False,
239
+ content="",
240
+ error=f"API error: {response.status_code}",
241
+ )
242
+
243
+ except httpx.RequestError as e:
244
+ return APIResponse(
245
+ success=False,
246
+ content="",
247
+ error=f"Network error: {str(e)}",
248
+ )
249
+
250
+
251
+ # Singleton instance
252
+ _default_client: PyResolveAPIClient | None = None
253
+
254
+
255
+ def get_api_client() -> PyResolveAPIClient:
256
+ """Get the default API client instance."""
257
+ global _default_client
258
+ if _default_client is None:
259
+ _default_client = PyResolveAPIClient()
260
+ return _default_client
261
+
262
+
263
+ def reset_api_client() -> None:
264
+ """Reset the API client (useful after login/logout)."""
265
+ global _default_client
266
+ _default_client = None
@@ -0,0 +1,318 @@
1
+ """Cache for LLM responses and expensive operations."""
2
+
3
+ import hashlib
4
+ import json
5
+ import time
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+
11
+ @dataclass
12
+ class CacheEntry:
13
+ """A cached entry with metadata."""
14
+
15
+ key: str
16
+ value: Any
17
+ created_at: float
18
+ expires_at: float | None = None
19
+ hits: int = 0
20
+
21
+ @property
22
+ def is_expired(self) -> bool:
23
+ """Check if the entry has expired."""
24
+ if self.expires_at is None:
25
+ return False
26
+ return time.time() > self.expires_at
27
+
28
+
29
+ class Cache:
30
+ """Simple file-based cache for PyResolve."""
31
+
32
+ def __init__(
33
+ self,
34
+ cache_dir: Path | None = None,
35
+ default_ttl: int | None = None,
36
+ ):
37
+ """Initialize the cache.
38
+
39
+ Args:
40
+ cache_dir: Directory to store cache files.
41
+ Defaults to ~/.codeshift/cache
42
+ default_ttl: Default time-to-live in seconds. None means no expiration.
43
+ """
44
+ if cache_dir is None:
45
+ cache_dir = Path.home() / ".codeshift" / "cache"
46
+ self.cache_dir = cache_dir
47
+ self.default_ttl = default_ttl
48
+ self._memory_cache: dict[str, CacheEntry] = {}
49
+
50
+ def _ensure_dir(self) -> None:
51
+ """Ensure the cache directory exists."""
52
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
53
+
54
+ def _make_key(self, *args: Any) -> str:
55
+ """Create a cache key from arguments."""
56
+ key_data = json.dumps(args, sort_keys=True, default=str)
57
+ return hashlib.sha256(key_data.encode()).hexdigest()[:16]
58
+
59
+ def _get_cache_path(self, key: str) -> Path:
60
+ """Get the file path for a cache key."""
61
+ return self.cache_dir / f"{key}.json"
62
+
63
+ def get(self, key: str) -> Any | None:
64
+ """Get a value from the cache.
65
+
66
+ Args:
67
+ key: The cache key
68
+
69
+ Returns:
70
+ The cached value or None if not found/expired
71
+ """
72
+ # Check memory cache first
73
+ if key in self._memory_cache:
74
+ entry = self._memory_cache[key]
75
+ if not entry.is_expired:
76
+ entry.hits += 1
77
+ return entry.value
78
+ else:
79
+ del self._memory_cache[key]
80
+
81
+ # Check file cache
82
+ cache_path = self._get_cache_path(key)
83
+ if cache_path.exists():
84
+ try:
85
+ data = json.loads(cache_path.read_text())
86
+ entry = CacheEntry(
87
+ key=data["key"],
88
+ value=data["value"],
89
+ created_at=data["created_at"],
90
+ expires_at=data.get("expires_at"),
91
+ hits=data.get("hits", 0),
92
+ )
93
+
94
+ if entry.is_expired:
95
+ cache_path.unlink()
96
+ return None
97
+
98
+ # Store in memory cache
99
+ entry.hits += 1
100
+ self._memory_cache[key] = entry
101
+
102
+ # Update file with new hit count
103
+ data["hits"] = entry.hits
104
+ cache_path.write_text(json.dumps(data))
105
+
106
+ return entry.value
107
+
108
+ except (json.JSONDecodeError, KeyError):
109
+ # Invalid cache file, remove it
110
+ cache_path.unlink()
111
+
112
+ return None
113
+
114
+ def set(self, key: str, value: Any, ttl: int | None = None) -> None:
115
+ """Set a value in the cache.
116
+
117
+ Args:
118
+ key: The cache key
119
+ value: The value to cache
120
+ ttl: Time-to-live in seconds. None uses default, 0 means no expiration.
121
+ """
122
+ self._ensure_dir()
123
+
124
+ if ttl is None:
125
+ ttl = self.default_ttl
126
+
127
+ now = time.time()
128
+ expires_at = now + ttl if ttl else None
129
+
130
+ entry = CacheEntry(
131
+ key=key,
132
+ value=value,
133
+ created_at=now,
134
+ expires_at=expires_at,
135
+ hits=0,
136
+ )
137
+
138
+ # Store in memory
139
+ self._memory_cache[key] = entry
140
+
141
+ # Store in file
142
+ cache_path = self._get_cache_path(key)
143
+ cache_path.write_text(
144
+ json.dumps(
145
+ {
146
+ "key": entry.key,
147
+ "value": entry.value,
148
+ "created_at": entry.created_at,
149
+ "expires_at": entry.expires_at,
150
+ "hits": entry.hits,
151
+ }
152
+ )
153
+ )
154
+
155
+ def delete(self, key: str) -> bool:
156
+ """Delete a value from the cache.
157
+
158
+ Args:
159
+ key: The cache key
160
+
161
+ Returns:
162
+ True if the key was found and deleted
163
+ """
164
+ found = False
165
+
166
+ if key in self._memory_cache:
167
+ del self._memory_cache[key]
168
+ found = True
169
+
170
+ cache_path = self._get_cache_path(key)
171
+ if cache_path.exists():
172
+ cache_path.unlink()
173
+ found = True
174
+
175
+ return found
176
+
177
+ def clear(self) -> int:
178
+ """Clear all cached entries.
179
+
180
+ Returns:
181
+ Number of entries cleared
182
+ """
183
+ count = len(self._memory_cache)
184
+ self._memory_cache.clear()
185
+
186
+ if self.cache_dir.exists():
187
+ for cache_file in self.cache_dir.glob("*.json"):
188
+ cache_file.unlink()
189
+ count += 1
190
+
191
+ return count
192
+
193
+ def cleanup_expired(self) -> int:
194
+ """Remove all expired entries.
195
+
196
+ Returns:
197
+ Number of entries removed
198
+ """
199
+ count = 0
200
+
201
+ # Clean memory cache
202
+ expired_keys = [k for k, v in self._memory_cache.items() if v.is_expired]
203
+ for key in expired_keys:
204
+ del self._memory_cache[key]
205
+ count += 1
206
+
207
+ # Clean file cache
208
+ if self.cache_dir.exists():
209
+ for cache_file in self.cache_dir.glob("*.json"):
210
+ try:
211
+ data = json.loads(cache_file.read_text())
212
+ expires_at = data.get("expires_at")
213
+ if expires_at and time.time() > expires_at:
214
+ cache_file.unlink()
215
+ count += 1
216
+ except (json.JSONDecodeError, KeyError):
217
+ cache_file.unlink()
218
+ count += 1
219
+
220
+ return count
221
+
222
+ def stats(self) -> dict:
223
+ """Get cache statistics.
224
+
225
+ Returns:
226
+ Dictionary with cache stats
227
+ """
228
+ memory_entries = len(self._memory_cache)
229
+ file_entries = len(list(self.cache_dir.glob("*.json"))) if self.cache_dir.exists() else 0
230
+ total_hits = sum(e.hits for e in self._memory_cache.values())
231
+
232
+ return {
233
+ "memory_entries": memory_entries,
234
+ "file_entries": file_entries,
235
+ "total_hits": total_hits,
236
+ "cache_dir": str(self.cache_dir),
237
+ }
238
+
239
+
240
+ class LLMCache(Cache):
241
+ """Specialized cache for LLM responses."""
242
+
243
+ def __init__(
244
+ self,
245
+ cache_dir: Path | None = None,
246
+ default_ttl: int = 86400 * 7, # 7 days default
247
+ ):
248
+ """Initialize the LLM cache.
249
+
250
+ Args:
251
+ cache_dir: Directory to store cache files
252
+ default_ttl: Default TTL in seconds (default: 7 days)
253
+ """
254
+ if cache_dir is None:
255
+ cache_dir = Path.home() / ".codeshift" / "cache" / "llm"
256
+ super().__init__(cache_dir, default_ttl)
257
+
258
+ def get_migration(
259
+ self,
260
+ code: str,
261
+ library: str,
262
+ from_version: str,
263
+ to_version: str,
264
+ ) -> str | None:
265
+ """Get a cached migration result.
266
+
267
+ Args:
268
+ code: The source code
269
+ library: Library name
270
+ from_version: Source version
271
+ to_version: Target version
272
+
273
+ Returns:
274
+ Cached migrated code or None
275
+ """
276
+ key = self._make_key("migrate", code, library, from_version, to_version)
277
+ return self.get(key)
278
+
279
+ def set_migration(
280
+ self,
281
+ code: str,
282
+ library: str,
283
+ from_version: str,
284
+ to_version: str,
285
+ result: str,
286
+ ) -> None:
287
+ """Cache a migration result.
288
+
289
+ Args:
290
+ code: The source code
291
+ library: Library name
292
+ from_version: Source version
293
+ to_version: Target version
294
+ result: The migrated code
295
+ """
296
+ key = self._make_key("migrate", code, library, from_version, to_version)
297
+ self.set(key, result)
298
+
299
+
300
+ # Default cache instances
301
+ _default_cache: Cache | None = None
302
+ _llm_cache: LLMCache | None = None
303
+
304
+
305
+ def get_cache() -> Cache:
306
+ """Get the default cache instance."""
307
+ global _default_cache
308
+ if _default_cache is None:
309
+ _default_cache = Cache()
310
+ return _default_cache
311
+
312
+
313
+ def get_llm_cache() -> LLMCache:
314
+ """Get the LLM cache instance."""
315
+ global _llm_cache
316
+ if _llm_cache is None:
317
+ _llm_cache = LLMCache()
318
+ return _llm_cache
@@ -0,0 +1,71 @@
1
+ """Configuration management for Codeshift."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+
6
+ import toml
7
+
8
+
9
+ @dataclass
10
+ class ProjectConfig:
11
+ """Configuration loaded from pyproject.toml [tool.codeshift] section."""
12
+
13
+ exclude: list[str] = field(
14
+ default_factory=lambda: [".codeshift/*", "tests/*", ".venv/*", "venv/*"]
15
+ )
16
+ use_llm: bool = True
17
+ anthropic_api_key: str | None = None
18
+ cache_dir: Path = field(default_factory=lambda: Path.home() / ".codeshift" / "cache")
19
+
20
+ @classmethod
21
+ def from_pyproject(cls, project_path: Path) -> "ProjectConfig":
22
+ """Load configuration from pyproject.toml if it exists."""
23
+ pyproject_path = project_path / "pyproject.toml"
24
+ config = cls()
25
+
26
+ if pyproject_path.exists():
27
+ try:
28
+ data = toml.load(pyproject_path)
29
+ codeshift_config = data.get("tool", {}).get("codeshift", {})
30
+
31
+ if "exclude" in codeshift_config:
32
+ config.exclude = codeshift_config["exclude"]
33
+ if "use_llm" in codeshift_config:
34
+ config.use_llm = codeshift_config["use_llm"]
35
+ if "anthropic_api_key" in codeshift_config:
36
+ config.anthropic_api_key = codeshift_config["anthropic_api_key"]
37
+ if "cache_dir" in codeshift_config:
38
+ config.cache_dir = Path(codeshift_config["cache_dir"])
39
+ except Exception:
40
+ # If we can't parse the config, use defaults
41
+ pass
42
+
43
+ return config
44
+
45
+
46
+ @dataclass
47
+ class Config:
48
+ """Runtime configuration for a Codeshift session."""
49
+
50
+ project_path: Path
51
+ target_library: str
52
+ target_version: str
53
+ project_config: ProjectConfig = field(default_factory=ProjectConfig)
54
+ state_file: Path | None = None
55
+ dry_run: bool = False
56
+ verbose: bool = False
57
+
58
+ def __post_init__(self) -> None:
59
+ """Initialize derived fields."""
60
+ if self.state_file is None:
61
+ self.state_file = self.project_path / ".codeshift" / "state.json"
62
+
63
+ @property
64
+ def codeshift_dir(self) -> Path:
65
+ """Get the .codeshift directory for this project."""
66
+ return self.project_path / ".codeshift"
67
+
68
+ def ensure_dirs(self) -> None:
69
+ """Ensure required directories exist."""
70
+ self.codeshift_dir.mkdir(parents=True, exist_ok=True)
71
+ self.project_config.cache_dir.mkdir(parents=True, exist_ok=True)