@workos-inc/authkit-nextjs 3.0.0-beta.1 → 3.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. package/README.md +276 -102
  2. package/dist/esm/actions.js +35 -4
  3. package/dist/esm/actions.js.map +1 -1
  4. package/dist/esm/auth.js +51 -20
  5. package/dist/esm/auth.js.map +1 -1
  6. package/dist/esm/authkit-callback-route.js +82 -93
  7. package/dist/esm/authkit-callback-route.js.map +1 -1
  8. package/dist/esm/components/authkit-provider.js +36 -15
  9. package/dist/esm/components/authkit-provider.js.map +1 -1
  10. package/dist/esm/components/impersonation.js +17 -15
  11. package/dist/esm/components/impersonation.js.map +1 -1
  12. package/dist/esm/components/min-max-button.js +1 -1
  13. package/dist/esm/components/min-max-button.js.map +1 -1
  14. package/dist/esm/components/tokenStore.js +28 -19
  15. package/dist/esm/components/tokenStore.js.map +1 -1
  16. package/dist/esm/components/useAccessToken.js +1 -1
  17. package/dist/esm/components/useAccessToken.js.map +1 -1
  18. package/dist/esm/components/useTokenClaims.js +1 -1
  19. package/dist/esm/components/useTokenClaims.js.map +1 -1
  20. package/dist/esm/cookie.js +16 -5
  21. package/dist/esm/cookie.js.map +1 -1
  22. package/dist/esm/env-variables.js +6 -6
  23. package/dist/esm/env-variables.js.map +1 -1
  24. package/dist/esm/errors.js +36 -0
  25. package/dist/esm/errors.js.map +1 -0
  26. package/dist/esm/get-authorization-url.js +51 -12
  27. package/dist/esm/get-authorization-url.js.map +1 -1
  28. package/dist/esm/index.js +5 -2
  29. package/dist/esm/index.js.map +1 -1
  30. package/dist/esm/interfaces.js +7 -1
  31. package/dist/esm/interfaces.js.map +1 -1
  32. package/dist/esm/middleware-helpers.js +102 -0
  33. package/dist/esm/middleware-helpers.js.map +1 -0
  34. package/dist/esm/middleware.js +3 -1
  35. package/dist/esm/middleware.js.map +1 -1
  36. package/dist/esm/pkce.js +38 -0
  37. package/dist/esm/pkce.js.map +1 -0
  38. package/dist/esm/session.js +73 -35
  39. package/dist/esm/session.js.map +1 -1
  40. package/dist/esm/test-helpers.js +1 -1
  41. package/dist/esm/test-helpers.js.map +1 -1
  42. package/dist/esm/types/actions.d.ts +34 -5
  43. package/dist/esm/types/auth.d.ts +7 -15
  44. package/dist/esm/types/components/authkit-provider.d.ts +6 -2
  45. package/dist/esm/types/components/impersonation.d.ts +2 -1
  46. package/dist/esm/types/cookie.d.ts +8 -0
  47. package/dist/esm/types/env-variables.d.ts +2 -1
  48. package/dist/esm/types/errors.d.ts +15 -0
  49. package/dist/esm/types/get-authorization-url.d.ts +2 -2
  50. package/dist/esm/types/index.d.ts +5 -2
  51. package/dist/esm/types/interfaces.d.ts +12 -0
  52. package/dist/esm/types/jwt.d.ts +9 -9
  53. package/dist/esm/types/middleware-helpers.d.ts +27 -0
  54. package/dist/esm/types/middleware.d.ts +3 -1
  55. package/dist/esm/types/pkce.d.ts +12 -0
  56. package/dist/esm/types/session.d.ts +1 -1
  57. package/dist/esm/types/utils.d.ts +5 -0
  58. package/dist/esm/types/validate-api-key.d.ts +1 -0
  59. package/dist/esm/types/workos.d.ts +1 -1
  60. package/dist/esm/utils.js +10 -2
  61. package/dist/esm/utils.js.map +1 -1
  62. package/dist/esm/validate-api-key.js +16 -0
  63. package/dist/esm/validate-api-key.js.map +1 -0
  64. package/dist/esm/workos.js +1 -1
  65. package/package.json +32 -34
  66. package/src/actions.spec.ts +94 -17
  67. package/src/actions.ts +44 -5
  68. package/src/auth.spec.ts +60 -29
  69. package/src/auth.ts +55 -41
  70. package/src/authkit-callback-route.spec.ts +310 -58
  71. package/src/authkit-callback-route.ts +106 -103
  72. package/src/components/authkit-provider.spec.tsx +264 -70
  73. package/src/components/authkit-provider.tsx +40 -15
  74. package/src/components/button.spec.tsx +4 -6
  75. package/src/components/impersonation.spec.tsx +152 -35
  76. package/src/components/impersonation.tsx +37 -30
  77. package/src/components/min-max-button.spec.tsx +2 -1
  78. package/src/components/tokenStore.spec.ts +59 -44
  79. package/src/components/tokenStore.ts +11 -3
  80. package/src/components/useAccessToken.spec.tsx +82 -83
  81. package/src/components/useTokenClaims.spec.tsx +23 -22
  82. package/src/cookie.spec.ts +14 -9
  83. package/src/cookie.ts +29 -0
  84. package/src/env-variables.ts +2 -0
  85. package/src/errors.spec.ts +108 -0
  86. package/src/errors.ts +46 -0
  87. package/src/get-authorization-url.spec.ts +170 -15
  88. package/src/get-authorization-url.ts +69 -23
  89. package/src/index.ts +20 -2
  90. package/src/interfaces.ts +15 -0
  91. package/src/jwt.ts +9 -9
  92. package/src/middleware-helpers.spec.ts +238 -0
  93. package/src/middleware-helpers.ts +134 -0
  94. package/src/middleware.spec.ts +25 -0
  95. package/src/middleware.ts +4 -1
  96. package/src/pkce.spec.ts +125 -0
  97. package/src/pkce.ts +42 -0
  98. package/src/session.spec.ts +87 -89
  99. package/src/session.ts +91 -27
  100. package/src/test-helpers.ts +1 -1
  101. package/src/utils.spec.ts +14 -31
  102. package/src/utils.ts +9 -0
  103. package/src/validate-api-key.spec.ts +111 -0
  104. package/src/validate-api-key.ts +19 -0
  105. package/src/workos.spec.ts +2 -2
  106. package/src/workos.ts +1 -1
