@lobehub/chat 1.77.15 → 1.77.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. package/CHANGELOG.md +50 -0
  2. package/changelog/v1.json +18 -0
  3. package/docker-compose/local/docker-compose.yml +2 -1
  4. package/locales/ar/components.json +4 -0
  5. package/locales/ar/modelProvider.json +1 -0
  6. package/locales/ar/models.json +8 -5
  7. package/locales/bg-BG/components.json +4 -0
  8. package/locales/bg-BG/modelProvider.json +1 -0
  9. package/locales/bg-BG/models.json +8 -5
  10. package/locales/de-DE/components.json +4 -0
  11. package/locales/de-DE/modelProvider.json +1 -0
  12. package/locales/de-DE/models.json +8 -5
  13. package/locales/en-US/components.json +4 -0
  14. package/locales/en-US/modelProvider.json +1 -0
  15. package/locales/en-US/models.json +8 -5
  16. package/locales/es-ES/components.json +4 -0
  17. package/locales/es-ES/modelProvider.json +1 -0
  18. package/locales/es-ES/models.json +7 -4
  19. package/locales/fa-IR/components.json +4 -0
  20. package/locales/fa-IR/modelProvider.json +1 -0
  21. package/locales/fa-IR/models.json +7 -4
  22. package/locales/fr-FR/components.json +4 -0
  23. package/locales/fr-FR/modelProvider.json +1 -0
  24. package/locales/fr-FR/models.json +8 -5
  25. package/locales/it-IT/components.json +4 -0
  26. package/locales/it-IT/modelProvider.json +1 -0
  27. package/locales/it-IT/models.json +7 -4
  28. package/locales/ja-JP/components.json +4 -0
  29. package/locales/ja-JP/modelProvider.json +1 -0
  30. package/locales/ja-JP/models.json +8 -5
  31. package/locales/ko-KR/components.json +4 -0
  32. package/locales/ko-KR/modelProvider.json +1 -0
  33. package/locales/ko-KR/models.json +8 -5
  34. package/locales/nl-NL/components.json +4 -0
  35. package/locales/nl-NL/modelProvider.json +1 -0
  36. package/locales/nl-NL/models.json +8 -5
  37. package/locales/pl-PL/components.json +4 -0
  38. package/locales/pl-PL/modelProvider.json +1 -0
  39. package/locales/pl-PL/models.json +8 -5
  40. package/locales/pt-BR/components.json +4 -0
  41. package/locales/pt-BR/modelProvider.json +1 -0
  42. package/locales/pt-BR/models.json +7 -4
  43. package/locales/ru-RU/components.json +4 -0
  44. package/locales/ru-RU/modelProvider.json +1 -0
  45. package/locales/ru-RU/models.json +7 -4
  46. package/locales/tr-TR/components.json +4 -0
  47. package/locales/tr-TR/modelProvider.json +1 -0
  48. package/locales/tr-TR/models.json +8 -5
  49. package/locales/vi-VN/components.json +4 -0
  50. package/locales/vi-VN/modelProvider.json +1 -0
  51. package/locales/vi-VN/models.json +8 -5
  52. package/locales/zh-CN/components.json +4 -0
  53. package/locales/zh-CN/modelProvider.json +1 -0
  54. package/locales/zh-CN/models.json +9 -6
  55. package/locales/zh-TW/components.json +4 -0
  56. package/locales/zh-TW/modelProvider.json +1 -0
  57. package/locales/zh-TW/models.json +7 -4
  58. package/package.json +1 -1
  59. package/src/app/(backend)/webapi/models/[provider]/pull/route.ts +34 -0
  60. package/src/app/(backend)/webapi/{chat/models → models}/[provider]/route.ts +1 -2
  61. package/src/app/[variants]/(main)/settings/llm/ProviderList/Ollama/index.tsx +0 -7
  62. package/src/app/[variants]/(main)/settings/provider/(detail)/ollama/CheckError.tsx +1 -1
  63. package/src/components/FormAction/index.tsx +1 -1
  64. package/src/database/models/__tests__/aiProvider.test.ts +100 -0
  65. package/src/database/models/aiProvider.ts +11 -1
  66. package/src/features/Conversation/Error/OllamaBizError/InvalidOllamaModel.tsx +43 -0
  67. package/src/features/Conversation/Error/OllamaDesktopSetupGuide/index.tsx +61 -0
  68. package/src/features/Conversation/Error/index.tsx +7 -0
  69. package/src/features/DevPanel/SystemInspector/ServerConfig.tsx +18 -2
  70. package/src/features/DevPanel/SystemInspector/index.tsx +25 -6
  71. package/src/features/OllamaModelDownloader/index.tsx +149 -0
  72. package/src/libs/agent-runtime/AgentRuntime.ts +6 -0
  73. package/src/libs/agent-runtime/BaseAI.ts +7 -0
  74. package/src/libs/agent-runtime/ollama/index.ts +84 -2
  75. package/src/libs/agent-runtime/openrouter/__snapshots__/index.test.ts.snap +24 -3263
  76. package/src/libs/agent-runtime/openrouter/fixtures/frontendModels.json +25 -0
  77. package/src/libs/agent-runtime/openrouter/fixtures/models.json +0 -3353
  78. package/src/libs/agent-runtime/openrouter/index.test.ts +56 -1
  79. package/src/libs/agent-runtime/openrouter/index.ts +9 -4
  80. package/src/libs/agent-runtime/types/index.ts +1 -0
  81. package/src/libs/agent-runtime/types/model.ts +44 -0
  82. package/src/libs/agent-runtime/utils/streams/index.ts +1 -0
  83. package/src/libs/agent-runtime/utils/streams/model.ts +110 -0
  84. package/src/locales/default/components.ts +4 -0
  85. package/src/locales/default/modelProvider.ts +1 -0
  86. package/src/server/routers/async/file.ts +3 -4
  87. package/src/server/routers/lambda/file.ts +8 -11
  88. package/src/server/routers/lambda/importer.ts +3 -4
  89. package/src/server/routers/lambda/message.ts +9 -3
  90. package/src/server/routers/lambda/ragEval.ts +5 -6
  91. package/src/server/services/file/impls/index.ts +12 -0
  92. package/src/server/services/file/impls/s3.test.ts +110 -0
  93. package/src/server/services/file/impls/s3.ts +60 -0
  94. package/src/server/services/file/impls/type.ts +44 -0
  95. package/src/server/services/file/index.ts +65 -0
  96. package/src/services/__tests__/models.test.ts +21 -0
  97. package/src/services/_url.ts +4 -1
  98. package/src/services/chat.ts +1 -1
  99. package/src/services/electron/__tests__/devtools.test.ts +34 -0
  100. package/src/services/models.ts +153 -7
  101. package/src/store/aiInfra/slices/aiModel/action.ts +1 -1
  102. package/src/store/aiInfra/slices/aiProvider/action.ts +2 -1
  103. package/src/store/user/slices/modelList/action.test.ts +2 -2
  104. package/src/store/user/slices/modelList/action.ts +1 -1
  105. package/src/app/[variants]/(main)/settings/llm/ProviderList/Ollama/Checker.tsx +0 -73
  106. package/src/app/[variants]/(main)/settings/provider/(detail)/ollama/OllamaModelDownloader/index.tsx +0 -127
  107. package/src/features/Conversation/Error/OllamaBizError/InvalidOllamaModel/index.tsx +0 -154
  108. package/src/features/Conversation/Error/OllamaBizError/InvalidOllamaModel/useDownloadMonitor.ts +0 -29
  109. package/src/server/utils/files.test.ts +0 -37
  110. package/src/server/utils/files.ts +0 -20
  111. package/src/services/__tests__/ollama.test.ts +0 -28
  112. package/src/services/ollama.ts +0 -83
  113. /package/src/{app/[variants]/(main)/settings/provider/(detail)/ollama → features}/OllamaModelDownloader/useDownloadMonitor.ts +0 -0
