@better-auth/sso 1.4.17 → 1.5.0-beta.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/routes/sso.ts CHANGED
@@ -11,6 +11,7 @@ import {
11
11
  import {
12
12
  APIError,
13
13
  createAuthEndpoint,
14
+ getSessionFromCtx,
14
15
  sessionMiddleware,
15
16
  } from "better-auth/api";
16
17
  import { setSessionCookie } from "better-auth/cookies";
@@ -52,8 +53,9 @@ import {
52
53
  validateSAMLAlgorithms,
53
54
  validateSingleAssertion,
54
55
  } from "../saml";
56
+ import { generateRelayState, parseRelayState } from "../saml-state";
55
57
  import type { OIDCConfig, SAMLConfig, SSOOptions, SSOProvider } from "../types";
56
- import { safeJsonParse, validateEmailDomain } from "../utils";
58
+ import { domainMatches, safeJsonParse, validateEmailDomain } from "../utils";
57
59
 
58
60
  export interface TimestampValidationOptions {
59
61
  clockSkew?: number;
@@ -165,6 +167,8 @@ const spMetadataQuerySchema = z.object({
165
167
  format: z.enum(["xml", "json"]).default("xml"),
166
168
  });
167
169
 
170
+ type RelayState = Awaited<ReturnType<typeof parseRelayState>>;
171
+
168
172
  export const spMetadata = () => {
169
173
  return createAuthEndpoint(
170
174
  "/sso/saml2/sp/metadata",
@@ -225,6 +229,7 @@ export const spMetadata = () => {
225
229
  },
226
230
  ],
227
231
  wantMessageSigned: parsedSamlConfig.wantAssertionsSigned || false,
232
+ authnRequestsSigned: parsedSamlConfig.authnRequestsSigned || false,
228
233
  nameIDFormat: parsedSamlConfig.identifierFormat
229
234
  ? [parsedSamlConfig.identifierFormat]
230
235
  : undefined,
@@ -247,7 +252,8 @@ const ssoProviderBodySchema = z.object({
247
252
  description: "The issuer of the provider",
248
253
  }),
249
254
  domain: z.string({}).meta({
250
- description: "The domain of the provider. This is used for email matching",
255
+ description:
256
+ "The domain(s) of the provider. For enterprise multi-domain SSO where a single IdP serves multiple email domains, use comma-separated values (e.g., 'company.com,subsidiary.com,acquired-company.com')",
251
257
  }),
252
258
  oidcConfig: z
253
259
  .object({
@@ -385,6 +391,7 @@ const ssoProviderBodySchema = z.object({
385
391
  encPrivateKeyPass: z.string().optional(),
386
392
  }),
387
393
  wantAssertionsSigned: z.boolean().optional(),
394
+ authnRequestsSigned: z.boolean().optional(),
388
395
  signatureAlgorithm: z.string().optional(),
389
396
  digestAlgorithm: z.string().optional(),
390
397
  identifierFormat: z.string().optional(),
@@ -830,6 +837,7 @@ export const registerSSOProvider = <O extends SSOOptions>(options: O) => {
830
837
  idpMetadata: body.samlConfig.idpMetadata,
831
838
  spMetadata: body.samlConfig.spMetadata,
832
839
  wantAssertionsSigned: body.samlConfig.wantAssertionsSigned,
840
+ authnRequestsSigned: body.samlConfig.authnRequestsSigned,
833
841
  signatureAlgorithm: body.samlConfig.signatureAlgorithm,
834
842
  digestAlgorithm: body.samlConfig.digestAlgorithm,
835
843
  identifierFormat: body.samlConfig.identifierFormat,
@@ -1121,38 +1129,58 @@ export const signInSSO = (options?: SSOOptions) => {
1121
1129
  }
1122
1130
  // Try to find provider in database
1123
1131
  if (!provider) {
1124
- provider = await ctx.context.adapter
1125
- .findOne<SSOProvider<SSOOptions>>({
1126
- model: "ssoProvider",
1127
- where: [
1128
- {
1129
- field: providerId
1130
- ? "providerId"
1131
- : orgId
1132
- ? "organizationId"
1133
- : "domain",
1134
- value: providerId || orgId || domain!,
1135
- },
1136
- ],
1137
- })
1138
- .then((res) => {
1139
- if (!res) {
1140
- return null;
1141
- }
1142
- return {
1143
- ...res,
1144
- oidcConfig: res.oidcConfig
1145
- ? safeJsonParse<OIDCConfig>(
1146
- res.oidcConfig as unknown as string,
1147
- ) || undefined
1148
- : undefined,
1149
- samlConfig: res.samlConfig
1150
- ? safeJsonParse<SAMLConfig>(
1151
- res.samlConfig as unknown as string,
1152
- ) || undefined
1153
- : undefined,
1154
- };
1155
- });
1132
+ const parseProvider = (res: SSOProvider<SSOOptions> | null) => {
1133
+ if (!res) return null;
1134
+ return {
1135
+ ...res,
1136
+ oidcConfig: res.oidcConfig
1137
+ ? safeJsonParse<OIDCConfig>(
1138
+ res.oidcConfig as unknown as string,
1139
+ ) || undefined
1140
+ : undefined,
1141
+ samlConfig: res.samlConfig
1142
+ ? safeJsonParse<SAMLConfig>(
1143
+ res.samlConfig as unknown as string,
1144
+ ) || undefined
1145
+ : undefined,
1146
+ };
1147
+ };
1148
+
1149
+ if (providerId || orgId) {
1150
+ // Exact match for providerId or orgId
1151
+ provider = parseProvider(
1152
+ await ctx.context.adapter.findOne<SSOProvider<SSOOptions>>({
1153
+ model: "ssoProvider",
1154
+ where: [
1155
+ {
1156
+ field: providerId ? "providerId" : "organizationId",
1157
+ value: providerId || orgId!,
1158
+ },
1159
+ ],
1160
+ }),
1161
+ );
1162
+ } else if (domain) {
1163
+ // For domain lookup, support comma-separated domains
1164
+ // First try exact match (fast path)
1165
+ provider = parseProvider(
1166
+ await ctx.context.adapter.findOne<SSOProvider<SSOOptions>>({
1167
+ model: "ssoProvider",
1168
+ where: [{ field: "domain", value: domain }],
1169
+ }),
1170
+ );
1171
+ // If not found, search all providers for comma-separated domain match
1172
+ if (!provider) {
1173
+ const allProviders = await ctx.context.adapter.findMany<
1174
+ SSOProvider<SSOOptions>
1175
+ >({
1176
+ model: "ssoProvider",
1177
+ });
1178
+ const matchingProvider = allProviders.find((p) =>
1179
+ domainMatches(domain, p.domain),
1180
+ );
1181
+ provider = parseProvider(matchingProvider ?? null);
1182
+ }
1183
+ }
1156
1184
  }
1157
1185
 
1158
1186
  if (!provider) {
@@ -1241,6 +1269,17 @@ export const signInSSO = (options?: SSOOptions) => {
1241
1269
  });
1242
1270
  }
1243
1271
 
1272
+ if (
1273
+ parsedSamlConfig.authnRequestsSigned &&
1274
+ !parsedSamlConfig.spMetadata?.privateKey &&
1275
+ !parsedSamlConfig.privateKey
1276
+ ) {
1277
+ ctx.context.logger.warn(
1278
+ "authnRequestsSigned is enabled but no privateKey provided - AuthnRequests will not be signed",
1279
+ { providerId: provider.providerId },
1280
+ );
1281
+ }
1282
+
1244
1283
  let metadata = parsedSamlConfig.spMetadata.metadata;
1245
1284
 
1246
1285
  if (!metadata) {
@@ -1260,6 +1299,8 @@ export const signInSSO = (options?: SSOOptions) => {
1260
1299
  ],
1261
1300
  wantMessageSigned:
1262
1301
  parsedSamlConfig.wantAssertionsSigned || false,
1302
+ authnRequestsSigned:
1303
+ parsedSamlConfig.authnRequestsSigned || false,
1263
1304
  nameIDFormat: parsedSamlConfig.identifierFormat
1264
1305
  ? [parsedSamlConfig.identifierFormat]
1265
1306
  : undefined,
@@ -1270,6 +1311,10 @@ export const signInSSO = (options?: SSOOptions) => {
1270
1311
  const sp = saml.ServiceProvider({
1271
1312
  metadata: metadata,
1272
1313
  allowCreate: true,
1314
+ privateKey:
1315
+ parsedSamlConfig.spMetadata?.privateKey ||
1316
+ parsedSamlConfig.privateKey,
1317
+ privateKeyPass: parsedSamlConfig.spMetadata?.privateKeyPass,
1273
1318
  });
1274
1319
 
1275
1320
  const idp = saml.IdentityProvider({
@@ -1293,6 +1338,12 @@ export const signInSSO = (options?: SSOOptions) => {
1293
1338
  });
1294
1339
  }
1295
1340
 
1341
+ const { state: relayState } = await generateRelayState(
1342
+ ctx,
1343
+ undefined,
1344
+ false,
1345
+ );
1346
+
1296
1347
  const shouldSaveRequest =
1297
1348
  loginRequest.id && options?.saml?.enableInResponseToValidation;
1298
1349
  if (shouldSaveRequest) {
@@ -1311,9 +1362,7 @@ export const signInSSO = (options?: SSOOptions) => {
1311
1362
  }
1312
1363
 
1313
1364
  return ctx.json({
1314
- url: `${loginRequest.context}&RelayState=${encodeURIComponent(
1315
- body.callbackURL,
1316
- )}`,
1365
+ url: `${loginRequest.context}&RelayState=${encodeURIComponent(relayState)}`,
1317
1366
  redirect: true,
1318
1367
  });
1319
1368
  }
@@ -1683,12 +1732,71 @@ const callbackSSOSAMLBodySchema = z.object({
1683
1732
  RelayState: z.string().optional(),
1684
1733
  });
1685
1734
 
1735
+ /**
1736
+ * Validates and returns a safe redirect URL.
1737
+ * - Prevents open redirect attacks by validating against trusted origins
1738
+ * - Prevents redirect loops by checking if URL points to callback route
1739
+ * - Falls back to appOrigin if URL is invalid or unsafe
1740
+ */
1741
+ const getSafeRedirectUrl = (
1742
+ url: string | undefined,
1743
+ callbackPath: string,
1744
+ appOrigin: string,
1745
+ isTrustedOrigin: (
1746
+ url: string,
1747
+ settings?: { allowRelativePaths: boolean },
1748
+ ) => boolean,
1749
+ ): string => {
1750
+ if (!url) {
1751
+ return appOrigin;
1752
+ }
1753
+
1754
+ if (url.startsWith("/") && !url.startsWith("//")) {
1755
+ try {
1756
+ const absoluteUrl = new URL(url, appOrigin);
1757
+ if (absoluteUrl.origin !== appOrigin) {
1758
+ return appOrigin;
1759
+ }
1760
+ const callbackPathname = new URL(callbackPath).pathname;
1761
+ if (absoluteUrl.pathname === callbackPathname) {
1762
+ return appOrigin;
1763
+ }
1764
+ } catch {
1765
+ return appOrigin;
1766
+ }
1767
+ return url;
1768
+ }
1769
+
1770
+ if (!isTrustedOrigin(url, { allowRelativePaths: false })) {
1771
+ return appOrigin;
1772
+ }
1773
+
1774
+ try {
1775
+ const callbackPathname = new URL(callbackPath).pathname;
1776
+ const urlPathname = new URL(url).pathname;
1777
+ if (urlPathname === callbackPathname) {
1778
+ return appOrigin;
1779
+ }
1780
+ } catch {
1781
+ if (url === callbackPath || url.startsWith(`${callbackPath}?`)) {
1782
+ return appOrigin;
1783
+ }
1784
+ }
1785
+
1786
+ return url;
1787
+ };
1788
+
1686
1789
  export const callbackSSOSAML = (options?: SSOOptions) => {
1687
1790
  return createAuthEndpoint(
1688
1791
  "/sso/saml2/callback/:providerId",
1689
1792
  {
1690
- method: "POST",
1691
- body: callbackSSOSAMLBodySchema,
1793
+ method: ["GET", "POST"],
1794
+ body: callbackSSOSAMLBodySchema.optional(),
1795
+ query: z
1796
+ .object({
1797
+ RelayState: z.string().optional(),
1798
+ })
1799
+ .optional(),
1692
1800
  metadata: {
1693
1801
  ...HIDE_METADATA,
1694
1802
  allowedMediaTypes: [
@@ -1699,7 +1807,7 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1699
1807
  operationId: "handleSAMLCallback",
1700
1808
  summary: "Callback URL for SAML provider",
1701
1809
  description:
1702
- "This endpoint is used as the callback URL for SAML providers.",
1810
+ "This endpoint is used as the callback URL for SAML providers. Supports both GET and POST methods for IdP-initiated and SP-initiated flows.",
1703
1811
  responses: {
1704
1812
  "302": {
1705
1813
  description: "Redirects to the callback URL",
@@ -1715,8 +1823,41 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1715
1823
  },
1716
1824
  },
1717
1825
  async (ctx) => {
1718
- const { SAMLResponse, RelayState } = ctx.body;
1719
1826
  const { providerId } = ctx.params;
1827
+ const appOrigin = new URL(ctx.context.baseURL).origin;
1828
+ const errorURL =
1829
+ ctx.context.options.onAPIError?.errorURL || `${appOrigin}/error`;
1830
+ const currentCallbackPath = `${ctx.context.baseURL}/sso/saml2/callback/${providerId}`;
1831
+
1832
+ // Determine if this is a GET request by checking both method AND body presence
1833
+ // When called via auth.api.*, ctx.method may not be reliable, so we also check for body
1834
+ const isGetRequest = ctx.method === "GET" && !ctx.body?.SAMLResponse;
1835
+
1836
+ if (isGetRequest) {
1837
+ const session = await getSessionFromCtx(ctx);
1838
+
1839
+ if (!session?.session) {
1840
+ throw ctx.redirect(`${errorURL}?error=invalid_request`);
1841
+ }
1842
+
1843
+ const relayState = ctx.query?.RelayState as string | undefined;
1844
+ const safeRedirectUrl = getSafeRedirectUrl(
1845
+ relayState,
1846
+ currentCallbackPath,
1847
+ appOrigin,
1848
+ (url, settings) => ctx.context.isTrustedOrigin(url, settings),
1849
+ );
1850
+
1851
+ throw ctx.redirect(safeRedirectUrl);
1852
+ }
1853
+
1854
+ if (!ctx.body?.SAMLResponse) {
1855
+ throw new APIError("BAD_REQUEST", {
1856
+ message: "SAMLResponse is required for POST requests",
1857
+ });
1858
+ }
1859
+
1860
+ const { SAMLResponse } = ctx.body;
1720
1861
 
1721
1862
  const maxResponseSize =
1722
1863
  options?.saml?.maxResponseSize ?? DEFAULT_MAX_SAML_RESPONSE_SIZE;
@@ -1726,6 +1867,14 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1726
1867
  });
1727
1868
  }
1728
1869
 
1870
+ let relayState: RelayState | null = null;
1871
+ if (ctx.body.RelayState) {
1872
+ try {
1873
+ relayState = await parseRelayState(ctx);
1874
+ } catch {
1875
+ relayState = null;
1876
+ }
1877
+ }
1729
1878
  let provider: SSOProvider<SSOOptions> | null = null;
1730
1879
  if (options?.defaultSSO?.length) {
1731
1880
  const matchingDefault = options.defaultSSO.find(
@@ -1846,7 +1995,7 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1846
1995
  parsedResponse = await sp.parseLoginResponse(idp, "post", {
1847
1996
  body: {
1848
1997
  SAMLResponse,
1849
- RelayState: RelayState || undefined,
1998
+ RelayState: ctx.body.RelayState || undefined,
1850
1999
  },
1851
2000
  });
1852
2001
 
@@ -1909,7 +2058,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1909
2058
  { inResponseTo, providerId: provider.providerId },
1910
2059
  );
1911
2060
  const redirectUrl =
1912
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2061
+ relayState?.callbackURL ||
2062
+ parsedSamlConfig.callbackUrl ||
2063
+ ctx.context.baseURL;
1913
2064
  throw ctx.redirect(
1914
2065
  `${redirectUrl}?error=invalid_saml_response&error_description=Unknown+or+expired+request+ID`,
1915
2066
  );
@@ -1929,7 +2080,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1929
2080
  `${AUTHN_REQUEST_KEY_PREFIX}${inResponseTo}`,
1930
2081
  );
1931
2082
  const redirectUrl =
1932
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2083
+ relayState?.callbackURL ||
2084
+ parsedSamlConfig.callbackUrl ||
2085
+ ctx.context.baseURL;
1933
2086
  throw ctx.redirect(
1934
2087
  `${redirectUrl}?error=invalid_saml_response&error_description=Provider+mismatch`,
1935
2088
  );
@@ -1944,7 +2097,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1944
2097
  { providerId: provider.providerId },
1945
2098
  );
1946
2099
  const redirectUrl =
1947
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2100
+ relayState?.callbackURL ||
2101
+ parsedSamlConfig.callbackUrl ||
2102
+ ctx.context.baseURL;
1948
2103
  throw ctx.redirect(
1949
2104
  `${redirectUrl}?error=unsolicited_response&error_description=IdP-initiated+SSO+not+allowed`,
1950
2105
  );
@@ -1997,7 +2152,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1997
2152
  },
1998
2153
  );
1999
2154
  const redirectUrl =
2000
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2155
+ relayState?.callbackURL ||
2156
+ parsedSamlConfig.callbackUrl ||
2157
+ ctx.context.baseURL;
2001
2158
  throw ctx.redirect(
2002
2159
  `${redirectUrl}?error=replay_detected&error_description=SAML+assertion+has+already+been+used`,
2003
2160
  );
@@ -2032,7 +2189,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
2032
2189
  ]),
2033
2190
  ),
2034
2191
  id: attributes[mapping.id || "nameID"] || extract.nameID,
2035
- email: attributes[mapping.email || "email"] || extract.nameID,
2192
+ email: (
2193
+ attributes[mapping.email || "email"] || extract.nameID
2194
+ ).toLowerCase(),
2036
2195
  name:
2037
2196
  [
2038
2197
  attributes[mapping.firstName || "givenName"],
@@ -2071,7 +2230,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
2071
2230
  validateEmailDomain(userInfo.email as string, provider.domain));
2072
2231
 
2073
2232
  const callbackUrl =
2074
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2233
+ relayState?.callbackURL ||
2234
+ parsedSamlConfig.callbackUrl ||
2235
+ ctx.context.baseURL;
2075
2236
 
2076
2237
  const result = await handleOAuthUserInfo(ctx, {
2077
2238
  userInfo: {
@@ -2122,7 +2283,14 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
2122
2283
  });
2123
2284
 
2124
2285
  await setSessionCookie(ctx, { session, user });
2125
- throw ctx.redirect(callbackUrl);
2286
+
2287
+ const safeRedirectUrl = getSafeRedirectUrl(
2288
+ relayState?.callbackURL || parsedSamlConfig.callbackUrl,
2289
+ currentCallbackPath,
2290
+ appOrigin,
2291
+ (url, settings) => ctx.context.isTrustedOrigin(url, settings),
2292
+ );
2293
+ throw ctx.redirect(safeRedirectUrl);
2126
2294
  },
2127
2295
  );
2128
2296
  };
@@ -2483,7 +2651,9 @@ export const acsEndpoint = (options?: SSOOptions) => {
2483
2651
  ]),
2484
2652
  ),
