@better-auth/sso 1.4.0-beta.21 → 1.4.0-beta.23

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,550 @@
1
+ import { betterAuth } from "better-auth";
2
+ import { memoryAdapter } from "better-auth/adapters/memory";
3
+ import { createAuthClient } from "better-auth/client";
4
+ import { setCookieToHeader } from "better-auth/cookies";
5
+ import { bearer, organization } from "better-auth/plugins";
6
+ import { afterEach, describe, expect, it, vi } from "vitest";
7
+ import { sso } from ".";
8
+ import { ssoClient } from "./client";
9
+ import type { SSOOptions } from "./types";
10
+
11
+ const dnsMock = vi.hoisted(() => {
12
+ return {
13
+ resolveTxt: vi.fn(),
14
+ };
15
+ });
16
+
17
+ vi.mock("node:dns/promises", () => {
18
+ return {
19
+ ...dnsMock,
20
+ default: dnsMock,
21
+ };
22
+ });
23
+
24
+ describe("Domain verification", async () => {
25
+ type TestUser = { email: string; password: string; name: string };
26
+ const testUser: TestUser = {
27
+ email: "test@email.com",
28
+ password: "password",
29
+ name: "Test User",
30
+ };
31
+
32
+ const createTestAuth = (options?: SSOOptions) => {
33
+ const data = {
34
+ user: [],
35
+ session: [],
36
+ verification: [],
37
+ account: [],
38
+ ssoProvider: [],
39
+ member: [],
40
+ organization: [],
41
+ };
42
+
43
+ const memory = memoryAdapter(data);
44
+
45
+ const ssoOptions = {
46
+ ...options,
47
+ domainVerification: {
48
+ ...options?.domainVerification,
49
+ enabled: true,
50
+ },
51
+ } satisfies SSOOptions;
52
+
53
+ const auth = betterAuth({
54
+ database: memory,
55
+ baseURL: "http://localhost:3000",
56
+ emailAndPassword: {
57
+ enabled: true,
58
+ },
59
+ plugins: [sso(ssoOptions), organization()],
60
+ });
61
+
62
+ const authClient = createAuthClient({
63
+ baseURL: "http://localhost:3000",
64
+ plugins: [bearer(), ssoClient({ domainVerification: { enabled: true } })],
65
+ fetchOptions: {
66
+ customFetchImpl: async (url, init) => {
67
+ return auth.handler(new Request(url, init));
68
+ },
69
+ },
70
+ });
71
+
72
+ async function createOrganization(name: string, headers: Headers) {
73
+ return await auth.api.createOrganization({
74
+ body: {
75
+ name,
76
+ slug: name,
77
+ },
78
+ headers,
79
+ });
80
+ }
81
+
82
+ async function getAuthHeaders(user: TestUser, organizationId?: string) {
83
+ const headers = new Headers();
84
+ const response = await authClient.signUp.email({
85
+ email: user.email,
86
+ password: user.password,
87
+ name: user.name,
88
+ });
89
+
90
+ if (response.data && organizationId) {
91
+ await auth.api.addMember({
92
+ body: {
93
+ userId: response.data.user.id,
94
+ role: "member",
95
+ },
96
+ headers,
97
+ });
98
+ }
99
+
100
+ await authClient.signIn.email(user, {
101
+ throw: true,
102
+ onSuccess: setCookieToHeader(headers),
103
+ });
104
+
105
+ return headers;
106
+ }
107
+
108
+ async function registerSSOProvider(
109
+ headers: Headers,
110
+ organizationId?: string,
111
+ ) {
112
+ return auth.api.registerSSOProvider({
113
+ body: {
114
+ providerId: "saml-provider-1",
115
+ issuer: "http://hello.com:8081",
116
+ domain: "http://hello.com:8081",
117
+ samlConfig: {
118
+ entryPoint: "http://idp.com:",
119
+ cert: "the-cert",
120
+ callbackUrl: "http://hello.com:8081/api/sso/saml2/callback",
121
+ spMetadata: {},
122
+ },
123
+ organizationId,
124
+ },
125
+ headers,
126
+ });
127
+ }
128
+
129
+ return {
130
+ auth,
131
+ authClient,
132
+ registerSSOProvider,
133
+ getAuthHeaders,
134
+ createOrganization,
135
+ };
136
+ };
137
+
138
+ afterEach(() => {
139
+ vi.clearAllMocks();
140
+ vi.useRealTimers();
141
+ });
142
+
143
+ describe("POST /sso/request-domain-verification", () => {
144
+ it("should return unauthorized when session is missing", async () => {
145
+ const { auth } = createTestAuth();
146
+ const response = await auth.api.requestDomainVerification({
147
+ body: {
148
+ providerId: "the-provider",
149
+ },
150
+ asResponse: true,
151
+ });
152
+
153
+ expect(response.status).toBe(401);
154
+ });
155
+
156
+ it("should return not found when no provider is found", async () => {
157
+ const { auth, getAuthHeaders } = createTestAuth();
158
+ const headers = await getAuthHeaders(testUser);
159
+ const response = await auth.api.requestDomainVerification({
160
+ body: {
161
+ providerId: "unknown",
162
+ },
163
+ headers,
164
+ asResponse: true,
165
+ });
166
+
167
+ expect(response.status).toBe(404);
168
+ expect(await response.json()).toEqual({
169
+ message: "Provider not found",
170
+ code: "PROVIDER_NOT_FOUND",
171
+ });
172
+ });
173
+
174
+ it("should return the existing active verification token", async () => {
175
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
176
+ const headers = await getAuthHeaders(testUser);
177
+ const provider = await registerSSOProvider(headers);
178
+
179
+ vi.useFakeTimers({ toFake: ["Date"] });
180
+
181
+ const newAuthHeaders = await getAuthHeaders(testUser);
182
+
183
+ const response = await auth.api.requestDomainVerification({
184
+ body: {
185
+ providerId: provider.providerId,
186
+ },
187
+ headers: newAuthHeaders,
188
+ asResponse: true,
189
+ });
190
+
191
+ expect(response.status).toBe(201);
192
+ expect(await response.json()).toEqual({
193
+ domainVerificationToken: provider.domainVerificationToken,
194
+ });
195
+ });
196
+
197
+ it("should return forbidden if user does not own the provider", async () => {
198
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
199
+ const headers = await getAuthHeaders(testUser);
200
+ const provider = await registerSSOProvider(headers);
201
+
202
+ const notOwnerHeaders = await getAuthHeaders({
203
+ name: "other",
204
+ email: "other@test.com",
205
+ password: "password",
206
+ });
207
+ const response = await auth.api.requestDomainVerification({
208
+ body: {
209
+ providerId: provider.providerId,
210
+ },
211
+ headers: notOwnerHeaders,
212
+ asResponse: true,
213
+ });
214
+
215
+ expect(response.status).toBe(403);
216
+ expect(await response.json()).toEqual({
217
+ message:
218
+ "User must be owner of or belong to the SSO provider organization",
219
+ code: "INSUFICCIENT_ACCESS",
220
+ });
221
+ });
222
+
223
+ it("should return forbidden if user does not belong to the provider organization", async () => {
224
+ const { auth, getAuthHeaders, registerSSOProvider, createOrganization } =
225
+ createTestAuth();
226
+ const headers = await getAuthHeaders(testUser);
227
+
228
+ const orgA = await createOrganization("org-a", headers);
229
+ const orgB = await createOrganization("org-b", headers);
230
+
231
+ const provider = await registerSSOProvider(headers, orgA?.id);
232
+
233
+ const notOrgHeaders = await getAuthHeaders(
234
+ {
235
+ name: "other",
236
+ email: "other@test.com",
237
+ password: "password",
238
+ },
239
+ orgB?.id,
240
+ );
241
+
242
+ const response = await auth.api.requestDomainVerification({
243
+ body: {
244
+ providerId: provider.providerId,
245
+ },
246
+ headers: notOrgHeaders,
247
+ asResponse: true,
248
+ });
249
+
250
+ expect(response.status).toBe(403);
251
+ expect(await response.json()).toEqual({
252
+ message:
253
+ "User must be owner of or belong to the SSO provider organization",
254
+ code: "INSUFICCIENT_ACCESS",
255
+ });
256
+ });
257
+
258
+ it("should return a new domain verification token", async () => {
259
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
260
+ const headers = await getAuthHeaders(testUser);
261
+ const provider = await registerSSOProvider(headers);
262
+
263
+ vi.useFakeTimers({ toFake: ["Date"] });
264
+ vi.advanceTimersByTime(Date.now() + 3600 * 24 * 7 * 1000 + 10); // advance 1 week + 10 seconds
265
+
266
+ const newHeaders = await getAuthHeaders(testUser);
267
+ const response = await auth.api.requestDomainVerification({
268
+ body: {
269
+ providerId: provider.providerId,
270
+ },
271
+ headers: newHeaders,
272
+ asResponse: true,
273
+ });
274
+
275
+ expect(response.status).toBe(201);
276
+ expect(await response.json()).toMatchObject({
277
+ domainVerificationToken: expect.any(String),
278
+ });
279
+ });
280
+
281
+ it("should fail to create a new token on an already verified domain", async () => {
282
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
283
+ const headers = await getAuthHeaders(testUser);
284
+ const provider = await registerSSOProvider(headers);
285
+
286
+ dnsMock.resolveTxt.mockResolvedValue([
287
+ [
288
+ `better-auth-token-saml-provider-1=${provider.domainVerificationToken}`,
289
+ ],
290
+ ]);
291
+
292
+ const domainVerificationResponse = await auth.api.verifyDomain({
293
+ body: {
294
+ providerId: provider.providerId,
295
+ },
296
+ headers,
297
+ asResponse: true,
298
+ });
299
+
300
+ expect(domainVerificationResponse.status).toBe(204);
301
+
302
+ const domainVerificationSubmissionResponse =
303
+ await auth.api.requestDomainVerification({
304
+ body: {
305
+ providerId: provider.providerId,
306
+ },
307
+ headers,
308
+ asResponse: true,
309
+ });
310
+
311
+ expect(domainVerificationSubmissionResponse.status).toBe(409);
312
+ expect(await domainVerificationSubmissionResponse.json()).toEqual({
313
+ message: "Domain has already been verified",
314
+ code: "DOMAIN_VERIFIED",
315
+ });
316
+ });
317
+ });
318
+
319
+ describe("POST /sso/verify-domain", () => {
320
+ it("should return unauthorized when session is missing", async () => {
321
+ const { auth } = createTestAuth();
322
+ const response = await auth.api.verifyDomain({
323
+ body: {
324
+ providerId: "the-provider",
325
+ },
326
+ asResponse: true,
327
+ });
328
+
329
+ expect(response.status).toBe(401);
330
+ });
331
+
332
+ it("should return not found when no provider is found", async () => {
333
+ const { auth, getAuthHeaders } = createTestAuth();
334
+ const headers = await getAuthHeaders(testUser);
335
+ const response = await auth.api.verifyDomain({
336
+ body: {
337
+ providerId: "unknown",
338
+ },
339
+ headers,
340
+ asResponse: true,
341
+ });
342
+
343
+ expect(response.status).toBe(404);
344
+ expect(await response.json()).toEqual({
345
+ message: "Provider not found",
346
+ code: "PROVIDER_NOT_FOUND",
347
+ });
348
+ });
349
+
350
+ it("should return not found when no pending verification is found", async () => {
351
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
352
+ const headers = await getAuthHeaders(testUser);
353
+ const provider = await registerSSOProvider(headers);
354
+
355
+ vi.useFakeTimers({ toFake: ["Date"] });
356
+ vi.advanceTimersByTime(Date.now() + 3600 * 24 * 7 * 1000 + 10); // advance 1 week + 10 seconds
357
+
358
+ const newAuthHeaders = await getAuthHeaders(testUser);
359
+
360
+ const response = await auth.api.verifyDomain({
361
+ body: {
362
+ providerId: provider.providerId,
363
+ },
364
+ headers: newAuthHeaders,
365
+ asResponse: true,
366
+ });
367
+
368
+ expect(response.status).toBe(404);
369
+ expect(await response.json()).toEqual({
370
+ message: "No pending domain verification exists",
371
+ code: "NO_PENDING_VERIFICATION",
372
+ });
373
+ });
374
+
375
+ it("should return bad gateway when unable to verify domain", async () => {
376
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
377
+ const headers = await getAuthHeaders(testUser);
378
+ const provider = await registerSSOProvider(headers);
379
+
380
+ dnsMock.resolveTxt.mockResolvedValue([
381
+ ["google-site-verification=the-token"],
382
+ ]);
383
+
384
+ const response = await auth.api.verifyDomain({
385
+ body: {
386
+ providerId: provider.providerId,
387
+ },
388
+ headers,
389
+ asResponse: true,
390
+ });
391
+
392
+ expect(response.status).toBe(502);
393
+ expect(await response.json()).toEqual({
394
+ message: "Unable to verify domain ownership. Try again later",
395
+ code: "DOMAIN_VERIFICATION_FAILED",
396
+ });
397
+ });
398
+
399
+ it("should return forbidden if user does not own the provider", async () => {
400
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
401
+ const headers = await getAuthHeaders(testUser);
402
+ const provider = await registerSSOProvider(headers);
403
+
404
+ const notOwnerHeaders = await getAuthHeaders({
405
+ name: "other",
406
+ email: "other@test.com",
407
+ password: "password",
408
+ });
409
+ const response = await auth.api.verifyDomain({
410
+ body: {
411
+ providerId: provider.providerId,
412
+ },
413
+ headers: notOwnerHeaders,
414
+ asResponse: true,
415
+ });
416
+
417
+ expect(response.status).toBe(403);
418
+ expect(await response.json()).toEqual({
419
+ message:
420
+ "User must be owner of or belong to the SSO provider organization",
421
+ code: "INSUFICCIENT_ACCESS",
422
+ });
423
+ });
424
+
425
+ it("should return forbidden if user does not belong to the provider organization", async () => {
426
+ const { auth, getAuthHeaders, registerSSOProvider, createOrganization } =
427
+ createTestAuth();
428
+ const headers = await getAuthHeaders(testUser);
429
+ const orgA = await createOrganization("org-a", headers);
430
+ const orgB = await createOrganization("org-b", headers);
431
+
432
+ const provider = await registerSSOProvider(headers, orgA?.id);
433
+
434
+ const notOrgHeaders = await getAuthHeaders(
435
+ {
436
+ name: "other",
437
+ email: "other@test.com",
438
+ password: "password",
439
+ },
440
+ orgB?.id,
441
+ );
442
+ const response = await auth.api.verifyDomain({
443
+ body: {
444
+ providerId: provider.providerId,
445
+ },
446
+ headers: notOrgHeaders,
447
+ asResponse: true,
448
+ });
449
+
450
+ expect(response.status).toBe(403);
451
+ expect(await response.json()).toEqual({
452
+ message:
453
+ "User must be owner of or belong to the SSO provider organization",
454
+ code: "INSUFICCIENT_ACCESS",
455
+ });
456
+ });
457
+
458
+ it("should verify a provider domain ownership", async () => {
459
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
460
+ const headers = await getAuthHeaders(testUser);
461
+ const provider = await registerSSOProvider(headers);
462
+
463
+ expect(provider.domain).toBe("http://hello.com:8081");
464
+ expect(provider.domainVerified).toBe(false);
465
+ expect(provider.domainVerificationToken).toBeTypeOf("string");
466
+
467
+ dnsMock.resolveTxt.mockResolvedValue([
468
+ ["google-site-verification=the-token"],
469
+ [
470
+ "v=spf1 ip4:50.242.118.232/29 include:_spf.google.com include:mail.zendesk.com ~all",
471
+ ],
472
+ [
473
+ `better-auth-token-saml-provider-1=${provider.domainVerificationToken}`,
474
+ ],
475
+ ]);
476
+
477
+ const response = await auth.api.verifyDomain({
478
+ body: {
479
+ providerId: provider.providerId,
480
+ },
481
+ headers,
482
+ asResponse: true,
483
+ });
484
+
485
+ expect(response.status).toBe(204);
486
+ });
487
+
488
+ it("should verify a provider domain ownership (custom token verification prefix)", async () => {
489
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth({
490
+ domainVerification: { tokenPrefix: "auth-prefix" },
491
+ });
492
+ const headers = await getAuthHeaders(testUser);
493
+ const provider = await registerSSOProvider(headers);
494
+
495
+ dnsMock.resolveTxt.mockResolvedValue([
496
+ ["google-site-verification=the-token"],
497
+ [
498
+ "v=spf1 ip4:50.242.118.232/29 include:_spf.google.com include:mail.zendesk.com ~all",
499
+ ],
500
+ [`auth-prefix-saml-provider-1=${provider.domainVerificationToken}`],
501
+ ]);
502
+
503
+ const response = await auth.api.verifyDomain({
504
+ body: {
505
+ providerId: provider.providerId,
506
+ },
507
+ headers,
508
+ asResponse: true,
509
+ });
510
+
511
+ expect(response.status).toBe(204);
512
+ });
513
+
514
+ it("should fail to verify an already verified domain", async () => {
515
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
516
+ const headers = await getAuthHeaders(testUser);
517
+ const provider = await registerSSOProvider(headers);
518
+
519
+ dnsMock.resolveTxt.mockResolvedValue([
520
+ [
521
+ `better-auth-token-saml-provider-1=${provider.domainVerificationToken}`,
522
+ ],
523
+ ]);
524
+
525
+ const firstResponse = await auth.api.verifyDomain({
526
+ body: {
527
+ providerId: provider.providerId,
528
+ },
529
+ headers,
530
+ asResponse: true,
531
+ });
532
+
533
+ expect(firstResponse.status).toBe(204);
534
+
535
+ const secondResponse = await auth.api.verifyDomain({
536
+ body: {
537
+ providerId: provider.providerId,
538
+ },
539
+ headers,
540
+ asResponse: true,
541
+ });
542
+
543
+ expect(secondResponse.status).toBe(409);
544
+ expect(await secondResponse.json()).toEqual({
545
+ message: "Domain has already been verified",
546
+ code: "DOMAIN_VERIFIED",
547
+ });
548
+ });
549
+ });
550
+ });
package/src/index.ts CHANGED
@@ -1,6 +1,10 @@
1
- import { type BetterAuthPlugin } from "better-auth";
1
+ import type { BetterAuthPlugin } from "better-auth";
2
2
  import { XMLValidator } from "fast-xml-parser";
