@lobehub/lobehub 2.0.0-next.200 → 2.0.0-next.202

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. package/CHANGELOG.md +50 -0
  2. package/changelog/v1.json +18 -0
  3. package/locales/ar/chat.json +2 -0
  4. package/locales/ar/models.json +64 -7
  5. package/locales/ar/plugin.json +2 -1
  6. package/locales/ar/providers.json +1 -0
  7. package/locales/bg-BG/chat.json +2 -0
  8. package/locales/bg-BG/models.json +49 -5
  9. package/locales/bg-BG/plugin.json +2 -1
  10. package/locales/bg-BG/providers.json +1 -0
  11. package/locales/de-DE/chat.json +2 -0
  12. package/locales/de-DE/models.json +36 -7
  13. package/locales/de-DE/plugin.json +2 -1
  14. package/locales/de-DE/providers.json +1 -0
  15. package/locales/en-US/chat.json +2 -0
  16. package/locales/en-US/models.json +10 -10
  17. package/locales/en-US/plugin.json +2 -1
  18. package/locales/en-US/providers.json +1 -0
  19. package/locales/es-ES/chat.json +2 -0
  20. package/locales/es-ES/models.json +106 -7
  21. package/locales/es-ES/plugin.json +2 -1
  22. package/locales/es-ES/providers.json +1 -0
  23. package/locales/fa-IR/chat.json +2 -0
  24. package/locales/fa-IR/models.json +83 -5
  25. package/locales/fa-IR/plugin.json +2 -1
  26. package/locales/fa-IR/providers.json +1 -0
  27. package/locales/fr-FR/chat.json +2 -0
  28. package/locales/fr-FR/models.json +38 -7
  29. package/locales/fr-FR/plugin.json +2 -1
  30. package/locales/fr-FR/providers.json +1 -0
  31. package/locales/it-IT/chat.json +2 -0
  32. package/locales/it-IT/models.json +40 -5
  33. package/locales/it-IT/plugin.json +2 -1
  34. package/locales/it-IT/providers.json +1 -0
  35. package/locales/ja-JP/chat.json +2 -0
  36. package/locales/ja-JP/models.json +84 -7
  37. package/locales/ja-JP/plugin.json +2 -1
  38. package/locales/ja-JP/providers.json +1 -0
  39. package/locales/ko-KR/chat.json +2 -0
  40. package/locales/ko-KR/models.json +65 -7
  41. package/locales/ko-KR/plugin.json +2 -1
  42. package/locales/ko-KR/providers.json +1 -0
  43. package/locales/nl-NL/chat.json +2 -0
  44. package/locales/nl-NL/models.json +62 -5
  45. package/locales/nl-NL/plugin.json +2 -1
  46. package/locales/nl-NL/providers.json +1 -0
  47. package/locales/pl-PL/chat.json +2 -0
  48. package/locales/pl-PL/models.json +85 -0
  49. package/locales/pl-PL/plugin.json +2 -1
  50. package/locales/pl-PL/providers.json +1 -0
  51. package/locales/pt-BR/chat.json +2 -0
  52. package/locales/pt-BR/models.json +37 -6
  53. package/locales/pt-BR/plugin.json +2 -1
  54. package/locales/pt-BR/providers.json +1 -0
  55. package/locales/ru-RU/chat.json +2 -0
  56. package/locales/ru-RU/models.json +36 -7
  57. package/locales/ru-RU/plugin.json +2 -1
  58. package/locales/ru-RU/providers.json +1 -0
  59. package/locales/tr-TR/chat.json +2 -0
  60. package/locales/tr-TR/models.json +28 -7
  61. package/locales/tr-TR/plugin.json +2 -1
  62. package/locales/tr-TR/providers.json +1 -0
  63. package/locales/vi-VN/chat.json +2 -0
  64. package/locales/vi-VN/models.json +62 -5
  65. package/locales/vi-VN/plugin.json +2 -1
  66. package/locales/vi-VN/providers.json +1 -0
  67. package/locales/zh-CN/chat.json +2 -0
  68. package/locales/zh-CN/models.json +87 -6
  69. package/locales/zh-CN/plugin.json +2 -1
  70. package/locales/zh-CN/providers.json +1 -0
  71. package/locales/zh-TW/chat.json +2 -0
  72. package/locales/zh-TW/models.json +71 -7
  73. package/locales/zh-TW/plugin.json +2 -1
  74. package/locales/zh-TW/providers.json +1 -0
  75. package/package.json +2 -2
  76. package/packages/builtin-tool-gtd/src/client/Inspector/ExecTask/index.tsx +30 -15
  77. package/packages/builtin-tool-gtd/src/manifest.ts +1 -1
  78. package/packages/model-runtime/src/core/ModelRuntime.test.ts +44 -86
  79. package/packages/types/src/aiChat.ts +0 -1
  80. package/packages/types/src/message/ui/chat.ts +1 -1
  81. package/src/app/(backend)/middleware/auth/index.ts +16 -2
  82. package/src/app/(backend)/webapi/chat/[provider]/route.test.ts +30 -15
  83. package/src/app/(backend)/webapi/chat/[provider]/route.ts +44 -40
  84. package/src/app/(backend)/webapi/models/[provider]/pull/route.ts +4 -3
  85. package/src/app/(backend)/webapi/models/[provider]/route.test.ts +36 -13
  86. package/src/app/(backend)/webapi/models/[provider]/route.ts +4 -11
  87. package/src/app/[variants]/(desktop)/desktop-onboarding/index.tsx +8 -2
  88. package/src/features/Conversation/Messages/AssistantGroup/Tool/Render/index.tsx +21 -23
  89. package/src/features/Conversation/Messages/AssistantGroup/components/ContentBlock.tsx +16 -3
  90. package/src/features/Conversation/Messages/Task/TaskDetailPanel/index.tsx +17 -20
  91. package/src/features/Conversation/Messages/Tasks/shared/ErrorState.tsx +16 -11
  92. package/src/features/Conversation/Messages/Tasks/shared/InitializingState.tsx +6 -20
  93. package/src/features/Conversation/Messages/Tasks/shared/ProcessingState.tsx +10 -20
  94. package/src/features/User/DataStatistics.tsx +4 -4
  95. package/src/hooks/useQueryParam.ts +0 -2
  96. package/src/libs/trpc/async/asyncAuth.ts +0 -2
  97. package/src/libs/trpc/async/context.ts +3 -11
  98. package/src/locales/default/chat.ts +2 -0
  99. package/src/locales/default/plugin.ts +2 -1
  100. package/src/server/modules/AgentRuntime/RuntimeExecutors.ts +6 -6
  101. package/src/server/modules/AgentRuntime/__tests__/RuntimeExecutors.test.ts +3 -3
  102. package/src/server/modules/AgentRuntime/factory.ts +39 -20
  103. package/src/server/modules/ModelRuntime/index.ts +138 -1
  104. package/src/server/routers/async/__tests__/caller.test.ts +22 -27
  105. package/src/server/routers/async/caller.ts +4 -6
  106. package/src/server/routers/async/file.ts +10 -5
  107. package/src/server/routers/async/image.ts +5 -4
  108. package/src/server/routers/async/ragEval.ts +7 -5
  109. package/src/server/routers/lambda/__tests__/aiChat.test.ts +8 -37
  110. package/src/server/routers/lambda/aiChat.ts +5 -21
  111. package/src/server/routers/lambda/chunk.ts +9 -28
  112. package/src/server/routers/lambda/image.ts +1 -7
  113. package/src/server/routers/lambda/ragEval.ts +1 -1
  114. package/src/server/routers/lambda/userMemories/reembed.ts +4 -1
  115. package/src/server/routers/lambda/userMemories/search.ts +7 -7
  116. package/src/server/routers/lambda/userMemories/shared.ts +8 -10
  117. package/src/server/routers/lambda/userMemories/tools.ts +140 -118
  118. package/src/server/routers/lambda/userMemories.test.ts +3 -7
  119. package/src/server/routers/lambda/userMemories.ts +44 -29
  120. package/src/server/services/agentRuntime/AgentRuntimeService.test.ts +87 -0
  121. package/src/server/services/agentRuntime/AgentRuntimeService.ts +53 -2
  122. package/src/server/services/agentRuntime/__tests__/executeSync.test.ts +2 -6
  123. package/src/server/services/agentRuntime/__tests__/stepLifecycleCallbacks.test.ts +1 -1
  124. package/src/server/services/chunk/index.ts +6 -5
  125. package/src/server/services/toolExecution/types.ts +1 -2
  126. package/src/services/__tests__/_url.test.ts +0 -1
  127. package/src/services/_url.ts +0 -3
  128. package/src/services/aiChat.ts +5 -12
  129. package/src/store/chat/slices/aiChat/actions/streamingExecutor.ts +0 -2
  130. package/src/app/(backend)/webapi/text-to-image/[provider]/route.ts +0 -74
