@lobehub/chat 1.114.5 → 1.115.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.cursor/rules/project-introduce.mdc +1 -15
- package/.cursor/rules/project-structure.mdc +227 -0
- package/.cursor/rules/testing-guide/db-model-test.mdc +5 -3
- package/.cursor/rules/testing-guide/testing-guide.mdc +153 -168
- package/.github/workflows/claude.yml +1 -1
- package/.github/workflows/test.yml +9 -0
- package/.prettierignore +0 -1
- package/.vscode/settings.json +86 -80
- package/CHANGELOG.md +50 -0
- package/CLAUDE.md +11 -27
- package/changelog/v1.json +10 -0
- package/docs/development/basic/feature-development.mdx +1 -1
- package/docs/development/basic/feature-development.zh-CN.mdx +1 -1
- package/package.json +5 -5
- package/packages/const/src/image.ts +28 -0
- package/packages/const/src/index.ts +1 -0
- package/packages/database/package.json +4 -2
- package/packages/database/src/repositories/aiInfra/index.ts +1 -1
- package/packages/database/tests/setup-db.ts +3 -0
- package/packages/database/vitest.config.mts +33 -0
- package/packages/model-runtime/src/utils/modelParse.ts +1 -1
- package/packages/utils/src/client/imageDimensions.test.ts +95 -0
- package/packages/utils/src/client/imageDimensions.ts +54 -0
- package/packages/utils/src/number.test.ts +3 -1
- package/packages/utils/src/number.ts +1 -2
- package/src/app/[variants]/(main)/files/[id]/page.tsx +0 -2
- package/src/app/[variants]/(main)/image/@menu/components/SeedNumberInput/index.tsx +1 -1
- package/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/components/DimensionControlGroup.tsx +0 -1
- package/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/components/ImageUpload.tsx +206 -185
- package/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/components/ImageUrl.tsx +16 -4
- package/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/components/ImageUrlsUpload.tsx +52 -3
- package/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/components/MultiImagesUpload/ImageManageModal.tsx +33 -19
- package/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/components/MultiImagesUpload/index.tsx +40 -12
- package/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/hooks/useAutoDimensions.ts +56 -0
- package/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/hooks/useUploadFilesValidation.ts +77 -0
- package/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/index.tsx +82 -5
- package/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/utils/__tests__/dimensionConstraints.test.ts +235 -0
- package/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/utils/__tests__/imageValidation.test.ts +401 -0
- package/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/utils/dimensionConstraints.ts +54 -0
- package/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/utils/imageValidation.ts +117 -0
- package/src/app/[variants]/(main)/image/@topic/features/Topics/TopicItem.tsx +3 -1
- package/src/app/[variants]/(main)/image/@topic/features/Topics/TopicList.tsx +15 -2
- package/src/app/[variants]/(main)/image/features/GenerationFeed/GenerationItem/utils.ts +5 -4
- package/src/libs/standard-parameters/index.ts +4 -1
- package/src/locales/default/components.ts +8 -0
- package/src/server/services/generation/index.ts +1 -1
- package/src/store/aiInfra/slices/aiProvider/__tests__/action.test.ts +29 -29
- package/src/store/aiInfra/slices/aiProvider/action.ts +80 -36
- package/src/store/chat/slices/builtinTool/actions/dalle.test.ts +20 -13
- package/src/store/file/slices/upload/action.ts +18 -7
- package/src/store/image/slices/generationConfig/hooks.ts +11 -1
- package/tsconfig.json +1 -10
- package/packages/const/src/imageGeneration.ts +0 -16
- package/src/app/(backend)/trpc/desktop/[trpc]/route.ts +0 -26
- package/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/components/AspectRatioSelect.tsx +0 -24
- package/src/app/[variants]/(main)/image/@menu/features/ConfigPanel/components/SizeSliderInput.tsx +0 -15
- package/src/app/[variants]/(main)/image/@topic/features/Topics/TopicItemContainer.tsx +0 -91
- package/src/app/desktop/devtools/page.tsx +0 -89
- package/src/app/desktop/layout.tsx +0 -31
- /package/apps/desktop/{vitest.config.ts → vitest.config.mts} +0 -0
- /package/packages/database/{vitest.config.ts → vitest.config.server.mts} +0 -0
- /package/packages/electron-server-ipc/{vitest.config.ts → vitest.config.mts} +0 -0
- /package/packages/file-loaders/{vitest.config.ts → vitest.config.mts} +0 -0
- /package/packages/model-runtime/{vitest.config.ts → vitest.config.mts} +0 -0
- /package/packages/prompts/{vitest.config.ts → vitest.config.mts} +0 -0
- /package/packages/utils/{vitest.config.ts → vitest.config.mts} +0 -0
- /package/packages/web-crawler/{vitest.config.ts → vitest.config.mts} +0 -0
- /package/{vitest.config.ts → vitest.config.mts} +0 -0
@@ -0,0 +1,117 @@
|
|
1
|
+
/**
|
2
|
+
* Image file validation utility functions
|
3
|
+
*/
|
4
|
+
|
5
|
+
/**
|
6
|
+
* Format file size to human readable format
|
7
|
+
* @param bytes - File size in bytes
|
8
|
+
* @returns Formatted string like "1.5 MB"
|
9
|
+
*/
|
10
|
+
export const formatFileSize = (bytes: number): string => {
|
11
|
+
if (bytes === 0) return '0 B';
|
12
|
+
|
13
|
+
const k = 1024;
|
14
|
+
const sizes = ['B', 'KB', 'MB', 'GB'];
|
15
|
+
const i = Math.floor(Math.log(bytes) / Math.log(k));
|
16
|
+
|
17
|
+
return `${parseFloat((bytes / Math.pow(k, i)).toFixed(1))} ${sizes[i]}`;
|
18
|
+
};
|
19
|
+
|
20
|
+
export interface ValidationResult {
|
21
|
+
// Additional details for error messages
|
22
|
+
actualSize?: number;
|
23
|
+
error?: string;
|
24
|
+
fileName?: string;
|
25
|
+
maxSize?: number;
|
26
|
+
valid: boolean;
|
27
|
+
}
|
28
|
+
|
29
|
+
/**
|
30
|
+
* Validate single image file size
|
31
|
+
* @param file - File to validate
|
32
|
+
* @param maxSize - Maximum file size in bytes, defaults to 10MB if not provided
|
33
|
+
* @returns Validation result
|
34
|
+
*/
|
35
|
+
export const validateImageFileSize = (file: File, maxSize?: number): ValidationResult => {
|
36
|
+
const defaultMaxSize = 10 * 1024 * 1024; // 10MB default limit
|
37
|
+
const actualMaxSize = maxSize ?? defaultMaxSize;
|
38
|
+
|
39
|
+
if (file.size > actualMaxSize) {
|
40
|
+
return {
|
41
|
+
actualSize: file.size,
|
42
|
+
error: 'fileSizeExceeded',
|
43
|
+
fileName: file.name,
|
44
|
+
maxSize: actualMaxSize,
|
45
|
+
valid: false,
|
46
|
+
};
|
47
|
+
}
|
48
|
+
|
49
|
+
return { valid: true };
|
50
|
+
};
|
51
|
+
|
52
|
+
/**
|
53
|
+
* Validate image count
|
54
|
+
* @param count - Current image count
|
55
|
+
* @param maxCount - Maximum allowed count, skip validation if not provided
|
56
|
+
* @returns Validation result
|
57
|
+
*/
|
58
|
+
export const validateImageCount = (count: number, maxCount?: number): ValidationResult => {
|
59
|
+
if (!maxCount) return { valid: true };
|
60
|
+
|
61
|
+
if (count > maxCount) {
|
62
|
+
return {
|
63
|
+
error: 'imageCountExceeded',
|
64
|
+
valid: false,
|
65
|
+
};
|
66
|
+
}
|
67
|
+
|
68
|
+
return { valid: true };
|
69
|
+
};
|
70
|
+
|
71
|
+
/**
|
72
|
+
* Validate image file list
|
73
|
+
* @param files - File list
|
74
|
+
* @param constraints - Constraint configuration
|
75
|
+
* @returns Validation result, including validation result for each file
|
76
|
+
*/
|
77
|
+
export const validateImageFiles = (
|
78
|
+
files: File[],
|
79
|
+
constraints: {
|
80
|
+
maxAddedFiles?: number;
|
81
|
+
maxFileSize?: number;
|
82
|
+
},
|
83
|
+
): {
|
84
|
+
errors: string[];
|
85
|
+
// Additional details for error messages
|
86
|
+
failedFiles?: ValidationResult[];
|
87
|
+
fileResults: ValidationResult[];
|
88
|
+
valid: boolean;
|
89
|
+
} => {
|
90
|
+
const errors: string[] = [];
|
91
|
+
const fileResults: ValidationResult[] = [];
|
92
|
+
const failedFiles: ValidationResult[] = [];
|
93
|
+
|
94
|
+
// Validate file count
|
95
|
+
const countResult = validateImageCount(files.length, constraints.maxAddedFiles);
|
96
|
+
if (!countResult.valid && countResult.error) {
|
97
|
+
errors.push(countResult.error);
|
98
|
+
}
|
99
|
+
|
100
|
+
// Validate each file
|
101
|
+
files.forEach((file) => {
|
102
|
+
const fileSizeResult = validateImageFileSize(file, constraints.maxFileSize);
|
103
|
+
fileResults.push(fileSizeResult);
|
104
|
+
|
105
|
+
if (!fileSizeResult.valid && fileSizeResult.error) {
|
106
|
+
errors.push(fileSizeResult.error);
|
107
|
+
failedFiles.push(fileSizeResult);
|
108
|
+
}
|
109
|
+
});
|
110
|
+
|
111
|
+
return {
|
112
|
+
errors: Array.from(new Set(errors)), // Remove duplicates
|
113
|
+
failedFiles,
|
114
|
+
fileResults,
|
115
|
+
valid: errors.length === 0,
|
116
|
+
};
|
117
|
+
};
|
@@ -23,10 +23,11 @@ const formatTime = (date: Date, locale: string) => {
|
|
23
23
|
|
24
24
|
interface TopicItemProps {
|
25
25
|
showMoreInfo?: boolean;
|
26
|
+
style?: React.CSSProperties;
|
26
27
|
topic: ImageGenerationTopic;
|
27
28
|
}
|
28
29
|
|
29
|
-
const TopicItem = memo<TopicItemProps>(({ topic, showMoreInfo }) => {
|
30
|
+
const TopicItem = memo<TopicItemProps>(({ topic, showMoreInfo, style }) => {
|
30
31
|
const theme = useTheme();
|
31
32
|
const { t } = useTranslation('image');
|
32
33
|
const { modal } = App.useApp();
|
@@ -111,6 +112,7 @@ const TopicItem = memo<TopicItemProps>(({ topic, showMoreInfo }) => {
|
|
111
112
|
onClick={handleClick}
|
112
113
|
style={{
|
113
114
|
cursor: 'pointer',
|
115
|
+
...style,
|
114
116
|
}}
|
115
117
|
width={'100%'}
|
116
118
|
>
|
@@ -47,8 +47,21 @@ const TopicsList = memo(() => {
|
|
47
47
|
showMoreInfo={showMoreInfo}
|
48
48
|
/>
|
49
49
|
<Flexbox align="center" gap={12} ref={parent} width={'100%'}>
|
50
|
-
{generationTopics.map((topic) => (
|
51
|
-
<TopicItem
|
50
|
+
{generationTopics.map((topic, index) => (
|
51
|
+
<TopicItem
|
52
|
+
key={topic.id}
|
53
|
+
showMoreInfo={showMoreInfo}
|
54
|
+
style={{
|
55
|
+
padding:
|
56
|
+
// fix the avatar border is clipped by overflow hidden
|
57
|
+
generationTopics.length === 1
|
58
|
+
? '4px 0'
|
59
|
+
: index === generationTopics.length - 1
|
60
|
+
? '0 0 4px'
|
61
|
+
: '0',
|
62
|
+
}}
|
63
|
+
topic={topic}
|
64
|
+
/>
|
52
65
|
))}
|
53
66
|
</Flexbox>
|
54
67
|
</Flexbox>
|
@@ -111,13 +111,14 @@ export const getThumbnailMaxWidth = (
|
|
111
111
|
): number => {
|
112
112
|
const dimensions = getImageDimensions(generation, generationBatch);
|
113
113
|
|
114
|
-
// Return default width if
|
115
|
-
if (!dimensions.
|
114
|
+
// Return default width if no dimension information is available
|
115
|
+
if (!dimensions.aspectRatio) {
|
116
116
|
return DEFAULT_MAX_ITEM_WIDTH;
|
117
117
|
}
|
118
118
|
|
119
|
-
|
120
|
-
const
|
119
|
+
// Parse aspect ratio string (format: "16 / 9")
|
120
|
+
const [widthStr, heightStr] = dimensions.aspectRatio.split(' / ');
|
121
|
+
const aspectRatio = Number(widthStr) / Number(heightStr);
|
121
122
|
|
122
123
|
// Apply screen height constraint (half of screen height)
|
123
124
|
// Note: window.innerHeight is safe to use here as this function is client-side only
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import type { Simplify } from 'type-fest';
|
2
2
|
import { z } from 'zod';
|
3
3
|
|
4
|
-
|
4
|
+
import { MAX_SEED } from '@/const/image';
|
5
5
|
|
6
6
|
// 定义顶层的元规范 - 平铺结构
|
7
7
|
export const ModelParamsMetaSchema = z.object({
|
@@ -40,6 +40,7 @@ export const ModelParamsMetaSchema = z.object({
|
|
40
40
|
.object({
|
41
41
|
default: z.string().nullable().optional(),
|
42
42
|
description: z.string().optional(),
|
43
|
+
maxFileSize: z.number().optional(),
|
43
44
|
type: z.tuple([z.literal('string'), z.literal('null')]).optional(),
|
44
45
|
})
|
45
46
|
.optional(),
|
@@ -48,6 +49,8 @@ export const ModelParamsMetaSchema = z.object({
|
|
48
49
|
.object({
|
49
50
|
default: z.array(z.string()),
|
50
51
|
description: z.string().optional(),
|
52
|
+
maxCount: z.number().optional(),
|
53
|
+
maxFileSize: z.number().optional(),
|
51
54
|
type: z.literal('array').optional(),
|
52
55
|
})
|
53
56
|
.optional(),
|
@@ -133,6 +133,14 @@ export default {
|
|
133
133
|
progress: {
|
134
134
|
uploadingWithCount: '{{completed}}/{{total}} 已上传',
|
135
135
|
},
|
136
|
+
validation: {
|
137
|
+
fileSizeExceeded: 'File size exceeded limit',
|
138
|
+
fileSizeExceededDetail:
|
139
|
+
'{{fileName}} ({{actualSize}}) exceeds the maximum size limit of {{maxSize}}',
|
140
|
+
fileSizeExceededMultiple:
|
141
|
+
'{{count}} files exceed the maximum size limit of {{maxSize}}: {{fileList}}',
|
142
|
+
imageCountExceeded: 'Image count exceeded limit',
|
143
|
+
},
|
136
144
|
},
|
137
145
|
OllamaSetupGuide: {
|
138
146
|
action: {
|
@@ -4,7 +4,7 @@ import mime from 'mime';
|
|
4
4
|
import { nanoid } from 'nanoid';
|
5
5
|
import sharp from 'sharp';
|
6
6
|
|
7
|
-
import { IMAGE_GENERATION_CONFIG } from '@/const/
|
7
|
+
import { IMAGE_GENERATION_CONFIG } from '@/const/image';
|
8
8
|
import { LobeChatDatabase } from '@/database/type';
|
9
9
|
import { parseDataUri } from '@/libs/model-runtime/utils/uriParser';
|
10
10
|
import { FileService } from '@/server/services/file';
|
@@ -6,7 +6,7 @@ import { getModelListByType } from '../action';
|
|
6
6
|
|
7
7
|
// Mock getModelPropertyWithFallback
|
8
8
|
vi.mock('@/utils/getFallbackModelProperty', () => ({
|
9
|
-
getModelPropertyWithFallback: vi.fn().
|
9
|
+
getModelPropertyWithFallback: vi.fn().mockResolvedValue({ size: '1024x1024' }),
|
10
10
|
}));
|
11
11
|
|
12
12
|
describe('getModelListByType', () => {
|
@@ -48,9 +48,9 @@ describe('getModelListByType', () => {
|
|
48
48
|
abilities: {} as ModelAbilities,
|
49
49
|
displayName: 'DALL-E 3',
|
50
50
|
enabled: true,
|
51
|
-
parameters: {
|
51
|
+
parameters: {
|
52
52
|
prompt: { default: '' },
|
53
|
-
size: { default: '1024x1024', enum: ['512x512', '1024x1024', '1536x1536'] }
|
53
|
+
size: { default: '1024x1024', enum: ['512x512', '1024x1024', '1536x1536'] },
|
54
54
|
},
|
55
55
|
},
|
56
56
|
{
|
@@ -66,15 +66,15 @@ describe('getModelListByType', () => {
|
|
66
66
|
const allModels = [...mockChatModels, ...mockImageModels];
|
67
67
|
|
68
68
|
describe('basic functionality', () => {
|
69
|
-
it('should filter models by providerId and type correctly', () => {
|
70
|
-
const result = getModelListByType(allModels, 'openai', 'chat');
|
69
|
+
it('should filter models by providerId and type correctly', async () => {
|
70
|
+
const result = await getModelListByType(allModels, 'openai', 'chat');
|
71
71
|
|
72
72
|
expect(result).toHaveLength(2);
|
73
73
|
expect(result.map((m) => m.id)).toEqual(['gpt-4', 'gpt-3.5-turbo']);
|
74
74
|
});
|
75
75
|
|
76
|
-
it('should return correct model structure', () => {
|
77
|
-
const result = getModelListByType(allModels, 'openai', 'chat');
|
76
|
+
it('should return correct model structure', async () => {
|
77
|
+
const result = await getModelListByType(allModels, 'openai', 'chat');
|
78
78
|
|
79
79
|
expect(result[0]).toEqual({
|
80
80
|
abilities: { functionCall: true, files: true },
|
@@ -84,23 +84,23 @@ describe('getModelListByType', () => {
|
|
84
84
|
});
|
85
85
|
});
|
86
86
|
|
87
|
-
it('should add parameters field for image models', () => {
|
88
|
-
const result = getModelListByType(allModels, 'openai', 'image');
|
87
|
+
it('should add parameters field for image models', async () => {
|
88
|
+
const result = await getModelListByType(allModels, 'openai', 'image');
|
89
89
|
|
90
90
|
expect(result[0]).toEqual({
|
91
91
|
abilities: {},
|
92
92
|
contextWindowTokens: undefined,
|
93
93
|
displayName: 'DALL-E 3',
|
94
94
|
id: 'dall-e-3',
|
95
|
-
parameters: {
|
95
|
+
parameters: {
|
96
96
|
prompt: { default: '' },
|
97
|
-
size: { default: '1024x1024', enum: ['512x512', '1024x1024', '1536x1536'] }
|
97
|
+
size: { default: '1024x1024', enum: ['512x512', '1024x1024', '1536x1536'] },
|
98
98
|
},
|
99
99
|
});
|
100
100
|
});
|
101
101
|
|
102
|
-
it('should use fallback parameters for image models without parameters', () => {
|
103
|
-
const result = getModelListByType(allModels, 'midjourney', 'image');
|
102
|
+
it('should use fallback parameters for image models without parameters', async () => {
|
103
|
+
const result = await getModelListByType(allModels, 'midjourney', 'image');
|
104
104
|
|
105
105
|
expect(result[0]).toEqual({
|
106
106
|
abilities: {},
|
@@ -113,22 +113,22 @@ describe('getModelListByType', () => {
|
|
113
113
|
});
|
114
114
|
|
115
115
|
describe('edge cases', () => {
|
116
|
-
it('should handle empty model list', () => {
|
117
|
-
const result = getModelListByType([], 'openai', 'chat');
|
116
|
+
it('should handle empty model list', async () => {
|
117
|
+
const result = await getModelListByType([], 'openai', 'chat');
|
118
118
|
expect(result).toEqual([]);
|
119
119
|
});
|
120
120
|
|
121
|
-
it('should handle non-existent providerId', () => {
|
122
|
-
const result = getModelListByType(allModels, 'nonexistent', 'chat');
|
121
|
+
it('should handle non-existent providerId', async () => {
|
122
|
+
const result = await getModelListByType(allModels, 'nonexistent', 'chat');
|
123
123
|
expect(result).toEqual([]);
|
124
124
|
});
|
125
125
|
|
126
|
-
it('should handle non-existent type', () => {
|
127
|
-
const result = getModelListByType(allModels, 'openai', 'nonexistent');
|
126
|
+
it('should handle non-existent type', async () => {
|
127
|
+
const result = await getModelListByType(allModels, 'openai', 'nonexistent');
|
128
128
|
expect(result).toEqual([]);
|
129
129
|
});
|
130
130
|
|
131
|
-
it('should handle missing displayName', () => {
|
131
|
+
it('should handle missing displayName', async () => {
|
132
132
|
const modelsWithoutDisplayName: EnabledAiModel[] = [
|
133
133
|
{
|
134
134
|
id: 'test-model',
|
@@ -139,11 +139,11 @@ describe('getModelListByType', () => {
|
|
139
139
|
},
|
140
140
|
];
|
141
141
|
|
142
|
-
const result = getModelListByType(modelsWithoutDisplayName, 'test', 'chat');
|
142
|
+
const result = await getModelListByType(modelsWithoutDisplayName, 'test', 'chat');
|
143
143
|
expect(result[0].displayName).toBe('');
|
144
144
|
});
|
145
145
|
|
146
|
-
it('should handle missing abilities', () => {
|
146
|
+
it('should handle missing abilities', async () => {
|
147
147
|
const modelsWithoutAbilities: EnabledAiModel[] = [
|
148
148
|
{
|
149
149
|
id: 'test-model',
|
@@ -153,13 +153,13 @@ describe('getModelListByType', () => {
|
|
153
153
|
} as EnabledAiModel,
|
154
154
|
];
|
155
155
|
|
156
|
-
const result = getModelListByType(modelsWithoutAbilities, 'test', 'chat');
|
156
|
+
const result = await getModelListByType(modelsWithoutAbilities, 'test', 'chat');
|
157
157
|
expect(result[0].abilities).toEqual({});
|
158
158
|
});
|
159
159
|
});
|
160
160
|
|
161
161
|
describe('deduplication', () => {
|
162
|
-
it('should remove duplicate model IDs', () => {
|
162
|
+
it('should remove duplicate model IDs', async () => {
|
163
163
|
const duplicateModels: EnabledAiModel[] = [
|
164
164
|
{
|
165
165
|
id: 'gpt-4',
|
@@ -179,7 +179,7 @@ describe('getModelListByType', () => {
|
|
179
179
|
},
|
180
180
|
];
|
181
181
|
|
182
|
-
const result = getModelListByType(duplicateModels, 'openai', 'chat');
|
182
|
+
const result = await getModelListByType(duplicateModels, 'openai', 'chat');
|
183
183
|
|
184
184
|
expect(result).toHaveLength(1);
|
185
185
|
expect(result[0].displayName).toBe('GPT-4 Version 1');
|
@@ -187,7 +187,7 @@ describe('getModelListByType', () => {
|
|
187
187
|
});
|
188
188
|
|
189
189
|
describe('type casting', () => {
|
190
|
-
it('should handle image model type casting correctly', () => {
|
190
|
+
it('should handle image model type casting correctly', async () => {
|
191
191
|
const imageModel: EnabledAiModel[] = [
|
192
192
|
{
|
193
193
|
id: 'dall-e-3',
|
@@ -200,14 +200,14 @@ describe('getModelListByType', () => {
|
|
200
200
|
} as any, // Simulate AIImageModelCard type
|
201
201
|
];
|
202
202
|
|
203
|
-
const result = getModelListByType(imageModel, 'openai', 'image');
|
203
|
+
const result = await getModelListByType(imageModel, 'openai', 'image');
|
204
204
|
|
205
205
|
expect(result[0]).toHaveProperty('parameters');
|
206
206
|
expect(result[0].parameters).toEqual({ size: '1024x1024' });
|
207
207
|
});
|
208
208
|
|
209
|
-
it('should not add parameters field for non-image models', () => {
|
210
|
-
const result = getModelListByType(mockChatModels, 'openai', 'chat');
|
209
|
+
it('should not add parameters field for non-image models', async () => {
|
210
|
+
const result = await getModelListByType(mockChatModels, 'openai', 'chat');
|
211
211
|
|
212
212
|
result.forEach((model) => {
|
213
213
|
expect(model).not.toHaveProperty('parameters');
|
@@ -6,7 +6,12 @@ import { isDeprecatedEdition, isDesktop, isUsePgliteDB } from '@/const/version';
|
|
6
6
|
import { useClientDataSWR } from '@/libs/swr';
|
7
7
|
import { aiProviderService } from '@/services/aiProvider';
|
8
8
|
import { AiInfraStore } from '@/store/aiInfra/store';
|
9
|
-
import {
|
9
|
+
import {
|
10
|
+
AIImageModelCard,
|
11
|
+
EnabledAiModel,
|
12
|
+
LobeDefaultAiModelListItem,
|
13
|
+
ModelAbilities,
|
14
|
+
} from '@/types/aiModel';
|
10
15
|
import {
|
11
16
|
AiProviderDetailItem,
|
12
17
|
AiProviderListItem,
|
@@ -15,6 +20,7 @@ import {
|
|
15
20
|
AiProviderSourceEnum,
|
16
21
|
CreateAiProviderParams,
|
17
22
|
EnabledProvider,
|
23
|
+
EnabledProviderWithModels,
|
18
24
|
UpdateAiProviderConfigParams,
|
19
25
|
UpdateAiProviderParams,
|
20
26
|
} from '@/types/aiProvider';
|
@@ -23,10 +29,17 @@ import { getModelPropertyWithFallback } from '@/utils/getFallbackModelProperty';
|
|
23
29
|
/**
|
24
30
|
* Get models by provider ID and type, with proper formatting and deduplication
|
25
31
|
*/
|
26
|
-
export const getModelListByType = (
|
27
|
-
|
28
|
-
|
29
|
-
|
32
|
+
export const getModelListByType = async (
|
33
|
+
enabledAiModels: EnabledAiModel[],
|
34
|
+
providerId: string,
|
35
|
+
type: string,
|
36
|
+
) => {
|
37
|
+
const filteredModels = enabledAiModels.filter(
|
38
|
+
(model) => model.providerId === providerId && model.type === type,
|
39
|
+
);
|
40
|
+
|
41
|
+
const models = await Promise.all(
|
42
|
+
filteredModels.map(async (model) => ({
|
30
43
|
abilities: (model.abilities || {}) as ModelAbilities,
|
31
44
|
contextWindowTokens: model.contextWindowTokens,
|
32
45
|
displayName: model.displayName ?? '',
|
@@ -34,13 +47,31 @@ export const getModelListByType = (enabledAiModels: any[], providerId: string, t
|
|
34
47
|
...(model.type === 'image' && {
|
35
48
|
parameters:
|
36
49
|
(model as AIImageModelCard).parameters ||
|
37
|
-
getModelPropertyWithFallback(model.id, 'parameters'),
|
50
|
+
(await getModelPropertyWithFallback(model.id, 'parameters')),
|
38
51
|
}),
|
39
|
-
}))
|
52
|
+
})),
|
53
|
+
);
|
40
54
|
|
41
55
|
return uniqBy(models, 'id');
|
42
56
|
};
|
43
57
|
|
58
|
+
/**
|
59
|
+
* Build provider model lists with proper async handling
|
60
|
+
*/
|
61
|
+
const buildProviderModelLists = async (
|
62
|
+
providers: EnabledProvider[],
|
63
|
+
enabledAiModels: EnabledAiModel[],
|
64
|
+
type: 'chat' | 'image',
|
65
|
+
) => {
|
66
|
+
return Promise.all(
|
67
|
+
providers.map(async (provider) => ({
|
68
|
+
...provider,
|
69
|
+
children: await getModelListByType(enabledAiModels, provider.id, type),
|
70
|
+
name: provider.name || provider.id,
|
71
|
+
})),
|
72
|
+
);
|
73
|
+
};
|
74
|
+
|
44
75
|
enum AiProviderSwrKey {
|
45
76
|
fetchAiProviderItem = 'FETCH_AI_PROVIDER_ITEM',
|
46
77
|
fetchAiProviderList = 'FETCH_AI_PROVIDER',
|
@@ -49,6 +80,8 @@ enum AiProviderSwrKey {
|
|
49
80
|
|
50
81
|
type AiProviderRuntimeStateWithBuiltinModels = AiProviderRuntimeState & {
|
51
82
|
builtinAiModelList: LobeDefaultAiModelListItem[];
|
83
|
+
enabledChatModelList?: EnabledProviderWithModels[];
|
84
|
+
enabledImageModelList?: EnabledProviderWithModels[];
|
52
85
|
};
|
53
86
|
|
54
87
|
export interface AiProviderAction {
|
@@ -203,31 +236,54 @@ export const createAiProviderSlice: StateCreator<
|
|
203
236
|
|
204
237
|
if (isLogin) {
|
205
238
|
const data = await aiProviderService.getAiProviderRuntimeState();
|
239
|
+
|
240
|
+
// Build model lists with proper async handling
|
241
|
+
const [enabledChatModelList, enabledImageModelList] = await Promise.all([
|
242
|
+
buildProviderModelLists(data.enabledChatAiProviders, data.enabledAiModels, 'chat'),
|
243
|
+
buildProviderModelLists(data.enabledImageAiProviders, data.enabledAiModels, 'image'),
|
244
|
+
]);
|
245
|
+
|
206
246
|
return {
|
207
247
|
...data,
|
208
248
|
builtinAiModelList,
|
249
|
+
enabledChatModelList,
|
250
|
+
enabledImageModelList,
|
209
251
|
};
|
210
252
|
}
|
211
253
|
|
212
254
|
const enabledAiProviders: EnabledProvider[] = DEFAULT_MODEL_PROVIDER_LIST.filter(
|
213
255
|
(provider) => provider.enabled,
|
214
|
-
).map((item) => ({ id: item.id, name: item.name, source:
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
256
|
+
).map((item) => ({ id: item.id, name: item.name, source: AiProviderSourceEnum.Builtin }));
|
257
|
+
|
258
|
+
const enabledChatAiProviders = enabledAiProviders.filter((provider) => {
|
259
|
+
return builtinAiModelList.some(
|
260
|
+
(model) => model.providerId === provider.id && model.type === 'chat',
|
261
|
+
);
|
262
|
+
});
|
263
|
+
|
264
|
+
const enabledImageAiProviders = enabledAiProviders
|
265
|
+
.filter((provider) => {
|
220
266
|
return builtinAiModelList.some(
|
221
|
-
(model) => model.providerId === provider.id && model.type === '
|
267
|
+
(model) => model.providerId === provider.id && model.type === 'image',
|
222
268
|
);
|
223
|
-
})
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
269
|
+
})
|
270
|
+
.map((item) => ({ id: item.id, name: item.name, source: AiProviderSourceEnum.Builtin }));
|
271
|
+
|
272
|
+
// Build model lists for non-login state as well
|
273
|
+
const enabledAiModels = builtinAiModelList.filter((m) => m.enabled);
|
274
|
+
const [enabledChatModelList, enabledImageModelList] = await Promise.all([
|
275
|
+
buildProviderModelLists(enabledChatAiProviders, enabledAiModels, 'chat'),
|
276
|
+
buildProviderModelLists(enabledImageAiProviders, enabledAiModels, 'image'),
|
277
|
+
]);
|
278
|
+
|
279
|
+
return {
|
280
|
+
builtinAiModelList,
|
281
|
+
enabledAiModels,
|
282
|
+
enabledAiProviders,
|
283
|
+
enabledChatAiProviders,
|
284
|
+
enabledChatModelList,
|
285
|
+
enabledImageAiProviders,
|
286
|
+
enabledImageModelList,
|
231
287
|
runtimeConfig: {},
|
232
288
|
};
|
233
289
|
},
|
@@ -236,26 +292,14 @@ export const createAiProviderSlice: StateCreator<
|
|
236
292
|
onSuccess: (data) => {
|
237
293
|
if (!data) return;
|
238
294
|
|
239
|
-
const enabledChatModelList = data.enabledChatAiProviders.map((provider) => ({
|
240
|
-
...provider,
|
241
|
-
children: getModelListByType(data.enabledAiModels, provider.id, 'chat'),
|
242
|
-
name: provider.name || provider.id,
|
243
|
-
}));
|
244
|
-
|
245
|
-
const enabledImageModelList = data.enabledImageAiProviders.map((provider) => ({
|
246
|
-
...provider,
|
247
|
-
children: getModelListByType(data.enabledAiModels, provider.id, 'image'),
|
248
|
-
name: provider.name || provider.id,
|
249
|
-
}));
|
250
|
-
|
251
295
|
set(
|
252
296
|
{
|
253
297
|
aiProviderRuntimeConfig: data.runtimeConfig,
|
254
298
|
builtinAiModelList: data.builtinAiModelList,
|
255
299
|
enabledAiModels: data.enabledAiModels,
|
256
300
|
enabledAiProviders: data.enabledAiProviders,
|
257
|
-
enabledChatModelList,
|
258
|
-
enabledImageModelList,
|
301
|
+
enabledChatModelList: data.enabledChatModelList || [],
|
302
|
+
enabledImageModelList: data.enabledImageModelList || [],
|
259
303
|
},
|
260
304
|
false,
|
261
305
|
'useFetchAiProviderRuntimeState',
|
@@ -7,6 +7,7 @@ import { messageService } from '@/services/message';
|
|
7
7
|
import { imageGenerationService } from '@/services/textToImage';
|
8
8
|
import { uploadService } from '@/services/upload';
|
9
9
|
import { chatSelectors } from '@/store/chat/selectors';
|
10
|
+
import { useFileStore } from '@/store/file';
|
10
11
|
import { ChatMessage } from '@/types/message';
|
11
12
|
import { DallEImageItem } from '@/types/tool/dalle';
|
12
13
|
|
@@ -41,24 +42,28 @@ describe('chatToolSlice - dalle', () => {
|
|
41
42
|
vi.spyOn(uploadService, 'getImageFileByUrlWithCORS').mockResolvedValue(
|
42
43
|
new File(['1'], 'file.png', { type: 'image/png' }),
|
43
44
|
);
|
44
|
-
|
45
|
-
|
46
|
-
vi.spyOn(
|
47
|
-
|
48
|
-
|
49
|
-
|
45
|
+
|
46
|
+
// Mock the new uploadWithProgress method from useFileStore
|
47
|
+
vi.spyOn(useFileStore, 'getState').mockReturnValue({
|
48
|
+
uploadWithProgress: vi.fn().mockResolvedValue({
|
49
|
+
id: mockId,
|
50
|
+
url: '',
|
51
|
+
dimensions: { width: 512, height: 512 },
|
52
|
+
filename: 'file.png',
|
53
|
+
}),
|
54
|
+
} as any);
|
55
|
+
|
56
|
+
// Mock store methods that are called in the implementation
|
50
57
|
vi.spyOn(result.current, 'toggleDallEImageLoading');
|
51
|
-
vi.spyOn(
|
52
|
-
|
53
|
-
);
|
58
|
+
vi.spyOn(result.current, 'updatePluginState').mockResolvedValue(undefined);
|
59
|
+
vi.spyOn(result.current, 'internal_updateMessageContent').mockResolvedValue(undefined);
|
54
60
|
|
55
61
|
await act(async () => {
|
56
62
|
await result.current.generateImageFromPrompts(prompts, messageId);
|
57
63
|
});
|
58
64
|
// For each prompt, loading is toggled on and then off
|
59
65
|
expect(imageGenerationService.generateImage).toHaveBeenCalledTimes(prompts.length);
|
60
|
-
|
61
|
-
expect(uploadService.uploadToClientS3).toHaveBeenCalledTimes(prompts.length);
|
66
|
+
expect(useFileStore.getState().uploadWithProgress).toHaveBeenCalledTimes(prompts.length);
|
62
67
|
expect(result.current.toggleDallEImageLoading).toHaveBeenCalledTimes(prompts.length * 2);
|
63
68
|
});
|
64
69
|
});
|
@@ -74,7 +79,7 @@ describe('chatToolSlice - dalle', () => {
|
|
74
79
|
draft[0].previewUrl = 'new-url';
|
75
80
|
draft[0].imageId = 'new-id';
|
76
81
|
};
|
77
|
-
vi.spyOn(result.current, 'internal_updateMessageContent');
|
82
|
+
vi.spyOn(result.current, 'internal_updateMessageContent').mockResolvedValue(undefined);
|
78
83
|
|
79
84
|
// 模拟 getMessageById 返回消息内容
|
80
85
|
vi.spyOn(chatSelectors, 'getMessageById').mockImplementationOnce(
|
@@ -105,7 +110,9 @@ describe('chatToolSlice - dalle', () => {
|
|
105
110
|
const data = [{ prompt: 'prompt 1' }, { prompt: 'prompt 2' }] as DallEImageItem[];
|
106
111
|
|
107
112
|
// Mock generateImageFromPrompts
|
108
|
-
const generateImageFromPromptsMock = vi
|
113
|
+
const generateImageFromPromptsMock = vi
|
114
|
+
.spyOn(result.current, 'generateImageFromPrompts')
|
115
|
+
.mockResolvedValue(undefined);
|
109
116
|
|
110
117
|
await act(async () => {
|
111
118
|
await result.current.text2image(id, data);
|