2485
2653
  id: attributes[mapping.id || "nameID"] || extract.nameID,
2486
- email: attributes[mapping.email || "email"] || extract.nameID,
2654
+ email: (
2655
+ attributes[mapping.email || "email"] || extract.nameID
2656
+ ).toLowerCase(),
2487
2657
  name:
2488
2658
  [
2489
2659
  attributes[mapping.firstName || "givenName"],
@@ -0,0 +1,78 @@
1
+ import type { GenericEndpointContext, StateData } from "better-auth";
2
+ import { generateGenericState, parseGenericState } from "better-auth";
3
+ import { generateRandomString } from "better-auth/crypto";
4
+ import { APIError } from "better-call";
5
+
6
+ export async function generateRelayState(
7
+ c: GenericEndpointContext,
8
+ link:
9
+ | {
10
+ email: string;
11
+ userId: string;
12
+ }
13
+ | undefined,
14
+ additionalData: Record<string, any> | false | undefined,
15
+ ) {
16
+ const callbackURL = c.body.callbackURL;
17
+ if (!callbackURL) {
18
+ throw new APIError("BAD_REQUEST", {
19
+ message: "callbackURL is required",
20
+ });
21
+ }
22
+
23
+ const codeVerifier = generateRandomString(128);
24
+ const stateData: StateData = {
25
+ ...(additionalData ? additionalData : {}),
26
+ callbackURL,
27
+ codeVerifier,
28
+ errorURL: c.body.errorCallbackURL,
29
+ newUserURL: c.body.newUserCallbackURL,
30
+ link,
31
+ /**
32
+ * This is the actual expiry time of the state
33
+ */
34
+ expiresAt: Date.now() + 10 * 60 * 1000,
35
+ requestSignUp: c.body.requestSignUp,
36
+ };
37
+
38
+ try {
39
+ return generateGenericState(c, stateData, {
40
+ cookieName: "relay_state",
41
+ });
42
+ } catch (error) {
43
+ c.context.logger.error(
44
+ "Failed to create verification for relay state",
45
+ error,
46
+ );
47
+ throw new APIError("INTERNAL_SERVER_ERROR", {
48
+ message: "State error: Unable to create verification for relay state",
49
+ cause: error,
50
+ });
51
+ }
52
+ }
53
+
54
+ export async function parseRelayState(c: GenericEndpointContext) {
55
+ const state = c.body.RelayState;
56
+ const errorURL =
57
+ c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
58
+
59
+ let parsedData: StateData;
60
+
61
+ try {
62
+ parsedData = await parseGenericState(c, state, {
63
+ cookieName: "relay_state",
64
+ });
65
+ } catch (error) {
66
+ c.context.logger.error("Failed to parse relay state", error);
67
+ throw new APIError("BAD_REQUEST", {
68
+ message: "State error: failed to validate relay state",
69
+ cause: error,
70
+ });
71
+ }
72
+
73
+ if (!parsedData.errorURL) {
74
+ parsedData.errorURL = errorURL;
75
+ }
76
+
77
+ return parsedData;
78
+ }