@lobehub/chat 1.79.8 → 1.79.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. package/.eslintrc.js +1 -0
  2. package/CHANGELOG.md +58 -0
  3. package/changelog/v1.json +18 -0
  4. package/locales/ar/models.json +9 -0
  5. package/locales/ar/oauth.json +7 -6
  6. package/locales/bg-BG/models.json +9 -0
  7. package/locales/bg-BG/oauth.json +7 -6
  8. package/locales/de-DE/models.json +9 -0
  9. package/locales/de-DE/oauth.json +7 -6
  10. package/locales/en-US/models.json +9 -0
  11. package/locales/en-US/oauth.json +7 -6
  12. package/locales/es-ES/models.json +9 -0
  13. package/locales/es-ES/oauth.json +7 -6
  14. package/locales/fa-IR/models.json +9 -0
  15. package/locales/fa-IR/oauth.json +7 -6
  16. package/locales/fr-FR/models.json +9 -0
  17. package/locales/fr-FR/oauth.json +7 -6
  18. package/locales/it-IT/models.json +9 -0
  19. package/locales/it-IT/oauth.json +7 -6
  20. package/locales/ja-JP/models.json +9 -0
  21. package/locales/ja-JP/oauth.json +7 -6
  22. package/locales/ko-KR/models.json +9 -0
  23. package/locales/ko-KR/oauth.json +7 -6
  24. package/locales/nl-NL/models.json +9 -0
  25. package/locales/nl-NL/oauth.json +7 -6
  26. package/locales/pl-PL/models.json +9 -0
  27. package/locales/pl-PL/oauth.json +7 -6
  28. package/locales/pt-BR/models.json +9 -0
  29. package/locales/pt-BR/oauth.json +7 -6
  30. package/locales/ru-RU/models.json +9 -0
  31. package/locales/ru-RU/oauth.json +7 -6
  32. package/locales/tr-TR/models.json +9 -0
  33. package/locales/tr-TR/oauth.json +7 -6
  34. package/locales/vi-VN/models.json +9 -0
  35. package/locales/vi-VN/oauth.json +7 -6
  36. package/locales/zh-CN/models.json +9 -0
  37. package/locales/zh-CN/oauth.json +7 -6
  38. package/locales/zh-TW/models.json +9 -0
  39. package/locales/zh-TW/oauth.json +7 -6
  40. package/package.json +1 -1
  41. package/src/app/(backend)/oidc/[...oidc]/route.ts +27 -201
  42. package/src/app/(backend)/oidc/consent/route.ts +58 -24
  43. package/src/app/(backend)/trpc/async/[trpc]/route.ts +1 -1
  44. package/src/app/(backend)/trpc/edge/[trpc]/route.ts +2 -2
  45. package/src/app/(backend)/trpc/lambda/[trpc]/route.ts +2 -2
  46. package/src/app/(backend)/trpc/tools/[trpc]/route.ts +2 -2
  47. package/src/app/[variants]/(main)/files/[id]/page.tsx +1 -1
  48. package/src/app/[variants]/oauth/consent/[uid]/Client.tsx +184 -57
  49. package/src/app/[variants]/oauth/consent/[uid]/ClientError.tsx +46 -0
  50. package/src/app/[variants]/oauth/consent/[uid]/page.tsx +19 -21
  51. package/src/components/Branding/ProductLogo/index.tsx +6 -1
  52. package/src/config/aiModels/openai.ts +63 -41
  53. package/src/config/modelProviders/openai.ts +17 -0
  54. package/src/const/settings/llm.ts +1 -1
  55. package/src/database/server/models/__tests__/adapter.test.ts +1 -5
  56. package/src/libs/oidc-provider/adapter.ts +47 -0
  57. package/src/libs/oidc-provider/config.ts +4 -5
  58. package/src/libs/oidc-provider/http-adapter.ts +60 -28
  59. package/src/libs/oidc-provider/provider.ts +41 -13
  60. package/src/libs/trpc/async/init.ts +1 -1
  61. package/src/{server → libs/trpc/edge}/context.ts +2 -2
  62. package/src/libs/trpc/{index.ts → edge/index.ts} +8 -8
  63. package/src/libs/trpc/{init.ts → edge/init.ts} +2 -2
  64. package/src/libs/trpc/{middleware → edge/middleware}/jwtPayload.test.ts +3 -3
  65. package/src/libs/trpc/{middleware → edge/middleware}/jwtPayload.ts +3 -2
  66. package/src/libs/trpc/lambda/context.ts +70 -0
  67. package/src/libs/trpc/lambda/index.ts +39 -1
  68. package/src/libs/trpc/lambda/init.ts +26 -0
  69. package/src/libs/trpc/lambda/middleware/index.ts +2 -0
  70. package/src/libs/trpc/{middleware → lambda/middleware}/keyVaults.ts +2 -1
  71. package/src/libs/trpc/lambda/{serverDatabase.ts → middleware/serverDatabase.ts} +2 -1
  72. package/src/libs/trpc/middleware/userAuth.test.ts +3 -3
  73. package/src/libs/trpc/middleware/userAuth.ts +1 -1
  74. package/src/libs/trpc/mock.ts +7 -0
  75. package/src/locales/default/oauth.ts +8 -6
  76. package/src/server/routers/edge/appStatus.ts +1 -1
  77. package/src/server/routers/edge/config/index.test.ts +2 -3
  78. package/src/server/routers/edge/config/index.ts +1 -1
  79. package/src/server/routers/edge/index.ts +1 -1
  80. package/src/server/routers/edge/upload.ts +1 -1
  81. package/src/server/routers/lambda/_template.ts +2 -2
  82. package/src/server/routers/lambda/agent.ts +2 -2
  83. package/src/server/routers/lambda/aiModel.ts +2 -2
  84. package/src/server/routers/lambda/aiProvider.ts +2 -2
  85. package/src/server/routers/lambda/chunk.ts +2 -3
  86. package/src/server/routers/lambda/exporter.ts +2 -2
  87. package/src/server/routers/lambda/file.ts +2 -2
  88. package/src/server/routers/lambda/importer.ts +2 -2
  89. package/src/server/routers/lambda/index.ts +1 -1
  90. package/src/server/routers/lambda/knowledgeBase.ts +2 -2
  91. package/src/server/routers/lambda/message.ts +2 -2
  92. package/src/server/routers/lambda/plugin.ts +2 -2
  93. package/src/server/routers/lambda/ragEval.ts +2 -3
  94. package/src/server/routers/lambda/session.ts +2 -2
  95. package/src/server/routers/lambda/sessionGroup.ts +2 -2
  96. package/src/server/routers/lambda/thread.ts +2 -2
  97. package/src/server/routers/lambda/topic.ts +2 -2
  98. package/src/server/routers/lambda/user.ts +2 -2
  99. package/src/server/routers/tools/__tests__/search.test.ts +2 -2
  100. package/src/server/routers/tools/index.ts +1 -1
  101. package/src/server/routers/tools/search.ts +3 -1
  102. package/src/server/services/oidc/index.ts +36 -1
  103. package/src/server/services/oidc/oidcProvider.ts +1 -3
  104. package/src/services/chat.ts +1 -0
  105. package/src/store/agent/slices/chat/selectors/__snapshots__/agent.test.ts.snap +1 -1
  106. package/src/store/user/slices/modelList/selectors/modelProvider.test.ts +1 -0
  107. package/src/store/user/slices/settings/selectors/__snapshots__/settings.test.ts.snap +8 -8
  108. package/src/server/mock.ts +0 -8
  109. /package/src/{server/asyncContext.ts → libs/trpc/async/context.ts} +0 -0
