codeshift 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codeshift/__init__.py +8 -0
- codeshift/analyzer/__init__.py +5 -0
- codeshift/analyzer/risk_assessor.py +388 -0
- codeshift/api/__init__.py +1 -0
- codeshift/api/auth.py +182 -0
- codeshift/api/config.py +73 -0
- codeshift/api/database.py +215 -0
- codeshift/api/main.py +103 -0
- codeshift/api/models/__init__.py +55 -0
- codeshift/api/models/auth.py +108 -0
- codeshift/api/models/billing.py +92 -0
- codeshift/api/models/migrate.py +42 -0
- codeshift/api/models/usage.py +116 -0
- codeshift/api/routers/__init__.py +5 -0
- codeshift/api/routers/auth.py +440 -0
- codeshift/api/routers/billing.py +395 -0
- codeshift/api/routers/migrate.py +304 -0
- codeshift/api/routers/usage.py +291 -0
- codeshift/api/routers/webhooks.py +289 -0
- codeshift/cli/__init__.py +5 -0
- codeshift/cli/commands/__init__.py +7 -0
- codeshift/cli/commands/apply.py +352 -0
- codeshift/cli/commands/auth.py +842 -0
- codeshift/cli/commands/diff.py +221 -0
- codeshift/cli/commands/scan.py +368 -0
- codeshift/cli/commands/upgrade.py +436 -0
- codeshift/cli/commands/upgrade_all.py +518 -0
- codeshift/cli/main.py +221 -0
- codeshift/cli/quota.py +210 -0
- codeshift/knowledge/__init__.py +50 -0
- codeshift/knowledge/cache.py +167 -0
- codeshift/knowledge/generator.py +231 -0
- codeshift/knowledge/models.py +151 -0
- codeshift/knowledge/parser.py +270 -0
- codeshift/knowledge/sources.py +388 -0
- codeshift/knowledge_base/__init__.py +17 -0
- codeshift/knowledge_base/loader.py +102 -0
- codeshift/knowledge_base/models.py +110 -0
- codeshift/migrator/__init__.py +23 -0
- codeshift/migrator/ast_transforms.py +256 -0
- codeshift/migrator/engine.py +395 -0
- codeshift/migrator/llm_migrator.py +320 -0
- codeshift/migrator/transforms/__init__.py +19 -0
- codeshift/migrator/transforms/fastapi_transformer.py +174 -0
- codeshift/migrator/transforms/pandas_transformer.py +236 -0
- codeshift/migrator/transforms/pydantic_v1_to_v2.py +637 -0
- codeshift/migrator/transforms/requests_transformer.py +218 -0
- codeshift/migrator/transforms/sqlalchemy_transformer.py +175 -0
- codeshift/scanner/__init__.py +6 -0
- codeshift/scanner/code_scanner.py +352 -0
- codeshift/scanner/dependency_parser.py +473 -0
- codeshift/utils/__init__.py +5 -0
- codeshift/utils/api_client.py +266 -0
- codeshift/utils/cache.py +318 -0
- codeshift/utils/config.py +71 -0
- codeshift/utils/llm_client.py +221 -0
- codeshift/validator/__init__.py +6 -0
- codeshift/validator/syntax_checker.py +183 -0
- codeshift/validator/test_runner.py +224 -0
- codeshift-0.2.0.dist-info/METADATA +326 -0
- codeshift-0.2.0.dist-info/RECORD +65 -0
- codeshift-0.2.0.dist-info/WHEEL +5 -0
- codeshift-0.2.0.dist-info/entry_points.txt +2 -0
- codeshift-0.2.0.dist-info/licenses/LICENSE +21 -0
- codeshift-0.2.0.dist-info/top_level.txt +1 -0
codeshift/__init__.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
"""Risk assessment for migration changes."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from enum import Enum
|
|
5
|
+
|
|
6
|
+
from codeshift.knowledge_base.models import BreakingChange, Severity
|
|
7
|
+
from codeshift.migrator.ast_transforms import TransformResult
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RiskLevel(Enum):
|
|
11
|
+
"""Risk level for a migration."""
|
|
12
|
+
|
|
13
|
+
LOW = "low"
|
|
14
|
+
MEDIUM = "medium"
|
|
15
|
+
HIGH = "high"
|
|
16
|
+
CRITICAL = "critical"
|
|
17
|
+
|
|
18
|
+
def __lt__(self, other: object) -> bool:
|
|
19
|
+
if not isinstance(other, RiskLevel):
|
|
20
|
+
return NotImplemented
|
|
21
|
+
order = [RiskLevel.LOW, RiskLevel.MEDIUM, RiskLevel.HIGH, RiskLevel.CRITICAL]
|
|
22
|
+
return order.index(self) < order.index(other)
|
|
23
|
+
|
|
24
|
+
def __le__(self, other: object) -> bool:
|
|
25
|
+
return self == other or self < other
|
|
26
|
+
|
|
27
|
+
def __gt__(self, other: object) -> bool:
|
|
28
|
+
if not isinstance(other, RiskLevel):
|
|
29
|
+
return NotImplemented
|
|
30
|
+
return not self <= other
|
|
31
|
+
|
|
32
|
+
def __ge__(self, other: object) -> bool:
|
|
33
|
+
return self == other or self > other
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class RiskFactor:
|
|
38
|
+
"""A factor contributing to migration risk."""
|
|
39
|
+
|
|
40
|
+
name: str
|
|
41
|
+
description: str
|
|
42
|
+
severity: RiskLevel
|
|
43
|
+
score: float # 0.0 to 1.0
|
|
44
|
+
mitigation: str | None = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class RiskAssessment:
|
|
49
|
+
"""Overall risk assessment for a migration."""
|
|
50
|
+
|
|
51
|
+
overall_risk: RiskLevel
|
|
52
|
+
confidence_score: float # 0.0 to 1.0 (how confident we are in the migration)
|
|
53
|
+
factors: list[RiskFactor] = field(default_factory=list)
|
|
54
|
+
recommendations: list[str] = field(default_factory=list)
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def is_safe(self) -> bool:
|
|
58
|
+
"""Check if the migration is considered safe."""
|
|
59
|
+
return (
|
|
60
|
+
self.overall_risk in (RiskLevel.LOW, RiskLevel.MEDIUM) and self.confidence_score >= 0.7
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def summary(self) -> str:
|
|
65
|
+
"""Get a summary of the risk assessment."""
|
|
66
|
+
risk_emoji = {
|
|
67
|
+
RiskLevel.LOW: "✅",
|
|
68
|
+
RiskLevel.MEDIUM: "⚠️",
|
|
69
|
+
RiskLevel.HIGH: "🔶",
|
|
70
|
+
RiskLevel.CRITICAL: "🔴",
|
|
71
|
+
}
|
|
72
|
+
return f"{risk_emoji.get(self.overall_risk, '❓')} {self.overall_risk.value.title()} risk (confidence: {self.confidence_score:.0%})"
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class RiskAssessor:
|
|
76
|
+
"""Assesses risk of migration changes."""
|
|
77
|
+
|
|
78
|
+
def __init__(self) -> None:
|
|
79
|
+
"""Initialize the risk assessor."""
|
|
80
|
+
# Weights for different risk factors
|
|
81
|
+
self.weights = {
|
|
82
|
+
"deterministic_transform": 0.3,
|
|
83
|
+
"test_coverage": 0.25,
|
|
84
|
+
"change_complexity": 0.2,
|
|
85
|
+
"file_criticality": 0.15,
|
|
86
|
+
"breaking_change_severity": 0.1,
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
def assess(
|
|
90
|
+
self,
|
|
91
|
+
results: list[TransformResult],
|
|
92
|
+
breaking_changes: list[BreakingChange] | None = None,
|
|
93
|
+
test_coverage: float | None = None,
|
|
94
|
+
) -> RiskAssessment:
|
|
95
|
+
"""Assess the risk of a migration.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
results: List of transform results
|
|
99
|
+
breaking_changes: List of breaking changes being addressed
|
|
100
|
+
test_coverage: Optional test coverage percentage (0.0 to 1.0)
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
RiskAssessment with overall risk evaluation
|
|
104
|
+
"""
|
|
105
|
+
factors = []
|
|
106
|
+
recommendations = []
|
|
107
|
+
|
|
108
|
+
# Factor 1: Transform determinism
|
|
109
|
+
deterministic_factor = self._assess_determinism(results)
|
|
110
|
+
factors.append(deterministic_factor)
|
|
111
|
+
|
|
112
|
+
# Factor 2: Change complexity
|
|
113
|
+
complexity_factor = self._assess_complexity(results)
|
|
114
|
+
factors.append(complexity_factor)
|
|
115
|
+
|
|
116
|
+
# Factor 3: File criticality (heuristic based on file names/paths)
|
|
117
|
+
criticality_factor = self._assess_file_criticality(results)
|
|
118
|
+
factors.append(criticality_factor)
|
|
119
|
+
|
|
120
|
+
# Factor 4: Breaking change severity
|
|
121
|
+
if breaking_changes:
|
|
122
|
+
severity_factor = self._assess_breaking_change_severity(breaking_changes)
|
|
123
|
+
factors.append(severity_factor)
|
|
124
|
+
|
|
125
|
+
# Factor 5: Test coverage
|
|
126
|
+
if test_coverage is not None:
|
|
127
|
+
coverage_factor = self._assess_test_coverage(test_coverage)
|
|
128
|
+
factors.append(coverage_factor)
|
|
129
|
+
else:
|
|
130
|
+
recommendations.append("Run tests with coverage to improve confidence score")
|
|
131
|
+
|
|
132
|
+
# Calculate overall risk and confidence
|
|
133
|
+
overall_risk, confidence = self._calculate_overall_risk(factors)
|
|
134
|
+
|
|
135
|
+
# Add recommendations based on risk factors
|
|
136
|
+
for factor in factors:
|
|
137
|
+
if factor.severity in (RiskLevel.HIGH, RiskLevel.CRITICAL) and factor.mitigation:
|
|
138
|
+
recommendations.append(factor.mitigation)
|
|
139
|
+
|
|
140
|
+
# Standard recommendations
|
|
141
|
+
if overall_risk != RiskLevel.LOW:
|
|
142
|
+
recommendations.append("Review the diff carefully before applying changes")
|
|
143
|
+
recommendations.append("Run your full test suite after applying changes")
|
|
144
|
+
recommendations.append("Consider applying changes incrementally to isolate issues")
|
|
145
|
+
|
|
146
|
+
return RiskAssessment(
|
|
147
|
+
overall_risk=overall_risk,
|
|
148
|
+
confidence_score=confidence,
|
|
149
|
+
factors=factors,
|
|
150
|
+
recommendations=recommendations,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
def _assess_determinism(self, results: list[TransformResult]) -> RiskFactor:
|
|
154
|
+
"""Assess risk based on transform determinism."""
|
|
155
|
+
total_changes = sum(r.change_count for r in results)
|
|
156
|
+
if total_changes == 0:
|
|
157
|
+
return RiskFactor(
|
|
158
|
+
name="Transform Determinism",
|
|
159
|
+
description="No changes to assess",
|
|
160
|
+
severity=RiskLevel.LOW,
|
|
161
|
+
score=1.0,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# All our transforms are deterministic, so this is low risk
|
|
165
|
+
# In the future, LLM-based transforms would increase risk
|
|
166
|
+
return RiskFactor(
|
|
167
|
+
name="Transform Determinism",
|
|
168
|
+
description="All changes use deterministic AST transforms",
|
|
169
|
+
severity=RiskLevel.LOW,
|
|
170
|
+
score=0.9,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def _assess_complexity(self, results: list[TransformResult]) -> RiskFactor:
|
|
174
|
+
"""Assess risk based on change complexity."""
|
|
175
|
+
total_changes = sum(r.change_count for r in results)
|
|
176
|
+
total_files = len(results)
|
|
177
|
+
|
|
178
|
+
if total_changes == 0:
|
|
179
|
+
return RiskFactor(
|
|
180
|
+
name="Change Complexity",
|
|
181
|
+
description="No changes",
|
|
182
|
+
severity=RiskLevel.LOW,
|
|
183
|
+
score=1.0,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# More changes = higher risk
|
|
187
|
+
if total_changes > 100:
|
|
188
|
+
severity = RiskLevel.HIGH
|
|
189
|
+
score = 0.4
|
|
190
|
+
description = f"Large migration: {total_changes} changes across {total_files} files"
|
|
191
|
+
mitigation = "Consider migrating in smaller batches"
|
|
192
|
+
elif total_changes > 50:
|
|
193
|
+
severity = RiskLevel.MEDIUM
|
|
194
|
+
score = 0.6
|
|
195
|
+
description = f"Medium migration: {total_changes} changes across {total_files} files"
|
|
196
|
+
mitigation = "Review changes carefully"
|
|
197
|
+
elif total_changes > 20:
|
|
198
|
+
severity = RiskLevel.LOW
|
|
199
|
+
score = 0.8
|
|
200
|
+
description = f"Small migration: {total_changes} changes across {total_files} files"
|
|
201
|
+
mitigation = None
|
|
202
|
+
else:
|
|
203
|
+
severity = RiskLevel.LOW
|
|
204
|
+
score = 0.9
|
|
205
|
+
description = f"Minimal migration: {total_changes} changes across {total_files} files"
|
|
206
|
+
mitigation = None
|
|
207
|
+
|
|
208
|
+
return RiskFactor(
|
|
209
|
+
name="Change Complexity",
|
|
210
|
+
description=description,
|
|
211
|
+
severity=severity,
|
|
212
|
+
score=score,
|
|
213
|
+
mitigation=mitigation,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
def _assess_file_criticality(self, results: list[TransformResult]) -> RiskFactor:
|
|
217
|
+
"""Assess risk based on which files are being modified."""
|
|
218
|
+
critical_patterns = [
|
|
219
|
+
"auth",
|
|
220
|
+
"security",
|
|
221
|
+
"payment",
|
|
222
|
+
"billing",
|
|
223
|
+
"config",
|
|
224
|
+
"settings",
|
|
225
|
+
"main",
|
|
226
|
+
"app",
|
|
227
|
+
"core",
|
|
228
|
+
"database",
|
|
229
|
+
"db",
|
|
230
|
+
"migration",
|
|
231
|
+
]
|
|
232
|
+
|
|
233
|
+
critical_files = []
|
|
234
|
+
for result in results:
|
|
235
|
+
file_name = result.file_path.name.lower()
|
|
236
|
+
file_path_str = str(result.file_path).lower()
|
|
237
|
+
|
|
238
|
+
for pattern in critical_patterns:
|
|
239
|
+
if pattern in file_name or pattern in file_path_str:
|
|
240
|
+
critical_files.append(result.file_path)
|
|
241
|
+
break
|
|
242
|
+
|
|
243
|
+
if not critical_files:
|
|
244
|
+
return RiskFactor(
|
|
245
|
+
name="File Criticality",
|
|
246
|
+
description="No critical files identified",
|
|
247
|
+
severity=RiskLevel.LOW,
|
|
248
|
+
score=0.9,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
ratio = len(critical_files) / len(results) if results else 0
|
|
252
|
+
|
|
253
|
+
if ratio > 0.5:
|
|
254
|
+
severity = RiskLevel.HIGH
|
|
255
|
+
score = 0.4
|
|
256
|
+
elif ratio > 0.2:
|
|
257
|
+
severity = RiskLevel.MEDIUM
|
|
258
|
+
score = 0.6
|
|
259
|
+
else:
|
|
260
|
+
severity = RiskLevel.LOW
|
|
261
|
+
score = 0.8
|
|
262
|
+
|
|
263
|
+
return RiskFactor(
|
|
264
|
+
name="File Criticality",
|
|
265
|
+
description=f"{len(critical_files)} critical file(s) affected: {', '.join(f.name for f in critical_files[:3])}",
|
|
266
|
+
severity=severity,
|
|
267
|
+
score=score,
|
|
268
|
+
mitigation=(
|
|
269
|
+
"Pay extra attention to critical files during review"
|
|
270
|
+
if severity != RiskLevel.LOW
|
|
271
|
+
else None
|
|
272
|
+
),
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
def _assess_breaking_change_severity(
|
|
276
|
+
self, breaking_changes: list[BreakingChange]
|
|
277
|
+
) -> RiskFactor:
|
|
278
|
+
"""Assess risk based on breaking change severity."""
|
|
279
|
+
if not breaking_changes:
|
|
280
|
+
return RiskFactor(
|
|
281
|
+
name="Breaking Change Severity",
|
|
282
|
+
description="No breaking changes to assess",
|
|
283
|
+
severity=RiskLevel.LOW,
|
|
284
|
+
score=1.0,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
severity_counts = {
|
|
288
|
+
Severity.LOW: 0,
|
|
289
|
+
Severity.MEDIUM: 0,
|
|
290
|
+
Severity.HIGH: 0,
|
|
291
|
+
Severity.CRITICAL: 0,
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
for change in breaking_changes:
|
|
295
|
+
severity_counts[change.severity] += 1
|
|
296
|
+
|
|
297
|
+
if severity_counts[Severity.CRITICAL] > 0:
|
|
298
|
+
severity = RiskLevel.CRITICAL
|
|
299
|
+
score = 0.2
|
|
300
|
+
elif severity_counts[Severity.HIGH] > 2:
|
|
301
|
+
severity = RiskLevel.HIGH
|
|
302
|
+
score = 0.4
|
|
303
|
+
elif severity_counts[Severity.HIGH] > 0 or severity_counts[Severity.MEDIUM] > 3:
|
|
304
|
+
severity = RiskLevel.MEDIUM
|
|
305
|
+
score = 0.6
|
|
306
|
+
else:
|
|
307
|
+
severity = RiskLevel.LOW
|
|
308
|
+
score = 0.8
|
|
309
|
+
|
|
310
|
+
return RiskFactor(
|
|
311
|
+
name="Breaking Change Severity",
|
|
312
|
+
description=f"Addressing {len(breaking_changes)} breaking changes",
|
|
313
|
+
severity=severity,
|
|
314
|
+
score=score,
|
|
315
|
+
mitigation=(
|
|
316
|
+
"Ensure thorough testing of affected functionality"
|
|
317
|
+
if severity != RiskLevel.LOW
|
|
318
|
+
else None
|
|
319
|
+
),
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
def _assess_test_coverage(self, coverage: float) -> RiskFactor:
|
|
323
|
+
"""Assess risk based on test coverage."""
|
|
324
|
+
if coverage >= 0.8:
|
|
325
|
+
severity = RiskLevel.LOW
|
|
326
|
+
score = 0.9
|
|
327
|
+
description = f"Good test coverage: {coverage:.0%}"
|
|
328
|
+
elif coverage >= 0.6:
|
|
329
|
+
severity = RiskLevel.LOW
|
|
330
|
+
score = 0.7
|
|
331
|
+
description = f"Moderate test coverage: {coverage:.0%}"
|
|
332
|
+
elif coverage >= 0.4:
|
|
333
|
+
severity = RiskLevel.MEDIUM
|
|
334
|
+
score = 0.5
|
|
335
|
+
description = f"Low test coverage: {coverage:.0%}"
|
|
336
|
+
mitigation = "Consider adding more tests before migration"
|
|
337
|
+
else:
|
|
338
|
+
severity = RiskLevel.HIGH
|
|
339
|
+
score = 0.3
|
|
340
|
+
description = f"Poor test coverage: {coverage:.0%}"
|
|
341
|
+
mitigation = "Strongly recommend adding tests before migration"
|
|
342
|
+
|
|
343
|
+
return RiskFactor(
|
|
344
|
+
name="Test Coverage",
|
|
345
|
+
description=description,
|
|
346
|
+
severity=severity,
|
|
347
|
+
score=score,
|
|
348
|
+
mitigation=mitigation if severity != RiskLevel.LOW else None,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
def _calculate_overall_risk(self, factors: list[RiskFactor]) -> tuple[RiskLevel, float]:
|
|
352
|
+
"""Calculate overall risk level and confidence score."""
|
|
353
|
+
if not factors:
|
|
354
|
+
return RiskLevel.LOW, 1.0
|
|
355
|
+
|
|
356
|
+
# Calculate weighted average score
|
|
357
|
+
total_weight = 0.0
|
|
358
|
+
weighted_score = 0.0
|
|
359
|
+
|
|
360
|
+
for factor in factors:
|
|
361
|
+
# Use factor name to look up weight, default to equal weighting
|
|
362
|
+
weight = self.weights.get(factor.name.lower().replace(" ", "_"), 1.0 / len(factors))
|
|
363
|
+
weighted_score += factor.score * weight
|
|
364
|
+
total_weight += weight
|
|
365
|
+
|
|
366
|
+
confidence = weighted_score / total_weight if total_weight > 0 else 0.5
|
|
367
|
+
|
|
368
|
+
# Determine overall risk based on worst factor and average
|
|
369
|
+
worst_severity = max(f.severity for f in factors)
|
|
370
|
+
severity_values = {
|
|
371
|
+
RiskLevel.LOW: 1,
|
|
372
|
+
RiskLevel.MEDIUM: 2,
|
|
373
|
+
RiskLevel.HIGH: 3,
|
|
374
|
+
RiskLevel.CRITICAL: 4,
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
avg_severity = sum(severity_values[f.severity] for f in factors) / len(factors)
|
|
378
|
+
|
|
379
|
+
if worst_severity == RiskLevel.CRITICAL or avg_severity > 3:
|
|
380
|
+
overall = RiskLevel.CRITICAL
|
|
381
|
+
elif worst_severity == RiskLevel.HIGH or avg_severity > 2.5:
|
|
382
|
+
overall = RiskLevel.HIGH
|
|
383
|
+
elif avg_severity > 1.5:
|
|
384
|
+
overall = RiskLevel.MEDIUM
|
|
385
|
+
else:
|
|
386
|
+
overall = RiskLevel.LOW
|
|
387
|
+
|
|
388
|
+
return overall, confidence
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""PyResolve Billing API."""
|
codeshift/api/auth.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""Authentication utilities and dependencies for the PyResolve API."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import secrets
|
|
5
|
+
from collections.abc import Awaitable, Callable
|
|
6
|
+
from typing import Annotated
|
|
7
|
+
|
|
8
|
+
from fastapi import Depends, HTTPException, Security, status
|
|
9
|
+
from fastapi.security import APIKeyHeader
|
|
10
|
+
|
|
11
|
+
from codeshift.api.config import get_settings
|
|
12
|
+
from codeshift.api.database import get_database
|
|
13
|
+
|
|
14
|
+
# API Key header scheme
|
|
15
|
+
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def generate_api_key() -> tuple[str, str, str]:
|
|
19
|
+
"""Generate a new API key.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Tuple of (full_key, key_prefix, key_hash)
|
|
23
|
+
"""
|
|
24
|
+
settings = get_settings()
|
|
25
|
+
|
|
26
|
+
# Generate 32 random bytes (256 bits of entropy)
|
|
27
|
+
key_suffix = secrets.token_urlsafe(32)
|
|
28
|
+
|
|
29
|
+
# Create the full key with prefix
|
|
30
|
+
full_key = f"{settings.api_key_prefix}{key_suffix}"
|
|
31
|
+
|
|
32
|
+
# Get prefix for identification (first 12 chars including prefix)
|
|
33
|
+
key_prefix = full_key[:12]
|
|
34
|
+
|
|
35
|
+
# Hash the full key for storage
|
|
36
|
+
key_hash = hash_api_key(full_key)
|
|
37
|
+
|
|
38
|
+
return full_key, key_prefix, key_hash
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def hash_api_key(api_key: str) -> str:
|
|
42
|
+
"""Hash an API key using SHA-256."""
|
|
43
|
+
return hashlib.sha256(api_key.encode()).hexdigest()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class AuthenticatedUser:
|
|
47
|
+
"""Authenticated user context."""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
user_id: str,
|
|
52
|
+
email: str,
|
|
53
|
+
tier: str,
|
|
54
|
+
api_key_id: str | None = None,
|
|
55
|
+
scopes: list[str] | None = None,
|
|
56
|
+
):
|
|
57
|
+
self.user_id = user_id
|
|
58
|
+
self.email = email
|
|
59
|
+
self.tier = tier
|
|
60
|
+
self.api_key_id = api_key_id
|
|
61
|
+
self.scopes = scopes or []
|
|
62
|
+
|
|
63
|
+
def has_scope(self, scope: str) -> bool:
|
|
64
|
+
"""Check if user has a specific scope."""
|
|
65
|
+
return scope in self.scopes or "admin" in self.scopes
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
async def get_current_user(
|
|
69
|
+
api_key: Annotated[str | None, Security(api_key_header)] = None,
|
|
70
|
+
) -> AuthenticatedUser:
|
|
71
|
+
"""Validate API key and return the authenticated user.
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
HTTPException: If API key is invalid or missing
|
|
75
|
+
"""
|
|
76
|
+
if not api_key:
|
|
77
|
+
raise HTTPException(
|
|
78
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
79
|
+
detail="Missing API key",
|
|
80
|
+
headers={"WWW-Authenticate": "ApiKey"},
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Hash the provided key
|
|
84
|
+
key_hash = hash_api_key(api_key)
|
|
85
|
+
|
|
86
|
+
# Look up the key in the database
|
|
87
|
+
db = get_database()
|
|
88
|
+
api_key_record = db.get_api_key_by_hash(key_hash)
|
|
89
|
+
|
|
90
|
+
if not api_key_record:
|
|
91
|
+
raise HTTPException(
|
|
92
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
93
|
+
detail="Invalid API key",
|
|
94
|
+
headers={"WWW-Authenticate": "ApiKey"},
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Check if key is expired
|
|
98
|
+
if api_key_record.get("expires_at"):
|
|
99
|
+
from datetime import datetime, timezone
|
|
100
|
+
|
|
101
|
+
expires_at = api_key_record["expires_at"]
|
|
102
|
+
if isinstance(expires_at, str):
|
|
103
|
+
expires_at = datetime.fromisoformat(expires_at.replace("Z", "+00:00"))
|
|
104
|
+
if expires_at < datetime.now(timezone.utc):
|
|
105
|
+
raise HTTPException(
|
|
106
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
107
|
+
detail="API key has expired",
|
|
108
|
+
headers={"WWW-Authenticate": "ApiKey"},
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Update last used timestamp
|
|
112
|
+
db.update_api_key_last_used(api_key_record["id"])
|
|
113
|
+
|
|
114
|
+
# Get profile data
|
|
115
|
+
profile = api_key_record.get("profiles", {})
|
|
116
|
+
|
|
117
|
+
return AuthenticatedUser(
|
|
118
|
+
user_id=api_key_record["user_id"],
|
|
119
|
+
email=profile.get("email", ""),
|
|
120
|
+
tier=profile.get("tier", "free"),
|
|
121
|
+
api_key_id=api_key_record["id"],
|
|
122
|
+
scopes=api_key_record.get("scopes", []),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
async def get_optional_user(
|
|
127
|
+
api_key: Annotated[str | None, Security(api_key_header)] = None,
|
|
128
|
+
) -> AuthenticatedUser | None:
|
|
129
|
+
"""Get the current user if authenticated, otherwise return None.
|
|
130
|
+
|
|
131
|
+
This is useful for endpoints that work both authenticated and unauthenticated.
|
|
132
|
+
"""
|
|
133
|
+
if not api_key:
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
return await get_current_user(api_key)
|
|
138
|
+
except HTTPException:
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def require_scope(scope: str) -> Callable[..., Awaitable[AuthenticatedUser]]:
|
|
143
|
+
"""Dependency that requires a specific scope."""
|
|
144
|
+
|
|
145
|
+
async def check_scope(
|
|
146
|
+
user: Annotated[AuthenticatedUser, Depends(get_current_user)],
|
|
147
|
+
) -> AuthenticatedUser:
|
|
148
|
+
if not user.has_scope(scope):
|
|
149
|
+
raise HTTPException(
|
|
150
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
151
|
+
detail=f"Scope '{scope}' required",
|
|
152
|
+
)
|
|
153
|
+
return user
|
|
154
|
+
|
|
155
|
+
return check_scope
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def require_tier(minimum_tier: str) -> Callable[..., Awaitable[AuthenticatedUser]]:
|
|
159
|
+
"""Dependency that requires a minimum tier."""
|
|
160
|
+
tier_levels = {"free": 0, "pro": 1, "unlimited": 2, "enterprise": 3}
|
|
161
|
+
|
|
162
|
+
async def check_tier(
|
|
163
|
+
user: Annotated[AuthenticatedUser, Depends(get_current_user)],
|
|
164
|
+
) -> AuthenticatedUser:
|
|
165
|
+
user_level = tier_levels.get(user.tier, 0)
|
|
166
|
+
required_level = tier_levels.get(minimum_tier, 0)
|
|
167
|
+
|
|
168
|
+
if user_level < required_level:
|
|
169
|
+
raise HTTPException(
|
|
170
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
171
|
+
detail=f"This feature requires {minimum_tier} tier or higher",
|
|
172
|
+
)
|
|
173
|
+
return user
|
|
174
|
+
|
|
175
|
+
return check_tier
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
# Type aliases for dependency injection
|
|
179
|
+
CurrentUser = Annotated[AuthenticatedUser, Depends(get_current_user)]
|
|
180
|
+
OptionalUser = Annotated[AuthenticatedUser | None, Depends(get_optional_user)]
|
|
181
|
+
ProUser = Annotated[AuthenticatedUser, Depends(require_tier("pro"))]
|
|
182
|
+
UnlimitedUser = Annotated[AuthenticatedUser, Depends(require_tier("unlimited"))]
|
codeshift/api/config.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""API configuration settings."""
|
|
2
|
+
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
|
|
5
|
+
from pydantic_settings import BaseSettings
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class APISettings(BaseSettings):
|
|
9
|
+
"""Configuration settings for the Codeshift API."""
|
|
10
|
+
|
|
11
|
+
# Supabase
|
|
12
|
+
supabase_url: str = ""
|
|
13
|
+
supabase_anon_key: str = ""
|
|
14
|
+
supabase_service_role_key: str = ""
|
|
15
|
+
|
|
16
|
+
# Stripe
|
|
17
|
+
stripe_secret_key: str = ""
|
|
18
|
+
stripe_webhook_secret: str = ""
|
|
19
|
+
stripe_price_id_pro: str = ""
|
|
20
|
+
stripe_price_id_unlimited: str = ""
|
|
21
|
+
|
|
22
|
+
# Anthropic (for server-side LLM calls)
|
|
23
|
+
anthropic_api_key: str = ""
|
|
24
|
+
|
|
25
|
+
# API settings
|
|
26
|
+
codeshift_api_url: str = "https://py-resolve.replit.app"
|
|
27
|
+
api_key_prefix: str = "cs_"
|
|
28
|
+
|
|
29
|
+
# Tier quotas
|
|
30
|
+
tier_free_files: int = 100
|
|
31
|
+
tier_free_llm_calls: int = 50
|
|
32
|
+
tier_pro_files: int = 1000
|
|
33
|
+
tier_pro_llm_calls: int = 500
|
|
34
|
+
tier_unlimited_files: int = 999999999
|
|
35
|
+
tier_unlimited_llm_calls: int = 999999999
|
|
36
|
+
|
|
37
|
+
# Environment
|
|
38
|
+
environment: str = "development"
|
|
39
|
+
|
|
40
|
+
model_config = {
|
|
41
|
+
"env_prefix": "",
|
|
42
|
+
"env_file": ".env",
|
|
43
|
+
"extra": "ignore",
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def is_production(self) -> bool:
|
|
48
|
+
"""Check if running in production."""
|
|
49
|
+
return self.environment == "production"
|
|
50
|
+
|
|
51
|
+
def get_tier_limits(self, tier: str) -> dict[str, int]:
|
|
52
|
+
"""Get quota limits for a tier."""
|
|
53
|
+
limits = {
|
|
54
|
+
"free": {
|
|
55
|
+
"files_per_month": self.tier_free_files,
|
|
56
|
+
"llm_calls_per_month": self.tier_free_llm_calls,
|
|
57
|
+
},
|
|
58
|
+
"pro": {
|
|
59
|
+
"files_per_month": self.tier_pro_files,
|
|
60
|
+
"llm_calls_per_month": self.tier_pro_llm_calls,
|
|
61
|
+
},
|
|
62
|
+
"unlimited": {
|
|
63
|
+
"files_per_month": self.tier_unlimited_files,
|
|
64
|
+
"llm_calls_per_month": self.tier_unlimited_llm_calls,
|
|
65
|
+
},
|
|
66
|
+
}
|
|
67
|
+
return limits.get(tier, limits["free"])
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@lru_cache
|
|
71
|
+
def get_settings() -> APISettings:
|
|
72
|
+
"""Get cached API settings."""
|
|
73
|
+
return APISettings()
|