@better-auth/sso 1.3.0-beta.9 → 1.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/saml.test.ts CHANGED
@@ -13,7 +13,6 @@ import { createAuthClient } from "better-auth/client";
13
13
  import { betterFetch } from "@better-fetch/fetch";
14
14
  import { setCookieToHeader } from "better-auth/cookies";
15
15
  import { bearer } from "better-auth/plugins";
16
- import { IdentityProvider, ServiceProvider } from "samlify";
17
16
  import { sso } from ".";
18
17
  import { ssoClient } from "./client";
19
18
  import { createServer } from "http";
@@ -27,6 +26,7 @@ import type {
27
26
  import express from "express";
28
27
  import bodyParser from "body-parser";
29
28
  import { randomUUID } from "crypto";
29
+ import { getTestInstanceMemory } from "better-auth/test";
30
30
 
31
31
  const spMetadata = `
32
32
  <md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="http://localhost:3001/api/sso/saml2/sp/metadata">
@@ -353,13 +353,15 @@ const createTemplateCallback =
353
353
  context: saml.SamlLib.replaceTagsByValue(template, tagValues),
354
354
  };
355
355
  };
356
- function createMockSAMLIdP(port: number) {
356
+
357
+ const createMockSAMLIdP = (port: number) => {
357
358
  const app: ExpressApp = express();
358
359
  let server: ReturnType<typeof createServer> | undefined;
360
+
359
361
  app.use(bodyParser.urlencoded({ extended: true }));
360
362
  app.use(bodyParser.json());
361
363
 
362
- const idp = IdentityProvider({
364
+ const idp = saml.IdentityProvider({
363
365
  metadata: idpMetadata,
364
366
  privateKey: idPk,
365
367
  isAssertionEncrypted: false,
@@ -389,7 +391,7 @@ function createMockSAMLIdP(port: number) {
389
391
  ],
390
392
  },
391
393
  });
392
- const sp = ServiceProvider({
394
+ const sp = saml.ServiceProvider({
393
395
  metadata: spMetadata,
394
396
  });
395
397
  app.get(
@@ -420,37 +422,32 @@ function createMockSAMLIdP(port: number) {
420
422
  res.status(200).send({ samlResponse: context, entityEndpoint });
421
423
  },
422
424
  );
423
- // @ts-ignore
424
- app.post(
425
- "/api/sso/saml2/sp/acs",
426
- async (req: ExpressRequest, res: ExpressResponse) => {
427
- try {
428
- const parseResult = await sp.parseLoginResponse(
429
- idp,
430
- saml.Constants.wording.binding.post,
431
- req,
432
- );
433
- const { extract } = parseResult;
434
- const { attributes } = extract;
435
- const relayState = req.body.RelayState;
436
- if (relayState) {
437
- return res.status(200).send({ relayState, attributes });
438
- } else {
439
- return res
440
- .status(200)
441
- .send({ extract, message: "RelayState is missing." });
442
- }
443
- } catch (error) {
444
- console.error("Error handling SAML ACS endpoint:", error);
445
- res.status(500).send({ error: "Failed to process SAML response." });
425
+ app.post("/api/sso/saml2/sp/acs", async (req: any, res: any) => {
426
+ try {
427
+ const parseResult = await sp.parseLoginResponse(
428
+ idp,
429
+ saml.Constants.wording.binding.post,
430
+ req,
431
+ );
432
+ const { extract } = parseResult;
433
+ const { attributes } = extract;
434
+ const relayState = req.body.RelayState;
435
+ if (relayState) {
436
+ return res.status(200).send({ relayState, attributes });
437
+ } else {
438
+ return res
439
+ .status(200)
440
+ .send({ extract, message: "RelayState is missing." });
446
441
  }
447
- },
448
- );
442
+ } catch (error) {
443
+ console.error("Error handling SAML ACS endpoint:", error);
444
+ res.status(500).send({ error: "Failed to process SAML response." });
445
+ }
446
+ });
449
447
  app.post(
450
- "/api/sso/saml2/callback",
448
+ "/api/sso/saml2/callback/:providerId",
451
449
  async (req: ExpressRequest, res: ExpressResponse) => {
452
450
  const { SAMLResponse, RelayState } = req.body;
453
-
454
451
  try {
455
452
  const parseResult = await sp.parseLoginResponse(
456
453
  idp,
@@ -474,29 +471,28 @@ function createMockSAMLIdP(port: number) {
474
471
  res.send(idpMetadata);
475
472
  },
476
473
  );
477
-
478
- return {
479
- start: () => {
480
- return new Promise<void>((resolve) => {
481
- server = app.listen(port, () => {
482
- console.log(`Mock SAML IdP running on port ${port}`);
483
- resolve();
484
- });
474
+ const start = () =>
475
+ new Promise<void>((resolve) => {
476
+ app.use(bodyParser.urlencoded({ extended: true }));
477
+ server = app.listen(port, () => {
478
+ console.log(`Mock SAML IdP running on port ${port}`);
479
+ resolve();
485
480
  });
486
- },
487
- stop: () => {
488
- return new Promise<void>((resolve, reject) => {
489
- server?.close((err) => {
490
- if (err) reject(err);
491
- else resolve();
492
- });
481
+ });
482
+
483
+ const stop = () =>
484
+ new Promise<void>((resolve, reject) => {
485
+ app.use(bodyParser.urlencoded({ extended: true }));
486
+ server?.close((err) => {
487
+ if (err) reject(err);
488
+ else resolve();
493
489
  });
494
- },
495
- get metadataUrl() {
496
- return `http://localhost:${port}/idp/metadata`;
497
- },
498
- };
499
- }
490
+ });
491
+
492
+ const metadataUrl = `http://localhost:${port}/idp/metadata`;
493
+
494
+ return { start, stop, metadataUrl };
495
+ };
500
496
 