@@ -1,3 +1,4 @@
1
+ import type { Mock } from 'vitest';
1
2
  import React from 'react';
2
3
  import { render, waitFor, act } from '@testing-library/react';
3
4
  import '@testing-library/jest-dom';
@@ -10,17 +11,17 @@ import {
10
11
  switchToOrganizationAction,
11
12
  } from '../actions.js';
12
13
 
13
- jest.mock('../actions', () => ({
14
- checkSessionAction: jest.fn(),
15
- getAuthAction: jest.fn(),
16
- refreshAuthAction: jest.fn(),
17
- handleSignOutAction: jest.fn(),
18
- switchToOrganizationAction: jest.fn(),
14
+ vi.mock('../actions', () => ({
15
+ checkSessionAction: vi.fn(),
16
+ getAuthAction: vi.fn(),
17
+ refreshAuthAction: vi.fn(),
18
+ handleSignOutAction: vi.fn(),
19
+ switchToOrganizationAction: vi.fn(),
19
20
  }));
20
21
 
21
22
  describe('AuthKitProvider', () => {
22
23
  beforeEach(() => {
23
- jest.clearAllMocks();
24
+ vi.clearAllMocks();
24
25
  });
25
26
 
26
27
  it('should render children', async () => {
@@ -35,8 +36,123 @@ describe('AuthKitProvider', () => {
35
36
  expect(getByText('Test Child')).toBeInTheDocument();
36
37
  });
37
38
 
39
+ it('should skip initial getAuthAction call when initialAuth is provided', async () => {
40
+ const initialAuth = {
41
+ user: {
42
+ id: 'user-123',
43
+ email: 'test@example.com',
44
+ emailVerified: true,
45
+ profilePictureUrl: null,
46
+ firstName: 'Test',
47
+ lastName: 'User',
48
+ object: 'user' as const,
49
+ createdAt: '2024-01-01T00:00:00Z',
50
+ updatedAt: '2024-01-01T00:00:00Z',
51
+ lastSignInAt: '2024-01-01T00:00:00Z',
52
+ externalId: null,
53
+ locale: 'en-US',
54
+ metadata: {},
55
+ },
56
+ sessionId: 'test-session',
57
+ organizationId: 'test-org',
58
+ role: 'admin',
59
+ roles: ['admin'],
60
+ permissions: ['read', 'write'],
61
+ entitlements: ['feature1'],
62
+ featureFlags: ['test-flag'],
63
+ impersonator: undefined,
64
+ };
65
+
66
+ render(
67
+ <AuthKitProvider initialAuth={initialAuth}>
68
+ <div>Test Child</div>
69
+ </AuthKitProvider>,
70
+ );
71
+
72
+ // Wait a bit to ensure no call is made
73
+ await waitFor(
74
+ () => {
75
+ expect(getAuthAction).not.toHaveBeenCalled();
76
+ },
77
+ { timeout: 100 },
78
+ );
79
+ });
80
+
81
+ it('should initialize state with initialAuth values', async () => {
82
+ const initialAuth = {
83
+ user: {
84
+ id: 'user-123',
85
+ email: 'test@example.com',
86
+ emailVerified: true,
87
+ profilePictureUrl: null,
88
+ firstName: 'Test',
89
+ lastName: 'User',
90
+ object: 'user' as const,
91
+ createdAt: '2024-01-01T00:00:00Z',
92
+ updatedAt: '2024-01-01T00:00:00Z',
93
+ lastSignInAt: '2024-01-01T00:00:00Z',
94
+ locale: 'en-US',
95
+ externalId: null,
96
+ metadata: {},
97
+ },
98
+ sessionId: 'test-session',
99
+ organizationId: 'test-org',
100
+ role: 'admin',
101
+ roles: ['admin'],
102
+ permissions: ['read', 'write'],
103
+ entitlements: ['feature1'],
104
+ featureFlags: ['test-flag'],
105
+ impersonator: { email: 'admin@example.com', reason: 'Support request' },
106
+ };
107
+
108
+ const TestComponent = () => {
109
+ const auth = useAuth();
110
+ return (
111
+ <div>
112
+ <div data-testid="loading">{auth.loading.toString()}</div>
113
+ <div data-testid="email">{auth.user?.email}</div>
114
+ <div data-testid="session">{auth.sessionId}</div>
115
+ <div data-testid="org">{auth.organizationId}</div>
116
+ <div data-testid="role">{auth.role}</div>
117
+ <div data-testid="impersonator">{auth.impersonator?.email}</div>
118
+ </div>
119
+ );
120
+ };
121
+
122
+ const { getByTestId } = render(
123
+ <AuthKitProvider initialAuth={initialAuth}>
124
+ <TestComponent />
125
+ </AuthKitProvider>,
126
+ );
127
+
128
+ // Should not be loading when initialAuth is provided
129
+ expect(getByTestId('loading')).toHaveTextContent('false');
130
+ expect(getByTestId('email')).toHaveTextContent('test@example.com');
131
+ expect(getByTestId('session')).toHaveTextContent('test-session');
132
+ expect(getByTestId('org')).toHaveTextContent('test-org');
133
+ expect(getByTestId('role')).toHaveTextContent('admin');
134
+ expect(getByTestId('impersonator')).toHaveTextContent('admin@example.com');
135
+ });
136
+
137
+ it('should call getAuthAction when initialAuth is not provided', async () => {
138
+ (getAuthAction as Mock).mockResolvedValueOnce({
139
+ user: { email: 'test@example.com' },
140
+ sessionId: 'test-session',
141
+ });
142
+
143
+ render(
144
+ <AuthKitProvider>
145
+ <div>Test Child</div>
146
+ </AuthKitProvider>,
147
+ );
148
+
149
+ await waitFor(() => {
150
+ expect(getAuthAction).toHaveBeenCalledTimes(1);
151
+ });
152
+ });
153
+
38
154
  it('should do nothing if onSessionExpired is false', async () => {
39
- jest.spyOn(window, 'addEventListener');
155
+ vi.spyOn(window, 'addEventListener');
40
156
 
41
157
  await act(async () => {
42
158
  render(
@@ -51,8 +167,8 @@ describe('AuthKitProvider', () => {
51
167
  });
52
168
 
53
169
  it('should call onSessionExpired when session is expired', async () => {
54
- (checkSessionAction as jest.Mock).mockRejectedValueOnce(new Error('Failed to fetch'));
55
- const onSessionExpired = jest.fn();
170
+ (checkSessionAction as Mock).mockRejectedValueOnce(new Error('Failed to fetch'));
171
+ const onSessionExpired = vi.fn();
56
172
 
57
173
  render(
58
174
  <AuthKitProvider onSessionExpired={onSessionExpired}>
@@ -71,8 +187,8 @@ describe('AuthKitProvider', () => {
71
187
  });
72
188
 
73
189
  it('should only call onSessionExpired once if multiple visibility changes occur', async () => {
74
- (checkSessionAction as jest.Mock).mockRejectedValueOnce(new Error('Failed to fetch'));
75
- const onSessionExpired = jest.fn();
190
+ (checkSessionAction as Mock).mockRejectedValueOnce(new Error('Failed to fetch'));
191
+ const onSessionExpired = vi.fn();
76
192
 
77
193
  render(
78
194
  <AuthKitProvider onSessionExpired={onSessionExpired}>
@@ -92,9 +208,9 @@ describe('AuthKitProvider', () => {
92
208
  });
93
209
 
94
210
  it('should pass through if checkSessionAction does not throw "Failed to fetch"', async () => {
95
- (checkSessionAction as jest.Mock).mockResolvedValueOnce(false);
211
+ (checkSessionAction as Mock).mockResolvedValueOnce(false);
96
212
 
97
- const onSessionExpired = jest.fn();
213
+ const onSessionExpired = vi.fn();
98
214
 
99
215
  render(
100
216
  <AuthKitProvider onSessionExpired={onSessionExpired}>
@@ -112,74 +228,73 @@ describe('AuthKitProvider', () => {
112
228
  });
113
229
  });
114
230
 
115
- it('should reload the page when session is expired and no onSessionExpired handler is provided', async () => {
116
- (checkSessionAction as jest.Mock).mockRejectedValueOnce(new Error('Failed to fetch'));
231
+ describe('window.location.reload behavior', () => {
232
+ let originalLocationDescriptor: PropertyDescriptor | undefined;
117
233
 
118
- const originalLocation = window.location;
119
-
120
- // @ts-expect-error - we're deleting the property to test the mock
121
- delete window.location;
122
-
123
- window.location = { ...window.location, reload: jest.fn() };
124
-
125
- render(
126
- <AuthKitProvider>
127
- <div>Test Child</div>
128
- </AuthKitProvider>,
129
- );
130
-
131
- act(() => {
132
- // Simulate visibility change
133
- window.dispatchEvent(new Event('visibilitychange'));
234
+ beforeEach(() => {
235
+ originalLocationDescriptor = Object.getOwnPropertyDescriptor(window, 'location');
236
+ Object.defineProperty(window, 'location', {
237
+ writable: true,
238
+ value: { reload: vi.fn() },
239
+ });
134
240
  });
135
241
 
136
- await waitFor(() => {
137
- expect(window.location.reload).toHaveBeenCalled();
242
+ afterEach(() => {
243
+ if (originalLocationDescriptor) {
244
+ Object.defineProperty(window, 'location', originalLocationDescriptor);
245
+ }
138
246
  });
139
247
 
140
- // Restore original reload function
141
- window.location = originalLocation;
142
- });
248
+ it('should reload the page when session is expired and no onSessionExpired handler is provided', async () => {
249
+ (checkSessionAction as Mock).mockRejectedValueOnce(new Error('Failed to fetch'));
143
250
 
144
- it('should not call onSessionExpired or reload the page if session is valid', async () => {
145
- (checkSessionAction as jest.Mock).mockResolvedValueOnce(true);
146
- const onSessionExpired = jest.fn();
251
+ render(
252
+ <AuthKitProvider>
253
+ <div>Test Child</div>
254
+ </AuthKitProvider>,
255
+ );
147
256
 
148
- const originalLocation = window.location;
257
+ act(() => {
258
+ // Simulate visibility change
259
+ window.dispatchEvent(new Event('visibilitychange'));
260
+ });
149
261
 
150
- // @ts-expect-error - we're deleting the property to test the mock
151
- delete window.location;
262
+ await waitFor(() => {
263
+ expect(window.location.reload).toHaveBeenCalled();
264
+ });
265
+ });
152
266
 
153
- window.location = { ...window.location, reload: jest.fn() };
267
+ it('should not call onSessionExpired or reload the page if session is valid', async () => {
268
+ (checkSessionAction as Mock).mockResolvedValueOnce(true);
269
+ const onSessionExpired = vi.fn();
154
270
 
155
- render(
156
- <AuthKitProvider onSessionExpired={onSessionExpired}>
157
- <div>Test Child</div>
158
- </AuthKitProvider>,
159
- );
271
+ render(
272
+ <AuthKitProvider onSessionExpired={onSessionExpired}>
273
+ <div>Test Child</div>
274
+ </AuthKitProvider>,
275
+ );
160
276
 
161
- act(() => {
162
- // Simulate visibility change
163
- window.dispatchEvent(new Event('visibilitychange'));
164
- });
277
+ act(() => {
278
+ // Simulate visibility change
279
+ window.dispatchEvent(new Event('visibilitychange'));
280
+ });
165
281
 
166
- await waitFor(() => {
167
- expect(onSessionExpired).not.toHaveBeenCalled();
168
- expect(window.location.reload).not.toHaveBeenCalled();
282
+ await waitFor(() => {
283
+ expect(onSessionExpired).not.toHaveBeenCalled();
284
+ expect(window.location.reload).not.toHaveBeenCalled();
285
+ });
169
286
  });
170
-
171
- window.location = originalLocation;
172
287
  });
173
288
  });
174
289
 
175
290
  describe('useAuth', () => {
176
291
  beforeEach(() => {
177
- jest.clearAllMocks();
292
+ vi.clearAllMocks();
178
293
  });
179
294
 
180
295
  it('should call getAuth when a user is not returned when ensureSignedIn is true', async () => {
181
296
  // First and second calls return no user, second call returns a user
182
- (getAuthAction as jest.Mock)
297
+ (getAuthAction as Mock)
183
298
  .mockResolvedValueOnce({ user: null, loading: true })
184
299
  .mockResolvedValueOnce({ user: { email: 'test@example.com' }, loading: false });
185
300
 
@@ -201,6 +316,85 @@ describe('useAuth', () => {
201
316
  });
202
317
  });
203
318
 
319
+ describe('client-side redirect for ensureSignedIn', () => {
320
+ let originalLocationDescriptor: PropertyDescriptor | undefined;
321
+
322
+ beforeEach(() => {
323
+ originalLocationDescriptor = Object.getOwnPropertyDescriptor(window, 'location');
324
+ Object.defineProperty(window, 'location', {
325
+ writable: true,
326
+ value: { href: '' },
327
+ });
328
+ });
329
+
330
+ afterEach(() => {
331
+ if (originalLocationDescriptor) {
332
+ Object.defineProperty(window, 'location', originalLocationDescriptor);
333
+ }
334
+ });
335
+
336
+ it('should redirect via window.location.href when getAuthAction returns signInUrl', async () => {
337
+ // First call (initial load): no user, no signInUrl
338
+ // Second call (ensureSignedIn triggered): no user, signInUrl returned
339
+ (getAuthAction as Mock)
340
+ .mockResolvedValueOnce({ user: null })
341
+ .mockResolvedValueOnce({ user: null, signInUrl: 'https://api.workos.com/authorize?client_id=test' });
342
+
343
+ const TestComponent = () => {
344
+ const auth = useAuth({ ensureSignedIn: true });
345
+ return <div data-testid="loading">{auth.loading.toString()}</div>;
346
+ };
347
+
348
+ render(
349
+ <AuthKitProvider>
350
+ <TestComponent />
351
+ </AuthKitProvider>,
352
+ );
353
+
354
+ await waitFor(() => {
355
+ expect(window.location.href).toBe('https://api.workos.com/authorize?client_id=test');
356
+ });
357
+ });
358
+
359
+ it('should redirect via window.location.href when refreshAuthAction returns signInUrl', async () => {
360
+ (getAuthAction as Mock).mockResolvedValueOnce({
361
+ user: { email: 'test@example.com' },
362
+ sessionId: 'test-session',
363
+ });
364
+ (refreshAuthAction as Mock).mockResolvedValueOnce({
365
+ user: null,
366
+ signInUrl: 'https://api.workos.com/authorize?client_id=refresh_test',
367
+ });
368
+
369
+ const TestComponent = () => {
370
+ const auth = useAuth();
371
+ return (
372
+ <div>
373
+ <button onClick={() => auth.refreshAuth({ ensureSignedIn: true })}>Refresh</button>
374
+ </div>
375
+ );
376
+ };
377
+
378
+ const { getByRole } = render(
379
+ <AuthKitProvider>
380
+ <TestComponent />
381
+ </AuthKitProvider>,
382
+ );
383
+
384
+ await waitFor(() => {
385
+ expect(getAuthAction).toHaveBeenCalledTimes(1);
386
+ });
387
+
388
+ act(() => {
389
+ getByRole('button').click();
390
+ });
391
+
392
+ await waitFor(() => {
393
+ expect(window.location.href).toBe('https://api.workos.com/authorize?client_id=refresh_test');
394
+ });
395
+ });
396
+ });
397
+
204
398
  it('should throw error when used outside of AuthKitProvider', () => {
205
399
  const TestComponent = () => {
206
400
  const auth = useAuth();
@@ -208,7 +402,7 @@ describe('useAuth', () => {
208
402
  };
209
403
 
210
404
  // Suppress console.error for this test since we expect an error
211
- const consoleSpy = jest.spyOn(console, 'error').mockImplementation(() => {});
405
+ const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
212
406
 
213
407
  expect(() => {
214
408
  render(<TestComponent />);
@@ -218,7 +412,7 @@ describe('useAuth', () => {
218
412
  });
219
413
 
220
414
  it('should provide auth context values when used within AuthKitProvider', async () => {
221
- (getAuthAction as jest.Mock).mockResolvedValueOnce({
415
+ (getAuthAction as Mock).mockResolvedValueOnce({
222
416
  user: { email: 'test@example.com' },
223
417
  sessionId: 'test-session',
224
418
  organizationId: 'test-org',
@@ -266,8 +460,8 @@ describe('useAuth', () => {
266
460
  sessionId: 'test-session',
267
461
  };
268
462
 
269
- (getAuthAction as jest.Mock).mockResolvedValueOnce(mockAuth);
270
- (refreshAuthAction as jest.Mock).mockResolvedValueOnce({
463
+ (getAuthAction as Mock).mockResolvedValueOnce(mockAuth);
464
+ (refreshAuthAction as Mock).mockResolvedValueOnce({
271
465
  ...mockAuth,
272
466
  sessionId: 'new-session',
273
467
  });
@@ -309,10 +503,10 @@ describe('useAuth', () => {
309
503
  organizationId: 'new-org',
310
504
  };
311
505
 
312
- (getAuthAction as jest.Mock)
506
+ (getAuthAction as Mock)
313
507
  .mockResolvedValue(mockAuth)
314
508
  .mockResolvedValueOnce({ ...mockAuth, organizationId: 'old-org' });
315
- (switchToOrganizationAction as jest.Mock).mockResolvedValueOnce(mockAuth);
509
+ (switchToOrganizationAction as Mock).mockResolvedValueOnce(mockAuth);
316
510
 
317
511
  const TestComponent = () => {
318
512
  const auth = useAuth();
@@ -345,7 +539,7 @@ describe('useAuth', () => {
345
539
  });
346
540
 
347
541
  it('should receive an error when refreshAuth fails with an error', async () => {
348
- (refreshAuthAction as jest.Mock).mockRejectedValueOnce(new Error('Refresh failed'));
542
+ (refreshAuthAction as Mock).mockRejectedValueOnce(new Error('Refresh failed'));
349
543
 
350
544
  let error: string | undefined;
351
545
 
@@ -382,7 +576,7 @@ describe('useAuth', () => {
382
576
  });
383
577
 
384
578
  it('should receive an error when refreshAuth fails with a string error', async () => {
385
- (refreshAuthAction as jest.Mock).mockRejectedValueOnce('Refresh failed');
579
+ (refreshAuthAction as Mock).mockRejectedValueOnce('Refresh failed');
386
580
 
387
581
  let error: string | undefined;
388
582
 
@@ -419,7 +613,7 @@ describe('useAuth', () => {
419
613
  });
420
614
 
421
615
  it('should call handleSignOutAction when signOut is called', async () => {
422
- (handleSignOutAction as jest.Mock).mockResolvedValueOnce({});
616
+ (handleSignOutAction as Mock).mockResolvedValueOnce({});
423
617
 
424
618
  const TestComponent = () => {
425
619
  const auth = useAuth();
@@ -445,7 +639,7 @@ describe('useAuth', () => {
445
639
  });
446
640
 
447
641
  it('should pass returnTo parameter to handleSignOutAction', async () => {
448
- (handleSignOutAction as jest.Mock).mockResolvedValueOnce({});
642
+ (handleSignOutAction as Mock).mockResolvedValueOnce({});
449
643
 
450
644
  const TestComponent = () => {
451
645
  const auth = useAuth();
@@ -1,6 +1,6 @@
1
1
  'use client';
2
2
 
3
- import React, { createContext, ReactNode, useCallback, useContext, useEffect, useState } from 'react';
3
+ import React, { createContext, ReactNode, useCallback, useContext, useEffect, useRef, useState } from 'react';
4
4
  import {
5
5
  checkSessionAction,
6
6
  getAuthAction,
@@ -9,7 +9,7 @@ import {
9
9
  switchToOrganizationAction,
10
10
  } from '../actions.js';
11
11
  import type { Impersonator, User } from '@workos-inc/node';
12
- import type { UserInfo, SwitchToOrganizationOptions } from '../interfaces.js';
12
+ import type { UserInfo, SwitchToOrganizationOptions, NoUserInfo } from '../interfaces.js';
13
13
 
14
14
  type AuthContextType = {
15
15
  user: User | null;
@@ -40,24 +40,44 @@ interface AuthKitProviderProps {
40
40
  * You can also pass this as `false` to disable the expired session checks.
41
41
  */
42
42
  onSessionExpired?: false | (() => void);
43
+ /**
44
+ * Initial auth data from the server. If provided, the provider will skip the initial client-side fetch.
45
+ */
46
+ initialAuth?: Omit<UserInfo | NoUserInfo, 'accessToken'>;
43
47
  }
44
48
 
45
- export const AuthKitProvider = ({ children, onSessionExpired }: AuthKitProviderProps) => {
46
- const [user, setUser] = useState<User | null>(null);
47
- const [sessionId, setSessionId] = useState<string | undefined>(undefined);
48
- const [organizationId, setOrganizationId] = useState<string | undefined>(undefined);
49
- const [role, setRole] = useState<string | undefined>(undefined);
50
- const [roles, setRoles] = useState<string[] | undefined>(undefined);
51
- const [permissions, setPermissions] = useState<string[] | undefined>(undefined);
52
- const [entitlements, setEntitlements] = useState<string[] | undefined>(undefined);
53
- const [featureFlags, setFeatureFlags] = useState<string[] | undefined>(undefined);
54
- const [impersonator, setImpersonator] = useState<Impersonator | undefined>(undefined);
55
- const [loading, setLoading] = useState(true);
49
+ export const AuthKitProvider = ({ children, onSessionExpired, initialAuth }: AuthKitProviderProps) => {
50
+ const [user, setUser] = useState<User | null>(initialAuth?.user ?? null);
51
+ const [sessionId, setSessionId] = useState<string | undefined>(initialAuth?.sessionId);
52
+ const [organizationId, setOrganizationId] = useState<string | undefined>(initialAuth?.organizationId);
53
+ const [role, setRole] = useState<string | undefined>(initialAuth?.role);
54
+ const [roles, setRoles] = useState<string[] | undefined>(initialAuth?.roles);
55
+ const [permissions, setPermissions] = useState<string[] | undefined>(initialAuth?.permissions);
56
+ const [entitlements, setEntitlements] = useState<string[] | undefined>(initialAuth?.entitlements);
57
+ const [featureFlags, setFeatureFlags] = useState<string[] | undefined>(initialAuth?.featureFlags);
58
+ const [impersonator, setImpersonator] = useState<Impersonator | undefined>(initialAuth?.impersonator);
59
+ const [loading, setLoading] = useState(!initialAuth);
60
+ const redirectingRef = useRef(false);
61
+
62
+ // Redirect client-side to avoid CORS errors that occur when redirect()
63
+ // is called from a server action to an external URL.
64
+ const handleSignInRedirect = useCallback((auth: Record<string, unknown>): boolean => {
65
+ if ('signInUrl' in auth && auth.signInUrl) {
66
+ redirectingRef.current = true;
67
+ window.location.href = auth.signInUrl as string;
68
+ return true;
69
+ }
70
+ return false;
71
+ }, []);
56
72
 
57
73
  const getAuth = useCallback(async ({ ensureSignedIn = false }: { ensureSignedIn?: boolean } = {}) => {
74
+ if (redirectingRef.current) return;
58
75
  setLoading(true);
59
76
  try {
60
77
  const auth = await getAuthAction({ ensureSignedIn });
78
+
79
+ if (handleSignInRedirect(auth)) return;
80
+
61
81
  setUser(auth.user);
62
82
  setSessionId(auth.sessionId);
63
83
  setOrganizationId(auth.organizationId);
@@ -67,7 +87,7 @@ export const AuthKitProvider = ({ children, onSessionExpired }: AuthKitProviderP
67
87
  setEntitlements(auth.entitlements);
68
88
  setFeatureFlags(auth.featureFlags);
69
89
  setImpersonator(auth.impersonator);
70
- } catch (error) {
90
+ } catch {
71
91
  setUser(null);
72
92
  setSessionId(undefined);
73
93
  setOrganizationId(undefined);
@@ -101,10 +121,13 @@ export const AuthKitProvider = ({ children, onSessionExpired }: AuthKitProviderP
101
121
 
102
122
  const refreshAuth = useCallback(
103
123
  async ({ ensureSignedIn = false, organizationId }: { ensureSignedIn?: boolean; organizationId?: string } = {}) => {
124
+ if (redirectingRef.current) return;
104
125
  try {
105
126
  setLoading(true);
106
127
  const auth = await refreshAuthAction({ ensureSignedIn, organizationId });
107
128
 
129
+ if (handleSignInRedirect(auth)) return;
130
+
108
131
  setUser(auth.user);
109
132
  setSessionId(auth.sessionId);
110
133
  setOrganizationId(auth.organizationId);
@@ -128,7 +151,9 @@ export const AuthKitProvider = ({ children, onSessionExpired }: AuthKitProviderP
128
151
  }, []);
129
152
 
130
153
  useEffect(() => {
131
- getAuth();
154
+ if (!initialAuth) {
155
+ getAuth();
156
+ }
132
157
 
133
158
  // Return early if the session expired checks are disabled.
134
159
  if (onSessionExpired === false) {
@@ -24,12 +24,10 @@ describe('Button', () => {
24
24
  const { getByRole } = render(<Button style={{ backgroundColor: 'red' }}>Click me</Button>);
25
25
  const button = getByRole('button');
26
26
 
27
- expect(button).toHaveStyle({
28
- backgroundColor: 'red',
29
- display: 'inline-flex',
30
- alignItems: 'center',
31
- justifyContent: 'center',
32
- });
27
+ expect(button.style.backgroundColor).toBe('red');
28
+ expect(button.style.display).toBe('inline-flex');
29
+ expect(button.style.alignItems).toBe('center');
30
+ expect(button.style.justifyContent).toBe('center');
33
31
  });
34
32
 
35
33
  it('should pass through additional props', () => {