@better-auth/sso 1.5.0-beta.1 → 1.5.0-beta.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/routes/sso.ts CHANGED
@@ -1,4 +1,3 @@
1
- import { base64 } from "@better-auth/utils/base64";
2
1
  import { BetterFetchError, betterFetch } from "@better-fetch/fetch";
3
2
  import type { User, Verification } from "better-auth";
4
3
  import {
@@ -12,6 +11,7 @@ import {
12
11
  import {
13
12
  APIError,
14
13
  createAuthEndpoint,
14
+ getSessionFromCtx,
15
15
  sessionMiddleware,
16
16
  } from "better-auth/api";
17
17
  import { setSessionCookie } from "better-auth/cookies";
@@ -37,6 +37,8 @@ import {
37
37
  DEFAULT_ASSERTION_TTL_MS,
38
38
  DEFAULT_AUTHN_REQUEST_TTL_MS,
39
39
  DEFAULT_CLOCK_SKEW_MS,
40
+ DEFAULT_MAX_SAML_METADATA_SIZE,
41
+ DEFAULT_MAX_SAML_RESPONSE_SIZE,
40
42
  USED_ASSERTION_KEY_PREFIX,
41
43
  } from "../constants";
42
44
  import { assignOrganizationFromProvider } from "../linking";
@@ -46,9 +48,14 @@ import {
46
48
  discoverOIDCConfig,
47
49
  mapDiscoveryErrorToAPIError,
48
50
  } from "../oidc";
49
- import { validateConfigAlgorithms, validateSAMLAlgorithms } from "../saml";
51
+ import {
52
+ validateConfigAlgorithms,
53
+ validateSAMLAlgorithms,
54
+ validateSingleAssertion,
55
+ } from "../saml";
56
+ import { generateRelayState, parseRelayState } from "../saml-state";
50
57
  import type { OIDCConfig, SAMLConfig, SSOOptions, SSOProvider } from "../types";
51
- import { safeJsonParse, validateEmailDomain } from "../utils";
58
+ import { domainMatches, safeJsonParse, validateEmailDomain } from "../utils";
52
59
 
53
60
  export interface TimestampValidationOptions {
54
61
  clockSkew?: number;
@@ -160,6 +167,8 @@ const spMetadataQuerySchema = z.object({
160
167
  format: z.enum(["xml", "json"]).default("xml"),
161
168
  });
162
169
 
170
+ type RelayState = Awaited<ReturnType<typeof parseRelayState>>;
171
+
163
172
  export const spMetadata = () => {
164
173
  return createAuthEndpoint(
165
174
  "/sso/saml2/sp/metadata",
@@ -220,6 +229,7 @@ export const spMetadata = () => {
220
229
  },
221
230
  ],
222
231
  wantMessageSigned: parsedSamlConfig.wantAssertionsSigned || false,
232
+ authnRequestsSigned: parsedSamlConfig.authnRequestsSigned || false,
223
233
  nameIDFormat: parsedSamlConfig.identifierFormat
224
234
  ? [parsedSamlConfig.identifierFormat]
225
235
  : undefined,
@@ -242,7 +252,8 @@ const ssoProviderBodySchema = z.object({
242
252
  description: "The issuer of the provider",
243
253
  }),
244
254
  domain: z.string({}).meta({
245
- description: "The domain of the provider. This is used for email matching",
255
+ description:
256
+ "The domain(s) of the provider. For enterprise multi-domain SSO where a single IdP serves multiple email domains, use comma-separated values (e.g., 'company.com,subsidiary.com,acquired-company.com')",
246
257
  }),
247
258
  oidcConfig: z
248
259
  .object({
@@ -380,6 +391,7 @@ const ssoProviderBodySchema = z.object({
380
391
  encPrivateKeyPass: z.string().optional(),
381
392
  }),
382
393
  wantAssertionsSigned: z.boolean().optional(),
394
+ authnRequestsSigned: z.boolean().optional(),
383
395
  signatureAlgorithm: z.string().optional(),
384
396
  digestAlgorithm: z.string().optional(),
385
397
  identifierFormat: z.string().optional(),
@@ -666,6 +678,20 @@ export const registerSSOProvider = <O extends SSOOptions>(options: O) => {
666
678
  message: "Invalid issuer. Must be a valid URL",
667
679
  });
668
680
  }
681
+
682
+ if (body.samlConfig?.idpMetadata?.metadata) {
683
+ const maxMetadataSize =
684
+ options?.saml?.maxMetadataSize ?? DEFAULT_MAX_SAML_METADATA_SIZE;
685
+ if (
686
+ new TextEncoder().encode(body.samlConfig.idpMetadata.metadata)
687
+ .length > maxMetadataSize
688
+ ) {
689
+ throw new APIError("BAD_REQUEST", {
690
+ message: `IdP metadata exceeds maximum allowed size (${maxMetadataSize} bytes)`,
691
+ });
692
+ }
693
+ }
694
+
669
695
  if (ctx.body.organizationId) {
670
696
  const organization = await ctx.context.adapter.findOne({
671
697
  model: "member",
@@ -720,7 +746,7 @@ export const registerSSOProvider = <O extends SSOOptions>(options: O) => {
720
746
  tokenEndpointAuthentication:
721
747
  body.oidcConfig.tokenEndpointAuthentication,
722
748
  },
723
- isTrustedOrigin: ctx.context.isTrustedOrigin,
749
+ isTrustedOrigin: (url: string) => ctx.context.isTrustedOrigin(url),
724
750
  });
725
751
  } catch (error) {
726
752
  if (error instanceof DiscoveryError) {
@@ -811,6 +837,7 @@ export const registerSSOProvider = <O extends SSOOptions>(options: O) => {
811
837
  idpMetadata: body.samlConfig.idpMetadata,
812
838
  spMetadata: body.samlConfig.spMetadata,
813
839
  wantAssertionsSigned: body.samlConfig.wantAssertionsSigned,
840
+ authnRequestsSigned: body.samlConfig.authnRequestsSigned,
814
841
  signatureAlgorithm: body.samlConfig.signatureAlgorithm,
815
842
  digestAlgorithm: body.samlConfig.digestAlgorithm,
816
843
  identifierFormat: body.samlConfig.identifierFormat,
@@ -1102,38 +1129,58 @@ export const signInSSO = (options?: SSOOptions) => {
1102
1129
  }
1103
1130
  // Try to find provider in database
1104
1131
  if (!provider) {
1105
- provider = await ctx.context.adapter
1106
- .findOne<SSOProvider<SSOOptions>>({
1107
- model: "ssoProvider",
1108
- where: [
1109
- {
1110
- field: providerId
1111
- ? "providerId"
1112
- : orgId
1113
- ? "organizationId"
1114
- : "domain",
1115
- value: providerId || orgId || domain!,
1116
- },
1117
- ],
1118
- })
1119
- .then((res) => {
1120
- if (!res) {
1121
- return null;
1122
- }
1123
- return {
1124
- ...res,
1125
- oidcConfig: res.oidcConfig
1126
- ? safeJsonParse<OIDCConfig>(
1127
- res.oidcConfig as unknown as string,
1128
- ) || undefined
1129
- : undefined,
1130
- samlConfig: res.samlConfig
1131
- ? safeJsonParse<SAMLConfig>(
1132
- res.samlConfig as unknown as string,
1133
- ) || undefined
1134
- : undefined,
1135
- };
1136
- });
1132
+ const parseProvider = (res: SSOProvider<SSOOptions> | null) => {
1133
+ if (!res) return null;
1134
+ return {
1135
+ ...res,
1136
+ oidcConfig: res.oidcConfig
1137
+ ? safeJsonParse<OIDCConfig>(
1138
+ res.oidcConfig as unknown as string,
1139
+ ) || undefined
1140
+ : undefined,
1141
+ samlConfig: res.samlConfig
1142
+ ? safeJsonParse<SAMLConfig>(
1143
+ res.samlConfig as unknown as string,
1144
+ ) || undefined
1145
+ : undefined,
1146
+ };
1147
+ };
1148
+
1149
+ if (providerId || orgId) {
1150
+ // Exact match for providerId or orgId
1151
+ provider = parseProvider(
1152
+ await ctx.context.adapter.findOne<SSOProvider<SSOOptions>>({
1153
+ model: "ssoProvider",
1154
+ where: [
1155
+ {
1156
+ field: providerId ? "providerId" : "organizationId",
1157
+ value: providerId || orgId!,
1158
+ },
1159
+ ],
1160
+ }),
1161
+ );
1162
+ } else if (domain) {
1163
+ // For domain lookup, support comma-separated domains
1164
+ // First try exact match (fast path)
1165
+ provider = parseProvider(
1166
+ await ctx.context.adapter.findOne<SSOProvider<SSOOptions>>({
1167
+ model: "ssoProvider",
1168
+ where: [{ field: "domain", value: domain }],
1169
+ }),
1170
+ );
1171
+ // If not found, search all providers for comma-separated domain match
1172
+ if (!provider) {
1173
+ const allProviders = await ctx.context.adapter.findMany<
1174
+ SSOProvider<SSOOptions>
1175
+ >({
1176
+ model: "ssoProvider",
1177
+ });
1178
+ const matchingProvider = allProviders.find((p) =>
1179
+ domainMatches(domain, p.domain),
1180
+ );
1181
+ provider = parseProvider(matchingProvider ?? null);
1182
+ }
1183
+ }
1137
1184
  }
1138
1185
 
1139
1186
  if (!provider) {
@@ -1222,6 +1269,17 @@ export const signInSSO = (options?: SSOOptions) => {
1222
1269
  });
1223
1270
  }
1224
1271
 
1272
+ if (
1273
+ parsedSamlConfig.authnRequestsSigned &&
1274
+ !parsedSamlConfig.spMetadata?.privateKey &&
1275
+ !parsedSamlConfig.privateKey
1276
+ ) {
1277
+ ctx.context.logger.warn(
1278
+ "authnRequestsSigned is enabled but no privateKey provided - AuthnRequests will not be signed",
1279
+ { providerId: provider.providerId },
1280
+ );
1281
+ }
1282
+
1225
1283
  let metadata = parsedSamlConfig.spMetadata.metadata;
1226
1284
 
1227
1285
  if (!metadata) {
@@ -1241,6 +1299,8 @@ export const signInSSO = (options?: SSOOptions) => {
1241
1299
  ],
1242
1300
  wantMessageSigned:
1243
1301
  parsedSamlConfig.wantAssertionsSigned || false,
1302
+ authnRequestsSigned:
1303
+ parsedSamlConfig.authnRequestsSigned || false,
1244
1304
  nameIDFormat: parsedSamlConfig.identifierFormat
1245
1305
  ? [parsedSamlConfig.identifierFormat]
1246
1306
  : undefined,
@@ -1251,6 +1311,10 @@ export const signInSSO = (options?: SSOOptions) => {
1251
1311
  const sp = saml.ServiceProvider({
1252
1312
  metadata: metadata,
1253
1313
  allowCreate: true,
1314
+ privateKey:
1315
+ parsedSamlConfig.spMetadata?.privateKey ||
1316
+ parsedSamlConfig.privateKey,
1317
+ privateKeyPass: parsedSamlConfig.spMetadata?.privateKeyPass,
1254
1318
  });
1255
1319
 
1256
1320
  const idp = saml.IdentityProvider({
@@ -1274,6 +1338,12 @@ export const signInSSO = (options?: SSOOptions) => {
1274
1338
  });
1275
1339
  }
1276
1340
 
1341
+ const { state: relayState } = await generateRelayState(
1342
+ ctx,
1343
+ undefined,
1344
+ false,
1345
+ );
1346
+
1277
1347
  const shouldSaveRequest =
1278
1348
  loginRequest.id && options?.saml?.enableInResponseToValidation;
1279
1349
  if (shouldSaveRequest) {
@@ -1292,9 +1362,7 @@ export const signInSSO = (options?: SSOOptions) => {
1292
1362
  }
1293
1363
 
1294
1364
  return ctx.json({
1295
- url: `${loginRequest.context}&RelayState=${encodeURIComponent(
1296
- body.callbackURL,
1297
- )}`,
1365
+ url: `${loginRequest.context}&RelayState=${encodeURIComponent(relayState)}`,
1298
1366
  redirect: true,
1299
1367
  });
1300
1368
  }
@@ -1664,12 +1732,71 @@ const callbackSSOSAMLBodySchema = z.object({
1664
1732
  RelayState: z.string().optional(),
1665
1733
  });
1666
1734
 
1735
+ /**
1736
+ * Validates and returns a safe redirect URL.
1737
+ * - Prevents open redirect attacks by validating against trusted origins
1738
+ * - Prevents redirect loops by checking if URL points to callback route
1739
+ * - Falls back to appOrigin if URL is invalid or unsafe
1740
+ */
1741
+ const getSafeRedirectUrl = (
1742
+ url: string | undefined,
1743
+ callbackPath: string,
1744
+ appOrigin: string,
1745
+ isTrustedOrigin: (
1746
+ url: string,
1747
+ settings?: { allowRelativePaths: boolean },
1748
+ ) => boolean,
1749
+ ): string => {
1750
+ if (!url) {
1751
+ return appOrigin;
1752
+ }
1753
+
1754
+ if (url.startsWith("/") && !url.startsWith("//")) {
1755
+ try {
1756
+ const absoluteUrl = new URL(url, appOrigin);
1757
+ if (absoluteUrl.origin !== appOrigin) {
1758
+ return appOrigin;
1759
+ }
1760
+ const callbackPathname = new URL(callbackPath).pathname;
1761
+ if (absoluteUrl.pathname === callbackPathname) {
1762
+ return appOrigin;
1763
+ }
1764
+ } catch {
1765
+ return appOrigin;
1766
+ }
1767
+ return url;
1768
+ }
1769
+
1770
+ if (!isTrustedOrigin(url, { allowRelativePaths: false })) {
1771
+ return appOrigin;
1772
+ }
1773
+
1774
+ try {
1775
+ const callbackPathname = new URL(callbackPath).pathname;
1776
+ const urlPathname = new URL(url).pathname;
1777
+ if (urlPathname === callbackPathname) {
1778
+ return appOrigin;
1779
+ }
1780
+ } catch {
1781
+ if (url === callbackPath || url.startsWith(`${callbackPath}?`)) {
1782
+ return appOrigin;
1783
+ }
1784
+ }
1785
+
1786
+ return url;
1787
+ };
1788
+
1667
1789
  export const callbackSSOSAML = (options?: SSOOptions) => {
1668
1790
  return createAuthEndpoint(
1669
1791
  "/sso/saml2/callback/:providerId",
1670
1792
  {
1671
- method: "POST",
1672
- body: callbackSSOSAMLBodySchema,
1793
+ method: ["GET", "POST"],
1794
+ body: callbackSSOSAMLBodySchema.optional(),
1795
+ query: z
1796
+ .object({
1797
+ RelayState: z.string().optional(),
1798
+ })
1799
+ .optional(),
1673
1800
  metadata: {
1674
1801
  ...HIDE_METADATA,
1675
1802
  allowedMediaTypes: [
@@ -1680,7 +1807,7 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1680
1807
  operationId: "handleSAMLCallback",
1681
1808
  summary: "Callback URL for SAML provider",
1682
1809
  description:
1683
- "This endpoint is used as the callback URL for SAML providers.",
1810
+ "This endpoint is used as the callback URL for SAML providers. Supports both GET and POST methods for IdP-initiated and SP-initiated flows.",
1684
1811
  responses: {
1685
1812
  "302": {
1686
1813
  description: "Redirects to the callback URL",
@@ -1696,8 +1823,58 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1696
1823
  },
1697
1824
  },
1698
1825
  async (ctx) => {
1699
- const { SAMLResponse, RelayState } = ctx.body;
1700
1826
  const { providerId } = ctx.params;
1827
+ const appOrigin = new URL(ctx.context.baseURL).origin;
1828
+ const errorURL =
1829
+ ctx.context.options.onAPIError?.errorURL || `${appOrigin}/error`;
1830
+ const currentCallbackPath = `${ctx.context.baseURL}/sso/saml2/callback/${providerId}`;
1831
+
1832
+ // Determine if this is a GET request by checking both method AND body presence
1833
+ // When called via auth.api.*, ctx.method may not be reliable, so we also check for body
1834
+ const isGetRequest = ctx.method === "GET" && !ctx.body?.SAMLResponse;
1835
+
1836
+ if (isGetRequest) {
1837
+ const session = await getSessionFromCtx(ctx);
1838
+
1839
+ if (!session?.session) {
1840
+ throw ctx.redirect(`${errorURL}?error=invalid_request`);
1841
+ }
1842
+
1843
+ const relayState = ctx.query?.RelayState as string | undefined;
1844
+ const safeRedirectUrl = getSafeRedirectUrl(
1845
+ relayState,
1846
+ currentCallbackPath,
1847
+ appOrigin,
1848
+ (url, settings) => ctx.context.isTrustedOrigin(url, settings),
1849
+ );
1850
+
1851
+ throw ctx.redirect(safeRedirectUrl);
1852
+ }
1853
+
1854
+ if (!ctx.body?.SAMLResponse) {
1855
+ throw new APIError("BAD_REQUEST", {
1856
+ message: "SAMLResponse is required for POST requests",
1857
+ });
1858
+ }
1859
+
1860
+ const { SAMLResponse } = ctx.body;
1861
+
1862
+ const maxResponseSize =
1863
+ options?.saml?.maxResponseSize ?? DEFAULT_MAX_SAML_RESPONSE_SIZE;
1864
+ if (new TextEncoder().encode(SAMLResponse).length > maxResponseSize) {
1865
+ throw new APIError("BAD_REQUEST", {
1866
+ message: `SAML response exceeds maximum allowed size (${maxResponseSize} bytes)`,
1867
+ });
1868
+ }
1869
+
1870
+ let relayState: RelayState | null = null;
1871
+ if (ctx.body.RelayState) {
1872
+ try {
1873
+ relayState = await parseRelayState(ctx);
1874
+ } catch {
1875
+ relayState = null;
1876
+ }
1877
+ }
1701
1878
  let provider: SSOProvider<SSOOptions> | null = null;
1702
1879
  if (options?.defaultSSO?.length) {
1703
1880
  const matchingDefault = options.defaultSSO.find(
@@ -1811,12 +1988,14 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1811
1988
  : undefined,
1812
1989
  });
1813
1990
 
1991
+ validateSingleAssertion(SAMLResponse);
1992
+
1814
1993
  let parsedResponse: FlowResult;
1815
1994
  try {
1816
1995
  parsedResponse = await sp.parseLoginResponse(idp, "post", {
1817
1996
  body: {
1818
1997
  SAMLResponse,
1819
- RelayState: RelayState || undefined,
1998
+ RelayState: ctx.body.RelayState || undefined,
1820
1999
  },
1821
2000
  });
1822
2001
 
@@ -1826,8 +2005,8 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1826
2005
  } catch (error) {
1827
2006
  ctx.context.logger.error("SAML response validation failed", {
1828
2007
  error,
1829
- decodedResponse: new TextDecoder().decode(
1830
- base64.decode(SAMLResponse),
2008
+ decodedResponse: Buffer.from(SAMLResponse, "base64").toString(
2009
+ "utf-8",
1831
2010
  ),
1832
2011
  });
1833
2012
  throw new APIError("BAD_REQUEST", {
@@ -1879,7 +2058,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1879
2058
  { inResponseTo, providerId: provider.providerId },
1880
2059
  );
1881
2060
  const redirectUrl =
1882
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2061
+ relayState?.callbackURL ||
2062
+ parsedSamlConfig.callbackUrl ||
2063
+ ctx.context.baseURL;
1883
2064
  throw ctx.redirect(
1884
2065
  `${redirectUrl}?error=invalid_saml_response&error_description=Unknown+or+expired+request+ID`,
1885
2066
  );
@@ -1899,7 +2080,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1899
2080
  `${AUTHN_REQUEST_KEY_PREFIX}${inResponseTo}`,
1900
2081
  );
1901
2082
  const redirectUrl =
1902
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2083
+ relayState?.callbackURL ||
2084
+ parsedSamlConfig.callbackUrl ||
2085
+ ctx.context.baseURL;
1903
2086
  throw ctx.redirect(
1904
2087
  `${redirectUrl}?error=invalid_saml_response&error_description=Provider+mismatch`,
1905
2088
  );
@@ -1914,7 +2097,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1914
2097
  { providerId: provider.providerId },
1915
2098
  );
1916
2099
  const redirectUrl =
1917
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2100
+ relayState?.callbackURL ||
2101
+ parsedSamlConfig.callbackUrl ||
2102
+ ctx.context.baseURL;
1918
2103
  throw ctx.redirect(
1919
2104
  `${redirectUrl}?error=unsolicited_response&error_description=IdP-initiated+SSO+not+allowed`,
1920
2105
  );
@@ -1967,7 +2152,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1967
2152
  },
1968
2153
  );
1969
2154
  const redirectUrl =
1970
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2155
+ relayState?.callbackURL ||
2156
+ parsedSamlConfig.callbackUrl ||
2157
+ ctx.context.baseURL;
1971
2158
  throw ctx.redirect(
1972
2159
  `${redirectUrl}?error=replay_detected&error_description=SAML+assertion+has+already+been+used`,
1973
2160
  );
@@ -2002,7 +2189,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
2002
2189
  ]),
2003
2190
  ),
2004
2191
  id: attributes[mapping.id || "nameID"] || extract.nameID,
2005
- email: attributes[mapping.email || "email"] || extract.nameID,
2192
+ email: (
2193
+ attributes[mapping.email || "email"] || extract.nameID
2194
+ ).toLowerCase(),
2006
2195
  name:
2007
2196
  [
2008
2197
  attributes[mapping.firstName || "givenName"],
@@ -2041,7 +2230,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
2041
2230
  validateEmailDomain(userInfo.email as string, provider.domain));
2042
2231
 
2043
2232
  const callbackUrl =
2044
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2233
+ relayState?.callbackURL ||
2234
+ parsedSamlConfig.callbackUrl ||
2235
+ ctx.context.baseURL;
2045
2236
 
2046
2237
  const result = await handleOAuthUserInfo(ctx, {
2047
2238
  userInfo: {
@@ -2092,15 +2283,18 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
2092
2283
  });
2093
2284
 
2094
2285
  await setSessionCookie(ctx, { session, user });
2095
- throw ctx.redirect(callbackUrl);
2286
+
2287
+ const safeRedirectUrl = getSafeRedirectUrl(
2288
+ relayState?.callbackURL || parsedSamlConfig.callbackUrl,
2289
+ currentCallbackPath,
2290
+ appOrigin,
2291
+ (url, settings) => ctx.context.isTrustedOrigin(url, settings),
2292
+ );
2293
+ throw ctx.redirect(safeRedirectUrl);
2096
2294
  },
2097
2295
  );
2098
2296
  };
2099
2297
 
2100
- const acsEndpointParamsSchema = z.object({
2101
- providerId: z.string().optional(),
2102
- });
2103
-
2104
2298
  const acsEndpointBodySchema = z.object({
2105
2299
  SAMLResponse: z.string(),
2106
2300
  RelayState: z.string().optional(),
@@ -2111,7 +2305,6 @@ export const acsEndpoint = (options?: SSOOptions) => {
2111
2305
  "/sso/saml2/sp/acs/:providerId",
2112
2306
  {
2113
2307
  method: "POST",
2114
- params: acsEndpointParamsSchema,
2115
2308
  body: acsEndpointBodySchema,
2116
2309
  metadata: {
2117
2310
  ...HIDE_METADATA,
@@ -2137,6 +2330,14 @@ export const acsEndpoint = (options?: SSOOptions) => {
2137
2330
  const { SAMLResponse, RelayState = "" } = ctx.body;
2138
2331
  const { providerId } = ctx.params;
2139
2332
 
2333
+ const maxResponseSize =
2334
+ options?.saml?.maxResponseSize ?? DEFAULT_MAX_SAML_RESPONSE_SIZE;
2335
+ if (new TextEncoder().encode(SAMLResponse).length > maxResponseSize) {
2336
+ throw new APIError("BAD_REQUEST", {
2337
+ message: `SAML response exceeds maximum allowed size (${maxResponseSize} bytes)`,
2338
+ });
2339
+ }
2340
+
2140
2341
  // If defaultSSO is configured, use it as the provider
2141
2342
  let provider: SSOProvider<SSOOptions> | null = null;
2142
2343
 
@@ -2167,7 +2368,7 @@ export const acsEndpoint = (options?: SSOOptions) => {
2167
2368
  where: [
2168
2369
  {
2169
2370
  field: "providerId",
2170
- value: providerId ?? "sso",
2371
+ value: providerId,
2171
2372
  },
2172
2373
  ],
2173
2374
  })
@@ -2240,6 +2441,23 @@ export const acsEndpoint = (options?: SSOOptions) => {
2240
2441
  metadata: idpData.metadata,
2241
2442
  });
2242
2443
 
2444
+ try {
2445
+ validateSingleAssertion(SAMLResponse);
2446
+ } catch (error) {
2447
+ if (error instanceof APIError) {
2448
+ const redirectUrl =
2449
+ RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2450
+ const errorCode =
2451
+ error.body?.code === "SAML_MULTIPLE_ASSERTIONS"
2452
+ ? "multiple_assertions"
2453
+ : "no_assertion";
2454
+ throw ctx.redirect(
2455
+ `${redirectUrl}?error=${errorCode}&error_description=${encodeURIComponent(error.message)}`,
2456
+ );
2457
+ }
2458
+ throw error;
2459
+ }
2460
+
2243
2461
  // Parse and validate SAML response
2244
2462
  let parsedResponse: FlowResult;
2245
2463
  try {
@@ -2256,8 +2474,8 @@ export const acsEndpoint = (options?: SSOOptions) => {
2256
2474
  } catch (error) {
2257
2475
  ctx.context.logger.error("SAML response validation failed", {
2258
2476
  error,
2259
- decodedResponse: new TextDecoder().decode(
2260
- base64.decode(SAMLResponse),
2477
+ decodedResponse: Buffer.from(SAMLResponse, "base64").toString(
2478
+ "utf-8",
2261
2479
  ),
2262
2480
  });
2263
2481
  throw new APIError("BAD_REQUEST", {
@@ -2353,8 +2571,8 @@ export const acsEndpoint = (options?: SSOOptions) => {
2353
2571
  }
2354
2572
 
2355
2573
  // Assertion Replay Protection
2356
- const samlContentAcs = new TextDecoder().decode(
2357
- base64.decode(SAMLResponse),
2574
+ const samlContentAcs = Buffer.from(SAMLResponse, "base64").toString(
2575
+ "utf-8",
2358
2576
  );
2359
2577
  const assertionIdAcs = extractAssertionId(samlContentAcs);
2360
2578
 
@@ -2433,7 +2651,9 @@ export const acsEndpoint = (options?: SSOOptions) => {
2433
2651
  ]),
2434
2652
  ),
2435
2653
  id: attributes[mapping.id || "nameID"] || extract.nameID,
2436
- email: attributes[mapping.email || "email"] || extract.nameID,
2654
+ email: (
2655
+ attributes[mapping.email || "email"] || extract.nameID
2656
+ ).toLowerCase(),
2437
2657
  name:
2438
2658
  [
2439
2659
  attributes[mapping.firstName || "givenName"],
@@ -1,5 +1,5 @@
1
1
  import { APIError } from "better-auth/api";
2
- import { XMLParser } from "fast-xml-parser";
2
+ import { findNode, xmlParser } from "./parser";
3
3
 
4
4
  export const SignatureAlgorithm = {
5
5
  RSA_SHA1: "http://www.w3.org/2000/09/xmldsig#rsa-sha1",
@@ -102,36 +102,6 @@ export interface AlgorithmValidationOptions {
102
102
  allowedDataEncryptionAlgorithms?: string[];
103
103
  }
104
104
 
105
- const xmlParser = new XMLParser({
106
- ignoreAttributes: false,
107
- attributeNamePrefix: "@_",
108
- removeNSPrefix: true,
109
- });
110
-
111
- function findNode(obj: unknown, nodeName: string): unknown {
112
- if (!obj || typeof obj !== "object") return null;
113
-
114
- const record = obj as Record<string, unknown>;
115
-
116
- if (nodeName in record) {
117
- return record[nodeName];
118
- }
119
-
120
- for (const value of Object.values(record)) {
121
- if (Array.isArray(value)) {
122
- for (const item of value) {
123
- const found = findNode(item, nodeName);
124
- if (found) return found;
125
- }
126
- } else if (typeof value === "object" && value !== null) {
127
- const found = findNode(value, nodeName);
128
- if (found) return found;
129
- }
130
- }
131
-
132
- return null;
133
- }
134
-
135
105
  function extractEncryptionAlgorithms(xml: string): {
136
106
  keyEncryption: string | null;
137
107
  dataEncryption: string | null;