501
497
  describe("SAML SSO", async () => {
502
498
  const data = {
@@ -709,7 +705,7 @@ describe("SAML SSO", async () => {
709
705
  samlConfig: {
710
706
  entryPoint: "http://localhost:8081/api/sso/saml2/idp/post",
711
707
  cert: certificate,
712
- callbackUrl: "http://localhost:8081/api/sso/saml2/callback",
708
+ callbackUrl: "http://localhost:8081/dashboard",
713
709
  wantAssertionsSigned: false,
714
710
  signatureAlgorithm: "sha256",
715
711
  digestAlgorithm: "sha256",
@@ -748,7 +744,6 @@ describe("SAML SSO", async () => {
748
744
  url: expect.stringContaining("http://localhost:8081"),
749
745
  redirect: true,
750
746
  });
751
-
752
747
  let samlResponse: any;
753
748
  await betterFetch(signInResponse?.url as string, {
754
749
  onSuccess: async (context) => {
@@ -756,21 +751,168 @@ describe("SAML SSO", async () => {
756
751
  },
757
752
  });
758
753
  let redirectLocation = "";
759
- await betterFetch("http://localhost:8081/api/sso/saml2/callback", {
760
- method: "POST",
761
- redirect: "manual",
762
- headers: {
763
- "Content-Type": "application/x-www-form-urlencoded",
754
+ await betterFetch(
755
+ "http://localhost:8081/api/sso/saml2/callback/saml-provider-1",
756
+ {
757
+ method: "POST",
758
+ redirect: "manual",
759
+ headers: {
760
+ "Content-Type": "application/x-www-form-urlencoded",
761
+ },
762
+ body: new URLSearchParams({
763
+ SAMLResponse: samlResponse.samlResponse,
764
+ }),
765
+ onError: (context) => {
766
+ expect(context.response.status).toBe(302);
767
+ redirectLocation = context.response.headers.get("location") || "";
768
+ },
764
769
  },
765
- body: new URLSearchParams({
766
- SAMLResponse: samlResponse.samlResponse,
767
- RelayState: "http://localhost:3000/dashboard",
770
+ );
771
+ expect(redirectLocation).toBe("http://localhost:3000/dashboard");
772
+ });
773
+
774
+ it("should not allow creating a provider if limit is set to 0", async () => {
775
+ const { auth, signInWithTestUser } = await getTestInstanceMemory({
776
+ plugins: [sso({ providersLimit: 0 })],
777
+ });
778
+ const { headers } = await signInWithTestUser();
779
+ await expect(
780
+ auth.api.registerSSOProvider({
781
+ body: {
782
+ providerId: "saml-provider-1",
783
+ issuer: "http://localhost:8081",
784
+ domain: "http://localhost:8081",
785
+ samlConfig: {
786
+ entryPoint: mockIdP.metadataUrl,
787
+ cert: certificate,
788
+ callbackUrl: "http://localhost:8081/api/sso/saml2/callback",
789
+ wantAssertionsSigned: false,
790
+ signatureAlgorithm: "sha256",
791
+ digestAlgorithm: "sha256",
792
+ spMetadata: {
793
+ metadata: spMetadata,
794
+ },
795
+ },
796
+ },
797
+ headers,
768
798
  }),
769
- onError: (context) => {
770
- expect(context.response.status).toBe(302);
771
- redirectLocation = context.response.headers.get("location") || "";
799
+ ).rejects.toMatchObject({
800
+ status: "FORBIDDEN",
801
+ body: { message: "SSO provider registration is disabled" },
802
+ });
803
+ });
804
+
805
+ it("should not allow creating a provider if limit is reached", async () => {
806
+ const { auth, signInWithTestUser } = await getTestInstanceMemory({
807
+ plugins: [sso({ providersLimit: 1 })],
808
+ });
809
+ const { headers } = await signInWithTestUser();
810
+
811
+ await auth.api.registerSSOProvider({
812
+ body: {
813
+ providerId: "saml-provider-1",
814
+ issuer: "http://localhost:8081",
815
+ domain: "http://localhost:8081",
816
+ samlConfig: {
817
+ entryPoint: mockIdP.metadataUrl,
818
+ cert: certificate,
819
+ callbackUrl: "http://localhost:8081/api/sso/saml2/callback",
820
+ wantAssertionsSigned: false,
821
+ signatureAlgorithm: "sha256",
822
+ digestAlgorithm: "sha256",
823
+ spMetadata: {
824
+ metadata: spMetadata,
825
+ },
826
+ },
827
+ },
828
+ headers,
829
+ });
830
+
831
+ await expect(
832
+ auth.api.registerSSOProvider({
833
+ body: {
834
+ providerId: "saml-provider-2",
835
+ issuer: "http://localhost:8081",
836
+ domain: "http://localhost:8081",
837
+ samlConfig: {
838
+ entryPoint: mockIdP.metadataUrl,
839
+ cert: certificate,
840
+ callbackUrl: "http://localhost:8081/api/sso/saml2/callback",
841
+ wantAssertionsSigned: false,
842
+ signatureAlgorithm: "sha256",
843
+ digestAlgorithm: "sha256",
844
+ spMetadata: {
845
+ metadata: spMetadata,
846
+ },
847
+ },
848
+ },
849
+ headers,
850
+ }),
851
+ ).rejects.toMatchObject({
852
+ status: "FORBIDDEN",
853
+ body: {
854
+ message: "You have reached the maximum number of SSO providers",
855
+ },
856
+ });
857
+ });
858
+
859
+ it("should not allow creating a provider if limit from function is reached", async () => {
860
+ const { auth, signInWithTestUser } = await getTestInstanceMemory({
861
+ plugins: [
862
+ sso({
863
+ providersLimit: async (user) => {
864
+ return user.email === "pro@example.com" ? 2 : 1;
865
+ },
866
+ }),
867
+ ],
868
+ });
869
+ const { headers } = await signInWithTestUser();
870
+
871
+ await auth.api.registerSSOProvider({
872
+ body: {
873
+ providerId: "saml-provider-1",
874
+ issuer: "http://localhost:8081",
875
+ domain: "http://localhost:8081",
876
+ samlConfig: {
877
+ entryPoint: mockIdP.metadataUrl,
878
+ cert: certificate,
879
+ callbackUrl: "http://localhost:8081/api/sso/saml2/callback",
880
+ wantAssertionsSigned: false,
881
+ signatureAlgorithm: "sha256",
882
+ digestAlgorithm: "sha256",
883
+ spMetadata: {
884
+ metadata: spMetadata,
885
+ },
886
+ },
887
+ },
888
+ headers,
889
+ });
890
+
891
+ await expect(
892
+ auth.api.registerSSOProvider({
893
+ body: {
894
+ providerId: "saml-provider-2",
895
+ issuer: "http://localhost:8081",
896
+ domain: "http://localhost:8081",
897
+ samlConfig: {
898
+ entryPoint: mockIdP.metadataUrl,
899
+ cert: certificate,
900
+ callbackUrl: "http://localhost:8081/api/sso/saml2/callback",
901
+ wantAssertionsSigned: false,
902
+ signatureAlgorithm: "sha256",
903
+ digestAlgorithm: "sha256",
904
+ spMetadata: {
905
+ metadata: spMetadata,
906
+ },
907
+ },
908
+ },
909
+ headers,
910
+ }),
911
+ ).rejects.toMatchObject({
912
+ status: "FORBIDDEN",
913
+ body: {
914
+ message: "You have reached the maximum number of SSO providers",
772
915
  },
773
916
  });
774
- expect(redirectLocation).toBe("http://localhost:3000/dashboard");
775
917
  });
776
918
  });