@lobehub/chat 1.77.16 → 1.77.18

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. package/CHANGELOG.md +50 -0
  2. package/changelog/v1.json +18 -0
  3. package/contributing/Basic/Architecture.md +1 -1
  4. package/contributing/Basic/Architecture.zh-CN.md +1 -1
  5. package/contributing/Basic/Chat-API.md +326 -108
  6. package/contributing/Basic/Chat-API.zh-CN.md +313 -133
  7. package/contributing/Basic/Contributing-Guidelines.md +7 -4
  8. package/contributing/Basic/Contributing-Guidelines.zh-CN.md +7 -6
  9. package/contributing/Home.md +5 -5
  10. package/contributing/State-Management/State-Management-Intro.md +1 -1
  11. package/contributing/State-Management/State-Management-Intro.zh-CN.md +1 -1
  12. package/docker-compose/local/docker-compose.yml +2 -1
  13. package/locales/ar/components.json +4 -0
  14. package/locales/ar/modelProvider.json +1 -0
  15. package/locales/ar/models.json +8 -5
  16. package/locales/ar/tool.json +21 -1
  17. package/locales/bg-BG/components.json +4 -0
  18. package/locales/bg-BG/modelProvider.json +1 -0
  19. package/locales/bg-BG/models.json +8 -5
  20. package/locales/bg-BG/tool.json +21 -1
  21. package/locales/de-DE/components.json +4 -0
  22. package/locales/de-DE/modelProvider.json +1 -0
  23. package/locales/de-DE/models.json +8 -5
  24. package/locales/de-DE/tool.json +21 -1
  25. package/locales/en-US/components.json +4 -0
  26. package/locales/en-US/modelProvider.json +1 -0
  27. package/locales/en-US/models.json +8 -5
  28. package/locales/en-US/tool.json +21 -1
  29. package/locales/es-ES/components.json +4 -0
  30. package/locales/es-ES/modelProvider.json +1 -0
  31. package/locales/es-ES/models.json +7 -4
  32. package/locales/es-ES/tool.json +21 -1
  33. package/locales/fa-IR/components.json +4 -0
  34. package/locales/fa-IR/modelProvider.json +1 -0
  35. package/locales/fa-IR/models.json +7 -4
  36. package/locales/fa-IR/tool.json +21 -1
  37. package/locales/fr-FR/components.json +4 -0
  38. package/locales/fr-FR/modelProvider.json +1 -0
  39. package/locales/fr-FR/models.json +8 -5
  40. package/locales/fr-FR/tool.json +21 -1
  41. package/locales/it-IT/components.json +4 -0
  42. package/locales/it-IT/modelProvider.json +1 -0
  43. package/locales/it-IT/models.json +7 -4
  44. package/locales/it-IT/tool.json +21 -1
  45. package/locales/ja-JP/components.json +4 -0
  46. package/locales/ja-JP/modelProvider.json +1 -0
  47. package/locales/ja-JP/models.json +8 -5
  48. package/locales/ja-JP/tool.json +21 -1
  49. package/locales/ko-KR/components.json +4 -0
  50. package/locales/ko-KR/modelProvider.json +1 -0
  51. package/locales/ko-KR/models.json +8 -5
  52. package/locales/ko-KR/tool.json +21 -1
  53. package/locales/nl-NL/components.json +4 -0
  54. package/locales/nl-NL/modelProvider.json +1 -0
  55. package/locales/nl-NL/models.json +8 -5
  56. package/locales/nl-NL/tool.json +21 -1
  57. package/locales/pl-PL/components.json +4 -0
  58. package/locales/pl-PL/modelProvider.json +1 -0
  59. package/locales/pl-PL/models.json +8 -5
  60. package/locales/pl-PL/tool.json +21 -1
  61. package/locales/pt-BR/components.json +4 -0
  62. package/locales/pt-BR/modelProvider.json +1 -0
  63. package/locales/pt-BR/models.json +7 -4
  64. package/locales/pt-BR/tool.json +21 -1
  65. package/locales/ru-RU/components.json +4 -0
  66. package/locales/ru-RU/modelProvider.json +1 -0
  67. package/locales/ru-RU/models.json +7 -4
  68. package/locales/ru-RU/tool.json +21 -1
  69. package/locales/tr-TR/components.json +4 -0
  70. package/locales/tr-TR/modelProvider.json +1 -0
  71. package/locales/tr-TR/models.json +8 -5
  72. package/locales/tr-TR/tool.json +21 -1
  73. package/locales/vi-VN/components.json +4 -0
  74. package/locales/vi-VN/modelProvider.json +1 -0
  75. package/locales/vi-VN/models.json +8 -5
  76. package/locales/vi-VN/tool.json +21 -1
  77. package/locales/zh-CN/components.json +4 -0
  78. package/locales/zh-CN/modelProvider.json +1 -0
  79. package/locales/zh-CN/models.json +9 -6
  80. package/locales/zh-CN/tool.json +30 -1
  81. package/locales/zh-TW/components.json +4 -0
  82. package/locales/zh-TW/modelProvider.json +1 -0
  83. package/locales/zh-TW/models.json +7 -4
  84. package/locales/zh-TW/tool.json +21 -1
  85. package/package.json +1 -1
  86. package/src/app/(backend)/webapi/models/[provider]/pull/route.ts +34 -0
  87. package/src/app/(backend)/webapi/{chat/models → models}/[provider]/route.ts +1 -2
  88. package/src/app/[variants]/(main)/settings/llm/ProviderList/Ollama/index.tsx +0 -7
  89. package/src/app/[variants]/(main)/settings/provider/(detail)/ollama/CheckError.tsx +1 -1
  90. package/src/components/FormAction/index.tsx +1 -1
  91. package/src/database/models/__tests__/aiProvider.test.ts +100 -0
  92. package/src/database/models/aiProvider.ts +11 -1
  93. package/src/features/Conversation/Error/OllamaBizError/InvalidOllamaModel.tsx +43 -0
  94. package/src/features/Conversation/Error/OllamaDesktopSetupGuide/index.tsx +61 -0
  95. package/src/features/Conversation/Error/index.tsx +7 -0
  96. package/src/features/DevPanel/SystemInspector/ServerConfig.tsx +18 -2
  97. package/src/features/DevPanel/SystemInspector/index.tsx +25 -6
  98. package/src/features/OllamaModelDownloader/index.tsx +149 -0
  99. package/src/libs/agent-runtime/AgentRuntime.ts +6 -0
  100. package/src/libs/agent-runtime/BaseAI.ts +7 -0
  101. package/src/libs/agent-runtime/ollama/index.ts +84 -2
  102. package/src/libs/agent-runtime/openrouter/__snapshots__/index.test.ts.snap +24 -3263
  103. package/src/libs/agent-runtime/openrouter/fixtures/frontendModels.json +25 -0
  104. package/src/libs/agent-runtime/openrouter/fixtures/models.json +0 -3353
  105. package/src/libs/agent-runtime/openrouter/index.test.ts +56 -1
  106. package/src/libs/agent-runtime/openrouter/index.ts +9 -4
  107. package/src/libs/agent-runtime/types/index.ts +1 -0
  108. package/src/libs/agent-runtime/types/model.ts +44 -0
  109. package/src/libs/agent-runtime/utils/streams/index.ts +1 -0
  110. package/src/libs/agent-runtime/utils/streams/model.ts +110 -0
  111. package/src/locales/default/components.ts +4 -0
  112. package/src/locales/default/modelProvider.ts +1 -0
  113. package/src/locales/default/tool.ts +30 -1
  114. package/src/server/modules/SearXNG.ts +10 -2
  115. package/src/server/routers/tools/__test__/search.test.ts +3 -1
  116. package/src/server/routers/tools/search.ts +10 -2
  117. package/src/services/__tests__/models.test.ts +21 -0
  118. package/src/services/_url.ts +4 -1
  119. package/src/services/chat.ts +1 -1
  120. package/src/services/models.ts +153 -7
  121. package/src/services/search.ts +2 -2
  122. package/src/store/aiInfra/slices/aiModel/action.ts +1 -1
  123. package/src/store/aiInfra/slices/aiProvider/action.ts +2 -1
  124. package/src/store/chat/slices/builtinTool/actions/searXNG.test.ts +28 -8
  125. package/src/store/chat/slices/builtinTool/actions/searXNG.ts +22 -5
  126. package/src/store/user/slices/modelList/action.test.ts +2 -2
  127. package/src/store/user/slices/modelList/action.ts +1 -1
  128. package/src/tools/web-browsing/Portal/Search/index.tsx +1 -1
  129. package/src/tools/web-browsing/Render/Search/SearchQuery/SearchView.tsx +1 -1
  130. package/src/tools/web-browsing/Render/Search/SearchQuery/index.tsx +1 -1
  131. package/src/tools/web-browsing/Render/Search/SearchResult/index.tsx +1 -1
  132. package/src/tools/web-browsing/components/CategoryAvatar.tsx +27 -0
  133. package/src/tools/web-browsing/components/SearchBar.tsx +84 -4
  134. package/src/tools/web-browsing/const.ts +26 -0
  135. package/src/tools/web-browsing/index.ts +58 -28
  136. package/src/tools/web-browsing/systemRole.ts +62 -1
  137. package/src/types/tool/search.ts +10 -1
  138. package/src/app/[variants]/(main)/settings/llm/ProviderList/Ollama/Checker.tsx +0 -73
  139. package/src/app/[variants]/(main)/settings/provider/(detail)/ollama/OllamaModelDownloader/index.tsx +0 -127
  140. package/src/features/Conversation/Error/OllamaBizError/InvalidOllamaModel/index.tsx +0 -154
  141. package/src/features/Conversation/Error/OllamaBizError/InvalidOllamaModel/useDownloadMonitor.ts +0 -29
  142. package/src/helpers/url.ts +0 -17
  143. package/src/services/__tests__/ollama.test.ts +0 -28
  144. package/src/services/ollama.ts +0 -83
  145. /package/src/{app/[variants]/(main)/settings/provider/(detail)/ollama → features}/OllamaModelDownloader/useDownloadMonitor.ts +0 -0
