@better-auth/sso 1.4.0-beta.21 → 1.4.0-beta.23

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,275 @@
1
+ import type { Verification } from "better-auth";
2
+ import {
3
+ APIError,
4
+ createAuthEndpoint,
5
+ sessionMiddleware,
6
+ } from "better-auth/api";
7
+ import { generateRandomString } from "better-auth/crypto";
8
+ import * as z from "zod/v4";
9
+ import type { SSOOptions, SSOProvider } from "../types";
10
+
11
+ export const requestDomainVerification = (options: SSOOptions) => {
12
+ return createAuthEndpoint(
13
+ "/sso/request-domain-verification",
14
+ {
15
+ method: "POST",
16
+ body: z.object({
17
+ providerId: z.string(),
18
+ }),
19
+ metadata: {
20
+ openapi: {
21
+ summary: "Request a domain verification",
22
+ description:
23
+ "Request a domain verification for the given SSO provider",
24
+ responses: {
25
+ "404": {
26
+ description: "Provider not found",
27
+ },
28
+ "409": {
29
+ description: "Domain has already been verified",
30
+ },
31
+ "201": {
32
+ description: "Domain submitted for verification",
33
+ },
34
+ },
35
+ },
36
+ },
37
+ use: [sessionMiddleware],
38
+ },
39
+ async (ctx) => {
40
+ const body = ctx.body;
41
+ const provider = await ctx.context.adapter.findOne<
42
+ SSOProvider<SSOOptions>
43
+ >({
44
+ model: "ssoProvider",
45
+ where: [{ field: "providerId", value: body.providerId }],
46
+ });
47
+
48
+ if (!provider) {
49
+ throw new APIError("NOT_FOUND", {
50
+ message: "Provider not found",
51
+ code: "PROVIDER_NOT_FOUND",
52
+ });
53
+ }
54
+
55
+ const userId = ctx.context.session.user.id;
56
+ let isOrgMember = true;
57
+ if (provider.organizationId) {
58
+ const membershipsCount = await ctx.context.adapter.count({
59
+ model: "member",
60
+ where: [
61
+ { field: "userId", value: userId },
62
+ { field: "organizationId", value: provider.organizationId },
63
+ ],
64
+ });
65
+
66
+ isOrgMember = membershipsCount > 0;
67
+ }
68
+
69
+ if (provider.userId !== userId || !isOrgMember) {
70
+ throw new APIError("FORBIDDEN", {
71
+ message:
72
+ "User must be owner of or belong to the SSO provider organization",
73
+ code: "INSUFICCIENT_ACCESS",
74
+ });
75
+ }
76
+
77
+ if ("domainVerified" in provider && provider.domainVerified) {
78
+ throw new APIError("CONFLICT", {
79
+ message: "Domain has already been verified",
80
+ code: "DOMAIN_VERIFIED",
81
+ });
82
+ }
83
+
84
+ const activeVerification =
85
+ await ctx.context.adapter.findOne<Verification>({
86
+ model: "verification",
87
+ where: [
88
+ {
89
+ field: "identifier",
90
+ value: options.domainVerification?.tokenPrefix
91
+ ? `${options.domainVerification?.tokenPrefix}-${provider.providerId}`
92
+ : `better-auth-token-${provider.providerId}`,
93
+ },
94
+ { field: "expiresAt", value: new Date(), operator: "gt" },
95
+ ],
96
+ });
97
+
98
+ if (activeVerification) {
99
+ ctx.setStatus(201);
100
+ return ctx.json({ domainVerificationToken: activeVerification.value });
101
+ }
102
+
103
+ const domainVerificationToken = generateRandomString(24);
104
+ await ctx.context.adapter.create<Verification>({
105
+ model: "verification",
106
+ data: {
107
+ identifier: options.domainVerification?.tokenPrefix
108
+ ? `${options.domainVerification?.tokenPrefix}-${provider.providerId}`
109
+ : `better-auth-token-${provider.providerId}`,
110
+ createdAt: new Date(),
111
+ updatedAt: new Date(),
112
+ value: domainVerificationToken,
113
+ expiresAt: new Date(Date.now() + 3600 * 24 * 7 * 1000), // 1 week
114
+ },
115
+ });
116
+
117
+ ctx.setStatus(201);
118
+ return ctx.json({
119
+ domainVerificationToken,
120
+ });
121
+ },
122
+ );
123
+ };
124
+
125
+ export const verifyDomain = (options: SSOOptions) => {
126
+ return createAuthEndpoint(
127
+ "/sso/verify-domain",
128
+ {
129
+ method: "POST",
130
+ body: z.object({
131
+ providerId: z.string(),
132
+ }),
133
+ metadata: {
134
+ openapi: {
135
+ summary: "Verify the provider domain ownership",
136
+ description: "Verify the provider domain ownership via DNS records",
137
+ responses: {
138
+ "404": {
139
+ description: "Provider not found",
140
+ },
141
+ "409": {
142
+ description:
143
+ "Domain has already been verified or no pending verification exists",
144
+ },
145
+ "502": {
146
+ description:
147
+ "Unable to verify domain ownership due to upstream validator error",
148
+ },
149
+ "204": {
150
+ description: "Domain ownership was verified",
151
+ },
152
+ },
153
+ },
154
+ },
155
+ use: [sessionMiddleware],
156
+ },
157
+ async (ctx) => {
158
+ const body = ctx.body;
159
+ const provider = await ctx.context.adapter.findOne<
160
+ SSOProvider<SSOOptions>
161
+ >({
162
+ model: "ssoProvider",
163
+ where: [{ field: "providerId", value: body.providerId }],
164
+ });
165
+
166
+ if (!provider) {
167
+ throw new APIError("NOT_FOUND", {
168
+ message: "Provider not found",
169
+ code: "PROVIDER_NOT_FOUND",
170
+ });
171
+ }
172
+
173
+ const userId = ctx.context.session.user.id;
174
+ let isOrgMember = true;
175
+ if (provider.organizationId) {
176
+ const membershipsCount = await ctx.context.adapter.count({
177
+ model: "member",
178
+ where: [
179
+ { field: "userId", value: userId },
180
+ { field: "organizationId", value: provider.organizationId },
181
+ ],
182
+ });
183
+
184
+ isOrgMember = membershipsCount > 0;
185
+ }
186
+
187
+ if (provider.userId !== userId || !isOrgMember) {
188
+ throw new APIError("FORBIDDEN", {
189
+ message:
190
+ "User must be owner of or belong to the SSO provider organization",
191
+ code: "INSUFICCIENT_ACCESS",
192
+ });
193
+ }
194
+
195
+ if ("domainVerified" in provider && provider.domainVerified) {
196
+ throw new APIError("CONFLICT", {
197
+ message: "Domain has already been verified",
198
+ code: "DOMAIN_VERIFIED",
199
+ });
200
+ }
201
+
202
+ const activeVerification =
203
+ await ctx.context.adapter.findOne<Verification>({
204
+ model: "verification",
205
+ where: [
206
+ {
207
+ field: "identifier",
208
+ value: options.domainVerification?.tokenPrefix
209
+ ? `${options.domainVerification?.tokenPrefix}-${provider.providerId}`
210
+ : `better-auth-token-${provider.providerId}`,
211
+ },
212
+ { field: "expiresAt", value: new Date(), operator: "gt" },
213
+ ],
214
+ });
215
+
216
+ if (!activeVerification) {
217
+ throw new APIError("NOT_FOUND", {
218
+ message: "No pending domain verification exists",
219
+ code: "NO_PENDING_VERIFICATION",
220
+ });
221
+ }
222
+
223
+ let records: string[] = [];
224
+ let dns: typeof import("node:dns/promises");
225
+
226
+ try {
227
+ dns = await import("node:dns/promises");
228
+ } catch (error) {
229
+ ctx.context.logger.error(
230
+ "The core node:dns module is required for the domain verification feature",
231
+ error,
232
+ );
233
+ throw new APIError("INTERNAL_SERVER_ERROR", {
234
+ message: "Unable to verify domain ownership due to server error",
235
+ code: "DOMAIN_VERIFICATION_FAILED",
236
+ });
237
+ }
238
+
239
+ try {
240
+ const dnsRecords = await dns.resolveTxt(
241
+ new URL(provider.domain).hostname,
242
+ );
243
+ records = dnsRecords.flat();
244
+ } catch (error) {
245
+ ctx.context.logger.warn(
246
+ "DNS resolution failure while validating domain ownership",
247
+ error,
248
+ );
249
+ }
250
+
251
+ const record = records.find((record) =>
252
+ record.includes(
253
+ `${activeVerification.identifier}=${activeVerification.value}`,
254
+ ),
255
+ );
256
+ if (!record) {
257
+ throw new APIError("BAD_GATEWAY", {
258
+ message: "Unable to verify domain ownership. Try again later",
259
+ code: "DOMAIN_VERIFICATION_FAILED",
260
+ });
261
+ }
262
+
263
+ await ctx.context.adapter.update<SSOProvider<SSOOptions>>({
264
+ model: "ssoProvider",
265
+ where: [{ field: "providerId", value: provider.providerId }],
266
+ update: {
267
+ domainVerified: true,
268
+ },
269
+ });
270
+
271
+ ctx.setStatus(204);
272
+ return;
273
+ },
274
+ );
275
+ };
package/src/routes/sso.ts CHANGED
@@ -1,11 +1,9 @@
1
1
  import { BetterFetchError, betterFetch } from "@better-fetch/fetch";
