codeshift 0.3.7__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. codeshift/__init__.py +2 -2
  2. codeshift/cli/__init__.py +1 -1
  3. codeshift/cli/commands/__init__.py +1 -1
  4. codeshift/cli/commands/auth.py +5 -5
  5. codeshift/cli/commands/scan.py +2 -5
  6. codeshift/cli/commands/upgrade.py +2 -7
  7. codeshift/cli/commands/upgrade_all.py +1 -1
  8. codeshift/cli/main.py +2 -2
  9. codeshift/migrator/llm_migrator.py +8 -12
  10. codeshift/utils/__init__.py +1 -1
  11. codeshift/utils/api_client.py +11 -11
  12. codeshift/utils/cache.py +1 -1
  13. {codeshift-0.3.7.dist-info → codeshift-0.4.0.dist-info}/METADATA +2 -17
  14. {codeshift-0.3.7.dist-info → codeshift-0.4.0.dist-info}/RECORD +18 -34
  15. {codeshift-0.3.7.dist-info → codeshift-0.4.0.dist-info}/licenses/LICENSE +1 -1
  16. codeshift/api/__init__.py +0 -1
  17. codeshift/api/auth.py +0 -182
  18. codeshift/api/config.py +0 -73
  19. codeshift/api/database.py +0 -215
  20. codeshift/api/main.py +0 -103
  21. codeshift/api/models/__init__.py +0 -55
  22. codeshift/api/models/auth.py +0 -108
  23. codeshift/api/models/billing.py +0 -92
  24. codeshift/api/models/migrate.py +0 -42
  25. codeshift/api/models/usage.py +0 -116
  26. codeshift/api/routers/__init__.py +0 -5
  27. codeshift/api/routers/auth.py +0 -440
  28. codeshift/api/routers/billing.py +0 -395
  29. codeshift/api/routers/migrate.py +0 -304
  30. codeshift/api/routers/usage.py +0 -291
  31. codeshift/api/routers/webhooks.py +0 -289
  32. {codeshift-0.3.7.dist-info → codeshift-0.4.0.dist-info}/WHEEL +0 -0
  33. {codeshift-0.3.7.dist-info → codeshift-0.4.0.dist-info}/entry_points.txt +0 -0
  34. {codeshift-0.3.7.dist-info → codeshift-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,291 +0,0 @@
1
- """Usage tracking router for the PyResolve API."""
2
-
3
- from datetime import datetime, timezone
4
-
5
- from fastapi import APIRouter, HTTPException, status
6
-
7
- from codeshift.api.auth import CurrentUser
8
- from codeshift.api.config import get_settings
9
- from codeshift.api.database import get_database
10
- from codeshift.api.models.usage import (
11
- QuotaCheckRequest,
12
- QuotaCheckResponse,
13
- QuotaInfo,
14
- UsageEvent,
15
- UsageEventCreate,
16
- UsageResponse,
17
- UsageSummary,
18
- )
19
-
20
- router = APIRouter()
21
-
22
-
23
- @router.get("/quota", response_model=QuotaInfo)
24
- async def get_quota(user: CurrentUser) -> QuotaInfo:
25
- """Get current quota information for the authenticated user."""
26
- db = get_database()
27
- settings = get_settings()
28
-
29
- # Get user profile for tier info
30
- profile = db.get_profile_by_id(user.user_id)
31
- if not profile:
32
- raise HTTPException(
33
- status_code=status.HTTP_404_NOT_FOUND,
34
- detail="User profile not found",
35
- )
36
-
37
- tier = profile.get("tier", "free")
38
- limits = settings.get_tier_limits(tier)
39
-
40
- # Get current billing period
41
- billing_period = datetime.now(timezone.utc).strftime("%Y-%m")
42
-
43
- # Get usage for current period
44
- usage = db.get_usage_for_period(user.user_id, billing_period)
45
-
46
- files_migrated = usage.get("file_migrated", 0)
47
- llm_calls = usage.get("llm_call", 0)
48
-
49
- return QuotaInfo.from_usage(
50
- tier=tier,
51
- billing_period=billing_period,
52
- files_migrated=files_migrated,
53
- llm_calls=llm_calls,
54
- files_limit=limits["files_per_month"],
55
- llm_calls_limit=limits["llm_calls_per_month"],
56
- )
57
-
58
-
59
- @router.get("/", response_model=UsageResponse)
60
- async def get_usage(
61
- user: CurrentUser,
62
- billing_period: str | None = None,
63
- limit: int = 20,
64
- ) -> UsageResponse:
65
- """Get usage summary and recent events."""
66
- db = get_database()
67
- settings = get_settings()
68
-
69
- # Get user profile for tier info
70
- profile = db.get_profile_by_id(user.user_id)
71
- if not profile:
72
- raise HTTPException(
73
- status_code=status.HTTP_404_NOT_FOUND,
74
- detail="User profile not found",
75
- )
76
-
77
- tier = profile.get("tier", "free")
78
- limits = settings.get_tier_limits(tier)
79
-
80
- # Default to current billing period
81
- if not billing_period:
82
- billing_period = datetime.now(timezone.utc).strftime("%Y-%m")
83
-
84
- # Get usage summary
85
- usage = db.get_usage_for_period(user.user_id, billing_period)
86
- files_migrated = usage.get("file_migrated", 0)
87
- llm_calls = usage.get("llm_call", 0)
88
-
89
- quota = QuotaInfo.from_usage(
90
- tier=tier,
91
- billing_period=billing_period,
92
- files_migrated=files_migrated,
93
- llm_calls=llm_calls,
94
- files_limit=limits["files_per_month"],
95
- llm_calls_limit=limits["llm_calls_per_month"],
96
- )
97
-
98
- # Get recent events
99
- events_data = db.get_usage_events(user.user_id, billing_period, limit=limit)
100
-
101
- recent_events = [
102
- UsageEvent(
103
- id=e["id"],
104
- user_id=e["user_id"],
105
- event_type=e["event_type"],
106
- library=e.get("library"),
107
- quantity=e["quantity"],
108
- metadata=e.get("metadata", {}),
109
- billing_period=e["billing_period"],
110
- created_at=e["created_at"],
111
- )
112
- for e in events_data
113
- ]
114
-
115
- return UsageResponse(quota=quota, recent_events=recent_events)
116
-
117
-
118
- @router.post("/", response_model=UsageEvent)
119
- async def record_usage(request: UsageEventCreate, user: CurrentUser) -> UsageEvent:
120
- """Record a usage event.
121
-
122
- This is called by the CLI after performing migrations.
123
- """
124
- db = get_database()
125
- settings = get_settings()
126
-
127
- # Get user's tier and limits
128
- profile = db.get_profile_by_id(user.user_id)
129
- if not profile:
130
- raise HTTPException(
131
- status_code=status.HTTP_404_NOT_FOUND,
132
- detail="User profile not found",
133
- )
134
-
135
- tier = profile.get("tier", "free")
136
- limits = settings.get_tier_limits(tier)
137
-
138
- # Get current usage
139
- billing_period = datetime.now(timezone.utc).strftime("%Y-%m")
140
- usage = db.get_usage_for_period(user.user_id, billing_period)
141
-
142
- # Check quota for file_migrated and llm_call events
143
- if request.event_type == "file_migrated":
144
- current = usage.get("file_migrated", 0)
145
- limit = limits["files_per_month"]
146
- if current + request.quantity > limit:
147
- raise HTTPException(
148
- status_code=status.HTTP_429_TOO_MANY_REQUESTS,
149
- detail=f"File migration quota exceeded. Used: {current}, Limit: {limit}",
150
- )
151
- elif request.event_type == "llm_call":
152
- current = usage.get("llm_call", 0)
153
- limit = limits["llm_calls_per_month"]
154
- if current + request.quantity > limit:
155
- raise HTTPException(
156
- status_code=status.HTTP_429_TOO_MANY_REQUESTS,
157
- detail=f"LLM call quota exceeded. Used: {current}, Limit: {limit}",
158
- )
159
-
160
- # Record the event
161
- result = db.record_usage_event(
162
- user_id=user.user_id,
163
- event_type=request.event_type,
164
- library=request.library,
165
- quantity=request.quantity,
166
- metadata=request.metadata,
167
- )
168
-
169
- return UsageEvent(
170
- id=result["id"],
171
- user_id=result["user_id"],
172
- event_type=result["event_type"],
173
- library=result.get("library"),
174
- quantity=result["quantity"],
175
- metadata=result.get("metadata", {}),
176
- billing_period=result["billing_period"],
177
- created_at=result["created_at"],
178
- )
179
-
180
-
181
- @router.post("/check", response_model=QuotaCheckResponse)
182
- async def check_quota(request: QuotaCheckRequest, user: CurrentUser) -> QuotaCheckResponse:
183
- """Check if an operation is within quota before performing it.
184
-
185
- This is used by the CLI to pre-check before starting a migration.
186
- """
187
- db = get_database()
188
- settings = get_settings()
189
-
190
- # Get user's tier and limits
191
- profile = db.get_profile_by_id(user.user_id)
192
- if not profile:
193
- raise HTTPException(
194
- status_code=status.HTTP_404_NOT_FOUND,
195
- detail="User profile not found",
196
- )
197
-
198
- tier = profile.get("tier", "free")
199
- limits = settings.get_tier_limits(tier)
200
-
201
- # Get current usage
202
- billing_period = datetime.now(timezone.utc).strftime("%Y-%m")
203
- usage = db.get_usage_for_period(user.user_id, billing_period)
204
-
205
- # Determine current usage and limit based on event type
206
- if request.event_type == "file_migrated":
207
- current_usage = usage.get("file_migrated", 0)
208
- limit = limits["files_per_month"]
209
- remaining = max(0, limit - current_usage)
210
- allowed = current_usage + request.quantity <= limit
211
- message = (
212
- None if allowed else f"Would exceed file migration quota ({current_usage}/{limit})"
213
- )
214
- elif request.event_type == "llm_call":
215
- current_usage = usage.get("llm_call", 0)
216
- limit = limits["llm_calls_per_month"]
217
- remaining = max(0, limit - current_usage)
218
- allowed = current_usage + request.quantity <= limit
219
- message = None if allowed else f"Would exceed LLM call quota ({current_usage}/{limit})"
220
- else:
221
- # scan and apply have no limits
222
- current_usage = usage.get(request.event_type, 0)
223
- limit = 999999999
224
- remaining = limit
225
- allowed = True
226
- message = None
227
-
228
- return QuotaCheckResponse(
229
- allowed=allowed,
230
- current_usage=current_usage,
231
- limit=limit,
232
- remaining=remaining,
233
- message=message,
234
- )
235
-
236
-
237
- @router.get("/summary", response_model=UsageSummary)
238
- async def get_usage_summary(
239
- user: CurrentUser,
240
- billing_period: str | None = None,
241
- ) -> UsageSummary:
242
- """Get usage summary for a billing period."""
243
- db = get_database()
244
-
245
- # Default to current billing period
246
- if not billing_period:
247
- billing_period = datetime.now(timezone.utc).strftime("%Y-%m")
248
-
249
- usage = db.get_usage_for_period(user.user_id, billing_period)
250
-
251
- return UsageSummary(
252
- billing_period=billing_period,
253
- files_migrated=usage.get("file_migrated", 0),
254
- llm_calls=usage.get("llm_call", 0),
255
- scans=usage.get("scan", 0),
256
- applies=usage.get("apply", 0),
257
- total_events=sum(usage.values()),
258
- )
259
-
260
-
261
- @router.get("/history", response_model=list[UsageSummary])
262
- async def get_usage_history(
263
- user: CurrentUser,
264
- months: int = 6,
265
- ) -> list[UsageSummary]:
266
- """Get usage history for past months."""
267
- from datetime import timedelta
268
-
269
- db = get_database()
270
- summaries = []
271
-
272
- # Calculate billing periods for past N months
273
- now = datetime.now(timezone.utc)
274
- for i in range(months):
275
- date = now - timedelta(days=30 * i)
276
- billing_period = date.strftime("%Y-%m")
277
-
278
- usage = db.get_usage_for_period(user.user_id, billing_period)
279
-
280
- summaries.append(
281
- UsageSummary(
282
- billing_period=billing_period,
283
- files_migrated=usage.get("file_migrated", 0),
284
- llm_calls=usage.get("llm_call", 0),
285
- scans=usage.get("scan", 0),
286
- applies=usage.get("apply", 0),
287
- total_events=sum(usage.values()),
288
- )
289
- )
290
-
291
- return summaries
@@ -1,289 +0,0 @@
1
- """Webhooks router for the PyResolve API."""
2
-
3
- import stripe
4
- from fastapi import APIRouter, Header, HTTPException, Request, status
5
-
6
- from codeshift.api.config import get_settings
7
- from codeshift.api.database import get_database
8
-
9
- router = APIRouter()
10
-
11
-
12
- def get_stripe_client() -> stripe:
13
- """Get configured Stripe client."""
14
- settings = get_settings()
15
- stripe.api_key = settings.stripe_secret_key
16
- return stripe
17
-
18
-
19
- @router.post("/stripe")
20
- async def handle_stripe_webhook(
21
- request: Request,
22
- stripe_signature: str = Header(None, alias="Stripe-Signature"),
23
- ) -> dict:
24
- """Handle Stripe webhook events.
25
-
26
- Handles the following events:
27
- - checkout.session.completed: User completed checkout, activate subscription
28
- - customer.subscription.created: New subscription created
29
- - customer.subscription.updated: Subscription updated (upgrade/downgrade)
30
- - customer.subscription.deleted: Subscription canceled
31
- - invoice.paid: Invoice paid successfully
32
- - invoice.payment_failed: Payment failed
33
- """
34
- settings = get_settings()
35
- stripe_client = get_stripe_client()
36
-
37
- # Get raw body for signature verification
38
- payload = await request.body()
39
-
40
- if not stripe_signature:
41
- raise HTTPException(
42
- status_code=status.HTTP_400_BAD_REQUEST,
43
- detail="Missing Stripe-Signature header",
44
- )
45
-
46
- try:
47
- event = stripe_client.Webhook.construct_event(
48
- payload,
49
- stripe_signature,
50
- settings.stripe_webhook_secret,
51
- )
52
- except ValueError as e:
53
- raise HTTPException(
54
- status_code=status.HTTP_400_BAD_REQUEST,
55
- detail="Invalid payload",
56
- ) from e
57
- except stripe.error.SignatureVerificationError as e:
58
- raise HTTPException(
59
- status_code=status.HTTP_400_BAD_REQUEST,
60
- detail="Invalid signature",
61
- ) from e
62
-
63
- # Handle the event
64
- event_type = event["type"]
65
- data = event["data"]["object"]
66
-
67
- try:
68
- if event_type == "checkout.session.completed":
69
- await handle_checkout_completed(data)
70
- elif event_type == "customer.subscription.created":
71
- await handle_subscription_created(data)
72
- elif event_type == "customer.subscription.updated":
73
- await handle_subscription_updated(data)
74
- elif event_type == "customer.subscription.deleted":
75
- await handle_subscription_deleted(data)
76
- elif event_type == "invoice.paid":
77
- await handle_invoice_paid(data)
78
- elif event_type == "invoice.payment_failed":
79
- await handle_invoice_payment_failed(data)
80
- else:
81
- # Log unhandled events for debugging
82
- print(f"Unhandled webhook event: {event_type}")
83
- except Exception as e:
84
- # Log error but don't fail the webhook
85
- print(f"Error handling webhook {event_type}: {e}")
86
- # Still return 200 to prevent Stripe retries for non-critical errors
87
-
88
- return {"received": True}
89
-
90
-
91
- async def handle_checkout_completed(session: dict) -> None:
92
- """Handle successful checkout session completion."""
93
- db = get_database()
94
-
95
- # Get user ID from metadata
96
- user_id = session.get("metadata", {}).get("user_id")
97
- tier = session.get("metadata", {}).get("tier", "pro")
98
-
99
- if not user_id:
100
- print("Checkout session missing user_id in metadata")
101
- return
102
-
103
- # Get subscription ID
104
- subscription_id = session.get("subscription")
105
- customer_id = session.get("customer")
106
-
107
- # Update user's profile
108
- db.update_profile(
109
- user_id,
110
- {
111
- "tier": tier,
112
- "stripe_customer_id": customer_id,
113
- "stripe_subscription_id": subscription_id,
114
- },
115
- )
116
-
117
- print(f"User {user_id} upgraded to {tier}")
118
-
119
-
120
- async def handle_subscription_created(subscription: dict) -> None:
121
- """Handle new subscription creation."""
122
- db = get_database()
123
-
124
- customer_id = subscription.get("customer")
125
- subscription_id = subscription.get("id")
126
-
127
- # Find user by customer ID
128
- result = (
129
- db.client.table("profiles").select("id").eq("stripe_customer_id", customer_id).execute()
130
- )
131
-
132
- if not result.data:
133
- print(f"No user found for customer {customer_id}")
134
- return
135
-
136
- user_id = result.data[0]["id"]
137
-
138
- # Determine tier from price
139
- items = subscription.get("items", {}).get("data", [])
140
- if items:
141
- price_id = items[0].get("price", {}).get("id")
142
- settings = get_settings()
143
-
144
- if price_id == settings.stripe_price_id_pro:
145
- tier = "pro"
146
- elif price_id == settings.stripe_price_id_unlimited:
147
- tier = "unlimited"
148
- else:
149
- tier = "pro" # Default
150
- else:
151
- tier = "pro"
152
-
153
- # Update profile
154
- db.update_profile(
155
- user_id,
156
- {
157
- "tier": tier,
158
- "stripe_subscription_id": subscription_id,
159
- "billing_period_start": subscription.get("current_period_start"),
160
- "billing_period_end": subscription.get("current_period_end"),
161
- },
162
- )
163
-
164
- print(f"Subscription created for user {user_id}: {tier}")
165
-
166
-
167
- async def handle_subscription_updated(subscription: dict) -> None:
168
- """Handle subscription updates (upgrades/downgrades)."""
169
- db = get_database()
170
-
171
- customer_id = subscription.get("customer")
172
- subscription_id = subscription.get("id")
173
- status_value = subscription.get("status")
174
-
175
- # Find user by customer ID
176
- result = (
177
- db.client.table("profiles").select("id").eq("stripe_customer_id", customer_id).execute()
178
- )
179
-
180
- if not result.data:
181
- print(f"No user found for customer {customer_id}")
182
- return
183
-
184
- user_id = result.data[0]["id"]
185
-
186
- # Determine tier from price
187
- items = subscription.get("items", {}).get("data", [])
188
- if items:
189
- price_id = items[0].get("price", {}).get("id")
190
- settings = get_settings()
191
-
192
- if price_id == settings.stripe_price_id_pro:
193
- tier = "pro"
194
- elif price_id == settings.stripe_price_id_unlimited:
195
- tier = "unlimited"
196
- else:
197
- tier = "pro"
198
- else:
199
- tier = "pro"
200
-
201
- # Handle canceled subscriptions
202
- if status_value in ("canceled", "unpaid"):
203
- tier = "free"
204
-
205
- # Update profile
206
- db.update_profile(
207
- user_id,
208
- {
209
- "tier": tier,
210
- "stripe_subscription_id": subscription_id if tier != "free" else None,
211
- "billing_period_start": subscription.get("current_period_start"),
212
- "billing_period_end": subscription.get("current_period_end"),
213
- },
214
- )
215
-
216
- print(f"Subscription updated for user {user_id}: {tier} ({status_value})")
217
-
218
-
219
- async def handle_subscription_deleted(subscription: dict) -> None:
220
- """Handle subscription cancellation/deletion."""
221
- db = get_database()
222
-
223
- customer_id = subscription.get("customer")
224
-
225
- # Find user by customer ID
226
- result = (
227
- db.client.table("profiles").select("id").eq("stripe_customer_id", customer_id).execute()
228
- )
229
-
230
- if not result.data:
231
- print(f"No user found for customer {customer_id}")
232
- return
233
-
234
- user_id = result.data[0]["id"]
235
-
236
- # Downgrade to free tier
237
- db.update_profile(
238
- user_id,
239
- {
240
- "tier": "free",
241
- "stripe_subscription_id": None,
242
- "billing_period_start": None,
243
- "billing_period_end": None,
244
- },
245
- )
246
-
247
- print(f"Subscription deleted for user {user_id}, downgraded to free")
248
-
249
-
250
- async def handle_invoice_paid(invoice: dict) -> None:
251
- """Handle successful invoice payment."""
252
- customer_id = invoice.get("customer")
253
- amount_paid = invoice.get("amount_paid", 0)
254
-
255
- print(f"Invoice paid for customer {customer_id}: ${amount_paid / 100:.2f}")
256
-
257
- # Could send receipt email, update analytics, etc.
258
-
259
-
260
- async def handle_invoice_payment_failed(invoice: dict) -> None:
261
- """Handle failed invoice payment."""
262
- db = get_database()
263
-
264
- customer_id = invoice.get("customer")
265
- attempt_count = invoice.get("attempt_count", 0)
266
-
267
- print(f"Invoice payment failed for customer {customer_id} (attempt {attempt_count})")
268
-
269
- # Find user
270
- result = (
271
- db.client.table("profiles")
272
- .select("id, email")
273
- .eq("stripe_customer_id", customer_id)
274
- .execute()
275
- )
276
-
277
- if not result.data:
278
- return
279
-
280
- # After multiple failed attempts, could:
281
- # - Send warning email
282
- # - Temporarily restrict features
283
- # - Eventually downgrade to free
284
-
285
- if attempt_count >= 3:
286
- user_id = result.data[0]["id"]
287
- print(f"Multiple payment failures for user {user_id}, consider downgrade")
288
- # Uncomment to auto-downgrade after 3 failures:
289
- # db.update_profile(user_id, {"tier": "free"})