@better-auth/sso 1.5.0-beta.6 → 1.5.0-beta.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
1
 
2
- > @better-auth/sso@1.5.0-beta.6 build /home/runner/work/better-auth/better-auth/packages/sso
2
+ > @better-auth/sso@1.5.0-beta.8 build /home/runner/work/better-auth/better-auth/packages/sso
3
3
  > tsdown
4
4
 
5
5
  ℹ tsdown v0.19.0 powered by rolldown v1.0.0-beta.59
@@ -7,10 +7,10 @@
7
7
  ℹ entry: src/index.ts, src/client.ts
8
8
  ℹ tsconfig: tsconfig.json
9
9
  ℹ Build start
10
- ℹ dist/index.mjs 99.53 kB │ gzip: 19.50 kB
11
- ℹ dist/client.mjs  0.15 kB │ gzip: 0.14 kB
12
- ℹ dist/index.d.mts  1.67 kB │ gzip: 0.57 kB
13
- ℹ dist/client.d.mts  0.49 kB │ gzip: 0.30 kB
14
- ℹ dist/index-BLMoKtp1.d.mts 44.35 kB │ gzip: 9.16 kB
15
- ℹ 5 files, total: 146.19 kB
16
- ✔ Build complete in 17523ms
10
+ ℹ dist/index.mjs 104.03 kB │ gzip: 20.69 kB
11
+ ℹ dist/client.mjs  0.15 kB │ gzip: 0.14 kB
12
+ ℹ dist/index.d.mts  1.67 kB │ gzip: 0.57 kB
13
+ ℹ dist/client.d.mts  0.49 kB │ gzip: 0.29 kB
14
+ ℹ dist/index-BT0wtuq1.d.mts  44.48 kB │ gzip: 9.20 kB
15
+ ℹ 5 files, total: 150.82 kB
16
+ ✔ Build complete in 17021ms
package/dist/client.d.mts CHANGED
@@ -1,4 +1,4 @@
1
- import { t as SSOPlugin } from "./index-BLMoKtp1.mjs";
1
+ import { t as SSOPlugin } from "./index-BT0wtuq1.mjs";
2
2
 
3
3
  //#region src/client.d.ts
4
4
  interface SSOClientOptions {
@@ -917,11 +917,14 @@ declare const callbackSSO: (options?: SSOOptions) => better_call0.StrictEndpoint
917
917
  };
918
918
  }, never>;
