codeshift 0.3.7__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codeshift/__init__.py +2 -2
- codeshift/cli/__init__.py +1 -1
- codeshift/cli/commands/__init__.py +1 -1
- codeshift/cli/commands/auth.py +46 -30
- codeshift/cli/commands/scan.py +2 -5
- codeshift/cli/commands/upgrade.py +69 -61
- codeshift/cli/commands/upgrade_all.py +1 -1
- codeshift/cli/main.py +2 -2
- codeshift/knowledge/generator.py +6 -0
- codeshift/knowledge_base/libraries/aiohttp.yaml +3 -3
- codeshift/knowledge_base/libraries/httpx.yaml +4 -4
- codeshift/knowledge_base/libraries/pytest.yaml +1 -1
- codeshift/knowledge_base/models.py +1 -0
- codeshift/migrator/llm_migrator.py +8 -12
- codeshift/migrator/transforms/marshmallow_transformer.py +50 -0
- codeshift/migrator/transforms/pydantic_v1_to_v2.py +191 -22
- codeshift/scanner/code_scanner.py +22 -2
- codeshift/utils/__init__.py +1 -1
- codeshift/utils/api_client.py +155 -15
- codeshift/utils/cache.py +1 -1
- codeshift/utils/credential_store.py +393 -0
- codeshift/utils/llm_client.py +111 -9
- {codeshift-0.3.7.dist-info → codeshift-0.5.0.dist-info}/METADATA +4 -16
- {codeshift-0.3.7.dist-info → codeshift-0.5.0.dist-info}/RECORD +28 -43
- {codeshift-0.3.7.dist-info → codeshift-0.5.0.dist-info}/licenses/LICENSE +1 -1
- codeshift/api/__init__.py +0 -1
- codeshift/api/auth.py +0 -182
- codeshift/api/config.py +0 -73
- codeshift/api/database.py +0 -215
- codeshift/api/main.py +0 -103
- codeshift/api/models/__init__.py +0 -55
- codeshift/api/models/auth.py +0 -108
- codeshift/api/models/billing.py +0 -92
- codeshift/api/models/migrate.py +0 -42
- codeshift/api/models/usage.py +0 -116
- codeshift/api/routers/__init__.py +0 -5
- codeshift/api/routers/auth.py +0 -440
- codeshift/api/routers/billing.py +0 -395
- codeshift/api/routers/migrate.py +0 -304
- codeshift/api/routers/usage.py +0 -291
- codeshift/api/routers/webhooks.py +0 -289
- {codeshift-0.3.7.dist-info → codeshift-0.5.0.dist-info}/WHEEL +0 -0
- {codeshift-0.3.7.dist-info → codeshift-0.5.0.dist-info}/entry_points.txt +0 -0
- {codeshift-0.3.7.dist-info → codeshift-0.5.0.dist-info}/top_level.txt +0 -0
codeshift/api/routers/usage.py
DELETED
|
@@ -1,291 +0,0 @@
|
|
|
1
|
-
"""Usage tracking router for the PyResolve API."""
|
|
2
|
-
|
|
3
|
-
from datetime import datetime, timezone
|
|
4
|
-
|
|
5
|
-
from fastapi import APIRouter, HTTPException, status
|
|
6
|
-
|
|
7
|
-
from codeshift.api.auth import CurrentUser
|
|
8
|
-
from codeshift.api.config import get_settings
|
|
9
|
-
from codeshift.api.database import get_database
|
|
10
|
-
from codeshift.api.models.usage import (
|
|
11
|
-
QuotaCheckRequest,
|
|
12
|
-
QuotaCheckResponse,
|
|
13
|
-
QuotaInfo,
|
|
14
|
-
UsageEvent,
|
|
15
|
-
UsageEventCreate,
|
|
16
|
-
UsageResponse,
|
|
17
|
-
UsageSummary,
|
|
18
|
-
)
|
|
19
|
-
|
|
20
|
-
router = APIRouter()
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
@router.get("/quota", response_model=QuotaInfo)
|
|
24
|
-
async def get_quota(user: CurrentUser) -> QuotaInfo:
|
|
25
|
-
"""Get current quota information for the authenticated user."""
|
|
26
|
-
db = get_database()
|
|
27
|
-
settings = get_settings()
|
|
28
|
-
|
|
29
|
-
# Get user profile for tier info
|
|
30
|
-
profile = db.get_profile_by_id(user.user_id)
|
|
31
|
-
if not profile:
|
|
32
|
-
raise HTTPException(
|
|
33
|
-
status_code=status.HTTP_404_NOT_FOUND,
|
|
34
|
-
detail="User profile not found",
|
|
35
|
-
)
|
|
36
|
-
|
|
37
|
-
tier = profile.get("tier", "free")
|
|
38
|
-
limits = settings.get_tier_limits(tier)
|
|
39
|
-
|
|
40
|
-
# Get current billing period
|
|
41
|
-
billing_period = datetime.now(timezone.utc).strftime("%Y-%m")
|
|
42
|
-
|
|
43
|
-
# Get usage for current period
|
|
44
|
-
usage = db.get_usage_for_period(user.user_id, billing_period)
|
|
45
|
-
|
|
46
|
-
files_migrated = usage.get("file_migrated", 0)
|
|
47
|
-
llm_calls = usage.get("llm_call", 0)
|
|
48
|
-
|
|
49
|
-
return QuotaInfo.from_usage(
|
|
50
|
-
tier=tier,
|
|
51
|
-
billing_period=billing_period,
|
|
52
|
-
files_migrated=files_migrated,
|
|
53
|
-
llm_calls=llm_calls,
|
|
54
|
-
files_limit=limits["files_per_month"],
|
|
55
|
-
llm_calls_limit=limits["llm_calls_per_month"],
|
|
56
|
-
)
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
@router.get("/", response_model=UsageResponse)
|
|
60
|
-
async def get_usage(
|
|
61
|
-
user: CurrentUser,
|
|
62
|
-
billing_period: str | None = None,
|
|
63
|
-
limit: int = 20,
|
|
64
|
-
) -> UsageResponse:
|
|
65
|
-
"""Get usage summary and recent events."""
|
|
66
|
-
db = get_database()
|
|
67
|
-
settings = get_settings()
|
|
68
|
-
|
|
69
|
-
# Get user profile for tier info
|
|
70
|
-
profile = db.get_profile_by_id(user.user_id)
|
|
71
|
-
if not profile:
|
|
72
|
-
raise HTTPException(
|
|
73
|
-
status_code=status.HTTP_404_NOT_FOUND,
|
|
74
|
-
detail="User profile not found",
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
tier = profile.get("tier", "free")
|
|
78
|
-
limits = settings.get_tier_limits(tier)
|
|
79
|
-
|
|
80
|
-
# Default to current billing period
|
|
81
|
-
if not billing_period:
|
|
82
|
-
billing_period = datetime.now(timezone.utc).strftime("%Y-%m")
|
|
83
|
-
|
|
84
|
-
# Get usage summary
|
|
85
|
-
usage = db.get_usage_for_period(user.user_id, billing_period)
|
|
86
|
-
files_migrated = usage.get("file_migrated", 0)
|
|
87
|
-
llm_calls = usage.get("llm_call", 0)
|
|
88
|
-
|
|
89
|
-
quota = QuotaInfo.from_usage(
|
|
90
|
-
tier=tier,
|
|
91
|
-
billing_period=billing_period,
|
|
92
|
-
files_migrated=files_migrated,
|
|
93
|
-
llm_calls=llm_calls,
|
|
94
|
-
files_limit=limits["files_per_month"],
|
|
95
|
-
llm_calls_limit=limits["llm_calls_per_month"],
|
|
96
|
-
)
|
|
97
|
-
|
|
98
|
-
# Get recent events
|
|
99
|
-
events_data = db.get_usage_events(user.user_id, billing_period, limit=limit)
|
|
100
|
-
|
|
101
|
-
recent_events = [
|
|
102
|
-
UsageEvent(
|
|
103
|
-
id=e["id"],
|
|
104
|
-
user_id=e["user_id"],
|
|
105
|
-
event_type=e["event_type"],
|
|
106
|
-
library=e.get("library"),
|
|
107
|
-
quantity=e["quantity"],
|
|
108
|
-
metadata=e.get("metadata", {}),
|
|
109
|
-
billing_period=e["billing_period"],
|
|
110
|
-
created_at=e["created_at"],
|
|
111
|
-
)
|
|
112
|
-
for e in events_data
|
|
113
|
-
]
|
|
114
|
-
|
|
115
|
-
return UsageResponse(quota=quota, recent_events=recent_events)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
@router.post("/", response_model=UsageEvent)
|
|
119
|
-
async def record_usage(request: UsageEventCreate, user: CurrentUser) -> UsageEvent:
|
|
120
|
-
"""Record a usage event.
|
|
121
|
-
|
|
122
|
-
This is called by the CLI after performing migrations.
|
|
123
|
-
"""
|
|
124
|
-
db = get_database()
|
|
125
|
-
settings = get_settings()
|
|
126
|
-
|
|
127
|
-
# Get user's tier and limits
|
|
128
|
-
profile = db.get_profile_by_id(user.user_id)
|
|
129
|
-
if not profile:
|
|
130
|
-
raise HTTPException(
|
|
131
|
-
status_code=status.HTTP_404_NOT_FOUND,
|
|
132
|
-
detail="User profile not found",
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
tier = profile.get("tier", "free")
|
|
136
|
-
limits = settings.get_tier_limits(tier)
|
|
137
|
-
|
|
138
|
-
# Get current usage
|
|
139
|
-
billing_period = datetime.now(timezone.utc).strftime("%Y-%m")
|
|
140
|
-
usage = db.get_usage_for_period(user.user_id, billing_period)
|
|
141
|
-
|
|
142
|
-
# Check quota for file_migrated and llm_call events
|
|
143
|
-
if request.event_type == "file_migrated":
|
|
144
|
-
current = usage.get("file_migrated", 0)
|
|
145
|
-
limit = limits["files_per_month"]
|
|
146
|
-
if current + request.quantity > limit:
|
|
147
|
-
raise HTTPException(
|
|
148
|
-
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
149
|
-
detail=f"File migration quota exceeded. Used: {current}, Limit: {limit}",
|
|
150
|
-
)
|
|
151
|
-
elif request.event_type == "llm_call":
|
|
152
|
-
current = usage.get("llm_call", 0)
|
|
153
|
-
limit = limits["llm_calls_per_month"]
|
|
154
|
-
if current + request.quantity > limit:
|
|
155
|
-
raise HTTPException(
|
|
156
|
-
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
157
|
-
detail=f"LLM call quota exceeded. Used: {current}, Limit: {limit}",
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
# Record the event
|
|
161
|
-
result = db.record_usage_event(
|
|
162
|
-
user_id=user.user_id,
|
|
163
|
-
event_type=request.event_type,
|
|
164
|
-
library=request.library,
|
|
165
|
-
quantity=request.quantity,
|
|
166
|
-
metadata=request.metadata,
|
|
167
|
-
)
|
|
168
|
-
|
|
169
|
-
return UsageEvent(
|
|
170
|
-
id=result["id"],
|
|
171
|
-
user_id=result["user_id"],
|
|
172
|
-
event_type=result["event_type"],
|
|
173
|
-
library=result.get("library"),
|
|
174
|
-
quantity=result["quantity"],
|
|
175
|
-
metadata=result.get("metadata", {}),
|
|
176
|
-
billing_period=result["billing_period"],
|
|
177
|
-
created_at=result["created_at"],
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
@router.post("/check", response_model=QuotaCheckResponse)
|
|
182
|
-
async def check_quota(request: QuotaCheckRequest, user: CurrentUser) -> QuotaCheckResponse:
|
|
183
|
-
"""Check if an operation is within quota before performing it.
|
|
184
|
-
|
|
185
|
-
This is used by the CLI to pre-check before starting a migration.
|
|
186
|
-
"""
|
|
187
|
-
db = get_database()
|
|
188
|
-
settings = get_settings()
|
|
189
|
-
|
|
190
|
-
# Get user's tier and limits
|
|
191
|
-
profile = db.get_profile_by_id(user.user_id)
|
|
192
|
-
if not profile:
|
|
193
|
-
raise HTTPException(
|
|
194
|
-
status_code=status.HTTP_404_NOT_FOUND,
|
|
195
|
-
detail="User profile not found",
|
|
196
|
-
)
|
|
197
|
-
|
|
198
|
-
tier = profile.get("tier", "free")
|
|
199
|
-
limits = settings.get_tier_limits(tier)
|
|
200
|
-
|
|
201
|
-
# Get current usage
|
|
202
|
-
billing_period = datetime.now(timezone.utc).strftime("%Y-%m")
|
|
203
|
-
usage = db.get_usage_for_period(user.user_id, billing_period)
|
|
204
|
-
|
|
205
|
-
# Determine current usage and limit based on event type
|
|
206
|
-
if request.event_type == "file_migrated":
|
|
207
|
-
current_usage = usage.get("file_migrated", 0)
|
|
208
|
-
limit = limits["files_per_month"]
|
|
209
|
-
remaining = max(0, limit - current_usage)
|
|
210
|
-
allowed = current_usage + request.quantity <= limit
|
|
211
|
-
message = (
|
|
212
|
-
None if allowed else f"Would exceed file migration quota ({current_usage}/{limit})"
|
|
213
|
-
)
|
|
214
|
-
elif request.event_type == "llm_call":
|
|
215
|
-
current_usage = usage.get("llm_call", 0)
|
|
216
|
-
limit = limits["llm_calls_per_month"]
|
|
217
|
-
remaining = max(0, limit - current_usage)
|
|
218
|
-
allowed = current_usage + request.quantity <= limit
|
|
219
|
-
message = None if allowed else f"Would exceed LLM call quota ({current_usage}/{limit})"
|
|
220
|
-
else:
|
|
221
|
-
# scan and apply have no limits
|
|
222
|
-
current_usage = usage.get(request.event_type, 0)
|
|
223
|
-
limit = 999999999
|
|
224
|
-
remaining = limit
|
|
225
|
-
allowed = True
|
|
226
|
-
message = None
|
|
227
|
-
|
|
228
|
-
return QuotaCheckResponse(
|
|
229
|
-
allowed=allowed,
|
|
230
|
-
current_usage=current_usage,
|
|
231
|
-
limit=limit,
|
|
232
|
-
remaining=remaining,
|
|
233
|
-
message=message,
|
|
234
|
-
)
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
@router.get("/summary", response_model=UsageSummary)
|
|
238
|
-
async def get_usage_summary(
|
|
239
|
-
user: CurrentUser,
|
|
240
|
-
billing_period: str | None = None,
|
|
241
|
-
) -> UsageSummary:
|
|
242
|
-
"""Get usage summary for a billing period."""
|
|
243
|
-
db = get_database()
|
|
244
|
-
|
|
245
|
-
# Default to current billing period
|
|
246
|
-
if not billing_period:
|
|
247
|
-
billing_period = datetime.now(timezone.utc).strftime("%Y-%m")
|
|
248
|
-
|
|
249
|
-
usage = db.get_usage_for_period(user.user_id, billing_period)
|
|
250
|
-
|
|
251
|
-
return UsageSummary(
|
|
252
|
-
billing_period=billing_period,
|
|
253
|
-
files_migrated=usage.get("file_migrated", 0),
|
|
254
|
-
llm_calls=usage.get("llm_call", 0),
|
|
255
|
-
scans=usage.get("scan", 0),
|
|
256
|
-
applies=usage.get("apply", 0),
|
|
257
|
-
total_events=sum(usage.values()),
|
|
258
|
-
)
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
@router.get("/history", response_model=list[UsageSummary])
|
|
262
|
-
async def get_usage_history(
|
|
263
|
-
user: CurrentUser,
|
|
264
|
-
months: int = 6,
|
|
265
|
-
) -> list[UsageSummary]:
|
|
266
|
-
"""Get usage history for past months."""
|
|
267
|
-
from datetime import timedelta
|
|
268
|
-
|
|
269
|
-
db = get_database()
|
|
270
|
-
summaries = []
|
|
271
|
-
|
|
272
|
-
# Calculate billing periods for past N months
|
|
273
|
-
now = datetime.now(timezone.utc)
|
|
274
|
-
for i in range(months):
|
|
275
|
-
date = now - timedelta(days=30 * i)
|
|
276
|
-
billing_period = date.strftime("%Y-%m")
|
|
277
|
-
|
|
278
|
-
usage = db.get_usage_for_period(user.user_id, billing_period)
|
|
279
|
-
|
|
280
|
-
summaries.append(
|
|
281
|
-
UsageSummary(
|
|
282
|
-
billing_period=billing_period,
|
|
283
|
-
files_migrated=usage.get("file_migrated", 0),
|
|
284
|
-
llm_calls=usage.get("llm_call", 0),
|
|
285
|
-
scans=usage.get("scan", 0),
|
|
286
|
-
applies=usage.get("apply", 0),
|
|
287
|
-
total_events=sum(usage.values()),
|
|
288
|
-
)
|
|
289
|
-
)
|
|
290
|
-
|
|
291
|
-
return summaries
|
|
@@ -1,289 +0,0 @@
|
|
|
1
|
-
"""Webhooks router for the PyResolve API."""
|
|
2
|
-
|
|
3
|
-
import stripe
|
|
4
|
-
from fastapi import APIRouter, Header, HTTPException, Request, status
|
|
5
|
-
|
|
6
|
-
from codeshift.api.config import get_settings
|
|
7
|
-
from codeshift.api.database import get_database
|
|
8
|
-
|
|
9
|
-
router = APIRouter()
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def get_stripe_client() -> stripe:
|
|
13
|
-
"""Get configured Stripe client."""
|
|
14
|
-
settings = get_settings()
|
|
15
|
-
stripe.api_key = settings.stripe_secret_key
|
|
16
|
-
return stripe
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@router.post("/stripe")
|
|
20
|
-
async def handle_stripe_webhook(
|
|
21
|
-
request: Request,
|
|
22
|
-
stripe_signature: str = Header(None, alias="Stripe-Signature"),
|
|
23
|
-
) -> dict:
|
|
24
|
-
"""Handle Stripe webhook events.
|
|
25
|
-
|
|
26
|
-
Handles the following events:
|
|
27
|
-
- checkout.session.completed: User completed checkout, activate subscription
|
|
28
|
-
- customer.subscription.created: New subscription created
|
|
29
|
-
- customer.subscription.updated: Subscription updated (upgrade/downgrade)
|
|
30
|
-
- customer.subscription.deleted: Subscription canceled
|
|
31
|
-
- invoice.paid: Invoice paid successfully
|
|
32
|
-
- invoice.payment_failed: Payment failed
|
|
33
|
-
"""
|
|
34
|
-
settings = get_settings()
|
|
35
|
-
stripe_client = get_stripe_client()
|
|
36
|
-
|
|
37
|
-
# Get raw body for signature verification
|
|
38
|
-
payload = await request.body()
|
|
39
|
-
|
|
40
|
-
if not stripe_signature:
|
|
41
|
-
raise HTTPException(
|
|
42
|
-
status_code=status.HTTP_400_BAD_REQUEST,
|
|
43
|
-
detail="Missing Stripe-Signature header",
|
|
44
|
-
)
|
|
45
|
-
|
|
46
|
-
try:
|
|
47
|
-
event = stripe_client.Webhook.construct_event(
|
|
48
|
-
payload,
|
|
49
|
-
stripe_signature,
|
|
50
|
-
settings.stripe_webhook_secret,
|
|
51
|
-
)
|
|
52
|
-
except ValueError as e:
|
|
53
|
-
raise HTTPException(
|
|
54
|
-
status_code=status.HTTP_400_BAD_REQUEST,
|
|
55
|
-
detail="Invalid payload",
|
|
56
|
-
) from e
|
|
57
|
-
except stripe.error.SignatureVerificationError as e:
|
|
58
|
-
raise HTTPException(
|
|
59
|
-
status_code=status.HTTP_400_BAD_REQUEST,
|
|
60
|
-
detail="Invalid signature",
|
|
61
|
-
) from e
|
|
62
|
-
|
|
63
|
-
# Handle the event
|
|
64
|
-
event_type = event["type"]
|
|
65
|
-
data = event["data"]["object"]
|
|
66
|
-
|
|
67
|
-
try:
|
|
68
|
-
if event_type == "checkout.session.completed":
|
|
69
|
-
await handle_checkout_completed(data)
|
|
70
|
-
elif event_type == "customer.subscription.created":
|
|
71
|
-
await handle_subscription_created(data)
|
|
72
|
-
elif event_type == "customer.subscription.updated":
|
|
73
|
-
await handle_subscription_updated(data)
|
|
74
|
-
elif event_type == "customer.subscription.deleted":
|
|
75
|
-
await handle_subscription_deleted(data)
|
|
76
|
-
elif event_type == "invoice.paid":
|
|
77
|
-
await handle_invoice_paid(data)
|
|
78
|
-
elif event_type == "invoice.payment_failed":
|
|
79
|
-
await handle_invoice_payment_failed(data)
|
|
80
|
-
else:
|
|
81
|
-
# Log unhandled events for debugging
|
|
82
|
-
print(f"Unhandled webhook event: {event_type}")
|
|
83
|
-
except Exception as e:
|
|
84
|
-
# Log error but don't fail the webhook
|
|
85
|
-
print(f"Error handling webhook {event_type}: {e}")
|
|
86
|
-
# Still return 200 to prevent Stripe retries for non-critical errors
|
|
87
|
-
|
|
88
|
-
return {"received": True}
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
async def handle_checkout_completed(session: dict) -> None:
|
|
92
|
-
"""Handle successful checkout session completion."""
|
|
93
|
-
db = get_database()
|
|
94
|
-
|
|
95
|
-
# Get user ID from metadata
|
|
96
|
-
user_id = session.get("metadata", {}).get("user_id")
|
|
97
|
-
tier = session.get("metadata", {}).get("tier", "pro")
|
|
98
|
-
|
|
99
|
-
if not user_id:
|
|
100
|
-
print("Checkout session missing user_id in metadata")
|
|
101
|
-
return
|
|
102
|
-
|
|
103
|
-
# Get subscription ID
|
|
104
|
-
subscription_id = session.get("subscription")
|
|
105
|
-
customer_id = session.get("customer")
|
|
106
|
-
|
|
107
|
-
# Update user's profile
|
|
108
|
-
db.update_profile(
|
|
109
|
-
user_id,
|
|
110
|
-
{
|
|
111
|
-
"tier": tier,
|
|
112
|
-
"stripe_customer_id": customer_id,
|
|
113
|
-
"stripe_subscription_id": subscription_id,
|
|
114
|
-
},
|
|
115
|
-
)
|
|
116
|
-
|
|
117
|
-
print(f"User {user_id} upgraded to {tier}")
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
async def handle_subscription_created(subscription: dict) -> None:
|
|
121
|
-
"""Handle new subscription creation."""
|
|
122
|
-
db = get_database()
|
|
123
|
-
|
|
124
|
-
customer_id = subscription.get("customer")
|
|
125
|
-
subscription_id = subscription.get("id")
|
|
126
|
-
|
|
127
|
-
# Find user by customer ID
|
|
128
|
-
result = (
|
|
129
|
-
db.client.table("profiles").select("id").eq("stripe_customer_id", customer_id).execute()
|
|
130
|
-
)
|
|
131
|
-
|
|
132
|
-
if not result.data:
|
|
133
|
-
print(f"No user found for customer {customer_id}")
|
|
134
|
-
return
|
|
135
|
-
|
|
136
|
-
user_id = result.data[0]["id"]
|
|
137
|
-
|
|
138
|
-
# Determine tier from price
|
|
139
|
-
items = subscription.get("items", {}).get("data", [])
|
|
140
|
-
if items:
|
|
141
|
-
price_id = items[0].get("price", {}).get("id")
|
|
142
|
-
settings = get_settings()
|
|
143
|
-
|
|
144
|
-
if price_id == settings.stripe_price_id_pro:
|
|
145
|
-
tier = "pro"
|
|
146
|
-
elif price_id == settings.stripe_price_id_unlimited:
|
|
147
|
-
tier = "unlimited"
|
|
148
|
-
else:
|
|
149
|
-
tier = "pro" # Default
|
|
150
|
-
else:
|
|
151
|
-
tier = "pro"
|
|
152
|
-
|
|
153
|
-
# Update profile
|
|
154
|
-
db.update_profile(
|
|
155
|
-
user_id,
|
|
156
|
-
{
|
|
157
|
-
"tier": tier,
|
|
158
|
-
"stripe_subscription_id": subscription_id,
|
|
159
|
-
"billing_period_start": subscription.get("current_period_start"),
|
|
160
|
-
"billing_period_end": subscription.get("current_period_end"),
|
|
161
|
-
},
|
|
162
|
-
)
|
|
163
|
-
|
|
164
|
-
print(f"Subscription created for user {user_id}: {tier}")
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
async def handle_subscription_updated(subscription: dict) -> None:
|
|
168
|
-
"""Handle subscription updates (upgrades/downgrades)."""
|
|
169
|
-
db = get_database()
|
|
170
|
-
|
|
171
|
-
customer_id = subscription.get("customer")
|
|
172
|
-
subscription_id = subscription.get("id")
|
|
173
|
-
status_value = subscription.get("status")
|
|
174
|
-
|
|
175
|
-
# Find user by customer ID
|
|
176
|
-
result = (
|
|
177
|
-
db.client.table("profiles").select("id").eq("stripe_customer_id", customer_id).execute()
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
if not result.data:
|
|
181
|
-
print(f"No user found for customer {customer_id}")
|
|
182
|
-
return
|
|
183
|
-
|
|
184
|
-
user_id = result.data[0]["id"]
|
|
185
|
-
|
|
186
|
-
# Determine tier from price
|
|
187
|
-
items = subscription.get("items", {}).get("data", [])
|
|
188
|
-
if items:
|
|
189
|
-
price_id = items[0].get("price", {}).get("id")
|
|
190
|
-
settings = get_settings()
|
|
191
|
-
|
|
192
|
-
if price_id == settings.stripe_price_id_pro:
|
|
193
|
-
tier = "pro"
|
|
194
|
-
elif price_id == settings.stripe_price_id_unlimited:
|
|
195
|
-
tier = "unlimited"
|
|
196
|
-
else:
|
|
197
|
-
tier = "pro"
|
|
198
|
-
else:
|
|
199
|
-
tier = "pro"
|
|
200
|
-
|
|
201
|
-
# Handle canceled subscriptions
|
|
202
|
-
if status_value in ("canceled", "unpaid"):
|
|
203
|
-
tier = "free"
|
|
204
|
-
|
|
205
|
-
# Update profile
|
|
206
|
-
db.update_profile(
|
|
207
|
-
user_id,
|
|
208
|
-
{
|
|
209
|
-
"tier": tier,
|
|
210
|
-
"stripe_subscription_id": subscription_id if tier != "free" else None,
|
|
211
|
-
"billing_period_start": subscription.get("current_period_start"),
|
|
212
|
-
"billing_period_end": subscription.get("current_period_end"),
|
|
213
|
-
},
|
|
214
|
-
)
|
|
215
|
-
|
|
216
|
-
print(f"Subscription updated for user {user_id}: {tier} ({status_value})")
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
async def handle_subscription_deleted(subscription: dict) -> None:
|
|
220
|
-
"""Handle subscription cancellation/deletion."""
|
|
221
|
-
db = get_database()
|
|
222
|
-
|
|
223
|
-
customer_id = subscription.get("customer")
|
|
224
|
-
|
|
225
|
-
# Find user by customer ID
|
|
226
|
-
result = (
|
|
227
|
-
db.client.table("profiles").select("id").eq("stripe_customer_id", customer_id).execute()
|
|
228
|
-
)
|
|
229
|
-
|
|
230
|
-
if not result.data:
|
|
231
|
-
print(f"No user found for customer {customer_id}")
|
|
232
|
-
return
|
|
233
|
-
|
|
234
|
-
user_id = result.data[0]["id"]
|
|
235
|
-
|
|
236
|
-
# Downgrade to free tier
|
|
237
|
-
db.update_profile(
|
|
238
|
-
user_id,
|
|
239
|
-
{
|
|
240
|
-
"tier": "free",
|
|
241
|
-
"stripe_subscription_id": None,
|
|
242
|
-
"billing_period_start": None,
|
|
243
|
-
"billing_period_end": None,
|
|
244
|
-
},
|
|
245
|
-
)
|
|
246
|
-
|
|
247
|
-
print(f"Subscription deleted for user {user_id}, downgraded to free")
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
async def handle_invoice_paid(invoice: dict) -> None:
|
|
251
|
-
"""Handle successful invoice payment."""
|
|
252
|
-
customer_id = invoice.get("customer")
|
|
253
|
-
amount_paid = invoice.get("amount_paid", 0)
|
|
254
|
-
|
|
255
|
-
print(f"Invoice paid for customer {customer_id}: ${amount_paid / 100:.2f}")
|
|
256
|
-
|
|
257
|
-
# Could send receipt email, update analytics, etc.
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
async def handle_invoice_payment_failed(invoice: dict) -> None:
|
|
261
|
-
"""Handle failed invoice payment."""
|
|
262
|
-
db = get_database()
|
|
263
|
-
|
|
264
|
-
customer_id = invoice.get("customer")
|
|
265
|
-
attempt_count = invoice.get("attempt_count", 0)
|
|
266
|
-
|
|
267
|
-
print(f"Invoice payment failed for customer {customer_id} (attempt {attempt_count})")
|
|
268
|
-
|
|
269
|
-
# Find user
|
|
270
|
-
result = (
|
|
271
|
-
db.client.table("profiles")
|
|
272
|
-
.select("id, email")
|
|
273
|
-
.eq("stripe_customer_id", customer_id)
|
|
274
|
-
.execute()
|
|
275
|
-
)
|
|
276
|
-
|
|
277
|
-
if not result.data:
|
|
278
|
-
return
|
|
279
|
-
|
|
280
|
-
# After multiple failed attempts, could:
|
|
281
|
-
# - Send warning email
|
|
282
|
-
# - Temporarily restrict features
|
|
283
|
-
# - Eventually downgrade to free
|
|
284
|
-
|
|
285
|
-
if attempt_count >= 3:
|
|
286
|
-
user_id = result.data[0]["id"]
|
|
287
|
-
print(f"Multiple payment failures for user {user_id}, consider downgrade")
|
|
288
|
-
# Uncomment to auto-downgrade after 3 failures:
|
|
289
|
-
# db.update_profile(user_id, {"tier": "free"})
|
|
File without changes
|
|
File without changes
|
|
File without changes
|