codeshift 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. codeshift/__init__.py +8 -0
  2. codeshift/analyzer/__init__.py +5 -0
  3. codeshift/analyzer/risk_assessor.py +388 -0
  4. codeshift/api/__init__.py +1 -0
  5. codeshift/api/auth.py +182 -0
  6. codeshift/api/config.py +73 -0
  7. codeshift/api/database.py +215 -0
  8. codeshift/api/main.py +103 -0
  9. codeshift/api/models/__init__.py +55 -0
  10. codeshift/api/models/auth.py +108 -0
  11. codeshift/api/models/billing.py +92 -0
  12. codeshift/api/models/migrate.py +42 -0
  13. codeshift/api/models/usage.py +116 -0
  14. codeshift/api/routers/__init__.py +5 -0
  15. codeshift/api/routers/auth.py +440 -0
  16. codeshift/api/routers/billing.py +395 -0
  17. codeshift/api/routers/migrate.py +304 -0
  18. codeshift/api/routers/usage.py +291 -0
  19. codeshift/api/routers/webhooks.py +289 -0
  20. codeshift/cli/__init__.py +5 -0
  21. codeshift/cli/commands/__init__.py +7 -0
  22. codeshift/cli/commands/apply.py +352 -0
  23. codeshift/cli/commands/auth.py +842 -0
  24. codeshift/cli/commands/diff.py +221 -0
  25. codeshift/cli/commands/scan.py +368 -0
  26. codeshift/cli/commands/upgrade.py +436 -0
  27. codeshift/cli/commands/upgrade_all.py +518 -0
  28. codeshift/cli/main.py +221 -0
  29. codeshift/cli/quota.py +210 -0
  30. codeshift/knowledge/__init__.py +50 -0
  31. codeshift/knowledge/cache.py +167 -0
  32. codeshift/knowledge/generator.py +231 -0
  33. codeshift/knowledge/models.py +151 -0
  34. codeshift/knowledge/parser.py +270 -0
  35. codeshift/knowledge/sources.py +388 -0
  36. codeshift/knowledge_base/__init__.py +17 -0
  37. codeshift/knowledge_base/loader.py +102 -0
  38. codeshift/knowledge_base/models.py +110 -0
  39. codeshift/migrator/__init__.py +23 -0
  40. codeshift/migrator/ast_transforms.py +256 -0
  41. codeshift/migrator/engine.py +395 -0
  42. codeshift/migrator/llm_migrator.py +320 -0
  43. codeshift/migrator/transforms/__init__.py +19 -0
  44. codeshift/migrator/transforms/fastapi_transformer.py +174 -0
  45. codeshift/migrator/transforms/pandas_transformer.py +236 -0
  46. codeshift/migrator/transforms/pydantic_v1_to_v2.py +637 -0
  47. codeshift/migrator/transforms/requests_transformer.py +218 -0
  48. codeshift/migrator/transforms/sqlalchemy_transformer.py +175 -0
  49. codeshift/scanner/__init__.py +6 -0
  50. codeshift/scanner/code_scanner.py +352 -0
  51. codeshift/scanner/dependency_parser.py +473 -0
  52. codeshift/utils/__init__.py +5 -0
  53. codeshift/utils/api_client.py +266 -0
  54. codeshift/utils/cache.py +318 -0
  55. codeshift/utils/config.py +71 -0
  56. codeshift/utils/llm_client.py +221 -0
  57. codeshift/validator/__init__.py +6 -0
  58. codeshift/validator/syntax_checker.py +183 -0
  59. codeshift/validator/test_runner.py +224 -0
  60. codeshift-0.2.0.dist-info/METADATA +326 -0
  61. codeshift-0.2.0.dist-info/RECORD +65 -0
  62. codeshift-0.2.0.dist-info/WHEEL +5 -0
  63. codeshift-0.2.0.dist-info/entry_points.txt +2 -0
  64. codeshift-0.2.0.dist-info/licenses/LICENSE +21 -0
  65. codeshift-0.2.0.dist-info/top_level.txt +1 -0