919
919
  declare const callbackSSOSAML: (options?: SSOOptions) => better_call0.StrictEndpoint<"/sso/saml2/callback/:providerId", {
920
- method: "POST";
921
- body: z.ZodObject<{
920
+ method: ("POST" | "GET")[];
921
+ body: z.ZodOptional<z.ZodObject<{
922
922
  SAMLResponse: z.ZodString;
923
923
  RelayState: z.ZodOptional<z.ZodString>;
924
- }, z.core.$strip>;
924
+ }, z.core.$strip>>;
925
+ query: z.ZodOptional<z.ZodObject<{
926
+ RelayState: z.ZodOptional<z.ZodString>;
927
+ }, z.core.$strip>>;
925
928
  metadata: {
926
929
  allowedMediaTypes: string[];
927
930
  openapi: {
package/dist/index.d.mts CHANGED
@@ -1,2 +1,2 @@
1
- import { A as DataEncryptionAlgorithm, C as TimestampValidationOptions, D as SSOOptions, E as SAMLConfig, M as DigestAlgorithm, N as KeyEncryptionAlgorithm, O as SSOProvider, P as SignatureAlgorithm, S as SAMLConditions, T as OIDCConfig, _ as REQUIRED_DISCOVERY_FIELDS, a as fetchDiscoveryDocument, b as DEFAULT_MAX_SAML_METADATA_SIZE, c as normalizeUrl, d as validateDiscoveryUrl, f as DiscoverOIDCConfigParams, g as OIDCDiscoveryDocument, h as HydratedOIDCConfig, i as discoverOIDCConfig, j as DeprecatedAlgorithmBehavior, k as AlgorithmValidationOptions, l as selectTokenEndpointAuthMethod, m as DiscoveryErrorCode, n as sso, o as needsRuntimeDiscovery, p as DiscoveryError, r as computeDiscoveryUrl, s as normalizeDiscoveryUrls, t as SSOPlugin, u as validateDiscoveryDocument, v as RequiredDiscoveryField, w as validateSAMLTimestamp, x as DEFAULT_MAX_SAML_RESPONSE_SIZE, y as DEFAULT_CLOCK_SKEW_MS } from "./index-BLMoKtp1.mjs";
1
+ import { A as DataEncryptionAlgorithm, C as TimestampValidationOptions, D as SSOOptions, E as SAMLConfig, M as DigestAlgorithm, N as KeyEncryptionAlgorithm, O as SSOProvider, P as SignatureAlgorithm, S as SAMLConditions, T as OIDCConfig, _ as REQUIRED_DISCOVERY_FIELDS, a as fetchDiscoveryDocument, b as DEFAULT_MAX_SAML_METADATA_SIZE, c as normalizeUrl, d as validateDiscoveryUrl, f as DiscoverOIDCConfigParams, g as OIDCDiscoveryDocument, h as HydratedOIDCConfig, i as discoverOIDCConfig, j as DeprecatedAlgorithmBehavior, k as AlgorithmValidationOptions, l as selectTokenEndpointAuthMethod, m as DiscoveryErrorCode, n as sso, o as needsRuntimeDiscovery, p as DiscoveryError, r as computeDiscoveryUrl, s as normalizeDiscoveryUrls, t as SSOPlugin, u as validateDiscoveryDocument, v as RequiredDiscoveryField, w as validateSAMLTimestamp, x as DEFAULT_MAX_SAML_RESPONSE_SIZE, y as DEFAULT_CLOCK_SKEW_MS } from "./index-BT0wtuq1.mjs";
2
2
  export { AlgorithmValidationOptions, DEFAULT_CLOCK_SKEW_MS, DEFAULT_MAX_SAML_METADATA_SIZE, DEFAULT_MAX_SAML_RESPONSE_SIZE, DataEncryptionAlgorithm, DeprecatedAlgorithmBehavior, DigestAlgorithm, DiscoverOIDCConfigParams, DiscoveryError, DiscoveryErrorCode, HydratedOIDCConfig, KeyEncryptionAlgorithm, OIDCConfig, OIDCDiscoveryDocument, REQUIRED_DISCOVERY_FIELDS, RequiredDiscoveryField, SAMLConditions, SAMLConfig, SSOOptions, SSOPlugin, SSOProvider, SignatureAlgorithm, TimestampValidationOptions, computeDiscoveryUrl, discoverOIDCConfig, fetchDiscoveryDocument, needsRuntimeDiscovery, normalizeDiscoveryUrls, normalizeUrl, selectTokenEndpointAuthMethod, sso, validateDiscoveryDocument, validateDiscoveryUrl, validateSAMLTimestamp };
package/dist/index.mjs CHANGED
@@ -1,15 +1,16 @@
1
- import { APIError, createAuthEndpoint, createAuthMiddleware, sessionMiddleware } from "better-auth/api";
1
+ import { APIError, createAuthEndpoint, createAuthMiddleware, getSessionFromCtx, sessionMiddleware } from "better-auth/api";
2
2
  import { XMLParser, XMLValidator } from "fast-xml-parser";
3
3
  import * as saml from "samlify";
4
4
  import { generateRandomString } from "better-auth/crypto";
5
5
  import * as z$1 from "zod/v4";
6
6
  import z from "zod/v4";
7
7
  import { BetterFetchError, betterFetch } from "@better-fetch/fetch";
8
- import { HIDE_METADATA, createAuthorizationURL, generateState, parseState, validateAuthorizationCode, validateToken } from "better-auth";
8
+ import { HIDE_METADATA, createAuthorizationURL, generateGenericState, generateState, parseGenericState, parseState, validateAuthorizationCode, validateToken } from "better-auth";
9
9
  import { setSessionCookie } from "better-auth/cookies";
10
10
  import { handleOAuthUserInfo } from "better-auth/oauth2";
11
11
  import { decodeJwt } from "jose";
12
12
  import { base64 } from "@better-auth/utils/base64";
13
+ import { APIError as APIError$1 } from "better-call";
13
14
 
14
15
  //#region src/linking/org-assignment.ts
15
16
  /**
@@ -922,6 +923,49 @@ function validateSingleAssertion(samlResponse) {
922
923
  });
923
924
  }
924
925
 
926
+ //#endregion
927
+ //#region src/saml-state.ts
928
+ async function generateRelayState(c, link, additionalData) {
929
+ const callbackURL = c.body.callbackURL;
930
+ if (!callbackURL) throw new APIError$1("BAD_REQUEST", { message: "callbackURL is required" });
931
+ const codeVerifier = generateRandomString(128);
932
+ const stateData = {
933
+ ...additionalData ? additionalData : {},
934
+ callbackURL,
935
+ codeVerifier,
936
+ errorURL: c.body.errorCallbackURL,
937
+ newUserURL: c.body.newUserCallbackURL,
938
+ link,
939
+ expiresAt: Date.now() + 600 * 1e3,
940
+ requestSignUp: c.body.requestSignUp
941
+ };
942
+ try {
943
+ return generateGenericState(c, stateData, { cookieName: "relay_state" });
944
+ } catch (error) {
945
+ c.context.logger.error("Failed to create verification for relay state", error);
946
+ throw new APIError$1("INTERNAL_SERVER_ERROR", {
947
+ message: "State error: Unable to create verification for relay state",
948
+ cause: error
949
+ });
950
+ }
951
+ }
952
+ async function parseRelayState(c) {
953
+ const state = c.body.RelayState;
954
+ const errorURL = c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
955
+ let parsedData;
956
+ try {
957
+ parsedData = await parseGenericState(c, state, { cookieName: "relay_state" });
958
+ } catch (error) {
959
+ c.context.logger.error("Failed to parse relay state", error);
960
+ throw new APIError$1("BAD_REQUEST", {
961
+ message: "State error: failed to validate relay state",
962
+ cause: error
963
+ });
964
+ }
965
+ if (!parsedData.errorURL) parsedData.errorURL = errorURL;
966
+ return parsedData;
967
+ }
968
+
925
969
  //#endregion
926
970
  //#region src/utils.ts
927
971
  /**
@@ -1628,6 +1672,7 @@ const signInSSO = (options) => {
1628
1672
  });
1629
1673
  const loginRequest = sp.createLoginRequest(idp, "redirect");
1630
1674
  if (!loginRequest) throw new APIError("BAD_REQUEST", { message: "Invalid SAML request" });
1675
+ const { state: relayState } = await generateRelayState(ctx, void 0, false);
1631
1676
  if (loginRequest.id && options?.saml?.enableInResponseToValidation) {
1632
1677
  const ttl = options?.saml?.requestTTL ?? DEFAULT_AUTHN_REQUEST_TTL_MS;
1633
1678
  const record = {
@@ -1643,7 +1688,7 @@ const signInSSO = (options) => {
1643
1688
  });
1644
1689
  }
1645
1690
  return ctx.json({
1646
- url: `${loginRequest.context}&RelayState=${encodeURIComponent(body.callbackURL)}`,
1691
+ url: `${loginRequest.context}&RelayState=${encodeURIComponent(relayState)}`,
1647
1692
  redirect: true
1648
1693
  });
1649
1694
  }
@@ -1825,17 +1870,46 @@ const callbackSSOSAMLBodySchema = z.object({
1825
1870
  SAMLResponse: z.string(),
1826
1871
  RelayState: z.string().optional()
1827
1872
  });
1873
+ /**
1874
+ * Validates and returns a safe redirect URL.
1875
+ * - Prevents open redirect attacks by validating against trusted origins
1876
+ * - Prevents redirect loops by checking if URL points to callback route
1877
+ * - Falls back to appOrigin if URL is invalid or unsafe
1878
+ */
1879
+ const getSafeRedirectUrl = (url, callbackPath, appOrigin, isTrustedOrigin) => {
1880
+ if (!url) return appOrigin;
1881
+ if (url.startsWith("/") && !url.startsWith("//")) {
1882
+ try {
1883
+ const absoluteUrl = new URL(url, appOrigin);
1884
+ if (absoluteUrl.origin !== appOrigin) return appOrigin;
1885
+ const callbackPathname = new URL(callbackPath).pathname;
1886
+ if (absoluteUrl.pathname === callbackPathname) return appOrigin;
1887
+ } catch {
1888
+ return appOrigin;
1889
+ }
1890
+ return url;
1891
+ }
1892
+ if (!isTrustedOrigin(url, { allowRelativePaths: false })) return appOrigin;
1893
+ try {
1894
+ const callbackPathname = new URL(callbackPath).pathname;
1895
+ if (new URL(url).pathname === callbackPathname) return appOrigin;
1896
+ } catch {
1897
+ if (url === callbackPath || url.startsWith(`${callbackPath}?`)) return appOrigin;
1898
+ }
1899
+ return url;
1900
+ };
1828
1901
  const callbackSSOSAML = (options) => {
1829
1902
  return createAuthEndpoint("/sso/saml2/callback/:providerId", {
1830
- method: "POST",
1831
- body: callbackSSOSAMLBodySchema,
1903
+ method: ["GET", "POST"],
1904
+ body: callbackSSOSAMLBodySchema.optional(),
1905
+ query: z.object({ RelayState: z.string().optional() }).optional(),
1832
1906
  metadata: {
1833
1907
  ...HIDE_METADATA,
1834
1908
  allowedMediaTypes: ["application/x-www-form-urlencoded", "application/json"],
1835
1909
  openapi: {
1836
1910
  operationId: "handleSAMLCallback",
1837
1911
  summary: "Callback URL for SAML provider",
1838
- description: "This endpoint is used as the callback URL for SAML providers.",
1912
+ description: "This endpoint is used as the callback URL for SAML providers. Supports both GET and POST methods for IdP-initiated and SP-initiated flows.",
1839
1913
  responses: {
1840
1914
  "302": { description: "Redirects to the callback URL" },
1841
1915
  "400": { description: "Invalid SAML response" },
@@ -1844,10 +1918,26 @@ const callbackSSOSAML = (options) => {
1844
1918
  }
1845
1919
  }
1846
1920
  }, async (ctx) => {
1847
- const { SAMLResponse, RelayState } = ctx.body;
1848
1921
  const { providerId } = ctx.params;
1922
+ const appOrigin = new URL(ctx.context.baseURL).origin;
1923
+ const errorURL = ctx.context.options.onAPIError?.errorURL || `${appOrigin}/error`;
1924
+ const currentCallbackPath = `${ctx.context.baseURL}/sso/saml2/callback/${providerId}`;
1925
+ if (ctx.method === "GET" && !ctx.body?.SAMLResponse) {
1926
+ if (!(await getSessionFromCtx(ctx))?.session) throw ctx.redirect(`${errorURL}?error=invalid_request`);
1927
+ const relayState$1 = ctx.query?.RelayState;
1928
+ const safeRedirectUrl$1 = getSafeRedirectUrl(relayState$1, currentCallbackPath, appOrigin, (url, settings) => ctx.context.isTrustedOrigin(url, settings));
1929
+ throw ctx.redirect(safeRedirectUrl$1);
1930
+ }
1931
+ if (!ctx.body?.SAMLResponse) throw new APIError("BAD_REQUEST", { message: "SAMLResponse is required for POST requests" });
1932
+ const { SAMLResponse } = ctx.body;
1849
1933
  const maxResponseSize = options?.saml?.maxResponseSize ?? DEFAULT_MAX_SAML_RESPONSE_SIZE;
1850
1934
  if (new TextEncoder().encode(SAMLResponse).length > maxResponseSize) throw new APIError("BAD_REQUEST", { message: `SAML response exceeds maximum allowed size (${maxResponseSize} bytes)` });
1935
+ let relayState = null;
1936
+ if (ctx.body.RelayState) try {
1937
+ relayState = await parseRelayState(ctx);
1938
+ } catch {
1939
+ relayState = null;
1940
+ }
1851
1941
  let provider = null;
1852
1942
  if (options?.defaultSSO?.length) {
1853
1943
  const matchingDefault = options.defaultSSO.find((defaultProvider) => defaultProvider.providerId === providerId);
@@ -1918,7 +2008,7 @@ const callbackSSOSAML = (options) => {
1918
2008
  try {
1919
2009
  parsedResponse = await sp.parseLoginResponse(idp, "post", { body: {
1920
2010
  SAMLResponse,
1921
- RelayState: RelayState || void 0
2011
+ RelayState: ctx.body.RelayState || void 0
1922
2012
  } });
1923
2013
  if (!parsedResponse?.extract) throw new Error("Invalid SAML response structure");
1924
2014
  } catch (error) {
@@ -1955,7 +2045,7 @@ const callbackSSOSAML = (options) => {
1955
2045
  inResponseTo,
1956
2046
  providerId: provider.providerId
1957
2047
  });
1958
- const redirectUrl = RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2048
+ const redirectUrl = relayState?.callbackURL || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
1959
2049
  throw ctx.redirect(`${redirectUrl}?error=invalid_saml_response&error_description=Unknown+or+expired+request+ID`);
1960
2050
  }
1961
2051
  if (storedRequest.providerId !== provider.providerId) {
@@ -1965,13 +2055,13 @@ const callbackSSOSAML = (options) => {
1965
2055
  actualProvider: provider.providerId
1966
2056
  });
1967
2057
  await ctx.context.internalAdapter.deleteVerificationByIdentifier(`${AUTHN_REQUEST_KEY_PREFIX}${inResponseTo}`);
1968
- const redirectUrl = RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2058
+ const redirectUrl = relayState?.callbackURL || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
1969
2059
  throw ctx.redirect(`${redirectUrl}?error=invalid_saml_response&error_description=Provider+mismatch`);
1970
2060
  }
1971
2061
  await ctx.context.internalAdapter.deleteVerificationByIdentifier(`${AUTHN_REQUEST_KEY_PREFIX}${inResponseTo}`);
1972
2062
  } else if (!allowIdpInitiated) {
1973
2063
  ctx.context.logger.error("SAML IdP-initiated SSO rejected: InResponseTo missing and allowIdpInitiated is false", { providerId: provider.providerId });
1974
- const redirectUrl = RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2064
+ const redirectUrl = relayState?.callbackURL || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
1975
2065
  throw ctx.redirect(`${redirectUrl}?error=unsolicited_response&error_description=IdP-initiated+SSO+not+allowed`);
1976
2066
  }
1977
2067
  }
@@ -1998,7 +2088,7 @@ const callbackSSOSAML = (options) => {
1998
2088
  issuer,
1999
2089
  providerId: provider.providerId
2000
2090
  });
2001
- const redirectUrl = RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2091
+ const redirectUrl = relayState?.callbackURL || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2002
2092
  throw ctx.redirect(`${redirectUrl}?error=replay_detected&error_description=SAML+assertion+has+already+been+used`);
2003
2093
  }
2004
2094
  await ctx.context.internalAdapter.createVerificationValue({
@@ -2032,7 +2122,7 @@ const callbackSSOSAML = (options) => {
2032
2122
  throw new APIError("BAD_REQUEST", { message: "Unable to extract user ID or email from SAML response" });
2033
2123
  }
2034
2124
  const isTrustedProvider = !!ctx.context.options.account?.accountLinking?.trustedProviders?.includes(provider.providerId) || "domainVerified" in provider && !!provider.domainVerified && validateEmailDomain(userInfo.email, provider.domain);
2035
- const callbackUrl = RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2125
+ const callbackUrl = relayState?.callbackURL || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2036
2126
  const result = await handleOAuthUserInfo(ctx, {
2037
2127
  userInfo: {
2038
2128
  email: userInfo.email,
@@ -2074,7 +2164,8 @@ const callbackSSOSAML = (options) => {
2074
2164
  session,
2075
2165
  user
2076
2166
  });
2077
- throw ctx.redirect(callbackUrl);
2167
+ const safeRedirectUrl = getSafeRedirectUrl(relayState?.callbackURL || parsedSamlConfig.callbackUrl, currentCallbackPath, appOrigin, (url, settings) => ctx.context.isTrustedOrigin(url, settings));
2168
+ throw ctx.redirect(safeRedirectUrl);
2078
2169
  });
2079
2170
  };
2080
2171
  const acsEndpointParamsSchema = z.object({ providerId: z.string().optional() });
@@ -2329,6 +2420,12 @@ saml.setSchemaValidator({ async validate(xml) {
2329
2420
  if (XMLValidator.validate(xml, { allowBooleanAttributes: true }) === true) return "SUCCESS_VALIDATE_XML";
2330
2421
  throw "ERR_INVALID_XML";
2331
2422
  } });
2423
+ /**
2424
+ * SAML endpoint paths that should skip origin check validation.
2425
+ * These endpoints receive POST requests from external Identity Providers,
2426
+ * which won't have a matching Origin header.
2427
+ */
2428
+ const SAML_SKIP_ORIGIN_CHECK_PATHS = ["/sso/saml2/callback", "/sso/saml2/sp/acs"];
2332
2429
  function sso(options) {
2333
2430
  const optionsWithStore = options;
2334
2431
  let endpoints = {
@@ -2351,6 +2448,11 @@ function sso(options) {
2351
2448
  }
2352
2449
  return {
2353
2450
  id: "sso",
2451
+ init(ctx) {
2452
+ const existing = ctx.skipOriginCheck;
2453
+ if (existing === true) return {};
2454
+ return { context: { skipOriginCheck: [...Array.isArray(existing) ? existing : [], ...SAML_SKIP_ORIGIN_CHECK_PATHS] } };
2455
+ },
2354
2456
  endpoints,
2355
2457
  hooks: { after: [{
2356
2458
  matcher(context) {
package/package.json CHANGED
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "name": "@better-auth/sso",
3
3
  "author": "Bereket Engida",
4
- "version": "1.5.0-beta.6",
4
+ "version": "1.5.0-beta.8",
5
5
  "type": "module",
6
6
  "main": "dist/index.mjs",
7
7
  "types": "dist/index.d.mts",
@@ -67,13 +67,13 @@
67
67
  "express": "^5.1.0",
68
68
  "oauth2-mock-server": "^8.2.0",
69
69
  "tsdown": "^0.19.0",
70
- "@better-auth/core": "1.5.0-beta.6",
71
- "better-auth": "1.5.0-beta.6"
70
+ "better-auth": "1.5.0-beta.8",
71
+ "@better-auth/core": "1.5.0-beta.8"
72
72
  },
73
73
  "peerDependencies": {
74
74
  "@better-auth/utils": "0.3.0",
75
- "@better-auth/core": "1.5.0-beta.6",
76
- "better-auth": "1.5.0-beta.6"
75
+ "@better-auth/core": "1.5.0-beta.8",
76
+ "better-auth": "1.5.0-beta.8"
77
77
  },
78
78
  "scripts": {
79
79
  "test": "vitest",
package/src/index.ts CHANGED
@@ -103,6 +103,16 @@ export type SSOPlugin<O extends SSOOptions> = {
103
103
  : {});
104
104
  };
105
105
 
106
+ /**
107
+ * SAML endpoint paths that should skip origin check validation.
108
+ * These endpoints receive POST requests from external Identity Providers,
109
+ * which won't have a matching Origin header.
110
+ */
111
+ const SAML_SKIP_ORIGIN_CHECK_PATHS = [
112
+ "/sso/saml2/callback", // SP-initiated SSO callback (prefix matches /callback/:providerId)
113
+ "/sso/saml2/sp/acs", // IdP-initiated SSO ACS (prefix matches /sp/acs/:providerId)
114
+ ];
115
+
106
116
  export function sso<
107
117
  O extends SSOOptions & {
108
118
  domainVerification?: { enabled: true };
@@ -148,6 +158,18 @@ export function sso<O extends SSOOptions>(options?: O | undefined): any {
148
158
 
149
159
  return {
150
160
  id: "sso",
161
+ init(ctx) {
162
+ const existing = ctx.skipOriginCheck;
163
+ if (existing === true) {
164
+ return {};
165
+ }
166
+ const existingPaths = Array.isArray(existing) ? existing : [];
167
+ return {
168
+ context: {
169
+ skipOriginCheck: [...existingPaths, ...SAML_SKIP_ORIGIN_CHECK_PATHS],
170
+ },
171
+ };
172
+ },
151
173
  endpoints,
152
174
  hooks: {
153
175
  after: [
package/src/oidc.test.ts CHANGED
@@ -7,7 +7,7 @@ import { afterAll, beforeAll, describe, expect, it } from "vitest";
7
7
  import { sso } from ".";
8
8
  import { ssoClient } from "./client";
9
9
 
10
- let server = new OAuth2Server();
10
+ const server = new OAuth2Server();
11
11
 
12
12
  describe("SSO", async () => {
13
13
  const { auth, signInWithTestUser, customFetchImpl, cookieSetter } =
package/src/routes/sso.ts CHANGED
@@ -11,6 +11,7 @@ import {
11
11
  import {
12
12
  APIError,
13
13
  createAuthEndpoint,
14
+ getSessionFromCtx,
14
15
  sessionMiddleware,
15
16
  } from "better-auth/api";
16
17
  import { setSessionCookie } from "better-auth/cookies";
@@ -52,6 +53,7 @@ import {
52
53
  validateSAMLAlgorithms,
53
54
  validateSingleAssertion,
54
55
  } from "../saml";
56
+ import { generateRelayState, parseRelayState } from "../saml-state";
55
57
  import type { OIDCConfig, SAMLConfig, SSOOptions, SSOProvider } from "../types";
56
58
  import { safeJsonParse, validateEmailDomain } from "../utils";
57
59
 
@@ -165,6 +167,8 @@ const spMetadataQuerySchema = z.object({
165
167
  format: z.enum(["xml", "json"]).default("xml"),
166
168
  });
167
169
 
170
+ type RelayState = Awaited<ReturnType<typeof parseRelayState>>;
171
+
168
172
  export const spMetadata = () => {
169
173
  return createAuthEndpoint(
170
174
  "/sso/saml2/sp/metadata",
@@ -1293,6 +1297,12 @@ export const signInSSO = (options?: SSOOptions) => {
1293
1297
  });
1294
1298
  }
1295
1299
 
1300
+ const { state: relayState } = await generateRelayState(
1301
+ ctx,
1302
+ undefined,
1303
+ false,
1304
+ );
1305
+
1296
1306
  const shouldSaveRequest =
1297
1307
  loginRequest.id && options?.saml?.enableInResponseToValidation;
1298
1308
  if (shouldSaveRequest) {
@@ -1311,9 +1321,7 @@ export const signInSSO = (options?: SSOOptions) => {
1311
1321
  }
1312
1322
 
1313
1323
  return ctx.json({
1314
- url: `${loginRequest.context}&RelayState=${encodeURIComponent(
1315
- body.callbackURL,
1316
- )}`,
1324
+ url: `${loginRequest.context}&RelayState=${encodeURIComponent(relayState)}`,
1317
1325
  redirect: true,
1318
1326
  });
1319
1327
  }
@@ -1683,12 +1691,71 @@ const callbackSSOSAMLBodySchema = z.object({
1683
1691
  RelayState: z.string().optional(),
1684
1692
  });
1685
1693
 
1694
+ /**
1695
+ * Validates and returns a safe redirect URL.
1696
+ * - Prevents open redirect attacks by validating against trusted origins
1697
+ * - Prevents redirect loops by checking if URL points to callback route
1698
+ * - Falls back to appOrigin if URL is invalid or unsafe
1699
+ */
1700
+ const getSafeRedirectUrl = (
1701
+ url: string | undefined,
1702
+ callbackPath: string,
1703
+ appOrigin: string,
1704
+ isTrustedOrigin: (
1705
+ url: string,
1706
+ settings?: { allowRelativePaths: boolean },
1707
+ ) => boolean,
1708
+ ): string => {
1709
+ if (!url) {
1710
+ return appOrigin;
1711
+ }
1712
+
1713
+ if (url.startsWith("/") && !url.startsWith("//")) {
1714
+ try {
1715
+ const absoluteUrl = new URL(url, appOrigin);
1716
+ if (absoluteUrl.origin !== appOrigin) {
1717
+ return appOrigin;
1718
+ }
1719
+ const callbackPathname = new URL(callbackPath).pathname;
1720
+ if (absoluteUrl.pathname === callbackPathname) {
1721
+ return appOrigin;
1722
+ }
1723
+ } catch {
1724
+ return appOrigin;
1725
+ }
1726
+ return url;
1727
+ }
1728
+
1729
+ if (!isTrustedOrigin(url, { allowRelativePaths: false })) {
1730
+ return appOrigin;
1731
+ }
1732
+
1733
+ try {
1734
+ const callbackPathname = new URL(callbackPath).pathname;
1735
+ const urlPathname = new URL(url).pathname;
1736
+ if (urlPathname === callbackPathname) {
1737
+ return appOrigin;
1738
+ }
1739
+ } catch {
1740
+ if (url === callbackPath || url.startsWith(`${callbackPath}?`)) {
1741
+ return appOrigin;
1742
+ }
1743
+ }
1744
+
1745
+ return url;
1746
+ };
1747
+
1686
1748
  export const callbackSSOSAML = (options?: SSOOptions) => {
1687
1749
  return createAuthEndpoint(
1688
1750
  "/sso/saml2/callback/:providerId",
1689
1751
  {
1690
- method: "POST",
1691
- body: callbackSSOSAMLBodySchema,
1752
+ method: ["GET", "POST"],
1753
+ body: callbackSSOSAMLBodySchema.optional(),
1754
+ query: z
1755
+ .object({
1756
+ RelayState: z.string().optional(),
1757
+ })
1758
+ .optional(),
1692
1759
  metadata: {
1693
1760
  ...HIDE_METADATA,
1694
1761
  allowedMediaTypes: [
@@ -1699,7 +1766,7 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1699
1766
  operationId: "handleSAMLCallback",
1700
1767
  summary: "Callback URL for SAML provider",
1701
1768
  description:
1702
- "This endpoint is used as the callback URL for SAML providers.",
1769
+ "This endpoint is used as the callback URL for SAML providers. Supports both GET and POST methods for IdP-initiated and SP-initiated flows.",
1703
1770
  responses: {
1704
1771
  "302": {
1705
1772
  description: "Redirects to the callback URL",
@@ -1715,8 +1782,41 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1715
1782
  },
1716
1783
  },
1717
1784
  async (ctx) => {
1718
- const { SAMLResponse, RelayState } = ctx.body;
1719
1785
  const { providerId } = ctx.params;
1786
+ const appOrigin = new URL(ctx.context.baseURL).origin;
1787
+ const errorURL =
1788
+ ctx.context.options.onAPIError?.errorURL || `${appOrigin}/error`;
1789
+ const currentCallbackPath = `${ctx.context.baseURL}/sso/saml2/callback/${providerId}`;
1790
+
1791
+ // Determine if this is a GET request by checking both method AND body presence
1792
+ // When called via auth.api.*, ctx.method may not be reliable, so we also check for body
1793
+ const isGetRequest = ctx.method === "GET" && !ctx.body?.SAMLResponse;
1794
+
1795
+ if (isGetRequest) {
1796
+ const session = await getSessionFromCtx(ctx);
1797
+
1798
+ if (!session?.session) {
1799
+ throw ctx.redirect(`${errorURL}?error=invalid_request`);
1800
+ }
1801
+
1802
+ const relayState = ctx.query?.RelayState as string | undefined;
1803
+ const safeRedirectUrl = getSafeRedirectUrl(
1804
+ relayState,
1805
+ currentCallbackPath,
1806
+ appOrigin,
1807
+ (url, settings) => ctx.context.isTrustedOrigin(url, settings),
1808
+ );
1809
+
1810
+ throw ctx.redirect(safeRedirectUrl);
1811
+ }
1812
+
1813
+ if (!ctx.body?.SAMLResponse) {
1814
+ throw new APIError("BAD_REQUEST", {
1815
+ message: "SAMLResponse is required for POST requests",
1816
+ });
1817
+ }
1818
+
1819
+ const { SAMLResponse } = ctx.body;
1720
1820
 
1721
1821
  const maxResponseSize =
1722
1822
  options?.saml?.maxResponseSize ?? DEFAULT_MAX_SAML_RESPONSE_SIZE;
@@ -1726,6 +1826,14 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1726
1826
  });
1727
1827
  }
1728
1828
 
1829
+ let relayState: RelayState | null = null;
1830
+ if (ctx.body.RelayState) {
1831
+ try {
1832
+ relayState = await parseRelayState(ctx);
1833
+ } catch {
1834
+ relayState = null;
1835
+ }
1836
+ }
1729
1837
  let provider: SSOProvider<SSOOptions> | null = null;
1730
1838
  if (options?.defaultSSO?.length) {
1731
1839
  const matchingDefault = options.defaultSSO.find(
@@ -1846,7 +1954,7 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1846
1954
  parsedResponse = await sp.parseLoginResponse(idp, "post", {
1847
1955
  body: {
1848
1956
  SAMLResponse,
1849
- RelayState: RelayState || undefined,
1957
+ RelayState: ctx.body.RelayState || undefined,
1850
1958
  },
1851
1959
  });
1852
1960
 
@@ -1909,7 +2017,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1909
2017
  { inResponseTo, providerId: provider.providerId },
1910
2018
  );
1911
2019
  const redirectUrl =
1912
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2020
+ relayState?.callbackURL ||
2021
+ parsedSamlConfig.callbackUrl ||
2022
+ ctx.context.baseURL;
1913
2023
  throw ctx.redirect(
1914
2024
  `${redirectUrl}?error=invalid_saml_response&error_description=Unknown+or+expired+request+ID`,
1915
2025
  );
@@ -1929,7 +2039,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1929
2039
  `${AUTHN_REQUEST_KEY_PREFIX}${inResponseTo}`,
1930
2040
  );
1931
2041
  const redirectUrl =
1932
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2042
+ relayState?.callbackURL ||
2043
+ parsedSamlConfig.callbackUrl ||
2044
+ ctx.context.baseURL;
1933
2045
  throw ctx.redirect(
1934
2046
  `${redirectUrl}?error=invalid_saml_response&error_description=Provider+mismatch`,
1935
2047
  );
@@ -1944,7 +2056,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1944
2056
  { providerId: provider.providerId },
1945
2057
  );
1946
2058
  const redirectUrl =
1947
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2059
+ relayState?.callbackURL ||
2060
+ parsedSamlConfig.callbackUrl ||
2061
+ ctx.context.baseURL;
1948
2062
  throw ctx.redirect(
1949
2063
  `${redirectUrl}?error=unsolicited_response&error_description=IdP-initiated+SSO+not+allowed`,
1950
2064
  );
@@ -1997,7 +2111,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
1997
2111
  },
1998
2112
  );
1999
2113
  const redirectUrl =
2000
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2114
+ relayState?.callbackURL ||
2115
+ parsedSamlConfig.callbackUrl ||
2116
+ ctx.context.baseURL;
2001
2117
  throw ctx.redirect(
2002
2118
  `${redirectUrl}?error=replay_detected&error_description=SAML+assertion+has+already+been+used`,
2003
2119
  );
@@ -2071,7 +2187,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
2071
2187
  validateEmailDomain(userInfo.email as string, provider.domain));
2072
2188
 
2073
2189
  const callbackUrl =
2074
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
2190
+ relayState?.callbackURL ||
2191
+ parsedSamlConfig.callbackUrl ||
2192
+ ctx.context.baseURL;
2075
2193
 
2076
2194
  const result = await handleOAuthUserInfo(ctx, {
2077
2195
  userInfo: {
@@ -2122,7 +2240,14 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
2122
2240
  });
2123
2241
 
2124
2242
  await setSessionCookie(ctx, { session, user });
2125
- throw ctx.redirect(callbackUrl);
2243
+
2244
+ const safeRedirectUrl = getSafeRedirectUrl(
2245
+ relayState?.callbackURL || parsedSamlConfig.callbackUrl,
2246
+ currentCallbackPath,
2247
+ appOrigin,
2248
+ (url, settings) => ctx.context.isTrustedOrigin(url, settings),
2249
+ );
2250
+ throw ctx.redirect(safeRedirectUrl);
2126
2251
  },
2127
2252
  );
2128
2253
  };