@lobehub/chat 0.147.21 → 0.148.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. package/CHANGELOG.md +42 -0
  2. package/locales/ar/setting.json +4 -0
  3. package/locales/bg-BG/setting.json +4 -0
  4. package/locales/de-DE/setting.json +4 -0
  5. package/locales/en-US/setting.json +4 -0
  6. package/locales/es-ES/setting.json +4 -0
  7. package/locales/fr-FR/setting.json +4 -0
  8. package/locales/it-IT/setting.json +4 -0
  9. package/locales/ja-JP/setting.json +4 -0
  10. package/locales/ko-KR/setting.json +4 -0
  11. package/locales/nl-NL/setting.json +4 -0
  12. package/locales/pl-PL/setting.json +4 -0
  13. package/locales/pt-BR/setting.json +4 -0
  14. package/locales/ru-RU/setting.json +4 -0
  15. package/locales/tr-TR/setting.json +4 -0
  16. package/locales/vi-VN/setting.json +4 -0
  17. package/locales/zh-CN/setting.json +4 -0
  18. package/locales/zh-TW/setting.json +4 -0
  19. package/package.json +3 -2
  20. package/public/favicon-32x32.ico +0 -0
  21. package/public/favicon.ico +0 -0
  22. package/public/icons/apple-touch-icon.png +0 -0
  23. package/src/app/api/chat/[provider]/route.test.ts +5 -7
  24. package/src/app/api/chat/[provider]/route.ts +13 -7
  25. package/src/app/api/chat/agentRuntime.test.ts +195 -451
  26. package/src/app/api/chat/agentRuntime.ts +197 -280
  27. package/src/app/api/chat/models/[provider]/route.ts +2 -2
  28. package/src/app/chat/features/TopicListContent/Topic/TopicContent.tsx +2 -2
  29. package/src/app/metadata.ts +3 -5
  30. package/src/app/settings/llm/components/ProviderConfig/index.tsx +23 -1
  31. package/src/app/settings/llm/index.tsx +2 -2
  32. package/src/app/settings/llm/page.tsx +1 -5
  33. package/src/features/ChatInput/Topic/index.tsx +6 -2
  34. package/src/features/Conversation/components/ChatItem/index.tsx +8 -3
  35. package/src/libs/agent-runtime/AgentRuntime.test.ts +400 -0
  36. package/src/libs/agent-runtime/AgentRuntime.ts +192 -0
  37. package/src/libs/agent-runtime/index.ts +1 -0
  38. package/src/libs/swr/index.ts +9 -0
  39. package/src/locales/default/setting.ts +4 -0
  40. package/src/services/__tests__/chat.test.ts +287 -1
  41. package/src/services/chat.ts +148 -2
  42. package/src/store/chat/slices/message/action.ts +80 -42
  43. package/src/store/chat/slices/message/initialState.ts +1 -1
  44. package/src/store/chat/slices/message/reducer.ts +32 -1
  45. package/src/store/chat/slices/topic/action.test.ts +25 -2
  46. package/src/store/chat/slices/topic/action.ts +24 -7
  47. package/src/store/chat/slices/topic/reducer.test.ts +141 -0
  48. package/src/store/chat/slices/topic/reducer.ts +67 -0
  49. package/src/store/global/slices/settings/selectors/modelConfig.ts +13 -0
  50. package/src/store/session/slices/session/action.ts +4 -5
  51. package/src/types/settings/modelProvider.ts +4 -0
  52. package/vercel.json +1 -1
@@ -1,16 +1,40 @@
1
1
  import { LobeChatPluginManifest } from '@lobehub/chat-plugin-sdk';
2
2
  import { act } from '@testing-library/react';
3
+ import { merge } from 'lodash';
3
4
  import { describe, expect, it, vi } from 'vitest';
4
5
 
5
6
  import { DEFAULT_AGENT_CONFIG } from '@/const/settings';