@@ -4,6 +4,7 @@ import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
4
4
  import { LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime';
5
5
  import { testProvider } from '@/libs/agent-runtime/providerTestUtils';
6
6
 
7
+ import frontendModels from './fixtures/frontendModels.json';
7
8
  import models from './fixtures/models.json';
8
9
  import { LobeOpenRouterAI } from './index';
9
10
 
@@ -137,12 +138,66 @@ describe('LobeOpenRouterAI', () => {
137
138
  });
138
139
 
139
140
  describe('models', () => {
140
- it('should get models', async () => {
141
+ it('should get models with frontend models data', async () => {
141
142
  // mock the models.list method
142
143
  (instance['client'].models.list as Mock).mockResolvedValue({ data: models });
143
144
 
145
+ // 模拟成功的 fetch 响应
146
+ vi.stubGlobal(
147
+ 'fetch',
148
+ vi.fn().mockResolvedValue({
149
+ ok: true,
150
+ json: vi.fn().mockResolvedValue(frontendModels),
151
+ }),
152
+ );
153
+
154
+ const list = await instance.models();
155
+
156
+ // 验证 fetch 被正确调用
157
+ expect(fetch).toHaveBeenCalledWith('https://openrouter.ai/api/frontend/models');
158
+
159
+ // 验证模型列表中包含了从前端 API 获取的额外信息
160
+ const reflectionModel = list.find((model) => model.id === 'mattshumer/reflection-70b:free');
161
+ expect(reflectionModel).toBeDefined();
162
+ expect(reflectionModel?.reasoning).toBe(true);
163
+ expect(reflectionModel?.functionCall).toBe(true);
164
+
165
+ expect(list).toMatchSnapshot();
166
+ });
167
+
168
+ it('should handle fetch failure gracefully', async () => {
169
+ // mock the models.list method
170
+ (instance['client'].models.list as Mock).mockResolvedValue({ data: models });
171
+
172
+ // 模拟失败的 fetch 响应
173
+ vi.stubGlobal(
174
+ 'fetch',
175
+ vi.fn().mockResolvedValue({
176
+ ok: false,
177
+ }),
178
+ );
179
+
180
+ const list = await instance.models();
181
+
182
+ // 验证即使 fetch 失败,方法仍然能返回有效的模型列表
183
+ expect(fetch).toHaveBeenCalledWith('https://openrouter.ai/api/frontend/models');
184
+ expect(list.length).toBeGreaterThan(0); // 确保返回了模型列表
185
+ expect(list).toMatchSnapshot();
186
+ });
187
+
188
+ it('should handle fetch error gracefully', async () => {
189
+ // mock the models.list method
190
+ (instance['client'].models.list as Mock).mockResolvedValue({ data: models });
191
+
192
+ // 在测试环境中,需要先修改 fetch 的实现,确保错误被捕获
193
+ vi.spyOn(global, 'fetch').mockImplementation(() => {
194
+ throw new Error('Network error');
195
+ });
196
+
144
197
  const list = await instance.models();
145
198
 
199
+ // 验证即使 fetch 出错,方法仍然能返回有效的模型列表
200
+ expect(list.length).toBeGreaterThan(0); // 确保返回了模型列表
146
201
  expect(list).toMatchSnapshot();
147
202
  });
148
203
  });
@@ -54,11 +54,16 @@ export const LobeOpenRouterAI = LobeOpenAICompatibleFactory({
54
54
  const modelsPage = (await client.models.list()) as any;
55
55
  const modelList: OpenRouterModelCard[] = modelsPage.data;
56
56
 
57
- const response = await fetch('https://openrouter.ai/api/frontend/models');
58
57
  const modelsExtraInfo: OpenRouterModelExtraInfo[] = [];
59
- if (response.ok) {
60
- const data = await response.json();
61
- modelsExtraInfo.push(...data['data']);
58
+ try {
59
+ const response = await fetch('https://openrouter.ai/api/frontend/models');
60
+ if (response.ok) {
61
+ const data = await response.json();
62
+ modelsExtraInfo.push(...data['data']);
63
+ }
64
+ } catch (error) {
65
+ // 忽略 fetch 错误,使用空的 modelsExtraInfo 数组继续处理
66
+ console.error('Failed to fetch OpenRouter frontend models:', error);
62
67
  }
63
68
 
64
69
  return modelList
@@ -1,5 +1,6 @@
1
1
  export * from './chat';
2
2
  export * from './embeddings';
3
+ export * from './model';
3
4
  export * from './textToImage';
4
5
  export * from './tts';
5
6
  export * from './type';
@@ -0,0 +1,44 @@
1
+ export interface ModelDetail {
2
+ details?: {
3
+ families?: string[];
4
+ family?: string;
5
+ format?: string;
6
+ parameter_size?: string;
7
+ quantization_level?: string;
8
+ };
9
+ digest?: string;
10
+ id: string;
11
+ modified_at?: Date;
12
+ name?: string;
13
+ size?: number;
14
+ }
15
+
16
+ export interface ModelProgressResponse {
17
+ completed?: number;
18
+ digest?: string;
19
+ model?: string;
20
+ status: string;
21
+ total?: number;
22
+ }
23
+
24
+ export interface ModelsParams {
25
+ name?: string;
26
+ }
27
+
28
+ export interface PullModelParams {
29
+ insecure?: boolean;
30
+ model: string;
31
+ stream?: boolean;
32
+ }
33
+
34
+ export interface ModelDetailParams {
35
+ model: string;
36
+ }
37
+
38
+ export interface DeleteModelParams {
39
+ model: string;
40
+ }
41
+
42
+ export interface ModelRequestOptions {
43
+ signal?: AbortSignal;
44
+ }
@@ -1,6 +1,7 @@
1
1
  export * from './anthropic';
2
2
  export * from './bedrock';
3
3
  export * from './google-ai';
4
+ export * from './model';
4
5
  export * from './ollama';
5
6
  export * from './openai';
6
7
  export * from './protocol';
@@ -0,0 +1,110 @@
1
+ /**
2
+ * 将异步迭代器转换为 JSON 格式的 ReadableStream
3
+ */
4
+ export const createModelPullStream = <
5
+ T extends { completed?: number; digest?: string; status: string; total?: number },
6
+ >(
7
+ iterable: AsyncIterable<T>,
8
+ model: string,
9
+ {
10
+ onCancel, // 新增:取消时调用的回调函数
11
+ }: {
12
+ onCancel?: (reason?: any) => void; // 回调函数签名
13
+ } = {},
14
+ ): ReadableStream => {
15
+ let iterator: AsyncIterator<T>; // 在外部跟踪迭代器以便取消时可以调用 return
16
+
17
+ return new ReadableStream({
18
+ // 实现 cancel 方法
19
+ cancel(reason) {
20
+ // 调用传入的 onCancel 回调,执行外部的清理逻辑(如 client.abort())
21
+ if (onCancel) {
22
+ onCancel(reason);
23
+ }
24
+
25
+ // 尝试优雅地终止迭代器
26
+ // 注意:这依赖于 AsyncIterable 的实现是否支持 return/throw
27
+ if (iterator && typeof iterator.return === 'function') {
28
+ // 不需要 await,让它在后台执行清理
29
+ iterator.return().catch();
30
+ }
31
+ },
32
+ async start(controller) {
33
+ iterator = iterable[Symbol.asyncIterator](); // 获取迭代器
34
+
35
+ const encoder = new TextEncoder();
36
+
37
+ try {
38
+ // eslint-disable-next-line no-constant-condition
39
+ while (true) {
40
+ // 等待下一个数据块或迭代完成
41
+ const { value: progress, done } = await iterator.next();
42
+
43
+ // 如果迭代完成,跳出循环
44
+ if (done) {
45
+ break;
46
+ }
47
+
48
+ // 忽略 'pulling manifest' 状态,因为它不包含进度
49
+ if (progress.status === 'pulling manifest') continue;
50
+
51
+ // 格式化为标准格式并写入流
52
+ const progressData =
53
+ JSON.stringify({
54
+ completed: progress.completed,
55
+ digest: progress.digest,
56
+ model,
57
+ status: progress.status,
58
+ total: progress.total,
59
+ }) + '\n';
60
+
61
+ controller.enqueue(encoder.encode(progressData));
62
+ }
63
+
64
+ // 正常完成
65
+ controller.close();
66
+ } catch (error) {
67
+ // 处理错误
68
+
69
+ // 如果错误是由于中止操作引起的,则静默处理或记录日志,然后尝试关闭流
70
+ if (error instanceof DOMException && error.name === 'AbortError') {
71
+ // 不需要再 enqueue 错误信息,因为连接可能已断开
72
+ // 尝试正常关闭,如果已经取消,controller 可能已关闭或出错
73
+ try {
74
+ controller.enqueue(new TextEncoder().encode(JSON.stringify({ status: 'cancelled' })));
75
+ controller.close();
76
+ } catch {
77
+ // 忽略关闭错误,可能流已经被取消机制处理了
78
+ }
79
+ } else {
80
+ console.error('[createModelPullStream] model download stream error:', error);
81
+ // 对于其他错误,尝试将错误信息发送给客户端
82
+ const errorMessage = error instanceof Error ? error.message : String(error);
83
+ const errorData =
84
+ JSON.stringify({
85
+ error: errorMessage,
86
+ model,
87
+ status: 'error',
88
+ }) + '\n';
89
+
90
+ try {
91
+ // 只有在流还期望数据时才尝试 enqueue
92
+ if (controller.desiredSize !== null && controller.desiredSize > 0) {
93
+ controller.enqueue(encoder.encode(errorData));
94
+ }
95
+ } catch (enqueueError) {
96
+ console.error('[createModelPullStream] Error enqueueing error message:', enqueueError);
97
+ // 如果这里也失败,很可能连接已断开
98
+ }
99
+
100
+ // 尝试关闭流或标记为错误状态
101
+ try {
102
+ controller.close(); // 尝试正常关闭
103
+ } catch {
104
+ controller.error(error); // 如果关闭失败,则将流置于错误状态
105
+ }
106
+ }
107
+ }
108
+ },
109
+ });
110
+ };
@@ -93,6 +93,10 @@ export default {
93
93
  provider: '服务商',
94
94
  },
95
95
  OllamaSetupGuide: {
96
+ action: {
97
+ close: '关闭提示',
98
+ start: '已安装并运行,开始对话',
99
+ },
96
100
  cors: {
97
101
  description: '因浏览器安全限制,你需要为 Ollama 进行跨域配置后方可正常使用。',
98
102
  linux: {
@@ -166,6 +166,7 @@ export default {
166
166
  },
167
167
  download: {
168
168
  desc: 'Ollama 正在下载该模型,请尽量不要关闭本页面。重新下载时将会中断处继续',
169
+ failed: '模型下载失败,请检查网络或者 Ollama 设置后重试',
169
170
  remainingTime: '剩余时间',
170
171
  speed: '下载速度',
171
172
  title: '正在下载模型 {{model}} ',
@@ -19,8 +19,37 @@ export default {
19
19
  placeholder: '关键词',
20
20
  tooltip: '将会重新获取搜索结果,并创建一条新的总结消息',
21
21
  },
22
- searchEngine: '搜索引擎:',
22
+ searchCategory: {
23
+ placeholder: '搜索类别',
24
+ title: '搜索类别:',
25
+ value: {
26
+ 'files': '文件',
27
+ 'general': '通用',
28
+ 'images': '图片',
29
+ 'it': '信息技术',
30
+ 'map': '地图',
31
+ 'music': '音乐',
32
+ 'news': '新闻',
33
+ 'science': '科学',
34
+ 'social_media': '社交媒体',
35
+ 'videos': '视频',
36
+ },
37
+ },
38
+ searchEngine: {
39
+ placeholder: '搜索引擎',
40
+ title: '搜索引擎:',
41
+ },
23
42
  searchResult: '搜索数量:',
43
+ searchTimeRange: {
44
+ title: '时间范围:',
45
+ value: {
46
+ anytime: '时间不限',
47
+ day: '一天内',
48
+ month: '一月内',
49
+ week: '一周内',
50
+ year: '一年内',
51
+ },
52
+ },
24
53
  summary: '总结',
25
54
  summaryTooltip: '总结当前内容',
26
55
  viewMoreResults: '查看更多 {{results}} 个结果',
@@ -10,10 +10,18 @@ export class SearXNGClient {
10
10
  this.baseUrl = baseUrl;
11
11
  }
12
12
 
13
- async search(query: string, engines?: string[]): Promise<SearchResponse> {
13
+ async search(query: string, optionalParams: Record<string, any> = {}): Promise<SearchResponse> {
14
14
  try {
15
+ const { time_range, ...otherParams } = optionalParams;
16
+
17
+ const processedParams = Object.entries(otherParams).reduce<Record<string, any>>((acc, [key, value]) => {
18
+ acc[key] = Array.isArray(value) ? value.join(',') : value;
19
+ return acc;
20
+ }, {});
21
+
15
22
  const searchParams = qs.stringify({
16
- engines: engines?.join(','),
23
+ ...processedParams,
24
+ ...(time_range !== 'anytime' && { time_range }),
17
25
  format: 'json',
18
26
  q: query,
19
27
  });
@@ -98,8 +98,10 @@ describe('searchRouter', () => {
98
98
  const caller = searchRouter.createCaller(mockContext as any);
99
99
 
100
100
  const result = await caller.query({
101
+ optionalParams: {
102
+ searchEngines: ['google'],
103
+ },
101
104
  query: 'test query',
102
- searchEngine: ['google'],
103
105
  });
104
106
 
105
107
  expect(result).toEqual(mockSearchResult);
@@ -43,8 +43,12 @@ export const searchRouter = router({
43
43
  query: searchProcedure
44
44
  .input(
45
45
  z.object({
46
+ optionalParams: z.object({
47
+ searchCategories: z.array(z.string()).optional(),
48
+ searchEngines: z.array(z.string()).optional(),
49
+ searchTimeRange: z.string().optional(),
50
+ }).optional(),
46
51
  query: z.string(),
47
- searchEngine: z.array(z.string()).optional(),
48
52
  }),
49
53
  )
50
54
  .query(async ({ input }) => {
@@ -55,7 +59,11 @@ export const searchRouter = router({
55
59
  const client = new SearXNGClient(toolsEnv.SEARXNG_URL);
56
60
 
57
61
  try {
58
- return await client.search(input.query, input.searchEngine);
62
+ return await client.search(input.query, {
63
+ categories: input.optionalParams?.searchCategories,
64
+ engines: input.optionalParams?.searchEngines,
65
+ time_range: input.optionalParams?.searchTimeRange,
66
+ });
59
67
  } catch (e) {
60
68
  console.error(e);
61
69
 
@@ -0,0 +1,21 @@
1
+ import { Mock, describe, expect, it, vi } from 'vitest';
2
+
3
+ import { ModelsService } from '../models';
4
+
5
+ vi.stubGlobal('fetch', vi.fn());
6
+
7
+ // 创建一个测试用的 ModelsService 实例
8
+
9
+ const modelsService = new ModelsService();
10
+
11
+ describe('ModelsService', () => {
12
+ describe('getModels', () => {
13
+ it('should call the appropriate endpoint for a generic provider', async () => {
14
+ (fetch as Mock).mockResolvedValueOnce(new Response(JSON.stringify({ models: [] })));
15
+
16
+ await modelsService.getModels('openai');
17
+
18
+ expect(fetch).toHaveBeenCalled();
19
+ });
20
+ });
21
+ });
@@ -32,7 +32,10 @@ export const API_ENDPOINTS = mapWithBasePath({
32
32
 
33
33
  // chat
34
34
  chat: (provider: string) => withBasePath(`/webapi/chat/${provider}`),
35
- chatModels: (provider: string) => withBasePath(`/webapi/chat/models/${provider}`),
35
+
36
+ // models
37
+ models: (provider: string) => withBasePath(`/webapi/models/${provider}`),
38
+ modelPull: (provider: string) => withBasePath(`/webapi/models/${provider}/pull`),
36
39
 
37
40
  // image
38
41
  images: (provider: string) => `/webapi/text-to-image/${provider}`,
@@ -133,7 +133,7 @@ interface CreateAssistantMessageStream extends FetchSSEOptions {
133
133
  *
134
134
  * **Note**: if you try to fetch directly, use `fetchOnClient` instead.
135
135
  */
136
- export function initializeWithClientStore(provider: string, payload: any) {
136
+ export function initializeWithClientStore(provider: string, payload?: any) {
137
137
  /**
138
138
  * Since #5267, we map parameters for client-fetch in function `getProviderAuthPayload`
139
139
  * which called by `createPayloadWithKeyVaults` below.
@@ -1,13 +1,42 @@
1
+ import { isDeprecatedEdition } from '@/const/version';
1
2
  import { createHeaderWithAuth } from '@/services/_auth';
3
+ import { aiProviderSelectors, getAiInfraStoreState } from '@/store/aiInfra';
2
4
  import { useUserStore } from '@/store/user';
3
5
  import { modelConfigSelectors } from '@/store/user/selectors';
4
6
  import { ChatModelCard } from '@/types/llm';
7
+ import { getMessageError } from '@/utils/fetch';
5
8
 
6
9
  import { API_ENDPOINTS } from './_url';
7
10
  import { initializeWithClientStore } from './chat';
8
11
 
9
- class ModelsService {
10
- getChatModels = async (provider: string): Promise<ChatModelCard[] | undefined> => {
12
+ const isEnableFetchOnClient = (provider: string) => {
13
+ // TODO: remove this condition in V2.0
14
+ if (isDeprecatedEdition) {
15
+ return modelConfigSelectors.isProviderFetchOnClient(provider)(useUserStore.getState());
16
+ } else {
17
+ return aiProviderSelectors.isProviderFetchOnClient(provider)(getAiInfraStoreState());
18
+ }
19
+ };
20
+
21
+ // 进度信息接口
22
+ export interface ModelProgressInfo {
23
+ completed?: number;
24
+ digest?: string;
25
+ model?: string;
26
+ status?: string;
27
+ total?: number;
28
+ }
29
+
30
+ // 进度回调函数类型
31
+ export type ProgressCallback = (progress: ModelProgressInfo) => void;
32
+ export type ErrorCallback = (error: { message: string }) => void;
33
+
34
+ export class ModelsService {
35
+ // 用于中断下载的控制器
36
+ private _abortController: AbortController | null = null;
37
+
38
+ // 获取模型列表
39
+ getModels = async (provider: string): Promise<ChatModelCard[] | undefined> => {
11
40
  const headers = await createHeaderWithAuth({
12
41
  headers: { 'Content-Type': 'application/json' },
13
42
  provider,
@@ -16,15 +45,13 @@ class ModelsService {
16
45
  /**
17
46
  * Use browser agent runtime
18
47
  */
19
- const enableFetchOnClient = modelConfigSelectors.isProviderFetchOnClient(provider)(
20
- useUserStore.getState(),
21
- );
48
+ const enableFetchOnClient = isEnableFetchOnClient(provider);
22
49
  if (enableFetchOnClient) {
23
- const agentRuntime = await initializeWithClientStore(provider, {});
50
+ const agentRuntime = await initializeWithClientStore(provider);
24
51
  return agentRuntime.models();
25
52
  }
26
53
 
27
- const res = await fetch(API_ENDPOINTS.chatModels(provider), { headers });
54
+ const res = await fetch(API_ENDPOINTS.models(provider), { headers });
28
55
  if (!res.ok) return;
29
56
 
30
57
  return res.json();
@@ -32,6 +59,125 @@ class ModelsService {
32
59
  return;
33
60
  }
34
61
  };
62
+
63
+ /**
64
+ * 下载模型并通过回调函数返回进度信息
65
+ */
66
+ downloadModel = async (
67
+ { model, provider }: { model: string; provider: string },
68
+ { onProgress }: { onError?: ErrorCallback; onProgress?: ProgressCallback } = {},
69
+ ): Promise<void> => {
70
+ try {
71
+ // 创建一个新的 AbortController
72
+ this._abortController = new AbortController();
73
+ const signal = this._abortController.signal;
74
+
75
+ const headers = await createHeaderWithAuth({
76
+ headers: { 'Content-Type': 'application/json' },
77
+ provider,
78
+ });
79
+
80
+ const enableFetchOnClient = isEnableFetchOnClient(provider);
81
+
82
+ console.log('enableFetchOnClient:', enableFetchOnClient);
83
+ let res: Response;
84
+ if (enableFetchOnClient) {
85
+ const agentRuntime = await initializeWithClientStore(provider);
86
+ res = (await agentRuntime.pullModel({ model }, { signal }))!;
87
+ } else {
88
+ res = await fetch(API_ENDPOINTS.modelPull(provider), {
89
+ body: JSON.stringify({ model }),
90
+ headers,
91
+ method: 'POST',
92
+ signal,
93
+ });
94
+ }
95
+
96
+ if (!res.ok) {
97
+ throw await getMessageError(res);
98
+ }
99
+
100
+ // 处理响应流
101
+ if (res.body) {
102
+ await this.processModelPullStream(res, { onProgress });
103
+ }
104
+ } catch (error) {
105
+ // 如果是取消操作,不需要继续抛出错误
106
+ if (error instanceof DOMException && error.name === 'AbortError') {
107
+ return;
108
+ }
109
+
110
+ console.error('download model error:', error);
111
+ throw error;
112
+ } finally {
113
+ // 清理 AbortController
114
+ this._abortController = null;
115
+ }
116
+ };
117
+
118
+ // 中断模型下载
119
+ abortPull = () => {
120
+ // 使用 AbortController 中断下载
121
+ if (this._abortController) {
122
+ this._abortController.abort();
123
+ this._abortController = null;
124
+ }
125
+ };
126
+
127
+ /**
128
+ * 处理模型下载流,解析进度信息并通过回调函数返回
129
+ * @param response 响应对象
130
+ * @param onProgress 进度回调函数
131
+ * @returns Promise<void>
132
+ */
133
+ private processModelPullStream = async (
134
+ response: Response,
135
+ { onProgress, onError }: { onError?: ErrorCallback; onProgress?: ProgressCallback },
136
+ ): Promise<void> => {
137
+ // 处理响应流
138
+ const reader = response.body?.getReader();
139
+ if (!reader) return;
140
+
141
+ // 读取和处理流数据
142
+ // eslint-disable-next-line no-constant-condition
143
+ while (true) {
144
+ const { done, value } = await reader.read();
145
+ if (done) break;
146
+
147
+ // 解析进度数据
148
+ const progressText = new TextDecoder().decode(value);
149
+ // 一行可能包含多个进度更新
150
+ const progressUpdates = progressText.trim().split('\n');
151
+
152
+ for (const update of progressUpdates) {
153
+ let progress;
154
+ try {
155
+ progress = JSON.parse(update);
156
+ } catch (e) {
157
+ console.error('Error parsing progress update:', e);
158
+ console.error('raw data', update);
159
+ }
160
+
161
+ if (progress.status === 'canceled') {
162
+ console.log('progress:', progress);
163
+ // const abortError = new Error('abort');
164
+ // abortError.name = 'AbortError';
165
+ //
166
+ // throw abortError;
167
+ }
168
+
169
+ if (progress.status === 'error') {
170
+ onError?.({ message: progress.error });
171
+ throw new Error(progress.error);
172
+ }
173
+
174
+ // 调用进度回调
175
+ if (progress.completed !== undefined || progress.status) {
176
+ onProgress?.(progress);
177
+ }
178
+ }
179
+ }
180
+ };
35
181
  }
36
182
 
37
183
  export const modelsService = new ModelsService();
@@ -1,8 +1,8 @@
1
1
  import { toolsClient } from '@/libs/trpc/client';
2
2
 
3
3
  class SearchService {
4
- search(query: string, searchEngine?: string[]) {
5
- return toolsClient.search.query.query({ query, searchEngine });
4
+ search(query: string, optionalParams?: object) {
5
+ return toolsClient.search.query.query({ optionalParams, query});
6
6
  }
7
7
 
8
8
  crawlPage(url: string) {
@@ -71,7 +71,7 @@ export const createAiModelSlice: StateCreator<
71
71
  fetchRemoteModelList: async (providerId) => {
72
72
  const { modelsService } = await import('@/services/models');
73
73
 
74
- const data = await modelsService.getChatModels(providerId);
74
+ const data = await modelsService.getModels(providerId);
75
75
  if (data) {
76
76
  await get().batchUpdateAiModels(
77
77
  data.map((model) => ({
@@ -3,7 +3,7 @@ import { SWRResponse, mutate } from 'swr';
3
3
  import { StateCreator } from 'zustand/vanilla';
4
4
 
5
5
  import { DEFAULT_MODEL_PROVIDER_LIST } from '@/config/modelProviders';
6
- import { isDeprecatedEdition } from '@/const/version';
6
+ import { isDeprecatedEdition, isDesktop, isUsePgliteDB } from '@/const/version';
7
7
  import { useClientDataSWR } from '@/libs/swr';
8
8
  import { aiProviderService } from '@/services/aiProvider';
9
9
  import { AiInfraStore } from '@/store/aiInfra/store';
@@ -184,6 +184,7 @@ export const createAiProviderSlice: StateCreator<
184
184
  };
185
185
  },
186
186
  {
187
+ focusThrottleInterval: isDesktop || isUsePgliteDB ? 100 : undefined,
187
188
  onSuccess: async (data) => {
188
189
  if (!data) return;
189
190