@@ -4,6 +4,7 @@ import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
4
4
  import { LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime';
5
5
  import { testProvider } from '@/libs/agent-runtime/providerTestUtils';
6
6
 
7
+ import frontendModels from './fixtures/frontendModels.json';
7
8
  import models from './fixtures/models.json';
8
9
  import { LobeOpenRouterAI } from './index';
9
10
 
@@ -137,12 +138,66 @@ describe('LobeOpenRouterAI', () => {
137
138
  });
138
139
 
139
140
  describe('models', () => {
140
- it('should get models', async () => {
141
+ it('should get models with frontend models data', async () => {
141
142
  // mock the models.list method
142
143
  (instance['client'].models.list as Mock).mockResolvedValue({ data: models });
143
144
 
145
+ // 模拟成功的 fetch 响应
146
+ vi.stubGlobal(
147
+ 'fetch',
148
+ vi.fn().mockResolvedValue({
149
+ ok: true,
150
+ json: vi.fn().mockResolvedValue(frontendModels),
151
+ }),
152
+ );
153
+
154
+ const list = await instance.models();
155
+
156
+ // 验证 fetch 被正确调用
157
+ expect(fetch).toHaveBeenCalledWith('https://openrouter.ai/api/frontend/models');
158
+
159
+ // 验证模型列表中包含了从前端 API 获取的额外信息
160
+ const reflectionModel = list.find((model) => model.id === 'mattshumer/reflection-70b:free');
161
+ expect(reflectionModel).toBeDefined();
162
+ expect(reflectionModel?.reasoning).toBe(true);
163
+ expect(reflectionModel?.functionCall).toBe(true);
164
+
165
+ expect(list).toMatchSnapshot();
166
+ });
167
+
168
+ it('should handle fetch failure gracefully', async () => {
169
+ // mock the models.list method
170
+ (instance['client'].models.list as Mock).mockResolvedValue({ data: models });
171
+
172
+ // 模拟失败的 fetch 响应
173
+ vi.stubGlobal(
174
+ 'fetch',
175
+ vi.fn().mockResolvedValue({
176
+ ok: false,
177
+ }),
178
+ );
179
+
180
+ const list = await instance.models();
181
+
182
+ // 验证即使 fetch 失败,方法仍然能返回有效的模型列表
183
+ expect(fetch).toHaveBeenCalledWith('https://openrouter.ai/api/frontend/models');
184
+ expect(list.length).toBeGreaterThan(0); // 确保返回了模型列表
185
+ expect(list).toMatchSnapshot();
186
+ });
187
+
188
+ it('should handle fetch error gracefully', async () => {
189
+ // mock the models.list method
190
+ (instance['client'].models.list as Mock).mockResolvedValue({ data: models });
191
+
192
+ // 在测试环境中,需要先修改 fetch 的实现,确保错误被捕获
193
+ vi.spyOn(global, 'fetch').mockImplementation(() => {
194
+ throw new Error('Network error');
195
+ });
196
+
144
197
  const list = await instance.models();
145
198
 
199
+ // 验证即使 fetch 出错,方法仍然能返回有效的模型列表
200
+ expect(list.length).toBeGreaterThan(0); // 确保返回了模型列表
146
201
  expect(list).toMatchSnapshot();
147
202
  });
148
203
  });
@@ -54,11 +54,16 @@ export const LobeOpenRouterAI = LobeOpenAICompatibleFactory({
54
54
  const modelsPage = (await client.models.list()) as any;
55
55
  const modelList: OpenRouterModelCard[] = modelsPage.data;
56
56
 
57
- const response = await fetch('https://openrouter.ai/api/frontend/models');
58
57
  const modelsExtraInfo: OpenRouterModelExtraInfo[] = [];
59
- if (response.ok) {
60
- const data = await response.json();
61
- modelsExtraInfo.push(...data['data']);
58
+ try {
59
+ const response = await fetch('https://openrouter.ai/api/frontend/models');
60
+ if (response.ok) {
61
+ const data = await response.json();
62
+ modelsExtraInfo.push(...data['data']);
63
+ }
64
+ } catch (error) {
65
+ // 忽略 fetch 错误,使用空的 modelsExtraInfo 数组继续处理
66
+ console.error('Failed to fetch OpenRouter frontend models:', error);
62
67
  }
63
68
 
64
69
  return modelList
@@ -1,5 +1,6 @@
1
1
  export * from './chat';
2
2
  export * from './embeddings';
3
+ export * from './model';
3
4
  export * from './textToImage';
4
5
  export * from './tts';
5
6
  export * from './type';
@@ -0,0 +1,44 @@
1
+ export interface ModelDetail {
2
+ details?: {
3
+ families?: string[];
4
+ family?: string;
5
+ format?: string;
6
+ parameter_size?: string;
7
+ quantization_level?: string;
8
+ };
9
+ digest?: string;
10
+ id: string;
11
+ modified_at?: Date;
12
+ name?: string;
13
+ size?: number;
14
+ }
15
+
16
+ export interface ModelProgressResponse {
17
+ completed?: number;
18
+ digest?: string;
19
+ model?: string;
20
+ status: string;
21
+ total?: number;
22
+ }
23
+
24
+ export interface ModelsParams {
25
+ name?: string;
26
+ }
27
+
28
+ export interface PullModelParams {
29
+ insecure?: boolean;
30
+ model: string;
31
+ stream?: boolean;
32
+ }
33
+
34
+ export interface ModelDetailParams {
35
+ model: string;
36
+ }
37
+
38
+ export interface DeleteModelParams {
39
+ model: string;
40
+ }
41
+
42
+ export interface ModelRequestOptions {
43
+ signal?: AbortSignal;
44
+ }
@@ -1,6 +1,7 @@
1
1
  export * from './anthropic';
2
2
  export * from './bedrock';
3
3
  export * from './google-ai';
4
+ export * from './model';
4
5
  export * from './ollama';
5
6
  export * from './openai';
6
7
  export * from './protocol';
@@ -0,0 +1,110 @@
1
+ /**
2
+ * 将异步迭代器转换为 JSON 格式的 ReadableStream
3
+ */
4
+ export const createModelPullStream = <
5
+ T extends { completed?: number; digest?: string; status: string; total?: number },
6
+ >(
7
+ iterable: AsyncIterable<T>,
8
+ model: string,
9
+ {
10
+ onCancel, // 新增:取消时调用的回调函数
11
+ }: {
12
+ onCancel?: (reason?: any) => void; // 回调函数签名
13
+ } = {},
14
+ ): ReadableStream => {
15
+ let iterator: AsyncIterator<T>; // 在外部跟踪迭代器以便取消时可以调用 return
16
+
17
+ return new ReadableStream({
18
+ // 实现 cancel 方法
19
+ cancel(reason) {
20
+ // 调用传入的 onCancel 回调,执行外部的清理逻辑(如 client.abort())
21
+ if (onCancel) {
22
+ onCancel(reason);
23
+ }
24
+
25
+ // 尝试优雅地终止迭代器
26
+ // 注意:这依赖于 AsyncIterable 的实现是否支持 return/throw
27
+ if (iterator && typeof iterator.return === 'function') {
28
+ // 不需要 await,让它在后台执行清理
29
+ iterator.return().catch();
30
+ }
31
+ },
32
+ async start(controller) {
33
+ iterator = iterable[Symbol.asyncIterator](); // 获取迭代器
34
+
35
+ const encoder = new TextEncoder();
36
+
37
+ try {
38
+ // eslint-disable-next-line no-constant-condition
39
+ while (true) {
40
+ // 等待下一个数据块或迭代完成
41
+ const { value: progress, done } = await iterator.next();
42
+
43
+ // 如果迭代完成,跳出循环
44
+ if (done) {
45
+ break;
46
+ }
47
+
48
+ // 忽略 'pulling manifest' 状态,因为它不包含进度
49
+ if (progress.status === 'pulling manifest') continue;
50
+
51
+ // 格式化为标准格式并写入流
52
+ const progressData =
53
+ JSON.stringify({
54
+ completed: progress.completed,
55
+ digest: progress.digest,
56
+ model,
57
+ status: progress.status,
58
+ total: progress.total,
59
+ }) + '\n';
60
+
61
+ controller.enqueue(encoder.encode(progressData));
62
+ }
63
+
64
+ // 正常完成
65
+ controller.close();
66
+ } catch (error) {
67
+ // 处理错误
68
+
69
+ // 如果错误是由于中止操作引起的,则静默处理或记录日志,然后尝试关闭流
70
+ if (error instanceof DOMException && error.name === 'AbortError') {
71
+ // 不需要再 enqueue 错误信息,因为连接可能已断开
72
+ // 尝试正常关闭,如果已经取消,controller 可能已关闭或出错
73
+ try {
74
+ controller.enqueue(new TextEncoder().encode(JSON.stringify({ status: 'cancelled' })));
75
+ controller.close();
76
+ } catch {
77
+ // 忽略关闭错误,可能流已经被取消机制处理了
78
+ }
79
+ } else {
80
+ console.error('[createModelPullStream] model download stream error:', error);
81
+ // 对于其他错误,尝试将错误信息发送给客户端
82
+ const errorMessage = error instanceof Error ? error.message : String(error);
83
+ const errorData =
84
+ JSON.stringify({
85
+ error: errorMessage,
86
+ model,
87
+ status: 'error',
88
+ }) + '\n';
89
+
90
+ try {
91
+ // 只有在流还期望数据时才尝试 enqueue
92
+ if (controller.desiredSize !== null && controller.desiredSize > 0) {
93
+ controller.enqueue(encoder.encode(errorData));
94
+ }
95
+ } catch (enqueueError) {
96
+ console.error('[createModelPullStream] Error enqueueing error message:', enqueueError);
97
+ // 如果这里也失败,很可能连接已断开
98
+ }
99
+
100
+ // 尝试关闭流或标记为错误状态
101
+ try {
102
+ controller.close(); // 尝试正常关闭
103
+ } catch {
104
+ controller.error(error); // 如果关闭失败,则将流置于错误状态
105
+ }
106
+ }
107
+ }
108
+ },
109
+ });
110
+ };
@@ -93,6 +93,10 @@ export default {
93
93
  provider: '服务商',
94
94
  },
95
95
  OllamaSetupGuide: {
96
+ action: {
97
+ close: '关闭提示',
98
+ start: '已安装并运行,开始对话',
99
+ },
96
100
  cors: {
97
101
  description: '因浏览器安全限制,你需要为 Ollama 进行跨域配置后方可正常使用。',
98
102
  linux: {
@@ -166,6 +166,7 @@ export default {
166
166
  },
167
167
  download: {
168
168
  desc: 'Ollama 正在下载该模型,请尽量不要关闭本页面。重新下载时将会中断处继续',
169
+ failed: '模型下载失败,请检查网络或者 Ollama 设置后重试',
169
170
  remainingTime: '剩余时间',
170
171
  speed: '下载速度',
171
172
  title: '正在下载模型 {{model}} ',
@@ -14,8 +14,8 @@ import { NewChunkItem, NewEmbeddingsItem } from '@/database/schemas';
14
14
  import { asyncAuthedProcedure, asyncRouter as router } from '@/libs/trpc/async';
15
15
  import { getServerDefaultFilesConfig } from '@/server/globalConfig';
16
16
  import { initAgentRuntimeWithUserPayload } from '@/server/modules/AgentRuntime';
17
- import { S3 } from '@/server/modules/S3';
18
17
  import { ChunkService } from '@/server/services/chunk';
18
+ import { FileService } from '@/server/services/file';
19
19
  import {
20
20
  AsyncTaskError,
21
21
  AsyncTaskErrorType,
@@ -35,6 +35,7 @@ const fileProcedure = asyncAuthedProcedure.use(async (opts) => {
35
35
  chunkService: new ChunkService(ctx.userId),
36
36
  embeddingModel: new EmbeddingModel(ctx.serverDB, ctx.userId),
37
37
  fileModel: new FileModel(ctx.serverDB, ctx.userId),
38
+ fileService: new FileService(),
38
39
  },
39
40
  });
40
41
  });
@@ -162,11 +163,9 @@ export const fileRouter = router({
162
163
  throw new TRPCError({ code: 'BAD_REQUEST', message: 'File not found' });
163
164
  }
164
165
 
165
- const s3 = new S3();
166
-
167
166
  let content: Uint8Array | undefined;
168
167
  try {
169
- content = await s3.getFileByteArray(file.url);
168
+ content = await ctx.fileService.getFileByteArray(file.url);
170
169
  } catch (e) {
171
170
  console.error(e);
172
171
  // if file not found, delete it from db
@@ -7,8 +7,7 @@ import { ChunkModel } from '@/database/models/chunk';
7
7
  import { FileModel } from '@/database/models/file';
8
8
  import { authedProcedure, router } from '@/libs/trpc';
9
9
  import { serverDatabase } from '@/libs/trpc/lambda';
10
- import { S3 } from '@/server/modules/S3';
11
- import { getFullFileUrl } from '@/server/utils/files';
10
+ import { FileService } from '@/server/services/file';
12
11
  import { AsyncTaskStatus, AsyncTaskType } from '@/types/asyncTask';
13
12
  import { FileListItem, QueryFileListSchema, UploadFileSchema } from '@/types/files';
14
13
 
@@ -20,6 +19,7 @@ const fileProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
20
19
  asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId),
21
20
  chunkModel: new ChunkModel(ctx.serverDB, ctx.userId),
22
21
  fileModel: new FileModel(ctx.serverDB, ctx.userId),
22
+ fileService: new FileService(),
23
23
  },
24
24
  });
25
25
  });
@@ -50,7 +50,7 @@ export const fileRouter = router({
50
50
  !isExist,
51
51
  );
52
52
 
53
- return { id, url: await getFullFileUrl(input.url) };
53
+ return { id, url: await ctx.fileService.getFullFileUrl(input.url) };
54
54
  }),
55
55
  findById: fileProcedure
56
56
  .input(
@@ -62,7 +62,7 @@ export const fileRouter = router({
62
62
  const item = await ctx.fileModel.findById(input.id);
63
63
  if (!item) throw new TRPCError({ code: 'BAD_REQUEST', message: 'File not found' });
64
64
 
65
- return { ...item, url: await getFullFileUrl(item?.url) };
65
+ return { ...item, url: await ctx.fileService.getFullFileUrl(item?.url) };
66
66
  }),
67
67
 
68
68
  getFileItemById: fileProcedure
@@ -95,7 +95,7 @@ export const fileRouter = router({
95
95
  embeddingError: embeddingTask?.error,
96
96
  embeddingStatus: embeddingTask?.status as AsyncTaskStatus,
97
97
  finishEmbedding: embeddingTask?.status === AsyncTaskStatus.Success,
98
- url: await getFullFileUrl(item.url!),
98
+ url: await ctx.fileService.getFullFileUrl(item.url!),
99
99
  };
100
100
  }),
101
101
 
@@ -132,7 +132,7 @@ export const fileRouter = router({
132
132
  embeddingError: embeddingTask?.error ?? null,
133
133
  embeddingStatus: embeddingTask?.status as AsyncTaskStatus,
134
134
  finishEmbedding: embeddingTask?.status === AsyncTaskStatus.Success,
135
- url: await getFullFileUrl(item.url!),
135
+ url: await ctx.fileService.getFullFileUrl(item.url!),
136
136
  } as FileListItem;
137
137
  resultFiles.push(fileItem);
138
138
  }
@@ -150,8 +150,7 @@ export const fileRouter = router({
150
150
  if (!file) return;
151
151
 
152
152
  // delele the file from remove from S3 if it is not used by other files
153
- const s3Client = new S3();
154
- await s3Client.deleteFile(file.url!);
153
+ await ctx.fileService.deleteFile(file.url!);
155
154
  }),
156
155
 
157
156
  removeFileAsyncTask: fileProcedure
@@ -184,9 +183,7 @@ export const fileRouter = router({
184
183
  if (!needToRemoveFileList || needToRemoveFileList.length === 0) return;
185
184
 
186
185
  // remove from S3
187
- const s3Client = new S3();
188
-
189
- await s3Client.deleteFiles(needToRemoveFileList.map((file) => file.url!));
186
+ await ctx.fileService.deleteFiles(needToRemoveFileList.map((file) => file.url!));
190
187
  }),
191
188
  });
192
189
 
@@ -4,7 +4,7 @@ import { z } from 'zod';
4
4
  import { DataImporterRepos } from '@/database/repositories/dataImporter';
5
5
  import { authedProcedure, router } from '@/libs/trpc';
6
6
  import { serverDatabase } from '@/libs/trpc/lambda';
7
- import { S3 } from '@/server/modules/S3';
7
+ import { FileService } from '@/server/services/file';
8
8
  import { ImportPgDataStructure } from '@/types/export';
9
9
  import { ImportResultData, ImporterEntryData } from '@/types/importer';
10
10
 
@@ -13,7 +13,7 @@ const importProcedure = authedProcedure.use(serverDatabase).use(async (opts) =>
13
13
  const dataImporterService = new DataImporterRepos(ctx.serverDB, ctx.userId);
14
14
 
15
15
  return opts.next({
16
- ctx: { dataImporterService },
16
+ ctx: { dataImporterService, fileService: new FileService() },
17
17
  });
18
18
  });
