@lobehub/chat 1.108.0 → 1.108.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.cursor/rules/testing-guide/testing-guide.mdc +18 -0
- package/CHANGELOG.md +58 -0
- package/README.md +3 -2
- package/README.zh-CN.md +3 -2
- package/changelog/v1.json +21 -0
- package/package.json +3 -3
- package/src/app/(backend)/trpc/desktop/[trpc]/route.ts +1 -1
- package/src/app/[variants]/(main)/settings/provider/features/ProviderConfig/Checker.tsx +15 -2
- package/src/app/[variants]/(main)/settings/provider/features/ProviderConfig/index.tsx +30 -3
- package/src/app/[variants]/layout.tsx +1 -0
- package/src/components/Analytics/LobeAnalyticsProvider.tsx +10 -13
- package/src/components/Analytics/LobeAnalyticsProviderWrapper.tsx +16 -4
- package/src/database/models/__tests__/_test_template.ts +1 -1
- package/src/database/models/__tests__/agent.test.ts +1 -1
- package/src/database/models/__tests__/aiModel.test.ts +1 -1
- package/src/database/models/__tests__/aiProvider.test.ts +1 -1
- package/src/database/models/__tests__/asyncTask.test.ts +1 -1
- package/src/database/models/__tests__/chunk.test.ts +1 -1
- package/src/database/models/__tests__/file.test.ts +1 -1
- package/src/database/models/__tests__/generationTopic.test.ts +1 -1
- package/src/database/models/__tests__/knowledgeBase.test.ts +1 -1
- package/src/database/models/__tests__/message.test.ts +1 -1
- package/src/database/models/__tests__/session.test.ts +1 -1
- package/src/database/models/__tests__/sessionGroup.test.ts +1 -1
- package/src/database/models/__tests__/topic.test.ts +1 -1
- package/src/database/models/_template.ts +1 -1
- package/src/database/models/agent.ts +1 -1
- package/src/database/models/aiModel.ts +1 -1
- package/src/database/models/aiProvider.ts +1 -1
- package/src/database/models/apiKey.ts +1 -1
- package/src/database/models/asyncTask.ts +1 -1
- package/src/database/models/chunk.ts +1 -2
- package/src/database/models/document.ts +1 -1
- package/src/database/models/embedding.ts +1 -2
- package/src/database/models/file.ts +1 -2
- package/src/database/models/generationTopic.ts +1 -1
- package/src/database/models/knowledgeBase.ts +1 -1
- package/src/database/models/message.ts +1 -2
- package/src/database/models/plugin.ts +1 -1
- package/src/database/models/session.ts +15 -2
- package/src/database/models/sessionGroup.ts +1 -1
- package/src/database/models/thread.ts +1 -1
- package/src/database/models/topic.ts +1 -2
- package/src/database/models/user.ts +1 -1
- package/src/database/repositories/dataExporter/index.ts +1 -1
- package/src/database/repositories/dataImporter/__tests__/index.test.ts +1 -1
- package/src/database/repositories/dataImporter/deprecated/__tests__/index.test.ts +1 -1
- package/src/database/repositories/dataImporter/deprecated/index.ts +1 -2
- package/src/database/repositories/dataImporter/index.ts +1 -1
- package/src/database/server/models/__tests__/adapter.test.ts +1 -1
- package/src/database/server/models/__tests__/nextauth.test.ts +1 -1
- package/src/database/server/models/__tests__/user.test.ts +1 -1
- package/src/database/server/models/ragEval/dataset.ts +1 -1
- package/src/database/server/models/ragEval/datasetRecord.ts +1 -1
- package/src/database/server/models/ragEval/evaluation.ts +1 -2
- package/src/database/server/models/ragEval/evaluationRecord.ts +1 -1
- package/src/database/utils/genWhere.ts +1 -2
- package/src/features/User/UserAvatar.tsx +18 -2
- package/src/libs/model-runtime/RouterRuntime/createRuntime.test.ts +538 -0
- package/src/libs/model-runtime/RouterRuntime/createRuntime.ts +50 -13
- package/src/libs/model-runtime/RouterRuntime/index.ts +1 -1
- package/src/libs/model-runtime/aihubmix/index.ts +10 -5
- package/src/libs/model-runtime/ppio/index.test.ts +3 -6
- package/src/libs/model-runtime/utils/openaiCompatibleFactory/index.ts +8 -6
- package/src/libs/next-auth/adapter/index.ts +1 -1
- package/src/libs/oidc-provider/adapter.ts +1 -2
- package/src/server/globalConfig/genServerAiProviderConfig.test.ts +22 -25
- package/src/server/globalConfig/genServerAiProviderConfig.ts +34 -22
- package/src/server/globalConfig/index.ts +1 -1
- package/src/server/routers/lambda/chunk.ts +1 -1
- package/src/server/services/discover/index.ts +11 -2
- package/src/services/chat.ts +1 -1
- package/src/services/session/client.test.ts +1 -1
- package/src/store/aiInfra/slices/aiProvider/__tests__/action.test.ts +211 -0
- package/src/store/aiInfra/slices/aiProvider/action.ts +46 -35
- package/src/store/user/slices/modelList/action.test.ts +5 -5
- package/src/store/user/slices/modelList/action.ts +4 -4
- package/src/styles/antdOverride.ts +6 -0
- package/src/utils/getFallbackModelProperty.test.ts +52 -45
- package/src/utils/getFallbackModelProperty.ts +4 -3
- package/src/utils/parseModels.test.ts +107 -98
- package/src/utils/parseModels.ts +10 -8
@@ -1,13 +1,10 @@
|
|
1
1
|
// @vitest-environment node
|
2
|
-
import OpenAI from 'openai';
|
3
2
|
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
4
3
|
|
5
|
-
import {
|
4
|
+
import { LobeOpenAICompatibleRuntime } from '@/libs/model-runtime';
|
6
5
|
import { ModelProvider } from '@/libs/model-runtime';
|
7
|
-
import { AgentRuntimeErrorType } from '@/libs/model-runtime';
|
8
6
|
import { testProvider } from '@/libs/model-runtime/providerTestUtils';
|
9
7
|
|
10
|
-
import * as debugStreamModule from '../utils/debugStream';
|
11
8
|
import models from './fixtures/models.json';
|
12
9
|
import { LobePPIOAI } from './index';
|
13
10
|
|
@@ -30,7 +27,7 @@ let instance: LobeOpenAICompatibleRuntime;
|
|
30
27
|
beforeEach(() => {
|
31
28
|
instance = new LobePPIOAI({ apiKey: 'test' });
|
32
29
|
|
33
|
-
//
|
30
|
+
// Use vi.spyOn to mock the chat.completions.create method
|
34
31
|
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
|
35
32
|
new ReadableStream() as any,
|
36
33
|
);
|
@@ -44,7 +41,7 @@ afterEach(() => {
|
|
44
41
|
describe('PPIO', () => {
|
45
42
|
describe('models', () => {
|
46
43
|
it('should get models', async () => {
|
47
|
-
//
|
44
|
+
// Mock the models.list method
|
48
45
|
(instance['client'].models.list as Mock).mockResolvedValue({ data: models });
|
49
46
|
|
50
47
|
const list = await instance.models();
|
@@ -489,12 +489,14 @@ export const createOpenAICompatibleRuntime = <T extends Record<string, any> = an
|
|
489
489
|
.filter(Boolean) as ChatModelCard[];
|
490
490
|
}
|
491
491
|
|
492
|
-
return
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
492
|
+
return (await Promise.all(
|
493
|
+
resultModels.map(async (model) => {
|
494
|
+
return {
|
495
|
+
...model,
|
496
|
+
type: model.type || (await getModelPropertyWithFallback(model.id, 'type')),
|
497
|
+
};
|
498
|
+
}),
|
499
|
+
)) as ChatModelCard[];
|
498
500
|
}
|
499
501
|
|
500
502
|
async embeddings(
|
@@ -4,7 +4,7 @@ import type {
|
|
4
4
|
AdapterUser,
|
5
5
|
VerificationToken,
|
6
6
|
} from '@auth/core/adapters';
|
7
|
-
import { and, eq } from 'drizzle-orm
|
7
|
+
import { and, eq } from 'drizzle-orm';
|
8
8
|
import type { NeonDatabase } from 'drizzle-orm/neon-serverless';
|
9
9
|
import { Adapter, AdapterAccount } from 'next-auth/adapters';
|
10
10
|
|
@@ -1,8 +1,5 @@
|
|
1
1
|
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
2
2
|
|
3
|
-
import { ModelProvider } from '@/libs/model-runtime';
|
4
|
-
import { AiFullModelCard } from '@/types/aiModel';
|
5
|
-
|
6
3
|
import { genServerAiProvidersConfig } from './genServerAiProviderConfig';
|
7
4
|
|
8
5
|
// Mock dependencies using importOriginal to preserve real provider data
|
@@ -23,11 +20,11 @@ vi.mock('@/config/llm', () => ({
|
|
23
20
|
}));
|
24
21
|
|
25
22
|
vi.mock('@/utils/parseModels', () => ({
|
26
|
-
extractEnabledModels: vi.fn((providerId: string, modelString?: string) => {
|
23
|
+
extractEnabledModels: vi.fn(async (providerId: string, modelString?: string) => {
|
27
24
|
if (!modelString) return undefined;
|
28
25
|
return [`${providerId}-model-1`, `${providerId}-model-2`];
|
29
26
|
}),
|
30
|
-
transformToAiModelList: vi.fn((params) => {
|
27
|
+
transformToAiModelList: vi.fn(async (params) => {
|
31
28
|
return params.defaultModels;
|
32
29
|
}),
|
33
30
|
}));
|
@@ -43,8 +40,8 @@ describe('genServerAiProvidersConfig', () => {
|
|
43
40
|
});
|
44
41
|
});
|
45
42
|
|
46
|
-
it('should generate basic provider config with default settings', () => {
|
47
|
-
const result = genServerAiProvidersConfig({});
|
43
|
+
it('should generate basic provider config with default settings', async () => {
|
44
|
+
const result = await genServerAiProvidersConfig({});
|
48
45
|
|
49
46
|
expect(result).toHaveProperty('openai');
|
50
47
|
expect(result).toHaveProperty('anthropic');
|
@@ -62,7 +59,7 @@ describe('genServerAiProvidersConfig', () => {
|
|
62
59
|
});
|
63
60
|
});
|
64
61
|
|
65
|
-
it('should use custom enabled settings from specificConfig', () => {
|
62
|
+
it('should use custom enabled settings from specificConfig', async () => {
|
66
63
|
const specificConfig = {
|
67
64
|
openai: {
|
68
65
|
enabled: false,
|
@@ -72,7 +69,7 @@ describe('genServerAiProvidersConfig', () => {
|
|
72
69
|
},
|
73
70
|
};
|
74
71
|
|
75
|
-
const result = genServerAiProvidersConfig(specificConfig);
|
72
|
+
const result = await genServerAiProvidersConfig(specificConfig);
|
76
73
|
|
77
74
|
expect(result.openai.enabled).toBe(false);
|
78
75
|
expect(result.anthropic.enabled).toBe(true);
|
@@ -93,7 +90,7 @@ describe('genServerAiProvidersConfig', () => {
|
|
93
90
|
CUSTOM_OPENAI_ENABLED: true,
|
94
91
|
} as any);
|
95
92
|
|
96
|
-
const result = genServerAiProvidersConfig(specificConfig);
|
93
|
+
const result = await genServerAiProvidersConfig(specificConfig);
|
97
94
|
|
98
95
|
expect(result.openai.enabled).toBe(true);
|
99
96
|
});
|
@@ -102,9 +99,9 @@ describe('genServerAiProvidersConfig', () => {
|
|
102
99
|
process.env.OPENAI_MODEL_LIST = '+gpt-4,+gpt-3.5-turbo';
|
103
100
|
|
104
101
|
const { extractEnabledModels } = vi.mocked(await import('@/utils/parseModels'));
|
105
|
-
extractEnabledModels.
|
102
|
+
extractEnabledModels.mockResolvedValue(['gpt-4', 'gpt-3.5-turbo']);
|
106
103
|
|
107
|
-
const result = genServerAiProvidersConfig({});
|
104
|
+
const result = await genServerAiProvidersConfig({});
|
108
105
|
|
109
106
|
expect(extractEnabledModels).toHaveBeenCalledWith('openai', '+gpt-4,+gpt-3.5-turbo', false);
|
110
107
|
expect(result.openai.enabledModels).toEqual(['gpt-4', 'gpt-3.5-turbo']);
|
@@ -121,7 +118,7 @@ describe('genServerAiProvidersConfig', () => {
|
|
121
118
|
|
122
119
|
const { extractEnabledModels } = vi.mocked(await import('@/utils/parseModels'));
|
123
120
|
|
124
|
-
genServerAiProvidersConfig(specificConfig);
|
121
|
+
await genServerAiProvidersConfig(specificConfig);
|
125
122
|
|
126
123
|
expect(extractEnabledModels).toHaveBeenCalledWith('openai', '+custom-model', false);
|
127
124
|
});
|
@@ -139,7 +136,7 @@ describe('genServerAiProvidersConfig', () => {
|
|
139
136
|
await import('@/utils/parseModels'),
|
140
137
|
);
|
141
138
|
|
142
|
-
genServerAiProvidersConfig(specificConfig);
|
139
|
+
await genServerAiProvidersConfig(specificConfig);
|
143
140
|
|
144
141
|
expect(extractEnabledModels).toHaveBeenCalledWith('openai', '+gpt-4->deployment1', true);
|
145
142
|
expect(transformToAiModelList).toHaveBeenCalledWith({
|
@@ -150,26 +147,26 @@ describe('genServerAiProvidersConfig', () => {
|
|
150
147
|
});
|
151
148
|
});
|
152
149
|
|
153
|
-
it('should include fetchOnClient when specified in config', () => {
|
150
|
+
it('should include fetchOnClient when specified in config', async () => {
|
154
151
|
const specificConfig = {
|
155
152
|
openai: {
|
156
153
|
fetchOnClient: true,
|
157
154
|
},
|
158
155
|
};
|
159
156
|
|
160
|
-
const result = genServerAiProvidersConfig(specificConfig);
|
157
|
+
const result = await genServerAiProvidersConfig(specificConfig);
|
161
158
|
|
162
159
|
expect(result.openai).toHaveProperty('fetchOnClient', true);
|
163
160
|
});
|
164
161
|
|
165
|
-
it('should not include fetchOnClient when not specified in config', () => {
|
166
|
-
const result = genServerAiProvidersConfig({});
|
162
|
+
it('should not include fetchOnClient when not specified in config', async () => {
|
163
|
+
const result = await genServerAiProvidersConfig({});
|
167
164
|
|
168
165
|
expect(result.openai).not.toHaveProperty('fetchOnClient');
|
169
166
|
});
|
170
167
|
|
171
|
-
it('should handle all available providers', () => {
|
172
|
-
const result = genServerAiProvidersConfig({});
|
168
|
+
it('should handle all available providers', async () => {
|
169
|
+
const result = await genServerAiProvidersConfig({});
|
173
170
|
|
174
171
|
// Check that result includes some key providers
|
175
172
|
expect(result).toHaveProperty('openai');
|
@@ -210,8 +207,8 @@ describe('genServerAiProvidersConfig Error Handling', () => {
|
|
210
207
|
}));
|
211
208
|
|
212
209
|
vi.doMock('@/utils/parseModels', () => ({
|
213
|
-
extractEnabledModels: vi.fn(() => undefined),
|
214
|
-
transformToAiModelList: vi.fn(() => []),
|
210
|
+
extractEnabledModels: vi.fn(async () => undefined),
|
211
|
+
transformToAiModelList: vi.fn(async () => []),
|
215
212
|
}));
|
216
213
|
|
217
214
|
// Mock ModelProvider to include the missing provider
|
@@ -228,8 +225,8 @@ describe('genServerAiProvidersConfig Error Handling', () => {
|
|
228
225
|
);
|
229
226
|
|
230
227
|
// This should throw because 'openai' is in ModelProvider but not in aiModels
|
231
|
-
expect(() => {
|
232
|
-
genServerAiProvidersConfig({});
|
233
|
-
}).toThrow();
|
228
|
+
await expect(async () => {
|
229
|
+
await genServerAiProvidersConfig({});
|
230
|
+
}).rejects.toThrow();
|
234
231
|
});
|
235
232
|
});
|
@@ -13,11 +13,14 @@ interface ProviderSpecificConfig {
|
|
13
13
|
withDeploymentName?: boolean;
|
14
14
|
}
|
15
15
|
|
16
|
-
export const genServerAiProvidersConfig = (
|
16
|
+
export const genServerAiProvidersConfig = async (
|
17
|
+
specificConfig: Record<any, ProviderSpecificConfig>,
|
18
|
+
) => {
|
17
19
|
const llmConfig = getLLMConfig() as Record<string, any>;
|
18
20
|
|
19
|
-
|
20
|
-
|
21
|
+
// 并发处理所有 providers
|
22
|
+
const providerConfigs = await Promise.all(
|
23
|
+
Object.values(ModelProvider).map(async (provider) => {
|
21
24
|
const providerUpperCase = provider.toUpperCase();
|
22
25
|
const aiModels = AiModels[provider] as AiFullModelCard[];
|
23
26
|
|
@@ -30,30 +33,39 @@ export const genServerAiProvidersConfig = (specificConfig: Record<any, ProviderS
|
|
30
33
|
const modelString =
|
31
34
|
process.env[providerConfig.modelListKey ?? `${providerUpperCase}_MODEL_LIST`];
|
32
35
|
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
: llmConfig[providerConfig.enabledKey || `ENABLED_${providerUpperCase}`],
|
38
|
-
|
39
|
-
enabledModels: extractEnabledModels(
|
40
|
-
provider,
|
41
|
-
modelString,
|
42
|
-
providerConfig.withDeploymentName || false,
|
43
|
-
),
|
44
|
-
serverModelLists: transformToAiModelList({
|
36
|
+
// 并发处理 extractEnabledModels 和 transformToAiModelList
|
37
|
+
const [enabledModels, serverModelLists] = await Promise.all([
|
38
|
+
extractEnabledModels(provider, modelString, providerConfig.withDeploymentName || false),
|
39
|
+
transformToAiModelList({
|
45
40
|
defaultModels: aiModels || [],
|
46
41
|
modelString,
|
47
42
|
providerId: provider,
|
48
43
|
withDeploymentName: providerConfig.withDeploymentName || false,
|
49
44
|
}),
|
50
|
-
|
51
|
-
fetchOnClient: providerConfig.fetchOnClient,
|
52
|
-
}),
|
53
|
-
};
|
45
|
+
]);
|
54
46
|
|
55
|
-
return
|
56
|
-
|
57
|
-
|
47
|
+
return {
|
48
|
+
config: {
|
49
|
+
enabled:
|
50
|
+
typeof providerConfig.enabled !== 'undefined'
|
51
|
+
? providerConfig.enabled
|
52
|
+
: llmConfig[providerConfig.enabledKey || `ENABLED_${providerUpperCase}`],
|
53
|
+
enabledModels,
|
54
|
+
serverModelLists,
|
55
|
+
...(providerConfig.fetchOnClient !== undefined && {
|
56
|
+
fetchOnClient: providerConfig.fetchOnClient,
|
57
|
+
}),
|
58
|
+
},
|
59
|
+
provider,
|
60
|
+
};
|
61
|
+
}),
|
58
62
|
);
|
63
|
+
|
64
|
+
// 将结果转换为对象
|
65
|
+
const config = {} as Record<string, ProviderConfig>;
|
66
|
+
for (const { provider, config: providerConfig } of providerConfigs) {
|
67
|
+
config[provider] = providerConfig;
|
68
|
+
}
|
69
|
+
|
70
|
+
return config;
|
59
71
|
};
|
@@ -17,7 +17,7 @@ export const getServerGlobalConfig = async () => {
|
|
17
17
|
const { ACCESS_CODES, DEFAULT_AGENT_CONFIG } = getAppConfig();
|
18
18
|
|
19
19
|
const config: GlobalServerConfig = {
|
20
|
-
aiProvider: genServerAiProvidersConfig({
|
20
|
+
aiProvider: await genServerAiProvidersConfig({
|
21
21
|
azure: {
|
22
22
|
enabledKey: 'ENABLED_AZURE_OPENAI',
|
23
23
|
withDeploymentName: true,
|
@@ -6,8 +6,6 @@ import matter from 'gray-matter';
|
|
6
6
|
import { cloneDeep, countBy, isString, merge, uniq, uniqBy } from 'lodash-es';
|
7
7
|
import urlJoin from 'url-join';
|
8
8
|
|
9
|
-
import { LOBE_DEFAULT_MODEL_LIST } from '@/config/aiModels';
|
10
|
-
import { DEFAULT_MODEL_PROVIDER_LIST } from '@/config/modelProviders';
|
11
9
|
import {
|
12
10
|
DEFAULT_DISCOVER_ASSISTANT_ITEM,
|
13
11
|
DEFAULT_DISCOVER_PLUGIN_ITEM,
|
@@ -728,6 +726,10 @@ export class DiscoverService {
|
|
728
726
|
|
729
727
|
private _getProviderList = async (): Promise<DiscoverProviderItem[]> => {
|
730
728
|
log('_getProviderList: fetching provider list');
|
729
|
+
const [{ LOBE_DEFAULT_MODEL_LIST }, { DEFAULT_MODEL_PROVIDER_LIST }] = await Promise.all([
|
730
|
+
import('@/config/aiModels'),
|
731
|
+
import('@/config/modelProviders'),
|
732
|
+
]);
|
731
733
|
const result = DEFAULT_MODEL_PROVIDER_LIST.map((item) => {
|
732
734
|
const models = uniq(
|
733
735
|
LOBE_DEFAULT_MODEL_LIST.filter((m) => m.providerId === item.id).map((m) => m.id),
|
@@ -751,6 +753,7 @@ export class DiscoverService {
|
|
751
753
|
}): Promise<DiscoverProviderDetail | undefined> => {
|
752
754
|
log('getProviderDetail: params=%O', params);
|
753
755
|
const { identifier, locale, withReadme } = params;
|
756
|
+
const { LOBE_DEFAULT_MODEL_LIST } = await import('@/config/aiModels');
|
754
757
|
const all = await this._getProviderList();
|
755
758
|
let provider = all.find((item) => item.identifier === identifier);
|
756
759
|
if (!provider) {
|
@@ -886,6 +889,7 @@ export class DiscoverService {
|
|
886
889
|
|
887
890
|
private _getRawModelList = async (): Promise<DiscoverModelItem[]> => {
|
888
891
|
log('_getRawModelList: fetching raw model list');
|
892
|
+
const { LOBE_DEFAULT_MODEL_LIST } = await import('@/config/aiModels');
|
889
893
|
const result = LOBE_DEFAULT_MODEL_LIST.map((item) => {
|
890
894
|
const identifier = (item.id.split('/').at(-1) || item.id).toLowerCase();
|
891
895
|
const providers = uniq(
|
@@ -978,6 +982,7 @@ export class DiscoverService {
|
|
978
982
|
getModelCategories = async (params: CategoryListQuery = {}): Promise<CategoryItem[]> => {
|
979
983
|
log('getModelCategories: params=%O', params);
|
980
984
|
const { q } = params;
|
985
|
+
const { LOBE_DEFAULT_MODEL_LIST } = await import('@/config/aiModels');
|
981
986
|
let list = LOBE_DEFAULT_MODEL_LIST;
|
982
987
|
if (q) {
|
983
988
|
const originalCount = list.length;
|
@@ -1011,6 +1016,10 @@ export class DiscoverService {
|
|
1011
1016
|
identifier: string;
|
1012
1017
|
}): Promise<DiscoverModelDetail | undefined> => {
|
1013
1018
|
log('getModelDetail: params=%O', params);
|
1019
|
+
const [{ LOBE_DEFAULT_MODEL_LIST }, { DEFAULT_MODEL_PROVIDER_LIST }] = await Promise.all([
|
1020
|
+
import('@/config/aiModels'),
|
1021
|
+
import('@/config/modelProviders'),
|
1022
|
+
]);
|
1014
1023
|
const { identifier } = params;
|
1015
1024
|
const all = await this._getModelList();
|
1016
1025
|
let model = all.find((item) => item.identifier.toLowerCase() === identifier.toLowerCase());
|
package/src/services/chat.ts
CHANGED
@@ -2,7 +2,6 @@ import { PluginRequestPayload, createHeadersWithPluginSettings } from '@lobehub/
|
|
2
2
|
import { produce } from 'immer';
|
3
3
|
import { merge } from 'lodash-es';
|
4
4
|
|
5
|
-
import { DEFAULT_MODEL_PROVIDER_LIST } from '@/config/modelProviders';
|
6
5
|
import { enableAuth } from '@/const/auth';
|
7
6
|
import { INBOX_GUIDE_SYSTEMROLE } from '@/const/guide';
|
8
7
|
import { INBOX_SESSION_ID } from '@/const/session';
|
@@ -404,6 +403,7 @@ class ChatService {
|
|
404
403
|
provider,
|
405
404
|
});
|
406
405
|
|
406
|
+
const { DEFAULT_MODEL_PROVIDER_LIST } = await import('@/config/modelProviders');
|
407
407
|
const providerConfig = DEFAULT_MODEL_PROVIDER_LIST.find((item) => item.id === provider);
|
408
408
|
|
409
409
|
let sdkType = provider;
|
@@ -0,0 +1,211 @@
|
|
1
|
+
import { describe, expect, it, vi } from 'vitest';
|
2
|
+
|
3
|
+
import type { EnabledAiModel, ModelAbilities } from '@/types/aiModel';
|
4
|
+
|
5
|
+
import { getModelListByType } from '../action';
|
6
|
+
|
7
|
+
// Mock getModelPropertyWithFallback
|
8
|
+
vi.mock('@/utils/getFallbackModelProperty', () => ({
|
9
|
+
getModelPropertyWithFallback: vi.fn().mockReturnValue({ size: '1024x1024' }),
|
10
|
+
}));
|
11
|
+
|
12
|
+
describe('getModelListByType', () => {
|
13
|
+
const mockChatModels: EnabledAiModel[] = [
|
14
|
+
{
|
15
|
+
id: 'gpt-4',
|
16
|
+
providerId: 'openai',
|
17
|
+
type: 'chat',
|
18
|
+
abilities: { functionCall: true, files: true } as ModelAbilities,
|
19
|
+
contextWindowTokens: 8192,
|
20
|
+
displayName: 'GPT-4',
|
21
|
+
enabled: true,
|
22
|
+
},
|
23
|
+
{
|
24
|
+
id: 'gpt-3.5-turbo',
|
25
|
+
providerId: 'openai',
|
26
|
+
type: 'chat',
|
27
|
+
abilities: { functionCall: true } as ModelAbilities,
|
28
|
+
contextWindowTokens: 4096,
|
29
|
+
displayName: 'GPT-3.5 Turbo',
|
30
|
+
enabled: true,
|
31
|
+
},
|
32
|
+
{
|
33
|
+
id: 'claude-3-opus',
|
34
|
+
providerId: 'anthropic',
|
35
|
+
type: 'chat',
|
36
|
+
abilities: { functionCall: false, files: true } as ModelAbilities,
|
37
|
+
contextWindowTokens: 200000,
|
38
|
+
displayName: 'Claude 3 Opus',
|
39
|
+
enabled: true,
|
40
|
+
},
|
41
|
+
];
|
42
|
+
|
43
|
+
const mockImageModels: EnabledAiModel[] = [
|
44
|
+
{
|
45
|
+
id: 'dall-e-3',
|
46
|
+
providerId: 'openai',
|
47
|
+
type: 'image',
|
48
|
+
abilities: {} as ModelAbilities,
|
49
|
+
displayName: 'DALL-E 3',
|
50
|
+
enabled: true,
|
51
|
+
parameters: { size: '1024x1024', quality: 'standard' },
|
52
|
+
},
|
53
|
+
{
|
54
|
+
id: 'midjourney',
|
55
|
+
providerId: 'midjourney',
|
56
|
+
type: 'image',
|
57
|
+
abilities: {} as ModelAbilities,
|
58
|
+
displayName: 'Midjourney',
|
59
|
+
enabled: true,
|
60
|
+
},
|
61
|
+
];
|
62
|
+
|
63
|
+
const allModels = [...mockChatModels, ...mockImageModels];
|
64
|
+
|
65
|
+
describe('basic functionality', () => {
|
66
|
+
it('should filter models by providerId and type correctly', () => {
|
67
|
+
const result = getModelListByType(allModels, 'openai', 'chat');
|
68
|
+
|
69
|
+
expect(result).toHaveLength(2);
|
70
|
+
expect(result.map((m) => m.id)).toEqual(['gpt-4', 'gpt-3.5-turbo']);
|
71
|
+
});
|
72
|
+
|
73
|
+
it('should return correct model structure', () => {
|
74
|
+
const result = getModelListByType(allModels, 'openai', 'chat');
|
75
|
+
|
76
|
+
expect(result[0]).toEqual({
|
77
|
+
abilities: { functionCall: true, files: true },
|
78
|
+
contextWindowTokens: 8192,
|
79
|
+
displayName: 'GPT-4',
|
80
|
+
id: 'gpt-4',
|
81
|
+
});
|
82
|
+
});
|
83
|
+
|
84
|
+
it('should add parameters field for image models', () => {
|
85
|
+
const result = getModelListByType(allModels, 'openai', 'image');
|
86
|
+
|
87
|
+
expect(result[0]).toEqual({
|
88
|
+
abilities: {},
|
89
|
+
contextWindowTokens: undefined,
|
90
|
+
displayName: 'DALL-E 3',
|
91
|
+
id: 'dall-e-3',
|
92
|
+
parameters: { size: '1024x1024', quality: 'standard' },
|
93
|
+
});
|
94
|
+
});
|
95
|
+
|
96
|
+
it('should use fallback parameters for image models without parameters', () => {
|
97
|
+
const result = getModelListByType(allModels, 'midjourney', 'image');
|
98
|
+
|
99
|
+
expect(result[0]).toEqual({
|
100
|
+
abilities: {},
|
101
|
+
contextWindowTokens: undefined,
|
102
|
+
displayName: 'Midjourney',
|
103
|
+
id: 'midjourney',
|
104
|
+
parameters: { size: '1024x1024' },
|
105
|
+
});
|
106
|
+
});
|
107
|
+
});
|
108
|
+
|
109
|
+
describe('edge cases', () => {
|
110
|
+
it('should handle empty model list', () => {
|
111
|
+
const result = getModelListByType([], 'openai', 'chat');
|
112
|
+
expect(result).toEqual([]);
|
113
|
+
});
|
114
|
+
|
115
|
+
it('should handle non-existent providerId', () => {
|
116
|
+
const result = getModelListByType(allModels, 'nonexistent', 'chat');
|
117
|
+
expect(result).toEqual([]);
|
118
|
+
});
|
119
|
+
|
120
|
+
it('should handle non-existent type', () => {
|
121
|
+
const result = getModelListByType(allModels, 'openai', 'nonexistent');
|
122
|
+
expect(result).toEqual([]);
|
123
|
+
});
|
124
|
+
|
125
|
+
it('should handle missing displayName', () => {
|
126
|
+
const modelsWithoutDisplayName: EnabledAiModel[] = [
|
127
|
+
{
|
128
|
+
id: 'test-model',
|
129
|
+
providerId: 'test',
|
130
|
+
type: 'chat',
|
131
|
+
abilities: {} as ModelAbilities,
|
132
|
+
enabled: true,
|
133
|
+
},
|
134
|
+
];
|
135
|
+
|
136
|
+
const result = getModelListByType(modelsWithoutDisplayName, 'test', 'chat');
|
137
|
+
expect(result[0].displayName).toBe('');
|
138
|
+
});
|
139
|
+
|
140
|
+
it('should handle missing abilities', () => {
|
141
|
+
const modelsWithoutAbilities: EnabledAiModel[] = [
|
142
|
+
{
|
143
|
+
id: 'test-model',
|
144
|
+
providerId: 'test',
|
145
|
+
type: 'chat',
|
146
|
+
enabled: true,
|
147
|
+
} as EnabledAiModel,
|
148
|
+
];
|
149
|
+
|
150
|
+
const result = getModelListByType(modelsWithoutAbilities, 'test', 'chat');
|
151
|
+
expect(result[0].abilities).toEqual({});
|
152
|
+
});
|
153
|
+
});
|
154
|
+
|
155
|
+
describe('deduplication', () => {
|
156
|
+
it('should remove duplicate model IDs', () => {
|
157
|
+
const duplicateModels: EnabledAiModel[] = [
|
158
|
+
{
|
159
|
+
id: 'gpt-4',
|
160
|
+
providerId: 'openai',
|
161
|
+
type: 'chat',
|
162
|
+
abilities: { functionCall: true } as ModelAbilities,
|
163
|
+
displayName: 'GPT-4 Version 1',
|
164
|
+
enabled: true,
|
165
|
+
},
|
166
|
+
{
|
167
|
+
id: 'gpt-4',
|
168
|
+
providerId: 'openai',
|
169
|
+
type: 'chat',
|
170
|
+
abilities: { functionCall: false } as ModelAbilities,
|
171
|
+
displayName: 'GPT-4 Version 2',
|
172
|
+
enabled: true,
|
173
|
+
},
|
174
|
+
];
|
175
|
+
|
176
|
+
const result = getModelListByType(duplicateModels, 'openai', 'chat');
|
177
|
+
|
178
|
+
expect(result).toHaveLength(1);
|
179
|
+
expect(result[0].displayName).toBe('GPT-4 Version 1');
|
180
|
+
});
|
181
|
+
});
|
182
|
+
|
183
|
+
describe('type casting', () => {
|
184
|
+
it('should handle image model type casting correctly', () => {
|
185
|
+
const imageModel: EnabledAiModel[] = [
|
186
|
+
{
|
187
|
+
id: 'dall-e-3',
|
188
|
+
providerId: 'openai',
|
189
|
+
type: 'image',
|
190
|
+
abilities: {} as ModelAbilities,
|
191
|
+
displayName: 'DALL-E 3',
|
192
|
+
enabled: true,
|
193
|
+
parameters: { size: '1024x1024' },
|
194
|
+
} as any, // Simulate AIImageModelCard type
|
195
|
+
];
|
196
|
+
|
197
|
+
const result = getModelListByType(imageModel, 'openai', 'image');
|
198
|
+
|
199
|
+
expect(result[0]).toHaveProperty('parameters');
|
200
|
+
expect(result[0].parameters).toEqual({ size: '1024x1024' });
|
201
|
+
});
|
202
|
+
|
203
|
+
it('should not add parameters field for non-image models', () => {
|
204
|
+
const result = getModelListByType(mockChatModels, 'openai', 'chat');
|
205
|
+
|
206
|
+
result.forEach((model) => {
|
207
|
+
expect(model).not.toHaveProperty('parameters');
|
208
|
+
});
|
209
|
+
});
|
210
|
+
});
|
211
|
+
});
|