codeshift/__init__.py ADDED
@@ -0,0 +1,8 @@
1
+ """
2
+ PyResolve - AI-powered CLI tool for migrating Python code to handle breaking dependency changes.
3
+
4
+ Don't just flag the update. Fix the break.
5
+ """
6
+
7
+ __version__ = "0.2.0"
8
+ __author__ = "PyResolve Team"
@@ -0,0 +1,5 @@
1
+ """Analyzer module for assessing migration risk and complexity."""
2
+
3
+ from codeshift.analyzer.risk_assessor import RiskAssessment, RiskAssessor
4
+
5
+ __all__ = ["RiskAssessor", "RiskAssessment"]
@@ -0,0 +1,388 @@
1
+ """Risk assessment for migration changes."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from enum import Enum
5
+
6
+ from codeshift.knowledge_base.models import BreakingChange, Severity
7
+ from codeshift.migrator.ast_transforms import TransformResult
8
+
9
+
10
+ class RiskLevel(Enum):
11
+ """Risk level for a migration."""
12
+
13
+ LOW = "low"
14
+ MEDIUM = "medium"
15
+ HIGH = "high"
16
+ CRITICAL = "critical"
17
+
18
+ def __lt__(self, other: object) -> bool:
19
+ if not isinstance(other, RiskLevel):
20
+ return NotImplemented
21
+ order = [RiskLevel.LOW, RiskLevel.MEDIUM, RiskLevel.HIGH, RiskLevel.CRITICAL]
22
+ return order.index(self) < order.index(other)
23
+
24
+ def __le__(self, other: object) -> bool:
25
+ return self == other or self < other
26
+
27
+ def __gt__(self, other: object) -> bool:
28
+ if not isinstance(other, RiskLevel):
29
+ return NotImplemented
30
+ return not self <= other
31
+
32
+ def __ge__(self, other: object) -> bool:
33
+ return self == other or self > other
34
+
35
+
36
+ @dataclass
37
+ class RiskFactor:
38
+ """A factor contributing to migration risk."""
39
+
40
+ name: str
41
+ description: str
42
+ severity: RiskLevel
43
+ score: float # 0.0 to 1.0
44
+ mitigation: str | None = None
45
+
46
+
47
+ @dataclass
48
+ class RiskAssessment:
49
+ """Overall risk assessment for a migration."""
50
+
51
+ overall_risk: RiskLevel
52
+ confidence_score: float # 0.0 to 1.0 (how confident we are in the migration)
53
+ factors: list[RiskFactor] = field(default_factory=list)
54
+ recommendations: list[str] = field(default_factory=list)
55
+
56
+ @property
57
+ def is_safe(self) -> bool:
58
+ """Check if the migration is considered safe."""
59
+ return (
60
+ self.overall_risk in (RiskLevel.LOW, RiskLevel.MEDIUM) and self.confidence_score >= 0.7
61
+ )
62
+
63
+ @property
64
+ def summary(self) -> str:
65
+ """Get a summary of the risk assessment."""
66
+ risk_emoji = {
67
+ RiskLevel.LOW: "✅",
68
+ RiskLevel.MEDIUM: "⚠️",
69
+ RiskLevel.HIGH: "🔶",
70
+ RiskLevel.CRITICAL: "🔴",
71
+ }
72
+ return f"{risk_emoji.get(self.overall_risk, '❓')} {self.overall_risk.value.title()} risk (confidence: {self.confidence_score:.0%})"
73
+
74
+
75
+ class RiskAssessor:
76
+ """Assesses risk of migration changes."""
77
+
78
+ def __init__(self) -> None:
79
+ """Initialize the risk assessor."""
80
+ # Weights for different risk factors
81
+ self.weights = {
82
+ "deterministic_transform": 0.3,
83
+ "test_coverage": 0.25,
84
+ "change_complexity": 0.2,
85
+ "file_criticality": 0.15,
86
+ "breaking_change_severity": 0.1,
87
+ }
88
+
89
+ def assess(
90
+ self,
91
+ results: list[TransformResult],
92
+ breaking_changes: list[BreakingChange] | None = None,
93
+ test_coverage: float | None = None,
94
+ ) -> RiskAssessment:
95
+ """Assess the risk of a migration.
96
+
97
+ Args:
98
+ results: List of transform results
99
+ breaking_changes: List of breaking changes being addressed
100
+ test_coverage: Optional test coverage percentage (0.0 to 1.0)
101
+
102
+ Returns:
103
+ RiskAssessment with overall risk evaluation
104
+ """
105
+ factors = []
106
+ recommendations = []
107
+
108
+ # Factor 1: Transform determinism
109
+ deterministic_factor = self._assess_determinism(results)
110
+ factors.append(deterministic_factor)
111
+
112
+ # Factor 2: Change complexity
113
+ complexity_factor = self._assess_complexity(results)
114
+ factors.append(complexity_factor)
115
+
116
+ # Factor 3: File criticality (heuristic based on file names/paths)
117
+ criticality_factor = self._assess_file_criticality(results)
118
+ factors.append(criticality_factor)
119
+
120
+ # Factor 4: Breaking change severity
121
+ if breaking_changes:
122
+ severity_factor = self._assess_breaking_change_severity(breaking_changes)
123
+ factors.append(severity_factor)
124
+
125
+ # Factor 5: Test coverage
126
+ if test_coverage is not None:
127
+ coverage_factor = self._assess_test_coverage(test_coverage)
128
+ factors.append(coverage_factor)
129
+ else:
130
+ recommendations.append("Run tests with coverage to improve confidence score")
131
+
132
+ # Calculate overall risk and confidence
133
+ overall_risk, confidence = self._calculate_overall_risk(factors)
134
+
135
+ # Add recommendations based on risk factors
136
+ for factor in factors:
137
+ if factor.severity in (RiskLevel.HIGH, RiskLevel.CRITICAL) and factor.mitigation:
138
+ recommendations.append(factor.mitigation)
139
+
140
+ # Standard recommendations
141
+ if overall_risk != RiskLevel.LOW:
142
+ recommendations.append("Review the diff carefully before applying changes")
143
+ recommendations.append("Run your full test suite after applying changes")
144
+ recommendations.append("Consider applying changes incrementally to isolate issues")
145
+
146
+ return RiskAssessment(
147
+ overall_risk=overall_risk,
148
+ confidence_score=confidence,
149
+ factors=factors,
150
+ recommendations=recommendations,
151
+ )
152
+
153
+ def _assess_determinism(self, results: list[TransformResult]) -> RiskFactor:
154
+ """Assess risk based on transform determinism."""
155
+ total_changes = sum(r.change_count for r in results)
156
+ if total_changes == 0:
157
+ return RiskFactor(
158
+ name="Transform Determinism",
159
+ description="No changes to assess",
160
+ severity=RiskLevel.LOW,
161
+ score=1.0,
162
+ )
163
+
164
+ # All our transforms are deterministic, so this is low risk
165
+ # In the future, LLM-based transforms would increase risk
166
+ return RiskFactor(
167
+ name="Transform Determinism",
168
+ description="All changes use deterministic AST transforms",
169
+ severity=RiskLevel.LOW,
170
+ score=0.9,
171
+ )
172
+
173
+ def _assess_complexity(self, results: list[TransformResult]) -> RiskFactor:
174
+ """Assess risk based on change complexity."""
175
+ total_changes = sum(r.change_count for r in results)
176
+ total_files = len(results)
177
+
178
+ if total_changes == 0:
179
+ return RiskFactor(
180
+ name="Change Complexity",
181
+ description="No changes",
182
+ severity=RiskLevel.LOW,
183
+ score=1.0,
184
+ )
185
+
186
+ # More changes = higher risk
187
+ if total_changes > 100:
188
+ severity = RiskLevel.HIGH
189
+ score = 0.4
190
+ description = f"Large migration: {total_changes} changes across {total_files} files"
191
+ mitigation = "Consider migrating in smaller batches"
192
+ elif total_changes > 50:
193
+ severity = RiskLevel.MEDIUM
194
+ score = 0.6
195
+ description = f"Medium migration: {total_changes} changes across {total_files} files"
196
+ mitigation = "Review changes carefully"
197
+ elif total_changes > 20:
198
+ severity = RiskLevel.LOW
199
+ score = 0.8
200
+ description = f"Small migration: {total_changes} changes across {total_files} files"
201
+ mitigation = None
202
+ else:
203
+ severity = RiskLevel.LOW
204
+ score = 0.9
205
+ description = f"Minimal migration: {total_changes} changes across {total_files} files"
206
+ mitigation = None
207
+
208
+ return RiskFactor(
209
+ name="Change Complexity",
210
+ description=description,
211
+ severity=severity,
212
+ score=score,
213
+ mitigation=mitigation,
214
+ )
215
+
216
+ def _assess_file_criticality(self, results: list[TransformResult]) -> RiskFactor:
217
+ """Assess risk based on which files are being modified."""
218
+ critical_patterns = [
219
+ "auth",
220
+ "security",
221
+ "payment",
222
+ "billing",
223
+ "config",
224
+ "settings",
225
+ "main",
226
+ "app",
227
+ "core",
228
+ "database",
229
+ "db",
230
+ "migration",
231
+ ]
232
+
233
+ critical_files = []
234
+ for result in results:
235
+ file_name = result.file_path.name.lower()
236
+ file_path_str = str(result.file_path).lower()
237
+
238
+ for pattern in critical_patterns:
239
+ if pattern in file_name or pattern in file_path_str:
240
+ critical_files.append(result.file_path)
241
+ break
242
+
243
+ if not critical_files:
244
+ return RiskFactor(
245
+ name="File Criticality",
246
+ description="No critical files identified",
247
+ severity=RiskLevel.LOW,
248
+ score=0.9,
249
+ )
250
+
251
+ ratio = len(critical_files) / len(results) if results else 0
252
+
253
+ if ratio > 0.5:
254
+ severity = RiskLevel.HIGH
255
+ score = 0.4
256
+ elif ratio > 0.2:
257
+ severity = RiskLevel.MEDIUM
258
+ score = 0.6
259
+ else:
260
+ severity = RiskLevel.LOW
261
+ score = 0.8
262
+
263
+ return RiskFactor(
264
+ name="File Criticality",
265
+ description=f"{len(critical_files)} critical file(s) affected: {', '.join(f.name for f in critical_files[:3])}",
266
+ severity=severity,
267
+ score=score,
268
+ mitigation=(
269
+ "Pay extra attention to critical files during review"
270
+ if severity != RiskLevel.LOW
271
+ else None
272
+ ),
273
+ )
274
+
275
+ def _assess_breaking_change_severity(
276
+ self, breaking_changes: list[BreakingChange]
277
+ ) -> RiskFactor:
278
+ """Assess risk based on breaking change severity."""
279
+ if not breaking_changes:
280
+ return RiskFactor(
281
+ name="Breaking Change Severity",
282
+ description="No breaking changes to assess",
283
+ severity=RiskLevel.LOW,
284
+ score=1.0,
285
+ )
286
+
287
+ severity_counts = {
288
+ Severity.LOW: 0,
289
+ Severity.MEDIUM: 0,
290
+ Severity.HIGH: 0,
291
+ Severity.CRITICAL: 0,
292
+ }
293
+
294
+ for change in breaking_changes:
295
+ severity_counts[change.severity] += 1
296
+
297
+ if severity_counts[Severity.CRITICAL] > 0:
298
+ severity = RiskLevel.CRITICAL
299
+ score = 0.2
300
+ elif severity_counts[Severity.HIGH] > 2:
301
+ severity = RiskLevel.HIGH
302
+ score = 0.4
303
+ elif severity_counts[Severity.HIGH] > 0 or severity_counts[Severity.MEDIUM] > 3:
304
+ severity = RiskLevel.MEDIUM
305
+ score = 0.6
306
+ else:
307
+ severity = RiskLevel.LOW
308
+ score = 0.8
309
+
310
+ return RiskFactor(
311
+ name="Breaking Change Severity",
312
+ description=f"Addressing {len(breaking_changes)} breaking changes",
313
+ severity=severity,
314
+ score=score,
315
+ mitigation=(
316
+ "Ensure thorough testing of affected functionality"
317
+ if severity != RiskLevel.LOW
318
+ else None
319
+ ),
320
+ )
321
+
322
+ def _assess_test_coverage(self, coverage: float) -> RiskFactor:
323
+ """Assess risk based on test coverage."""
324
+ if coverage >= 0.8:
325
+ severity = RiskLevel.LOW
326
+ score = 0.9
327
+ description = f"Good test coverage: {coverage:.0%}"
328
+ elif coverage >= 0.6:
329
+ severity = RiskLevel.LOW
330
+ score = 0.7
331
+ description = f"Moderate test coverage: {coverage:.0%}"
332
+ elif coverage >= 0.4:
333
+ severity = RiskLevel.MEDIUM
334
+ score = 0.5
335
+ description = f"Low test coverage: {coverage:.0%}"
336
+ mitigation = "Consider adding more tests before migration"
337
+ else:
338
+ severity = RiskLevel.HIGH
339
+ score = 0.3
340
+ description = f"Poor test coverage: {coverage:.0%}"
341
+ mitigation = "Strongly recommend adding tests before migration"
342
+
343
+ return RiskFactor(
344
+ name="Test Coverage",
345
+ description=description,
346
+ severity=severity,
347
+ score=score,
348
+ mitigation=mitigation if severity != RiskLevel.LOW else None,
349
+ )
350
+
351
+ def _calculate_overall_risk(self, factors: list[RiskFactor]) -> tuple[RiskLevel, float]:
352
+ """Calculate overall risk level and confidence score."""
353
+ if not factors:
354
+ return RiskLevel.LOW, 1.0
355
+
356
+ # Calculate weighted average score
357
+ total_weight = 0.0
358
+ weighted_score = 0.0
359
+
360
+ for factor in factors:
361
+ # Use factor name to look up weight, default to equal weighting
362
+ weight = self.weights.get(factor.name.lower().replace(" ", "_"), 1.0 / len(factors))
363
+ weighted_score += factor.score * weight
364
+ total_weight += weight
365
+
366
+ confidence = weighted_score / total_weight if total_weight > 0 else 0.5
367
+
368
+ # Determine overall risk based on worst factor and average
369
+ worst_severity = max(f.severity for f in factors)
370
+ severity_values = {
371
+ RiskLevel.LOW: 1,
372
+ RiskLevel.MEDIUM: 2,
373
+ RiskLevel.HIGH: 3,
374
+ RiskLevel.CRITICAL: 4,
375
+ }
376
+
377
+ avg_severity = sum(severity_values[f.severity] for f in factors) / len(factors)
378
+
379
+ if worst_severity == RiskLevel.CRITICAL or avg_severity > 3:
380
+ overall = RiskLevel.CRITICAL
381
+ elif worst_severity == RiskLevel.HIGH or avg_severity > 2.5:
382
+ overall = RiskLevel.HIGH
383
+ elif avg_severity > 1.5:
384
+ overall = RiskLevel.MEDIUM
385
+ else:
386
+ overall = RiskLevel.LOW
387
+
388
+ return overall, confidence
@@ -0,0 +1 @@
1
+ """PyResolve Billing API."""
codeshift/api/auth.py ADDED
@@ -0,0 +1,182 @@
1
+ """Authentication utilities and dependencies for the PyResolve API."""
2
+
3
+ import hashlib
4
+ import secrets
5
+ from collections.abc import Awaitable, Callable
6
+ from typing import Annotated
7
+
8
+ from fastapi import Depends, HTTPException, Security, status
9
+ from fastapi.security import APIKeyHeader
10
+
11
+ from codeshift.api.config import get_settings
12
+ from codeshift.api.database import get_database
13
+
14
+ # API Key header scheme
15
+ api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
16
+
17
+
18
+ def generate_api_key() -> tuple[str, str, str]:
19
+ """Generate a new API key.
20
+
21
+ Returns:
22
+ Tuple of (full_key, key_prefix, key_hash)
23
+ """
24
+ settings = get_settings()
25
+
26
+ # Generate 32 random bytes (256 bits of entropy)
27
+ key_suffix = secrets.token_urlsafe(32)
28
+
29
+ # Create the full key with prefix
30
+ full_key = f"{settings.api_key_prefix}{key_suffix}"
31
+
32
+ # Get prefix for identification (first 12 chars including prefix)
33
+ key_prefix = full_key[:12]
34
+
35
+ # Hash the full key for storage
36
+ key_hash = hash_api_key(full_key)
37
+
38
+ return full_key, key_prefix, key_hash
39
+
40
+
41
+ def hash_api_key(api_key: str) -> str:
42
+ """Hash an API key using SHA-256."""
43
+ return hashlib.sha256(api_key.encode()).hexdigest()
44
+
45
+
46
+ class AuthenticatedUser:
47
+ """Authenticated user context."""
48
+
49
+ def __init__(
50
+ self,
51
+ user_id: str,
52
+ email: str,
53
+ tier: str,
54
+ api_key_id: str | None = None,
55
+ scopes: list[str] | None = None,
56
+ ):
57
+ self.user_id = user_id
58
+ self.email = email
59
+ self.tier = tier
60
+ self.api_key_id = api_key_id
61
+ self.scopes = scopes or []
62
+
63
+ def has_scope(self, scope: str) -> bool:
64
+ """Check if user has a specific scope."""
65
+ return scope in self.scopes or "admin" in self.scopes
66
+
67
+
68
+ async def get_current_user(
69
+ api_key: Annotated[str | None, Security(api_key_header)] = None,
70
+ ) -> AuthenticatedUser:
71
+ """Validate API key and return the authenticated user.
72
+
73
+ Raises:
74
+ HTTPException: If API key is invalid or missing
75
+ """
76
+ if not api_key:
77
+ raise HTTPException(
78
+ status_code=status.HTTP_401_UNAUTHORIZED,
79
+ detail="Missing API key",
80
+ headers={"WWW-Authenticate": "ApiKey"},
81
+ )
82
+
83
+ # Hash the provided key
84
+ key_hash = hash_api_key(api_key)
85
+
86
+ # Look up the key in the database
87
+ db = get_database()
88
+ api_key_record = db.get_api_key_by_hash(key_hash)
89
+
90
+ if not api_key_record:
91
+ raise HTTPException(
92
+ status_code=status.HTTP_401_UNAUTHORIZED,
93
+ detail="Invalid API key",
94
+ headers={"WWW-Authenticate": "ApiKey"},
95
+ )
96
+
97
+ # Check if key is expired
98
+ if api_key_record.get("expires_at"):
99
+ from datetime import datetime, timezone
100
+
101
+ expires_at = api_key_record["expires_at"]
102
+ if isinstance(expires_at, str):
103
+ expires_at = datetime.fromisoformat(expires_at.replace("Z", "+00:00"))
104
+ if expires_at < datetime.now(timezone.utc):
105
+ raise HTTPException(
106
+ status_code=status.HTTP_401_UNAUTHORIZED,
107
+ detail="API key has expired",
108
+ headers={"WWW-Authenticate": "ApiKey"},
109
+ )
110
+
111
+ # Update last used timestamp
112
+ db.update_api_key_last_used(api_key_record["id"])
113
+
114
+ # Get profile data
115
+ profile = api_key_record.get("profiles", {})
116
+
117
+ return AuthenticatedUser(
118
+ user_id=api_key_record["user_id"],
119
+ email=profile.get("email", ""),
120
+ tier=profile.get("tier", "free"),
121
+ api_key_id=api_key_record["id"],
122
+ scopes=api_key_record.get("scopes", []),
123
+ )
124
+
125
+
126
+ async def get_optional_user(
127
+ api_key: Annotated[str | None, Security(api_key_header)] = None,
128
+ ) -> AuthenticatedUser | None:
129
+ """Get the current user if authenticated, otherwise return None.
130
+
131
+ This is useful for endpoints that work both authenticated and unauthenticated.
132
+ """
133
+ if not api_key:
134
+ return None
135
+
136
+ try:
137
+ return await get_current_user(api_key)
138
+ except HTTPException:
139
+ return None
140
+
141
+
142
+ def require_scope(scope: str) -> Callable[..., Awaitable[AuthenticatedUser]]:
143
+ """Dependency that requires a specific scope."""
144
+
145
+ async def check_scope(
146
+ user: Annotated[AuthenticatedUser, Depends(get_current_user)],
147
+ ) -> AuthenticatedUser:
148
+ if not user.has_scope(scope):
149
+ raise HTTPException(
150
+ status_code=status.HTTP_403_FORBIDDEN,
151
+ detail=f"Scope '{scope}' required",
152
+ )
153
+ return user
154
+
155
+ return check_scope
156
+
157
+
158
+ def require_tier(minimum_tier: str) -> Callable[..., Awaitable[AuthenticatedUser]]:
159
+ """Dependency that requires a minimum tier."""
160
+ tier_levels = {"free": 0, "pro": 1, "unlimited": 2, "enterprise": 3}
161
+
162
+ async def check_tier(
163
+ user: Annotated[AuthenticatedUser, Depends(get_current_user)],
164
+ ) -> AuthenticatedUser:
165
+ user_level = tier_levels.get(user.tier, 0)
166
+ required_level = tier_levels.get(minimum_tier, 0)
167
+
168
+ if user_level < required_level:
169
+ raise HTTPException(
170
+ status_code=status.HTTP_403_FORBIDDEN,
171
+ detail=f"This feature requires {minimum_tier} tier or higher",
172
+ )
173
+ return user
174
+
175
+ return check_tier
176
+
177
+
178
+ # Type aliases for dependency injection
179
+ CurrentUser = Annotated[AuthenticatedUser, Depends(get_current_user)]
180
+ OptionalUser = Annotated[AuthenticatedUser | None, Depends(get_optional_user)]
181
+ ProUser = Annotated[AuthenticatedUser, Depends(require_tier("pro"))]
182
+ UnlimitedUser = Annotated[AuthenticatedUser, Depends(require_tier("unlimited"))]
@@ -0,0 +1,73 @@
1
+ """API configuration settings."""
2
+
3
+ from functools import lru_cache
4
+
5
+ from pydantic_settings import BaseSettings
6
+
7
+
8
+ class APISettings(BaseSettings):
9
+ """Configuration settings for the Codeshift API."""
10
+
11
+ # Supabase
12
+ supabase_url: str = ""
13
+ supabase_anon_key: str = ""
14
+ supabase_service_role_key: str = ""
15
+
16
+ # Stripe
17
+ stripe_secret_key: str = ""
18
+ stripe_webhook_secret: str = ""
19
+ stripe_price_id_pro: str = ""
20
+ stripe_price_id_unlimited: str = ""
21
+
22
+ # Anthropic (for server-side LLM calls)
23
+ anthropic_api_key: str = ""
24
+
25
+ # API settings
26
+ codeshift_api_url: str = "https://py-resolve.replit.app"
27
+ api_key_prefix: str = "cs_"
28
+
29
+ # Tier quotas
30
+ tier_free_files: int = 100
31
+ tier_free_llm_calls: int = 50
32
+ tier_pro_files: int = 1000
33
+ tier_pro_llm_calls: int = 500
34
+ tier_unlimited_files: int = 999999999
35
+ tier_unlimited_llm_calls: int = 999999999
36
+
37
+ # Environment
38
+ environment: str = "development"
39
+
40
+ model_config = {
41
+ "env_prefix": "",
42
+ "env_file": ".env",
43
+ "extra": "ignore",
44
+ }
45
+
46
+ @property
47
+ def is_production(self) -> bool:
48
+ """Check if running in production."""
49
+ return self.environment == "production"
50
+
51
+ def get_tier_limits(self, tier: str) -> dict[str, int]:
52
+ """Get quota limits for a tier."""
53
+ limits = {
54
+ "free": {
55
+ "files_per_month": self.tier_free_files,
56
+ "llm_calls_per_month": self.tier_free_llm_calls,
57
+ },
58
+ "pro": {
59
+ "files_per_month": self.tier_pro_files,
60
+ "llm_calls_per_month": self.tier_pro_llm_calls,
61
+ },
62
+ "unlimited": {
63
+ "files_per_month": self.tier_unlimited_files,
64
+ "llm_calls_per_month": self.tier_unlimited_llm_calls,
65
+ },
66
+ }
67
+ return limits.get(tier, limits["free"])
68
+
69
+
70
+ @lru_cache
71
+ def get_settings() -> APISettings:
72
+ """Get cached API settings."""
73
+ return APISettings()