@@ -1,18 +1,25 @@
1
1
  // @vitest-environment node
2
- import { TraceNameMap } from '@lobechat/types';
3
2
  import { ClientSecretPayload } from '@lobechat/types';
4
- import { Langfuse } from 'langfuse';
5
- import { LangfuseGenerationClient, LangfuseTraceClient } from 'langfuse-core';
6
3
  import { ModelProvider } from 'model-bank';
7
4
  import { beforeEach, describe, expect, it, vi } from 'vitest';
8
5
 
9
- import * as langfuseCfg from '@/envs/langfuse';
10
- import { createTraceOptions } from '@/server/modules/ModelRuntime';
11
-
12
- import { ChatStreamPayload, LobeOpenAI, ModelRuntime } from '../index';
6
+ import { ChatStreamCallbacks, ChatStreamPayload, LobeOpenAI, ModelRuntime } from '../index';
13
7
  import { providerRuntimeMap } from '../runtimeMap';
14
8
  import { CreateImagePayload } from '../types/image';
15
- import { AgentChatOptions } from './ModelRuntime';
9
+
10
+ /**
11
+ * Mock createTraceOptions for testing purposes.
12
+ * This avoids importing from @/server/modules/ModelRuntime which has database dependencies.
13
+ */
14
+ const createMockTraceOptions = (callbacks?: Partial<ChatStreamCallbacks>) => ({
15
+ callback: {
16
+ onCompletion: callbacks?.onCompletion ?? vi.fn(),
17
+ onFinal: callbacks?.onFinal ?? vi.fn(),
18
+ onStart: callbacks?.onStart ?? vi.fn(),
19
+ onToolsCalling: callbacks?.onToolsCalling ?? vi.fn(),
20
+ } as ChatStreamCallbacks,
21
+ headers: new Headers(),
22
+ });
16
23
 