7
+ import {
8
+ LobeAnthropicAI,
9
+ LobeAzureOpenAI,
10
+ LobeBedrockAI,
11
+ LobeGoogleAI,
12
+ LobeGroq,
13
+ LobeMistralAI,
14
+ LobeMoonshotAI,
15
+ LobeOllamaAI,
16
+ LobeOpenAI,
17
+ LobeOpenRouterAI,
18
+ LobePerplexityAI,
19
+ LobeTogetherAI,
20
+ LobeZeroOneAI,
21
+ LobeZhipuAI,
22
+ ModelProvider,
23
+ } from '@/libs/agent-runtime';
24
+ import { AgentRuntime } from '@/libs/agent-runtime';
6
25
  import { useFileStore } from '@/store/file';
26
+ import { GlobalStore } from '@/store/global';
27
+ import {
28
+ GlobalSettingsState,
29
+ initialSettingsState,
30
+ } from '@/store/global/slices/settings/initialState';
7
31
  import { useToolStore } from '@/store/tool';
8
32
  import { DalleManifest } from '@/tools/dalle';
9
33
  import { ChatMessage } from '@/types/message';
10
34
  import { ChatStreamPayload } from '@/types/openai/chat';
11
35
  import { LobeTool } from '@/types/tool';
12
36
 
13
- import { chatService } from '../chat';
37
+ import { chatService, initializeWithClientStore } from '../chat';
14
38
 
15
39
  // Mocking external dependencies
