@lobehub/chat 1.79.7 → 1.79.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. package/.eslintrc.js +1 -0
  2. package/CHANGELOG.md +58 -0
  3. package/changelog/v1.json +18 -0
  4. package/docs/development/database-schema.dbml +119 -0
  5. package/locales/ar/models.json +12 -0
  6. package/locales/ar/oauth.json +40 -0
  7. package/locales/bg-BG/models.json +12 -0
  8. package/locales/bg-BG/oauth.json +40 -0
  9. package/locales/de-DE/models.json +12 -0
  10. package/locales/de-DE/oauth.json +40 -0
  11. package/locales/en-US/models.json +12 -0
  12. package/locales/en-US/oauth.json +40 -0
  13. package/locales/es-ES/models.json +12 -0
  14. package/locales/es-ES/oauth.json +40 -0
  15. package/locales/fa-IR/models.json +12 -0
  16. package/locales/fa-IR/oauth.json +40 -0
  17. package/locales/fr-FR/models.json +12 -0
  18. package/locales/fr-FR/oauth.json +40 -0
  19. package/locales/it-IT/models.json +12 -0
  20. package/locales/it-IT/oauth.json +40 -0
  21. package/locales/ja-JP/models.json +12 -0
  22. package/locales/ja-JP/oauth.json +40 -0
  23. package/locales/ko-KR/models.json +12 -0
  24. package/locales/ko-KR/oauth.json +40 -0
  25. package/locales/nl-NL/models.json +12 -0
  26. package/locales/nl-NL/oauth.json +40 -0
  27. package/locales/pl-PL/models.json +12 -0
  28. package/locales/pl-PL/oauth.json +40 -0
  29. package/locales/pt-BR/models.json +12 -0
  30. package/locales/pt-BR/oauth.json +40 -0
  31. package/locales/ru-RU/models.json +12 -0
  32. package/locales/ru-RU/oauth.json +40 -0
  33. package/locales/tr-TR/models.json +12 -0
  34. package/locales/tr-TR/oauth.json +40 -0
  35. package/locales/vi-VN/models.json +12 -0
  36. package/locales/vi-VN/oauth.json +40 -0
  37. package/locales/zh-CN/models.json +12 -0
  38. package/locales/zh-CN/oauth.json +40 -0
  39. package/locales/zh-TW/models.json +12 -0
  40. package/locales/zh-TW/oauth.json +40 -0
  41. package/package.json +4 -1
  42. package/scripts/generate-oidc-jwk.mjs +59 -0
  43. package/scripts/migrateServerDB/index.ts +3 -1
  44. package/src/app/(backend)/oidc/[...oidc]/route.ts +96 -0
  45. package/src/app/(backend)/oidc/consent/route.ts +131 -0
  46. package/src/app/(backend)/trpc/async/[trpc]/route.ts +1 -1
  47. package/src/app/(backend)/trpc/edge/[trpc]/route.ts +2 -2
  48. package/src/app/(backend)/trpc/lambda/[trpc]/route.ts +2 -2
  49. package/src/app/(backend)/trpc/tools/[trpc]/route.ts +2 -2
  50. package/src/app/[variants]/(main)/files/[id]/page.tsx +1 -1
  51. package/src/app/[variants]/oauth/consent/[uid]/Client.tsx +224 -0
  52. package/src/app/[variants]/oauth/consent/[uid]/ClientError.tsx +46 -0
  53. package/src/app/[variants]/oauth/consent/[uid]/failed/page.tsx +36 -0
  54. package/src/app/[variants]/oauth/consent/[uid]/page.tsx +69 -0
  55. package/src/app/[variants]/oauth/consent/[uid]/success/page.tsx +30 -0
  56. package/src/components/Branding/ProductLogo/index.tsx +6 -1
  57. package/src/config/aiModels/openai.ts +63 -41
  58. package/src/database/client/migrations.json +27 -8
  59. package/src/database/migrations/0020_add_oidc.sql +124 -0
  60. package/src/database/migrations/meta/0020_snapshot.json +4975 -0
  61. package/src/database/migrations/meta/_journal.json +7 -0
  62. package/src/database/repositories/tableViewer/index.test.ts +1 -1
  63. package/src/database/schemas/index.ts +1 -0
  64. package/src/database/schemas/oidc.ts +158 -0
  65. package/src/database/server/models/__tests__/adapter.test.ts +499 -0
  66. package/src/envs/oidc.ts +18 -0
  67. package/src/libs/agent-runtime/azureOpenai/index.ts +4 -1
  68. package/src/libs/agent-runtime/utils/streams/protocol.ts +2 -4
  69. package/src/libs/oidc-provider/adapter.ts +541 -0
  70. package/src/libs/oidc-provider/config.ts +52 -0
  71. package/src/libs/oidc-provider/http-adapter.ts +311 -0
  72. package/src/libs/oidc-provider/interaction-policy.ts +37 -0
  73. package/src/libs/oidc-provider/provider.ts +288 -0
  74. package/src/libs/trpc/async/init.ts +1 -1
  75. package/src/{server → libs/trpc/edge}/context.ts +2 -2
  76. package/src/libs/trpc/{index.ts → edge/index.ts} +8 -8
  77. package/src/libs/trpc/{init.ts → edge/init.ts} +2 -2
  78. package/src/libs/trpc/{middleware → edge/middleware}/jwtPayload.test.ts +3 -3
  79. package/src/libs/trpc/{middleware → edge/middleware}/jwtPayload.ts +3 -2
  80. package/src/libs/trpc/lambda/context.ts +70 -0
  81. package/src/libs/trpc/lambda/index.ts +39 -1
  82. package/src/libs/trpc/lambda/init.ts +26 -0
  83. package/src/libs/trpc/lambda/middleware/index.ts +2 -0
  84. package/src/libs/trpc/{middleware → lambda/middleware}/keyVaults.ts +2 -1
  85. package/src/libs/trpc/lambda/{serverDatabase.ts → middleware/serverDatabase.ts} +2 -1
  86. package/src/libs/trpc/middleware/userAuth.test.ts +3 -3
  87. package/src/libs/trpc/middleware/userAuth.ts +1 -1
  88. package/src/libs/trpc/mock.ts +7 -0
  89. package/src/locales/default/index.ts +2 -0
  90. package/src/locales/default/oauth.ts +43 -0
  91. package/src/middleware.ts +94 -6
  92. package/src/server/routers/edge/appStatus.ts +1 -1
  93. package/src/server/routers/edge/config/index.test.ts +2 -3
  94. package/src/server/routers/edge/config/index.ts +1 -1
  95. package/src/server/routers/edge/index.ts +1 -1
  96. package/src/server/routers/edge/upload.ts +1 -1
  97. package/src/server/routers/lambda/_template.ts +2 -2
  98. package/src/server/routers/lambda/agent.ts +2 -2
  99. package/src/server/routers/lambda/aiModel.ts +2 -2
  100. package/src/server/routers/lambda/aiProvider.ts +2 -2
  101. package/src/server/routers/lambda/chunk.ts +2 -3
  102. package/src/server/routers/lambda/exporter.ts +2 -2
  103. package/src/server/routers/lambda/file.ts +2 -2
  104. package/src/server/routers/lambda/importer.ts +2 -2
  105. package/src/server/routers/lambda/index.ts +1 -1
  106. package/src/server/routers/lambda/knowledgeBase.ts +2 -2
  107. package/src/server/routers/lambda/message.ts +2 -2
  108. package/src/server/routers/lambda/plugin.ts +2 -2
  109. package/src/server/routers/lambda/ragEval.ts +2 -3
  110. package/src/server/routers/lambda/session.ts +2 -2
  111. package/src/server/routers/lambda/sessionGroup.ts +2 -2
  112. package/src/server/routers/lambda/thread.ts +2 -2
  113. package/src/server/routers/lambda/topic.ts +2 -2
  114. package/src/server/routers/lambda/user.ts +2 -2
  115. package/src/server/routers/tools/__tests__/search.test.ts +2 -2
  116. package/src/server/routers/tools/index.ts +1 -1
  117. package/src/server/routers/tools/search.ts +2 -1
  118. package/src/server/services/oidc/index.ts +64 -0
  119. package/src/server/services/oidc/oidcProvider.ts +25 -0
  120. package/src/server/mock.ts +0 -8
  121. /package/src/{server/asyncContext.ts → libs/trpc/async/context.ts} +0 -0
