@lobehub/chat 1.79.7 → 1.79.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. package/CHANGELOG.md +25 -0
  2. package/changelog/v1.json +9 -0
  3. package/docs/development/database-schema.dbml +119 -0
  4. package/locales/ar/models.json +12 -0
  5. package/locales/ar/oauth.json +39 -0
  6. package/locales/bg-BG/models.json +12 -0
  7. package/locales/bg-BG/oauth.json +39 -0
  8. package/locales/de-DE/models.json +12 -0
  9. package/locales/de-DE/oauth.json +39 -0
  10. package/locales/en-US/models.json +12 -0
  11. package/locales/en-US/oauth.json +39 -0
  12. package/locales/es-ES/models.json +12 -0
  13. package/locales/es-ES/oauth.json +39 -0
  14. package/locales/fa-IR/models.json +12 -0
  15. package/locales/fa-IR/oauth.json +39 -0
  16. package/locales/fr-FR/models.json +12 -0
  17. package/locales/fr-FR/oauth.json +39 -0
  18. package/locales/it-IT/models.json +12 -0
  19. package/locales/it-IT/oauth.json +39 -0
  20. package/locales/ja-JP/models.json +12 -0
  21. package/locales/ja-JP/oauth.json +39 -0
  22. package/locales/ko-KR/models.json +12 -0
  23. package/locales/ko-KR/oauth.json +39 -0
  24. package/locales/nl-NL/models.json +12 -0
  25. package/locales/nl-NL/oauth.json +39 -0
  26. package/locales/pl-PL/models.json +12 -0
  27. package/locales/pl-PL/oauth.json +39 -0
  28. package/locales/pt-BR/models.json +12 -0
  29. package/locales/pt-BR/oauth.json +39 -0
  30. package/locales/ru-RU/models.json +12 -0
  31. package/locales/ru-RU/oauth.json +39 -0
  32. package/locales/tr-TR/models.json +12 -0
  33. package/locales/tr-TR/oauth.json +39 -0
  34. package/locales/vi-VN/models.json +12 -0
  35. package/locales/vi-VN/oauth.json +39 -0
  36. package/locales/zh-CN/models.json +12 -0
  37. package/locales/zh-CN/oauth.json +39 -0
  38. package/locales/zh-TW/models.json +12 -0
  39. package/locales/zh-TW/oauth.json +39 -0
  40. package/package.json +4 -1
  41. package/scripts/generate-oidc-jwk.mjs +59 -0
  42. package/scripts/migrateServerDB/index.ts +3 -1
  43. package/src/app/(backend)/oidc/[...oidc]/route.ts +270 -0
  44. package/src/app/(backend)/oidc/consent/route.ts +97 -0
  45. package/src/app/[variants]/oauth/consent/[uid]/Client.tsx +97 -0
  46. package/src/app/[variants]/oauth/consent/[uid]/failed/page.tsx +36 -0
  47. package/src/app/[variants]/oauth/consent/[uid]/page.tsx +71 -0
  48. package/src/app/[variants]/oauth/consent/[uid]/success/page.tsx +30 -0
  49. package/src/database/client/migrations.json +27 -8
  50. package/src/database/migrations/0020_add_oidc.sql +124 -0
  51. package/src/database/migrations/meta/0020_snapshot.json +4975 -0
  52. package/src/database/migrations/meta/_journal.json +7 -0
  53. package/src/database/repositories/tableViewer/index.test.ts +1 -1
  54. package/src/database/schemas/index.ts +1 -0
  55. package/src/database/schemas/oidc.ts +158 -0
  56. package/src/database/server/models/__tests__/adapter.test.ts +503 -0
  57. package/src/envs/oidc.ts +18 -0
  58. package/src/libs/agent-runtime/azureOpenai/index.ts +4 -1
  59. package/src/libs/agent-runtime/utils/streams/protocol.ts +2 -4
  60. package/src/libs/oidc-provider/adapter.ts +494 -0
  61. package/src/libs/oidc-provider/config.ts +53 -0
  62. package/src/libs/oidc-provider/http-adapter.ts +279 -0
  63. package/src/libs/oidc-provider/interaction-policy.ts +37 -0
  64. package/src/libs/oidc-provider/provider.ts +260 -0
  65. package/src/locales/default/index.ts +2 -0
  66. package/src/locales/default/oauth.ts +41 -0
  67. package/src/middleware.ts +94 -6
  68. package/src/server/services/oidc/index.ts +29 -0
  69. package/src/server/services/oidc/oidcProvider.ts +27 -0