16
40
  vi.mock('i18next', () => ({
@@ -649,3 +673,265 @@ Get data from users`,
649
673
  });
650
674
  });
651
675
  });
676
+
677
+ /**
678
+ * Tests for AgentRuntime on client side, aim to test the
679
+ * initialization of AgentRuntime with different providers
680
+ */
681
+ vi.mock('../_auth', async (importOriginal) => {
682
+ return await importOriginal();
683
+ });
684
+ describe('AgentRuntimeOnClient', () => {
685
+ describe('initializeWithClientStore', () => {
686
+ describe('should initialize with options correctly', () => {
687
+ it('OpenAI provider: with apikey and endpoint', async () => {
688
+ // Mock the global store to return the user's OpenAI API key and endpoint
689
+ merge(initialSettingsState, {
690
+ settings: {
691
+ languageModel: {
692
+ openai: {
693
+ apiKey: 'user-openai-key',
694
+ endpoint: 'user-openai-endpoint',
695
+ },
696
+ },
697
+ },
698
+ } as GlobalSettingsState) as unknown as GlobalStore;
699
+ const runtime = await initializeWithClientStore(ModelProvider.OpenAI, {});
700
+ expect(runtime).toBeInstanceOf(AgentRuntime);
701
+ expect(runtime['_runtime']).toBeInstanceOf(LobeOpenAI);
702
+ expect(runtime['_runtime'].baseURL).toBe('user-openai-endpoint');
703
+ });
704
+
705
+ it('Azure provider: with apiKey, apiVersion, endpoint', async () => {
706
+ merge(initialSettingsState, {
707
+ settings: {
708
+ languageModel: {
709
+ azure: {
710
+ apiKey: 'user-azure-key',
711
+ endpoint: 'user-azure-endpoint',
712
+ apiVersion: '2024-02-01',
713
+ },
714
+ },
715
+ },
716
+ } as GlobalSettingsState) as unknown as GlobalStore;
717
+ const runtime = await initializeWithClientStore(ModelProvider.Azure, {});
718
+ expect(runtime).toBeInstanceOf(AgentRuntime);
719
+ expect(runtime['_runtime']).toBeInstanceOf(LobeAzureOpenAI);
720
+ });
721
+
722
+ it('Google provider: with apiKey', async () => {
723
+ merge(initialSettingsState, {
724
+ settings: {
725
+ languageModel: {
726
+ google: {
727
+ apiKey: 'user-google-key',
728
+ },
729
+ },
730
+ },
731
+ } as GlobalSettingsState) as unknown as GlobalStore;
732
+ const runtime = await initializeWithClientStore(ModelProvider.Google, {});
733
+ expect(runtime).toBeInstanceOf(AgentRuntime);
734
+ expect(runtime['_runtime']).toBeInstanceOf(LobeGoogleAI);
735
+ });
736
+
737
+ it('Moonshot AI provider: with apiKey', async () => {
738
+ merge(initialSettingsState, {
739
+ settings: {
740
+ languageModel: {
741
+ moonshot: {
742
+ apiKey: 'user-moonshot-key',
743
+ },
744
+ },
745
+ },
746
+ } as GlobalSettingsState) as unknown as GlobalStore;
747
+ const runtime = await initializeWithClientStore(ModelProvider.Moonshot, {});
748
+ expect(runtime).toBeInstanceOf(AgentRuntime);
749
+ expect(runtime['_runtime']).toBeInstanceOf(LobeMoonshotAI);
750
+ });
751
+
752
+ it('Bedrock provider: with accessKeyId, region, secretAccessKey', async () => {
753
+ merge(initialSettingsState, {
754
+ settings: {
755
+ languageModel: {
756
+ bedrock: {
757
+ accessKeyId: 'user-bedrock-access-key',
758
+ region: 'user-bedrock-region',
759
+ secretAccessKey: 'user-bedrock-secret',
760
+ },
761
+ },
762
+ },
763
+ } as GlobalSettingsState) as unknown as GlobalStore;
764
+ const runtime = await initializeWithClientStore(ModelProvider.Bedrock, {});
765
+ expect(runtime).toBeInstanceOf(AgentRuntime);
766
+ expect(runtime['_runtime']).toBeInstanceOf(LobeBedrockAI);
767
+ });
768
+
769
+ it('Ollama provider: with endpoint', async () => {
770
+ merge(initialSettingsState, {
771
+ settings: {
772
+ languageModel: {
773
+ ollama: {
774
+ endpoint: 'user-ollama-endpoint',
775
+ },
776
+ },
777
+ },
778
+ } as GlobalSettingsState) as unknown as GlobalStore;
779
+ const runtime = await initializeWithClientStore(ModelProvider.Ollama, {});
780
+ expect(runtime).toBeInstanceOf(AgentRuntime);
781
+ expect(runtime['_runtime']).toBeInstanceOf(LobeOllamaAI);
782
+ });
783
+
784
+ it('Perplexity provider: with apiKey', async () => {
785
+ merge(initialSettingsState, {
786
+ settings: {
787
+ languageModel: {
788
+ perplexity: {
789
+ apiKey: 'user-perplexity-key',
790
+ },
791
+ },
792
+ },
793
+ } as GlobalSettingsState) as unknown as GlobalStore;
794
+ const runtime = await initializeWithClientStore(ModelProvider.Perplexity, {});
795
+ expect(runtime).toBeInstanceOf(AgentRuntime);
796
+ expect(runtime['_runtime']).toBeInstanceOf(LobePerplexityAI);
797
+ });
798
+
799
+ it('Anthropic provider: with apiKey', async () => {
800
+ merge(initialSettingsState, {
801
+ settings: {
802
+ languageModel: {
803
+ anthropic: {
804
+ apiKey: 'user-anthropic-key',
805
+ },
806
+ },
807
+ },
808
+ } as GlobalSettingsState) as unknown as GlobalStore;
809
+ const runtime = await initializeWithClientStore(ModelProvider.Anthropic, {});
810
+ expect(runtime).toBeInstanceOf(AgentRuntime);
811
+ expect(runtime['_runtime']).toBeInstanceOf(LobeAnthropicAI);
812
+ });
813
+
814
+ it('Mistral provider: with apiKey', async () => {
815
+ merge(initialSettingsState, {
816
+ settings: {
817
+ languageModel: {
818
+ mistral: {
819
+ apiKey: 'user-mistral-key',
820
+ },
821
+ },
822
+ },
823
+ } as GlobalSettingsState) as unknown as GlobalStore;
824
+ const runtime = await initializeWithClientStore(ModelProvider.Mistral, {});
825
+ expect(runtime).toBeInstanceOf(AgentRuntime);
826
+ expect(runtime['_runtime']).toBeInstanceOf(LobeMistralAI);
827
+ });
828
+
829
+ it('OpenRouter provider: with apiKey', async () => {
830
+ merge(initialSettingsState, {
831
+ settings: {
832
+ languageModel: {
833
+ openrouter: {
834
+ apiKey: 'user-openrouter-key',
835
+ },
836
+ },
837
+ },
838
+ } as GlobalSettingsState) as unknown as GlobalStore;
839
+ const runtime = await initializeWithClientStore(ModelProvider.OpenRouter, {});
840
+ expect(runtime).toBeInstanceOf(AgentRuntime);
841
+ expect(runtime['_runtime']).toBeInstanceOf(LobeOpenRouterAI);
842
+ });
843
+
844
+ it('TogetherAI provider: with apiKey', async () => {
845
+ merge(initialSettingsState, {
846
+ settings: {
847
+ languageModel: {
848
+ togetherai: {
849
+ apiKey: 'user-togetherai-key',
850
+ },
851
+ },
852
+ },
853
+ } as GlobalSettingsState) as unknown as GlobalStore;
854
+ const runtime = await initializeWithClientStore(ModelProvider.TogetherAI, {});
855
+ expect(runtime).toBeInstanceOf(AgentRuntime);
856
+ expect(runtime['_runtime']).toBeInstanceOf(LobeTogetherAI);
857
+ });
858
+
859
+ it('ZeroOneAI provider: with apiKey', async () => {
860
+ merge(initialSettingsState, {
861
+ settings: {
862
+ languageModel: {
863
+ zeroone: {
864
+ apiKey: 'user-zeroone-key',
865
+ },
866
+ },
867
+ },
868
+ } as GlobalSettingsState) as unknown as GlobalStore;
869
+ const runtime = await initializeWithClientStore(ModelProvider.ZeroOne, {});
870
+ expect(runtime).toBeInstanceOf(AgentRuntime);
871
+ expect(runtime['_runtime']).toBeInstanceOf(LobeZeroOneAI);
872
+ });
873
+
874
+ it('Groq provider: with apiKey', async () => {
875
+ merge(initialSettingsState, {
876
+ settings: {
877
+ languageModel: {
878
+ groq: {
879
+ apiKey: 'user-groq-key',
880
+ },
881
+ },
882
+ },
883
+ } as GlobalSettingsState) as unknown as GlobalStore;
884
+ const runtime = await initializeWithClientStore(ModelProvider.Groq, {});
885
+ expect(runtime).toBeInstanceOf(AgentRuntime);
886
+ expect(runtime['_runtime']).toBeInstanceOf(LobeGroq);
887
+ });
888
+
889
+ /**
890
+ * Should not have a unknown provider in client, but has
891
+ * similar cases in server side
892
+ */
893
+ it('Unknown provider: with apiKey', async () => {
894
+ merge(initialSettingsState, {
895
+ settings: {
896
+ languageModel: {
897
+ unknown: {
898
+ apiKey: 'user-unknown-key',
899
+ endpoint: 'user-unknown-endpoint',
900
+ },
901
+ },
902
+ },
903
+ } as any as GlobalSettingsState) as unknown as GlobalStore;
904
+ const runtime = await initializeWithClientStore('unknown' as ModelProvider, {});
905
+ expect(runtime).toBeInstanceOf(AgentRuntime);
906
+ expect(runtime['_runtime']).toBeInstanceOf(LobeOpenAI);
907
+ });
908
+
909
+ /**
910
+ * The following test cases need to be enforce
911
+ */
912
+
913
+ it('ZhiPu AI provider: with apiKey', async () => {
914
+ // Mock the generateApiToken function
915
+ vi.mock('@/libs/agent-runtime/zhipu/authToken', () => ({
916
+ generateApiToken: vi
917
+ .fn()
918
+ .mockResolvedValue(
919
+ 'eyJhbGciOiJIUzI1NiIsInNpZ25fdHlwZSI6IlNJR04iLCJ0eXAiOiJKV1QifQ.eyJhcGlfa2V5IjoiemhpcHUiLCJleHAiOjE3MTU5MTc2NzMsImlhdCI6MTcxMzMyNTY3M30.gt8o-hUDvJFPJLYcH4EhrT1LAmTXI8YnybHeQjpD9oM',
920
+ ),
921
+ }));
922
+ merge(initialSettingsState, {
923
+ settings: {
924
+ languageModel: {
925
+ zhipu: {
926
+ apiKey: 'zhipu.user-key',
927
+ },
928
+ },
929
+ },
930
+ } as GlobalSettingsState) as unknown as GlobalStore;
931
+ const runtime = await initializeWithClientStore(ModelProvider.ZhiPu, {});
932
+ expect(runtime).toBeInstanceOf(AgentRuntime);
933
+ expect(runtime['_runtime']).toBeInstanceOf(LobeZhipuAI);
934
+ });
935
+ });
936
+ });
937
+ });
@@ -2,13 +2,15 @@ import { PluginRequestPayload, createHeadersWithPluginSettings } from '@lobehub/
2
2
  import { produce } from 'immer';
3
3
  import { merge } from 'lodash-es';
4
4
 
5
+ import { createErrorResponse } from '@/app/api/errorResponse';
5
6
  import { DEFAULT_AGENT_CONFIG } from '@/const/settings';
6
7
  import { TracePayload, TraceTagMap } from '@/const/trace';
7
- import { ModelProvider } from '@/libs/agent-runtime';
8
+ import { AgentRuntime, ChatCompletionErrorPayload, ModelProvider } from '@/libs/agent-runtime';
8
9
  import { filesSelectors, useFileStore } from '@/store/file';
9
10
  import { useGlobalStore } from '@/store/global';
10
11
  import {
11
12
  commonSelectors,
13
+ modelConfigSelectors,
12
14
  modelProviderSelectors,
13
15
  preferenceSelectors,
14
16
  } from '@/store/global/selectors';
@@ -16,13 +18,14 @@ import { useSessionStore } from '@/store/session';
16
18
  import { agentSelectors } from '@/store/session/selectors';
17
19
  import { useToolStore } from '@/store/tool';
18
20
  import { pluginSelectors, toolSelectors } from '@/store/tool/selectors';
21
+ import { ChatErrorType } from '@/types/fetch';
19
22
  import { ChatMessage } from '@/types/message';
20
23
  import type { ChatStreamPayload, OpenAIChatMessage } from '@/types/openai/chat';
21
24
  import { UserMessageContentPart } from '@/types/openai/chat';
22
25
  import { FetchSSEOptions, OnFinishHandler, fetchSSE, getMessageError } from '@/utils/fetch';
23
26
  import { createTraceHeader, getTraceId } from '@/utils/trace';
24
27
 
25
- import { createHeaderWithAuth } from './_auth';
28
+ import { createHeaderWithAuth, getProviderAuthPayload } from './_auth';
26
29
  import { API_ENDPOINTS } from './_url';
27
30
 
28
31
  interface FetchOptions {
@@ -64,6 +67,119 @@ interface CreateAssistantMessageStream extends FetchSSEOptions {
64
67
  trace?: TracePayload;
65
68
  }
66
69
 
70
+ /**
71
+ * Initializes the AgentRuntime with the client store.
72
+ * @param provider - The provider name.
73
+ * @param payload - Init options
74
+ * @returns The initialized AgentRuntime instance
75
+ *
76
+ * **Note**: if you try to fetch directly, use `fetchOnClient` instead.
77
+ */
78
+ export function initializeWithClientStore(provider: string, payload: any) {
79
+ // add auth payload
80
+ const providerAuthPayload = getProviderAuthPayload(provider);
81
+ const commonOptions = {
82
+ // Some provider base openai sdk, so enable it run on browser
83
+ dangerouslyAllowBrowser: true,
84
+ };
85
+ let providerOptions = {};
86
+
87
+ switch (provider) {
88
+ default:
89
+ case ModelProvider.OpenAI: {
90
+ providerOptions = {
91
+ baseURL: providerAuthPayload?.endpoint,
92
+ };
93
+ break;
94
+ }
95
+ case ModelProvider.Azure: {
96
+ providerOptions = {
97
+ apiVersion: providerAuthPayload?.azureApiVersion,
98
+ // That's a wired properity, but just remapped it
99
+ apikey: providerAuthPayload?.apiKey,
100
+ };
101
+ break;
102
+ }
103
+ case ModelProvider.ZhiPu: {
104
+ break;
105
+ }
106
+ case ModelProvider.Google: {
107
+ providerOptions = {
108
+ baseURL: providerAuthPayload?.endpoint,
109
+ };
110
+ break;
111
+ }
112
+ case ModelProvider.Moonshot: {
113
+ break;
114
+ }
115
+ case ModelProvider.Bedrock: {
116
+ if (providerAuthPayload?.apiKey) {
117
+ providerOptions = {
118
+ accessKeyId: providerAuthPayload?.awsAccessKeyId,
119
+ accessKeySecret: providerAuthPayload?.awsSecretAccessKey,
120
+ region: providerAuthPayload?.awsRegion,
121
+ };
122
+ }
123
+ break;
124
+ }
125
+ case ModelProvider.Ollama: {
126
+ providerOptions = {
127
+ baseURL: providerAuthPayload?.endpoint,
128
+ };
129
+ break;
130
+ }
131
+ case ModelProvider.Perplexity: {
132
+ break;
133
+ }
134
+ case ModelProvider.Anthropic: {
135
+ providerOptions = {
136
+ baseURL: providerAuthPayload?.endpoint,
137
+ };
138
+ break;
139
+ }
140
+ case ModelProvider.Mistral: {
141
+ break;
142
+ }
143
+ case ModelProvider.Groq: {
144
+ break;
145
+ }
146
+ case ModelProvider.OpenRouter: {
147
+ break;
148
+ }
149
+ case ModelProvider.TogetherAI: {
150
+ break;
151
+ }
152
+ case ModelProvider.ZeroOne: {
153
+ break;
154
+ }
155
+ }
156
+
157
+ /**
158
+ * Configuration override order:
159
+ * payload -> providerOptions -> providerAuthPayload -> commonOptions
160
+ */
161
+ return AgentRuntime.initializeWithProviderOptions(provider, {
162
+ [provider]: {
163
+ ...commonOptions,
164
+ ...providerAuthPayload,
165
+ ...providerOptions,
166
+ ...payload,
167
+ },
168
+ });
169
+ }
170
+
171
+ /**
172
+ * Fetch chat completion on the client side.
173
+ * @param provider - The provider name.
174
+ * @param payload - The payload data for the chat stream.
175
+ * @returns A promise that resolves to the chat response.
176
+ */
177
+ export async function fetchOnClient(provider: string, payload: Partial<ChatStreamPayload>) {
178
+ const agentRuntime = await initializeWithClientStore(provider, payload);
179
+ const data = payload as ChatStreamPayload;
180
+ return await agentRuntime.chat(data);
181
+ }
182
+
67
183
  class ChatService {
68
184
  createAssistantMessage = async (
69
185
  { plugins: enabledPlugins, messages, ...params }: GetChatCompletionPayload,
@@ -149,6 +265,36 @@ class ChatService {
149
265
  { ...res, model },
150
266
  );
151
267
 
268
+ /**
269
+ * Use browser agent runtime
270
+ */
271
+ const enableFetchOnClient = modelConfigSelectors.isProviderFetchOnClient(provider)(
272
+ useGlobalStore.getState(),
273
+ );
274
+ /**
275
+ * Notes:
276
+ * 1. Broswer agent runtime will skip auth check if a key and endpoint provided by
277
+ * user which will cause abuse of plugins services
278
+ * 2. This feature will disabled by default
279
+ */
280
+ if (enableFetchOnClient) {
281
+ try {
282
+ return await fetchOnClient(provider, payload);
283
+ } catch (e) {
284
+ const {
285
+ errorType = ChatErrorType.BadRequest,
286
+ error: errorContent,
287
+ ...res
288
+ } = e as ChatCompletionErrorPayload;
289
+
290
+ const error = errorContent || e;
291
+ // track the error at server side
292
+ console.error(`Route: [${provider}] ${errorType}:`, error);
293
+
294
+ return createErrorResponse(errorType, { error, ...res, provider });
295
+ }
296
+ }
297
+
152
298
  const traceHeader = createTraceHeader({ ...options?.trace });
153
299
 
154
300
  const headers = await createHeaderWithAuth({
@@ -1,6 +1,7 @@
1
1
  /* eslint-disable sort-keys-fix/sort-keys-fix, typescript-sort-keys/interface */
2
2
  // Disable the auto sort key eslint rule to make the code more logic and readable
3
3
  import { copyToClipboard } from '@lobehub/ui';
4
+ import { produce } from 'immer';
4
5
  import { template } from 'lodash-es';
5
6
  import { SWRResponse, mutate } from 'swr';
6
7
  import { StateCreator } from 'zustand/vanilla';
@@ -19,6 +20,7 @@ import { agentSelectors } from '@/store/session/selectors';
19
20
  import { ChatMessage } from '@/types/message';
20
21
  import { TraceEventPayloads } from '@/types/trace';
21
22
  import { setNamespace } from '@/utils/storeDebug';
23
+ import { nanoid } from '@/utils/uuid';
22
24
 
23
25
  import { chatSelectors } from '../../selectors';
24
26
  import { MessageDispatch, messagesReducer } from './reducer';
@@ -97,6 +99,7 @@ export interface ChatMessageAction {
97
99
  id?: string,
98
100
  action?: string,
99
101
  ) => AbortController | undefined;
102
+ toggleMessageLoading: (loading: boolean, id: string) => void;
100
103
  refreshMessages: () => Promise<void>;
101
104
  // TODO: 后续 smoothMessage 实现考虑落到 sse 这一层
102
105
  createSmoothMessage: (id: string) => {
@@ -111,6 +114,7 @@ export interface ChatMessageAction {
111
114
  * @param content
112
115
  */
113
116
  internalUpdateMessageContent: (id: string, content: string) => Promise<void>;
117
+ internalCreateMessage: (params: CreateMessageParams) => Promise<string>;
114
118
  internalResendMessage: (id: string, traceId?: string) => Promise<void>;
115
119
  internalTraceMessage: (id: string, payload: TraceEventPayloads) => Promise<void>;
116
120
  }
@@ -130,6 +134,7 @@ export const chatMessage: StateCreator<
130
134
  ChatMessageAction
131
135
  > = (set, get) => ({
132
136
  deleteMessage: async (id) => {
137
+ get().dispatchMessage({ type: 'deleteMessage', id });
133
138
  await messageService.removeMessage(id);
134
139
  await get().refreshMessages();
135
140
  },
@@ -167,43 +172,6 @@ export const chatMessage: StateCreator<
167
172
  await messageService.removeAllMessages();
168
173
  await refreshMessages();
169
174
  },
170
- internalResendMessage: async (messageId, traceId) => {
171
- // 1. 构造所有相关的历史记录
172
- const chats = chatSelectors.currentChats(get());
173
-
174
- const currentIndex = chats.findIndex((c) => c.id === messageId);
175
- if (currentIndex < 0) return;
176
-
177
- const currentMessage = chats[currentIndex];
178
-
179
- let contextMessages: ChatMessage[] = [];
180
-
181
- switch (currentMessage.role) {
182
- case 'function':
183
- case 'user': {
184
- contextMessages = chats.slice(0, currentIndex + 1);
185
- break;
186
- }
187
- case 'assistant': {
188
- // 消息是 AI 发出的因此需要找到它的 user 消息
189
- const userId = currentMessage.parentId;
190
- const userIndex = chats.findIndex((c) => c.id === userId);
191
- // 如果消息没有 parentId,那么同 user/function 模式
192
- contextMessages = chats.slice(0, userIndex < 0 ? currentIndex + 1 : userIndex + 1);
193
- break;
194
- }
195
- }
196
-
197
- if (contextMessages.length <= 0) return;
198
-
199
- const { coreProcessMessage } = get();
200
-
201
- const latestMsg = contextMessages.filter((s) => s.role === 'user').at(-1);
202
-
203
- if (!latestMsg) return;
204
-
205
- await coreProcessMessage(contextMessages, latestMsg.id, traceId);
206
- },
207
175
  sendMessage: async ({ message, files, onlyAddUserMessage }) => {
208
176
  const { coreProcessMessage, activeTopicId, activeId } = get();
209
177
  if (!activeId) return;
@@ -223,8 +191,7 @@ export const chatMessage: StateCreator<
223
191
  topicId: activeTopicId,
224
192
  };
225
193
 
226
- const id = await messageService.createMessage(newMessage);
227
- await get().refreshMessages();
194
+ const id = await get().internalCreateMessage(newMessage);
228
195
 
229
196
  // if only add user message, then stop
230
197
  if (onlyAddUserMessage) return;
@@ -315,8 +282,7 @@ export const chatMessage: StateCreator<
315
282
  topicId: activeTopicId, // if there is activeTopicId,then add it to topicId
316
283
  };
317
284
 
318
- const mid = await messageService.createMessage(assistantMessage);
319
- await refreshMessages();
285
+ const mid = await get().internalCreateMessage(assistantMessage);
320
286
 
321
287
  // 2. fetch the AI response
322
288
  const { isFunctionCall, content, functionCallAtEnd, functionCallContent, traceId } =
@@ -344,7 +310,7 @@ export const chatMessage: StateCreator<
344
310
  traceId,
345
311
  };
346
312
 
347
- functionId = await messageService.createMessage(functionMessage);
313
+ functionId = await get().internalCreateMessage(functionMessage);
348
314
  }
349
315
 
350
316
  await refreshMessages();
@@ -533,6 +499,62 @@ export const chatMessage: StateCreator<
533
499
  window.removeEventListener('beforeunload', preventLeavingFn);
534
500
  }
535
501
  },
502
+ toggleMessageLoading: (loading, id) => {
503
+ set(
504
+ {
505
+ messageLoadingIds: produce(get().messageLoadingIds, (draft) => {
506
+ if (loading) {
507
+ draft.push(id);
508
+ } else {
509
+ const index = draft.indexOf(id);
510
+
511
+ if (index >= 0) draft.splice(index, 1);
512
+ }
513
+ }),
514
+ },
515
+ false,
516
+ 'toggleMessageLoading',
517
+ );
518
+ },
519
+
520
+ internalResendMessage: async (messageId, traceId) => {
521
+ // 1. 构造所有相关的历史记录
522
+ const chats = chatSelectors.currentChats(get());
523
+
524
+ const currentIndex = chats.findIndex((c) => c.id === messageId);
525
+ if (currentIndex < 0) return;
526
+
527
+ const currentMessage = chats[currentIndex];
528
+
529
+ let contextMessages: ChatMessage[] = [];
530
+
531
+ switch (currentMessage.role) {
532
+ case 'function':
533
+ case 'user': {
534
+ contextMessages = chats.slice(0, currentIndex + 1);
535
+ break;
536
+ }
537
+ case 'assistant': {
538
+ // 消息是 AI 发出的因此需要找到它的 user 消息
539
+ const userId = currentMessage.parentId;
540
+ const userIndex = chats.findIndex((c) => c.id === userId);
541
+ // 如果消息没有 parentId,那么同 user/function 模式
542
+ contextMessages = chats.slice(0, userIndex < 0 ? currentIndex + 1 : userIndex + 1);
543
+ break;
544
+ }
545
+ }
546
+
547
+ if (contextMessages.length <= 0) return;
548
+
549
+ const { coreProcessMessage } = get();
550
+
551
+ const latestMsg = contextMessages.filter((s) => s.role === 'user').at(-1);
552
+
553
+ if (!latestMsg) return;
554
+
555
+ await coreProcessMessage(contextMessages, latestMsg.id, traceId);
556
+ },
557
+
536
558
  internalUpdateMessageContent: async (id, content) => {
537
559
  const { dispatchMessage, refreshMessages } = get();
538
560
 
@@ -545,6 +567,22 @@ export const chatMessage: StateCreator<
545
567
  await refreshMessages();
546
568
  },
547
569
 
570
+ internalCreateMessage: async (message) => {
571
+ const { dispatchMessage, refreshMessages, toggleMessageLoading } = get();
572
+
573
+ // use optimistic update to avoid the slow waiting
574
+ const tempId = 'tmp_' + nanoid();
575
+ dispatchMessage({ type: 'createMessage', id: tempId, value: message });
576
+
577
+ toggleMessageLoading(true, tempId);
578
+ const id = await messageService.createMessage(message);
579
+
580
+ await refreshMessages();
581
+ toggleMessageLoading(false, tempId);
582
+
583
+ return id;
584
+ },
585
+
548
586
  createSmoothMessage: (id) => {
549
587
  const { dispatchMessage } = get();
550
588
 
@@ -9,7 +9,7 @@ export interface ChatMessageState {
9
9
  activeId: string;
10
10
  chatLoadingId?: string;
11
11
  inputMessage: string;
12
- messageLoadingIds: [];
12
+ messageLoadingIds: string[];
13
13
  messages: ChatMessage[];
14
14
  /**
15
15
  * whether messages have fetched