msaas-billing 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ node_modules/
2
+ dist/
3
+ .next/
4
+ .turbo/
5
+ *.pyc
6
+ __pycache__/
7
+ .venv/
8
+ *.egg-info/
9
+ .pytest_cache/
10
+ .ruff_cache/
11
+ .env
12
+ .env.local
13
+ .env.*.local
14
+ .DS_Store
15
+ coverage/
16
+
17
+ # Runtime artifacts
18
+ logs_llm/
19
+ vectors.db
20
+ vectors.db-shm
21
+ vectors.db-wal
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: msaas-billing
3
+ Version: 0.1.0
4
+ Summary: Stripe billing module for SaaS products
5
+ Requires-Python: >=3.12
6
+ Requires-Dist: fastapi>=0.115.0
7
+ Requires-Dist: msaas-api-core
8
+ Requires-Dist: msaas-errors
9
+ Requires-Dist: pydantic>=2.0
10
+ Requires-Dist: stripe>=9.0.0
11
+ Provides-Extra: dev
12
+ Requires-Dist: httpx>=0.27; extra == 'dev'
13
+ Requires-Dist: pytest-asyncio>=0.24; extra == 'dev'
14
+ Requires-Dist: pytest>=8.0; extra == 'dev'
@@ -0,0 +1,30 @@
1
+ [project]
2
+ name = "msaas-billing"
3
+ version = "0.1.0"
4
+ description = "Stripe billing module for SaaS products"
5
+ requires-python = ">=3.12"
6
+ dependencies = [
7
+ "msaas-api-core",
8
+ "msaas-errors",
9
+ "stripe>=9.0.0",
10
+ "fastapi>=0.115.0",
11
+ "pydantic>=2.0",
12
+ ]
13
+
14
+ [build-system]
15
+ requires = ["hatchling"]
16
+ build-backend = "hatchling.build"
17
+
18
+ [tool.hatch.build.targets.wheel]
19
+ packages = ["src/billing"]
20
+
21
+ [project.optional-dependencies]
22
+ dev = [
23
+ "pytest>=8.0",
24
+ "pytest-asyncio>=0.24",
25
+ "httpx>=0.27",
26
+ ]
27
+
28
+ [tool.uv.sources]
29
+ msaas-api-core = { workspace = true }
30
+ msaas-errors = { workspace = true }
@@ -0,0 +1,56 @@
1
+ """willian-billing: Stripe billing module for SaaS products.
2
+
3
+ Usage:
4
+ from billing import init_billing, BillingConfig, BillingRouter
5
+
6
+ config = BillingConfig(
7
+ stripe_secret_key="sk_...",
8
+ stripe_webhook_secret="whsec_...",
9
+ success_url="https://app.example.com/billing/success",
10
+ cancel_url="https://app.example.com/billing/cancel",
11
+ product_plans={"starter": "price_xxx", "pro": "price_yyy"},
12
+ )
13
+ init_billing(config)
14
+
15
+ # Mount the router in your FastAPI app
16
+ app.include_router(BillingRouter)
17
+ """
18
+
19
+ from billing.checkout import create_checkout, create_portal_session
20
+ from billing.config import BillingConfig, init_billing
21
+ from billing.invoices import get_invoices
22
+ from billing.models import (
23
+ CheckoutSession,
24
+ Invoice,
25
+ Subscription,
26
+ UsageRecord,
27
+ WebhookEvent,
28
+ )
29
+ from billing.router import BillingRouter
30
+ from billing.subscriptions import (
31
+ cancel_subscription,
32
+ get_subscription,
33
+ update_subscription,
34
+ )
35
+ from billing.usage import get_usage, record_usage
36
+ from billing.webhooks import handle_webhook
37
+
38
+ __all__ = [
39
+ "BillingConfig",
40
+ "BillingRouter",
41
+ "CheckoutSession",
42
+ "Invoice",
43
+ "Subscription",
44
+ "UsageRecord",
45
+ "WebhookEvent",
46
+ "cancel_subscription",
47
+ "create_checkout",
48
+ "create_portal_session",
49
+ "get_invoices",
50
+ "get_subscription",
51
+ "get_usage",
52
+ "handle_webhook",
53
+ "init_billing",
54
+ "record_usage",
55
+ "update_subscription",
56
+ ]
@@ -0,0 +1,73 @@
1
+ """Stripe Checkout and Customer Portal session management."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import stripe
6
+
7
+ from billing.config import get_config
8
+ from billing.models import CheckoutSession
9
+
10
+
11
+ def create_checkout(
12
+ user_id: str,
13
+ plan: str,
14
+ email: str | None = None,
15
+ ) -> CheckoutSession:
16
+ """Create a Stripe Checkout session for the given plan.
17
+
18
+ Args:
19
+ user_id: Internal user identifier, stored as client_reference_id.
20
+ plan: Plan name as defined in BillingConfig.product_plans.
21
+ email: Optional customer email to pre-fill in checkout.
22
+
23
+ Returns:
24
+ CheckoutSession with the redirect URL and session ID.
25
+
26
+ Raises:
27
+ ValueError: If the plan name is not found in configured plans.
28
+ stripe.StripeError: On Stripe API failures.
29
+ """
30
+ config = get_config()
31
+
32
+ price_id = config.product_plans.get(plan)
33
+ if not price_id:
34
+ available = ", ".join(config.product_plans.keys())
35
+ raise ValueError(f"Unknown plan '{plan}'. Available plans: {available}")
36
+
37
+ params: dict = {
38
+ "mode": "subscription",
39
+ "line_items": [{"price": price_id, "quantity": 1}],
40
+ "success_url": config.success_url,
41
+ "cancel_url": config.cancel_url,
42
+ "client_reference_id": user_id,
43
+ "metadata": {"user_id": user_id, "plan": plan},
44
+ }
45
+
46
+ if email:
47
+ params["customer_email"] = email
48
+
49
+ session = stripe.checkout.Session.create(**params)
50
+
51
+ return CheckoutSession(
52
+ url=session.url,
53
+ session_id=session.id,
54
+ )
55
+
56
+
57
+ def create_portal_session(customer_id: str) -> str:
58
+ """Create a Stripe Customer Portal session.
59
+
60
+ Args:
61
+ customer_id: Stripe customer ID.
62
+
63
+ Returns:
64
+ The portal session URL for redirect.
65
+
66
+ Raises:
67
+ stripe.StripeError: On Stripe API failures.
68
+ """
69
+ session = stripe.billing_portal.Session.create(
70
+ customer=customer_id,
71
+ return_url=get_config().success_url,
72
+ )
73
+ return session.url
@@ -0,0 +1,48 @@
1
+ """Billing configuration and Stripe client initialization."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import stripe
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ class BillingConfig(BaseModel):
10
+ """Configuration for the billing module."""
11
+
12
+ stripe_secret_key: str = Field(description="Stripe secret API key")
13
+ stripe_webhook_secret: str = Field(description="Stripe webhook signing secret")
14
+ success_url: str = Field(description="Redirect URL after successful checkout")
15
+ cancel_url: str = Field(description="Redirect URL after cancelled checkout")
16
+ product_plans: dict[str, str] = Field(
17
+ default_factory=dict,
18
+ description="Mapping of plan_name -> stripe_price_id",
19
+ )
20
+
21
+
22
+ _config: BillingConfig | None = None
23
+
24
+
25
+ def init_billing(config: BillingConfig) -> None:
26
+ """Initialize the billing module with the given configuration.
27
+
28
+ Sets the Stripe API key and stores config for later use by all billing functions.
29
+
30
+ Args:
31
+ config: Billing configuration with Stripe credentials and plan mappings.
32
+ """
33
+ global _config
34
+ _config = config
35
+ stripe.api_key = config.stripe_secret_key
36
+
37
+
38
+ def get_config() -> BillingConfig:
39
+ """Return the current billing configuration.
40
+
41
+ Raises:
42
+ RuntimeError: If init_billing has not been called.
43
+ """
44
+ if _config is None:
45
+ raise RuntimeError(
46
+ "Billing not initialized. Call init_billing(config) before using billing functions."
47
+ )
48
+ return _config
@@ -0,0 +1,47 @@
1
+ """Invoice retrieval operations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from datetime import UTC, datetime
6
+
7
+ import stripe
8
+
9
+ from billing.models import Invoice
10
+ from billing.subscriptions import _get_customer_id
11
+
12
+
13
+ def get_invoices(user_id: str, limit: int = 10) -> list[Invoice]:
14
+ """Fetch recent invoices for a user.
15
+
16
+ Args:
17
+ user_id: Internal user identifier.
18
+ limit: Maximum number of invoices to return. Defaults to 10.
19
+
20
+ Returns:
21
+ List of Invoice objects, most recent first.
22
+
23
+ Raises:
24
+ ValueError: If no Stripe customer is found for the user.
25
+ stripe.StripeError: On Stripe API failures.
26
+ """
27
+ customer_id = _get_customer_id(user_id)
28
+
29
+ stripe_invoices = stripe.Invoice.list(
30
+ customer=customer_id,
31
+ limit=limit,
32
+ )
33
+
34
+ invoices: list[Invoice] = []
35
+ for inv in stripe_invoices.data:
36
+ invoices.append(
37
+ Invoice(
38
+ id=inv.id,
39
+ amount=inv.amount_paid or inv.amount_due,
40
+ currency=inv.currency,
41
+ status=inv.status,
42
+ created=datetime.fromtimestamp(inv.created, tz=UTC),
43
+ pdf_url=inv.invoice_pdf,
44
+ )
45
+ )
46
+
47
+ return invoices
@@ -0,0 +1,53 @@
1
+ """Pydantic models for billing domain objects."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from datetime import datetime
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+
10
+ class Subscription(BaseModel):
11
+ """Represents an active or past subscription."""
12
+
13
+ id: str
14
+ user_id: str
15
+ plan_name: str
16
+ status: str = Field(description="Stripe subscription status (active, canceled, past_due, etc.)")
17
+ current_period_start: datetime
18
+ current_period_end: datetime
19
+ cancel_at_period_end: bool = False
20
+
21
+
22
+ class Invoice(BaseModel):
23
+ """Represents a Stripe invoice."""
24
+
25
+ id: str
26
+ amount: int = Field(description="Amount in smallest currency unit (e.g., cents)")
27
+ currency: str
28
+ status: str
29
+ created: datetime
30
+ pdf_url: str | None = None
31
+
32
+
33
+ class UsageRecord(BaseModel):
34
+ """Represents a metered usage record."""
35
+
36
+ user_id: str
37
+ metric: str
38
+ quantity: int
39
+ timestamp: datetime = Field(default_factory=datetime.now)
40
+
41
+
42
+ class CheckoutSession(BaseModel):
43
+ """Result of creating a Stripe Checkout session."""
44
+
45
+ url: str = Field(description="Redirect URL for the checkout page")
46
+ session_id: str
47
+
48
+
49
+ class WebhookEvent(BaseModel):
50
+ """Parsed and verified Stripe webhook event."""
51
+
52
+ type: str = Field(description="Stripe event type (e.g., checkout.session.completed)")
53
+ data: dict = Field(default_factory=dict, description="Event payload data")
@@ -0,0 +1,138 @@
1
+ """FastAPI router exposing billing endpoints."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Annotated
6
+
7
+ from errors import (
8
+ NotFoundError,
9
+ ValidationError,
10
+ )
11
+ from fastapi import APIRouter, Header, Request
12
+ from pydantic import BaseModel
13
+
14
+ from billing.checkout import create_checkout, create_portal_session
15
+ from billing.invoices import get_invoices
16
+ from billing.models import CheckoutSession, Invoice, Subscription
17
+ from billing.subscriptions import cancel_subscription, get_subscription
18
+ from billing.usage import record_usage
19
+ from billing.webhooks import handle_webhook
20
+
21
+
22
+ class CheckoutRequest(BaseModel):
23
+ """Request body for creating a checkout session."""
24
+
25
+ user_id: str
26
+ plan: str
27
+ email: str | None = None
28
+
29
+
30
+ class PortalRequest(BaseModel):
31
+ """Request body for creating a customer portal session."""
32
+
33
+ customer_id: str
34
+
35
+
36
+ class UsageRequest(BaseModel):
37
+ """Request body for recording usage."""
38
+
39
+ user_id: str
40
+ metric: str
41
+ quantity: int = 1
42
+
43
+
44
+ class CancelRequest(BaseModel):
45
+ """Request body for cancelling a subscription."""
46
+
47
+ user_id: str
48
+ at_period_end: bool = True
49
+
50
+
51
+ class PortalResponse(BaseModel):
52
+ """Response containing a portal session URL."""
53
+
54
+ url: str
55
+
56
+
57
+ BillingRouter = APIRouter(prefix="/billing", tags=["billing"])
58
+
59
+
60
+ @BillingRouter.post("/checkout", response_model=CheckoutSession)
61
+ async def checkout_endpoint(body: CheckoutRequest) -> CheckoutSession:
62
+ """Create a Stripe Checkout session for a plan."""
63
+ try:
64
+ return create_checkout(
65
+ user_id=body.user_id,
66
+ plan=body.plan,
67
+ email=body.email,
68
+ )
69
+ except ValueError as e:
70
+ raise ValidationError(str(e)) from e
71
+
72
+
73
+ @BillingRouter.get("/subscription", response_model=Subscription | None)
74
+ async def subscription_endpoint(user_id: str) -> Subscription | None:
75
+ """Get the current subscription for a user."""
76
+ return get_subscription(user_id)
77
+
78
+
79
+ @BillingRouter.post("/cancel", response_model=Subscription)
80
+ async def cancel_endpoint(body: CancelRequest) -> Subscription:
81
+ """Cancel a user's subscription."""
82
+ try:
83
+ return cancel_subscription(
84
+ user_id=body.user_id,
85
+ at_period_end=body.at_period_end,
86
+ )
87
+ except ValueError as e:
88
+ raise NotFoundError(str(e)) from e
89
+
90
+
91
+ @BillingRouter.post("/portal", response_model=PortalResponse)
92
+ async def portal_endpoint(body: PortalRequest) -> PortalResponse:
93
+ """Create a Stripe Customer Portal session."""
94
+ url = create_portal_session(customer_id=body.customer_id)
95
+ return PortalResponse(url=url)
96
+
97
+
98
+ @BillingRouter.post("/webhook")
99
+ async def webhook_endpoint(
100
+ request: Request,
101
+ stripe_signature: Annotated[str, Header(alias="Stripe-Signature")],
102
+ ) -> dict:
103
+ """Handle Stripe webhook events.
104
+
105
+ Reads the raw body and verifies the Stripe signature before processing.
106
+ """
107
+ payload = await request.body()
108
+
109
+ try:
110
+ event = handle_webhook(payload=payload, signature=stripe_signature)
111
+ except ValueError as e:
112
+ raise ValidationError(str(e)) from e
113
+
114
+ return {"type": event.type, "received": True}
115
+
116
+
117
+ @BillingRouter.post("/usage")
118
+ async def usage_endpoint(body: UsageRequest) -> dict:
119
+ """Record metered usage for billing."""
120
+ try:
121
+ record_usage(
122
+ user_id=body.user_id,
123
+ metric=body.metric,
124
+ quantity=body.quantity,
125
+ )
126
+ except ValueError as e:
127
+ raise ValidationError(str(e)) from e
128
+
129
+ return {"recorded": True}
130
+
131
+
132
+ @BillingRouter.get("/invoices", response_model=list[Invoice])
133
+ async def invoices_endpoint(user_id: str, limit: int = 10) -> list[Invoice]:
134
+ """Fetch recent invoices for a user."""
135
+ try:
136
+ return get_invoices(user_id=user_id, limit=limit)
137
+ except ValueError as e:
138
+ raise NotFoundError(str(e)) from e
@@ -0,0 +1,168 @@
1
+ """Subscription lifecycle management via Stripe."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from datetime import UTC, datetime
6
+
7
+ import stripe
8
+
9
+ from billing.config import get_config
10
+ from billing.models import Subscription
11
+
12
+
13
+ def _get_customer_id(user_id: str) -> str:
14
+ """Look up the Stripe customer ID by user_id stored in metadata.
15
+
16
+ Args:
17
+ user_id: Internal user identifier.
18
+
19
+ Returns:
20
+ Stripe customer ID.
21
+
22
+ Raises:
23
+ ValueError: If no customer is found for the given user_id.
24
+ """
25
+ customers = stripe.Customer.search(
26
+ query=f'metadata["user_id"]:"{user_id}"',
27
+ limit=1,
28
+ )
29
+
30
+ if not customers.data:
31
+ raise ValueError(f"No Stripe customer found for user_id '{user_id}'")
32
+
33
+ return customers.data[0].id
34
+
35
+
36
+ def _resolve_plan_name(price_id: str) -> str:
37
+ """Resolve a Stripe price_id back to a plan name from config.
38
+
39
+ Returns the price_id itself if no matching plan is found.
40
+ """
41
+ config = get_config()
42
+ for plan_name, configured_price_id in config.product_plans.items():
43
+ if configured_price_id == price_id:
44
+ return plan_name
45
+ return price_id
46
+
47
+
48
+ def _stripe_sub_to_model(sub: stripe.Subscription, user_id: str) -> Subscription:
49
+ """Convert a Stripe Subscription object to our domain model."""
50
+ price_id = sub["items"]["data"][0]["price"]["id"]
51
+ return Subscription(
52
+ id=sub.id,
53
+ user_id=user_id,
54
+ plan_name=_resolve_plan_name(price_id),
55
+ status=sub.status,
56
+ current_period_start=datetime.fromtimestamp(sub.current_period_start, tz=UTC),
57
+ current_period_end=datetime.fromtimestamp(sub.current_period_end, tz=UTC),
58
+ cancel_at_period_end=sub.cancel_at_period_end,
59
+ )
60
+
61
+
62
+ def get_subscription(user_id: str) -> Subscription | None:
63
+ """Fetch the current active subscription for a user.
64
+
65
+ Args:
66
+ user_id: Internal user identifier.
67
+
68
+ Returns:
69
+ Subscription if found, None if the user has no active subscription.
70
+
71
+ Raises:
72
+ stripe.StripeError: On Stripe API failures.
73
+ """
74
+ try:
75
+ customer_id = _get_customer_id(user_id)
76
+ except ValueError:
77
+ return None
78
+
79
+ subscriptions = stripe.Subscription.list(
80
+ customer=customer_id,
81
+ status="active",
82
+ limit=1,
83
+ expand=["data.items.data.price"],
84
+ )
85
+
86
+ if not subscriptions.data:
87
+ # Also check for past_due or trialing
88
+ subscriptions = stripe.Subscription.list(
89
+ customer=customer_id,
90
+ limit=1,
91
+ expand=["data.items.data.price"],
92
+ )
93
+ if not subscriptions.data:
94
+ return None
95
+
96
+ return _stripe_sub_to_model(subscriptions.data[0], user_id)
97
+
98
+
99
+ def cancel_subscription(user_id: str, at_period_end: bool = True) -> Subscription:
100
+ """Cancel a user's subscription.
101
+
102
+ Args:
103
+ user_id: Internal user identifier.
104
+ at_period_end: If True, cancel at the end of the current billing period.
105
+ If False, cancel immediately.
106
+
107
+ Returns:
108
+ Updated Subscription after cancellation.
109
+
110
+ Raises:
111
+ ValueError: If no active subscription is found.
112
+ stripe.StripeError: On Stripe API failures.
113
+ """
114
+ sub = get_subscription(user_id)
115
+ if sub is None:
116
+ raise ValueError(f"No active subscription found for user_id '{user_id}'")
117
+
118
+ if at_period_end:
119
+ updated = stripe.Subscription.modify(
120
+ sub.id,
121
+ cancel_at_period_end=True,
122
+ )
123
+ else:
124
+ updated = stripe.Subscription.cancel(sub.id)
125
+
126
+ return _stripe_sub_to_model(updated, user_id)
127
+
128
+
129
+ def update_subscription(user_id: str, new_plan: str) -> Subscription:
130
+ """Change a user's subscription to a different plan (proration applied).
131
+
132
+ Args:
133
+ user_id: Internal user identifier.
134
+ new_plan: New plan name as defined in BillingConfig.product_plans.
135
+
136
+ Returns:
137
+ Updated Subscription with the new plan.
138
+
139
+ Raises:
140
+ ValueError: If no active subscription or unknown plan.
141
+ stripe.StripeError: On Stripe API failures.
142
+ """
143
+ config = get_config()
144
+ new_price_id = config.product_plans.get(new_plan)
145
+ if not new_price_id:
146
+ available = ", ".join(config.product_plans.keys())
147
+ raise ValueError(f"Unknown plan '{new_plan}'. Available plans: {available}")
148
+
149
+ sub = get_subscription(user_id)
150
+ if sub is None:
151
+ raise ValueError(f"No active subscription found for user_id '{user_id}'")
152
+
153
+ # Get the current subscription item to replace
154
+ stripe_sub = stripe.Subscription.retrieve(sub.id, expand=["items.data.price"])
155
+ current_item_id = stripe_sub["items"]["data"][0].id
156
+
157
+ updated = stripe.Subscription.modify(
158
+ sub.id,
159
+ items=[
160
+ {
161
+ "id": current_item_id,
162
+ "price": new_price_id,
163
+ }
164
+ ],
165
+ proration_behavior="create_prorations",
166
+ )
167
+
168
+ return _stripe_sub_to_model(updated, user_id)
@@ -0,0 +1,108 @@
1
+ """Metered usage billing operations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+
7
+ import stripe
8
+
9
+ from billing.subscriptions import _get_customer_id
10
+
11
+
12
+ def _get_subscription_item_id(user_id: str, metric: str) -> str:
13
+ """Find the subscription item for a given metered metric.
14
+
15
+ Looks up the user's active subscription and finds the line item
16
+ whose price has a lookup_key matching the metric name.
17
+
18
+ Args:
19
+ user_id: Internal user identifier.
20
+ metric: The usage metric name (should match a Stripe price lookup_key).
21
+
22
+ Returns:
23
+ Stripe subscription item ID.
24
+
25
+ Raises:
26
+ ValueError: If no matching subscription item is found.
27
+ """
28
+ customer_id = _get_customer_id(user_id)
29
+
30
+ subscriptions = stripe.Subscription.list(
31
+ customer=customer_id,
32
+ status="active",
33
+ limit=1,
34
+ expand=["data.items.data.price"],
35
+ )
36
+
37
+ if not subscriptions.data:
38
+ raise ValueError(f"No active subscription for user_id '{user_id}'")
39
+
40
+ for item in subscriptions.data[0]["items"]["data"]:
41
+ price = item["price"]
42
+ if price.get("lookup_key") == metric or price.get("id") == metric:
43
+ return item.id
44
+
45
+ raise ValueError(
46
+ f"No subscription item found for metric '{metric}' on user '{user_id}'. "
47
+ "Ensure the Stripe price has a lookup_key matching the metric name."
48
+ )
49
+
50
+
51
+ def record_usage(user_id: str, metric: str, quantity: int = 1) -> None:
52
+ """Record metered usage for a subscription item.
53
+
54
+ Args:
55
+ user_id: Internal user identifier.
56
+ metric: The usage metric name (matches Stripe price lookup_key).
57
+ quantity: Number of units to record. Defaults to 1.
58
+
59
+ Raises:
60
+ ValueError: If no matching subscription item is found.
61
+ stripe.StripeError: On Stripe API failures.
62
+ """
63
+ subscription_item_id = _get_subscription_item_id(user_id, metric)
64
+
65
+ stripe.SubscriptionItem.create_usage_record(
66
+ subscription_item_id,
67
+ quantity=quantity,
68
+ timestamp=int(time.time()),
69
+ action="increment",
70
+ )
71
+
72
+
73
+ def get_usage(
74
+ user_id: str,
75
+ metric: str,
76
+ period_start: int | None = None,
77
+ ) -> int:
78
+ """Get total metered usage for a subscription item in the current period.
79
+
80
+ Args:
81
+ user_id: Internal user identifier.
82
+ metric: The usage metric name.
83
+ period_start: Unix timestamp for the start of the period to query.
84
+ Defaults to the current subscription period start.
85
+
86
+ Returns:
87
+ Total quantity used in the period.
88
+
89
+ Raises:
90
+ ValueError: If no matching subscription item is found.
91
+ stripe.StripeError: On Stripe API failures.
92
+ """
93
+ subscription_item_id = _get_subscription_item_id(user_id, metric)
94
+
95
+ # Fetch usage record summaries for the period
96
+ summaries = stripe.SubscriptionItem.list_usage_record_summaries(
97
+ subscription_item_id,
98
+ limit=100,
99
+ )
100
+
101
+ total = 0
102
+ for summary in summaries.auto_paging_iter():
103
+ # Filter by period_start if provided
104
+ if period_start is not None and summary.period.start < period_start:
105
+ continue
106
+ total += summary.total_usage
107
+
108
+ return total
@@ -0,0 +1,120 @@
1
+ """Stripe webhook handling with signature verification."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+
7
+ import stripe
8
+ from stripe import SignatureVerificationError
9
+
10
+ from billing.config import get_config
11
+ from billing.models import WebhookEvent
12
+
13
+ logger = logging.getLogger("billing.webhooks")
14
+
15
+ # Events this module understands and returns structured data for.
16
+ HANDLED_EVENTS = frozenset(
17
+ {
18
+ "checkout.session.completed",
19
+ "customer.subscription.updated",
20
+ "customer.subscription.deleted",
21
+ "invoice.payment_succeeded",
22
+ "invoice.payment_failed",
23
+ }
24
+ )
25
+
26
+
27
+ def handle_webhook(payload: bytes, signature: str) -> WebhookEvent:
28
+ """Verify and parse a Stripe webhook event.
29
+
30
+ Verifies the webhook signature using the configured webhook secret,
31
+ then extracts structured data for known event types.
32
+
33
+ Args:
34
+ payload: Raw request body bytes from the webhook POST.
35
+ signature: Value of the Stripe-Signature header.
36
+
37
+ Returns:
38
+ WebhookEvent with the event type and extracted data.
39
+
40
+ Raises:
41
+ ValueError: If signature verification fails.
42
+ """
43
+ config = get_config()
44
+
45
+ try:
46
+ event = stripe.Webhook.construct_event(
47
+ payload,
48
+ signature,
49
+ config.stripe_webhook_secret,
50
+ )
51
+ except SignatureVerificationError as e:
52
+ logger.warning("Webhook signature verification failed: %s", e)
53
+ raise ValueError("Invalid webhook signature") from e
54
+
55
+ event_type = event.type
56
+ event_data: dict = {}
57
+
58
+ if event_type not in HANDLED_EVENTS:
59
+ logger.debug("Unhandled webhook event type: %s", event_type)
60
+ return WebhookEvent(type=event_type, data=event.data.object)
61
+
62
+ obj = event.data.object
63
+
64
+ if event_type == "checkout.session.completed":
65
+ event_data = {
66
+ "session_id": obj.id,
67
+ "customer_id": obj.customer,
68
+ "user_id": obj.client_reference_id,
69
+ "subscription_id": obj.subscription,
70
+ "email": obj.customer_email or obj.customer_details.get("email")
71
+ if obj.customer_details
72
+ else obj.customer_email,
73
+ "metadata": dict(obj.metadata) if obj.metadata else {},
74
+ }
75
+
76
+ elif event_type == "customer.subscription.updated":
77
+ event_data = {
78
+ "subscription_id": obj.id,
79
+ "customer_id": obj.customer,
80
+ "status": obj.status,
81
+ "cancel_at_period_end": obj.cancel_at_period_end,
82
+ "current_period_end": obj.current_period_end,
83
+ "items": [
84
+ {"price_id": item.price.id, "quantity": item.quantity}
85
+ for item in obj["items"]["data"]
86
+ ],
87
+ }
88
+
89
+ elif event_type == "customer.subscription.deleted":
90
+ event_data = {
91
+ "subscription_id": obj.id,
92
+ "customer_id": obj.customer,
93
+ "status": obj.status,
94
+ "ended_at": obj.ended_at,
95
+ }
96
+
97
+ elif event_type == "invoice.payment_succeeded":
98
+ event_data = {
99
+ "invoice_id": obj.id,
100
+ "customer_id": obj.customer,
101
+ "subscription_id": obj.subscription,
102
+ "amount_paid": obj.amount_paid,
103
+ "currency": obj.currency,
104
+ "hosted_invoice_url": obj.hosted_invoice_url,
105
+ "pdf_url": obj.invoice_pdf,
106
+ }
107
+
108
+ elif event_type == "invoice.payment_failed":
109
+ event_data = {
110
+ "invoice_id": obj.id,
111
+ "customer_id": obj.customer,
112
+ "subscription_id": obj.subscription,
113
+ "amount_due": obj.amount_due,
114
+ "currency": obj.currency,
115
+ "attempt_count": obj.attempt_count,
116
+ "next_payment_attempt": obj.next_payment_attempt,
117
+ }
118
+
119
+ logger.info("Processed webhook event: %s", event_type)
120
+ return WebhookEvent(type=event_type, data=event_data)
File without changes
@@ -0,0 +1,24 @@
1
+ """Shared test fixtures for billing tests."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+ from billing.config import BillingConfig, init_billing
7
+
8
+
9
+ @pytest.fixture(autouse=True)
10
+ def billing_config():
11
+ """Initialize billing with test configuration before each test."""
12
+ config = BillingConfig(
13
+ stripe_secret_key="sk_test_fake_key_for_testing",
14
+ stripe_webhook_secret="whsec_test_fake_secret",
15
+ success_url="https://example.com/success",
16
+ cancel_url="https://example.com/cancel",
17
+ product_plans={
18
+ "starter": "price_starter_test",
19
+ "pro": "price_pro_test",
20
+ "enterprise": "price_enterprise_test",
21
+ },
22
+ )
23
+ init_billing(config)
24
+ return config
@@ -0,0 +1,78 @@
1
+ """Tests for checkout session creation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from unittest.mock import MagicMock, patch
6
+
7
+ import pytest
8
+ from billing.checkout import create_checkout, create_portal_session
9
+ from billing.models import CheckoutSession
10
+
11
+
12
+ class TestCreateCheckout:
13
+ """Tests for create_checkout function."""
14
+
15
+ @patch("billing.checkout.stripe.checkout.Session.create")
16
+ def test_creates_session_with_correct_params(self, mock_create: MagicMock) -> None:
17
+ mock_create.return_value = MagicMock(
18
+ url="https://checkout.stripe.com/c/pay_test123",
19
+ id="cs_test_session123",
20
+ )
21
+
22
+ result = create_checkout(user_id="user_abc", plan="pro", email="user@example.com")
23
+
24
+ assert isinstance(result, CheckoutSession)
25
+ assert result.url == "https://checkout.stripe.com/c/pay_test123"
26
+ assert result.session_id == "cs_test_session123"
27
+
28
+ call_kwargs = mock_create.call_args[1]
29
+ assert call_kwargs["mode"] == "subscription"
30
+ assert call_kwargs["line_items"] == [{"price": "price_pro_test", "quantity": 1}]
31
+ assert call_kwargs["client_reference_id"] == "user_abc"
32
+ assert call_kwargs["customer_email"] == "user@example.com"
33
+ assert call_kwargs["success_url"] == "https://example.com/success"
34
+ assert call_kwargs["cancel_url"] == "https://example.com/cancel"
35
+ assert call_kwargs["metadata"]["user_id"] == "user_abc"
36
+ assert call_kwargs["metadata"]["plan"] == "pro"
37
+
38
+ @patch("billing.checkout.stripe.checkout.Session.create")
39
+ def test_creates_session_without_email(self, mock_create: MagicMock) -> None:
40
+ mock_create.return_value = MagicMock(
41
+ url="https://checkout.stripe.com/c/pay_test456",
42
+ id="cs_test_session456",
43
+ )
44
+
45
+ result = create_checkout(user_id="user_xyz", plan="starter")
46
+
47
+ assert result.session_id == "cs_test_session456"
48
+
49
+ call_kwargs = mock_create.call_args[1]
50
+ assert "customer_email" not in call_kwargs
51
+
52
+ def test_raises_value_error_for_unknown_plan(self) -> None:
53
+ with pytest.raises(ValueError, match="Unknown plan 'nonexistent'"):
54
+ create_checkout(user_id="user_abc", plan="nonexistent")
55
+
56
+ def test_error_message_lists_available_plans(self) -> None:
57
+ with pytest.raises(ValueError, match="starter") as exc_info:
58
+ create_checkout(user_id="user_abc", plan="bad_plan")
59
+ assert "pro" in str(exc_info.value)
60
+ assert "enterprise" in str(exc_info.value)
61
+
62
+
63
+ class TestCreatePortalSession:
64
+ """Tests for create_portal_session function."""
65
+
66
+ @patch("billing.checkout.stripe.billing_portal.Session.create")
67
+ def test_returns_portal_url(self, mock_create: MagicMock) -> None:
68
+ mock_create.return_value = MagicMock(
69
+ url="https://billing.stripe.com/p/session/test_portal",
70
+ )
71
+
72
+ url = create_portal_session(customer_id="cus_test123")
73
+
74
+ assert url == "https://billing.stripe.com/p/session/test_portal"
75
+ mock_create.assert_called_once_with(
76
+ customer="cus_test123",
77
+ return_url="https://example.com/success",
78
+ )
@@ -0,0 +1,148 @@
1
+ """Tests for webhook handling and signature verification."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from unittest.mock import MagicMock, patch
6
+
7
+ import pytest
8
+ from billing.webhooks import handle_webhook
9
+ from stripe import SignatureVerificationError
10
+
11
+
12
+ class TestHandleWebhook:
13
+ """Tests for handle_webhook function."""
14
+
15
+ def _make_stripe_event(self, event_type: str, obj_data: dict) -> MagicMock:
16
+ """Helper to create a mock Stripe event."""
17
+ event = MagicMock()
18
+ event.type = event_type
19
+ event.data.object = MagicMock(**obj_data)
20
+ # Ensure dict-style access works for nested items
21
+ if "items" in obj_data:
22
+ event.data.object.__getitem__ = lambda self, key: obj_data[key]
23
+ return event
24
+
25
+ @patch("billing.webhooks.stripe.Webhook.construct_event")
26
+ def test_checkout_completed_event(self, mock_construct: MagicMock) -> None:
27
+ mock_event = self._make_stripe_event(
28
+ "checkout.session.completed",
29
+ {
30
+ "id": "cs_test_123",
31
+ "customer": "cus_test_456",
32
+ "client_reference_id": "user_abc",
33
+ "subscription": "sub_test_789",
34
+ "customer_email": "user@example.com",
35
+ "customer_details": None,
36
+ "metadata": {"user_id": "user_abc", "plan": "pro"},
37
+ },
38
+ )
39
+ mock_construct.return_value = mock_event
40
+
41
+ result = handle_webhook(b"payload", "sig_header")
42
+
43
+ assert result.type == "checkout.session.completed"
44
+ assert result.data["session_id"] == "cs_test_123"
45
+ assert result.data["customer_id"] == "cus_test_456"
46
+ assert result.data["user_id"] == "user_abc"
47
+ assert result.data["subscription_id"] == "sub_test_789"
48
+
49
+ @patch("billing.webhooks.stripe.Webhook.construct_event")
50
+ def test_subscription_deleted_event(self, mock_construct: MagicMock) -> None:
51
+ mock_event = self._make_stripe_event(
52
+ "customer.subscription.deleted",
53
+ {
54
+ "id": "sub_test_789",
55
+ "customer": "cus_test_456",
56
+ "status": "canceled",
57
+ "ended_at": 1700000000,
58
+ },
59
+ )
60
+ mock_construct.return_value = mock_event
61
+
62
+ result = handle_webhook(b"payload", "sig_header")
63
+
64
+ assert result.type == "customer.subscription.deleted"
65
+ assert result.data["subscription_id"] == "sub_test_789"
66
+ assert result.data["status"] == "canceled"
67
+ assert result.data["ended_at"] == 1700000000
68
+
69
+ @patch("billing.webhooks.stripe.Webhook.construct_event")
70
+ def test_invoice_payment_succeeded(self, mock_construct: MagicMock) -> None:
71
+ mock_event = self._make_stripe_event(
72
+ "invoice.payment_succeeded",
73
+ {
74
+ "id": "in_test_001",
75
+ "customer": "cus_test_456",
76
+ "subscription": "sub_test_789",
77
+ "amount_paid": 2999,
78
+ "currency": "usd",
79
+ "hosted_invoice_url": "https://invoice.stripe.com/i/test",
80
+ "invoice_pdf": "https://invoice.stripe.com/i/test/pdf",
81
+ },
82
+ )
83
+ mock_construct.return_value = mock_event
84
+
85
+ result = handle_webhook(b"payload", "sig_header")
86
+
87
+ assert result.type == "invoice.payment_succeeded"
88
+ assert result.data["amount_paid"] == 2999
89
+ assert result.data["currency"] == "usd"
90
+ assert result.data["pdf_url"] == "https://invoice.stripe.com/i/test/pdf"
91
+
92
+ @patch("billing.webhooks.stripe.Webhook.construct_event")
93
+ def test_invoice_payment_failed(self, mock_construct: MagicMock) -> None:
94
+ mock_event = self._make_stripe_event(
95
+ "invoice.payment_failed",
96
+ {
97
+ "id": "in_test_002",
98
+ "customer": "cus_test_456",
99
+ "subscription": "sub_test_789",
100
+ "amount_due": 2999,
101
+ "currency": "usd",
102
+ "attempt_count": 2,
103
+ "next_payment_attempt": 1700100000,
104
+ },
105
+ )
106
+ mock_construct.return_value = mock_event
107
+
108
+ result = handle_webhook(b"payload", "sig_header")
109
+
110
+ assert result.type == "invoice.payment_failed"
111
+ assert result.data["attempt_count"] == 2
112
+ assert result.data["next_payment_attempt"] == 1700100000
113
+
114
+ @patch("billing.webhooks.stripe.Webhook.construct_event")
115
+ def test_unhandled_event_type_returns_raw_data(self, mock_construct: MagicMock) -> None:
116
+ mock_event = MagicMock()
117
+ mock_event.type = "charge.refunded"
118
+ mock_event.data.object = {"id": "ch_test", "amount": 1000}
119
+ mock_construct.return_value = mock_event
120
+
121
+ result = handle_webhook(b"payload", "sig_header")
122
+
123
+ assert result.type == "charge.refunded"
124
+
125
+ @patch("billing.webhooks.stripe.Webhook.construct_event")
126
+ def test_invalid_signature_raises_value_error(self, mock_construct: MagicMock) -> None:
127
+ mock_construct.side_effect = SignatureVerificationError(
128
+ "No signatures found matching the expected signature for payload",
129
+ sig_header="bad_sig",
130
+ )
131
+
132
+ with pytest.raises(ValueError, match="Invalid webhook signature"):
133
+ handle_webhook(b"payload", "bad_sig")
134
+
135
+ @patch("billing.webhooks.stripe.Webhook.construct_event")
136
+ def test_construct_event_called_with_correct_args(self, mock_construct: MagicMock) -> None:
137
+ mock_event = MagicMock()
138
+ mock_event.type = "charge.refunded"
139
+ mock_event.data.object = {}
140
+ mock_construct.return_value = mock_event
141
+
142
+ handle_webhook(b"raw_body_data", "whsec_sig_value")
143
+
144
+ mock_construct.assert_called_once_with(
145
+ b"raw_body_data",
146
+ "whsec_sig_value",
147
+ "whsec_test_fake_secret",
148
+ )