msaas-billing 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- msaas_billing-0.1.0/.gitignore +21 -0
- msaas_billing-0.1.0/PKG-INFO +14 -0
- msaas_billing-0.1.0/pyproject.toml +30 -0
- msaas_billing-0.1.0/src/billing/__init__.py +56 -0
- msaas_billing-0.1.0/src/billing/checkout.py +73 -0
- msaas_billing-0.1.0/src/billing/config.py +48 -0
- msaas_billing-0.1.0/src/billing/invoices.py +47 -0
- msaas_billing-0.1.0/src/billing/models.py +53 -0
- msaas_billing-0.1.0/src/billing/router.py +138 -0
- msaas_billing-0.1.0/src/billing/subscriptions.py +168 -0
- msaas_billing-0.1.0/src/billing/usage.py +108 -0
- msaas_billing-0.1.0/src/billing/webhooks.py +120 -0
- msaas_billing-0.1.0/tests/__init__.py +0 -0
- msaas_billing-0.1.0/tests/conftest.py +24 -0
- msaas_billing-0.1.0/tests/test_checkout.py +78 -0
- msaas_billing-0.1.0/tests/test_webhooks.py +148 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
node_modules/
|
|
2
|
+
dist/
|
|
3
|
+
.next/
|
|
4
|
+
.turbo/
|
|
5
|
+
*.pyc
|
|
6
|
+
__pycache__/
|
|
7
|
+
.venv/
|
|
8
|
+
*.egg-info/
|
|
9
|
+
.pytest_cache/
|
|
10
|
+
.ruff_cache/
|
|
11
|
+
.env
|
|
12
|
+
.env.local
|
|
13
|
+
.env.*.local
|
|
14
|
+
.DS_Store
|
|
15
|
+
coverage/
|
|
16
|
+
|
|
17
|
+
# Runtime artifacts
|
|
18
|
+
logs_llm/
|
|
19
|
+
vectors.db
|
|
20
|
+
vectors.db-shm
|
|
21
|
+
vectors.db-wal
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: msaas-billing
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Stripe billing module for SaaS products
|
|
5
|
+
Requires-Python: >=3.12
|
|
6
|
+
Requires-Dist: fastapi>=0.115.0
|
|
7
|
+
Requires-Dist: msaas-api-core
|
|
8
|
+
Requires-Dist: msaas-errors
|
|
9
|
+
Requires-Dist: pydantic>=2.0
|
|
10
|
+
Requires-Dist: stripe>=9.0.0
|
|
11
|
+
Provides-Extra: dev
|
|
12
|
+
Requires-Dist: httpx>=0.27; extra == 'dev'
|
|
13
|
+
Requires-Dist: pytest-asyncio>=0.24; extra == 'dev'
|
|
14
|
+
Requires-Dist: pytest>=8.0; extra == 'dev'
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "msaas-billing"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Stripe billing module for SaaS products"
|
|
5
|
+
requires-python = ">=3.12"
|
|
6
|
+
dependencies = [
|
|
7
|
+
"msaas-api-core",
|
|
8
|
+
"msaas-errors",
|
|
9
|
+
"stripe>=9.0.0",
|
|
10
|
+
"fastapi>=0.115.0",
|
|
11
|
+
"pydantic>=2.0",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
[build-system]
|
|
15
|
+
requires = ["hatchling"]
|
|
16
|
+
build-backend = "hatchling.build"
|
|
17
|
+
|
|
18
|
+
[tool.hatch.build.targets.wheel]
|
|
19
|
+
packages = ["src/billing"]
|
|
20
|
+
|
|
21
|
+
[project.optional-dependencies]
|
|
22
|
+
dev = [
|
|
23
|
+
"pytest>=8.0",
|
|
24
|
+
"pytest-asyncio>=0.24",
|
|
25
|
+
"httpx>=0.27",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
[tool.uv.sources]
|
|
29
|
+
msaas-api-core = { workspace = true }
|
|
30
|
+
msaas-errors = { workspace = true }
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""willian-billing: Stripe billing module for SaaS products.
|
|
2
|
+
|
|
3
|
+
Usage:
|
|
4
|
+
from billing import init_billing, BillingConfig, BillingRouter
|
|
5
|
+
|
|
6
|
+
config = BillingConfig(
|
|
7
|
+
stripe_secret_key="sk_...",
|
|
8
|
+
stripe_webhook_secret="whsec_...",
|
|
9
|
+
success_url="https://app.example.com/billing/success",
|
|
10
|
+
cancel_url="https://app.example.com/billing/cancel",
|
|
11
|
+
product_plans={"starter": "price_xxx", "pro": "price_yyy"},
|
|
12
|
+
)
|
|
13
|
+
init_billing(config)
|
|
14
|
+
|
|
15
|
+
# Mount the router in your FastAPI app
|
|
16
|
+
app.include_router(BillingRouter)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from billing.checkout import create_checkout, create_portal_session
|
|
20
|
+
from billing.config import BillingConfig, init_billing
|
|
21
|
+
from billing.invoices import get_invoices
|
|
22
|
+
from billing.models import (
|
|
23
|
+
CheckoutSession,
|
|
24
|
+
Invoice,
|
|
25
|
+
Subscription,
|
|
26
|
+
UsageRecord,
|
|
27
|
+
WebhookEvent,
|
|
28
|
+
)
|
|
29
|
+
from billing.router import BillingRouter
|
|
30
|
+
from billing.subscriptions import (
|
|
31
|
+
cancel_subscription,
|
|
32
|
+
get_subscription,
|
|
33
|
+
update_subscription,
|
|
34
|
+
)
|
|
35
|
+
from billing.usage import get_usage, record_usage
|
|
36
|
+
from billing.webhooks import handle_webhook
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
"BillingConfig",
|
|
40
|
+
"BillingRouter",
|
|
41
|
+
"CheckoutSession",
|
|
42
|
+
"Invoice",
|
|
43
|
+
"Subscription",
|
|
44
|
+
"UsageRecord",
|
|
45
|
+
"WebhookEvent",
|
|
46
|
+
"cancel_subscription",
|
|
47
|
+
"create_checkout",
|
|
48
|
+
"create_portal_session",
|
|
49
|
+
"get_invoices",
|
|
50
|
+
"get_subscription",
|
|
51
|
+
"get_usage",
|
|
52
|
+
"handle_webhook",
|
|
53
|
+
"init_billing",
|
|
54
|
+
"record_usage",
|
|
55
|
+
"update_subscription",
|
|
56
|
+
]
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Stripe Checkout and Customer Portal session management."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import stripe
|
|
6
|
+
|
|
7
|
+
from billing.config import get_config
|
|
8
|
+
from billing.models import CheckoutSession
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def create_checkout(
|
|
12
|
+
user_id: str,
|
|
13
|
+
plan: str,
|
|
14
|
+
email: str | None = None,
|
|
15
|
+
) -> CheckoutSession:
|
|
16
|
+
"""Create a Stripe Checkout session for the given plan.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
user_id: Internal user identifier, stored as client_reference_id.
|
|
20
|
+
plan: Plan name as defined in BillingConfig.product_plans.
|
|
21
|
+
email: Optional customer email to pre-fill in checkout.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
CheckoutSession with the redirect URL and session ID.
|
|
25
|
+
|
|
26
|
+
Raises:
|
|
27
|
+
ValueError: If the plan name is not found in configured plans.
|
|
28
|
+
stripe.StripeError: On Stripe API failures.
|
|
29
|
+
"""
|
|
30
|
+
config = get_config()
|
|
31
|
+
|
|
32
|
+
price_id = config.product_plans.get(plan)
|
|
33
|
+
if not price_id:
|
|
34
|
+
available = ", ".join(config.product_plans.keys())
|
|
35
|
+
raise ValueError(f"Unknown plan '{plan}'. Available plans: {available}")
|
|
36
|
+
|
|
37
|
+
params: dict = {
|
|
38
|
+
"mode": "subscription",
|
|
39
|
+
"line_items": [{"price": price_id, "quantity": 1}],
|
|
40
|
+
"success_url": config.success_url,
|
|
41
|
+
"cancel_url": config.cancel_url,
|
|
42
|
+
"client_reference_id": user_id,
|
|
43
|
+
"metadata": {"user_id": user_id, "plan": plan},
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
if email:
|
|
47
|
+
params["customer_email"] = email
|
|
48
|
+
|
|
49
|
+
session = stripe.checkout.Session.create(**params)
|
|
50
|
+
|
|
51
|
+
return CheckoutSession(
|
|
52
|
+
url=session.url,
|
|
53
|
+
session_id=session.id,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def create_portal_session(customer_id: str) -> str:
|
|
58
|
+
"""Create a Stripe Customer Portal session.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
customer_id: Stripe customer ID.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
The portal session URL for redirect.
|
|
65
|
+
|
|
66
|
+
Raises:
|
|
67
|
+
stripe.StripeError: On Stripe API failures.
|
|
68
|
+
"""
|
|
69
|
+
session = stripe.billing_portal.Session.create(
|
|
70
|
+
customer=customer_id,
|
|
71
|
+
return_url=get_config().success_url,
|
|
72
|
+
)
|
|
73
|
+
return session.url
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Billing configuration and Stripe client initialization."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import stripe
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BillingConfig(BaseModel):
|
|
10
|
+
"""Configuration for the billing module."""
|
|
11
|
+
|
|
12
|
+
stripe_secret_key: str = Field(description="Stripe secret API key")
|
|
13
|
+
stripe_webhook_secret: str = Field(description="Stripe webhook signing secret")
|
|
14
|
+
success_url: str = Field(description="Redirect URL after successful checkout")
|
|
15
|
+
cancel_url: str = Field(description="Redirect URL after cancelled checkout")
|
|
16
|
+
product_plans: dict[str, str] = Field(
|
|
17
|
+
default_factory=dict,
|
|
18
|
+
description="Mapping of plan_name -> stripe_price_id",
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
_config: BillingConfig | None = None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def init_billing(config: BillingConfig) -> None:
|
|
26
|
+
"""Initialize the billing module with the given configuration.
|
|
27
|
+
|
|
28
|
+
Sets the Stripe API key and stores config for later use by all billing functions.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
config: Billing configuration with Stripe credentials and plan mappings.
|
|
32
|
+
"""
|
|
33
|
+
global _config
|
|
34
|
+
_config = config
|
|
35
|
+
stripe.api_key = config.stripe_secret_key
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_config() -> BillingConfig:
|
|
39
|
+
"""Return the current billing configuration.
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
RuntimeError: If init_billing has not been called.
|
|
43
|
+
"""
|
|
44
|
+
if _config is None:
|
|
45
|
+
raise RuntimeError(
|
|
46
|
+
"Billing not initialized. Call init_billing(config) before using billing functions."
|
|
47
|
+
)
|
|
48
|
+
return _config
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Invoice retrieval operations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from datetime import UTC, datetime
|
|
6
|
+
|
|
7
|
+
import stripe
|
|
8
|
+
|
|
9
|
+
from billing.models import Invoice
|
|
10
|
+
from billing.subscriptions import _get_customer_id
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_invoices(user_id: str, limit: int = 10) -> list[Invoice]:
|
|
14
|
+
"""Fetch recent invoices for a user.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
user_id: Internal user identifier.
|
|
18
|
+
limit: Maximum number of invoices to return. Defaults to 10.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
List of Invoice objects, most recent first.
|
|
22
|
+
|
|
23
|
+
Raises:
|
|
24
|
+
ValueError: If no Stripe customer is found for the user.
|
|
25
|
+
stripe.StripeError: On Stripe API failures.
|
|
26
|
+
"""
|
|
27
|
+
customer_id = _get_customer_id(user_id)
|
|
28
|
+
|
|
29
|
+
stripe_invoices = stripe.Invoice.list(
|
|
30
|
+
customer=customer_id,
|
|
31
|
+
limit=limit,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
invoices: list[Invoice] = []
|
|
35
|
+
for inv in stripe_invoices.data:
|
|
36
|
+
invoices.append(
|
|
37
|
+
Invoice(
|
|
38
|
+
id=inv.id,
|
|
39
|
+
amount=inv.amount_paid or inv.amount_due,
|
|
40
|
+
currency=inv.currency,
|
|
41
|
+
status=inv.status,
|
|
42
|
+
created=datetime.fromtimestamp(inv.created, tz=UTC),
|
|
43
|
+
pdf_url=inv.invoice_pdf,
|
|
44
|
+
)
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
return invoices
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Pydantic models for billing domain objects."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Subscription(BaseModel):
|
|
11
|
+
"""Represents an active or past subscription."""
|
|
12
|
+
|
|
13
|
+
id: str
|
|
14
|
+
user_id: str
|
|
15
|
+
plan_name: str
|
|
16
|
+
status: str = Field(description="Stripe subscription status (active, canceled, past_due, etc.)")
|
|
17
|
+
current_period_start: datetime
|
|
18
|
+
current_period_end: datetime
|
|
19
|
+
cancel_at_period_end: bool = False
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Invoice(BaseModel):
|
|
23
|
+
"""Represents a Stripe invoice."""
|
|
24
|
+
|
|
25
|
+
id: str
|
|
26
|
+
amount: int = Field(description="Amount in smallest currency unit (e.g., cents)")
|
|
27
|
+
currency: str
|
|
28
|
+
status: str
|
|
29
|
+
created: datetime
|
|
30
|
+
pdf_url: str | None = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class UsageRecord(BaseModel):
|
|
34
|
+
"""Represents a metered usage record."""
|
|
35
|
+
|
|
36
|
+
user_id: str
|
|
37
|
+
metric: str
|
|
38
|
+
quantity: int
|
|
39
|
+
timestamp: datetime = Field(default_factory=datetime.now)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class CheckoutSession(BaseModel):
|
|
43
|
+
"""Result of creating a Stripe Checkout session."""
|
|
44
|
+
|
|
45
|
+
url: str = Field(description="Redirect URL for the checkout page")
|
|
46
|
+
session_id: str
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class WebhookEvent(BaseModel):
|
|
50
|
+
"""Parsed and verified Stripe webhook event."""
|
|
51
|
+
|
|
52
|
+
type: str = Field(description="Stripe event type (e.g., checkout.session.completed)")
|
|
53
|
+
data: dict = Field(default_factory=dict, description="Event payload data")
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""FastAPI router exposing billing endpoints."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Annotated
|
|
6
|
+
|
|
7
|
+
from errors import (
|
|
8
|
+
NotFoundError,
|
|
9
|
+
ValidationError,
|
|
10
|
+
)
|
|
11
|
+
from fastapi import APIRouter, Header, Request
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
|
|
14
|
+
from billing.checkout import create_checkout, create_portal_session
|
|
15
|
+
from billing.invoices import get_invoices
|
|
16
|
+
from billing.models import CheckoutSession, Invoice, Subscription
|
|
17
|
+
from billing.subscriptions import cancel_subscription, get_subscription
|
|
18
|
+
from billing.usage import record_usage
|
|
19
|
+
from billing.webhooks import handle_webhook
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class CheckoutRequest(BaseModel):
|
|
23
|
+
"""Request body for creating a checkout session."""
|
|
24
|
+
|
|
25
|
+
user_id: str
|
|
26
|
+
plan: str
|
|
27
|
+
email: str | None = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class PortalRequest(BaseModel):
|
|
31
|
+
"""Request body for creating a customer portal session."""
|
|
32
|
+
|
|
33
|
+
customer_id: str
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class UsageRequest(BaseModel):
|
|
37
|
+
"""Request body for recording usage."""
|
|
38
|
+
|
|
39
|
+
user_id: str
|
|
40
|
+
metric: str
|
|
41
|
+
quantity: int = 1
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class CancelRequest(BaseModel):
|
|
45
|
+
"""Request body for cancelling a subscription."""
|
|
46
|
+
|
|
47
|
+
user_id: str
|
|
48
|
+
at_period_end: bool = True
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class PortalResponse(BaseModel):
|
|
52
|
+
"""Response containing a portal session URL."""
|
|
53
|
+
|
|
54
|
+
url: str
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
BillingRouter = APIRouter(prefix="/billing", tags=["billing"])
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@BillingRouter.post("/checkout", response_model=CheckoutSession)
|
|
61
|
+
async def checkout_endpoint(body: CheckoutRequest) -> CheckoutSession:
|
|
62
|
+
"""Create a Stripe Checkout session for a plan."""
|
|
63
|
+
try:
|
|
64
|
+
return create_checkout(
|
|
65
|
+
user_id=body.user_id,
|
|
66
|
+
plan=body.plan,
|
|
67
|
+
email=body.email,
|
|
68
|
+
)
|
|
69
|
+
except ValueError as e:
|
|
70
|
+
raise ValidationError(str(e)) from e
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@BillingRouter.get("/subscription", response_model=Subscription | None)
|
|
74
|
+
async def subscription_endpoint(user_id: str) -> Subscription | None:
|
|
75
|
+
"""Get the current subscription for a user."""
|
|
76
|
+
return get_subscription(user_id)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@BillingRouter.post("/cancel", response_model=Subscription)
|
|
80
|
+
async def cancel_endpoint(body: CancelRequest) -> Subscription:
|
|
81
|
+
"""Cancel a user's subscription."""
|
|
82
|
+
try:
|
|
83
|
+
return cancel_subscription(
|
|
84
|
+
user_id=body.user_id,
|
|
85
|
+
at_period_end=body.at_period_end,
|
|
86
|
+
)
|
|
87
|
+
except ValueError as e:
|
|
88
|
+
raise NotFoundError(str(e)) from e
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@BillingRouter.post("/portal", response_model=PortalResponse)
|
|
92
|
+
async def portal_endpoint(body: PortalRequest) -> PortalResponse:
|
|
93
|
+
"""Create a Stripe Customer Portal session."""
|
|
94
|
+
url = create_portal_session(customer_id=body.customer_id)
|
|
95
|
+
return PortalResponse(url=url)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@BillingRouter.post("/webhook")
|
|
99
|
+
async def webhook_endpoint(
|
|
100
|
+
request: Request,
|
|
101
|
+
stripe_signature: Annotated[str, Header(alias="Stripe-Signature")],
|
|
102
|
+
) -> dict:
|
|
103
|
+
"""Handle Stripe webhook events.
|
|
104
|
+
|
|
105
|
+
Reads the raw body and verifies the Stripe signature before processing.
|
|
106
|
+
"""
|
|
107
|
+
payload = await request.body()
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
event = handle_webhook(payload=payload, signature=stripe_signature)
|
|
111
|
+
except ValueError as e:
|
|
112
|
+
raise ValidationError(str(e)) from e
|
|
113
|
+
|
|
114
|
+
return {"type": event.type, "received": True}
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@BillingRouter.post("/usage")
|
|
118
|
+
async def usage_endpoint(body: UsageRequest) -> dict:
|
|
119
|
+
"""Record metered usage for billing."""
|
|
120
|
+
try:
|
|
121
|
+
record_usage(
|
|
122
|
+
user_id=body.user_id,
|
|
123
|
+
metric=body.metric,
|
|
124
|
+
quantity=body.quantity,
|
|
125
|
+
)
|
|
126
|
+
except ValueError as e:
|
|
127
|
+
raise ValidationError(str(e)) from e
|
|
128
|
+
|
|
129
|
+
return {"recorded": True}
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@BillingRouter.get("/invoices", response_model=list[Invoice])
|
|
133
|
+
async def invoices_endpoint(user_id: str, limit: int = 10) -> list[Invoice]:
|
|
134
|
+
"""Fetch recent invoices for a user."""
|
|
135
|
+
try:
|
|
136
|
+
return get_invoices(user_id=user_id, limit=limit)
|
|
137
|
+
except ValueError as e:
|
|
138
|
+
raise NotFoundError(str(e)) from e
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""Subscription lifecycle management via Stripe."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from datetime import UTC, datetime
|
|
6
|
+
|
|
7
|
+
import stripe
|
|
8
|
+
|
|
9
|
+
from billing.config import get_config
|
|
10
|
+
from billing.models import Subscription
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _get_customer_id(user_id: str) -> str:
|
|
14
|
+
"""Look up the Stripe customer ID by user_id stored in metadata.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
user_id: Internal user identifier.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Stripe customer ID.
|
|
21
|
+
|
|
22
|
+
Raises:
|
|
23
|
+
ValueError: If no customer is found for the given user_id.
|
|
24
|
+
"""
|
|
25
|
+
customers = stripe.Customer.search(
|
|
26
|
+
query=f'metadata["user_id"]:"{user_id}"',
|
|
27
|
+
limit=1,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
if not customers.data:
|
|
31
|
+
raise ValueError(f"No Stripe customer found for user_id '{user_id}'")
|
|
32
|
+
|
|
33
|
+
return customers.data[0].id
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _resolve_plan_name(price_id: str) -> str:
|
|
37
|
+
"""Resolve a Stripe price_id back to a plan name from config.
|
|
38
|
+
|
|
39
|
+
Returns the price_id itself if no matching plan is found.
|
|
40
|
+
"""
|
|
41
|
+
config = get_config()
|
|
42
|
+
for plan_name, configured_price_id in config.product_plans.items():
|
|
43
|
+
if configured_price_id == price_id:
|
|
44
|
+
return plan_name
|
|
45
|
+
return price_id
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _stripe_sub_to_model(sub: stripe.Subscription, user_id: str) -> Subscription:
|
|
49
|
+
"""Convert a Stripe Subscription object to our domain model."""
|
|
50
|
+
price_id = sub["items"]["data"][0]["price"]["id"]
|
|
51
|
+
return Subscription(
|
|
52
|
+
id=sub.id,
|
|
53
|
+
user_id=user_id,
|
|
54
|
+
plan_name=_resolve_plan_name(price_id),
|
|
55
|
+
status=sub.status,
|
|
56
|
+
current_period_start=datetime.fromtimestamp(sub.current_period_start, tz=UTC),
|
|
57
|
+
current_period_end=datetime.fromtimestamp(sub.current_period_end, tz=UTC),
|
|
58
|
+
cancel_at_period_end=sub.cancel_at_period_end,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_subscription(user_id: str) -> Subscription | None:
|
|
63
|
+
"""Fetch the current active subscription for a user.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
user_id: Internal user identifier.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Subscription if found, None if the user has no active subscription.
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
stripe.StripeError: On Stripe API failures.
|
|
73
|
+
"""
|
|
74
|
+
try:
|
|
75
|
+
customer_id = _get_customer_id(user_id)
|
|
76
|
+
except ValueError:
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
subscriptions = stripe.Subscription.list(
|
|
80
|
+
customer=customer_id,
|
|
81
|
+
status="active",
|
|
82
|
+
limit=1,
|
|
83
|
+
expand=["data.items.data.price"],
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if not subscriptions.data:
|
|
87
|
+
# Also check for past_due or trialing
|
|
88
|
+
subscriptions = stripe.Subscription.list(
|
|
89
|
+
customer=customer_id,
|
|
90
|
+
limit=1,
|
|
91
|
+
expand=["data.items.data.price"],
|
|
92
|
+
)
|
|
93
|
+
if not subscriptions.data:
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
return _stripe_sub_to_model(subscriptions.data[0], user_id)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def cancel_subscription(user_id: str, at_period_end: bool = True) -> Subscription:
|
|
100
|
+
"""Cancel a user's subscription.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
user_id: Internal user identifier.
|
|
104
|
+
at_period_end: If True, cancel at the end of the current billing period.
|
|
105
|
+
If False, cancel immediately.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Updated Subscription after cancellation.
|
|
109
|
+
|
|
110
|
+
Raises:
|
|
111
|
+
ValueError: If no active subscription is found.
|
|
112
|
+
stripe.StripeError: On Stripe API failures.
|
|
113
|
+
"""
|
|
114
|
+
sub = get_subscription(user_id)
|
|
115
|
+
if sub is None:
|
|
116
|
+
raise ValueError(f"No active subscription found for user_id '{user_id}'")
|
|
117
|
+
|
|
118
|
+
if at_period_end:
|
|
119
|
+
updated = stripe.Subscription.modify(
|
|
120
|
+
sub.id,
|
|
121
|
+
cancel_at_period_end=True,
|
|
122
|
+
)
|
|
123
|
+
else:
|
|
124
|
+
updated = stripe.Subscription.cancel(sub.id)
|
|
125
|
+
|
|
126
|
+
return _stripe_sub_to_model(updated, user_id)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def update_subscription(user_id: str, new_plan: str) -> Subscription:
|
|
130
|
+
"""Change a user's subscription to a different plan (proration applied).
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
user_id: Internal user identifier.
|
|
134
|
+
new_plan: New plan name as defined in BillingConfig.product_plans.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Updated Subscription with the new plan.
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
ValueError: If no active subscription or unknown plan.
|
|
141
|
+
stripe.StripeError: On Stripe API failures.
|
|
142
|
+
"""
|
|
143
|
+
config = get_config()
|
|
144
|
+
new_price_id = config.product_plans.get(new_plan)
|
|
145
|
+
if not new_price_id:
|
|
146
|
+
available = ", ".join(config.product_plans.keys())
|
|
147
|
+
raise ValueError(f"Unknown plan '{new_plan}'. Available plans: {available}")
|
|
148
|
+
|
|
149
|
+
sub = get_subscription(user_id)
|
|
150
|
+
if sub is None:
|
|
151
|
+
raise ValueError(f"No active subscription found for user_id '{user_id}'")
|
|
152
|
+
|
|
153
|
+
# Get the current subscription item to replace
|
|
154
|
+
stripe_sub = stripe.Subscription.retrieve(sub.id, expand=["items.data.price"])
|
|
155
|
+
current_item_id = stripe_sub["items"]["data"][0].id
|
|
156
|
+
|
|
157
|
+
updated = stripe.Subscription.modify(
|
|
158
|
+
sub.id,
|
|
159
|
+
items=[
|
|
160
|
+
{
|
|
161
|
+
"id": current_item_id,
|
|
162
|
+
"price": new_price_id,
|
|
163
|
+
}
|
|
164
|
+
],
|
|
165
|
+
proration_behavior="create_prorations",
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
return _stripe_sub_to_model(updated, user_id)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""Metered usage billing operations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
|
|
7
|
+
import stripe
|
|
8
|
+
|
|
9
|
+
from billing.subscriptions import _get_customer_id
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _get_subscription_item_id(user_id: str, metric: str) -> str:
|
|
13
|
+
"""Find the subscription item for a given metered metric.
|
|
14
|
+
|
|
15
|
+
Looks up the user's active subscription and finds the line item
|
|
16
|
+
whose price has a lookup_key matching the metric name.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
user_id: Internal user identifier.
|
|
20
|
+
metric: The usage metric name (should match a Stripe price lookup_key).
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Stripe subscription item ID.
|
|
24
|
+
|
|
25
|
+
Raises:
|
|
26
|
+
ValueError: If no matching subscription item is found.
|
|
27
|
+
"""
|
|
28
|
+
customer_id = _get_customer_id(user_id)
|
|
29
|
+
|
|
30
|
+
subscriptions = stripe.Subscription.list(
|
|
31
|
+
customer=customer_id,
|
|
32
|
+
status="active",
|
|
33
|
+
limit=1,
|
|
34
|
+
expand=["data.items.data.price"],
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
if not subscriptions.data:
|
|
38
|
+
raise ValueError(f"No active subscription for user_id '{user_id}'")
|
|
39
|
+
|
|
40
|
+
for item in subscriptions.data[0]["items"]["data"]:
|
|
41
|
+
price = item["price"]
|
|
42
|
+
if price.get("lookup_key") == metric or price.get("id") == metric:
|
|
43
|
+
return item.id
|
|
44
|
+
|
|
45
|
+
raise ValueError(
|
|
46
|
+
f"No subscription item found for metric '{metric}' on user '{user_id}'. "
|
|
47
|
+
"Ensure the Stripe price has a lookup_key matching the metric name."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def record_usage(user_id: str, metric: str, quantity: int = 1) -> None:
|
|
52
|
+
"""Record metered usage for a subscription item.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
user_id: Internal user identifier.
|
|
56
|
+
metric: The usage metric name (matches Stripe price lookup_key).
|
|
57
|
+
quantity: Number of units to record. Defaults to 1.
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
ValueError: If no matching subscription item is found.
|
|
61
|
+
stripe.StripeError: On Stripe API failures.
|
|
62
|
+
"""
|
|
63
|
+
subscription_item_id = _get_subscription_item_id(user_id, metric)
|
|
64
|
+
|
|
65
|
+
stripe.SubscriptionItem.create_usage_record(
|
|
66
|
+
subscription_item_id,
|
|
67
|
+
quantity=quantity,
|
|
68
|
+
timestamp=int(time.time()),
|
|
69
|
+
action="increment",
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def get_usage(
|
|
74
|
+
user_id: str,
|
|
75
|
+
metric: str,
|
|
76
|
+
period_start: int | None = None,
|
|
77
|
+
) -> int:
|
|
78
|
+
"""Get total metered usage for a subscription item in the current period.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
user_id: Internal user identifier.
|
|
82
|
+
metric: The usage metric name.
|
|
83
|
+
period_start: Unix timestamp for the start of the period to query.
|
|
84
|
+
Defaults to the current subscription period start.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Total quantity used in the period.
|
|
88
|
+
|
|
89
|
+
Raises:
|
|
90
|
+
ValueError: If no matching subscription item is found.
|
|
91
|
+
stripe.StripeError: On Stripe API failures.
|
|
92
|
+
"""
|
|
93
|
+
subscription_item_id = _get_subscription_item_id(user_id, metric)
|
|
94
|
+
|
|
95
|
+
# Fetch usage record summaries for the period
|
|
96
|
+
summaries = stripe.SubscriptionItem.list_usage_record_summaries(
|
|
97
|
+
subscription_item_id,
|
|
98
|
+
limit=100,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
total = 0
|
|
102
|
+
for summary in summaries.auto_paging_iter():
|
|
103
|
+
# Filter by period_start if provided
|
|
104
|
+
if period_start is not None and summary.period.start < period_start:
|
|
105
|
+
continue
|
|
106
|
+
total += summary.total_usage
|
|
107
|
+
|
|
108
|
+
return total
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Stripe webhook handling with signature verification."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
import stripe
|
|
8
|
+
from stripe import SignatureVerificationError
|
|
9
|
+
|
|
10
|
+
from billing.config import get_config
|
|
11
|
+
from billing.models import WebhookEvent
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger("billing.webhooks")
|
|
14
|
+
|
|
15
|
+
# Events this module understands and returns structured data for.
|
|
16
|
+
HANDLED_EVENTS = frozenset(
|
|
17
|
+
{
|
|
18
|
+
"checkout.session.completed",
|
|
19
|
+
"customer.subscription.updated",
|
|
20
|
+
"customer.subscription.deleted",
|
|
21
|
+
"invoice.payment_succeeded",
|
|
22
|
+
"invoice.payment_failed",
|
|
23
|
+
}
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def handle_webhook(payload: bytes, signature: str) -> WebhookEvent:
|
|
28
|
+
"""Verify and parse a Stripe webhook event.
|
|
29
|
+
|
|
30
|
+
Verifies the webhook signature using the configured webhook secret,
|
|
31
|
+
then extracts structured data for known event types.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
payload: Raw request body bytes from the webhook POST.
|
|
35
|
+
signature: Value of the Stripe-Signature header.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
WebhookEvent with the event type and extracted data.
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
ValueError: If signature verification fails.
|
|
42
|
+
"""
|
|
43
|
+
config = get_config()
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
event = stripe.Webhook.construct_event(
|
|
47
|
+
payload,
|
|
48
|
+
signature,
|
|
49
|
+
config.stripe_webhook_secret,
|
|
50
|
+
)
|
|
51
|
+
except SignatureVerificationError as e:
|
|
52
|
+
logger.warning("Webhook signature verification failed: %s", e)
|
|
53
|
+
raise ValueError("Invalid webhook signature") from e
|
|
54
|
+
|
|
55
|
+
event_type = event.type
|
|
56
|
+
event_data: dict = {}
|
|
57
|
+
|
|
58
|
+
if event_type not in HANDLED_EVENTS:
|
|
59
|
+
logger.debug("Unhandled webhook event type: %s", event_type)
|
|
60
|
+
return WebhookEvent(type=event_type, data=event.data.object)
|
|
61
|
+
|
|
62
|
+
obj = event.data.object
|
|
63
|
+
|
|
64
|
+
if event_type == "checkout.session.completed":
|
|
65
|
+
event_data = {
|
|
66
|
+
"session_id": obj.id,
|
|
67
|
+
"customer_id": obj.customer,
|
|
68
|
+
"user_id": obj.client_reference_id,
|
|
69
|
+
"subscription_id": obj.subscription,
|
|
70
|
+
"email": obj.customer_email or obj.customer_details.get("email")
|
|
71
|
+
if obj.customer_details
|
|
72
|
+
else obj.customer_email,
|
|
73
|
+
"metadata": dict(obj.metadata) if obj.metadata else {},
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
elif event_type == "customer.subscription.updated":
|
|
77
|
+
event_data = {
|
|
78
|
+
"subscription_id": obj.id,
|
|
79
|
+
"customer_id": obj.customer,
|
|
80
|
+
"status": obj.status,
|
|
81
|
+
"cancel_at_period_end": obj.cancel_at_period_end,
|
|
82
|
+
"current_period_end": obj.current_period_end,
|
|
83
|
+
"items": [
|
|
84
|
+
{"price_id": item.price.id, "quantity": item.quantity}
|
|
85
|
+
for item in obj["items"]["data"]
|
|
86
|
+
],
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
elif event_type == "customer.subscription.deleted":
|
|
90
|
+
event_data = {
|
|
91
|
+
"subscription_id": obj.id,
|
|
92
|
+
"customer_id": obj.customer,
|
|
93
|
+
"status": obj.status,
|
|
94
|
+
"ended_at": obj.ended_at,
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
elif event_type == "invoice.payment_succeeded":
|
|
98
|
+
event_data = {
|
|
99
|
+
"invoice_id": obj.id,
|
|
100
|
+
"customer_id": obj.customer,
|
|
101
|
+
"subscription_id": obj.subscription,
|
|
102
|
+
"amount_paid": obj.amount_paid,
|
|
103
|
+
"currency": obj.currency,
|
|
104
|
+
"hosted_invoice_url": obj.hosted_invoice_url,
|
|
105
|
+
"pdf_url": obj.invoice_pdf,
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
elif event_type == "invoice.payment_failed":
|
|
109
|
+
event_data = {
|
|
110
|
+
"invoice_id": obj.id,
|
|
111
|
+
"customer_id": obj.customer,
|
|
112
|
+
"subscription_id": obj.subscription,
|
|
113
|
+
"amount_due": obj.amount_due,
|
|
114
|
+
"currency": obj.currency,
|
|
115
|
+
"attempt_count": obj.attempt_count,
|
|
116
|
+
"next_payment_attempt": obj.next_payment_attempt,
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
logger.info("Processed webhook event: %s", event_type)
|
|
120
|
+
return WebhookEvent(type=event_type, data=event_data)
|
|
File without changes
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Shared test fixtures for billing tests."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
from billing.config import BillingConfig, init_billing
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@pytest.fixture(autouse=True)
|
|
10
|
+
def billing_config():
|
|
11
|
+
"""Initialize billing with test configuration before each test."""
|
|
12
|
+
config = BillingConfig(
|
|
13
|
+
stripe_secret_key="sk_test_fake_key_for_testing",
|
|
14
|
+
stripe_webhook_secret="whsec_test_fake_secret",
|
|
15
|
+
success_url="https://example.com/success",
|
|
16
|
+
cancel_url="https://example.com/cancel",
|
|
17
|
+
product_plans={
|
|
18
|
+
"starter": "price_starter_test",
|
|
19
|
+
"pro": "price_pro_test",
|
|
20
|
+
"enterprise": "price_enterprise_test",
|
|
21
|
+
},
|
|
22
|
+
)
|
|
23
|
+
init_billing(config)
|
|
24
|
+
return config
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""Tests for checkout session creation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from unittest.mock import MagicMock, patch
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
from billing.checkout import create_checkout, create_portal_session
|
|
9
|
+
from billing.models import CheckoutSession
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TestCreateCheckout:
|
|
13
|
+
"""Tests for create_checkout function."""
|
|
14
|
+
|
|
15
|
+
@patch("billing.checkout.stripe.checkout.Session.create")
|
|
16
|
+
def test_creates_session_with_correct_params(self, mock_create: MagicMock) -> None:
|
|
17
|
+
mock_create.return_value = MagicMock(
|
|
18
|
+
url="https://checkout.stripe.com/c/pay_test123",
|
|
19
|
+
id="cs_test_session123",
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
result = create_checkout(user_id="user_abc", plan="pro", email="user@example.com")
|
|
23
|
+
|
|
24
|
+
assert isinstance(result, CheckoutSession)
|
|
25
|
+
assert result.url == "https://checkout.stripe.com/c/pay_test123"
|
|
26
|
+
assert result.session_id == "cs_test_session123"
|
|
27
|
+
|
|
28
|
+
call_kwargs = mock_create.call_args[1]
|
|
29
|
+
assert call_kwargs["mode"] == "subscription"
|
|
30
|
+
assert call_kwargs["line_items"] == [{"price": "price_pro_test", "quantity": 1}]
|
|
31
|
+
assert call_kwargs["client_reference_id"] == "user_abc"
|
|
32
|
+
assert call_kwargs["customer_email"] == "user@example.com"
|
|
33
|
+
assert call_kwargs["success_url"] == "https://example.com/success"
|
|
34
|
+
assert call_kwargs["cancel_url"] == "https://example.com/cancel"
|
|
35
|
+
assert call_kwargs["metadata"]["user_id"] == "user_abc"
|
|
36
|
+
assert call_kwargs["metadata"]["plan"] == "pro"
|
|
37
|
+
|
|
38
|
+
@patch("billing.checkout.stripe.checkout.Session.create")
|
|
39
|
+
def test_creates_session_without_email(self, mock_create: MagicMock) -> None:
|
|
40
|
+
mock_create.return_value = MagicMock(
|
|
41
|
+
url="https://checkout.stripe.com/c/pay_test456",
|
|
42
|
+
id="cs_test_session456",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
result = create_checkout(user_id="user_xyz", plan="starter")
|
|
46
|
+
|
|
47
|
+
assert result.session_id == "cs_test_session456"
|
|
48
|
+
|
|
49
|
+
call_kwargs = mock_create.call_args[1]
|
|
50
|
+
assert "customer_email" not in call_kwargs
|
|
51
|
+
|
|
52
|
+
def test_raises_value_error_for_unknown_plan(self) -> None:
|
|
53
|
+
with pytest.raises(ValueError, match="Unknown plan 'nonexistent'"):
|
|
54
|
+
create_checkout(user_id="user_abc", plan="nonexistent")
|
|
55
|
+
|
|
56
|
+
def test_error_message_lists_available_plans(self) -> None:
|
|
57
|
+
with pytest.raises(ValueError, match="starter") as exc_info:
|
|
58
|
+
create_checkout(user_id="user_abc", plan="bad_plan")
|
|
59
|
+
assert "pro" in str(exc_info.value)
|
|
60
|
+
assert "enterprise" in str(exc_info.value)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class TestCreatePortalSession:
|
|
64
|
+
"""Tests for create_portal_session function."""
|
|
65
|
+
|
|
66
|
+
@patch("billing.checkout.stripe.billing_portal.Session.create")
|
|
67
|
+
def test_returns_portal_url(self, mock_create: MagicMock) -> None:
|
|
68
|
+
mock_create.return_value = MagicMock(
|
|
69
|
+
url="https://billing.stripe.com/p/session/test_portal",
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
url = create_portal_session(customer_id="cus_test123")
|
|
73
|
+
|
|
74
|
+
assert url == "https://billing.stripe.com/p/session/test_portal"
|
|
75
|
+
mock_create.assert_called_once_with(
|
|
76
|
+
customer="cus_test123",
|
|
77
|
+
return_url="https://example.com/success",
|
|
78
|
+
)
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
"""Tests for webhook handling and signature verification."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from unittest.mock import MagicMock, patch
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
from billing.webhooks import handle_webhook
|
|
9
|
+
from stripe import SignatureVerificationError
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TestHandleWebhook:
|
|
13
|
+
"""Tests for handle_webhook function."""
|
|
14
|
+
|
|
15
|
+
def _make_stripe_event(self, event_type: str, obj_data: dict) -> MagicMock:
|
|
16
|
+
"""Helper to create a mock Stripe event."""
|
|
17
|
+
event = MagicMock()
|
|
18
|
+
event.type = event_type
|
|
19
|
+
event.data.object = MagicMock(**obj_data)
|
|
20
|
+
# Ensure dict-style access works for nested items
|
|
21
|
+
if "items" in obj_data:
|
|
22
|
+
event.data.object.__getitem__ = lambda self, key: obj_data[key]
|
|
23
|
+
return event
|
|
24
|
+
|
|
25
|
+
@patch("billing.webhooks.stripe.Webhook.construct_event")
|
|
26
|
+
def test_checkout_completed_event(self, mock_construct: MagicMock) -> None:
|
|
27
|
+
mock_event = self._make_stripe_event(
|
|
28
|
+
"checkout.session.completed",
|
|
29
|
+
{
|
|
30
|
+
"id": "cs_test_123",
|
|
31
|
+
"customer": "cus_test_456",
|
|
32
|
+
"client_reference_id": "user_abc",
|
|
33
|
+
"subscription": "sub_test_789",
|
|
34
|
+
"customer_email": "user@example.com",
|
|
35
|
+
"customer_details": None,
|
|
36
|
+
"metadata": {"user_id": "user_abc", "plan": "pro"},
|
|
37
|
+
},
|
|
38
|
+
)
|
|
39
|
+
mock_construct.return_value = mock_event
|
|
40
|
+
|
|
41
|
+
result = handle_webhook(b"payload", "sig_header")
|
|
42
|
+
|
|
43
|
+
assert result.type == "checkout.session.completed"
|
|
44
|
+
assert result.data["session_id"] == "cs_test_123"
|
|
45
|
+
assert result.data["customer_id"] == "cus_test_456"
|
|
46
|
+
assert result.data["user_id"] == "user_abc"
|
|
47
|
+
assert result.data["subscription_id"] == "sub_test_789"
|
|
48
|
+
|
|
49
|
+
@patch("billing.webhooks.stripe.Webhook.construct_event")
|
|
50
|
+
def test_subscription_deleted_event(self, mock_construct: MagicMock) -> None:
|
|
51
|
+
mock_event = self._make_stripe_event(
|
|
52
|
+
"customer.subscription.deleted",
|
|
53
|
+
{
|
|
54
|
+
"id": "sub_test_789",
|
|
55
|
+
"customer": "cus_test_456",
|
|
56
|
+
"status": "canceled",
|
|
57
|
+
"ended_at": 1700000000,
|
|
58
|
+
},
|
|
59
|
+
)
|
|
60
|
+
mock_construct.return_value = mock_event
|
|
61
|
+
|
|
62
|
+
result = handle_webhook(b"payload", "sig_header")
|
|
63
|
+
|
|
64
|
+
assert result.type == "customer.subscription.deleted"
|
|
65
|
+
assert result.data["subscription_id"] == "sub_test_789"
|
|
66
|
+
assert result.data["status"] == "canceled"
|
|
67
|
+
assert result.data["ended_at"] == 1700000000
|
|
68
|
+
|
|
69
|
+
@patch("billing.webhooks.stripe.Webhook.construct_event")
|
|
70
|
+
def test_invoice_payment_succeeded(self, mock_construct: MagicMock) -> None:
|
|
71
|
+
mock_event = self._make_stripe_event(
|
|
72
|
+
"invoice.payment_succeeded",
|
|
73
|
+
{
|
|
74
|
+
"id": "in_test_001",
|
|
75
|
+
"customer": "cus_test_456",
|
|
76
|
+
"subscription": "sub_test_789",
|
|
77
|
+
"amount_paid": 2999,
|
|
78
|
+
"currency": "usd",
|
|
79
|
+
"hosted_invoice_url": "https://invoice.stripe.com/i/test",
|
|
80
|
+
"invoice_pdf": "https://invoice.stripe.com/i/test/pdf",
|
|
81
|
+
},
|
|
82
|
+
)
|
|
83
|
+
mock_construct.return_value = mock_event
|
|
84
|
+
|
|
85
|
+
result = handle_webhook(b"payload", "sig_header")
|
|
86
|
+
|
|
87
|
+
assert result.type == "invoice.payment_succeeded"
|
|
88
|
+
assert result.data["amount_paid"] == 2999
|
|
89
|
+
assert result.data["currency"] == "usd"
|
|
90
|
+
assert result.data["pdf_url"] == "https://invoice.stripe.com/i/test/pdf"
|
|
91
|
+
|
|
92
|
+
@patch("billing.webhooks.stripe.Webhook.construct_event")
|
|
93
|
+
def test_invoice_payment_failed(self, mock_construct: MagicMock) -> None:
|
|
94
|
+
mock_event = self._make_stripe_event(
|
|
95
|
+
"invoice.payment_failed",
|
|
96
|
+
{
|
|
97
|
+
"id": "in_test_002",
|
|
98
|
+
"customer": "cus_test_456",
|
|
99
|
+
"subscription": "sub_test_789",
|
|
100
|
+
"amount_due": 2999,
|
|
101
|
+
"currency": "usd",
|
|
102
|
+
"attempt_count": 2,
|
|
103
|
+
"next_payment_attempt": 1700100000,
|
|
104
|
+
},
|
|
105
|
+
)
|
|
106
|
+
mock_construct.return_value = mock_event
|
|
107
|
+
|
|
108
|
+
result = handle_webhook(b"payload", "sig_header")
|
|
109
|
+
|
|
110
|
+
assert result.type == "invoice.payment_failed"
|
|
111
|
+
assert result.data["attempt_count"] == 2
|
|
112
|
+
assert result.data["next_payment_attempt"] == 1700100000
|
|
113
|
+
|
|
114
|
+
@patch("billing.webhooks.stripe.Webhook.construct_event")
|
|
115
|
+
def test_unhandled_event_type_returns_raw_data(self, mock_construct: MagicMock) -> None:
|
|
116
|
+
mock_event = MagicMock()
|
|
117
|
+
mock_event.type = "charge.refunded"
|
|
118
|
+
mock_event.data.object = {"id": "ch_test", "amount": 1000}
|
|
119
|
+
mock_construct.return_value = mock_event
|
|
120
|
+
|
|
121
|
+
result = handle_webhook(b"payload", "sig_header")
|
|
122
|
+
|
|
123
|
+
assert result.type == "charge.refunded"
|
|
124
|
+
|
|
125
|
+
@patch("billing.webhooks.stripe.Webhook.construct_event")
|
|
126
|
+
def test_invalid_signature_raises_value_error(self, mock_construct: MagicMock) -> None:
|
|
127
|
+
mock_construct.side_effect = SignatureVerificationError(
|
|
128
|
+
"No signatures found matching the expected signature for payload",
|
|
129
|
+
sig_header="bad_sig",
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
with pytest.raises(ValueError, match="Invalid webhook signature"):
|
|
133
|
+
handle_webhook(b"payload", "bad_sig")
|
|
134
|
+
|
|
135
|
+
@patch("billing.webhooks.stripe.Webhook.construct_event")
|
|
136
|
+
def test_construct_event_called_with_correct_args(self, mock_construct: MagicMock) -> None:
|
|
137
|
+
mock_event = MagicMock()
|
|
138
|
+
mock_event.type = "charge.refunded"
|
|
139
|
+
mock_event.data.object = {}
|
|
140
|
+
mock_construct.return_value = mock_event
|
|
141
|
+
|
|
142
|
+
handle_webhook(b"raw_body_data", "whsec_sig_value")
|
|
143
|
+
|
|
144
|
+
mock_construct.assert_called_once_with(
|
|
145
|
+
b"raw_body_data",
|
|
146
|
+
"whsec_sig_value",
|
|
147
|
+
"whsec_test_fake_secret",
|
|
148
|
+
)
|