@lobehub/chat 1.79.7 → 1.79.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.eslintrc.js +1 -0
- package/CHANGELOG.md +58 -0
- package/changelog/v1.json +18 -0
- package/docs/development/database-schema.dbml +119 -0
- package/locales/ar/models.json +12 -0
- package/locales/ar/oauth.json +40 -0
- package/locales/bg-BG/models.json +12 -0
- package/locales/bg-BG/oauth.json +40 -0
- package/locales/de-DE/models.json +12 -0
- package/locales/de-DE/oauth.json +40 -0
- package/locales/en-US/models.json +12 -0
- package/locales/en-US/oauth.json +40 -0
- package/locales/es-ES/models.json +12 -0
- package/locales/es-ES/oauth.json +40 -0
- package/locales/fa-IR/models.json +12 -0
- package/locales/fa-IR/oauth.json +40 -0
- package/locales/fr-FR/models.json +12 -0
- package/locales/fr-FR/oauth.json +40 -0
- package/locales/it-IT/models.json +12 -0
- package/locales/it-IT/oauth.json +40 -0
- package/locales/ja-JP/models.json +12 -0
- package/locales/ja-JP/oauth.json +40 -0
- package/locales/ko-KR/models.json +12 -0
- package/locales/ko-KR/oauth.json +40 -0
- package/locales/nl-NL/models.json +12 -0
- package/locales/nl-NL/oauth.json +40 -0
- package/locales/pl-PL/models.json +12 -0
- package/locales/pl-PL/oauth.json +40 -0
- package/locales/pt-BR/models.json +12 -0
- package/locales/pt-BR/oauth.json +40 -0
- package/locales/ru-RU/models.json +12 -0
- package/locales/ru-RU/oauth.json +40 -0
- package/locales/tr-TR/models.json +12 -0
- package/locales/tr-TR/oauth.json +40 -0
- package/locales/vi-VN/models.json +12 -0
- package/locales/vi-VN/oauth.json +40 -0
- package/locales/zh-CN/models.json +12 -0
- package/locales/zh-CN/oauth.json +40 -0
- package/locales/zh-TW/models.json +12 -0
- package/locales/zh-TW/oauth.json +40 -0
- package/package.json +4 -1
- package/scripts/generate-oidc-jwk.mjs +59 -0
- package/scripts/migrateServerDB/index.ts +3 -1
- package/src/app/(backend)/oidc/[...oidc]/route.ts +96 -0
- package/src/app/(backend)/oidc/consent/route.ts +131 -0
- package/src/app/(backend)/trpc/async/[trpc]/route.ts +1 -1
- package/src/app/(backend)/trpc/edge/[trpc]/route.ts +2 -2
- package/src/app/(backend)/trpc/lambda/[trpc]/route.ts +2 -2
- package/src/app/(backend)/trpc/tools/[trpc]/route.ts +2 -2
- package/src/app/[variants]/(main)/files/[id]/page.tsx +1 -1
- package/src/app/[variants]/oauth/consent/[uid]/Client.tsx +224 -0
- package/src/app/[variants]/oauth/consent/[uid]/ClientError.tsx +46 -0
- package/src/app/[variants]/oauth/consent/[uid]/failed/page.tsx +36 -0
- package/src/app/[variants]/oauth/consent/[uid]/page.tsx +69 -0
- package/src/app/[variants]/oauth/consent/[uid]/success/page.tsx +30 -0
- package/src/components/Branding/ProductLogo/index.tsx +6 -1
- package/src/config/aiModels/openai.ts +63 -41
- package/src/database/client/migrations.json +27 -8
- package/src/database/migrations/0020_add_oidc.sql +124 -0
- package/src/database/migrations/meta/0020_snapshot.json +4975 -0
- package/src/database/migrations/meta/_journal.json +7 -0
- package/src/database/repositories/tableViewer/index.test.ts +1 -1
- package/src/database/schemas/index.ts +1 -0
- package/src/database/schemas/oidc.ts +158 -0
- package/src/database/server/models/__tests__/adapter.test.ts +499 -0
- package/src/envs/oidc.ts +18 -0
- package/src/libs/agent-runtime/azureOpenai/index.ts +4 -1
- package/src/libs/agent-runtime/utils/streams/protocol.ts +2 -4
- package/src/libs/oidc-provider/adapter.ts +541 -0
- package/src/libs/oidc-provider/config.ts +52 -0
- package/src/libs/oidc-provider/http-adapter.ts +311 -0
- package/src/libs/oidc-provider/interaction-policy.ts +37 -0
- package/src/libs/oidc-provider/provider.ts +288 -0
- package/src/libs/trpc/async/init.ts +1 -1
- package/src/{server → libs/trpc/edge}/context.ts +2 -2
- package/src/libs/trpc/{index.ts → edge/index.ts} +8 -8
- package/src/libs/trpc/{init.ts → edge/init.ts} +2 -2
- package/src/libs/trpc/{middleware → edge/middleware}/jwtPayload.test.ts +3 -3
- package/src/libs/trpc/{middleware → edge/middleware}/jwtPayload.ts +3 -2
- package/src/libs/trpc/lambda/context.ts +70 -0
- package/src/libs/trpc/lambda/index.ts +39 -1
- package/src/libs/trpc/lambda/init.ts +26 -0
- package/src/libs/trpc/lambda/middleware/index.ts +2 -0
- package/src/libs/trpc/{middleware → lambda/middleware}/keyVaults.ts +2 -1
- package/src/libs/trpc/lambda/{serverDatabase.ts → middleware/serverDatabase.ts} +2 -1
- package/src/libs/trpc/middleware/userAuth.test.ts +3 -3
- package/src/libs/trpc/middleware/userAuth.ts +1 -1
- package/src/libs/trpc/mock.ts +7 -0
- package/src/locales/default/index.ts +2 -0
- package/src/locales/default/oauth.ts +43 -0
- package/src/middleware.ts +94 -6
- package/src/server/routers/edge/appStatus.ts +1 -1
- package/src/server/routers/edge/config/index.test.ts +2 -3
- package/src/server/routers/edge/config/index.ts +1 -1
- package/src/server/routers/edge/index.ts +1 -1
- package/src/server/routers/edge/upload.ts +1 -1
- package/src/server/routers/lambda/_template.ts +2 -2
- package/src/server/routers/lambda/agent.ts +2 -2
- package/src/server/routers/lambda/aiModel.ts +2 -2
- package/src/server/routers/lambda/aiProvider.ts +2 -2
- package/src/server/routers/lambda/chunk.ts +2 -3
- package/src/server/routers/lambda/exporter.ts +2 -2
- package/src/server/routers/lambda/file.ts +2 -2
- package/src/server/routers/lambda/importer.ts +2 -2
- package/src/server/routers/lambda/index.ts +1 -1
- package/src/server/routers/lambda/knowledgeBase.ts +2 -2
- package/src/server/routers/lambda/message.ts +2 -2
- package/src/server/routers/lambda/plugin.ts +2 -2
- package/src/server/routers/lambda/ragEval.ts +2 -3
- package/src/server/routers/lambda/session.ts +2 -2
- package/src/server/routers/lambda/sessionGroup.ts +2 -2
- package/src/server/routers/lambda/thread.ts +2 -2
- package/src/server/routers/lambda/topic.ts +2 -2
- package/src/server/routers/lambda/user.ts +2 -2
- package/src/server/routers/tools/__tests__/search.test.ts +2 -2
- package/src/server/routers/tools/index.ts +1 -1
- package/src/server/routers/tools/search.ts +2 -1
- package/src/server/services/oidc/index.ts +64 -0
- package/src/server/services/oidc/oidcProvider.ts +25 -0
- package/src/server/mock.ts +0 -8
- /package/src/{server/asyncContext.ts → libs/trpc/async/context.ts} +0 -0
@@ -0,0 +1,311 @@
|
|
1
|
+
import debug from 'debug';
|
2
|
+
import { cookies } from 'next/headers';
|
3
|
+
import { NextRequest } from 'next/server';
|
4
|
+
import { IncomingMessage, ServerResponse } from 'node:http';
|
5
|
+
import urlJoin from 'url-join';
|
6
|
+
|
7
|
+
import { appEnv } from '@/config/app';
|
8
|
+
|
9
|
+
const log = debug('lobe-oidc:http-adapter');
|
10
|
+
|
11
|
+
/**
|
12
|
+
* 将 Next.js 请求头转换为标准 Node.js HTTP 头格式
|
13
|
+
*/
|
14
|
+
export const convertHeadersToNodeHeaders = (nextHeaders: Headers): Record<string, string> => {
|
15
|
+
const headers: Record<string, string> = {};
|
16
|
+
nextHeaders.forEach((value, key) => {
|
17
|
+
headers[key] = value;
|
18
|
+
});
|
19
|
+
return headers;
|
20
|
+
};
|
21
|
+
|
22
|
+
/**
|
23
|
+
* 创建用于 OIDC Provider 的 Node.js HTTP 请求对象
|
24
|
+
* @param req Next.js 请求对象
|
25
|
+
*/
|
26
|
+
export const createNodeRequest = async (req: NextRequest): Promise<IncomingMessage> => {
|
27
|
+
// 构建 URL 对象
|
28
|
+
const url = new URL(req.url);
|
29
|
+
|
30
|
+
// 计算相对于前缀的路径
|
31
|
+
let providerPath = url.pathname;
|
32
|
+
|
33
|
+
// 确保路径始终以/开头
|
34
|
+
if (!providerPath.startsWith('/')) {
|
35
|
+
providerPath = '/' + providerPath;
|
36
|
+
}
|
37
|
+
|
38
|
+
log('Creating Node.js request from Next.js request');
|
39
|
+
log('Original path: %s, Provider path: %s', url.pathname, providerPath);
|
40
|
+
|
41
|
+
// Attempt to parse and attach body for relevant methods
|
42
|
+
let parsedBody: any = undefined;
|
43
|
+
const methodsWithBody = ['POST', 'PUT', 'PATCH', 'DELETE'];
|
44
|
+
if (methodsWithBody.includes(req.method)) {
|
45
|
+
const contentType = req.headers.get('content-type')?.split(';')[0]; // Get content type without charset etc.
|
46
|
+
log(`Attempting to parse body for ${req.method} with Content-Type: ${contentType}`);
|
47
|
+
try {
|
48
|
+
// Check if body exists and has size before attempting to parse
|
49
|
+
if (req.body && req.headers.get('content-length') !== '0') {
|
50
|
+
if (contentType === 'application/x-www-form-urlencoded') {
|
51
|
+
const formData = await req.formData();
|
52
|
+
parsedBody = {};
|
53
|
+
formData.forEach((value, key) => {
|
54
|
+
// If a key appears multiple times, keep the last one (standard form behavior)
|
55
|
+
// Or convert to array if oidc-provider expects it:
|
56
|
+
// if (parsedBody[key]) {
|
57
|
+
// if (!Array.isArray(parsedBody[key])) parsedBody[key] = [parsedBody[key]];
|
58
|
+
// parsedBody[key].push(value);
|
59
|
+
// } else {
|
60
|
+
// parsedBody[key] = value;
|
61
|
+
// }
|
62
|
+
parsedBody[key] = value;
|
63
|
+
});
|
64
|
+
log('Parsed form data body: %O', parsedBody);
|
65
|
+
} else if (contentType === 'application/json') {
|
66
|
+
parsedBody = await req.json();
|
67
|
+
log('Parsed JSON body: %O', parsedBody);
|
68
|
+
} else {
|
69
|
+
log(`Body parsing skipped for Content-Type: ${contentType}. Trying text() as fallback.`);
|
70
|
+
// Fallback: try reading as text if content type is unknown but body exists
|
71
|
+
parsedBody = await req.text();
|
72
|
+
log('Parsed body as text fallback.');
|
73
|
+
}
|
74
|
+
} else {
|
75
|
+
log('Request has no body or content-length is 0, skipping parsing.');
|
76
|
+
}
|
77
|
+
} catch (error) {
|
78
|
+
log('Error parsing request body: %O', error);
|
79
|
+
// Keep parsedBody as undefined, let oidc-provider handle the potential issue
|
80
|
+
}
|
81
|
+
}
|
82
|
+
const nodeRequest = {
|
83
|
+
// 基本属性
|
84
|
+
headers: convertHeadersToNodeHeaders(req.headers),
|
85
|
+
|
86
|
+
method: req.method,
|
87
|
+
// 模拟可读流行为 (oidc-provider might not rely on this if body is pre-parsed)
|
88
|
+
// eslint-disable-next-line @typescript-eslint/ban-types
|
89
|
+
on: (event: string, handler: Function) => {
|
90
|
+
if (event === 'end') {
|
91
|
+
// Simulate end immediately as body is already processed or will be attached
|
92
|
+
handler();
|
93
|
+
}
|
94
|
+
},
|
95
|
+
// 添加 Node.js 服务器所期望的额外属性
|
96
|
+
socket: {
|
97
|
+
remoteAddress: req.headers.get('x-forwarded-for') || '127.0.0.1',
|
98
|
+
},
|
99
|
+
url: providerPath + url.search,
|
100
|
+
...(parsedBody !== undefined && { body: parsedBody }), // Attach body if it exists
|
101
|
+
};
|
102
|
+
|
103
|
+
log('Node.js request created with method %s and path %s', nodeRequest.method, nodeRequest.url);
|
104
|
+
if (nodeRequest.body) {
|
105
|
+
log('Attached parsed body to Node.js request.');
|
106
|
+
}
|
107
|
+
// Cast back to IncomingMessage for the function's return signature
|
108
|
+
return nodeRequest as unknown as IncomingMessage;
|
109
|
+
};
|
110
|
+
|
111
|
+
/**
|
112
|
+
* 响应收集器接口,用于捕获 OIDC Provider 的响应
|
113
|
+
*/
|
114
|
+
export interface ResponseCollector {
|
115
|
+
nodeResponse: ServerResponse;
|
116
|
+
readonly responseBody: string | Buffer;
|
117
|
+
readonly responseHeaders: Record<string, string | string[]>;
|
118
|
+
readonly responseStatus: number;
|
119
|
+
}
|
120
|
+
|
121
|
+
/**
|
122
|
+
* 创建用于 OIDC Provider 的 Node.js HTTP 响应对象
|
123
|
+
* @param resolvePromise 当响应完成时调用的解析函数
|
124
|
+
*/
|
125
|
+
export const createNodeResponse = (resolvePromise: () => void): ResponseCollector => {
|
126
|
+
log('Creating Node.js response collector');
|
127
|
+
|
128
|
+
// 存储响应状态的对象
|
129
|
+
const state = {
|
130
|
+
responseBody: '' as string | Buffer,
|
131
|
+
responseHeaders: {} as Record<string, string | string[]>,
|
132
|
+
responseStatus: 200,
|
133
|
+
};
|
134
|
+
|
135
|
+
let promiseResolved = false;
|
136
|
+
|
137
|
+
const nodeResponse: any = {
|
138
|
+
end: (chunk?: string | Buffer) => {
|
139
|
+
log('NodeResponse.end called');
|
140
|
+
if (chunk) {
|
141
|
+
log('NodeResponse.end chunk: %s', typeof chunk === 'string' ? chunk : '(Buffer)');
|
142
|
+
// @ts-ignore
|
143
|
+
state.responseBody += chunk;
|
144
|
+
}
|
145
|
+
|
146
|
+
const locationHeader = state.responseHeaders['location'];
|
147
|
+
if (locationHeader && state.responseStatus === 200) {
|
148
|
+
log('Location header detected with status 200, overriding to 302');
|
149
|
+
state.responseStatus = 302;
|
150
|
+
}
|
151
|
+
|
152
|
+
if (!promiseResolved) {
|
153
|
+
log('Resolving response promise');
|
154
|
+
promiseResolved = true;
|
155
|
+
resolvePromise();
|
156
|
+
}
|
157
|
+
},
|
158
|
+
|
159
|
+
getHeader: (name: string) => {
|
160
|
+
const lowerName = name.toLowerCase();
|
161
|
+
return state.responseHeaders[lowerName];
|
162
|
+
},
|
163
|
+
|
164
|
+
getHeaderNames: () => {
|
165
|
+
return Object.keys(state.responseHeaders);
|
166
|
+
},
|
167
|
+
|
168
|
+
getHeaders: () => {
|
169
|
+
return state.responseHeaders;
|
170
|
+
},
|
171
|
+
|
172
|
+
headersSent: false,
|
173
|
+
|
174
|
+
removeHeader: (name: string) => {
|
175
|
+
const lowerName = name.toLowerCase();
|
176
|
+
log('Removing header: %s', lowerName);
|
177
|
+
delete state.responseHeaders[lowerName];
|
178
|
+
},
|
179
|
+
|
180
|
+
setHeader: (name: string, value: string | string[]) => {
|
181
|
+
const lowerName = name.toLowerCase();
|
182
|
+
log('Setting header: %s = %s', lowerName, value);
|
183
|
+
state.responseHeaders[lowerName] = value;
|
184
|
+
},
|
185
|
+
|
186
|
+
write: (chunk: string | Buffer) => {
|
187
|
+
log('NodeResponse.write called with chunk');
|
188
|
+
// @ts-ignore
|
189
|
+
state.responseBody += chunk;
|
190
|
+
},
|
191
|
+
|
192
|
+
writeHead: (status: number, headers?: Record<string, string | string[]>) => {
|
193
|
+
log('NodeResponse.writeHead called with status: %d', status);
|
194
|
+
state.responseStatus = status;
|
195
|
+
|
196
|
+
if (headers) {
|
197
|
+
const lowerCaseHeaders = Object.entries(headers).reduce(
|
198
|
+
(acc, [key, value]) => {
|
199
|
+
acc[key.toLowerCase()] = value;
|
200
|
+
return acc;
|
201
|
+
},
|
202
|
+
{} as Record<string, string | string[]>,
|
203
|
+
);
|
204
|
+
state.responseHeaders = { ...state.responseHeaders, ...lowerCaseHeaders };
|
205
|
+
}
|
206
|
+
|
207
|
+
(nodeResponse as any).headersSent = true;
|
208
|
+
},
|
209
|
+
} as unknown as ServerResponse;
|
210
|
+
|
211
|
+
log('Node.js response collector created successfully');
|
212
|
+
|
213
|
+
return {
|
214
|
+
nodeResponse,
|
215
|
+
get responseBody() {
|
216
|
+
return state.responseBody;
|
217
|
+
},
|
218
|
+
get responseHeaders() {
|
219
|
+
return state.responseHeaders;
|
220
|
+
},
|
221
|
+
get responseStatus() {
|
222
|
+
return state.responseStatus;
|
223
|
+
},
|
224
|
+
};
|
225
|
+
};
|
226
|
+
|
227
|
+
/**
|
228
|
+
* 创建用于调用 provider.interactionDetails 的上下文 (req, res)
|
229
|
+
* @param uid 交互 ID
|
230
|
+
*/
|
231
|
+
export const createContextForInteractionDetails = async (
|
232
|
+
uid: string,
|
233
|
+
): Promise<{ req: IncomingMessage; res: ServerResponse }> => {
|
234
|
+
log('Creating context for interaction details for uid: %s', uid);
|
235
|
+
const baseUrl = appEnv.APP_URL!;
|
236
|
+
log('Using base URL: %s', baseUrl);
|
237
|
+
|
238
|
+
// 从baseUrl提取主机名用于headers
|
239
|
+
const hostName = new URL(baseUrl).host;
|
240
|
+
|
241
|
+
// 1. 获取真实的 Cookies
|
242
|
+
const cookieStore = await cookies();
|
243
|
+
const realCookies: Record<string, string> = {};
|
244
|
+
cookieStore.getAll().forEach((cookie) => {
|
245
|
+
realCookies[cookie.name] = cookie.value;
|
246
|
+
});
|
247
|
+
log('Real cookies found: %o', Object.keys(realCookies));
|
248
|
+
|
249
|
+
// 特别检查交互会话cookie
|
250
|
+
const interactionCookieName = `_interaction_${uid}`;
|
251
|
+
if (realCookies[interactionCookieName]) {
|
252
|
+
log('Found interaction session cookie: %s', interactionCookieName);
|
253
|
+
} else {
|
254
|
+
log('Warning: Interaction session cookie not found: %s', interactionCookieName);
|
255
|
+
}
|
256
|
+
|
257
|
+
// 2. 构建包含真实 Cookie 的 Headers
|
258
|
+
const headers = new Headers({ host: hostName });
|
259
|
+
const cookieString = Object.entries(realCookies)
|
260
|
+
.map(([name, value]) => `${name}=${value}`)
|
261
|
+
.join('; ');
|
262
|
+
if (cookieString) {
|
263
|
+
headers.set('cookie', cookieString);
|
264
|
+
log('Setting cookie header');
|
265
|
+
} else {
|
266
|
+
log('No cookies found to set in header');
|
267
|
+
}
|
268
|
+
|
269
|
+
// 3. 创建模拟的 NextRequest
|
270
|
+
// 注意:这里的 IP, geo, ua 等信息可能是 oidc-provider 某些特性需要的,
|
271
|
+
// 如果遇到相关问题,可能需要从真实请求头中提取 (e.g., 'x-forwarded-for', 'user-agent')
|
272
|
+
const interactionUrl = urlJoin(baseUrl, `/oauth/consent/${uid}`);
|
273
|
+
log('Creating interaction URL: %s', interactionUrl);
|
274
|
+
|
275
|
+
const mockNextRequest = {
|
276
|
+
cookies: {
|
277
|
+
// 模拟 NextRequestCookies 接口
|
278
|
+
get: (name: string) => cookieStore.get(name)?.value,
|
279
|
+
getAll: () => cookieStore.getAll(),
|
280
|
+
has: (name: string) => cookieStore.has(name),
|
281
|
+
},
|
282
|
+
geo: {},
|
283
|
+
headers: headers,
|
284
|
+
ip: '127.0.0.1',
|
285
|
+
method: 'GET',
|
286
|
+
nextUrl: new URL(interactionUrl),
|
287
|
+
page: { name: undefined, params: undefined },
|
288
|
+
ua: undefined,
|
289
|
+
url: new URL(interactionUrl),
|
290
|
+
} as unknown as NextRequest;
|
291
|
+
log('Mock NextRequest created for url: %s', mockNextRequest.url);
|
292
|
+
|
293
|
+
// 4. 使用 createNodeRequest 创建模拟的 Node.js IncomingMessage
|
294
|
+
// pathPrefix 设置为 '/' 因为我们的 URL 已经是 Provider 期望的路径格式 /interaction/:uid
|
295
|
+
const req: IncomingMessage = await createNodeRequest(mockNextRequest);
|
296
|
+
// @ts-ignore - 将解析出的 cookies 附加到模拟的 Node.js 请求上
|
297
|
+
req.cookies = realCookies;
|
298
|
+
log('Node.js IncomingMessage created, attached real cookies');
|
299
|
+
|
300
|
+
// 5. 使用 createNodeResponse 创建模拟的 Node.js ServerResponse
|
301
|
+
let resolveFunc: () => void;
|
302
|
+
new Promise<void>((resolve) => {
|
303
|
+
resolveFunc = resolve;
|
304
|
+
});
|
305
|
+
|
306
|
+
const responseCollector: ResponseCollector = createNodeResponse(() => resolveFunc());
|
307
|
+
const res: ServerResponse = responseCollector.nodeResponse;
|
308
|
+
log('Node.js ServerResponse created');
|
309
|
+
|
310
|
+
return { req, res };
|
311
|
+
};
|
@@ -0,0 +1,37 @@
|
|
1
|
+
import debug from 'debug';
|
2
|
+
import { interactionPolicy } from 'oidc-provider';
|
3
|
+
|
4
|
+
const { base } = interactionPolicy; // Import Check and base
|
5
|
+
const log = debug('lobe-oidc:interaction-policy');
|
6
|
+
|
7
|
+
/**
|
8
|
+
* 创建自定义交互策略
|
9
|
+
*/
|
10
|
+
export const createInteractionPolicy = () => {
|
11
|
+
log('Creating custom interaction policy');
|
12
|
+
const policy = base();
|
13
|
+
|
14
|
+
log('Base policy details: %O', {
|
15
|
+
promptNames: Array.from(policy.keys()),
|
16
|
+
size: policy.length,
|
17
|
+
});
|
18
|
+
|
19
|
+
const loginPrompt = policy.get('login');
|
20
|
+
log('Accessing login prompt from policy: %O', !!loginPrompt);
|
21
|
+
|
22
|
+
if (loginPrompt) {
|
23
|
+
log('Login prompt details: %O', {
|
24
|
+
checks: Array.from(loginPrompt.checks.keys()),
|
25
|
+
name: loginPrompt.name,
|
26
|
+
requestable: loginPrompt.requestable,
|
27
|
+
});
|
28
|
+
} else {
|
29
|
+
console.warn(
|
30
|
+
"Could not find 'login' prompt in the base policy. Custom session check not applied.",
|
31
|
+
);
|
32
|
+
log('WARNING: login prompt not found in base policy');
|
33
|
+
}
|
34
|
+
|
35
|
+
log('Custom interaction policy created successfully');
|
36
|
+
return policy;
|
37
|
+
};
|
@@ -0,0 +1,288 @@
|
|
1
|
+
import debug from 'debug';
|
2
|
+
import Provider, { Configuration, KoaContextWithOIDC } from 'oidc-provider';
|
3
|
+
import urlJoin from 'url-join';
|
4
|
+
|
5
|
+
import { appEnv } from '@/config/app';
|
6
|
+
import { serverDBEnv } from '@/config/db';
|
7
|
+
import { UserModel } from '@/database/models/user';
|
8
|
+
import { LobeChatDatabase } from '@/database/type';
|
9
|
+
import { oidcEnv } from '@/envs/oidc';
|
10
|
+
|
11
|
+
import { DrizzleAdapter } from './adapter';
|
12
|
+
import { defaultClaims, defaultClients, defaultScopes } from './config';
|
13
|
+
import { createInteractionPolicy } from './interaction-policy';
|
14
|
+
|
15
|
+
const logProvider = debug('lobe-oidc:provider'); // <--- 添加 provider 日志实例
|
16
|
+
|
17
|
+
/**
|
18
|
+
* 从环境变量中获取 JWKS
|
19
|
+
* 该 JWKS 是一个包含 RS256 私钥的 JSON 对象
|
20
|
+
*/
|
21
|
+
const getJWKS = (): object => {
|
22
|
+
try {
|
23
|
+
const jwksString = oidcEnv.OIDC_JWKS_KEY;
|
24
|
+
|
25
|
+
if (!jwksString) {
|
26
|
+
throw new Error(
|
27
|
+
'OIDC_JWKS_KEY 环境变量是必需的。请使用 scripts/generate-oidc-jwk.mjs 生成 JWKS。',
|
28
|
+
);
|
29
|
+
}
|
30
|
+
|
31
|
+
// 尝试解析 JWKS JSON 字符串
|
32
|
+
const jwks = JSON.parse(jwksString);
|
33
|
+
|
34
|
+
// 检查 JWKS 格式是否正确
|
35
|
+
if (!jwks.keys || !Array.isArray(jwks.keys) || jwks.keys.length === 0) {
|
36
|
+
throw new Error('JWKS 格式无效: 缺少或为空的 keys 数组');
|
37
|
+
}
|
38
|
+
|
39
|
+
// 检查是否有 RS256 算法的密钥
|
40
|
+
const hasRS256Key = jwks.keys.some((key: any) => key.alg === 'RS256' && key.kty === 'RSA');
|
41
|
+
if (!hasRS256Key) {
|
42
|
+
throw new Error('JWKS 中没有找到 RS256 算法的 RSA 密钥');
|
43
|
+
}
|
44
|
+
|
45
|
+
return jwks;
|
46
|
+
} catch (error) {
|
47
|
+
console.error('解析 JWKS 失败:', error);
|
48
|
+
throw new Error(`OIDC_JWKS_KEY 解析错误: ${(error as Error).message}`);
|
49
|
+
}
|
50
|
+
};
|
51
|
+
|
52
|
+
/**
|
53
|
+
* 获取 Cookie 密钥,使用 KEY_VAULTS_SECRET
|
54
|
+
*/
|
55
|
+
const getCookieKeys = () => {
|
56
|
+
const key = serverDBEnv.KEY_VAULTS_SECRET;
|
57
|
+
if (!key) {
|
58
|
+
throw new Error('KEY_VAULTS_SECRET is required for OIDC Provider cookie encryption');
|
59
|
+
}
|
60
|
+
return [key];
|
61
|
+
};
|
62
|
+
|
63
|
+
/**
|
64
|
+
* 创建 OIDC Provider 实例
|
65
|
+
* @param db - 数据库实例
|
66
|
+
* @returns 配置好的 OIDC Provider 实例
|
67
|
+
*/
|
68
|
+
export const createOIDCProvider = async (db: LobeChatDatabase): Promise<Provider> => {
|
69
|
+
// 获取 JWKS
|
70
|
+
const jwks = getJWKS();
|
71
|
+
|
72
|
+
const cookieKeys = getCookieKeys();
|
73
|
+
|
74
|
+
const configuration: Configuration = {
|
75
|
+
// 11. 数据库适配器
|
76
|
+
adapter: DrizzleAdapter.createAdapterFactory(db),
|
77
|
+
|
78
|
+
// 4. Claims 定义
|
79
|
+
claims: defaultClaims,
|
80
|
+
|
81
|
+
// 新增:客户端 CORS 控制逻辑
|
82
|
+
clientBasedCORS(ctx, origin, client) {
|
83
|
+
// 检查客户端是否允许此来源
|
84
|
+
// 一个常见的策略是允许所有已注册的 redirect_uris 的来源
|
85
|
+
if (!client || !client.redirectUris) {
|
86
|
+
logProvider('clientBasedCORS: No client or redirectUris found, denying origin: %s', origin);
|
87
|
+
return false; // 如果没有客户端或重定向URI,则拒绝
|
88
|
+
}
|
89
|
+
|
90
|
+
const allowed = client.redirectUris.some((uri) => {
|
91
|
+
try {
|
92
|
+
// 比较来源 (scheme, hostname, port)
|
93
|
+
return new URL(uri).origin === origin;
|
94
|
+
} catch {
|
95
|
+
// 如果 redirect_uri 不是有效的 URL (例如自定义协议),则跳过
|
96
|
+
return false;
|
97
|
+
}
|
98
|
+
});
|
99
|
+
|
100
|
+
logProvider(
|
101
|
+
'clientBasedCORS check for origin [%s] and client [%s]: %s',
|
102
|
+
origin,
|
103
|
+
client.clientId,
|
104
|
+
allowed ? 'Allowed' : 'Denied',
|
105
|
+
);
|
106
|
+
return allowed;
|
107
|
+
},
|
108
|
+
|
109
|
+
// 1. 客户端配置
|
110
|
+
clients: defaultClients,
|
111
|
+
|
112
|
+
// 7. Cookie 配置
|
113
|
+
cookies: {
|
114
|
+
keys: cookieKeys,
|
115
|
+
long: { path: '/', signed: true },
|
116
|
+
short: { path: '/', signed: true },
|
117
|
+
},
|
118
|
+
|
119
|
+
// 5. 特性配置
|
120
|
+
features: {
|
121
|
+
backchannelLogout: { enabled: true },
|
122
|
+
clientCredentials: { enabled: false },
|
123
|
+
devInteractions: { enabled: false },
|
124
|
+
deviceFlow: { enabled: false },
|
125
|
+
introspection: { enabled: true },
|
126
|
+
resourceIndicators: { enabled: false },
|
127
|
+
revocation: { enabled: true },
|
128
|
+
rpInitiatedLogout: { enabled: true },
|
129
|
+
userinfo: { enabled: true },
|
130
|
+
},
|
131
|
+
// 10. 账户查找
|
132
|
+
async findAccount(ctx: KoaContextWithOIDC, id: string) {
|
133
|
+
logProvider('findAccount called for id: %s', id);
|
134
|
+
|
135
|
+
// 检查是否有预先存储的外部账户 ID
|
136
|
+
// @ts-ignore - 自定义属性
|
137
|
+
const externalAccountId = ctx.externalAccountId;
|
138
|
+
if (externalAccountId) {
|
139
|
+
logProvider('Found externalAccountId in context: %s', externalAccountId);
|
140
|
+
}
|
141
|
+
|
142
|
+
// 确定要查找的账户 ID
|
143
|
+
// 优先级: 1. externalAccountId 2. ctx.oidc.session?.accountId 3. 传入的 id
|
144
|
+
const accountIdToFind = externalAccountId || ctx.oidc?.session?.accountId || id;
|
145
|
+
|
146
|
+
logProvider(
|
147
|
+
'Attempting to find account with ID: %s (source: %s)',
|
148
|
+
accountIdToFind,
|
149
|
+
externalAccountId
|
150
|
+
? 'externalAccountId'
|
151
|
+
: ctx.oidc?.session?.accountId
|
152
|
+
? 'oidc_session'
|
153
|
+
: 'parameter_id',
|
154
|
+
);
|
155
|
+
|
156
|
+
// 如果没有可用的 ID,返回 undefined
|
157
|
+
if (!accountIdToFind) {
|
158
|
+
logProvider('findAccount: No account ID available, returning undefined.');
|
159
|
+
return undefined;
|
160
|
+
}
|
161
|
+
|
162
|
+
try {
|
163
|
+
const user = await UserModel.findById(db, accountIdToFind);
|
164
|
+
logProvider(
|
165
|
+
'UserModel.findById result for %s: %O',
|
166
|
+
accountIdToFind,
|
167
|
+
user ? { id: user.id, name: user.username } : null,
|
168
|
+
);
|
169
|
+
|
170
|
+
if (!user) {
|
171
|
+
logProvider('No user found for accountId: %s', accountIdToFind);
|
172
|
+
return undefined;
|
173
|
+
}
|
174
|
+
|
175
|
+
return {
|
176
|
+
accountId: user.id,
|
177
|
+
async claims(use, scope): Promise<{ [key: string]: any; sub: string }> {
|
178
|
+
logProvider('claims function called for user %s with scope: %s', user.id, scope);
|
179
|
+
const claims: { [key: string]: any; sub: string } = {
|
180
|
+
sub: user.id,
|
181
|
+
};
|
182
|
+
|
183
|
+
if (scope.includes('profile')) {
|
184
|
+
claims.name =
|
185
|
+
user.fullName ||
|
186
|
+
user.username ||
|
187
|
+
`${user.firstName || ''} ${user.lastName || ''}`.trim();
|
188
|
+
claims.picture = user.avatar;
|
189
|
+
}
|
190
|
+
|
191
|
+
if (scope.includes('email')) {
|
192
|
+
claims.email = user.email;
|
193
|
+
claims.email_verified = !!user.emailVerifiedAt;
|
194
|
+
}
|
195
|
+
|
196
|
+
logProvider('Returning claims: %O', claims);
|
197
|
+
return claims;
|
198
|
+
},
|
199
|
+
};
|
200
|
+
} catch (error) {
|
201
|
+
logProvider('Error finding account or generating claims: %O', error);
|
202
|
+
console.error('Error finding account:', error);
|
203
|
+
return undefined;
|
204
|
+
}
|
205
|
+
},
|
206
|
+
|
207
|
+
// 9. 交互策略
|
208
|
+
interactions: {
|
209
|
+
policy: createInteractionPolicy(),
|
210
|
+
url(ctx, interaction) {
|
211
|
+
// ---> 添加日志 <---
|
212
|
+
logProvider('interactions.url function called');
|
213
|
+
logProvider('Interaction details: %O', interaction);
|
214
|
+
const interactionUrl = `/oauth/consent/${interaction.uid}`;
|
215
|
+
logProvider('Generated interaction URL: %s', interactionUrl);
|
216
|
+
// ---> 添加日志结束 <---
|
217
|
+
return interactionUrl;
|
218
|
+
},
|
219
|
+
},
|
220
|
+
|
221
|
+
// 6. 密钥配置 - 使用 RS256 JWKS
|
222
|
+
jwks: jwks as { keys: any[] },
|
223
|
+
|
224
|
+
// 2. PKCE 配置
|
225
|
+
pkce: {
|
226
|
+
required: () => true,
|
227
|
+
},
|
228
|
+
|
229
|
+
// 12. 其他配置
|
230
|
+
renderError: async (ctx, out, error) => {
|
231
|
+
ctx.type = 'html';
|
232
|
+
ctx.body = `
|
233
|
+
<html>
|
234
|
+
<head>
|
235
|
+
<title>LobeHub OIDC Error</title>
|
236
|
+
</head>
|
237
|
+
<body>
|
238
|
+
<h1>LobeHub OIDC Error</h1>
|
239
|
+
<p>${JSON.stringify(error, null, 2)}</p>
|
240
|
+
<p>${JSON.stringify(out, null, 2)}</p>
|
241
|
+
</body>
|
242
|
+
</html>
|
243
|
+
`;
|
244
|
+
},
|
245
|
+
|
246
|
+
// 新增:启用 Refresh Token 轮换
|
247
|
+
rotateRefreshToken: true,
|
248
|
+
|
249
|
+
routes: {
|
250
|
+
authorization: '/oidc/auth',
|
251
|
+
end_session: '/oidc/session/end',
|
252
|
+
token: '/oidc/token',
|
253
|
+
},
|
254
|
+
// 3. Scopes 定义
|
255
|
+
scopes: defaultScopes,
|
256
|
+
|
257
|
+
// 8. 令牌有效期
|
258
|
+
ttl: {
|
259
|
+
AccessToken: 3600, // 1 hour in seconds
|
260
|
+
AuthorizationCode: 600, // 10 minutes
|
261
|
+
DeviceCode: 600, // 10 minutes (if enabled)
|
262
|
+
|
263
|
+
IdToken: 3600, // 1 hour
|
264
|
+
Interaction: 3600, // 1 hour
|
265
|
+
|
266
|
+
RefreshToken: 30 * 24 * 60 * 60, // 30 days
|
267
|
+
Session: 30 * 24 * 60 * 60, // 30 days
|
268
|
+
},
|
269
|
+
};
|
270
|
+
|
271
|
+
// 创建提供者实例
|
272
|
+
const baseUrl = urlJoin(appEnv.APP_URL!, '/oidc');
|
273
|
+
|
274
|
+
const provider = new Provider(baseUrl, configuration);
|
275
|
+
|
276
|
+
provider.on('server_error', (ctx, err) => {
|
277
|
+
logProvider('OIDC Provider Server Error: %O', err); // Use logProvider
|
278
|
+
console.error('OIDC Provider Error:', err);
|
279
|
+
});
|
280
|
+
|
281
|
+
provider.on('authorization.success', (ctx) => {
|
282
|
+
logProvider('Authorization successful for client: %s', ctx.oidc.client?.clientId); // Use logProvider
|
283
|
+
});
|
284
|
+
|
285
|
+
return provider;
|
286
|
+
};
|
287
|
+
|
288
|
+
export { type default as OIDCProvider } from 'oidc-provider';
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import { initTRPC } from '@trpc/server';
|
2
2
|
import superjson from 'superjson';
|
3
3
|
|
4
|
-
import { AsyncContext } from '
|
4
|
+
import { AsyncContext } from './context';
|
5
5
|
|
6
6
|
export const asyncTrpc = initTRPC.context<AsyncContext>().create({
|
7
7
|
errorFormatter({ shape }) {
|
@@ -28,13 +28,13 @@ export const createContextInner = async (params?: {
|
|
28
28
|
userId: params?.userId,
|
29
29
|
});
|
30
30
|
|
31
|
-
export type
|
31
|
+
export type EdgeContext = Awaited<ReturnType<typeof createContextInner>>;
|
32
32
|
|
33
33
|
/**
|
34
34
|
* Creates context for an incoming request
|
35
35
|
* @link https://trpc.io/docs/v11/context
|
36
36
|
*/
|
37
|
-
export const
|
37
|
+
export const createEdgeContext = async (request: NextRequest): Promise<EdgeContext> => {
|
38
38
|
// for API-response caching see https://trpc.io/docs/v11/caching
|
39
39
|
|
40
40
|
const authorization = request.headers.get(LOBE_CHAT_AUTH_HEADER);
|