3
3
  import * as saml from "samlify";
4
+ import {
5
+ requestDomainVerification,
6
+ verifyDomain,
7
+ } from "./routes/domain-verification";
4
8
  import {
5
9
  acsEndpoint,
6
10
  callbackSSO,
@@ -25,47 +29,90 @@ const fastValidator = {
25
29
 
26
30
  saml.setSchemaValidator(fastValidator);
27
31
 
28
- type SSOEndpoints = {
32
+ type DomainVerificationEndpoints = {
33
+ requestDomainVerification: ReturnType<typeof requestDomainVerification>;
34
+ verifyDomain: ReturnType<typeof verifyDomain>;
35
+ };
36
+
37
+ type SSOEndpoints<O extends SSOOptions> = {
29
38
  spMetadata: ReturnType<typeof spMetadata>;
30
- registerSSOProvider: ReturnType<typeof registerSSOProvider>;
39
+ registerSSOProvider: ReturnType<typeof registerSSOProvider<O>>;
31
40
  signInSSO: ReturnType<typeof signInSSO>;
32
41
  callbackSSO: ReturnType<typeof callbackSSO>;
33
42
  callbackSSOSAML: ReturnType<typeof callbackSSOSAML>;
34
43
  acsEndpoint: ReturnType<typeof acsEndpoint>;
35
44
  };
36
45
 
46
+ export type SSOPlugin<O extends SSOOptions> = {
47
+ id: "sso";
48
+ endpoints: SSOEndpoints<O> &
49
+ (O extends { domainVerification: { enabled: true } }
50
+ ? DomainVerificationEndpoints
51
+ : {});
52
+ };
53
+
54
+ export function sso<
55
+ O extends SSOOptions & {
56
+ domainVerification?: { enabled: true };
57
+ },
58
+ >(
59
+ options?: O | undefined,
60
+ ): {
61
+ id: "sso";
62
+ endpoints: SSOEndpoints<O> & DomainVerificationEndpoints;
63
+ schema: any;
64
+ options: O;
65
+ };
37
66
  export function sso<O extends SSOOptions>(
38
67
  options?: O | undefined,
39
68
  ): {
40
69
  id: "sso";
41
- endpoints: SSOEndpoints;
70
+ endpoints: SSOEndpoints<O>;
42
71
  };
43
72
 
44
73
  export function sso<O extends SSOOptions>(options?: O | undefined): any {
74
+ let endpoints = {
75
+ spMetadata: spMetadata(),
76
+ registerSSOProvider: registerSSOProvider(options as O),
77
+ signInSSO: signInSSO(options as O),
78
+ callbackSSO: callbackSSO(options as O),
79
+ callbackSSOSAML: callbackSSOSAML(options as O),
80
+ acsEndpoint: acsEndpoint(options as O),
81
+ };
82
+
83
+ if (options?.domainVerification?.enabled) {
84
+ const domainVerificationEndpoints = {
85
+ requestDomainVerification: requestDomainVerification(options as O),
86
+ verifyDomain: verifyDomain(options as O),
87
+ };
88
+
89
+ endpoints = {
90
+ ...endpoints,
91
+ ...domainVerificationEndpoints,
92
+ };
93
+ }
94
+
45
95
  return {
46
96
  id: "sso",
47
- endpoints: {
48
- spMetadata: spMetadata(),
49
- registerSSOProvider: registerSSOProvider(options),
50
- signInSSO: signInSSO(options),
51
- callbackSSO: callbackSSO(options),
52
- callbackSSOSAML: callbackSSOSAML(options),
53
- acsEndpoint: acsEndpoint(options),
54
- },
97
+ endpoints,
55
98
  schema: {
56
99
  ssoProvider: {
100
+ modelName: options?.modelName ?? "ssoProvider",
57
101
  fields: {
58
102
  issuer: {
59
103
  type: "string",
60
104
  required: true,
105
+ fieldName: options?.fields?.issuer ?? "issuer",
61
106
  },
62
107
  oidcConfig: {
63
108
  type: "string",
64
109
  required: false,
110
+ fieldName: options?.fields?.oidcConfig ?? "oidcConfig",
65
111
  },
66
112
  samlConfig: {
67
113
  type: "string",
68
114
  required: false,
115
+ fieldName: options?.fields?.samlConfig ?? "samlConfig",
69
116
  },
70
117
  userId: {
71
118
  type: "string",
@@ -73,20 +120,27 @@ export function sso<O extends SSOOptions>(options?: O | undefined): any {
73
120
  model: "user",
74
121
  field: "id",
75
122
  },
123
+ fieldName: options?.fields?.userId ?? "userId",
76
124
  },
77
125
  providerId: {
78
126
  type: "string",
79
127
  required: true,
80
128
  unique: true,
129
+ fieldName: options?.fields?.providerId ?? "providerId",
81
130
  },
82
131
  organizationId: {
83
132
  type: "string",
84
133
  required: false,
134
+ fieldName: options?.fields?.organizationId ?? "organizationId",
85
135
  },
86
136
  domain: {
87
137
  type: "string",
88
138
  required: true,
139
+ fieldName: options?.fields?.domain ?? "domain",
89
140
  },
141
+ ...(options?.domainVerification?.enabled
142
+ ? { domainVerified: { type: "boolean", required: false } }
143
+ : {}),
90
144
  },
91
145
  },
92
146
  },