@lobehub/lobehub 2.0.0-next.201 → 2.0.0-next.202
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +25 -0
- package/changelog/v1.json +9 -0
- package/locales/ar/chat.json +2 -0
- package/locales/ar/models.json +64 -7
- package/locales/ar/plugin.json +2 -1
- package/locales/ar/providers.json +1 -0
- package/locales/bg-BG/chat.json +2 -0
- package/locales/bg-BG/models.json +49 -5
- package/locales/bg-BG/plugin.json +2 -1
- package/locales/bg-BG/providers.json +1 -0
- package/locales/de-DE/chat.json +2 -0
- package/locales/de-DE/models.json +36 -7
- package/locales/de-DE/plugin.json +2 -1
- package/locales/de-DE/providers.json +1 -0
- package/locales/en-US/chat.json +2 -0
- package/locales/en-US/models.json +10 -10
- package/locales/en-US/plugin.json +2 -1
- package/locales/en-US/providers.json +1 -0
- package/locales/es-ES/chat.json +2 -0
- package/locales/es-ES/models.json +106 -7
- package/locales/es-ES/plugin.json +2 -1
- package/locales/es-ES/providers.json +1 -0
- package/locales/fa-IR/chat.json +2 -0
- package/locales/fa-IR/models.json +83 -5
- package/locales/fa-IR/plugin.json +2 -1
- package/locales/fa-IR/providers.json +1 -0
- package/locales/fr-FR/chat.json +2 -0
- package/locales/fr-FR/models.json +38 -7
- package/locales/fr-FR/plugin.json +2 -1
- package/locales/fr-FR/providers.json +1 -0
- package/locales/it-IT/chat.json +2 -0
- package/locales/it-IT/models.json +40 -5
- package/locales/it-IT/plugin.json +2 -1
- package/locales/it-IT/providers.json +1 -0
- package/locales/ja-JP/chat.json +2 -0
- package/locales/ja-JP/models.json +84 -7
- package/locales/ja-JP/plugin.json +2 -1
- package/locales/ja-JP/providers.json +1 -0
- package/locales/ko-KR/chat.json +2 -0
- package/locales/ko-KR/models.json +65 -7
- package/locales/ko-KR/plugin.json +2 -1
- package/locales/ko-KR/providers.json +1 -0
- package/locales/nl-NL/chat.json +2 -0
- package/locales/nl-NL/models.json +62 -5
- package/locales/nl-NL/plugin.json +2 -1
- package/locales/nl-NL/providers.json +1 -0
- package/locales/pl-PL/chat.json +2 -0
- package/locales/pl-PL/models.json +85 -0
- package/locales/pl-PL/plugin.json +2 -1
- package/locales/pl-PL/providers.json +1 -0
- package/locales/pt-BR/chat.json +2 -0
- package/locales/pt-BR/models.json +37 -6
- package/locales/pt-BR/plugin.json +2 -1
- package/locales/pt-BR/providers.json +1 -0
- package/locales/ru-RU/chat.json +2 -0
- package/locales/ru-RU/models.json +36 -7
- package/locales/ru-RU/plugin.json +2 -1
- package/locales/ru-RU/providers.json +1 -0
- package/locales/tr-TR/chat.json +2 -0
- package/locales/tr-TR/models.json +28 -7
- package/locales/tr-TR/plugin.json +2 -1
- package/locales/tr-TR/providers.json +1 -0
- package/locales/vi-VN/chat.json +2 -0
- package/locales/vi-VN/models.json +62 -5
- package/locales/vi-VN/plugin.json +2 -1
- package/locales/vi-VN/providers.json +1 -0
- package/locales/zh-CN/chat.json +2 -0
- package/locales/zh-CN/models.json +87 -6
- package/locales/zh-CN/plugin.json +2 -1
- package/locales/zh-CN/providers.json +1 -0
- package/locales/zh-TW/chat.json +2 -0
- package/locales/zh-TW/models.json +71 -7
- package/locales/zh-TW/plugin.json +2 -1
- package/locales/zh-TW/providers.json +1 -0
- package/package.json +1 -1
- package/packages/builtin-tool-gtd/src/client/Inspector/ExecTask/index.tsx +30 -15
- package/packages/builtin-tool-gtd/src/manifest.ts +1 -1
- package/packages/model-runtime/src/core/ModelRuntime.test.ts +44 -86
- package/packages/types/src/aiChat.ts +0 -1
- package/packages/types/src/message/ui/chat.ts +1 -1
- package/src/app/(backend)/middleware/auth/index.ts +16 -2
- package/src/app/(backend)/webapi/chat/[provider]/route.test.ts +30 -15
- package/src/app/(backend)/webapi/chat/[provider]/route.ts +44 -40
- package/src/app/(backend)/webapi/models/[provider]/pull/route.ts +4 -3
- package/src/app/(backend)/webapi/models/[provider]/route.test.ts +36 -13
- package/src/app/(backend)/webapi/models/[provider]/route.ts +4 -11
- package/src/features/Conversation/Messages/AssistantGroup/Tool/Render/index.tsx +21 -23
- package/src/features/Conversation/Messages/AssistantGroup/components/ContentBlock.tsx +16 -3
- package/src/features/Conversation/Messages/Task/TaskDetailPanel/index.tsx +17 -20
- package/src/features/Conversation/Messages/Tasks/shared/ErrorState.tsx +16 -11
- package/src/features/Conversation/Messages/Tasks/shared/InitializingState.tsx +6 -20
- package/src/features/Conversation/Messages/Tasks/shared/ProcessingState.tsx +10 -20
- package/src/features/User/DataStatistics.tsx +4 -4
- package/src/hooks/useQueryParam.ts +0 -2
- package/src/libs/trpc/async/asyncAuth.ts +0 -2
- package/src/libs/trpc/async/context.ts +3 -11
- package/src/locales/default/chat.ts +2 -0
- package/src/locales/default/plugin.ts +2 -1
- package/src/server/modules/AgentRuntime/RuntimeExecutors.ts +6 -6
- package/src/server/modules/AgentRuntime/__tests__/RuntimeExecutors.test.ts +3 -3
- package/src/server/modules/AgentRuntime/factory.ts +39 -20
- package/src/server/modules/ModelRuntime/index.ts +138 -1
- package/src/server/routers/async/__tests__/caller.test.ts +22 -27
- package/src/server/routers/async/caller.ts +4 -6
- package/src/server/routers/async/file.ts +10 -5
- package/src/server/routers/async/image.ts +5 -4
- package/src/server/routers/async/ragEval.ts +7 -5
- package/src/server/routers/lambda/__tests__/aiChat.test.ts +8 -37
- package/src/server/routers/lambda/aiChat.ts +5 -21
- package/src/server/routers/lambda/chunk.ts +9 -28
- package/src/server/routers/lambda/image.ts +1 -7
- package/src/server/routers/lambda/ragEval.ts +1 -1
- package/src/server/routers/lambda/userMemories/reembed.ts +4 -1
- package/src/server/routers/lambda/userMemories/search.ts +7 -7
- package/src/server/routers/lambda/userMemories/shared.ts +8 -10
- package/src/server/routers/lambda/userMemories/tools.ts +140 -118
- package/src/server/routers/lambda/userMemories.test.ts +3 -7
- package/src/server/routers/lambda/userMemories.ts +44 -29
- package/src/server/services/agentRuntime/AgentRuntimeService.test.ts +87 -0
- package/src/server/services/agentRuntime/AgentRuntimeService.ts +53 -2
- package/src/server/services/agentRuntime/__tests__/executeSync.test.ts +2 -6
- package/src/server/services/agentRuntime/__tests__/stepLifecycleCallbacks.test.ts +1 -1
- package/src/server/services/chunk/index.ts +6 -5
- package/src/server/services/toolExecution/types.ts +1 -2
- package/src/services/__tests__/_url.test.ts +0 -1
- package/src/services/_url.ts +0 -3
- package/src/services/aiChat.ts +5 -12
- package/src/store/chat/slices/aiChat/actions/streamingExecutor.ts +0 -2
- package/src/app/(backend)/webapi/text-to-image/[provider]/route.ts +0 -74
|
@@ -1,18 +1,25 @@
|
|
|
1
1
|
// @vitest-environment node
|
|
2
|
-
import { TraceNameMap } from '@lobechat/types';
|
|
3
2
|
import { ClientSecretPayload } from '@lobechat/types';
|
|
4
|
-
import { Langfuse } from 'langfuse';
|
|
5
|
-
import { LangfuseGenerationClient, LangfuseTraceClient } from 'langfuse-core';
|
|
6
3
|
import { ModelProvider } from 'model-bank';
|
|
7
4
|
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
|
8
5
|
|
|
9
|
-
import
|
|
10
|
-
import { createTraceOptions } from '@/server/modules/ModelRuntime';
|
|
11
|
-
|
|
12
|
-
import { ChatStreamPayload, LobeOpenAI, ModelRuntime } from '../index';
|
|
6
|
+
import { ChatStreamCallbacks, ChatStreamPayload, LobeOpenAI, ModelRuntime } from '../index';
|
|
13
7
|
import { providerRuntimeMap } from '../runtimeMap';
|
|
14
8
|
import { CreateImagePayload } from '../types/image';
|
|
15
|
-
|
|
9
|
+
|
|
10
|
+
/**
|
|
11
|
+
* Mock createTraceOptions for testing purposes.
|
|
12
|
+
* This avoids importing from @/server/modules/ModelRuntime which has database dependencies.
|
|
13
|
+
*/
|
|
14
|
+
const createMockTraceOptions = (callbacks?: Partial<ChatStreamCallbacks>) => ({
|
|
15
|
+
callback: {
|
|
16
|
+
onCompletion: callbacks?.onCompletion ?? vi.fn(),
|
|
17
|
+
onFinal: callbacks?.onFinal ?? vi.fn(),
|
|
18
|
+
onStart: callbacks?.onStart ?? vi.fn(),
|
|
19
|
+
onToolsCalling: callbacks?.onToolsCalling ?? vi.fn(),
|
|
20
|
+
} as ChatStreamCallbacks,
|
|
21
|
+
headers: new Headers(),
|
|
22
|
+
});
|
|
16
23
|
|
|
17
24
|
const specialProviders = [
|
|
18
25
|
{ id: 'openai', payload: { apiKey: 'user-openai-key', baseURL: 'user-endpoint' } },
|
|
@@ -99,85 +106,50 @@ describe('ModelRuntime', () => {
|
|
|
99
106
|
|
|
100
107
|
await mockModelRuntime.chat(payload);
|
|
101
108
|
});
|
|
102
|
-
it('should handle options correctly', async () => {
|
|
109
|
+
it('should handle options with callbacks correctly', async () => {
|
|
103
110
|
const payload: ChatStreamPayload = {
|
|
104
111
|
messages: [{ role: 'user', content: 'Hello, world!' }],
|
|
105
112
|
model: 'text-davinci-002',
|
|
106
113
|
temperature: 0,
|
|
107
114
|
};
|
|
108
115
|
|
|
109
|
-
const options: AgentChatOptions = {
|
|
110
|
-
provider: 'openai',
|
|
111
|
-
trace: {
|
|
112
|
-
traceId: 'test-trace-id',
|
|
113
|
-
traceName: TraceNameMap.SummaryTopicTitle,
|
|
114
|
-
sessionId: 'test-session-id',
|
|
115
|
-
topicId: 'test-topic-id',
|
|
116
|
-
tags: [],
|
|
117
|
-
userId: 'test-user-id',
|
|
118
|
-
},
|
|
119
|
-
};
|
|
120
|
-
|
|
121
116
|
vi.spyOn(LobeOpenAI.prototype, 'chat').mockResolvedValue(new Response(''));
|
|
122
117
|
|
|
123
|
-
await mockModelRuntime.chat(payload,
|
|
118
|
+
await mockModelRuntime.chat(payload, createMockTraceOptions());
|
|
124
119
|
});
|
|
125
120
|
|
|
126
|
-
describe('callback',
|
|
121
|
+
describe('callback', () => {
|
|
127
122
|
const payload: ChatStreamPayload = {
|
|
128
123
|
messages: [{ role: 'user', content: 'Hello, world!' }],
|
|
129
124
|
model: 'text-davinci-002',
|
|
130
125
|
temperature: 0,
|
|
131
126
|
};
|
|
132
127
|
|
|
133
|
-
const options: AgentChatOptions = {
|
|
134
|
-
provider: 'openai',
|
|
135
|
-
trace: {
|
|
136
|
-
traceId: 'test-trace-id',
|
|
137
|
-
traceName: TraceNameMap.SummaryTopicTitle,
|
|
138
|
-
sessionId: 'test-session-id',
|
|
139
|
-
topicId: 'test-topic-id',
|
|
140
|
-
tags: [],
|
|
141
|
-
userId: 'test-user-id',
|
|
142
|
-
},
|
|
143
|
-
enableTrace: true,
|
|
144
|
-
};
|
|
145
|
-
|
|
146
|
-
const updateMock = vi.fn();
|
|
147
|
-
|
|
148
128
|
it('should call onToolsCalling correctly', async () => {
|
|
149
|
-
vi.
|
|
150
|
-
ENABLE_LANGFUSE: true,
|
|
151
|
-
LANGFUSE_PUBLIC_KEY: 'abc',
|
|
152
|
-
LANGFUSE_SECRET_KEY: 'DDD',
|
|
153
|
-
} as any);
|
|
129
|
+
const onToolsCallingMock = vi.fn();
|
|
154
130
|
|
|
155
|
-
// 使用 spyOn 模拟 chat 方法
|
|
156
131
|
vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation(
|
|
157
|
-
async (
|
|
158
|
-
// 模拟 onToolCall 回调的触发
|
|
132
|
+
async (_payload, { callback }: any) => {
|
|
159
133
|
if (callback?.onToolsCalling) {
|
|
160
134
|
await callback.onToolsCalling();
|
|
161
135
|
}
|
|
162
136
|
return new Response('abc');
|
|
163
137
|
},
|
|
164
138
|
);
|
|
165
|
-
vi.spyOn(LangfuseTraceClient.prototype, 'update').mockImplementation(updateMock);
|
|
166
139
|
|
|
167
|
-
await mockModelRuntime.chat(
|
|
140
|
+
await mockModelRuntime.chat(
|
|
141
|
+
payload,
|
|
142
|
+
createMockTraceOptions({ onToolsCalling: onToolsCallingMock }),
|
|
143
|
+
);
|
|
168
144
|
|
|
169
|
-
expect(
|
|
145
|
+
expect(onToolsCallingMock).toHaveBeenCalled();
|
|
170
146
|
});
|
|
147
|
+
|
|
171
148
|
it('should call onStart correctly', async () => {
|
|
172
|
-
vi.
|
|
173
|
-
ENABLE_LANGFUSE: true,
|
|
174
|
-
LANGFUSE_PUBLIC_KEY: 'abc',
|
|
175
|
-
LANGFUSE_SECRET_KEY: 'DDD',
|
|
176
|
-
} as any);
|
|
149
|
+
const onStartMock = vi.fn();
|
|
177
150
|
|
|
178
|
-
vi.spyOn(LangfuseGenerationClient.prototype, 'update').mockImplementation(updateMock);
|
|
179
151
|
vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation(
|
|
180
|
-
async (
|
|
152
|
+
async (_payload, { callback }: any) => {
|
|
181
153
|
if (callback?.onStart) {
|
|
182
154
|
callback.onStart();
|
|
183
155
|
}
|
|
@@ -185,22 +157,16 @@ describe('ModelRuntime', () => {
|
|
|
185
157
|
},
|
|
186
158
|
);
|
|
187
159
|
|
|
188
|
-
await mockModelRuntime.chat(payload,
|
|
160
|
+
await mockModelRuntime.chat(payload, createMockTraceOptions({ onStart: onStartMock }));
|
|
189
161
|
|
|
190
|
-
|
|
191
|
-
expect(updateMock).toHaveBeenCalledWith({ completionStartTime: expect.any(Date) });
|
|
162
|
+
expect(onStartMock).toHaveBeenCalled();
|
|
192
163
|
});
|
|
193
164
|
|
|
194
165
|
it('should call onCompletion correctly', async () => {
|
|
195
|
-
vi.
|
|
196
|
-
|
|
197
|
-
LANGFUSE_PUBLIC_KEY: 'abc',
|
|
198
|
-
LANGFUSE_SECRET_KEY: 'DDD',
|
|
199
|
-
} as any);
|
|
200
|
-
// Spy on the chat method and trigger onCompletion callback
|
|
201
|
-
vi.spyOn(LangfuseGenerationClient.prototype, 'update').mockImplementation(updateMock);
|
|
166
|
+
const onCompletionMock = vi.fn();
|
|
167
|
+
|
|
202
168
|
vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation(
|
|
203
|
-
async (
|
|
169
|
+
async (_payload, { callback }: any) => {
|
|
204
170
|
if (callback?.onCompletion) {
|
|
205
171
|
await callback.onCompletion({ text: 'Test completion' });
|
|
206
172
|
}
|
|
@@ -208,37 +174,29 @@ describe('ModelRuntime', () => {
|
|
|
208
174
|
},
|
|
209
175
|
);
|
|
210
176
|
|
|
211
|
-
await mockModelRuntime.chat(
|
|
177
|
+
await mockModelRuntime.chat(
|
|
178
|
+
payload,
|
|
179
|
+
createMockTraceOptions({ onCompletion: onCompletionMock }),
|
|
180
|
+
);
|
|
212
181
|
|
|
213
|
-
|
|
214
|
-
expect(updateMock).toHaveBeenCalledWith({
|
|
215
|
-
endTime: expect.any(Date),
|
|
216
|
-
metadata: {},
|
|
217
|
-
output: 'Test completion',
|
|
218
|
-
});
|
|
182
|
+
expect(onCompletionMock).toHaveBeenCalledWith({ text: 'Test completion' });
|
|
219
183
|
});
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
LANGFUSE_PUBLIC_KEY: 'abc',
|
|
224
|
-
LANGFUSE_SECRET_KEY: 'DDD',
|
|
225
|
-
} as any);
|
|
184
|
+
|
|
185
|
+
it('should call onFinal correctly', async () => {
|
|
186
|
+
const onFinalMock = vi.fn();
|
|
226
187
|
|
|
227
188
|
vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation(
|
|
228
|
-
async (
|
|
189
|
+
async (_payload, { callback }: any) => {
|
|
229
190
|
if (callback?.onFinal) {
|
|
230
191
|
await callback.onFinal('Test completion');
|
|
231
192
|
}
|
|
232
193
|
return new Response('Success');
|
|
233
194
|
},
|
|
234
195
|
);
|
|
235
|
-
const shutdownAsyncMock = vi.fn();
|
|
236
|
-
vi.spyOn(Langfuse.prototype, 'shutdownAsync').mockImplementation(shutdownAsyncMock);
|
|
237
196
|
|
|
238
|
-
await mockModelRuntime.chat(payload,
|
|
197
|
+
await mockModelRuntime.chat(payload, createMockTraceOptions({ onFinal: onFinalMock }));
|
|
239
198
|
|
|
240
|
-
|
|
241
|
-
expect(shutdownAsyncMock).toHaveBeenCalled();
|
|
199
|
+
expect(onFinalMock).toHaveBeenCalledWith('Test completion');
|
|
242
200
|
});
|
|
243
201
|
});
|
|
244
202
|
});
|
|
@@ -69,7 +69,7 @@ export interface TaskDetail {
|
|
|
69
69
|
/** Execution duration in milliseconds */
|
|
70
70
|
duration?: number;
|
|
71
71
|
/** Error message if task failed */
|
|
72
|
-
error?: string
|
|
72
|
+
error?: Record<string, any>;
|
|
73
73
|
/** Task start time (ISO string) */
|
|
74
74
|
startedAt?: string;
|
|
75
75
|
/** Task status */
|
|
@@ -15,6 +15,8 @@ import {
|
|
|
15
15
|
enableBetterAuth,
|
|
16
16
|
enableClerk,
|
|
17
17
|
} from '@/const/auth';
|
|
18
|
+
import { getServerDB } from '@/database/core/db-adaptor';
|
|
19
|
+
import { type LobeChatDatabase } from '@/database/type';
|
|
18
20
|
import { ClerkAuth } from '@/libs/clerk-auth';
|
|
19
21
|
import { validateOIDCJWT } from '@/libs/oidc-provider/jwt';
|
|
20
22
|
import { createErrorResponse } from '@/utils/errorResponse';
|
|
@@ -28,6 +30,8 @@ export type RequestHandler = (
|
|
|
28
30
|
req: Request,
|
|
29
31
|
options: RequestOptions & {
|
|
30
32
|
jwtPayload: ClientSecretPayload;
|
|
33
|
+
serverDB: LobeChatDatabase;
|
|
34
|
+
userId: string;
|
|
31
35
|
},
|
|
32
36
|
) => Promise<Response>;
|
|
33
37
|
|
|
@@ -38,10 +42,18 @@ export const checkAuth =
|
|
|
38
42
|
// This ensures the handler can safely read the request body
|
|
39
43
|
const clonedReq = req.clone();
|
|
40
44
|
|
|
45
|
+
// Get serverDB for database access
|
|
46
|
+
const serverDB = await getServerDB();
|
|
47
|
+
|
|
41
48
|
// we have a special header to debug the api endpoint in development mode
|
|
42
49
|
const isDebugApi = req.headers.get('lobe-auth-dev-backend-api') === '1';
|
|
43
50
|
if (process.env.NODE_ENV === 'development' && isDebugApi) {
|
|
44
|
-
return handler(clonedReq, {
|
|
51
|
+
return handler(clonedReq, {
|
|
52
|
+
...options,
|
|
53
|
+
jwtPayload: { userId: 'DEV_USER' },
|
|
54
|
+
serverDB,
|
|
55
|
+
userId: 'DEV_USER',
|
|
56
|
+
});
|
|
45
57
|
}
|
|
46
58
|
|
|
47
59
|
let jwtPayload: ClientSecretPayload;
|
|
@@ -124,5 +136,7 @@ export const checkAuth =
|
|
|
124
136
|
return createErrorResponse(errorType, { error, ...res, provider: params?.provider });
|
|
125
137
|
}
|
|
126
138
|
|
|
127
|
-
|
|
139
|
+
const userId = jwtPayload.userId || '';
|
|
140
|
+
|
|
141
|
+
return handler(clonedReq, { ...options, jwtPayload, serverDB, userId });
|
|
128
142
|
};
|
|
@@ -7,6 +7,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
|
|
7
7
|
|
|
8
8
|
import { checkAuthMethod } from '@/app/(backend)/middleware/auth/utils';
|
|
9
9
|
import { LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED } from '@/const/auth';
|
|
10
|
+
import { initModelRuntimeFromDB } from '@/server/modules/ModelRuntime';
|
|
10
11
|
|
|
11
12
|
import { POST } from './route';
|
|
12
13
|
|
|
@@ -22,6 +23,11 @@ vi.mock('@lobechat/utils/server', () => ({
|
|
|
22
23
|
getXorPayload: vi.fn(),
|
|
23
24
|
}));
|
|
24
25
|
|
|
26
|
+
vi.mock('@/server/modules/ModelRuntime', () => ({
|
|
27
|
+
initModelRuntimeFromDB: vi.fn(),
|
|
28
|
+
createTraceOptions: vi.fn().mockReturnValue({}),
|
|
29
|
+
}));
|
|
30
|
+
|
|
25
31
|
// Use vi.hoisted to ensure mockState is initialized before mocks are set up
|
|
26
32
|
const mockState = vi.hoisted(() => ({ enableClerk: false }));
|
|
27
33
|
|
|
@@ -60,7 +66,7 @@ describe('POST handler', () => {
|
|
|
60
66
|
it('should initialize ModelRuntime correctly with valid authorization', async () => {
|
|
61
67
|
const mockParams = Promise.resolve({ provider: 'test-provider' });
|
|
62
68
|
|
|
63
|
-
// 设置 getJWTPayload
|
|
69
|
+
// 设置 getJWTPayload 的模拟返回值
|
|
64
70
|
vi.mocked(getXorPayload).mockReturnValueOnce({
|
|
65
71
|
apiKey: 'test-api-key',
|
|
66
72
|
azureApiVersion: 'v1',
|
|
@@ -68,17 +74,19 @@ describe('POST handler', () => {
|
|
|
68
74
|
|
|
69
75
|
const mockRuntime: LobeRuntimeAI = { baseURL: 'abc', chat: vi.fn() };
|
|
70
76
|
|
|
71
|
-
//
|
|
72
|
-
|
|
73
|
-
.spyOn(ModelRuntime, 'initializeWithProvider')
|
|
74
|
-
.mockResolvedValue(new ModelRuntime(mockRuntime));
|
|
77
|
+
// Mock initModelRuntimeFromDB
|
|
78
|
+
vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
|
|
75
79
|
|
|
76
80
|
// 调用 POST 函数
|
|
77
81
|
await POST(request as unknown as Request, { params: mockParams });
|
|
78
82
|
|
|
79
83
|
// 验证是否正确调用了模拟函数
|
|
80
84
|
expect(getXorPayload).toHaveBeenCalledWith('Bearer some-valid-token');
|
|
81
|
-
expect(
|
|
85
|
+
expect(initModelRuntimeFromDB).toHaveBeenCalledWith(
|
|
86
|
+
expect.anything(),
|
|
87
|
+
expect.any(String),
|
|
88
|
+
'test-provider',
|
|
89
|
+
);
|
|
82
90
|
});
|
|
83
91
|
|
|
84
92
|
it('should return Unauthorized error when LOBE_CHAT_AUTH_HEADER is missing', async () => {
|
|
@@ -109,15 +117,13 @@ describe('POST handler', () => {
|
|
|
109
117
|
});
|
|
110
118
|
|
|
111
119
|
const mockParams = Promise.resolve({ provider: 'test-provider' });
|
|
112
|
-
// 设置 initModelRuntimeWithUserPayload 的模拟返回值
|
|
113
120
|
vi.mocked(getAuth).mockReturnValue({} as any);
|
|
114
121
|
vi.mocked(checkAuthMethod).mockReset();
|
|
115
122
|
|
|
116
123
|
const mockRuntime: LobeRuntimeAI = { baseURL: 'abc', chat: vi.fn() };
|
|
117
124
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
);
|
|
125
|
+
// Mock initModelRuntimeFromDB
|
|
126
|
+
vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
|
|
121
127
|
|
|
122
128
|
const request = new Request(new URL('https://test.com'), {
|
|
123
129
|
method: 'POST',
|
|
@@ -174,20 +180,23 @@ describe('POST handler', () => {
|
|
|
174
180
|
});
|
|
175
181
|
|
|
176
182
|
const mockChatResponse: any = { success: true, message: 'Reply from agent' };
|
|
183
|
+
const mockRuntime: LobeRuntimeAI = {
|
|
184
|
+
baseURL: 'abc',
|
|
185
|
+
chat: vi.fn().mockResolvedValue(mockChatResponse),
|
|
186
|
+
};
|
|
177
187
|
|
|
178
|
-
vi.
|
|
188
|
+
vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
|
|
179
189
|
|
|
180
190
|
const response = await POST(request as unknown as Request, { params: mockParams });
|
|
181
191
|
|
|
182
192
|
expect(response).toEqual(mockChatResponse);
|
|
183
|
-
expect(
|
|
184
|
-
user:
|
|
193
|
+
expect(mockRuntime.chat).toHaveBeenCalledWith(mockChatPayload, {
|
|
194
|
+
user: expect.any(String),
|
|
185
195
|
signal: expect.anything(),
|
|
186
196
|
});
|
|
187
197
|
});
|
|
188
198
|
|
|
189
199
|
it('should return an error response when chat completion fails', async () => {
|
|
190
|
-
// 设置 getJWTPayload 和 initAgentRuntimeWithUserPayload 的模拟返回值
|
|
191
200
|
vi.mocked(getXorPayload).mockReturnValueOnce({
|
|
192
201
|
apiKey: 'test-api-key',
|
|
193
202
|
azureApiVersion: 'v1',
|
|
@@ -203,10 +212,16 @@ describe('POST handler', () => {
|
|
|
203
212
|
|
|
204
213
|
const mockErrorResponse = {
|
|
205
214
|
errorType: ChatErrorType.InternalServerError,
|
|
215
|
+
error: { errorMessage: 'Something went wrong', errorType: 500 },
|
|
206
216
|
errorMessage: 'Something went wrong',
|
|
207
217
|
};
|
|
208
218
|
|
|
209
|
-
|
|
219
|
+
const mockRuntime: LobeRuntimeAI = {
|
|
220
|
+
baseURL: 'abc',
|
|
221
|
+
chat: vi.fn().mockRejectedValue(mockErrorResponse),
|
|
222
|
+
};
|
|
223
|
+
|
|
224
|
+
vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
|
|
210
225
|
|
|
211
226
|
const response = await POST(request, { params: mockParams });
|
|
212
227
|
|
|
@@ -6,7 +6,7 @@ import {
|
|
|
6
6
|
import { ChatErrorType } from '@lobechat/types';
|
|
7
7
|
|
|
8
8
|
import { checkAuth } from '@/app/(backend)/middleware/auth';
|
|
9
|
-
import { createTraceOptions,
|
|
9
|
+
import { createTraceOptions, initModelRuntimeFromDB } from '@/server/modules/ModelRuntime';
|
|
10
10
|
import { type ChatStreamPayload } from '@/types/openai/chat';
|
|
11
11
|
import { createErrorResponse } from '@/utils/errorResponse';
|
|
12
12
|
import { getTracePayload } from '@/utils/trace';
|
|
@@ -15,48 +15,52 @@ import { getTracePayload } from '@/utils/trace';
|
|
|
15
15
|
// this enforce user to enable fluid compute
|
|
16
16
|
export const maxDuration = 300;
|
|
17
17
|
|
|
18
|
-
export const POST = checkAuth(
|
|
19
|
-
|
|
18
|
+
export const POST = checkAuth(
|
|
19
|
+
async (req: Request, { params, userId, serverDB, createRuntime, jwtPayload }) => {
|
|
20
|
+
const provider = (await params)!.provider!;
|
|
20
21
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
22
|
+
try {
|
|
23
|
+
// ============ 1. init chat model ============ //
|
|
24
|
+
let modelRuntime: ModelRuntime;
|
|
25
|
+
if (createRuntime) {
|
|
26
|
+
// Legacy support for custom runtime creation
|
|
27
|
+
modelRuntime = createRuntime(jwtPayload);
|
|
28
|
+
} else {
|
|
29
|
+
// Read user's provider config from database
|
|
30
|
+
modelRuntime = await initModelRuntimeFromDB(serverDB, userId, provider);
|
|
31
|
+
}
|
|
29
32
|
|
|
30
|
-
|
|
33
|
+
// ============ 2. create chat completion ============ //
|
|
31
34
|
|
|
32
|
-
|
|
35
|
+
const data = (await req.json()) as ChatStreamPayload;
|
|
33
36
|
|
|
34
|
-
|
|
37
|
+
const tracePayload = getTracePayload(req);
|
|
35
38
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
39
|
+
let traceOptions = {};
|
|
40
|
+
// If user enable trace
|
|
41
|
+
if (tracePayload?.enabled) {
|
|
42
|
+
traceOptions = createTraceOptions(data, { provider, trace: tracePayload });
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
return await modelRuntime.chat(data, {
|
|
46
|
+
user: userId,
|
|
47
|
+
...traceOptions,
|
|
48
|
+
signal: req.signal,
|
|
49
|
+
});
|
|
50
|
+
} catch (e) {
|
|
51
|
+
const {
|
|
52
|
+
errorType = ChatErrorType.InternalServerError,
|
|
53
|
+
error: errorContent,
|
|
54
|
+
...res
|
|
55
|
+
} = e as ChatCompletionErrorPayload;
|
|
41
56
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
} = e as ChatCompletionErrorPayload;
|
|
53
|
-
|
|
54
|
-
const error = errorContent || e;
|
|
55
|
-
|
|
56
|
-
const logMethod = AGENT_RUNTIME_ERROR_SET.has(errorType as string) ? 'warn' : 'error';
|
|
57
|
-
// track the error at server side
|
|
58
|
-
console[logMethod](`Route: [${provider}] ${errorType}:`, error);
|
|
59
|
-
|
|
60
|
-
return createErrorResponse(errorType, { error, ...res, provider });
|
|
61
|
-
}
|
|
62
|
-
});
|
|
57
|
+
const error = errorContent || e;
|
|
58
|
+
|
|
59
|
+
const logMethod = AGENT_RUNTIME_ERROR_SET.has(errorType as string) ? 'warn' : 'error';
|
|
60
|
+
// track the error at server side
|
|
61
|
+
console[logMethod](`Route: [${provider}] ${errorType}:`, error);
|
|
62
|
+
|
|
63
|
+
return createErrorResponse(errorType, { error, ...res, provider });
|
|
64
|
+
}
|
|
65
|
+
},
|
|
66
|
+
);
|
|
@@ -2,14 +2,15 @@ import { type ChatCompletionErrorPayload, type PullModelParams } from '@lobechat
|
|
|
2
2
|
import { ChatErrorType } from '@lobechat/types';
|
|
3
3
|
|
|
4
4
|
import { checkAuth } from '@/app/(backend)/middleware/auth';
|
|
5
|
-
import {
|
|
5
|
+
import { initModelRuntimeFromDB } from '@/server/modules/ModelRuntime';
|
|
6
6
|
import { createErrorResponse } from '@/utils/errorResponse';
|
|
7
7
|
|
|
8
|
-
export const POST = checkAuth(async (req, { params,
|
|
8
|
+
export const POST = checkAuth(async (req, { params, userId, serverDB }) => {
|
|
9
9
|
const provider = (await params)!.provider!;
|
|
10
10
|
|
|
11
11
|
try {
|
|
12
|
-
|
|
12
|
+
// Read user's provider config from database
|
|
13
|
+
const agentRuntime = await initModelRuntimeFromDB(serverDB, userId, provider);
|
|
13
14
|
|
|
14
15
|
const data = (await req.json()) as PullModelParams;
|
|
15
16
|
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
// @vitest-environment node
|
|
2
|
-
import { ModelRuntime } from '@lobechat/model-runtime';
|
|
2
|
+
import { LobeRuntimeAI, ModelRuntime } from '@lobechat/model-runtime';
|
|
3
3
|
import { ChatErrorType } from '@lobechat/types';
|
|
4
4
|
import { getXorPayload } from '@lobechat/utils/server';
|
|
5
5
|
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
|
6
6
|
|
|
7
7
|
import { LOBE_CHAT_AUTH_HEADER } from '@/const/auth';
|
|
8
|
+
import { initModelRuntimeFromDB } from '@/server/modules/ModelRuntime';
|
|
8
9
|
|
|
9
10
|
import { GET } from './route';
|
|
10
11
|
|
|
@@ -20,6 +21,10 @@ vi.mock('@lobechat/utils/server', () => ({
|
|
|
20
21
|
getXorPayload: vi.fn(),
|
|
21
22
|
}));
|
|
22
23
|
|
|
24
|
+
vi.mock('@/server/modules/ModelRuntime', () => ({
|
|
25
|
+
initModelRuntimeFromDB: vi.fn(),
|
|
26
|
+
}));
|
|
27
|
+
|
|
23
28
|
let request: Request;
|
|
24
29
|
|
|
25
30
|
beforeEach(() => {
|
|
@@ -48,9 +53,12 @@ describe('GET handler', () => {
|
|
|
48
53
|
errorWithStack.stack =
|
|
49
54
|
'Error: Something went wrong\n at Object.<anonymous> (/path/to/file.ts:10:15)';
|
|
50
55
|
|
|
51
|
-
|
|
56
|
+
const mockRuntime: LobeRuntimeAI = {
|
|
57
|
+
baseURL: 'abc',
|
|
58
|
+
chat: vi.fn(),
|
|
52
59
|
models: vi.fn().mockRejectedValue(errorWithStack),
|
|
53
|
-
}
|
|
60
|
+
};
|
|
61
|
+
vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
|
|
54
62
|
|
|
55
63
|
const response = await GET(request, { params: mockParams });
|
|
56
64
|
const responseBody = await response.json();
|
|
@@ -85,9 +93,12 @@ describe('GET handler', () => {
|
|
|
85
93
|
const customError = new CustomError('Custom error occurred');
|
|
86
94
|
customError.stack = 'CustomError: Custom error occurred\n at somewhere';
|
|
87
95
|
|
|
88
|
-
|
|
96
|
+
const mockRuntime: LobeRuntimeAI = {
|
|
97
|
+
baseURL: 'abc',
|
|
98
|
+
chat: vi.fn(),
|
|
89
99
|
models: vi.fn().mockRejectedValue(customError),
|
|
90
|
-
}
|
|
100
|
+
};
|
|
101
|
+
vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
|
|
91
102
|
|
|
92
103
|
const response = await GET(request, { params: mockParams });
|
|
93
104
|
const responseBody = await response.json();
|
|
@@ -109,9 +120,12 @@ describe('GET handler', () => {
|
|
|
109
120
|
error: { code: 'PROVIDER_ERROR', details: 'API limit exceeded' },
|
|
110
121
|
};
|
|
111
122
|
|
|
112
|
-
|
|
123
|
+
const mockRuntime: LobeRuntimeAI = {
|
|
124
|
+
baseURL: 'abc',
|
|
125
|
+
chat: vi.fn(),
|
|
113
126
|
models: vi.fn().mockRejectedValue(structuredError),
|
|
114
|
-
}
|
|
127
|
+
};
|
|
128
|
+
vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
|
|
115
129
|
|
|
116
130
|
const response = await GET(request, { params: mockParams });
|
|
117
131
|
const responseBody = await response.json();
|
|
@@ -128,9 +142,12 @@ describe('GET handler', () => {
|
|
|
128
142
|
apiKey: 'test-api-key',
|
|
129
143
|
});
|
|
130
144
|
|
|
131
|
-
|
|
145
|
+
const mockRuntime: LobeRuntimeAI = {
|
|
146
|
+
baseURL: 'abc',
|
|
147
|
+
chat: vi.fn(),
|
|
132
148
|
models: vi.fn().mockRejectedValue(new Error('Failed')),
|
|
133
|
-
}
|
|
149
|
+
};
|
|
150
|
+
vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
|
|
134
151
|
|
|
135
152
|
const response = await GET(request, { params: mockParams });
|
|
136
153
|
|
|
@@ -144,9 +161,12 @@ describe('GET handler', () => {
|
|
|
144
161
|
apiKey: 'test-api-key',
|
|
145
162
|
});
|
|
146
163
|
|
|
147
|
-
|
|
164
|
+
const mockRuntime: LobeRuntimeAI = {
|
|
165
|
+
baseURL: 'abc',
|
|
166
|
+
chat: vi.fn(),
|
|
148
167
|
models: vi.fn().mockRejectedValue(new Error('Failed')),
|
|
149
|
-
}
|
|
168
|
+
};
|
|
169
|
+
vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
|
|
150
170
|
|
|
151
171
|
const response = await GET(request, { params: mockParams });
|
|
152
172
|
const responseBody = await response.json();
|
|
@@ -168,9 +188,12 @@ describe('GET handler', () => {
|
|
|
168
188
|
{ id: 'gpt-3.5-turbo', name: 'GPT-3.5 Turbo' },
|
|
169
189
|
];
|
|
170
190
|
|
|
171
|
-
|
|
191
|
+
const mockRuntime: LobeRuntimeAI = {
|
|
192
|
+
baseURL: 'abc',
|
|
193
|
+
chat: vi.fn(),
|
|
172
194
|
models: vi.fn().mockResolvedValue(mockModelList),
|
|
173
|
-
}
|
|
195
|
+
};
|
|
196
|
+
vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
|
|
174
197
|
|
|
175
198
|
const response = await GET(request, { params: mockParams });
|
|
176
199
|
const responseBody = await response.json();
|
|
@@ -1,24 +1,17 @@
|
|
|
1
1
|
import { type ChatCompletionErrorPayload } from '@lobechat/model-runtime';
|
|
2
2
|
import { ChatErrorType } from '@lobechat/types';
|
|
3
|
-
import { ModelProvider } from 'model-bank';
|
|
4
3
|
import { NextResponse } from 'next/server';
|
|
5
4
|
|
|
6
5
|
import { checkAuth } from '@/app/(backend)/middleware/auth';
|
|
7
|
-
import {
|
|
6
|
+
import { initModelRuntimeFromDB } from '@/server/modules/ModelRuntime';
|
|
8
7
|
import { createErrorResponse } from '@/utils/errorResponse';
|
|
9
8
|
|
|
10
|
-
const
|
|
11
|
-
|
|
12
|
-
export const GET = checkAuth(async (req, { params, jwtPayload }) => {
|
|
9
|
+
export const GET = checkAuth(async (req, { params, userId, serverDB }) => {
|
|
13
10
|
const provider = (await params)!.provider!;
|
|
14
11
|
|
|
15
12
|
try {
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
const agentRuntime = await initModelRuntimeWithUserPayload(provider, {
|
|
19
|
-
...jwtPayload,
|
|
20
|
-
apiKey: noNeedAPIKey(provider) ? hasDefaultApiKey : jwtPayload.apiKey,
|
|
21
|
-
});
|
|
13
|
+
// Read user's provider config from database
|
|
14
|
+
const agentRuntime = await initModelRuntimeFromDB(serverDB, userId, provider);
|
|
22
15
|
|
|
23
16
|
const list = await agentRuntime.models();
|
|
24
17
|
|