@better-auth/core 1.4.18 → 1.4.19

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,241 @@
1
+ import type { JWK } from "jose";
2
+ import { exportJWK, generateKeyPair, SignJWT } from "jose";
3
+ import {
4
+ afterAll,
5
+ beforeAll,
6
+ beforeEach,
7
+ describe,
8
+ expect,
9
+ it,
10
+ vi,
11
+ } from "vitest";
12
+ import { validateToken } from "./validate-authorization-code";
13
+
14
+ describe("validateToken", () => {
15
+ const originalFetch = globalThis.fetch;
16
+ const mockedFetch = vi.fn() as unknown as typeof fetch &
17
+ ReturnType<typeof vi.fn>;
18
+
19
+ beforeAll(() => {
20
+ globalThis.fetch = mockedFetch;
21
+ });
22
+
23
+ afterAll(() => {
24
+ globalThis.fetch = originalFetch;
25
+ });
26
+
27
+ beforeEach(() => {
28
+ mockedFetch.mockReset();
29
+ });
30
+
31
+ async function createTestJWKS(alg: string, crv?: string) {
32
+ const { publicKey, privateKey } = await generateKeyPair(alg, {
33
+ crv,
34
+ extractable: true,
35
+ });
36
+ const publicJWK = await exportJWK(publicKey);
37
+ const privateJWK = await exportJWK(privateKey);
38
+ const kid = `test-key-${Date.now()}`;
39
+ publicJWK.kid = kid;
40
+ publicJWK.alg = alg;
41
+ privateJWK.kid = kid;
42
+ privateJWK.alg = alg;
43
+ return { publicJWK, privateJWK, kid, publicKey, privateKey };
44
+ }
45
+
46
+ async function createSignedToken(
47
+ privateKey: CryptoKey,
48
+ alg: string,
49
+ kid: string,
50
+ payload: Record<string, unknown> = {},
51
+ ) {
52
+ return await new SignJWT({
53
+ sub: "user-123",
54
+ email: "test@example.com",
55
+ iss: "https://example.com",
56
+ aud: "test-client",
57
+ ...payload,
58
+ })
59
+ .setProtectedHeader({ alg, kid })
60
+ .setIssuedAt()
61
+ .setExpirationTime("1h")
62
+ .sign(privateKey);
63
+ }
64
+
65
+ function mockJWKSResponse(...publicJWKs: JWK[]) {
66
+ mockedFetch.mockResolvedValueOnce(
67
+ new Response(JSON.stringify({ keys: publicJWKs }), {
68
+ status: 200,
69
+ headers: { "content-type": "application/json" },
70
+ }),
71
+ );
72
+ }
73
+
74
+ it("should verify RS256 signed token", async () => {
75
+ const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
76
+ const token = await createSignedToken(privateKey, "RS256", kid);
77
+ mockJWKSResponse(publicJWK);
78
+
79
+ const result = await validateToken(
80
+ token,
81
+ "https://example.com/.well-known/jwks",
82
+ );
83
+
84
+ expect(result).toBeDefined();
85
+ expect(result.payload.sub).toBe("user-123");
86
+ expect(result.payload.email).toBe("test@example.com");
87
+ });
88
+
89
+ it("should verify ES256 signed token", async () => {
90
+ const { publicJWK, privateKey, kid } = await createTestJWKS("ES256");
91
+ const token = await createSignedToken(privateKey, "ES256", kid);
92
+ mockJWKSResponse(publicJWK);
93
+
94
+ const result = await validateToken(
95
+ token,
96
+ "https://example.com/.well-known/jwks",
97
+ );
98
+
99
+ expect(result).toBeDefined();
100
+ expect(result.payload.sub).toBe("user-123");
101
+ });
102
+
103
+ it("should verify EdDSA (Ed25519) signed token", async () => {
104
+ const { publicJWK, privateKey, kid } = await createTestJWKS(
105
+ "EdDSA",
106
+ "Ed25519",
107
+ );
108
+ const token = await createSignedToken(privateKey, "EdDSA", kid);
109
+ mockJWKSResponse(publicJWK);
110
+
111
+ const result = await validateToken(
112
+ token,
113
+ "https://example.com/.well-known/jwks",
114
+ );
115
+
116
+ expect(result).toBeDefined();
117
+ expect(result.payload.sub).toBe("user-123");
118
+ });
119
+
120
+ it("should throw when kid doesn't match any key", async () => {
121
+ const { publicJWK, privateKey } = await createTestJWKS("RS256");
122
+ publicJWK.kid = "different-kid";
123
+ const token = await createSignedToken(privateKey, "RS256", "original-kid");
124
+ mockJWKSResponse(publicJWK);
125
+
126
+ await expect(
127
+ validateToken(token, "https://example.com/.well-known/jwks"),
128
+ ).rejects.toThrow();
129
+ });
130
+
131
+ it("should find correct key when multiple keys exist", async () => {
132
+ const key1 = await createTestJWKS("RS256");
133
+ const key2 = await createTestJWKS("RS256");
134
+ const key3 = await createTestJWKS("ES256");
135
+ const token = await createSignedToken(key2.privateKey, "RS256", key2.kid);
136
+ mockJWKSResponse(key1.publicJWK, key2.publicJWK, key3.publicJWK);
137
+
138
+ const result = await validateToken(
139
+ token,
140
+ "https://example.com/.well-known/jwks",
141
+ );
142
+
143
+ expect(result).toBeDefined();
144
+ expect(result.payload.sub).toBe("user-123");
145
+ });
146
+
147
+ it("should throw when JWKS returns empty keys array", async () => {
148
+ const { privateKey, kid } = await createTestJWKS("RS256");
149
+ const token = await createSignedToken(privateKey, "RS256", kid);
150
+ mockJWKSResponse();
151
+
152
+ await expect(
153
+ validateToken(token, "https://example.com/.well-known/jwks"),
154
+ ).rejects.toThrow();
155
+ });
156
+
157
+ it("should throw when JWKS fetch fails", async () => {
158
+ const { privateKey, kid } = await createTestJWKS("RS256");
159
+ const token = await createSignedToken(privateKey, "RS256", kid);
160
+ mockedFetch.mockResolvedValueOnce(
161
+ new Response("Internal Server Error", { status: 500 }),
162
+ );
163
+
164
+ await expect(
165
+ validateToken(token, "https://example.com/.well-known/jwks"),
166
+ ).rejects.toBeDefined();
167
+ });
168
+
169
+ it("should verify token with matching audience", async () => {
170
+ const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
171
+ const token = await createSignedToken(privateKey, "RS256", kid);
172
+ mockJWKSResponse(publicJWK);
173
+
174
+ const result = await validateToken(
175
+ token,
176
+ "https://example.com/.well-known/jwks",
177
+ { audience: "test-client" },
178
+ );
179
+
180
+ expect(result).toBeDefined();
181
+ expect(result.payload.aud).toBe("test-client");
182
+ });
183
+
184
+ it("should reject token with mismatched audience", async () => {
185
+ const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
186
+ const token = await createSignedToken(privateKey, "RS256", kid);
187
+ mockJWKSResponse(publicJWK);
188
+
189
+ await expect(
190
+ validateToken(token, "https://example.com/.well-known/jwks", {
191
+ audience: "wrong-client",
192
+ }),
193
+ ).rejects.toThrow();
194
+ });
195
+
196
+ it("should verify token with matching issuer", async () => {
197
+ const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
198
+ const token = await createSignedToken(privateKey, "RS256", kid);
199
+ mockJWKSResponse(publicJWK);
200
+
201
+ const result = await validateToken(
202
+ token,
203
+ "https://example.com/.well-known/jwks",
204
+ { issuer: "https://example.com" },
205
+ );
206
+
207
+ expect(result).toBeDefined();
208
+ expect(result.payload.iss).toBe("https://example.com");
209
+ });
210
+
211
+ it("should reject token with mismatched issuer", async () => {
212
+ const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
213
+ const token = await createSignedToken(privateKey, "RS256", kid);
214
+ mockJWKSResponse(publicJWK);
215
+
216
+ await expect(
217
+ validateToken(token, "https://example.com/.well-known/jwks", {
218
+ issuer: "https://wrong-issuer.com",
219
+ }),
220
+ ).rejects.toThrow();
221
+ });
222
+
223
+ it("should verify token with both audience and issuer", async () => {
224
+ const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
225
+ const token = await createSignedToken(privateKey, "RS256", kid);
226
+ mockJWKSResponse(publicJWK);
227
+
228
+ const result = await validateToken(
229
+ token,
230
+ "https://example.com/.well-known/jwks",
231
+ {
232
+ audience: "test-client",
233
+ issuer: "https://example.com",
234
+ },
235
+ );
236
+
237
+ expect(result).toBeDefined();
238
+ expect(result.payload.aud).toBe("test-client");
239
+ expect(result.payload.iss).toBe("https://example.com");
240
+ });
241
+ });
@@ -111,30 +111,34 @@ export const apple = (options: AppleOptions) => {
111
111
  if (options.verifyIdToken) {
112
112
  return options.verifyIdToken(token, nonce);
113
113
  }
114
- const decodedHeader = decodeProtectedHeader(token);
115
- const { kid, alg: jwtAlg } = decodedHeader;
116
- if (!kid || !jwtAlg) return false;
117
- const publicKey = await getApplePublicKey(kid);
118
- const { payload: jwtClaims } = await jwtVerify(token, publicKey, {
119
- algorithms: [jwtAlg],
120
- issuer: "https://appleid.apple.com",
121
- audience:
122
- options.audience && options.audience.length
123
- ? options.audience
124
- : options.appBundleIdentifier
125
- ? options.appBundleIdentifier
126
- : options.clientId,
127
- maxTokenAge: "1h",
128
- });
129
- ["email_verified", "is_private_email"].forEach((field) => {
130
- if (jwtClaims[field] !== undefined) {
131
- jwtClaims[field] = Boolean(jwtClaims[field]);
114
+ try {
115
+ const decodedHeader = decodeProtectedHeader(token);
116
+ const { kid, alg: jwtAlg } = decodedHeader;
117
+ if (!kid || !jwtAlg) return false;
118
+ const publicKey = await getApplePublicKey(kid);
119
+ const { payload: jwtClaims } = await jwtVerify(token, publicKey, {
120
+ algorithms: [jwtAlg],
121
+ issuer: "https://appleid.apple.com",
122
+ audience:
123
+ options.audience && options.audience.length
124
+ ? options.audience
125
+ : options.appBundleIdentifier
126
+ ? options.appBundleIdentifier
127
+ : options.clientId,
128
+ maxTokenAge: "1h",
129
+ });
130
+ ["email_verified", "is_private_email"].forEach((field) => {
131
+ if (jwtClaims[field] !== undefined) {
132
+ jwtClaims[field] = Boolean(jwtClaims[field]);
133
+ }
134
+ });
135
+ if (nonce && jwtClaims.nonce !== nonce) {
136
+ return false;
132
137
  }
133
- });
134
- if (nonce && jwtClaims.nonce !== nonce) {
138
+ return !!jwtClaims;
139
+ } catch {
135
140
  return false;
136
141
  }
137
- return !!jwtClaims;
138
142
  },
139
143
  refreshAccessToken: options.refreshAccessToken
140
144
  ? options.refreshAccessToken
@@ -131,22 +131,26 @@ export const google = (options: GoogleOptions) => {
131
131
  // Verify JWT integrity
132
132
  // See https://developers.google.com/identity/sign-in/web/backend-auth#verify-the-integrity-of-the-id-token
133
133
 
134
- const { kid, alg: jwtAlg } = decodeProtectedHeader(token);
135
- if (!kid || !jwtAlg) return false;
134
+ try {
135
+ const { kid, alg: jwtAlg } = decodeProtectedHeader(token);
136
+ if (!kid || !jwtAlg) return false;
136
137
 
137
- const publicKey = await getGooglePublicKey(kid);
138
- const { payload: jwtClaims } = await jwtVerify(token, publicKey, {
139
- algorithms: [jwtAlg],
140
- issuer: ["https://accounts.google.com", "accounts.google.com"],
141
- audience: options.clientId,
142
- maxTokenAge: "1h",
143
- });
138
+ const publicKey = await getGooglePublicKey(kid);
139
+ const { payload: jwtClaims } = await jwtVerify(token, publicKey, {
140
+ algorithms: [jwtAlg],
141
+ issuer: ["https://accounts.google.com", "accounts.google.com"],
142
+ audience: options.clientId,
143
+ maxTokenAge: "1h",
144
+ });
145
+
146
+ if (nonce && jwtClaims.nonce !== nonce) {
147
+ return false;
148
+ }
144
149
 
145
- if (nonce && jwtClaims.nonce !== nonce) {
150
+ return true;
151
+ } catch {
146
152
  return false;
147
153
  }
148
-
149
- return true;
150
154
  },
151
155
  async getUserInfo(token) {
152
156
  if (options.getUserInfo) {
@@ -1,6 +1,7 @@
1
1
  import { base64 } from "@better-auth/utils/base64";
2
2
  import { betterFetch } from "@better-fetch/fetch";
3
- import { decodeJwt } from "jose";
3
+ import { APIError } from "better-call";
4
+ import { decodeJwt, decodeProtectedHeader, importJWK, jwtVerify } from "jose";
4
5
  import { logger } from "../env";
5
6
  import type { OAuthProvider, ProviderOptions } from "../oauth2";
6
7
  import {
@@ -174,6 +175,56 @@ export const microsoft = (options: MicrosoftOptions) => {
174
175
  tokenEndpoint,
175
176
  });
176
177
  },
178
+ async verifyIdToken(token, nonce) {
179
+ if (options.disableIdTokenSignIn) {
180
+ return false;
181
+ }
182
+ if (options.verifyIdToken) {
183
+ return options.verifyIdToken(token, nonce);
184
+ }
185
+
186
+ try {
187
+ const { kid, alg: jwtAlg } = decodeProtectedHeader(token);
188
+ if (!kid || !jwtAlg) return false;
189
+
190
+ const publicKey = await getMicrosoftPublicKey(kid, tenant, authority);
191
+ const verifyOptions: {
192
+ algorithms: [string];
193
+ audience: string;
194
+ maxTokenAge: string;
195
+ issuer?: string;
196
+ } = {
197
+ algorithms: [jwtAlg],
198
+ audience: options.clientId,
199
+ maxTokenAge: "1h",
200
+ };
201
+ /**
202
+ * Issuer varies per user's tenant for multi-tenant endpoints, so only validate for specific tenants.
203
+ * @see https://learn.microsoft.com/en-us/entra/identity-platform/v2-protocols#endpoints
204
+ */
205
+ if (
206
+ tenant !== "common" &&
207
+ tenant !== "organizations" &&
208
+ tenant !== "consumers"
209
+ ) {
210
+ verifyOptions.issuer = `${authority}/${tenant}/v2.0`;
211
+ }
212
+ const { payload: jwtClaims } = await jwtVerify(
213
+ token,
214
+ publicKey,
215
+ verifyOptions,
216
+ );
217
+
218
+ if (nonce && jwtClaims.nonce !== nonce) {
219
+ return false;
220
+ }
221
+
222
+ return true;
223
+ } catch (error) {
224
+ logger.error("Failed to verify ID token:", error);
225
+ return false;
226
+ }
227
+ },
177
228
  async getUserInfo(token) {
178
229
  if (options.getUserInfo) {
179
230
  return options.getUserInfo(token);
@@ -257,3 +308,35 @@ export const microsoft = (options: MicrosoftOptions) => {
257
308
  options,
258
309
  } satisfies OAuthProvider;
259
310
  };
311
+
312
+ export const getMicrosoftPublicKey = async (
313
+ kid: string,
314
+ tenant: string,
315
+ authority: string,
316
+ ) => {
317
+ const { data } = await betterFetch<{
318
+ keys: Array<{
319
+ kid: string;
320
+ alg: string;
321
+ kty: string;
322
+ use: string;
323
+ n: string;
324
+ e: string;
325
+ x5c?: string[];
326
+ x5t?: string;
327
+ }>;
328
+ }>(`${authority}/${tenant}/discovery/v2.0/keys`);
329
+
330
+ if (!data?.keys) {
331
+ throw new APIError("BAD_REQUEST", {
332
+ message: "Keys not found",
333
+ });
334
+ }
335
+
336
+ const jwk = data.keys.find((key) => key.kid === kid);
337
+ if (!jwk) {
338
+ throw new Error(`JWK with kid ${kid} not found`);
339
+ }
340
+
341
+ return await importJWK(jwk, jwk.alg);
342
+ };