19
19
 
@@ -24,8 +24,7 @@ export const importerRouter = router({
24
24
  let data: ImporterEntryData | undefined;
25
25
 
26
26
  try {
27
- const s3 = new S3();
28
- const dataStr = await s3.getFileContent(input.pathname);
27
+ const dataStr = await ctx.fileService.getFileContent(input.pathname);
29
28
  data = JSON.parse(dataStr);
30
29
  } catch {
31
30
  data = undefined;
@@ -5,7 +5,7 @@ import { updateMessagePluginSchema } from '@/database/schemas';
5
5
  import { getServerDB } from '@/database/server';
6
6
  import { authedProcedure, publicProcedure, router } from '@/libs/trpc';
7
7
  import { serverDatabase } from '@/libs/trpc/lambda';
8
- import { getFullFileUrl } from '@/server/utils/files';
8
+ import { FileService } from '@/server/services/file';
9
9
  import { ChatMessage } from '@/types/message';
10
10
  import { BatchTaskResult } from '@/types/service';
11
11
 
@@ -15,7 +15,10 @@ const messageProcedure = authedProcedure.use(serverDatabase).use(async (opts) =>
15
15
  const { ctx } = opts;
16
16
 
17
17
  return opts.next({
18
- ctx: { messageModel: new MessageModel(ctx.serverDB, ctx.userId) },
18
+ ctx: {
19
+ fileService: new FileService(),
20
+ messageModel: new MessageModel(ctx.serverDB, ctx.userId),
21
+ },
19
22
  });
20
23
  });
21
24
 
@@ -99,8 +102,11 @@ export const messageRouter = router({
99
102
  const serverDB = await getServerDB();
100
103
 
101
104
  const messageModel = new MessageModel(serverDB, ctx.userId);
105
+ const fileService = new FileService();
102
106
 
103
- return messageModel.query(input, { postProcessUrl: (path) => getFullFileUrl(path) });
107
+ return messageModel.query(input, {
108
+ postProcessUrl: (path) => fileService.getFullFileUrl(path),
109
+ });
104
110
  }),
105
111
 
106
112
  rankModels: messageProcedure.query(async ({ ctx }) => {
@@ -16,9 +16,8 @@ import {
16
16
  import { authedProcedure, router } from '@/libs/trpc';
17
17
  import { serverDatabase } from '@/libs/trpc/lambda';
18
18
  import { keyVaults } from '@/libs/trpc/middleware/keyVaults';
19
- import { S3 } from '@/server/modules/S3';
20
19
  import { createAsyncServerClient } from '@/server/routers/async';
21
- import { getFullFileUrl } from '@/server/utils/files';
20
+ import { FileService } from '@/server/services/file';
22
21
  import {
23
22
  EvalDatasetRecord,
24
23
  EvalEvaluationStatus,
@@ -42,7 +41,7 @@ const ragEvalProcedure = authedProcedure
42
41
  datasetRecordModel: new EvalDatasetRecordModel(ctx.userId),
43
42
  evaluationModel: new EvalEvaluationModel(ctx.userId),
44
43
  evaluationRecordModel: new EvaluationRecordModel(ctx.userId),
45
- s3: new S3(),
44
+ fileService: new FileService(),
46
45
  },
47
46
  });
48
47
  });
@@ -144,7 +143,7 @@ export const ragEvalRouter = router({
144
143
  }),
145
144
  )