@@ -0,0 +1,311 @@
1
+ import debug from 'debug';
2
+ import { cookies } from 'next/headers';
3
+ import { NextRequest } from 'next/server';
4
+ import { IncomingMessage, ServerResponse } from 'node:http';
5
+ import urlJoin from 'url-join';
6
+
7
+ import { appEnv } from '@/config/app';
8
+
9
+ const log = debug('lobe-oidc:http-adapter');
10
+
11
+ /**
12
+ * 将 Next.js 请求头转换为标准 Node.js HTTP 头格式
13
+ */
14
+ export const convertHeadersToNodeHeaders = (nextHeaders: Headers): Record<string, string> => {
15
+ const headers: Record<string, string> = {};
16
+ nextHeaders.forEach((value, key) => {
17
+ headers[key] = value;
18
+ });
19
+ return headers;
20
+ };
21
+
22
+ /**
23
+ * 创建用于 OIDC Provider 的 Node.js HTTP 请求对象
24
+ * @param req Next.js 请求对象
25
+ */
26
+ export const createNodeRequest = async (req: NextRequest): Promise<IncomingMessage> => {
27
+ // 构建 URL 对象
28
+ const url = new URL(req.url);
29
+
30
+ // 计算相对于前缀的路径
31
+ let providerPath = url.pathname;
32
+
33
+ // 确保路径始终以/开头
34
+ if (!providerPath.startsWith('/')) {
35
+ providerPath = '/' + providerPath;
36
+ }
37
+
38
+ log('Creating Node.js request from Next.js request');
39
+ log('Original path: %s, Provider path: %s', url.pathname, providerPath);
40
+
41
+ // Attempt to parse and attach body for relevant methods
42
+ let parsedBody: any = undefined;
43
+ const methodsWithBody = ['POST', 'PUT', 'PATCH', 'DELETE'];
44
+ if (methodsWithBody.includes(req.method)) {
45
+ const contentType = req.headers.get('content-type')?.split(';')[0]; // Get content type without charset etc.
46
+ log(`Attempting to parse body for ${req.method} with Content-Type: ${contentType}`);
47
+ try {
48
+ // Check if body exists and has size before attempting to parse
49
+ if (req.body && req.headers.get('content-length') !== '0') {
50
+ if (contentType === 'application/x-www-form-urlencoded') {
51
+ const formData = await req.formData();
52
+ parsedBody = {};
53
+ formData.forEach((value, key) => {
54
+ // If a key appears multiple times, keep the last one (standard form behavior)
55
+ // Or convert to array if oidc-provider expects it:
56
+ // if (parsedBody[key]) {
57
+ // if (!Array.isArray(parsedBody[key])) parsedBody[key] = [parsedBody[key]];
58
+ // parsedBody[key].push(value);
59
+ // } else {
60
+ // parsedBody[key] = value;
61
+ // }
62
+ parsedBody[key] = value;
63
+ });
64
+ log('Parsed form data body: %O', parsedBody);
65
+ } else if (contentType === 'application/json') {
66
+ parsedBody = await req.json();
67
+ log('Parsed JSON body: %O', parsedBody);
68
+ } else {
69
+ log(`Body parsing skipped for Content-Type: ${contentType}. Trying text() as fallback.`);
70
+ // Fallback: try reading as text if content type is unknown but body exists
71
+ parsedBody = await req.text();
72
+ log('Parsed body as text fallback.');
73
+ }
74
+ } else {
75
+ log('Request has no body or content-length is 0, skipping parsing.');
76
+ }
77
+ } catch (error) {
78
+ log('Error parsing request body: %O', error);
79
+ // Keep parsedBody as undefined, let oidc-provider handle the potential issue
80
+ }
81
+ }
82
+ const nodeRequest = {
83
+ // 基本属性
84
+ headers: convertHeadersToNodeHeaders(req.headers),
85
+
86
+ method: req.method,
87
+ // 模拟可读流行为 (oidc-provider might not rely on this if body is pre-parsed)
88
+ // eslint-disable-next-line @typescript-eslint/ban-types
89
+ on: (event: string, handler: Function) => {
90
+ if (event === 'end') {
91
+ // Simulate end immediately as body is already processed or will be attached
92
+ handler();
93
+ }
94
+ },
95
+ // 添加 Node.js 服务器所期望的额外属性
96
+ socket: {
97
+ remoteAddress: req.headers.get('x-forwarded-for') || '127.0.0.1',
98
+ },
99
+ url: providerPath + url.search,
100
+ ...(parsedBody !== undefined && { body: parsedBody }), // Attach body if it exists
101
+ };
102
+
103
+ log('Node.js request created with method %s and path %s', nodeRequest.method, nodeRequest.url);
104
+ if (nodeRequest.body) {
105
+ log('Attached parsed body to Node.js request.');
106
+ }
107
+ // Cast back to IncomingMessage for the function's return signature
108
+ return nodeRequest as unknown as IncomingMessage;
109
+ };
110
+
111
+ /**
112
+ * 响应收集器接口,用于捕获 OIDC Provider 的响应
113
+ */
114
+ export interface ResponseCollector {
115
+ nodeResponse: ServerResponse;
116
+ readonly responseBody: string | Buffer;
117
+ readonly responseHeaders: Record<string, string | string[]>;
118
+ readonly responseStatus: number;
119
+ }
120
+
121
+ /**
122
+ * 创建用于 OIDC Provider 的 Node.js HTTP 响应对象
123
+ * @param resolvePromise 当响应完成时调用的解析函数
124
+ */
125
+ export const createNodeResponse = (resolvePromise: () => void): ResponseCollector => {
126
+ log('Creating Node.js response collector');
127
+
128
+ // 存储响应状态的对象
129
+ const state = {
130
+ responseBody: '' as string | Buffer,
131
+ responseHeaders: {} as Record<string, string | string[]>,
132
+ responseStatus: 200,
133
+ };
134
+
135
+ let promiseResolved = false;
136
+
137
+ const nodeResponse: any = {
138
+ end: (chunk?: string | Buffer) => {
139
+ log('NodeResponse.end called');
140
+ if (chunk) {
141
+ log('NodeResponse.end chunk: %s', typeof chunk === 'string' ? chunk : '(Buffer)');
142
+ // @ts-ignore
143
+ state.responseBody += chunk;
144
+ }
145
+
146
+ const locationHeader = state.responseHeaders['location'];
147
+ if (locationHeader && state.responseStatus === 200) {
148
+ log('Location header detected with status 200, overriding to 302');
149
+ state.responseStatus = 302;
150
+ }
151
+
152
+ if (!promiseResolved) {
153
+ log('Resolving response promise');
154
+ promiseResolved = true;
155
+ resolvePromise();
156
+ }
157
+ },
158
+
159
+ getHeader: (name: string) => {
160
+ const lowerName = name.toLowerCase();
161
+ return state.responseHeaders[lowerName];
162
+ },
163
+
164
+ getHeaderNames: () => {
165
+ return Object.keys(state.responseHeaders);
166
+ },
167
+
168
+ getHeaders: () => {
169
+ return state.responseHeaders;
170
+ },
171
+
172
+ headersSent: false,
173
+
174
+ removeHeader: (name: string) => {
175
+ const lowerName = name.toLowerCase();
176
+ log('Removing header: %s', lowerName);
177
+ delete state.responseHeaders[lowerName];
178
+ },
179
+
180
+ setHeader: (name: string, value: string | string[]) => {
181
+ const lowerName = name.toLowerCase();
182
+ log('Setting header: %s = %s', lowerName, value);
183
+ state.responseHeaders[lowerName] = value;
184
+ },
185
+
186
+ write: (chunk: string | Buffer) => {
187
+ log('NodeResponse.write called with chunk');
188
+ // @ts-ignore
189
+ state.responseBody += chunk;
190
+ },
191
+
192
+ writeHead: (status: number, headers?: Record<string, string | string[]>) => {
193
+ log('NodeResponse.writeHead called with status: %d', status);
194
+ state.responseStatus = status;
195
+
196
+ if (headers) {
197
+ const lowerCaseHeaders = Object.entries(headers).reduce(
198
+ (acc, [key, value]) => {
199
+ acc[key.toLowerCase()] = value;
200
+ return acc;
201
+ },
202
+ {} as Record<string, string | string[]>,
203
+ );
204
+ state.responseHeaders = { ...state.responseHeaders, ...lowerCaseHeaders };
205
+ }
206
+
207
+ (nodeResponse as any).headersSent = true;
208
+ },
209
+ } as unknown as ServerResponse;
210
+
211
+ log('Node.js response collector created successfully');
212
+
213
+ return {
214
+ nodeResponse,
215
+ get responseBody() {
216
+ return state.responseBody;
217
+ },
218
+ get responseHeaders() {
219
+ return state.responseHeaders;
220
+ },
221
+ get responseStatus() {
222
+ return state.responseStatus;
223
+ },
224
+ };
225
+ };
226
+
227
+ /**
228
+ * 创建用于调用 provider.interactionDetails 的上下文 (req, res)
229
+ * @param uid 交互 ID
230
+ */
231
+ export const createContextForInteractionDetails = async (
232
+ uid: string,
233
+ ): Promise<{ req: IncomingMessage; res: ServerResponse }> => {
234
+ log('Creating context for interaction details for uid: %s', uid);
235
+ const baseUrl = appEnv.APP_URL!;
236
+ log('Using base URL: %s', baseUrl);
237
+
238
+ // 从baseUrl提取主机名用于headers
239
+ const hostName = new URL(baseUrl).host;
240
+
241
+ // 1. 获取真实的 Cookies
242
+ const cookieStore = await cookies();
243
+ const realCookies: Record<string, string> = {};
244
+ cookieStore.getAll().forEach((cookie) => {
245
+ realCookies[cookie.name] = cookie.value;
246
+ });
247
+ log('Real cookies found: %o', Object.keys(realCookies));
248
+
249
+ // 特别检查交互会话cookie
250
+ const interactionCookieName = `_interaction_${uid}`;
251
+ if (realCookies[interactionCookieName]) {
252
+ log('Found interaction session cookie: %s', interactionCookieName);
253
+ } else {
254
+ log('Warning: Interaction session cookie not found: %s', interactionCookieName);
255
+ }
256
+
257
+ // 2. 构建包含真实 Cookie 的 Headers
258
+ const headers = new Headers({ host: hostName });
259
+ const cookieString = Object.entries(realCookies)
260
+ .map(([name, value]) => `${name}=${value}`)
261
+ .join('; ');
262
+ if (cookieString) {
263
+ headers.set('cookie', cookieString);
264
+ log('Setting cookie header');
265
+ } else {
266
+ log('No cookies found to set in header');
267
+ }
268
+
269
+ // 3. 创建模拟的 NextRequest
270
+ // 注意:这里的 IP, geo, ua 等信息可能是 oidc-provider 某些特性需要的,
271
+ // 如果遇到相关问题,可能需要从真实请求头中提取 (e.g., 'x-forwarded-for', 'user-agent')
272
+ const interactionUrl = urlJoin(baseUrl, `/oauth/consent/${uid}`);
273
+ log('Creating interaction URL: %s', interactionUrl);
274
+
275
+ const mockNextRequest = {
276
+ cookies: {
277
+ // 模拟 NextRequestCookies 接口
278
+ get: (name: string) => cookieStore.get(name)?.value,
279
+ getAll: () => cookieStore.getAll(),
280
+ has: (name: string) => cookieStore.has(name),
281
+ },
282
+ geo: {},
283
+ headers: headers,
284
+ ip: '127.0.0.1',
285
+ method: 'GET',
286
+ nextUrl: new URL(interactionUrl),
287
+ page: { name: undefined, params: undefined },
288
+ ua: undefined,
289
+ url: new URL(interactionUrl),
290
+ } as unknown as NextRequest;
291
+ log('Mock NextRequest created for url: %s', mockNextRequest.url);
292
+
293
+ // 4. 使用 createNodeRequest 创建模拟的 Node.js IncomingMessage
294
+ // pathPrefix 设置为 '/' 因为我们的 URL 已经是 Provider 期望的路径格式 /interaction/:uid
295
+ const req: IncomingMessage = await createNodeRequest(mockNextRequest);
296
+ // @ts-ignore - 将解析出的 cookies 附加到模拟的 Node.js 请求上
297
+ req.cookies = realCookies;
298
+ log('Node.js IncomingMessage created, attached real cookies');
299
+
300
+ // 5. 使用 createNodeResponse 创建模拟的 Node.js ServerResponse
301
+ let resolveFunc: () => void;
302
+ new Promise<void>((resolve) => {
303
+ resolveFunc = resolve;
304
+ });
305
+
306
+ const responseCollector: ResponseCollector = createNodeResponse(() => resolveFunc());
307
+ const res: ServerResponse = responseCollector.nodeResponse;
308
+ log('Node.js ServerResponse created');
309
+
310
+ return { req, res };
311
+ };
@@ -0,0 +1,37 @@
1
+ import debug from 'debug';
2
+ import { interactionPolicy } from 'oidc-provider';
3
+
4
+ const { base } = interactionPolicy; // Import Check and base
5
+ const log = debug('lobe-oidc:interaction-policy');
6
+
7
+ /**
8
+ * 创建自定义交互策略
9
+ */
10
+ export const createInteractionPolicy = () => {
11
+ log('Creating custom interaction policy');
12
+ const policy = base();
13
+
14
+ log('Base policy details: %O', {
15
+ promptNames: Array.from(policy.keys()),
16
+ size: policy.length,
17
+ });
18
+
19
+ const loginPrompt = policy.get('login');
20
+ log('Accessing login prompt from policy: %O', !!loginPrompt);
21
+
22
+ if (loginPrompt) {
23
+ log('Login prompt details: %O', {
24
+ checks: Array.from(loginPrompt.checks.keys()),
25
+ name: loginPrompt.name,
26
+ requestable: loginPrompt.requestable,
27
+ });
28
+ } else {
29
+ console.warn(
30
+ "Could not find 'login' prompt in the base policy. Custom session check not applied.",
31
+ );
32
+ log('WARNING: login prompt not found in base policy');
33
+ }
34
+
35
+ log('Custom interaction policy created successfully');
36
+ return policy;
37
+ };
@@ -0,0 +1,288 @@
1
+ import debug from 'debug';
2
+ import Provider, { Configuration, KoaContextWithOIDC } from 'oidc-provider';
3
+ import urlJoin from 'url-join';
4
+
5
+ import { appEnv } from '@/config/app';
6
+ import { serverDBEnv } from '@/config/db';
7
+ import { UserModel } from '@/database/models/user';
8
+ import { LobeChatDatabase } from '@/database/type';
9
+ import { oidcEnv } from '@/envs/oidc';
10
+
11
+ import { DrizzleAdapter } from './adapter';
12
+ import { defaultClaims, defaultClients, defaultScopes } from './config';
13
+ import { createInteractionPolicy } from './interaction-policy';
14
+
15
+ const logProvider = debug('lobe-oidc:provider'); // <--- 添加 provider 日志实例
16
+
17
+ /**
18
+ * 从环境变量中获取 JWKS
19
+ * 该 JWKS 是一个包含 RS256 私钥的 JSON 对象
20
+ */
21
+ const getJWKS = (): object => {
22
+ try {
23
+ const jwksString = oidcEnv.OIDC_JWKS_KEY;
24
+
25
+ if (!jwksString) {
26
+ throw new Error(
27
+ 'OIDC_JWKS_KEY 环境变量是必需的。请使用 scripts/generate-oidc-jwk.mjs 生成 JWKS。',
28
+ );
29
+ }
30
+
31
+ // 尝试解析 JWKS JSON 字符串
32
+ const jwks = JSON.parse(jwksString);
33
+
34
+ // 检查 JWKS 格式是否正确
35
+ if (!jwks.keys || !Array.isArray(jwks.keys) || jwks.keys.length === 0) {
36
+ throw new Error('JWKS 格式无效: 缺少或为空的 keys 数组');
37
+ }
38
+
39
+ // 检查是否有 RS256 算法的密钥
40
+ const hasRS256Key = jwks.keys.some((key: any) => key.alg === 'RS256' && key.kty === 'RSA');
41
+ if (!hasRS256Key) {
42
+ throw new Error('JWKS 中没有找到 RS256 算法的 RSA 密钥');
43
+ }
44
+
45
+ return jwks;
46
+ } catch (error) {
47
+ console.error('解析 JWKS 失败:', error);
48
+ throw new Error(`OIDC_JWKS_KEY 解析错误: ${(error as Error).message}`);
49
+ }
50
+ };
51
+
52
+ /**
53
+ * 获取 Cookie 密钥,使用 KEY_VAULTS_SECRET
54
+ */
55
+ const getCookieKeys = () => {
56
+ const key = serverDBEnv.KEY_VAULTS_SECRET;
57
+ if (!key) {
58
+ throw new Error('KEY_VAULTS_SECRET is required for OIDC Provider cookie encryption');
59
+ }
60
+ return [key];
61
+ };
62
+
63
+ /**
64
+ * 创建 OIDC Provider 实例
65
+ * @param db - 数据库实例
66
+ * @returns 配置好的 OIDC Provider 实例
67
+ */
68
+ export const createOIDCProvider = async (db: LobeChatDatabase): Promise<Provider> => {
69
+ // 获取 JWKS
70
+ const jwks = getJWKS();
71
+
72
+ const cookieKeys = getCookieKeys();
73
+
74
+ const configuration: Configuration = {
75
+ // 11. 数据库适配器
76
+ adapter: DrizzleAdapter.createAdapterFactory(db),
77
+
78
+ // 4. Claims 定义
79
+ claims: defaultClaims,
80
+
81
+ // 新增:客户端 CORS 控制逻辑
82
+ clientBasedCORS(ctx, origin, client) {
83
+ // 检查客户端是否允许此来源
84
+ // 一个常见的策略是允许所有已注册的 redirect_uris 的来源
85
+ if (!client || !client.redirectUris) {
86
+ logProvider('clientBasedCORS: No client or redirectUris found, denying origin: %s', origin);
87
+ return false; // 如果没有客户端或重定向URI,则拒绝
88
+ }
89
+
90
+ const allowed = client.redirectUris.some((uri) => {
91
+ try {
92
+ // 比较来源 (scheme, hostname, port)
93
+ return new URL(uri).origin === origin;
94
+ } catch {
95
+ // 如果 redirect_uri 不是有效的 URL (例如自定义协议),则跳过
96
+ return false;
97
+ }
98
+ });
99
+
100
+ logProvider(
101
+ 'clientBasedCORS check for origin [%s] and client [%s]: %s',
102
+ origin,
103
+ client.clientId,
104
+ allowed ? 'Allowed' : 'Denied',
105
+ );
106
+ return allowed;
107
+ },
108
+
109
+ // 1. 客户端配置
110
+ clients: defaultClients,
111
+
112
+ // 7. Cookie 配置
113
+ cookies: {
114
+ keys: cookieKeys,
115
+ long: { path: '/', signed: true },
116
+ short: { path: '/', signed: true },
117
+ },
118
+
119
+ // 5. 特性配置
120
+ features: {
121
+ backchannelLogout: { enabled: true },
122
+ clientCredentials: { enabled: false },
123
+ devInteractions: { enabled: false },
124
+ deviceFlow: { enabled: false },
125
+ introspection: { enabled: true },
126
+ resourceIndicators: { enabled: false },
127
+ revocation: { enabled: true },
128
+ rpInitiatedLogout: { enabled: true },
129
+ userinfo: { enabled: true },
130
+ },
131
+ // 10. 账户查找
132
+ async findAccount(ctx: KoaContextWithOIDC, id: string) {
133
+ logProvider('findAccount called for id: %s', id);
134
+
135
+ // 检查是否有预先存储的外部账户 ID
136
+ // @ts-ignore - 自定义属性
137
+ const externalAccountId = ctx.externalAccountId;
138
+ if (externalAccountId) {
139
+ logProvider('Found externalAccountId in context: %s', externalAccountId);
140
+ }
141
+
142
+ // 确定要查找的账户 ID
143
+ // 优先级: 1. externalAccountId 2. ctx.oidc.session?.accountId 3. 传入的 id
144
+ const accountIdToFind = externalAccountId || ctx.oidc?.session?.accountId || id;
145
+
146
+ logProvider(
147
+ 'Attempting to find account with ID: %s (source: %s)',
148
+ accountIdToFind,
149
+ externalAccountId
150
+ ? 'externalAccountId'
151
+ : ctx.oidc?.session?.accountId
152
+ ? 'oidc_session'
153
+ : 'parameter_id',
154
+ );
155
+
156
+ // 如果没有可用的 ID,返回 undefined
157
+ if (!accountIdToFind) {
158
+ logProvider('findAccount: No account ID available, returning undefined.');
159
+ return undefined;
160
+ }
161
+
162
+ try {
163
+ const user = await UserModel.findById(db, accountIdToFind);
164
+ logProvider(
165
+ 'UserModel.findById result for %s: %O',
166
+ accountIdToFind,
167
+ user ? { id: user.id, name: user.username } : null,
168
+ );
169
+
170
+ if (!user) {
171
+ logProvider('No user found for accountId: %s', accountIdToFind);
172
+ return undefined;
173
+ }
174
+
175
+ return {
176
+ accountId: user.id,
177
+ async claims(use, scope): Promise<{ [key: string]: any; sub: string }> {
178
+ logProvider('claims function called for user %s with scope: %s', user.id, scope);
179
+ const claims: { [key: string]: any; sub: string } = {
180
+ sub: user.id,
181
+ };
182
+
183
+ if (scope.includes('profile')) {
184
+ claims.name =
185
+ user.fullName ||
186
+ user.username ||
187
+ `${user.firstName || ''} ${user.lastName || ''}`.trim();
188
+ claims.picture = user.avatar;
189
+ }
190
+
191
+ if (scope.includes('email')) {
192
+ claims.email = user.email;
193
+ claims.email_verified = !!user.emailVerifiedAt;
194
+ }
195
+
196
+ logProvider('Returning claims: %O', claims);
197
+ return claims;
198
+ },
199
+ };
200
+ } catch (error) {
201
+ logProvider('Error finding account or generating claims: %O', error);
202
+ console.error('Error finding account:', error);
203
+ return undefined;
204
+ }
205
+ },
206
+
207
+ // 9. 交互策略
208
+ interactions: {
209
+ policy: createInteractionPolicy(),
210
+ url(ctx, interaction) {
211
+ // ---> 添加日志 <---
212
+ logProvider('interactions.url function called');
213
+ logProvider('Interaction details: %O', interaction);
214
+ const interactionUrl = `/oauth/consent/${interaction.uid}`;
215
+ logProvider('Generated interaction URL: %s', interactionUrl);
216
+ // ---> 添加日志结束 <---
217
+ return interactionUrl;
218
+ },
219
+ },
220
+
221
+ // 6. 密钥配置 - 使用 RS256 JWKS
222
+ jwks: jwks as { keys: any[] },
223
+
224
+ // 2. PKCE 配置
225
+ pkce: {
226
+ required: () => true,
227
+ },
228
+
229
+ // 12. 其他配置
230
+ renderError: async (ctx, out, error) => {
231
+ ctx.type = 'html';
232
+ ctx.body = `
233
+ <html>
234
+ <head>
235
+ <title>LobeHub OIDC Error</title>
236
+ </head>
237
+ <body>
238
+ <h1>LobeHub OIDC Error</h1>
239
+ <p>${JSON.stringify(error, null, 2)}</p>
240
+ <p>${JSON.stringify(out, null, 2)}</p>
241
+ </body>
242
+ </html>
243
+ `;
244
+ },
245
+
246
+ // 新增:启用 Refresh Token 轮换
247
+ rotateRefreshToken: true,
248
+
249
+ routes: {
250
+ authorization: '/oidc/auth',
251
+ end_session: '/oidc/session/end',
252
+ token: '/oidc/token',
253
+ },
254
+ // 3. Scopes 定义
255
+ scopes: defaultScopes,
256
+
257
+ // 8. 令牌有效期
258
+ ttl: {
259
+ AccessToken: 3600, // 1 hour in seconds
260
+ AuthorizationCode: 600, // 10 minutes
261
+ DeviceCode: 600, // 10 minutes (if enabled)
262
+
263
+ IdToken: 3600, // 1 hour
264
+ Interaction: 3600, // 1 hour
265
+
266
+ RefreshToken: 30 * 24 * 60 * 60, // 30 days
267
+ Session: 30 * 24 * 60 * 60, // 30 days
268
+ },
269
+ };
270
+
271
+ // 创建提供者实例
272
+ const baseUrl = urlJoin(appEnv.APP_URL!, '/oidc');
273
+
274
+ const provider = new Provider(baseUrl, configuration);
275
+
276
+ provider.on('server_error', (ctx, err) => {
277
+ logProvider('OIDC Provider Server Error: %O', err); // Use logProvider
278
+ console.error('OIDC Provider Error:', err);
279
+ });
280
+
281
+ provider.on('authorization.success', (ctx) => {
282
+ logProvider('Authorization successful for client: %s', ctx.oidc.client?.clientId); // Use logProvider
283
+ });
284
+
285
+ return provider;
286
+ };
287
+
288
+ export { type default as OIDCProvider } from 'oidc-provider';
@@ -1,7 +1,7 @@
1
1
  import { initTRPC } from '@trpc/server';
2
2
  import superjson from 'superjson';
3
3
 
4
- import { AsyncContext } from '@/server/asyncContext';
4
+ import { AsyncContext } from './context';
5
5
 
6
6
  export const asyncTrpc = initTRPC.context<AsyncContext>().create({
7
7
  errorFormatter({ shape }) {
@@ -28,13 +28,13 @@ export const createContextInner = async (params?: {
28
28
  userId: params?.userId,
29
29
  });
30
30
 
31
- export type Context = Awaited<ReturnType<typeof createContextInner>>;
31
+ export type EdgeContext = Awaited<ReturnType<typeof createContextInner>>;
32
32
 
33
33
  /**
34
34
  * Creates context for an incoming request
35
35
  * @link https://trpc.io/docs/v11/context
36
36
  */
37
- export const createContext = async (request: NextRequest): Promise<Context> => {
37
+ export const createEdgeContext = async (request: NextRequest): Promise<EdgeContext> => {
38
38
  // for API-response caching see https://trpc.io/docs/v11/caching
39
39
 
40
40
  const authorization = request.headers.get(LOBE_CHAT_AUTH_HEADER);