@@ -1,4 +1,5 @@
1
1
  import debug from 'debug';
2
+ import { sql } from 'drizzle-orm';
2
3
  import { eq } from 'drizzle-orm/expressions';
3
4
 
4
5
  import {
@@ -161,6 +162,24 @@ class OIDCAdapter {
161
162
  if (payload.accountId) {
162
163
  record.userId = payload.accountId;
163
164
  log('[%s] Setting userId: %s', this.name, payload.accountId);
165
+ } else {
166
+ try {
167
+ const { getUserAuth } = await import('@/utils/server/auth');
168
+ try {
169
+ const { userId } = await getUserAuth();
170
+ if (userId) {
171
+ payload.accountId = userId;
172
+ record.userId = userId;
173
+ log('[%s] Setting userId from auth context: %s', this.name, userId);
174
+ }
175
+ } catch (authError) {
176
+ log('[%s] Error getting userId from auth context: %O', this.name, authError);
177
+ // 如果获取 userId 失败,继续处理而不抛出错误
178
+ }
179
+ } catch (importError) {
180
+ log('[%s] Error importing auth module: %O', this.name, importError);
181
+ // 如果导入模块失败,继续处理而不抛出错误
182
+ }
164
183
  }
165
184
 
166
185
  if (payload.clientId) {
@@ -180,6 +199,7 @@ class OIDCAdapter {
180
199
 
181
200
  try {
182
201
  log('[%s] Executing upsert DB operation', this.name);
202
+
183
203
  await this.db
184
204
  .insert(table)
185
205
  .values(record as any)
@@ -335,7 +355,34 @@ class OIDCAdapter {
335
355
  */
336
356
  async findByUid(uid: string): Promise<any> {
337
357
  log('[Interaction] findByUid called - uid: %s', uid);
358
+ const table = this.getTable();
359
+ if (this.name === 'Session') {
360
+ try {
361
+ const jsonbUidEq = sql`${(table as any).data}->>'uid' = ${uid}`;
362
+ // @ts-ignore
363
+ const results = await this.db.select().from(table).where(jsonbUidEq).limit(1);
364
+ log('[Session] Find by data.uid query results: %O', results);
338
365
 
366
+ if (!results || results.length === 0) {
367
+ log('[Session] No record found by data.uid: %s', uid);
368
+ return undefined;
369
+ }
370
+
371
+ const model = results[0] as any;
372
+ // 检查过期
373
+ if (model.expiresAt && model.expiresAt < new Date()) {
374
+ log('[Session] Record found by data.uid but expired: %s', uid);
375
+ await this.destroy(model.id); // 仍然使用主键 id 删除
376
+ return undefined;
377
+ }
378
+
379
+ log('[Session] Successfully found by data.uid and returning record data for uid %s', uid);
380
+ return model.data;
381
+ } catch (error) {
382
+ log('[Session] ERROR during findSessionByUid operation for %s: %O', uid, error);
383
+ console.error(`[OIDC Adapter] Error finding Session by uid:`, error);
384
+ }
385
+ }
339
386
  // 复用 find 方法实现
340
387
  log('[Interaction] Delegating to find() method');
341
388
  return this.find(uid);
@@ -5,27 +5,26 @@ import { ClientMetadata } from 'oidc-provider';
5
5
  */
6
6
  export const defaultClients: ClientMetadata[] = [
7
7
  {
8
- // 公共客户端,令牌端点无需认证
9
8
  application_type: 'native',
10
9
  client_id: 'lobehub-desktop',
11
- description: 'LobeHub Desktop',
10
+ client_name: 'LobeHub Desktop',
12
11
  // 仅支持授权码流程
13
12
  grant_types: ['authorization_code', 'refresh_token'],
13
+
14
14
  // 明确指明是原生应用
15
15
  isFirstParty: true,
16
16
 
17
- name: 'LobeHub Desktop',
17
+ logo_uri: 'https://hub-apac-1.lobeobjects.space/lobehub-desktop-icon.png',
18
18
 
19
19
  // 桌面端注册的自定义协议回调(使用反向域名格式)
20
20
  post_logout_redirect_uris: ['com.lobehub.desktop://auth/logout/callback'],
21
21
 
22
- // 公共客户端,无密钥
23
22
  redirect_uris: ['com.lobehub.desktop://auth/callback', 'https://oauthdebugger.com/debug'],
24
23
 
25
24
  // 支持授权码获取令牌和刷新令牌
26
25
  response_types: ['code'],
27
26
 
28
- // 标记为第一方客户端
27
+ // 标记为公共客户端客户端,无密钥
29
28
  token_endpoint_auth_method: 'none',
30
29
  },
31
30
  ];
@@ -22,22 +22,13 @@ export const convertHeadersToNodeHeaders = (nextHeaders: Headers): Record<string
22
22
  /**
23
23
  * 创建用于 OIDC Provider 的 Node.js HTTP 请求对象
24
24
  * @param req Next.js 请求对象
25
- * @param pathPrefix 路径前缀
26
- * @param bodyText 可选的请求体文本,用于 POST 请求
27
25
  */
28
- export const createNodeRequest = (
29
- req: NextRequest,
30
- pathPrefix: string = '/oidc',
31
- bodyText?: string,
32
- ): IncomingMessage => {
26
+ export const createNodeRequest = async (req: NextRequest): Promise<IncomingMessage> => {
33
27
  // 构建 URL 对象
34
28
  const url = new URL(req.url);
35
29
 
36
30
  // 计算相对于前缀的路径
37
31
  let providerPath = url.pathname;
38
- if (pathPrefix && url.pathname.startsWith(pathPrefix)) {
39
- providerPath = url.pathname.slice(pathPrefix.length);
40
- }
41
32
 
42
33
  // 确保路径始终以/开头
43
34
  if (!providerPath.startsWith('/')) {
@@ -47,37 +38,74 @@ export const createNodeRequest = (
47
38
  log('Creating Node.js request from Next.js request');
48
39
  log('Original path: %s, Provider path: %s', url.pathname, providerPath);
49
40
 
41
+ // Attempt to parse and attach body for relevant methods
42
+ let parsedBody: any = undefined;
43
+ const methodsWithBody = ['POST', 'PUT', 'PATCH', 'DELETE'];
44
+ if (methodsWithBody.includes(req.method)) {
45
+ const contentType = req.headers.get('content-type')?.split(';')[0]; // Get content type without charset etc.
46
+ log(`Attempting to parse body for ${req.method} with Content-Type: ${contentType}`);
47
+ try {
48
+ // Check if body exists and has size before attempting to parse
49
+ if (req.body && req.headers.get('content-length') !== '0') {
50
+ if (contentType === 'application/x-www-form-urlencoded') {
51
+ const formData = await req.formData();
52
+ parsedBody = {};
53
+ formData.forEach((value, key) => {
54
+ // If a key appears multiple times, keep the last one (standard form behavior)
55
+ // Or convert to array if oidc-provider expects it:
56
+ // if (parsedBody[key]) {
57
+ // if (!Array.isArray(parsedBody[key])) parsedBody[key] = [parsedBody[key]];
58
+ // parsedBody[key].push(value);
59
+ // } else {
60
+ // parsedBody[key] = value;
61
+ // }
62
+ parsedBody[key] = value;
63
+ });
64
+ log('Parsed form data body: %O', parsedBody);
65
+ } else if (contentType === 'application/json') {
66
+ parsedBody = await req.json();
67
+ log('Parsed JSON body: %O', parsedBody);
68
+ } else {
69
+ log(`Body parsing skipped for Content-Type: ${contentType}. Trying text() as fallback.`);
70
+ // Fallback: try reading as text if content type is unknown but body exists
71
+ parsedBody = await req.text();
72
+ log('Parsed body as text fallback.');
73
+ }
74
+ } else {
75
+ log('Request has no body or content-length is 0, skipping parsing.');
76
+ }
77
+ } catch (error) {
78
+ log('Error parsing request body: %O', error);
79
+ // Keep parsedBody as undefined, let oidc-provider handle the potential issue
80
+ }
81
+ }
50
82
  const nodeRequest = {
51
83
  // 基本属性
52
84
  headers: convertHeadersToNodeHeaders(req.headers),
85
+
53
86
  method: req.method,
54
- // 模拟可读流行为
87
+ // 模拟可读流行为 (oidc-provider might not rely on this if body is pre-parsed)
55
88
  // eslint-disable-next-line @typescript-eslint/ban-types
56
89
  on: (event: string, handler: Function) => {
57
- if (event === 'data' && bodyText) {
58
- handler(bodyText);
59
- }
60
90
  if (event === 'end') {
91
+ // Simulate end immediately as body is already processed or will be attached
61
92
  handler();
62
93
  }
63
94
  },
64
-
65
- url: providerPath + url.search,
66
-
67
- // POST 请求所需属性
68
- ...(bodyText && {
69
- body: bodyText,
70
- readable: true,
71
- }),
72
-
73
95
  // 添加 Node.js 服务器所期望的额外属性
74
96
  socket: {
75
97
  remoteAddress: req.headers.get('x-forwarded-for') || '127.0.0.1',
76
98
  },
77
- } as unknown as IncomingMessage;
99
+ url: providerPath + url.search,
100
+ ...(parsedBody !== undefined && { body: parsedBody }), // Attach body if it exists
101
+ };
78
102
 
79
103
  log('Node.js request created with method %s and path %s', nodeRequest.method, nodeRequest.url);
80
- return nodeRequest;
104
+ if (nodeRequest.body) {
105
+ log('Attached parsed body to Node.js request.');
106
+ }
107
+ // Cast back to IncomingMessage for the function's return signature
108
+ return nodeRequest as unknown as IncomingMessage;
81
109
  };
82
110
 
83
111
  /**
@@ -143,6 +171,12 @@ export const createNodeResponse = (resolvePromise: () => void): ResponseCollecto
143
171
 
144
172
  headersSent: false,
145
173
 
174
+ removeHeader: (name: string) => {
175
+ const lowerName = name.toLowerCase();
176
+ log('Removing header: %s', lowerName);
177
+ delete state.responseHeaders[lowerName];
178
+ },
179
+
146
180
  setHeader: (name: string, value: string | string[]) => {
147
181
  const lowerName = name.toLowerCase();
148
182
  log('Setting header: %s = %s', lowerName, value);
@@ -198,8 +232,6 @@ export const createContextForInteractionDetails = async (
198
232
  uid: string,
199
233
  ): Promise<{ req: IncomingMessage; res: ServerResponse }> => {
200
234
  log('Creating context for interaction details for uid: %s', uid);
201
-
202
- // 使用APP_URL环境变量来构建URL基础部分
203
235
  const baseUrl = appEnv.APP_URL!;
204
236
  log('Using base URL: %s', baseUrl);
205
237
 
@@ -260,7 +292,7 @@ export const createContextForInteractionDetails = async (
260
292
 
261
293
  // 4. 使用 createNodeRequest 创建模拟的 Node.js IncomingMessage
262
294
  // pathPrefix 设置为 '/' 因为我们的 URL 已经是 Provider 期望的路径格式 /interaction/:uid
263
- const req: IncomingMessage = createNodeRequest(mockNextRequest, '/');
295
+ const req: IncomingMessage = await createNodeRequest(mockNextRequest);
264
296
  // @ts-ignore - 将解析出的 cookies 附加到模拟的 Node.js 请求上
265
297
  req.cookies = realCookies;
266
298
  log('Node.js IncomingMessage created, attached real cookies');
@@ -1,6 +1,8 @@
1
1
  import debug from 'debug';
2
2
  import Provider, { Configuration, KoaContextWithOIDC } from 'oidc-provider';
3
+ import urlJoin from 'url-join';
3
4
 
5
+ import { appEnv } from '@/config/app';
4
6
  import { serverDBEnv } from '@/config/db';
5
7
  import { UserModel } from '@/database/models/user';
6
8
  import { LobeChatDatabase } from '@/database/type';
@@ -61,18 +63,9 @@ const getCookieKeys = () => {
61
63
  /**
62
64
  * 创建 OIDC Provider 实例
63
65
  * @param db - 数据库实例
64
- * @param baseUrl - 服务部署的基础URL
65
66
  * @returns 配置好的 OIDC Provider 实例
66
67
  */
67
- export const createOIDCProvider = async (
68
- db: LobeChatDatabase,
69
- baseUrl: string,
70
- ): Promise<Provider> => {
71
- const issuerUrl = `${baseUrl}/oidc`;
72
- if (!issuerUrl) {
73
- throw new Error('Base URL is required for OIDC Provider');
74
- }
75
-
68
+ export const createOIDCProvider = async (db: LobeChatDatabase): Promise<Provider> => {
76
69
  // 获取 JWKS
77
70
  const jwks = getJWKS();
78
71
 
@@ -85,13 +78,41 @@ export const createOIDCProvider = async (
85
78
  // 4. Claims 定义
86
79
  claims: defaultClaims,
87
80
 
81
+ // 新增:客户端 CORS 控制逻辑
82
+ clientBasedCORS(ctx, origin, client) {
83
+ // 检查客户端是否允许此来源
84
+ // 一个常见的策略是允许所有已注册的 redirect_uris 的来源
85
+ if (!client || !client.redirectUris) {
86
+ logProvider('clientBasedCORS: No client or redirectUris found, denying origin: %s', origin);
87
+ return false; // 如果没有客户端或重定向URI,则拒绝
88
+ }
89
+
90
+ const allowed = client.redirectUris.some((uri) => {
91
+ try {
92
+ // 比较来源 (scheme, hostname, port)
93
+ return new URL(uri).origin === origin;
94
+ } catch {
95
+ // 如果 redirect_uri 不是有效的 URL (例如自定义协议),则跳过
96
+ return false;
97
+ }
98
+ });
99
+
100
+ logProvider(
101
+ 'clientBasedCORS check for origin [%s] and client [%s]: %s',
102
+ origin,
103
+ client.clientId,
104
+ allowed ? 'Allowed' : 'Denied',
105
+ );
106
+ return allowed;
107
+ },
108
+
88
109
  // 1. 客户端配置
89
110
  clients: defaultClients,
90
111
 
91
112
  // 7. Cookie 配置
92
113
  cookies: {
93
114
  keys: cookieKeys,
94
- long: { signed: true },
115
+ long: { path: '/', signed: true },
95
116
  short: { path: '/', signed: true },
96
117
  },
97
118
 
@@ -107,7 +128,6 @@ export const createOIDCProvider = async (
107
128
  rpInitiatedLogout: { enabled: true },
108
129
  userinfo: { enabled: true },
109
130
  },
110
-
111
131
  // 10. 账户查找
112
132
  async findAccount(ctx: KoaContextWithOIDC, id: string) {
113
133
  logProvider('findAccount called for id: %s', id);
@@ -183,6 +203,7 @@ export const createOIDCProvider = async (
183
203
  return undefined;
184
204
  }
185
205
  },
206
+
186
207
  // 9. 交互策略
187
208
  interactions: {
188
209
  policy: createInteractionPolicy(),
@@ -225,6 +246,11 @@ export const createOIDCProvider = async (
225
246
  // 新增:启用 Refresh Token 轮换
226
247
  rotateRefreshToken: true,
227
248
 
249
+ routes: {
250
+ authorization: '/oidc/auth',
251
+ end_session: '/oidc/session/end',
252
+ token: '/oidc/token',
253
+ },
228
254
  // 3. Scopes 定义
229
255
  scopes: defaultScopes,
230
256
 
@@ -243,7 +269,9 @@ export const createOIDCProvider = async (
243
269
  };
244
270
 
245
271
  // 创建提供者实例
246
- const provider = new Provider(issuerUrl, configuration);
272
+ const baseUrl = urlJoin(appEnv.APP_URL!, '/oidc');
273
+
274
+ const provider = new Provider(baseUrl, configuration);
247
275
 
248
276
  provider.on('server_error', (ctx, err) => {
249
277
  logProvider('OIDC Provider Server Error: %O', err); // Use logProvider
@@ -1,7 +1,7 @@
1
1
  import { initTRPC } from '@trpc/server';
2
2
  import superjson from 'superjson';
3
3
 
4
- import { AsyncContext } from '@/server/asyncContext';
4
+ import { AsyncContext } from './context';
5
5
 
6
6
  export const asyncTrpc = initTRPC.context<AsyncContext>().create({
7
7
  errorFormatter({ shape }) {
@@ -28,13 +28,13 @@ export const createContextInner = async (params?: {
28
28
  userId: params?.userId,
29
29
  });
30
30
 
31
- export type Context = Awaited<ReturnType<typeof createContextInner>>;
31
+ export type EdgeContext = Awaited<ReturnType<typeof createContextInner>>;
32
32
 
33
33
  /**
34
34
  * Creates context for an incoming request
35
35
  * @link https://trpc.io/docs/v11/context
36
36
  */
37
- export const createContext = async (request: NextRequest): Promise<Context> => {
37
+ export const createEdgeContext = async (request: NextRequest): Promise<EdgeContext> => {
38
38
  // for API-response caching see https://trpc.io/docs/v11/caching
39
39
 
40
40
  const authorization = request.headers.get(LOBE_CHAT_AUTH_HEADER);
@@ -10,40 +10,40 @@
10
10
  import { DESKTOP_USER_ID } from '@/const/desktop';
11
11
  import { isDesktop } from '@/const/version';
12
12
 
13
- import { trpc } from './init';
13
+ import { userAuth } from '../middleware/userAuth';
14
+ import { edgeTrpc } from './init';
14
15
  import { jwtPayloadChecker } from './middleware/jwtPayload';
15
- import { userAuth } from './middleware/userAuth';
16
16
 
17
17
  /**
18
18
  * Create a router
19
19
  * @link https://trpc.io/docs/v11/router
20
20
  */
21
- export const router = trpc.router;
21
+ export const router = edgeTrpc.router;
22
22
 
23
23
  /**
24
24
  * Create an unprotected procedure
25
25
  * @link https://trpc.io/docs/v11/procedures
26
26
  **/
27
- export const publicProcedure = trpc.procedure.use(({ next, ctx }) => {
27
+ export const publicProcedure = edgeTrpc.procedure.use(({ next, ctx }) => {
28
28
  return next({
29
29
  ctx: { userId: isDesktop ? DESKTOP_USER_ID : ctx.userId },
30
30
  });
31
31
  });
32
32
 
33
33
  // procedure that asserts that the user is logged in
34
- export const authedProcedure = trpc.procedure.use(userAuth);
34
+ export const authedProcedure = edgeTrpc.procedure.use(userAuth);
35
35
 
36
36
  // procedure that asserts that the user add the password
37
- export const passwordProcedure = trpc.procedure.use(jwtPayloadChecker);
37
+ export const passwordProcedure = edgeTrpc.procedure.use(jwtPayloadChecker);
38
38
 
39
39
  /**
40
40
  * Merge multiple routers together
41
41
  * @link https://trpc.io/docs/v11/merging-routers
42
42
  */
43
- export const mergeRouters = trpc.mergeRouters;
43
+ export const mergeRouters = edgeTrpc.mergeRouters;
44
44
 
45
45
  /**
46
46
  * Create a server-side caller
47
47
  * @link https://trpc.io/docs/v11/server/server-side-calls
48
48
  */
49
- export const createCallerFactory = trpc.createCallerFactory;
49
+ export const createCallerFactory = edgeTrpc.createCallerFactory;
@@ -10,9 +10,9 @@
10
10
  import { initTRPC } from '@trpc/server';
11
11
  import superjson from 'superjson';
12
12
 
13
- import type { Context } from '@/server/context';
13
+ import type { EdgeContext } from './context';
14
14
 
15
- export const trpc = initTRPC.context<Context>().create({
15
+ export const edgeTrpc = initTRPC.context<EdgeContext>().create({
16
16
  /**
17
17
  * @link https://trpc.io/docs/v11/error-formatting
18
18
  */
@@ -2,9 +2,9 @@
2
2
  import { TRPCError } from '@trpc/server';
3
3
  import { beforeEach, describe, expect, it, vi } from 'vitest';
4
4
 
5
- import { createCallerFactory } from '@/libs/trpc';
6
- import { trpc } from '@/libs/trpc/init';
7
- import { AuthContext, createContextInner } from '@/server/context';
5
+ import { createCallerFactory } from '@/libs/trpc/edge';
6
+ import { AuthContext, createContextInner } from '@/libs/trpc/edge/context';
7
+ import { edgeTrpc as trpc } from '@/libs/trpc/edge/init';
8
8
  import * as utils from '@/utils/server/jwt';
9
9
 
10
10
  import { jwtPayloadChecker } from './jwtPayload';
@@ -1,9 +1,10 @@
1
1
  import { TRPCError } from '@trpc/server';
2
2
 
3
- import { trpc } from '@/libs/trpc/init';
4
3
  import { getJWTPayload } from '@/utils/server/jwt';
5
4
 
6
- export const jwtPayloadChecker = trpc.middleware(async (opts) => {
5
+ import { edgeTrpc } from '../init';
6
+
7
+ export const jwtPayloadChecker = edgeTrpc.middleware(async (opts) => {
7
8
  const { ctx } = opts;
8
9
 
9
10
  if (!ctx.authorizationHeader) throw new TRPCError({ code: 'UNAUTHORIZED' });
@@ -0,0 +1,70 @@
1
+ import { User } from 'next-auth';
2
+ import { NextRequest } from 'next/server';
3
+
4
+ import { JWTPayload, LOBE_CHAT_AUTH_HEADER, enableClerk, enableNextAuth } from '@/const/auth';
5
+ import { ClerkAuth, IClerkAuth } from '@/libs/clerk-auth';
6
+
7
+ export interface AuthContext {
8
+ authorizationHeader?: string | null;
9
+ clerkAuth?: IClerkAuth;
10
+ jwtPayload?: JWTPayload | null;
11
+ nextAuth?: User;
12
+ userId?: string | null;
13
+ }
14
+
15
+ /**
16
+ * Inner function for `createContext` where we create the context.
17
+ * This is useful for testing when we don't want to mock Next.js' request/response
18
+ */
19
+ export const createContextInner = async (params?: {
20
+ authorizationHeader?: string | null;
21
+ clerkAuth?: IClerkAuth;
22
+ nextAuth?: User;
23
+ userId?: string | null;
24
+ }): Promise<AuthContext> => ({
25
+ authorizationHeader: params?.authorizationHeader,
26
+ clerkAuth: params?.clerkAuth,
27
+ nextAuth: params?.nextAuth,
28
+ userId: params?.userId,
29
+ });
30
+
31
+ export type LambdaContext = Awaited<ReturnType<typeof createContextInner>>;
32
+
33
+ /**
34
+ * Creates context for an incoming request
35
+ * @link https://trpc.io/docs/v11/context
36
+ */
37
+ export const createLambdaContext = async (request: NextRequest): Promise<LambdaContext> => {
38
+ // for API-response caching see https://trpc.io/docs/v11/caching
39
+
40
+ const authorization = request.headers.get(LOBE_CHAT_AUTH_HEADER);
41
+
42
+ let userId;
43
+ let auth;
44
+
45
+ if (enableClerk) {
46
+ const clerkAuth = new ClerkAuth();
47
+ const result = clerkAuth.getAuthFromRequest(request);
48
+ auth = result.clerkAuth;
49
+ userId = result.userId;
50
+
51
+ return createContextInner({ authorizationHeader: authorization, clerkAuth: auth, userId });
52
+ }
53
+
54
+ if (enableNextAuth) {
55
+ try {
56
+ const { default: NextAuthEdge } = await import('@/libs/next-auth/edge');
57
+
58
+ const session = await NextAuthEdge.auth();
59
+ if (session && session?.user?.id) {
60
+ auth = session.user;
61
+ userId = session.user.id;
62
+ }
63
+ return createContextInner({ authorizationHeader: authorization, nextAuth: auth, userId });
64
+ } catch (e) {
65
+ console.error('next auth err', e);
66
+ }
67
+ }
68
+
69
+ return createContextInner({ authorizationHeader: authorization, userId });
70
+ };
@@ -1 +1,39 @@
1
- export * from './serverDatabase';
1
+ /**
2
+ * This is your entry point to setup the root configuration for tRPC on the server.
3
+ * - `initTRPC` should only be used once per app.
4
+ * - We export only the functionality that we use so we can enforce which base procedures should be used
5
+ *
6
+ * Learn how to create protected base procedures and other things below:
7
+ * @link https://trpc.io/docs/v11/router
8
+ * @link https://trpc.io/docs/v11/procedures
9
+ */
10
+ import { DESKTOP_USER_ID } from '@/const/desktop';
11
+ import { isDesktop } from '@/const/version';
12
+
13
+ import { userAuth } from '../middleware/userAuth';
14
+ import { trpc } from './init';
15
+
16
+ /**
17
+ * Create a router
18
+ * @link https://trpc.io/docs/v11/router
19
+ */
20
+ export const router = trpc.router;
21
+
22
+ /**
23
+ * Create an unprotected procedure
24
+ * @link https://trpc.io/docs/v11/procedures
25
+ **/
26
+ export const publicProcedure = trpc.procedure.use(({ next, ctx }) => {
27
+ return next({
28
+ ctx: { userId: isDesktop ? DESKTOP_USER_ID : ctx.userId },
29
+ });
30
+ });
31
+
32
+ // procedure that asserts that the user is logged in
33
+ export const authedProcedure = trpc.procedure.use(userAuth);
34
+
35
+ /**
36
+ * Create a server-side caller
37
+ * @link https://trpc.io/docs/v11/server/server-side-calls
38
+ */
39
+ export const createCallerFactory = trpc.createCallerFactory;
@@ -0,0 +1,26 @@
1
+ /**
2
+ * This is your entry point to setup the root configuration for tRPC on the server.
3
+ * - `initTRPC` should only be used once per app.
4
+ * - We export only the functionality that we use so we can enforce which base procedures should be used
5
+ *
6
+ * Learn how to create protected base procedures and other things below:
7
+ * @link https://trpc.io/docs/v11/router
8
+ * @link https://trpc.io/docs/v11/procedures
9
+ */
10
+ import { initTRPC } from '@trpc/server';
11
+ import superjson from 'superjson';
12
+
13
+ import type { LambdaContext } from './context';
14
+
15
+ export const trpc = initTRPC.context<LambdaContext>().create({
16
+ /**
17
+ * @link https://trpc.io/docs/v11/error-formatting
18
+ */
19
+ errorFormatter({ shape }) {
20
+ return shape;
21
+ },
22
+ /**
23
+ * @link https://trpc.io/docs/v11/data-transformers
24
+ */
25
+ transformer: superjson,
26
+ });
@@ -0,0 +1,2 @@
1
+ export * from './keyVaults';
2
+ export * from './serverDatabase';
@@ -1,8 +1,9 @@
1
1
  import { TRPCError } from '@trpc/server';
2
2
 
3
- import { trpc } from '@/libs/trpc/init';
4
3
  import { getJWTPayload } from '@/utils/server/jwt';
5
4
 
5
+ import { trpc } from '../init';
6
+
6
7
  export const keyVaults = trpc.middleware(async (opts) => {
7
8
  const { ctx } = opts;
8
9
 
@@ -1,5 +1,6 @@
1
1
  import { getServerDB } from '@/database/core/db-adaptor';
2
- import { trpc } from '@/libs/trpc/init';
2
+
3
+ import { trpc } from '../init';
3
4
 
4
5
  export const serverDatabase = trpc.middleware(async (opts) => {
5
6
  const serverDB = await getServerDB();
@@ -1,10 +1,10 @@
1
1
  import { TRPCError } from '@trpc/server';
2
2
  import { beforeEach, describe, expect, it, vi } from 'vitest';
3
3
 
4
- import { createCallerFactory } from '@/libs/trpc';
5
- import { AuthContext, createContextInner } from '@/server/context';
4
+ import { createCallerFactory } from '@/libs/trpc/lambda';
5
+ import { AuthContext, createContextInner } from '@/libs/trpc/lambda/context';
6
6
 
7
- import { trpc } from '../init';
7
+ import { trpc } from '../lambda/init';
8
8
  import { userAuth } from './userAuth';
9
9
 
10
10
  const appRouter = trpc.router({
@@ -4,7 +4,7 @@ import { enableClerk } from '@/const/auth';
4
4
  import { DESKTOP_USER_ID } from '@/const/desktop';
5
5
  import { isDesktop } from '@/const/version';
6
6
 
7
- import { trpc } from '../init';
7
+ import { trpc } from '../lambda/init';
8
8
 
9
9
  export const userAuth = trpc.middleware(async (opts) => {
10
10
  const { ctx } = opts;