146
145
  .mutation(async ({ input, ctx }) => {
147
- const dataStr = await ctx.s3.getFileContent(input.pathname);
146
+ const dataStr = await ctx.fileService.getFileContent(input.pathname);
148
147
  const items = JSONL.parse<InsertEvalDatasetRecord>(dataStr);
149
148
 
150
149
  insertEvalDatasetRecordSchema.array().parse(items);
@@ -262,12 +261,12 @@ export const ragEvalRouter = router({
262
261
  const filename = `${date}-eval_${evaluation.id}-${evaluation.name}.jsonl`;
263
262
  const path = `rag_eval_records/${filename}`;
264
263
 
265
- await ctx.s3.uploadContent(path, JSONL.stringify(evalRecords));
264
+ await ctx.fileService.uploadContent(path, JSONL.stringify(evalRecords));
266
265
 
267
266
  // 保存数据
268
267
  await ctx.evaluationModel.update(input.id, {
269
268
  status: EvalEvaluationStatus.Success,
270
- evalRecordsUrl: await getFullFileUrl(path),
269
+ evalRecordsUrl: await ctx.fileService.getFullFileUrl(path),
271
270
  });
272
271
  }
273
272
 
@@ -0,0 +1,12 @@
1
+ import { S3StaticFileImpl } from './s3';
2
+ import { FileServiceImpl } from './type';
3
+
4
+ /**
5
+ * 创建文件服务模块
6
+ */
7
+ export const createFileServiceModule = (): FileServiceImpl => {
8
+ // 默认使用 S3 实现
9
+ return new S3StaticFileImpl();
10
+ };
11
+
12
+ export type { FileServiceImpl } from './type';
@@ -0,0 +1,110 @@
1
+ import { beforeEach, describe, expect, it, vi } from 'vitest';
2
+
3
+ import { S3StaticFileImpl } from './s3';
4
+
5
+ const config = {
6
+ S3_ENABLE_PATH_STYLE: false,
7
+ S3_PUBLIC_DOMAIN: 'https://example.com',
8
+ S3_BUCKET: 'my-bucket',
9
+ S3_SET_ACL: true,
10
+ };
11
+
12
+ // 模拟 fileEnv
13
+ vi.mock('@/config/file', () => ({
14
+ get fileEnv() {
15
+ return config;
16
+ },
17
+ }));
18
+
19
+ // 模拟 S3 类
20
+ vi.mock('@/server/modules/S3', () => ({
21
+ S3: vi.fn().mockImplementation(() => ({
22
+ createPreSignedUrlForPreview: vi
23
+ .fn()
24
+ .mockResolvedValue('https://presigned.example.com/test.jpg'),
25
+ getFileContent: vi.fn().mockResolvedValue('file content'),
26
+ getFileByteArray: vi.fn().mockResolvedValue(new Uint8Array([1, 2, 3])),
27
+ deleteFile: vi.fn().mockResolvedValue({}),
28
+ deleteFiles: vi.fn().mockResolvedValue({}),
29
+ createPreSignedUrl: vi.fn().mockResolvedValue('https://upload.example.com/test.jpg'),
30
+ uploadContent: vi.fn().mockResolvedValue({}),
31
+ })),
32
+ }));
33
+
34
+ describe('S3StaticFileImpl', () => {
35
+ let fileService: S3StaticFileImpl;
36
+
37
+ beforeEach(() => {
38
+ fileService = new S3StaticFileImpl();
39
+ });
40
+
41
+ describe('getFullFileUrl', () => {
42
+ it('should return empty string for null or undefined input', async () => {
43
+ expect(await fileService.getFullFileUrl(null)).toBe('');
44
+ expect(await fileService.getFullFileUrl(undefined)).toBe('');
45
+ });
46
+
47
+ it('当S3_SET_ACL为false时应返回预签名URL', async () => {
48
+ config.S3_SET_ACL = false;
49
+ const url = 'path/to/file.jpg';
50
+ expect(await fileService.getFullFileUrl(url)).toBe('https://presigned.example.com/test.jpg');
51
+ config.S3_SET_ACL = true;
52
+ });
53
+
54
+ it('should return correct URL when S3_ENABLE_PATH_STYLE is false', async () => {
55
+ const url = 'path/to/file.jpg';
56
+ expect(await fileService.getFullFileUrl(url)).toBe('https://example.com/path/to/file.jpg');
57
+ });
58
+
59
+ it('should return correct URL when S3_ENABLE_PATH_STYLE is true', async () => {
60
+ config.S3_ENABLE_PATH_STYLE = true;
61
+ const url = 'path/to/file.jpg';
62
+ expect(await fileService.getFullFileUrl(url)).toBe(
63
+ 'https://example.com/my-bucket/path/to/file.jpg',
64
+ );
65
+ config.S3_ENABLE_PATH_STYLE = false;
66
+ });
67
+ });
68
+
69
+ describe('getFileContent', () => {
70
+ it('应该返回文件内容', async () => {
71
+ expect(await fileService.getFileContent('test.txt')).toBe('file content');
72
+ });
73
+ });
74
+
75
+ describe('getFileByteArray', () => {
76
+ it('应该返回文件字节数组', async () => {
77
+ const result = await fileService.getFileByteArray('test.jpg');
78
+ expect(result).toBeInstanceOf(Uint8Array);
79
+ expect(result.length).toBe(3);
80
+ });
81
+ });
82
+
83
+ describe('deleteFile', () => {
84
+ it('应该调用S3的deleteFile方法', async () => {
85
+ await fileService.deleteFile('test.jpg');
86
+ expect(fileService['s3'].deleteFile).toHaveBeenCalledWith('test.jpg');
87
+ });
88
+ });
89
+
90
+ describe('deleteFiles', () => {
91
+ it('应该调用S3的deleteFiles方法', async () => {
92
+ await fileService.deleteFiles(['test1.jpg', 'test2.jpg']);
93
+ expect(fileService['s3'].deleteFiles).toHaveBeenCalledWith(['test1.jpg', 'test2.jpg']);
94
+ });
95
+ });
96
+
97
+ describe('createPreSignedUrl', () => {
98
+ it('应该调用S3的createPreSignedUrl方法', async () => {
99
+ const result = await fileService.createPreSignedUrl('test.jpg');
100
+ expect(result).toBe('https://upload.example.com/test.jpg');
101
+ });
102
+ });
103
+
104
+ describe('uploadContent', () => {
105
+ it('应该调用S3的uploadContent方法', async () => {
106
+ await fileService.uploadContent('test.jpg', 'content');
107
+ expect(fileService['s3'].uploadContent).toHaveBeenCalledWith('test.jpg', 'content');
108
+ });
109
+ });
110
+ });