17
24
  const specialProviders = [
18
25
  { id: 'openai', payload: { apiKey: 'user-openai-key', baseURL: 'user-endpoint' } },
@@ -99,85 +106,50 @@ describe('ModelRuntime', () => {
99
106
 
100
107
  await mockModelRuntime.chat(payload);
101
108
  });
102
- it('should handle options correctly', async () => {
109
+ it('should handle options with callbacks correctly', async () => {
103
110
  const payload: ChatStreamPayload = {
104
111
  messages: [{ role: 'user', content: 'Hello, world!' }],
105
112
  model: 'text-davinci-002',
106
113
  temperature: 0,
107
114
  };
108
115
 
109
- const options: AgentChatOptions = {
110
- provider: 'openai',
111
- trace: {
112
- traceId: 'test-trace-id',
113
- traceName: TraceNameMap.SummaryTopicTitle,
114
- sessionId: 'test-session-id',
115
- topicId: 'test-topic-id',
116
- tags: [],
117
- userId: 'test-user-id',
118
- },
119
- };
120
-
121
116
  vi.spyOn(LobeOpenAI.prototype, 'chat').mockResolvedValue(new Response(''));
122
117
 
123
- await mockModelRuntime.chat(payload, createTraceOptions(payload, options));
118
+ await mockModelRuntime.chat(payload, createMockTraceOptions());
124
119
  });
125
120
 
126
- describe('callback', async () => {
121
+ describe('callback', () => {
127
122
  const payload: ChatStreamPayload = {
128
123
  messages: [{ role: 'user', content: 'Hello, world!' }],
129
124
  model: 'text-davinci-002',
130
125
  temperature: 0,
131
126
  };
132
127
 
133
- const options: AgentChatOptions = {
134
- provider: 'openai',
135
- trace: {
136
- traceId: 'test-trace-id',
137
- traceName: TraceNameMap.SummaryTopicTitle,
138
- sessionId: 'test-session-id',
139
- topicId: 'test-topic-id',
140
- tags: [],
141
- userId: 'test-user-id',
142
- },
143
- enableTrace: true,
144
- };
145
-
146
- const updateMock = vi.fn();
147
-
148
128
  it('should call onToolsCalling correctly', async () => {
149
- vi.spyOn(langfuseCfg, 'getLangfuseConfig').mockReturnValue({
150
- ENABLE_LANGFUSE: true,
151
- LANGFUSE_PUBLIC_KEY: 'abc',
152
- LANGFUSE_SECRET_KEY: 'DDD',
153
- } as any);
129
+ const onToolsCallingMock = vi.fn();
154
130
 
155
- // 使用 spyOn 模拟 chat 方法
156
131
  vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation(
157
- async (payload, { callback }: any) => {
158
- // 模拟 onToolCall 回调的触发
132
+ async (_payload, { callback }: any) => {
159
133
  if (callback?.onToolsCalling) {
160
134
  await callback.onToolsCalling();
161
135
  }
162
136
  return new Response('abc');
163
137
  },
164
138
  );
165
- vi.spyOn(LangfuseTraceClient.prototype, 'update').mockImplementation(updateMock);
166
139
 
167
- await mockModelRuntime.chat(payload, createTraceOptions(payload, options));
140
+ await mockModelRuntime.chat(
141
+ payload,
142
+ createMockTraceOptions({ onToolsCalling: onToolsCallingMock }),
143
+ );
168
144
 
169
- expect(updateMock).toHaveBeenCalledWith({ tags: ['Tools Calling'] });
145
+ expect(onToolsCallingMock).toHaveBeenCalled();
170
146
  });
147
+
171
148
  it('should call onStart correctly', async () => {
172
- vi.spyOn(langfuseCfg, 'getLangfuseConfig').mockReturnValue({
173
- ENABLE_LANGFUSE: true,
174
- LANGFUSE_PUBLIC_KEY: 'abc',
175
- LANGFUSE_SECRET_KEY: 'DDD',
176
- } as any);
149
+ const onStartMock = vi.fn();
177
150
 
178
- vi.spyOn(LangfuseGenerationClient.prototype, 'update').mockImplementation(updateMock);
179
151
  vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation(
180
- async (payload, { callback }: any) => {
152
+ async (_payload, { callback }: any) => {
181
153
  if (callback?.onStart) {
182
154
  callback.onStart();
183
155
  }
@@ -185,22 +157,16 @@ describe('ModelRuntime', () => {
185
157
  },
186
158
  );
187
159
 
188
- await mockModelRuntime.chat(payload, createTraceOptions(payload, options));
160
+ await mockModelRuntime.chat(payload, createMockTraceOptions({ onStart: onStartMock }));
189
161
 
190
- // Verify onStart was called
191
- expect(updateMock).toHaveBeenCalledWith({ completionStartTime: expect.any(Date) });
162
+ expect(onStartMock).toHaveBeenCalled();
192
163
  });
193
164
 
194
165
  it('should call onCompletion correctly', async () => {
195
- vi.spyOn(langfuseCfg, 'getLangfuseConfig').mockReturnValue({
196
- ENABLE_LANGFUSE: true,
197
- LANGFUSE_PUBLIC_KEY: 'abc',
198
- LANGFUSE_SECRET_KEY: 'DDD',
199
- } as any);
200
- // Spy on the chat method and trigger onCompletion callback
201
- vi.spyOn(LangfuseGenerationClient.prototype, 'update').mockImplementation(updateMock);
166
+ const onCompletionMock = vi.fn();
167
+
202
168
  vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation(
203
- async (payload, { callback }: any) => {
169
+ async (_payload, { callback }: any) => {
204
170
  if (callback?.onCompletion) {
205
171
  await callback.onCompletion({ text: 'Test completion' });
206
172
  }
@@ -208,37 +174,29 @@ describe('ModelRuntime', () => {
208
174
  },
209
175
  );
210
176
 
211
- await mockModelRuntime.chat(payload, createTraceOptions(payload, options));
177
+ await mockModelRuntime.chat(
178
+ payload,
179
+ createMockTraceOptions({ onCompletion: onCompletionMock }),
180
+ );
212
181
 
213
- // Verify onCompletion was called with expected output
214
- expect(updateMock).toHaveBeenCalledWith({
215
- endTime: expect.any(Date),
216
- metadata: {},
217
- output: 'Test completion',
218
- });
182
+ expect(onCompletionMock).toHaveBeenCalledWith({ text: 'Test completion' });
219
183
  });
220
- it.skip('should call onFinal correctly', async () => {
221
- vi.spyOn(langfuseCfg, 'getLangfuseConfig').mockReturnValue({
222
- ENABLE_LANGFUSE: true,
223
- LANGFUSE_PUBLIC_KEY: 'abc',
224
- LANGFUSE_SECRET_KEY: 'DDD',
225
- } as any);
184
+
185
+ it('should call onFinal correctly', async () => {
186
+ const onFinalMock = vi.fn();
226
187
 
227
188
  vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation(
228
- async (payload, { callback }: any) => {
189
+ async (_payload, { callback }: any) => {
229
190
  if (callback?.onFinal) {
230
191
  await callback.onFinal('Test completion');
231
192
  }
232
193
  return new Response('Success');
233
194
  },
234
195
  );
235
- const shutdownAsyncMock = vi.fn();
236
- vi.spyOn(Langfuse.prototype, 'shutdownAsync').mockImplementation(shutdownAsyncMock);
237
196
 
238
- await mockModelRuntime.chat(payload, createTraceOptions(payload, options));
197
+ await mockModelRuntime.chat(payload, createMockTraceOptions({ onFinal: onFinalMock }));
239
198
 
240
- // Verify onCompletion was called with expected output
241
- expect(shutdownAsyncMock).toHaveBeenCalled();
199
+ expect(onFinalMock).toHaveBeenCalledWith('Test completion');
242
200
  });
243
201
  });
244
202
  });
@@ -120,7 +120,6 @@ export const StructureSchema = z.object({
120
120
  });
121
121
 
122
122
  export const StructureOutputSchema = z.object({
123
- keyVaultsPayload: z.string(),
124
123
  messages: z.array(z.any()),
125
124
  model: z.string(),
126
125
  provider: z.string(),
@@ -69,7 +69,7 @@ export interface TaskDetail {
69
69
  /** Execution duration in milliseconds */
70
70
  duration?: number;
71
71
  /** Error message if task failed */
72
- error?: string;
72
+ error?: Record<string, any>;
73
73
  /** Task start time (ISO string) */
74
74
  startedAt?: string;
75
75
  /** Task status */
@@ -15,6 +15,8 @@ import {
15
15
  enableBetterAuth,
16
16
  enableClerk,
17
17
  } from '@/const/auth';
18
+ import { getServerDB } from '@/database/core/db-adaptor';
19
+ import { type LobeChatDatabase } from '@/database/type';
18
20
  import { ClerkAuth } from '@/libs/clerk-auth';
19
21
  import { validateOIDCJWT } from '@/libs/oidc-provider/jwt';
20
22
  import { createErrorResponse } from '@/utils/errorResponse';
@@ -28,6 +30,8 @@ export type RequestHandler = (
28
30
  req: Request,
29
31
  options: RequestOptions & {
30
32
  jwtPayload: ClientSecretPayload;
33
+ serverDB: LobeChatDatabase;
34
+ userId: string;
31
35
  },
32
36
  ) => Promise<Response>;
33
37
 
@@ -38,10 +42,18 @@ export const checkAuth =
38
42
  // This ensures the handler can safely read the request body
39
43
  const clonedReq = req.clone();
40
44
 
45
+ // Get serverDB for database access
46
+ const serverDB = await getServerDB();
47
+
41
48
  // we have a special header to debug the api endpoint in development mode
42
49
  const isDebugApi = req.headers.get('lobe-auth-dev-backend-api') === '1';
43
50
  if (process.env.NODE_ENV === 'development' && isDebugApi) {
44
- return handler(clonedReq, { ...options, jwtPayload: { userId: 'DEV_USER' } });
51
+ return handler(clonedReq, {
52
+ ...options,
53
+ jwtPayload: { userId: 'DEV_USER' },
54
+ serverDB,
55
+ userId: 'DEV_USER',
56
+ });
45
57
  }
46
58
 
47
59
  let jwtPayload: ClientSecretPayload;
@@ -124,5 +136,7 @@ export const checkAuth =
124
136
  return createErrorResponse(errorType, { error, ...res, provider: params?.provider });
125
137
  }
126
138
 
127
- return handler(clonedReq, { ...options, jwtPayload });
139
+ const userId = jwtPayload.userId || '';
140
+
141
+ return handler(clonedReq, { ...options, jwtPayload, serverDB, userId });
128
142
  };
@@ -7,6 +7,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
7
7
 
8
8
  import { checkAuthMethod } from '@/app/(backend)/middleware/auth/utils';
9
9
  import { LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED } from '@/const/auth';
10
+ import { initModelRuntimeFromDB } from '@/server/modules/ModelRuntime';
10
11
 
11
12
  import { POST } from './route';
12
13
 
@@ -22,6 +23,11 @@ vi.mock('@lobechat/utils/server', () => ({
22
23
  getXorPayload: vi.fn(),
23
24
  }));
24
25
 
26
+ vi.mock('@/server/modules/ModelRuntime', () => ({
27
+ initModelRuntimeFromDB: vi.fn(),
28
+ createTraceOptions: vi.fn().mockReturnValue({}),
29
+ }));
30
+
25
31
  // Use vi.hoisted to ensure mockState is initialized before mocks are set up
26
32
  const mockState = vi.hoisted(() => ({ enableClerk: false }));
27
33
 
@@ -60,7 +66,7 @@ describe('POST handler', () => {
60
66
  it('should initialize ModelRuntime correctly with valid authorization', async () => {
61
67
  const mockParams = Promise.resolve({ provider: 'test-provider' });
62
68
 
63
- // 设置 getJWTPayload 和 initModelRuntimeWithUserPayload 的模拟返回值
69
+ // 设置 getJWTPayload 的模拟返回值
64
70
  vi.mocked(getXorPayload).mockReturnValueOnce({
65
71
  apiKey: 'test-api-key',
66
72
  azureApiVersion: 'v1',
@@ -68,17 +74,19 @@ describe('POST handler', () => {
68
74
 
69
75
  const mockRuntime: LobeRuntimeAI = { baseURL: 'abc', chat: vi.fn() };
70
76
 
71
- // migrate to new ModelRuntime init api
72
- const spy = vi
73
- .spyOn(ModelRuntime, 'initializeWithProvider')
74
- .mockResolvedValue(new ModelRuntime(mockRuntime));
77
+ // Mock initModelRuntimeFromDB
78
+ vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
75
79
 
76
80
  // 调用 POST 函数
77
81
  await POST(request as unknown as Request, { params: mockParams });
78
82
 
79
83
  // 验证是否正确调用了模拟函数
80
84
  expect(getXorPayload).toHaveBeenCalledWith('Bearer some-valid-token');
81
- expect(spy).toHaveBeenCalledWith('test-provider', expect.anything());
85
+ expect(initModelRuntimeFromDB).toHaveBeenCalledWith(
86
+ expect.anything(),
87
+ expect.any(String),
88
+ 'test-provider',
89
+ );
82
90
  });
83
91
 
84
92
  it('should return Unauthorized error when LOBE_CHAT_AUTH_HEADER is missing', async () => {
@@ -109,15 +117,13 @@ describe('POST handler', () => {
109
117
  });
110
118
 
111
119
  const mockParams = Promise.resolve({ provider: 'test-provider' });
112
- // 设置 initModelRuntimeWithUserPayload 的模拟返回值
113
120
  vi.mocked(getAuth).mockReturnValue({} as any);
114
121
  vi.mocked(checkAuthMethod).mockReset();
115
122
 
116
123
  const mockRuntime: LobeRuntimeAI = { baseURL: 'abc', chat: vi.fn() };
117
124
 
118
- vi.spyOn(ModelRuntime, 'initializeWithProvider').mockResolvedValue(
119
- new ModelRuntime(mockRuntime),
120
- );
125
+ // Mock initModelRuntimeFromDB
126
+ vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
121
127
 
122
128
  const request = new Request(new URL('https://test.com'), {
123
129
  method: 'POST',
@@ -174,20 +180,23 @@ describe('POST handler', () => {
174
180
  });
175
181
 
176
182
  const mockChatResponse: any = { success: true, message: 'Reply from agent' };
183
+ const mockRuntime: LobeRuntimeAI = {
184
+ baseURL: 'abc',
185
+ chat: vi.fn().mockResolvedValue(mockChatResponse),
186
+ };
177
187
 
178
- vi.spyOn(ModelRuntime.prototype, 'chat').mockResolvedValue(mockChatResponse);
188
+ vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
179
189
 
180
190
  const response = await POST(request as unknown as Request, { params: mockParams });
181
191
 
182
192
  expect(response).toEqual(mockChatResponse);
183
- expect(ModelRuntime.prototype.chat).toHaveBeenCalledWith(mockChatPayload, {
184
- user: 'abc',
193
+ expect(mockRuntime.chat).toHaveBeenCalledWith(mockChatPayload, {
194
+ user: expect.any(String),
185
195
  signal: expect.anything(),
186
196
  });
187
197
  });
188
198
 
189
199
  it('should return an error response when chat completion fails', async () => {
190
- // 设置 getJWTPayload 和 initAgentRuntimeWithUserPayload 的模拟返回值
191
200
  vi.mocked(getXorPayload).mockReturnValueOnce({
192
201
  apiKey: 'test-api-key',
193
202
  azureApiVersion: 'v1',
@@ -203,10 +212,16 @@ describe('POST handler', () => {
203
212
 
204
213
  const mockErrorResponse = {
205
214
  errorType: ChatErrorType.InternalServerError,
215
+ error: { errorMessage: 'Something went wrong', errorType: 500 },
206
216
  errorMessage: 'Something went wrong',
207
217
  };
208
218
 
209
- vi.spyOn(ModelRuntime.prototype, 'chat').mockRejectedValue(mockErrorResponse);
219
+ const mockRuntime: LobeRuntimeAI = {
220
+ baseURL: 'abc',
221
+ chat: vi.fn().mockRejectedValue(mockErrorResponse),
222
+ };
223
+
224
+ vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
210
225
 
211
226
  const response = await POST(request, { params: mockParams });
212
227
 
@@ -6,7 +6,7 @@ import {
6
6
  import { ChatErrorType } from '@lobechat/types';
7
7
 
8
8
  import { checkAuth } from '@/app/(backend)/middleware/auth';
9
- import { createTraceOptions, initModelRuntimeWithUserPayload } from '@/server/modules/ModelRuntime';
9
+ import { createTraceOptions, initModelRuntimeFromDB } from '@/server/modules/ModelRuntime';
10
10
  import { type ChatStreamPayload } from '@/types/openai/chat';
11
11
  import { createErrorResponse } from '@/utils/errorResponse';
12
12
  import { getTracePayload } from '@/utils/trace';
@@ -15,48 +15,52 @@ import { getTracePayload } from '@/utils/trace';
15
15
  // this enforce user to enable fluid compute
16
16
  export const maxDuration = 300;
17
17
 
18
- export const POST = checkAuth(async (req: Request, { params, jwtPayload, createRuntime }) => {
19
- const provider = (await params)!.provider!;
18
+ export const POST = checkAuth(
19
+ async (req: Request, { params, userId, serverDB, createRuntime, jwtPayload }) => {
20
+ const provider = (await params)!.provider!;
20
21
 
21
- try {
22
- // ============ 1. init chat model ============ //
23
- let modelRuntime: ModelRuntime;
24
- if (createRuntime) {
25
- modelRuntime = createRuntime(jwtPayload);
26
- } else {
27
- modelRuntime = await initModelRuntimeWithUserPayload(provider, jwtPayload);
28
- }
22
+ try {
23
+ // ============ 1. init chat model ============ //
24
+ let modelRuntime: ModelRuntime;
25
+ if (createRuntime) {
26
+ // Legacy support for custom runtime creation
27
+ modelRuntime = createRuntime(jwtPayload);
28
+ } else {
29
+ // Read user's provider config from database
30
+ modelRuntime = await initModelRuntimeFromDB(serverDB, userId, provider);
31
+ }
29
32
 
30
- // ============ 2. create chat completion ============ //
33
+ // ============ 2. create chat completion ============ //
31
34
 
32
- const data = (await req.json()) as ChatStreamPayload;
35
+ const data = (await req.json()) as ChatStreamPayload;
33
36
 
34
- const tracePayload = getTracePayload(req);
37
+ const tracePayload = getTracePayload(req);
35
38
 
36
- let traceOptions = {};
37
- // If user enable trace
38
- if (tracePayload?.enabled) {
39
- traceOptions = createTraceOptions(data, { provider, trace: tracePayload });
40
- }
39
+ let traceOptions = {};
40
+ // If user enable trace
41
+ if (tracePayload?.enabled) {
42
+ traceOptions = createTraceOptions(data, { provider, trace: tracePayload });
43
+ }
44
+
45
+ return await modelRuntime.chat(data, {
46
+ user: userId,
47
+ ...traceOptions,
48
+ signal: req.signal,
49
+ });
50
+ } catch (e) {
51
+ const {
52
+ errorType = ChatErrorType.InternalServerError,
53
+ error: errorContent,
54
+ ...res
55
+ } = e as ChatCompletionErrorPayload;
41
56
 
42
- return await modelRuntime.chat(data, {
43
- user: jwtPayload.userId,
44
- ...traceOptions,
45
- signal: req.signal,
46
- });
47
- } catch (e) {
48
- const {
49
- errorType = ChatErrorType.InternalServerError,
50
- error: errorContent,
51
- ...res
52
- } = e as ChatCompletionErrorPayload;
53
-
54
- const error = errorContent || e;
55
-
56
- const logMethod = AGENT_RUNTIME_ERROR_SET.has(errorType as string) ? 'warn' : 'error';
57
- // track the error at server side
58
- console[logMethod](`Route: [${provider}] ${errorType}:`, error);
59
-
60
- return createErrorResponse(errorType, { error, ...res, provider });
61
- }
62
- });
57
+ const error = errorContent || e;
58
+
59
+ const logMethod = AGENT_RUNTIME_ERROR_SET.has(errorType as string) ? 'warn' : 'error';
60
+ // track the error at server side
61
+ console[logMethod](`Route: [${provider}] ${errorType}:`, error);
62
+
63
+ return createErrorResponse(errorType, { error, ...res, provider });
64
+ }
65
+ },
66
+ );
@@ -2,14 +2,15 @@ import { type ChatCompletionErrorPayload, type PullModelParams } from '@lobechat
2
2
  import { ChatErrorType } from '@lobechat/types';
3
3
 
4
4
  import { checkAuth } from '@/app/(backend)/middleware/auth';
5
- import { initModelRuntimeWithUserPayload } from '@/server/modules/ModelRuntime';
5
+ import { initModelRuntimeFromDB } from '@/server/modules/ModelRuntime';
6
6
  import { createErrorResponse } from '@/utils/errorResponse';
7
7
 
8
- export const POST = checkAuth(async (req, { params, jwtPayload }) => {
8
+ export const POST = checkAuth(async (req, { params, userId, serverDB }) => {
9
9
  const provider = (await params)!.provider!;
10
10
 
11
11
  try {
12
- const agentRuntime = await initModelRuntimeWithUserPayload(provider, jwtPayload);
12
+ // Read user's provider config from database
13
+ const agentRuntime = await initModelRuntimeFromDB(serverDB, userId, provider);
13
14
 
14
15
  const data = (await req.json()) as PullModelParams;
15
16
 
@@ -1,10 +1,11 @@
1
1
  // @vitest-environment node
2
- import { ModelRuntime } from '@lobechat/model-runtime';
2
+ import { LobeRuntimeAI, ModelRuntime } from '@lobechat/model-runtime';
3
3
  import { ChatErrorType } from '@lobechat/types';
4
4
  import { getXorPayload } from '@lobechat/utils/server';
5
5
  import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
6
6
 
7
7
  import { LOBE_CHAT_AUTH_HEADER } from '@/const/auth';
8
+ import { initModelRuntimeFromDB } from '@/server/modules/ModelRuntime';
8
9
 
9
10
  import { GET } from './route';
10
11
 
@@ -20,6 +21,10 @@ vi.mock('@lobechat/utils/server', () => ({
20
21
  getXorPayload: vi.fn(),
21
22
  }));
22
23
 
24
+ vi.mock('@/server/modules/ModelRuntime', () => ({
25
+ initModelRuntimeFromDB: vi.fn(),
26
+ }));
27
+
23
28
  let request: Request;
24
29
 
25
30
  beforeEach(() => {
@@ -48,9 +53,12 @@ describe('GET handler', () => {
48
53
  errorWithStack.stack =
49
54
  'Error: Something went wrong\n at Object.<anonymous> (/path/to/file.ts:10:15)';
50
55
 
51
- vi.spyOn(ModelRuntime, 'initializeWithProvider').mockResolvedValue({
56
+ const mockRuntime: LobeRuntimeAI = {
57
+ baseURL: 'abc',
58
+ chat: vi.fn(),
52
59
  models: vi.fn().mockRejectedValue(errorWithStack),
53
- } as any);
60
+ };
61
+ vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
54
62
 
55
63
  const response = await GET(request, { params: mockParams });
56
64
  const responseBody = await response.json();
@@ -85,9 +93,12 @@ describe('GET handler', () => {
85
93
  const customError = new CustomError('Custom error occurred');
86
94
  customError.stack = 'CustomError: Custom error occurred\n at somewhere';
87
95
 
88
- vi.spyOn(ModelRuntime, 'initializeWithProvider').mockResolvedValue({
96
+ const mockRuntime: LobeRuntimeAI = {
97
+ baseURL: 'abc',
98
+ chat: vi.fn(),
89
99
  models: vi.fn().mockRejectedValue(customError),
90
- } as any);
100
+ };
101
+ vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
91
102
 
92
103
  const response = await GET(request, { params: mockParams });
93
104
  const responseBody = await response.json();
@@ -109,9 +120,12 @@ describe('GET handler', () => {
109
120
  error: { code: 'PROVIDER_ERROR', details: 'API limit exceeded' },
110
121
  };
111
122
 
112
- vi.spyOn(ModelRuntime, 'initializeWithProvider').mockResolvedValue({
123
+ const mockRuntime: LobeRuntimeAI = {
124
+ baseURL: 'abc',
125
+ chat: vi.fn(),
113
126
  models: vi.fn().mockRejectedValue(structuredError),
114
- } as any);
127
+ };
128
+ vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
115
129
 
116
130
  const response = await GET(request, { params: mockParams });
117
131
  const responseBody = await response.json();
@@ -128,9 +142,12 @@ describe('GET handler', () => {
128
142
  apiKey: 'test-api-key',
129
143
  });
130
144
 
131
- vi.spyOn(ModelRuntime, 'initializeWithProvider').mockResolvedValue({
145
+ const mockRuntime: LobeRuntimeAI = {
146
+ baseURL: 'abc',
147
+ chat: vi.fn(),
132
148
  models: vi.fn().mockRejectedValue(new Error('Failed')),
133
- } as any);
149
+ };
150
+ vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
134
151
 
135
152
  const response = await GET(request, { params: mockParams });
136
153
 
@@ -144,9 +161,12 @@ describe('GET handler', () => {
144
161
  apiKey: 'test-api-key',
145
162
  });
146
163
 
147
- vi.spyOn(ModelRuntime, 'initializeWithProvider').mockResolvedValue({
164
+ const mockRuntime: LobeRuntimeAI = {
165
+ baseURL: 'abc',
166
+ chat: vi.fn(),
148
167
  models: vi.fn().mockRejectedValue(new Error('Failed')),
149
- } as any);
168
+ };
169
+ vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
150
170
 
151
171
  const response = await GET(request, { params: mockParams });
152
172
  const responseBody = await response.json();
@@ -168,9 +188,12 @@ describe('GET handler', () => {
168
188
  { id: 'gpt-3.5-turbo', name: 'GPT-3.5 Turbo' },
169
189
  ];
170
190
 
171
- vi.spyOn(ModelRuntime, 'initializeWithProvider').mockResolvedValue({
191
+ const mockRuntime: LobeRuntimeAI = {
192
+ baseURL: 'abc',
193
+ chat: vi.fn(),
172
194
  models: vi.fn().mockResolvedValue(mockModelList),
173
- } as any);
195
+ };
196
+ vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
174
197
 
175
198
  const response = await GET(request, { params: mockParams });
176
199
  const responseBody = await response.json();
@@ -1,24 +1,17 @@
1
1
  import { type ChatCompletionErrorPayload } from '@lobechat/model-runtime';
2
2
  import { ChatErrorType } from '@lobechat/types';
3
- import { ModelProvider } from 'model-bank';
4
3
  import { NextResponse } from 'next/server';
5
4
 
6
5
  import { checkAuth } from '@/app/(backend)/middleware/auth';
7
- import { initModelRuntimeWithUserPayload } from '@/server/modules/ModelRuntime';
6
+ import { initModelRuntimeFromDB } from '@/server/modules/ModelRuntime';
8
7
  import { createErrorResponse } from '@/utils/errorResponse';
9
8
 
10
- const noNeedAPIKey = (provider: string) => [ModelProvider.OpenRouter].includes(provider as any);
11
-
12
- export const GET = checkAuth(async (req, { params, jwtPayload }) => {
9
+ export const GET = checkAuth(async (req, { params, userId, serverDB }) => {
13
10
  const provider = (await params)!.provider!;
14
11
 
15
12
  try {
16
- const hasDefaultApiKey = jwtPayload.apiKey || 'dont-need-api-key-for-model-list';
17
-
18
- const agentRuntime = await initModelRuntimeWithUserPayload(provider, {
19
- ...jwtPayload,
20
- apiKey: noNeedAPIKey(provider) ? hasDefaultApiKey : jwtPayload.apiKey,
21
- });
13
+ // Read user's provider config from database
14
+ const agentRuntime = await initModelRuntimeFromDB(serverDB, userId, provider);
22
15
 
23
16
  const list = await agentRuntime.models();
24
17