2
+ import type { Account, Session, User, Verification } from "better-auth";
2
3
  import {
3
- type Account,
4
4
  createAuthorizationURL,
5
5
  generateState,
6
6
  parseState,
7
- type Session,
8
- type User,
9
7
  validateAuthorizationCode,
10
8
  validateToken,
11
9
  } from "better-auth";
@@ -15,6 +13,7 @@ import {
15
13
  sessionMiddleware,
16
14
  } from "better-auth/api";
17
15
  import { setSessionCookie } from "better-auth/cookies";
16
+ import { generateRandomString } from "better-auth/crypto";
18
17
  import { handleOAuthUserInfo } from "better-auth/oauth2";
19
18
  import { decodeJwt } from "jose";
20
19
  import * as saml from "samlify";
@@ -23,6 +22,7 @@ import type { IdentityProvider } from "samlify/types/src/entity-idp";
23
22
  import type { FlowResult } from "samlify/types/src/flow";
24
23
  import * as z from "zod/v4";
25
24
  import type { OIDCConfig, SAMLConfig, SSOOptions, SSOProvider } from "../types";
25
+ import { validateEmailDomain } from "../utils";
26
26
 
27
27
  /**
28
28
  * Safely parses a value that might be a JSON string or already a parsed object
@@ -128,7 +128,7 @@ export const spMetadata = () => {
128
128
  );
129
129
  };
130
130
 
131
- export const registerSSOProvider = (options?: SSOOptions) => {
131
+ export const registerSSOProvider = <O extends SSOOptions>(options: O) => {
132
132
  return createAuthEndpoint(
133
133
  "/sso/register",
134
134
  {
@@ -360,6 +360,16 @@ export const registerSSOProvider = (options?: SSOOptions) => {
360
360
  description:
361
361
  "The domain of the provider, used for email matching",
362
362
  },
363
+ domainVerified: {
364
+ type: "boolean",
365
+ description:
366
+ "A boolean indicating whether the domain has been verified or not",
367
+ },
368
+ domainVerificationToken: {
369
+ type: "string",
370
+ description:
371
+ "Domain verification token. It can be used to prove ownership over the SSO domain",
372
+ },
363
373
  oidcConfig: {
364
374
  type: "object",
365
375
  properties: {
@@ -588,12 +598,13 @@ export const registerSSOProvider = (options?: SSOOptions) => {
588
598
 
589
599
  const provider = await ctx.context.adapter.create<
590
600
  Record<string, any>,
591
- SSOProvider
601
+ SSOProvider<O>
592
602
  >({
593
603
  model: "ssoProvider",
594
604
  data: {
595
605
  issuer: body.issuer,
596
606
  domain: body.domain,
607
+ domainVerified: false,
597
608
  oidcConfig: body.oidcConfig
598
609
  ? JSON.stringify({
599
610
  issuer: body.issuer,
@@ -642,6 +653,34 @@ export const registerSSOProvider = (options?: SSOOptions) => {
642
653
  },
643
654
  });
644
655
 
656
+ let domainVerificationToken: string | undefined;
657
+ let domainVerified: boolean | undefined;
658
+
659
+ if (options?.domainVerification?.enabled) {
660
+ domainVerified = false;
661
+ domainVerificationToken = generateRandomString(24);
662
+
663
+ await ctx.context.adapter.create<Verification>({
664
+ model: "verification",
665
+ data: {
666
+ identifier: options.domainVerification?.tokenPrefix
667
+ ? `${options.domainVerification?.tokenPrefix}-${provider.providerId}`
668
+ : `better-auth-token-${provider.providerId}`,
669
+ createdAt: new Date(),
670
+ updatedAt: new Date(),
671
+ value: domainVerificationToken,
672
+ expiresAt: new Date(Date.now() + 3600 * 24 * 7 * 1000), // 1 week
673
+ },
674
+ });
675
+ }
676
+
677
+ type SSOProviderReturn = O["domainVerification"] extends { enabled: true }
678
+ ? {
679
+ domainVerified: boolean;
680
+ domainVerificationToken: string;
681
+ } & SSOProvider<O>
682
+ : SSOProvider<O>;
683
+
645
684
  return ctx.json({
646
685
  ...provider,
647
686
  oidcConfig: JSON.parse(
@@ -651,7 +690,11 @@ export const registerSSOProvider = (options?: SSOOptions) => {
651
690
  provider.samlConfig as unknown as string,
652
691
  ) as SAMLConfig,
653
692
  redirectURI: `${ctx.context.baseURL}/sso/callback/${provider.providerId}`,
654
- });
693
+ ...(options?.domainVerification?.enabled ? { domainVerified } : {}),
694
+ ...(options?.domainVerification?.enabled
695
+ ? { domainVerificationToken }
696
+ : {}),
697
+ } as unknown as SSOProviderReturn);
655
698
  },
656
699
  );
657
700
  };
@@ -842,7 +885,7 @@ export const signInSSO = (options?: SSOOptions) => {
842
885
  return res.id;
843
886
  });
844
887
  }
845
- let provider: SSOProvider | null = null;
888
+ let provider: SSOProvider<SSOOptions> | null = null;
846
889
  if (options?.defaultSSO?.length) {
847
890
  // Find matching default SSO provider by providerId
848
891
  const matchingDefault = providerId
@@ -864,7 +907,10 @@ export const signInSSO = (options?: SSOOptions) => {
864
907
  oidcConfig: matchingDefault.oidcConfig,
865
908
  samlConfig: matchingDefault.samlConfig,
866
909
  domain: matchingDefault.domain,
867
- };
910
+ ...(options.domainVerification?.enabled
911
+ ? { domainVerified: true }
912
+ : {}),
913
+ } as SSOProvider<SSOOptions>;
868
914
  }
869
915
  }
870
916
  if (!providerId && !orgId && !domain) {
@@ -875,7 +921,7 @@ export const signInSSO = (options?: SSOOptions) => {
875
921
  // Try to find provider in database
876
922
  if (!provider) {
877
923
  provider = await ctx.context.adapter
878
- .findOne<SSOProvider>({
924
+ .findOne<SSOProvider<SSOOptions>>({
879
925
  model: "ssoProvider",
880
926
  where: [
881
927
  {
@@ -927,7 +973,32 @@ export const signInSSO = (options?: SSOOptions) => {
927
973
  }
928
974
  }
929
975
 
976
+ if (
977
+ options?.domainVerification?.enabled &&
978
+ !("domainVerified" in provider && provider.domainVerified)
979
+ ) {
980
+ throw new APIError("UNAUTHORIZED", {
981
+ message: "Provider domain has not been verified",
982
+ });
983
+ }
984
+
930
985
  if (provider.oidcConfig && body.providerType !== "saml") {
986
+ let finalAuthUrl = provider.oidcConfig.authorizationEndpoint;
987
+ if (!finalAuthUrl && provider.oidcConfig.discoveryEndpoint) {
988
+ const discovery = await betterFetch<{
989
+ authorization_endpoint: string;
990
+ }>(provider.oidcConfig.discoveryEndpoint, {
991
+ method: "GET",
992
+ });
993
+ if (discovery.data) {
994
+ finalAuthUrl = discovery.data.authorization_endpoint;
995
+ }
996
+ }
997
+ if (!finalAuthUrl) {
998
+ throw new APIError("BAD_REQUEST", {
999
+ message: "Invalid OIDC configuration. Authorization URL not found.",
1000
+ });
1001
+ }
931
1002
  const state = await generateState(ctx, undefined, false);
932
1003
  const redirectURI = `${ctx.context.baseURL}/sso/callback/${provider.providerId}`;
933
1004
  const authorizationURL = await createAuthorizationURL({
@@ -949,7 +1020,7 @@ export const signInSSO = (options?: SSOOptions) => {
949
1020
  "offline_access",
950
1021
  ],
951
1022
  loginHint: ctx.body.loginHint || email,
952
- authorizationEndpoint: provider.oidcConfig.authorizationEndpoint!,
1023
+ authorizationEndpoint: finalAuthUrl,
953
1024
  });
954
1025
  return ctx.json({
955
1026
  url: authorizationURL.toString(),
@@ -1014,6 +1085,10 @@ export const callbackSSO = (options?: SSOOptions) => {
1014
1085
  error: z.string().optional(),
1015
1086
  error_description: z.string().optional(),
1016
1087
  }),
1088
+ allowedMediaTypes: [
1089
+ "application/x-www-form-urlencoded",
1090
+ "application/json",
1091
+ ],
1017
1092
  metadata: {
1018
1093
  isAction: false,
1019
1094
  openapi: {
@@ -1046,7 +1121,7 @@ export const callbackSSO = (options?: SSOOptions) => {
1046
1121
  }?error=${error}&error_description=${error_description}`,
1047
1122
  );
1048
1123
  }
1049
- let provider: SSOProvider | null = null;
1124
+ let provider: SSOProvider<SSOOptions> | null = null;
1050
1125
  if (options?.defaultSSO?.length) {
1051
1126
  const matchingDefault = options.defaultSSO.find(
1052
1127
  (defaultProvider) =>
@@ -1057,7 +1132,10 @@ export const callbackSSO = (options?: SSOOptions) => {
1057
1132
  ...matchingDefault,
1058
1133
  issuer: matchingDefault.oidcConfig?.issuer || "",
1059
1134
  userId: "default",
1060
- };
1135
+ ...(options.domainVerification?.enabled
1136
+ ? { domainVerified: true }
1137
+ : {}),
1138
+ } as SSOProvider<SSOOptions>;
1061
1139
  }
1062
1140
  }
1063
1141
  if (!provider) {
@@ -1081,7 +1159,7 @@ export const callbackSSO = (options?: SSOOptions) => {
1081
1159
  ...res,
1082
1160
  oidcConfig:
1083
1161
  safeJsonParse<OIDCConfig>(res.oidcConfig) || undefined,
1084
- } as SSOProvider;
1162
+ } as SSOProvider<SSOOptions>;
1085
1163
  });
1086
1164
  }
1087
1165
  if (!provider) {
@@ -1092,6 +1170,15 @@ export const callbackSSO = (options?: SSOOptions) => {
1092
1170
  );
1093
1171
  }
1094
1172
 
1173
+ if (
1174
+ options?.domainVerification?.enabled &&
1175
+ !("domainVerified" in provider && provider.domainVerified)
1176
+ ) {
1177
+ throw new APIError("UNAUTHORIZED", {
1178
+ message: "Provider domain has not been verified",
1179
+ });
1180
+ }
1181
+
1095
1182
  let config = provider.oidcConfig;
1096
1183
 
1097
1184
  if (!config) {
@@ -1387,7 +1474,7 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1387
1474
  async (ctx) => {
1388
1475
  const { SAMLResponse, RelayState } = ctx.body;
1389
1476
  const { providerId } = ctx.params;
1390
- let provider: SSOProvider | null = null;
1477
+ let provider: SSOProvider<SSOOptions> | null = null;
1391
1478
  if (options?.defaultSSO?.length) {
1392
1479
  const matchingDefault = options.defaultSSO.find(
1393
1480
  (defaultProvider) => defaultProvider.providerId === providerId,
@@ -1397,12 +1484,15 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1397
1484
  ...matchingDefault,
1398
1485
  userId: "default",
1399
1486
  issuer: matchingDefault.samlConfig?.issuer || "",
1400
- };
1487
+ ...(options.domainVerification?.enabled
1488
+ ? { domainVerified: true }
1489
+ : {}),
1490
+ } as SSOProvider<SSOOptions>;
1401
1491
  }
1402
1492
  }
1403
1493
  if (!provider) {
1404
1494
  provider = await ctx.context.adapter
1405
- .findOne<SSOProvider>({
1495
+ .findOne<SSOProvider<SSOOptions>>({
1406
1496
  model: "ssoProvider",
1407
1497
  where: [{ field: "providerId", value: providerId }],
1408
1498
  })
@@ -1425,6 +1515,15 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1425
1515
  });
1426
1516
  }
1427
1517
 
1518
+ if (
1519
+ options?.domainVerification?.enabled &&
1520
+ !("domainVerified" in provider && provider.domainVerified)
1521
+ ) {
1522
+ throw new APIError("UNAUTHORIZED", {
1523
+ message: "Provider domain has not been verified",
1524
+ });
1525
+ }
1526
+
1428
1527
  const parsedSamlConfig = safeJsonParse<SAMLConfig>(
1429
1528
  provider.samlConfig as unknown as string,
1430
1529
  );
@@ -1718,7 +1817,7 @@ export const acsEndpoint = (options?: SSOOptions) => {
1718
1817
  const { providerId } = ctx.params;
1719
1818
 
1720
1819
  // If defaultSSO is configured, use it as the provider
1721
- let provider: SSOProvider | null = null;
1820
+ let provider: SSOProvider<SSOOptions> | null = null;
1722
1821
 
1723
1822
  if (options?.defaultSSO?.length) {
1724
1823
  // For ACS endpoint, we can use the first default provider or try to match by providerId
@@ -1735,11 +1834,14 @@ export const acsEndpoint = (options?: SSOOptions) => {
1735
1834
  userId: "default",
1736
1835
  samlConfig: matchingDefault.samlConfig,
1737
1836
  domain: matchingDefault.domain,
1837
+ ...(options.domainVerification?.enabled
1838
+ ? { domainVerified: true }
1839
+ : {}),
1738
1840
  };
1739
1841
  }
1740
1842
  } else {
1741
1843
  provider = await ctx.context.adapter
1742
- .findOne<SSOProvider>({
1844
+ .findOne<SSOProvider<SSOOptions>>({
1743
1845
  model: "ssoProvider",
1744
1846
  where: [
1745
1847
  {
@@ -1767,6 +1869,15 @@ export const acsEndpoint = (options?: SSOOptions) => {
1767
1869
  });
1768
1870
  }
1769
1871
 
1872
+ if (
1873
+ options?.domainVerification?.enabled &&
1874
+ !("domainVerified" in provider && provider.domainVerified)
1875
+ ) {
1876
+ throw new APIError("UNAUTHORIZED", {
1877
+ message: "Provider domain has not been verified",
1878
+ });
1879
+ }
1880
+
1770
1881
  const parsedSamlConfig = provider.samlConfig;
1771
1882
  // Configure SP and IdP
1772
1883
  const sp = saml.ServiceProvider({
@@ -1940,7 +2051,10 @@ export const acsEndpoint = (options?: SSOOptions) => {
1940
2051
  const isTrustedProvider =
1941
2052
  ctx.context.options.account?.accountLinking?.trustedProviders?.includes(
1942
2053
  provider.providerId,
1943
- );
2054
+ ) ||
2055
+ ("domainVerified" in provider &&
2056
+ provider.domainVerified &&
2057
+ validateEmailDomain(userInfo.email, provider.domain));
1944
2058
  if (!isTrustedProvider) {
1945
2059
  throw ctx.redirect(
1946
2060
  `${parsedSamlConfig.callbackUrl}?error=account_not_found`,