@better-auth/stripe 1.2.3 → 1.2.4-beta.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/index.ts CHANGED
@@ -1,6 +1,7 @@
1
1
  import {
2
2
  type GenericEndpointContext,
3
3
  type BetterAuthPlugin,
4
+ logger,
4
5
  } from "better-auth";
5
6
  import { createAuthEndpoint, createAuthMiddleware } from "better-auth/plugins";
6
7
  import Stripe from "stripe";
@@ -17,7 +18,12 @@ import {
17
18
  onSubscriptionDeleted,
18
19
  onSubscriptionUpdated,
19
20
  } from "./hooks";
20
- import type { InputSubscription, StripeOptions, Subscription } from "./types";
21
+ import type {
22
+ Customer,
23
+ InputSubscription,
24
+ StripeOptions,
25
+ Subscription,
26
+ } from "./types";
21
27
  import { getPlanByName, getPlanByPriceId, getPlans } from "./utils";
22
28
  import { getSchema } from "./schema";
23
29
 
@@ -77,35 +83,78 @@ export const stripe = <O extends StripeOptions>(options: O) => {
77
83
  {
78
84
  method: "POST",
79
85
  body: z.object({
86
+ /**
87
+ * The name of the plan to subscribe
88
+ */
80
89
  plan: z.string({
81
90
  description: "The name of the plan to upgrade to",
82
91
  }),
92
+ /**
93
+ * If annual plan should be applied.
94
+ */
83
95
  annual: z
84
96
  .boolean({
85
97
  description: "Whether to upgrade to an annual plan",
86
98
  })
87
99
  .optional(),
88
- referenceId: z.string().optional(),
100
+ /**
101
+ * Reference id of the subscription to upgrade
102
+ * This is used to identify the subscription to upgrade
103
+ * If not provided, the user's id will be used
104
+ */
105
+ referenceId: z
106
+ .string({
107
+ description: "Reference id of the subscription to upgrade",
108
+ })
109
+ .optional(),
110
+ /**
111
+ * This is to allow a specific subscription to be upgrade.
112
+ * If subscription id is provided, and subscription isn't found,
113
+ * it'll throw an error.
114
+ */
115
+ subscriptionId: z
116
+ .string({
117
+ description: "The id of the subscription to upgrade",
118
+ })
119
+ .optional(),
120
+ /**
121
+ * Any additional data you want to store in your database
122
+ * subscriptions
123
+ */
89
124
  metadata: z.record(z.string(), z.any()).optional(),
125
+ /**
126
+ * If a subscription
127
+ */
90
128
  seats: z
91
129
  .number({
92
130
  description: "Number of seats to upgrade to (if applicable)",
93
131
  })
94
132
  .optional(),
95
- uiMode: z.enum(["embedded", "hosted"]).default("hosted"),
133
+ /**
134
+ * Success url to redirect back after successful subscription
135
+ */
96
136
  successUrl: z
97
137
  .string({
98
138
  description:
99
139
  "callback url to redirect back after successful subscription",
100
140
  })
101
141
  .default("/"),
142
+ /**
143
+ * Cancel URL
144
+ */
102
145
  cancelUrl: z
103
146
  .string({
104
147
  description:
105
148
  "callback url to redirect back after successful subscription",
106
149
  })
107
150
  .default("/"),
151
+ /**
152
+ * Return URL
153
+ */
108
154
  returnUrl: z.string().optional(),
155
+ /**
156
+ * Disable Redirect
157
+ */
109
158
  disableRedirect: z.boolean().default(false),
110
159
  }),
111
160
  use: [
@@ -133,7 +182,22 @@ export const stripe = <O extends StripeOptions>(options: O) => {
133
182
  message: STRIPE_ERROR_CODES.SUBSCRIPTION_PLAN_NOT_FOUND,
134
183
  });
135
184
  }
136
- let customerId = user.stripeCustomerId;
185
+ const subscriptionToUpdate = ctx.body.subscriptionId
186
+ ? await ctx.context.adapter.findOne<Subscription>({
187
+ model: "subscription",
188
+ where: [{ field: "id", value: ctx.body.subscriptionId }],
189
+ })
190
+ : null;
191
+
192
+ if (ctx.body.subscriptionId && !subscriptionToUpdate) {
193
+ throw new APIError("BAD_REQUEST", {
194
+ message: STRIPE_ERROR_CODES.SUBSCRIPTION_NOT_FOUND,
195
+ });
196
+ }
197
+
198
+ let customerId =
199
+ subscriptionToUpdate?.stripeCustomerId || user.stripeCustomerId;
200
+
137
201
  if (!customerId) {
138
202
  try {
139
203
  const stripeCustomer = await client.customers.create(
@@ -179,15 +243,18 @@ export const stripe = <O extends StripeOptions>(options: O) => {
179
243
  .then((res) => res.data[0])
180
244
  .catch((e) => null)
181
245
  : null;
182
- const subscriptions = await ctx.context.adapter.findMany<Subscription>({
183
- model: "subscription",
184
- where: [
185
- {
186
- field: "referenceId",
187
- value: ctx.body.referenceId || user.id,
188
- },
189
- ],
190
- });
246
+
247
+ const subscriptions = subscriptionToUpdate
248
+ ? [subscriptionToUpdate]
249
+ : await ctx.context.adapter.findMany<Subscription>({
250
+ model: "subscription",
251
+ where: [
252
+ {
253
+ field: "referenceId",
254
+ value: ctx.body.referenceId || user.id,
255
+ },
256
+ ],
257
+ });
191
258
 
192
259
  const existingSubscription = subscriptions.find(
193
260
  (sub) => sub.status === "active" || sub.status === "trialing",
@@ -295,7 +362,7 @@ export const stripe = <O extends StripeOptions>(options: O) => {
295
362
  ctx.context.baseURL
296
363
  }/subscription/success?callbackURL=${encodeURIComponent(
297
364
  ctx.body.successUrl,
298
- )}&reference=${encodeURIComponent(referenceId)}`,
365
+ )}&subscriptionId=${encodeURIComponent(subscription.id)}`,
299
366
  ),
300
367
  cancel_url: getUrl(ctx, ctx.body.cancelUrl),
301
368
  line_items: [
@@ -338,9 +405,10 @@ export const stripe = <O extends StripeOptions>(options: O) => {
338
405
  {
339
406
  method: "GET",
340
407
  query: z.record(z.string(), z.any()).optional(),
408
+ use: [originCheck((ctx) => ctx.query.callbackURL)],
341
409
  },
342
410
  async (ctx) => {
343
- if (!ctx.query || !ctx.query.callbackURL || !ctx.query.reference) {
411
+ if (!ctx.query || !ctx.query.callbackURL || !ctx.query.subscriptionId) {
344
412
  throw ctx.redirect(getUrl(ctx, ctx.query?.callbackURL || "/"));
345
413
  }
346
414
  const session = await getSessionFromCtx<{ stripeCustomerId: string }>(
@@ -350,7 +418,7 @@ export const stripe = <O extends StripeOptions>(options: O) => {
350
418
  throw ctx.redirect(getUrl(ctx, ctx.query?.callbackURL || "/"));
351
419
  }
352
420
  const { user } = session;
353
- const { callbackURL, reference } = ctx.query;
421
+ const { callbackURL, subscriptionId } = ctx.query;
354
422
 
355
423
  if (user?.stripeCustomerId) {
356
424
  try {
@@ -359,8 +427,8 @@ export const stripe = <O extends StripeOptions>(options: O) => {
359
427
  model: "subscription",
360
428
  where: [
361
429
  {
362
- field: "referenceId",
363
- value: reference,
430
+ field: "id",
431
+ value: subscriptionId,
364
432
  },
365
433
  ],
366
434
  });
@@ -388,8 +456,8 @@ export const stripe = <O extends StripeOptions>(options: O) => {
388
456
  },
389
457
  where: [
390
458
  {
391
- field: "referenceId",
392
- value: reference,
459
+ field: "id",
460
+ value: subscription.id,
393
461
  },
394
462
  ],
395
463
  });
@@ -416,6 +484,7 @@ export const stripe = <O extends StripeOptions>(options: O) => {
416
484
  method: "POST",
417
485
  body: z.object({
418
486
  referenceId: z.string().optional(),
487
+ subscriptionId: z.string().optional(),
419
488
  returnUrl: z.string(),
420
489
  }),
421
490
  use: [
@@ -427,15 +496,27 @@ export const stripe = <O extends StripeOptions>(options: O) => {
427
496
  async (ctx) => {
428
497
  const referenceId =
429
498
  ctx.body?.referenceId || ctx.context.session.user.id;
430
- const subscription = await ctx.context.adapter.findOne<Subscription>({
431
- model: "subscription",
432
- where: [
433
- {
434
- field: "referenceId",
435
- value: referenceId,
436
- },
437
- ],
438
- });
499
+ const subscription = ctx.body.subscriptionId
500
+ ? await ctx.context.adapter.findOne<Subscription>({
501
+ model: "subscription",
502
+ where: [
503
+ {
504
+ field: "id",
505
+ value: ctx.body.subscriptionId,
506
+ },
507
+ ],
508
+ })
509
+ : await ctx.context.adapter
510
+ .findMany<Subscription>({
511
+ model: "subscription",
512
+ where: [{ field: "referenceId", value: referenceId }],
513
+ })
514
+ .then((subs) =>
515
+ subs.find(
516
+ (sub) => sub.status === "active" || sub.status === "trialing",
517
+ ),
518
+ );
519
+
439
520
  if (!subscription || !subscription.stripeCustomerId) {
440
521
  throw ctx.error("BAD_REQUEST", {
441
522
  message: STRIPE_ERROR_CODES.SUBSCRIPTION_NOT_FOUND,
@@ -485,7 +566,7 @@ export const stripe = <O extends StripeOptions>(options: O) => {
485
566
  ctx.context.baseURL
486
567
  }/subscription/cancel/callback?callbackURL=${encodeURIComponent(
487
568
  ctx.body?.returnUrl || "/",
488
- )}&reference=${encodeURIComponent(referenceId)}`,
569
+ )}&subscriptionId=${encodeURIComponent(subscription.id)}`,
489
570
  ),
490
571
  flow_data: {
491
572
  type: "subscription_cancel",
@@ -575,9 +656,10 @@ export const stripe = <O extends StripeOptions>(options: O) => {
575
656
  {
576
657
  method: "GET",
577
658
  query: z.record(z.string(), z.any()).optional(),
659
+ use: [originCheck((ctx) => ctx.query.callbackURL)],
578
660
  },
579
661
  async (ctx) => {
580
- if (!ctx.query || !ctx.query.callbackURL || !ctx.query.reference) {
662
+ if (!ctx.query || !ctx.query.callbackURL || !ctx.query.subscriptionId) {
581
663
  throw ctx.redirect(getUrl(ctx, ctx.query?.callbackURL || "/"));
582
664
  }
583
665
  const session = await getSessionFromCtx<{ stripeCustomerId: string }>(
@@ -587,44 +669,32 @@ export const stripe = <O extends StripeOptions>(options: O) => {
587
669
  throw ctx.redirect(getUrl(ctx, ctx.query?.callbackURL || "/"));
588
670
  }
589
671
  const { user } = session;
590
- const { callbackURL, reference } = ctx.query;
672
+ const { callbackURL, subscriptionId } = ctx.query;
591
673
 
592
- const subscriptions = await ctx.context.adapter.findMany<Subscription>({
674
+ const subscription = await ctx.context.adapter.findOne<Subscription>({
593
675
  model: "subscription",
594
676
  where: [
595
677
  {
596
- field: "referenceId",
597
- value: reference,
678
+ field: "id",
679
+ value: subscriptionId,
598
680
  },
599
681
  ],
600
682
  });
601
683
 
602
- const activeSubscription = subscriptions.find(
603
- (sub) => sub.status === "active" || sub.status === "trialing",
604
- );
605
-
606
- if (activeSubscription) {
684
+ if (
685
+ subscription?.status === "active" ||
686
+ subscription?.status === "trialing"
687
+ ) {
607
688
  return ctx.redirect(getUrl(ctx, callbackURL));
608
689
  }
690
+ const customerId =
691
+ subscription?.stripeCustomerId || user.stripeCustomerId;
609
692
 
610
- if (user?.stripeCustomerId) {
693
+ if (customerId) {
611
694
  try {
612
- const subscription =
613
- await ctx.context.adapter.findOne<Subscription>({
614
- model: "subscription",
615
- where: [
616
- {
617
- field: "referenceId",
618
- value: reference,
619
- },
620
- ],
621
- });
622
- if (!subscription || subscription.status === "active") {
623
- throw ctx.redirect(getUrl(ctx, callbackURL));
624
- }
625
695
  const stripeSubscription = await client.subscriptions
626
696
  .list({
627
- customer: user.stripeCustomerId,
697
+ customer: customerId,
628
698
  status: "active",
629
699
  })
630
700
  .then((res) => res.data[0]);
@@ -635,21 +705,36 @@ export const stripe = <O extends StripeOptions>(options: O) => {
635
705
  stripeSubscription.items.data[0]?.plan.id,
636
706
  );
637
707
 
638
- if (plan && subscriptions.length > 0) {
708
+ if (plan && subscription) {
639
709
  await ctx.context.adapter.update({
640
710
  model: "subscription",
641
711
  update: {
642
712
  status: stripeSubscription.status,
643
713
  seats: stripeSubscription.items.data[0]?.quantity || 1,
644
714
  plan: plan.name.toLowerCase(),
645
- periodEnd: stripeSubscription.current_period_end,
646
- periodStart: stripeSubscription.current_period_start,
715
+ periodEnd: new Date(
716
+ stripeSubscription.current_period_end * 1000,
717
+ ),
718
+ periodStart: new Date(
719
+ stripeSubscription.current_period_start * 1000,
720
+ ),
647
721
  stripeSubscriptionId: stripeSubscription.id,
722
+ ...(stripeSubscription.trial_start &&
723
+ stripeSubscription.trial_end
724
+ ? {
725
+ trialStart: new Date(
726
+ stripeSubscription.trial_start * 1000,
727
+ ),
728
+ trialEnd: new Date(
729
+ stripeSubscription.trial_end * 1000,
730
+ ),
731
+ }
732
+ : {}),
648
733
  },
649
734
  where: [
650
735
  {
651
- field: "referenceId",
652
- value: reference,
736
+ field: "id",
737
+ value: subscription.id,
653
738
  },
654
739
  ],
655
740
  });
@@ -692,7 +777,11 @@ export const stripe = <O extends StripeOptions>(options: O) => {
692
777
  message: "Stripe webhook secret not found",
693
778
  });
694
779
  }
695
- event = client.webhooks.constructEvent(buf, sig, webhookSecret);
780
+ event = await client.webhooks.constructEventAsync(
781
+ buf,
782
+ sig,
783
+ webhookSecret,
784
+ );
696
785
  } catch (err: any) {
697
786
  ctx.context.logger.error(`${err.message}`);
698
787
  throw new APIError("BAD_REQUEST", {
@@ -751,18 +840,29 @@ export const stripe = <O extends StripeOptions>(options: O) => {
751
840
  userId: user.id,
752
841
  },
753
842
  });
754
- await ctx.context.adapter.update({
755
- model: "user",
756
- update: {
757
- stripeCustomerId: stripeCustomer.id,
758
- },
759
- where: [
760
- {
761
- field: "id",
762
- value: user.id,
843
+ const customer = await ctx.context.adapter.update<Customer>(
844
+ {
845
+ model: "user",
846
+ update: {
847
+ stripeCustomerId: stripeCustomer.id,
763
848
  },
764
- ],
765
- });
849
+ where: [
850
+ {
851
+ field: "id",
852
+ value: user.id,
853
+ },
854
+ ],
855
+ },
856
+ );
857
+ if (!customer) {
858
+ logger.error("#BETTER_AUTH: Failed to create customer");
859
+ } else {
860
+ await options.onCustomerCreate?.({
861
+ customer,
862
+ stripeCustomer,
863
+ user,
864
+ });
865
+ }
766
866
  }
767
867
  },
768
868
  },
@@ -298,7 +298,9 @@ describe("stripe", async () => {
298
298
  retrieve: vi.fn().mockResolvedValue(mockSubscription),
299
299
  },
300
300
  webhooks: {
301
- constructEvent: vi.fn().mockReturnValue(mockCheckoutSessionEvent),
301
+ constructEventAsync: vi
302
+ .fn()
303
+ .mockResolvedValue(mockCheckoutSessionEvent),
302
304
  },
303
305
  };
304
306
 
@@ -403,7 +405,7 @@ describe("stripe", async () => {
403
405
  const stripeForTest = {
404
406
  ...stripeOptions.stripeClient,
405
407
  webhooks: {
406
- constructEvent: vi.fn().mockReturnValue(mockDeleteEvent),
408
+ constructEventAsync: vi.fn().mockResolvedValue(mockDeleteEvent),
407
409
  },
408
410
  subscriptions: {
409
411
  retrieve: vi.fn().mockResolvedValue({
@@ -514,7 +516,7 @@ describe("stripe", async () => {
514
516
  retrieve: vi.fn().mockResolvedValue(mockSubscription),
515
517
  },
516
518
  webhooks: {
517
- constructEvent: vi.fn().mockReturnValue(completeEvent),
519
+ constructEventAsync: vi.fn().mockResolvedValue(completeEvent),
518
520
  },
519
521
  };
520
522
 
@@ -591,7 +593,9 @@ describe("stripe", async () => {
591
593
  },
592
594
  );
593
595
 
594
- mockStripeForEvents.webhooks.constructEvent.mockReturnValue(updateEvent);
596
+ mockStripeForEvents.webhooks.constructEventAsync.mockReturnValue(
597
+ updateEvent,
598
+ );
595
599
  await eventTestAuth.handler(updateRequest);
596
600
  expect(onSubscriptionUpdate).toHaveBeenCalledWith(
597
601
  expect.objectContaining({
@@ -632,7 +636,7 @@ describe("stripe", async () => {
632
636
  },
633
637
  );
634
638
 
635
- mockStripeForEvents.webhooks.constructEvent.mockReturnValue(
639
+ mockStripeForEvents.webhooks.constructEventAsync.mockReturnValue(
636
640
  userCancelEvent,
637
641
  );
638
642
  await eventTestAuth.handler(userCancelRequest);
@@ -664,7 +668,9 @@ describe("stripe", async () => {
664
668
  },
665
669
  );
666
670
 
667
- mockStripeForEvents.webhooks.constructEvent.mockReturnValue(cancelEvent);
671
+ mockStripeForEvents.webhooks.constructEventAsync.mockReturnValue(
672
+ cancelEvent,
673
+ );
668
674
  await eventTestAuth.handler(cancelRequest);
669
675
 
670
676
  expect(onSubscriptionCancel).toHaveBeenCalled();
@@ -695,7 +701,9 @@ describe("stripe", async () => {
695
701
  },
696
702
  );
697
703
 
698
- mockStripeForEvents.webhooks.constructEvent.mockReturnValue(deleteEvent);
704
+ mockStripeForEvents.webhooks.constructEventAsync.mockReturnValue(
705
+ deleteEvent,
706
+ );
699
707
  await eventTestAuth.handler(deleteRequest);
700
708
 
701
709
  expect(onSubscriptionDeleted).toHaveBeenCalled();
package/src/types.ts CHANGED
@@ -1,7 +1,7 @@
1
1
  import type { Session, User } from "better-auth";
2
2
  import type Stripe from "stripe";
3
3
 
4
- export type Plan = {
4
+ export type StripePlan = {
5
5
  /**
6
6
  * Monthly price id
7
7
  */
@@ -20,6 +20,13 @@ export type Plan = {
20
20
  * yearly subscription
21
21
  */
22
22
  annualDiscountPriceId?: string;
23
+ /**
24
+ * To use lookup key instead of price id
25
+ *
26
+ * https://docs.stripe.com/products-prices/
27
+ * manage-prices#lookup-keys
28
+ */
29
+ annualDiscountLookupKey?: string;
23
30
  /**
24
31
  * Plan name
25
32
  */
@@ -204,7 +211,7 @@ export interface StripeOptions {
204
211
  /**
205
212
  * List of plan
206
213
  */
207
- plans: Plan[] | (() => Promise<Plan[]>);
214
+ plans: StripePlan[] | (() => Promise<StripePlan[]>);
208
215
  /**
209
216
  * Require email verification before a user is allowed to upgrade
210
217
  * their subscriptions
@@ -223,7 +230,7 @@ export interface StripeOptions {
223
230
  event: Stripe.Event;
224
231
  stripeSubscription: Stripe.Subscription;
225
232
  subscription: Subscription;
226
- plan: Plan;
233
+ plan: StripePlan;
227
234
  },
228
235
  request?: Request,
229
236
  ) => Promise<void>;
@@ -284,7 +291,7 @@ export interface StripeOptions {
284
291
  data: {
285
292
  user: User & Record<string, any>;
286
293
  session: Session & Record<string, any>;
287
- plan: Plan;
294
+ plan: StripePlan;
288
295
  subscription: Subscription;
289
296
  },
290
297
  request?: Request,