@@ -140,6 +140,13 @@
140
140
  "when": 1742806552131,
141
141
  "tag": "0019_add_hotkey_user_settings",
142
142
  "breakpoints": true
143
+ },
144
+ {
145
+ "idx": 20,
146
+ "version": "7",
147
+ "when": 1744458287757,
148
+ "tag": "0020_add_oidc",
149
+ "breakpoints": true
143
150
  }
144
151
  ],
145
152
  "version": "6"
@@ -23,7 +23,7 @@ describe('TableViewerRepo', () => {
23
23
  it('should return all tables with counts', async () => {
24
24
  const result = await repo.getAllTables();
25
25
 
26
- expect(result.length).toEqual(39);
26
+ expect(result.length).toEqual(48);
27
27
  expect(result[0]).toEqual({ name: 'agents', count: 0, type: 'BASE TABLE' });
28
28
  });
29
29
 
@@ -4,6 +4,7 @@ export * from './asyncTask';
4
4
  export * from './file';
5
5
  export * from './message';
6
6
  export * from './nextauth';
7
+ export * from './oidc';
7
8
  export * from './rag';
8
9
  export * from './ragEvals';
9
10
  export * from './relations';
@@ -0,0 +1,158 @@
1
+ /* eslint-disable sort-keys-fix/sort-keys-fix */
2
+ import { boolean, jsonb, pgTable, primaryKey, text, varchar } from 'drizzle-orm/pg-core';
3
+
4
+ import { timestamps, timestamptz } from './_helpers';
5
+ import { users } from './user';
6
+
7
+ /**
8
+ * OIDC 授权码
9
+ * oidc-provider 需要持久化的模型之一
10
+ */
11
+ export const oidcAuthorizationCodes = pgTable('oidc_authorization_codes', {
12
+ id: varchar('id', { length: 255 }).primaryKey(),
13
+ data: jsonb('data').notNull(),
14
+ expiresAt: timestamptz('expires_at').notNull(),
15
+ consumedAt: timestamptz('consumed_at'),
16
+ userId: text('user_id')
17
+ .references(() => users.id, { onDelete: 'cascade' })
18
+ .notNull(),
19
+ clientId: varchar('client_id', { length: 255 }).notNull(),
20
+ grantId: varchar('grant_id', { length: 255 }),
21
+ ...timestamps,
22
+ });
23
+
24
+ /**
25
+ * OIDC 访问令牌
26
+ * oidc-provider 需要持久化的模型之一
27
+ */
28
+ export const oidcAccessTokens = pgTable('oidc_access_tokens', {
29
+ id: varchar('id', { length: 255 }).primaryKey(),
30
+ data: jsonb('data').notNull(),
31
+ expiresAt: timestamptz('expires_at').notNull(),
32
+ consumedAt: timestamptz('consumed_at'),
33
+ userId: text('user_id')
34
+ .references(() => users.id, { onDelete: 'cascade' })
35
+ .notNull(),
36
+ clientId: varchar('client_id', { length: 255 }).notNull(),
37
+ grantId: varchar('grant_id', { length: 255 }),
38
+ ...timestamps,
39
+ });
40
+
41
+ /**
42
+ * OIDC 刷新令牌
43
+ * oidc-provider 需要持久化的模型之一
44
+ */
45
+ export const oidcRefreshTokens = pgTable('oidc_refresh_tokens', {
46
+ id: varchar('id', { length: 255 }).primaryKey(),
47
+ data: jsonb('data').notNull(),
48
+ expiresAt: timestamptz('expires_at').notNull(),
49
+ consumedAt: timestamptz('consumed_at'),
50
+ userId: text('user_id')
51
+ .references(() => users.id, { onDelete: 'cascade' })
52
+ .notNull(),
53
+ clientId: varchar('client_id', { length: 255 }).notNull(),
54
+ grantId: varchar('grant_id', { length: 255 }),
55
+ ...timestamps,
56
+ });
57
+
58
+ /**
59
+ * OIDC 设备代码
60
+ * oidc-provider 需要持久化的模型之一
61
+ */
62
+ export const oidcDeviceCodes = pgTable('oidc_device_codes', {
63
+ id: varchar('id', { length: 255 }).primaryKey(),
64
+ data: jsonb('data').notNull(),
65
+ expiresAt: timestamptz('expires_at').notNull(),
66
+ consumedAt: timestamptz('consumed_at'),
67
+ userId: text('user_id').references(() => users.id, { onDelete: 'cascade' }),
68
+ clientId: varchar('client_id', { length: 255 }).notNull(),
69
+ grantId: varchar('grant_id', { length: 255 }),
70
+ userCode: varchar('user_code', { length: 255 }),
71
+ ...timestamps,
72
+ });
73
+
74
+ /**
75
+ * OIDC 交互会话
76
+ * oidc-provider 需要持久化的模型之一
77
+ */
78
+ export const oidcInteractions = pgTable('oidc_interactions', {
79
+ id: varchar('id', { length: 255 }).primaryKey(),
80
+ data: jsonb('data').notNull(),
81
+ expiresAt: timestamptz('expires_at').notNull(),
82
+ ...timestamps,
83
+ });
84
+
85
+ /**
86
+ * OIDC 授权记录
87
+ * oidc-provider 需要持久化的模型之一
88
+ */
89
+ export const oidcGrants = pgTable('oidc_grants', {
90
+ id: varchar('id', { length: 255 }).primaryKey(),
91
+ data: jsonb('data').notNull(),
92
+ expiresAt: timestamptz('expires_at').notNull(),
93
+ consumedAt: timestamptz('consumed_at'),
94
+ userId: text('user_id')
95
+ .references(() => users.id, { onDelete: 'cascade' })
96
+ .notNull(),
97
+ clientId: varchar('client_id', { length: 255 }).notNull(),
98
+ ...timestamps,
99
+ });
100
+
101
+ /**
102
+ * OIDC 客户端配置
103
+ * 存储 OIDC 客户端配置信息
104
+ */
105
+ export const oidcClients = pgTable('oidc_clients', {
106
+ id: varchar('id', { length: 255 }).primaryKey(), // client_id
107
+ name: text('name').notNull(),
108
+ description: text('description'),
109
+ clientSecret: varchar('client_secret', { length: 255 }), // 公共客户端可为 null
110
+ redirectUris: text('redirect_uris').array().notNull(),
111
+ grants: text('grants').array().notNull(),
112
+ responseTypes: text('response_types').array().notNull(),
113
+ scopes: text('scopes').array().notNull(),
114
+ tokenEndpointAuthMethod: varchar('token_endpoint_auth_method', { length: 20 }),
115
+ applicationType: varchar('application_type', { length: 20 }),
116
+ clientUri: text('client_uri'),
117
+ logoUri: text('logo_uri'),
118
+ policyUri: text('policy_uri'),
119
+ tosUri: text('tos_uri'),
120
+ isFirstParty: boolean('is_first_party').default(false),
121
+ ...timestamps,
122
+ });
123
+
124
+ /**
125
+ * OIDC 会话
126
+ * oidc-provider 需要持久化的模型之一
127
+ */
128
+ export const oidcSessions = pgTable('oidc_sessions', {
129
+ id: varchar('id', { length: 255 }).primaryKey(),
130
+ data: jsonb('data').notNull(),
131
+ expiresAt: timestamptz('expires_at').notNull(),
132
+ userId: text('user_id')
133
+ .references(() => users.id, { onDelete: 'cascade' })
134
+ .notNull(),
135
+ ...timestamps,
136
+ });
137
+
138
+ /**
139
+ * OIDC 授权同意记录
140
+ * 记录用户对客户端的授权同意历史
141
+ */
142
+ export const oidcConsents = pgTable(
143
+ 'oidc_consents',
144
+ {
145
+ userId: text('user_id')
146
+ .references(() => users.id, { onDelete: 'cascade' })
147
+ .notNull(),
148
+ clientId: varchar('client_id', { length: 255 })
149
+ .references(() => oidcClients.id, { onDelete: 'cascade' })
150
+ .notNull(),
151
+ scopes: text('scopes').array().notNull(),
152
+ expiresAt: timestamptz('expires_at'),
153
+ ...timestamps,
154
+ },
155
+ (table) => ({
156
+ pk: primaryKey({ columns: [table.userId, table.clientId] }),
157
+ }),
158
+ );
@@ -0,0 +1,503 @@
1
+ import type { AdapterUser } from '@auth/core/adapters';
2
+ import { eq } from 'drizzle-orm/expressions';
3
+ import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it } from 'vitest';
4
+
5
+ import { getTestDBInstance } from '@/database/core/dbForTest';
6
+ import { users } from '@/database/schemas';
7
+ import {
8
+ oidcAccessTokens,
9
+ oidcAuthorizationCodes,
10
+ oidcClients,
11
+ oidcDeviceCodes,
12
+ oidcGrants,
13
+ oidcInteractions,
14
+ oidcRefreshTokens,
15
+ oidcSessions,
16
+ } from '@/database/schemas/oidc';
17
+ import { LobeChatDatabase } from '@/database/type';
18
+ import { DrizzleAdapter } from '@/libs/oidc-provider/adapter';
19
+
20
+ let serverDB = await getTestDBInstance();
21
+
22
+ // 测试数据
23
+ const testModelName = 'Session';
24
+ const testId = 'test-id';
25
+ const testUserId = 'test-user-id';
26
+ const testClientId = 'test-client-id';
27
+ const testGrantId = 'test-grant-id';
28
+ const testUserCode = 'test-user-code';
29
+ const testExpires = new Date(Date.now() + 3600 * 1000); // 1小时后过期
30
+
31
+ beforeEach(async () => {
32
+ await serverDB.insert(users).values({ id: testUserId });
33
+ });
34
+
35
+ // 每次测试后清理数据
36
+ afterEach(async () => {
37
+ await serverDB.delete(users);
38
+ await serverDB.delete(oidcClients);
39
+ await serverDB.delete(oidcDeviceCodes);
40
+ await serverDB.delete(oidcInteractions);
41
+ });
42
+
43
+ describe('DrizzleAdapter', () => {
44
+ describe('constructor', () => {
45
+ it('应该正确创建适配器实例', () => {
46
+ const adapter = new DrizzleAdapter(testModelName, serverDB);
47
+ expect(adapter).toBeDefined();
48
+ });
49
+ });
50
+
51
+ describe('upsert', () => {
52
+ it('应该为Session模型创建新记录', async () => {
53
+ const adapter = new DrizzleAdapter('Session', serverDB);
54
+ const payload = {
55
+ accountId: testUserId,
56
+ cookie: 'cookie-value',
57
+ exp: Math.floor(Date.now() / 1000) + 3600,
58
+ };
59
+
60
+ await adapter.upsert(testId, payload, 3600);
61
+
62
+ const result = await serverDB.query.oidcSessions.findFirst({
63
+ where: eq(oidcSessions.id, testId),
64
+ });
65
+
66
+ expect(result).toBeDefined();
67
+ expect(result?.id).toBe(testId);
68
+ expect(result?.userId).toBe(testUserId);
69
+ expect(result?.data).toEqual(payload);
70
+ });
71
+
72
+ it('应该为Client模型创建新记录', async () => {
73
+ const adapter = new DrizzleAdapter('Client', serverDB);
74
+ const payload = {
75
+ client_id: testClientId,
76
+ client_uri: 'https://example.com',
77
+ application_type: 'web',
78
+ client_secret: 'secret',
79
+ grant_types: ['authorization_code', 'refresh_token'],
80
+ name: 'Test Client',
81
+ redirectUris: ['https://example.com/callback'],
82
+ response_types: ['code'],
83
+ scope: 'openid profile email',
84
+ token_endpoint_auth_method: 'client_secret_basic',
85
+ };
86
+
87
+ await adapter.upsert(testClientId, payload, 0);
88
+
89
+ const result = await serverDB.query.oidcClients.findFirst({
90
+ where: eq(oidcClients.id, testClientId),
91
+ });
92
+
93
+ expect(result).toBeDefined();
94
+ expect(result?.id).toBe(testClientId);
95
+ expect(result?.name).toBe(payload.name);
96
+ expect(result?.redirectUris).toEqual(payload.redirectUris);
97
+ expect(result?.scopes).toEqual(['openid', 'profile', 'email']);
98
+ });
99
+
100
+ it('应该为AccessToken模型创建新记录', async () => {
101
+ const adapter = new DrizzleAdapter('AccessToken', serverDB);
102
+ const payload = {
103
+ accountId: testUserId,
104
+ clientId: testClientId,
105
+ grantId: testGrantId,
106
+ scope: 'openid profile',
107
+ iat: Math.floor(Date.now() / 1000),
108
+ };
109
+
110
+ await adapter.upsert(testId, payload, 3600);
111
+
112
+ const result = await serverDB.query.oidcAccessTokens.findFirst({
113
+ where: eq(oidcAccessTokens.id, testId),
114
+ });
115
+
116
+ expect(result).toBeDefined();
117
+ expect(result?.id).toBe(testId);
118
+ expect(result?.userId).toBe(testUserId);
119
+ expect(result?.clientId).toBe(testClientId);
120
+ expect(result?.grantId).toBe(testGrantId);
121
+ expect(result?.data).toEqual(payload);
122
+ });
123
+
124
+ it('应该为DeviceCode模型创建新记录并包含userCode', async () => {
125
+ const adapter = new DrizzleAdapter('DeviceCode', serverDB);
126
+ const payload = {
127
+ clientId: testClientId,
128
+ userCode: testUserCode,
129
+ exp: Math.floor(Date.now() / 1000) + 3600,
130
+ };
131
+
132
+ await adapter.upsert(testId, payload, 3600);
133
+
134
+ const result = await serverDB.query.oidcDeviceCodes.findFirst({
135
+ where: eq(oidcDeviceCodes.id, testId),
136
+ });
137
+
138
+ expect(result).toBeDefined();
139
+ expect(result?.id).toBe(testId);
140
+ expect(result?.clientId).toBe(testClientId);
141
+ expect(result?.userCode).toBe(testUserCode);
142
+ expect(result?.data).toEqual(payload);
143
+ });
144
+
145
+ it('应该更新现有的Session记录', async () => {
146
+ const adapter = new DrizzleAdapter('Session', serverDB);
147
+ const initialPayload = { accountId: testUserId, cookie: 'initial-cookie' };
148
+ const updatedPayload = { accountId: testUserId, cookie: 'updated-cookie' };
149
+
150
+ // 初始插入
151
+ await adapter.upsert(testId, initialPayload, 3600);
152
+ let result = await serverDB.query.oidcSessions.findFirst({
153
+ where: eq(oidcSessions.id, testId),
154
+ });
155
+ expect(result?.data).toEqual(initialPayload);
156
+
157
+ // 更新
158
+ await adapter.upsert(testId, updatedPayload, 7200); // 新的过期时间
159
+ result = await serverDB.query.oidcSessions.findFirst({ where: eq(oidcSessions.id, testId) });
160
+ expect(result?.data).toEqual(updatedPayload);
161
+ // 验证 expiresAt 是否也更新了 (大约 2 小时后)
162
+ expect(result?.expiresAt).toBeInstanceOf(Date);
163
+ const expectedExpires = Date.now() + 7200 * 1000;
164
+ expect(result!.expiresAt!.getTime()).toBeGreaterThan(expectedExpires - 5000); // 允许 5 秒误差
165
+ expect(result!.expiresAt!.getTime()).toBeLessThan(expectedExpires + 5000);
166
+ });
167
+
168
+ it('应该更新现有的Client记录', async () => {
169
+ const adapter = new DrizzleAdapter('Client', serverDB);
170
+ const initialPayload = {
171
+ client_id: testClientId,
172
+ client_uri: 'https://initial.com',
173
+ name: 'Initial Client',
174
+ redirectUris: ['https://initial.com/callback'],
175
+ scopes: ['openid'],
176
+ };
177
+ const updatedPayload = {
178
+ ...initialPayload,
179
+ client_uri: 'https://updated.com',
180
+ name: 'Updated Client',
181
+ scopes: ['openid', 'profile'], // 假设 scope 格式是空格分隔字符串
182
+ scope: 'openid profile',
183
+ redirectUris: ['https://updated.com/callback'],
184
+ };
185
+
186
+ // 初始插入
187
+ await adapter.upsert(testClientId, initialPayload, 0);
188
+ let result = await serverDB.query.oidcClients.findFirst({
189
+ where: eq(oidcClients.id, testClientId),
190
+ });
191
+
192
+ expect(result?.name).toBe('Initial Client');
193
+ expect(result?.clientUri).toBe('https://initial.com');
194
+ expect(result?.scopes).toEqual(['openid']);
195
+
196
+ // 更新
197
+ await adapter.upsert(testClientId, updatedPayload, 0);
198
+ result = await serverDB.query.oidcClients.findFirst({
199
+ where: eq(oidcClients.id, testClientId),
200
+ });
201
+ expect(result?.name).toBe('Updated Client');
202
+ expect(result?.clientUri).toBe('https://updated.com');
203
+ expect(result?.scopes).toEqual(['openid', 'profile']); // 验证数据库中存储的是数组
204
+ expect(result?.redirectUris).toEqual(['https://updated.com/callback']);
205
+ });
206
+ });
207
+
208
+ describe('find', () => {
209
+ it('应该找到存在的记录', async () => {
210
+ // 先创建一个记录
211
+ const adapter = new DrizzleAdapter('Session', serverDB);
212
+ const payload = {
213
+ accountId: testUserId,
214
+ cookie: 'cookie-value',
215
+ exp: Math.floor(Date.now() / 1000) + 3600,
216
+ };
217
+
218
+ await adapter.upsert(testId, payload, 3600);
219
+
220
+ // 然后查找它
221
+ const result = await adapter.find(testId);
222
+
223
+ expect(result).toBeDefined();
224
+ expect(result).toEqual(payload);
225
+ });
226
+
227
+ it('应该为Client模型返回正确的格式', async () => {
228
+ // 先创建一个Client记录
229
+ const adapter = new DrizzleAdapter('Client', serverDB);
230
+ const payload = {
231
+ client_id: testClientId,
232
+ client_uri: 'https://example.com',
233
+ application_type: 'web',
234
+ client_secret: 'secret',
235
+ grant_types: ['authorization_code', 'refresh_token'],
236
+ name: 'Test Client',
237
+ redirectUris: ['https://example.com/callback'],
238
+ response_types: ['code'],
239
+ scope: 'openid profile email',
240
+ token_endpoint_auth_method: 'client_secret_basic',
241
+ };
242
+
243
+ await adapter.upsert(testClientId, payload, 0);
244
+
245
+ // 然后查找它
246
+ const result = await adapter.find(testClientId);
247
+
248
+ expect(result).toBeDefined();
249
+ expect(result.client_id).toBe(testClientId);
250
+ expect(result.client_secret).toBe(payload.client_secret);
251
+ expect(result.redirect_uris).toEqual(payload.redirectUris);
252
+ expect(result.scope).toBe(payload.scope);
253
+ });
254
+
255
+ it('应该返回undefined如果记录不存在', async () => {
256
+ const adapter = new DrizzleAdapter('Session', serverDB);
257
+ const result = await adapter.find('non-existent-id');
258
+ expect(result).toBeUndefined();
259
+ });
260
+
261
+ it('应该返回undefined如果记录已过期', async () => {
262
+ // 创建一个过期的记录(过期时间设为过去)
263
+ const adapter = new DrizzleAdapter('Session', serverDB);
264
+ const payload = {
265
+ accountId: testUserId,
266
+ cookie: 'cookie-value',
267
+ exp: Math.floor(Date.now() / 1000) - 3600, // 1小时前
268
+ };
269
+
270
+ // 负的过期时间表示立即过期
271
+ await adapter.upsert(testId, payload, -1);
272
+
273
+ // 等待一小段时间确保过期
274
+ await new Promise((resolve) => setTimeout(resolve, 10));
275
+
276
+ // 然后查找它
277
+ const result = await adapter.find(testId);
278
+
279
+ expect(result).toBeUndefined();
280
+ });
281
+
282
+ it('应该返回undefined如果记录已被消费', async () => {
283
+ const adapter = new DrizzleAdapter('AccessToken', serverDB);
284
+ const payload = { accountId: testUserId, clientId: testClientId };
285
+ await adapter.upsert(testId, payload, 3600);
286
+
287
+ // 消费记录
288
+ await adapter.consume(testId);
289
+
290
+ // 查找已消费记录
291
+ const result = await adapter.find(testId);
292
+ expect(result).toBeUndefined();
293
+ });
294
+ });
295
+
296
+ describe('findByUserCode', () => {
297
+ it('应该通过userCode找到DeviceCode记录', async () => {
298
+ // 先创建一个DeviceCode记录
299
+ const adapter = new DrizzleAdapter('DeviceCode', serverDB);
300
+ const payload = {
301
+ clientId: testClientId,
302
+ userCode: testUserCode,
303
+ exp: Math.floor(Date.now() / 1000) + 3600,
304
+ };
305
+
306
+ await adapter.upsert(testId, payload, 3600);
307
+
308
+ // 然后通过userCode查找它
309
+ const result = await adapter.findByUserCode(testUserCode);
310
+
311
+ expect(result).toBeDefined();
312
+ expect(result).toEqual(payload);
313
+ });
314
+
315
+ it('应该返回undefined如果DeviceCode记录已过期', async () => {
316
+ const adapter = new DrizzleAdapter('DeviceCode', serverDB);
317
+ const payload = { clientId: testClientId, userCode: testUserCode };
318
+ // 使用负数 expiresIn 使其立即过期
319
+ await adapter.upsert(testId, payload, -1);
320
+ await new Promise((resolve) => setTimeout(resolve, 10)); // 短暂等待确保过期
321
+
322
+ const result = await adapter.findByUserCode(testUserCode);
323
+ expect(result).toBeUndefined();
324
+ });
325
+
326
+ it('应该返回undefined如果DeviceCode记录已被消费', async () => {
327
+ const adapter = new DrizzleAdapter('DeviceCode', serverDB);
328
+ const payload = { clientId: testClientId, userCode: testUserCode };
329
+ await adapter.upsert(testId, payload, 3600);
330
+
331
+ // 消费记录
332
+ await adapter.consume(testId);
333
+
334
+ // 查找已消费记录
335
+ const result = await adapter.findByUserCode(testUserCode);
336
+ expect(result).toBeUndefined();
337
+ });
338
+
339
+ it('应该在非DeviceCode模型上抛出错误', async () => {
340
+ const adapter = new DrizzleAdapter('Session', serverDB);
341
+ await expect(adapter.findByUserCode(testUserCode)).rejects.toThrow();
342
+ });
343
+ });
344
+
345
+ describe('findSessionByUserId', () => {
346
+ it('应该通过userId找到Session记录', async () => {
347
+ // 先创建一个Session记录
348
+ const adapter = new DrizzleAdapter('Session', serverDB);
349
+ const payload = {
350
+ accountId: testUserId,
351
+ cookie: 'cookie-value',
352
+ exp: Math.floor(Date.now() / 1000) + 3600,
353
+ };
354
+
355
+ await adapter.upsert(testId, payload, 3600);
356
+
357
+ // 然后通过userId查找它
358
+ const result = await adapter.findSessionByUserId(testUserId);
359
+
360
+ expect(result).toBeDefined();
361
+ expect(result).toEqual(payload);
362
+ });
363
+
364
+ it('应该在非Session模型上返回undefined', async () => {
365
+ const adapter = new DrizzleAdapter('AccessToken', serverDB);
366
+ const result = await adapter.findSessionByUserId(testUserId);
367
+ expect(result).toBeUndefined();
368
+ });
369
+ });
370
+
371
+ describe('destroy', () => {
372
+ it('应该删除存在的记录', async () => {
373
+ // 先创建一个记录
374
+ const adapter = new DrizzleAdapter('Session', serverDB);
375
+ const payload = {
376
+ accountId: testUserId,
377
+ cookie: 'cookie-value',
378
+ exp: Math.floor(Date.now() / 1000) + 3600,
379
+ };
380
+
381
+ await adapter.upsert(testId, payload, 3600);
382
+
383
+ // 确认记录存在
384
+ let result = await serverDB.query.oidcSessions.findFirst({
385
+ where: eq(oidcSessions.id, testId),
386
+ });
387
+ expect(result).toBeDefined();
388
+
389
+ // 删除记录
390
+ await adapter.destroy(testId);
391
+
392
+ // 验证记录已被删除
393
+ result = await serverDB.query.oidcSessions.findFirst({
394
+ where: eq(oidcSessions.id, testId),
395
+ });
396
+ expect(result).toBeUndefined();
397
+ });
398
+ });
399
+
400
+ describe('consume', () => {
401
+ it('应该标记记录为已消费', async () => {
402
+ // 先创建一个记录
403
+ const adapter = new DrizzleAdapter('AccessToken', serverDB);
404
+ const payload = {
405
+ accountId: testUserId,
406
+ clientId: testClientId,
407
+ exp: Math.floor(Date.now() / 1000) + 3600,
408
+ };
409
+
410
+ await adapter.upsert(testId, payload, 3600);
411
+
412
+ // 消费记录
413
+ await adapter.consume(testId);
414
+
415
+ // 验证记录已被标记为已消费
416
+ const result = await serverDB.query.oidcAccessTokens.findFirst({
417
+ where: eq(oidcAccessTokens.id, testId),
418
+ });
419
+
420
+ expect(result).toBeDefined();
421
+ expect(result?.consumedAt).not.toBeNull();
422
+ });
423
+ });
424
+
425
+ describe('revokeByGrantId', () => {
426
+ it('应该撤销与指定 grantId 相关的所有记录', async () => {
427
+ // 创建AccessToken记录
428
+ const accessTokenAdapter = new DrizzleAdapter('AccessToken', serverDB);
429
+ const accessTokenPayload = {
430
+ accountId: testUserId,
431
+ clientId: testClientId,
432
+ grantId: testGrantId,
433
+ exp: Math.floor(Date.now() / 1000) + 3600,
434
+ };
435
+ await accessTokenAdapter.upsert(testId, accessTokenPayload, 3600);
436
+
437
+ // 创建RefreshToken记录
438
+ const refreshTokenAdapter = new DrizzleAdapter('RefreshToken', serverDB);
439
+ const refreshTokenPayload = {
440
+ accountId: testUserId,
441
+ clientId: testClientId,
442
+ grantId: testGrantId,
443
+ exp: Math.floor(Date.now() / 1000) + 3600,
444
+ };
445
+ await refreshTokenAdapter.upsert('refresh-' + testId, refreshTokenPayload, 3600);
446
+
447
+ // 撤销与testGrantId相关的所有记录
448
+ await accessTokenAdapter.revokeByGrantId(testGrantId);
449
+
450
+ // 验证记录已被删除
451
+ const accessTokenResult = await serverDB.query.oidcAccessTokens.findFirst({
452
+ where: eq(oidcAccessTokens.id, testId),
453
+ });
454
+
455
+ expect(accessTokenResult).toBeUndefined();
456
+
457
+ const refreshTokenResult = await serverDB.query.oidcRefreshTokens.findFirst({
458
+ where: eq(oidcRefreshTokens.id, `refresh-${testId}`),
459
+ });
460
+ console.log('refreshTokenResult:', refreshTokenResult);
461
+ expect(refreshTokenResult).toBeUndefined();
462
+ });
463
+
464
+ it('应该在Grant模型上直接返回', async () => {
465
+ // Grant模型不需要通过grantId来撤销
466
+ const adapter = new DrizzleAdapter('Grant', serverDB);
467
+ await adapter.revokeByGrantId(testGrantId);
468
+ // 如果没有抛出错误,测试通过
469
+ });
470
+ });
471
+
472
+ describe('createAdapterFactory', () => {
473
+ it('应该创建一个适配器工厂函数', () => {
474
+ const factory = DrizzleAdapter.createAdapterFactory(serverDB as any);
475
+ expect(factory).toBeDefined();
476
+ expect(typeof factory).toBe('function');
477
+
478
+ const adapter = factory('Session');
479
+ expect(adapter).toBeDefined();
480
+ expect(adapter).toBeInstanceOf(DrizzleAdapter);
481
+ });
482
+ });
483
+
484
+ describe('getTable (indirectly via public methods)', () => {
485
+ it('当使用不支持的模型名称时应该抛出错误', async () => {
486
+ const invalidAdapter = new DrizzleAdapter('InvalidModelName', serverDB);
487
+ // 调用一个会触发 getTable 的方法
488
+ await expect(invalidAdapter.find('any-id')).rejects.toThrow('不支持的模型: InvalidModelName');
489
+ await expect(invalidAdapter.upsert('any-id', {}, 3600)).rejects.toThrow(
490
+ '不支持的模型: InvalidModelName',
491
+ );
492
+ await expect(invalidAdapter.destroy('any-id')).rejects.toThrow(
493
+ '不支持的模型: InvalidModelName',
494
+ );
495
+ await expect(invalidAdapter.consume('any-id')).rejects.toThrow(
496
+ '不支持的模型: InvalidModelName',
497
+ );
498
+ await expect(invalidAdapter.revokeByGrantId('any-grant-id')).rejects.toThrow(
499
+ '不支持的模型: InvalidModelName',
500
+ );
501
+ });
502
+ });
503
+ });