@lobehub/chat 0.149.3 → 0.149.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. package/.github/FUNDING.yml +1 -1
  2. package/CHANGELOG.md +58 -0
  3. package/package.json +1 -1
  4. package/src/app/chat/(desktop)/features/ChatHeader/Main.tsx +5 -5
  5. package/src/app/chat/(desktop)/features/ChatHeader/Tags.tsx +3 -3
  6. package/src/app/chat/(desktop)/features/ChatInput/Footer/DragUpload.tsx +9 -9
  7. package/src/app/chat/(desktop)/features/ChatInput/Footer/index.tsx +3 -3
  8. package/src/app/chat/(desktop)/features/SideBar/SystemRole/index.tsx +8 -3
  9. package/src/app/chat/(mobile)/mobile/ChatHeader/ChatHeaderTitle.tsx +2 -2
  10. package/src/app/chat/(mobile)/mobile/page.tsx +0 -6
  11. package/src/app/chat/_layout/Desktop/SessionList.tsx +2 -0
  12. package/src/app/chat/features/PageTitle/index.tsx +3 -3
  13. package/src/app/chat/features/PluginTag/PluginStatus.tsx +2 -2
  14. package/src/app/chat/features/SessionListContent/DefaultMode.tsx +4 -2
  15. package/src/app/chat/features/SessionListContent/List/Item/index.tsx +10 -17
  16. package/src/app/chat/features/SessionListContent/index.tsx +2 -0
  17. package/src/app/chat/features/ShareButton/Preview.tsx +15 -11
  18. package/src/app/chat/features/ShareButton/useScreenshot.ts +2 -2
  19. package/src/app/chat/settings/features/EditPage.tsx +10 -7
  20. package/src/app/chat/settings/features/SubmitAgentButton/SubmitAgentModal.tsx +5 -3
  21. package/src/app/metadata.ts +3 -3
  22. package/src/app/settings/(mobile)/features/AvatarBanner.tsx +1 -0
  23. package/src/config/modelProviders/ollama.ts +11 -12
  24. package/src/const/session.ts +1 -0
  25. package/src/database/client/models/session.ts +1 -0
  26. package/src/database/client/models/user.ts +6 -0
  27. package/src/features/ChatInput/ActionBar/FileUpload.tsx +11 -5
  28. package/src/features/ChatInput/ActionBar/History.tsx +3 -3
  29. package/src/features/ChatInput/ActionBar/ModelSwitch.tsx +2 -0
  30. package/src/features/ChatInput/ActionBar/Temperature.tsx +3 -3
  31. package/src/features/ChatInput/ActionBar/Token/TokenTag.tsx +4 -4
  32. package/src/features/ChatInput/ActionBar/Token/index.tsx +3 -3
  33. package/src/features/ChatInput/ActionBar/Tools/ToolItem.tsx +3 -3
  34. package/src/features/ChatInput/ActionBar/Tools/index.tsx +4 -4
  35. package/src/features/ChatInput/STT/browser.tsx +3 -3
  36. package/src/features/ChatInput/STT/openai.tsx +3 -3
  37. package/src/features/ChatInput/useChatInput.ts +3 -3
  38. package/src/features/Conversation/Extras/Assistant.test.tsx +7 -7
  39. package/src/features/Conversation/Extras/Assistant.tsx +3 -3
  40. package/src/features/Conversation/Extras/TTS/index.tsx +3 -3
  41. package/src/features/Conversation/components/ChatItem/ActionsBar.tsx +2 -2
  42. package/src/features/Conversation/components/ChatItem/index.tsx +6 -4
  43. package/src/features/Conversation/hooks/useInitConversation.ts +10 -7
  44. package/src/features/Conversation/index.tsx +6 -3
  45. package/src/features/ModelSwitchPanel/index.tsx +6 -4
  46. package/src/hooks/useTTS.ts +4 -4
  47. package/src/libs/agent-runtime/anthropic/index.test.ts +44 -32
  48. package/src/libs/agent-runtime/anthropic/index.ts +12 -9
  49. package/src/libs/agent-runtime/azureOpenai/index.ts +3 -4
  50. package/src/libs/agent-runtime/bedrock/index.ts +1 -1
  51. package/src/libs/agent-runtime/ollama/index.ts +7 -0
  52. package/src/libs/agent-runtime/perplexity/index.ts +1 -0
  53. package/src/libs/agent-runtime/types/chat.ts +2 -1
  54. package/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts +1 -0
  55. package/src/services/chat.ts +18 -15
  56. package/src/services/session/client.ts +19 -0
  57. package/src/services/session/type.ts +2 -0
  58. package/src/store/agent/index.ts +2 -0
  59. package/src/store/agent/initialState.ts +7 -0
  60. package/src/store/agent/selectors.ts +1 -0
  61. package/src/store/{session/slices/agent → agent/slices/chat}/action.test.ts +26 -63
  62. package/src/store/agent/slices/chat/action.ts +107 -0
  63. package/src/store/agent/slices/chat/initialState.ts +14 -0
  64. package/src/store/agent/slices/chat/selectors.test.ts +82 -0
  65. package/src/store/agent/slices/chat/selectors.ts +81 -0
  66. package/src/store/agent/store.ts +27 -0
  67. package/src/store/chat/slices/message/action.test.ts +3 -2
  68. package/src/store/chat/slices/message/action.ts +3 -3
  69. package/src/store/chat/slices/message/selectors.test.ts +9 -2
  70. package/src/store/chat/slices/message/selectors.ts +6 -4
  71. package/src/store/chat/slices/share/action.ts +5 -3
  72. package/src/store/global/slices/preference/selectors.ts +3 -1
  73. package/src/store/session/selectors.ts +1 -2
  74. package/src/store/session/slices/session/action.test.ts +43 -0
  75. package/src/store/session/slices/session/action.ts +28 -18
  76. package/src/store/session/slices/session/helpers.ts +2 -3
  77. package/src/store/session/slices/session/initialState.ts +1 -17
  78. package/src/store/session/slices/session/selectors/index.ts +1 -0
  79. package/src/store/session/slices/session/selectors/list.test.ts +5 -3
  80. package/src/store/session/slices/session/selectors/list.ts +2 -3
  81. package/src/store/session/slices/session/selectors/meta.test.ts +108 -0
  82. package/src/store/session/slices/session/selectors/meta.ts +45 -0
  83. package/src/store/session/store.ts +1 -7
  84. package/src/types/session.ts +1 -0
  85. package/src/store/session/slices/agent/action.ts +0 -84
  86. package/src/store/session/slices/agent/selectors.test.ts +0 -180
  87. package/src/store/session/slices/agent/selectors.ts +0 -129
  88. /package/src/store/{session/slices/agent → agent/slices/chat}/index.ts +0 -0
@@ -32,15 +32,18 @@ export class LobeAnthropicAI implements LobeRuntimeAI {
32
32
  const user_messages = messages.filter((m) => m.role !== 'system');
33
33
 
34
34
  try {
35
- const response = await this.client.messages.create({
36
- max_tokens: max_tokens || 4096,
37
- messages: buildAnthropicMessages(user_messages),
38
- model: model,
39
- stream: true,
40
- system: system_message?.content as string,
41
- temperature: temperature,
42
- top_p: top_p,
43
- });
35
+ const response = await this.client.messages.create(
36
+ {
37
+ max_tokens: max_tokens || 4096,
38
+ messages: buildAnthropicMessages(user_messages),
39
+ model: model,
40
+ stream: true,
41
+ system: system_message?.content as string,
42
+ temperature: temperature,
43
+ top_p: top_p,
44
+ },
45
+ { signal: options?.signal },
46
+ );
44
47
 
45
48
  const [prod, debug] = response.tee();
46
49
 
@@ -8,7 +8,7 @@ import { OpenAIStream, StreamingTextResponse } from 'ai';
8
8
 
9
9
  import { LobeRuntimeAI } from '../BaseAI';
10
10
  import { AgentRuntimeErrorType } from '../error';
11
- import { ChatStreamPayload, ModelProvider } from '../types';
11
+ import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
12
12
  import { AgentRuntimeError } from '../utils/createError';
13
13
  import { debugStream } from '../utils/debugStream';
14
14
 
@@ -26,7 +26,7 @@ export class LobeAzureOpenAI implements LobeRuntimeAI {
26
26
 
27
27
  baseURL: string;
28
28
 
29
- async chat(payload: ChatStreamPayload) {
29
+ async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) {
30
30
  // ============ 1. preprocess messages ============ //
31
31
  const { messages, model, ...params } = payload;
32
32
 
@@ -36,10 +36,9 @@ export class LobeAzureOpenAI implements LobeRuntimeAI {
36
36
  const response = await this.client.streamChatCompletions(
37
37
  model,
38
38
  messages as ChatRequestMessage[],
39
- params as GetChatCompletionsOptions,
39
+ { ...params, abortSignal: options?.signal } as GetChatCompletionsOptions,
40
40
  );
41
41
 
42
- // TODO: we need to refactor this part in the future
43
42
  const stream = OpenAIStream(response as any);
44
43
 
45
44
  const [debug, prod] = stream.tee();
@@ -68,7 +68,7 @@ export class LobeBedrockAI implements LobeRuntimeAI {
68
68
 
69
69
  try {
70
70
  // Ask Claude for a streaming chat completion given the prompt
71
- const bedrockResponse = await this.client.send(command);
71
+ const bedrockResponse = await this.client.send(command, { abortSignal: options?.signal });
72
72
 
73
73
  // Convert the response into a friendly text-stream
74
74
  const stream = AWSBedrockStream(
@@ -31,6 +31,13 @@ export class LobeOllamaAI implements LobeRuntimeAI {
31
31
 
32
32
  async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) {
33
33
  try {
34
+ const abort = () => {
35
+ this.client.abort();
36
+ options?.signal?.removeEventListener('abort', abort);
37
+ };
38
+
39
+ options?.signal?.addEventListener('abort', abort);
40
+
34
41
  const response = await this.client.chat({
35
42
  messages: this.buildOllamaMessages(payload.messages),
36
43
  model: payload.model,
@@ -33,6 +33,7 @@ export class LobePerplexityAI implements LobeRuntimeAI {
33
33
  };
34
34
  const response = await this.client.chat.completions.create(
35
35
  chatPayload as unknown as OpenAI.ChatCompletionCreateParamsStreaming,
36
+ { signal: options?.signal },
36
37
  );
37
38
  const [prod, debug] = response.tee();
38
39
 
@@ -90,8 +90,9 @@ export interface ChatStreamPayload {
90
90
  }
91
91
 
92
92
  export interface ChatCompetitionOptions {
93
- callback: ChatStreamCallbacks;
93
+ callback?: ChatStreamCallbacks;
94
94
  headers?: Record<string, any>;
95
+ signal?: AbortSignal;
95
96
  }
96
97
 
97
98
  export interface ChatCompletionFunctions {
@@ -76,6 +76,7 @@ export const LobeOpenAICompatibleFactory = ({
76
76
  const response = await this.client.chat.completions.create(postPayload, {
77
77
  // https://github.com/lobehub/lobe-chat/pull/318
78
78
  headers: { Accept: '*/*' },
79
+ signal: options?.signal,
79
80
  });
80
81
 
81
82
  const [prod, useForDebug] = response.tee();
@@ -15,7 +15,7 @@ import {
15
15
  preferenceSelectors,
16
16
  } from '@/store/global/selectors';
17
17
  import { useSessionStore } from '@/store/session';
18
- import { agentSelectors } from '@/store/session/selectors';
18
+ import { sessionMetaSelectors } from '@/store/session/selectors';
19
19
  import { useToolStore } from '@/store/tool';
20
20
  import { pluginSelectors, toolSelectors } from '@/store/tool/selectors';
21
21
  import { ChatErrorType } from '@/types/fetch';
@@ -168,18 +168,6 @@ export function initializeWithClientStore(provider: string, payload: any) {
168
168
  });
169
169
  }
170
170
 
171
- /**
172
- * Fetch chat completion on the client side.
173
- * @param provider - The provider name.
174
- * @param payload - The payload data for the chat stream.
175
- * @returns A promise that resolves to the chat response.
176
- */
177
- export async function fetchOnClient(provider: string, payload: Partial<ChatStreamPayload>) {
178
- const agentRuntime = await initializeWithClientStore(provider, payload);
179
- const data = payload as ChatStreamPayload;
180
- return await agentRuntime.chat(data);
181
- }
182
-
183
171
  class ChatService {
184
172
  createAssistantMessage = async (
185
173
  { plugins: enabledPlugins, messages, ...params }: GetChatCompletionPayload,
@@ -279,7 +267,7 @@ class ChatService {
279
267
  */
280
268
  if (enableFetchOnClient) {
281
269
  try {
282
- return await fetchOnClient(provider, payload);
270
+ return await this.fetchOnClient({ payload, provider, signal });
283
271
  } catch (e) {
284
272
  const {
285
273
  errorType = ChatErrorType.BadRequest,
@@ -459,7 +447,7 @@ class ChatService {
459
447
  };
460
448
 
461
449
  private mapTrace(trace?: TracePayload, tag?: TraceTagMap): TracePayload {
462
- const tags = agentSelectors.currentAgentMeta(useSessionStore.getState()).tags || [];
450
+ const tags = sessionMetaSelectors.currentAgentMeta(useSessionStore.getState()).tags || [];
463
451
 
464
452
  const enabled = preferenceSelectors.userAllowTrace(useGlobalStore.getState());
465
453
 
@@ -472,6 +460,21 @@ class ChatService {
472
460
  userId: commonSelectors.userId(useGlobalStore.getState()),
473
461
  };
474
462
  }
463
+
464
+ /**
465
+ * Fetch chat completion on the client side.
466
+
467
+ */
468
+ private fetchOnClient = async (params: {
469
+ payload: Partial<ChatStreamPayload>;
470
+ provider: string;
471
+ signal?: AbortSignal;
472
+ }) => {
473
+ const agentRuntime = await initializeWithClientStore(params.provider, params.payload);
474
+ const data = params.payload as ChatStreamPayload;
475
+
476
+ return agentRuntime.chat(data, { signal: params.signal });
477
+ };
475
478
  }
476
479
 
477
480
  export const chatService = new ChatService();
@@ -1,7 +1,10 @@
1
1
  import { DeepPartial } from 'utility-types';
2
2
 
3
+ import { INBOX_SESSION_ID } from '@/const/session';
3
4
  import { SessionModel } from '@/database/client/models/session';
4
5
  import { SessionGroupModel } from '@/database/client/models/sessionGroup';
6
+ import { UserModel } from '@/database/client/models/user';
7
+ import { useGlobalStore } from '@/store/global';
5
8
  import { LobeAgentConfig } from '@/types/agent';
6
9
  import {
7
10
  ChatSessionList,
@@ -40,6 +43,18 @@ export class ClientService implements ISessionService {
40
43
  return SessionModel.queryWithGroups();
41
44
  }
42
45
 
46
+ async getSessionConfig(id: string): Promise<LobeAgentConfig> {
47
+ if (!id || id === INBOX_SESSION_ID) {
48
+ return UserModel.getAgentConfig();
49
+ }
50
+
51
+ const res = await SessionModel.findById(id);
52
+
53
+ if (!res) throw new Error('Session not found');
54
+
55
+ return res.config as LobeAgentConfig;
56
+ }
57
+
43
58
  async getSessionsByType(type: 'agent' | 'group' | 'all' = 'all'): Promise<LobeSessions> {
44
59
  switch (type) {
45
60
  // TODO: add a filter to get only agents or agents
@@ -81,6 +96,10 @@ export class ClientService implements ISessionService {
81
96
  }
82
97
 
83
98
  async updateSessionConfig(activeId: string, config: DeepPartial<LobeAgentConfig>) {
99
+ if (activeId === INBOX_SESSION_ID) {
100
+ return useGlobalStore.getState().updateDefaultAgent({ config });
101
+ }
102
+
84
103
  return SessionModel.updateConfig(activeId, config);
85
104
  }
86
105
 
@@ -28,6 +28,8 @@ export interface ISessionService {
28
28
  id: string,
29
29
  data: Partial<{ group?: SessionGroupId; pinned?: boolean }>,
30
30
  ): Promise<any>;
31
+
32
+ getSessionConfig(id: string): Promise<LobeAgentConfig>;
31
33
  updateSessionConfig(id: string, config: DeepPartial<LobeAgentConfig>): Promise<any>;
32
34
 
33
35
  removeSession(id: string): Promise<any>;
@@ -0,0 +1,2 @@
1
+ export type { AgentStore } from './store';
2
+ export { useAgentStore } from './store';
@@ -0,0 +1,7 @@
1
+ import { AgentState, initialSessionState } from './slices/chat/initialState';
2
+
3
+ export type SessionStoreState = AgentState;
4
+
5
+ export const initialState: SessionStoreState = {
6
+ ...initialSessionState,
7
+ };
@@ -0,0 +1 @@
1
+ export { agentSelectors } from './slices/chat/selectors';
@@ -3,14 +3,14 @@ import * as immer from 'immer';
3
3
  import { describe, expect, it, vi } from 'vitest';
4
4
 
5
5
  import { sessionService } from '@/services/session';
6
+ import { useAgentStore } from '@/store/agent';
7
+ import { agentSelectors } from '@/store/agent/selectors';
6
8
  import { useGlobalStore } from '@/store/global';
7
- import { useSessionStore } from '@/store/session';
8
- import { agentSelectors, sessionSelectors } from '@/store/session/selectors';
9
9
 
10
10
  describe('AgentSlice', () => {
11
11
  describe('removePlugin', () => {
12
12
  it('should call togglePlugin with the provided id and false', async () => {
13
- const { result } = renderHook(() => useSessionStore());
13
+ const { result } = renderHook(() => useAgentStore());
14
14
  const pluginId = 'plugin-id';
15
15
  const togglePluginMock = vi.spyOn(result.current, 'togglePlugin');
16
16
 
@@ -25,7 +25,7 @@ describe('AgentSlice', () => {
25
25
 
26
26
  describe('togglePlugin', () => {
27
27
  it('should add plugin id to plugins array if not present and open is true or undefined', async () => {
28
- const { result } = renderHook(() => useSessionStore());
28
+ const { result } = renderHook(() => useAgentStore());
29
29
  const pluginId = 'plugin-id';
30
30
  const updateAgentConfigMock = vi.spyOn(result.current, 'updateAgentConfig');
31
31
 
@@ -43,7 +43,7 @@ describe('AgentSlice', () => {
43
43
  });
44
44
 
45
45
  it('should remove plugin id from plugins array if present and open is false', async () => {
46
- const { result } = renderHook(() => useSessionStore());
46
+ const { result } = renderHook(() => useAgentStore());
47
47
  const pluginId = 'plugin-id';
48
48
  const updateAgentConfigMock = vi.spyOn(result.current, 'updateAgentConfig');
49
49
 
@@ -61,7 +61,7 @@ describe('AgentSlice', () => {
61
61
  });
62
62
 
63
63
  it('should not modify plugins array if plugin id is not present and open is false', async () => {
64
- const { result } = renderHook(() => useSessionStore());
64
+ const { result } = renderHook(() => useAgentStore());
65
65
  const pluginId = 'plugin-id';
66
66
  const updateAgentConfigMock = vi.spyOn(result.current, 'updateAgentConfig');
67
67
 
@@ -79,49 +79,53 @@ describe('AgentSlice', () => {
79
79
 
80
80
  describe('updateAgentConfig', () => {
81
81
  it('should update global config if current session is inbox session', async () => {
82
- const { result } = renderHook(() => useSessionStore());
82
+ const { result } = renderHook(() => useAgentStore());
83
83
  const config = { model: 'gpt-3.5-turbo' };
84
- const updateDefaultAgentMock = vi.spyOn(useGlobalStore.getState(), 'updateDefaultAgent');
85
-
86
- // 模拟当前会话是收件箱会话
87
- vi.spyOn(sessionSelectors, 'isInboxSession').mockReturnValue(true);
84
+ const updateSessionConfigMock = vi.spyOn(sessionService, 'updateSessionConfig');
85
+ const refreshMock = vi.spyOn(result.current, 'internal_refreshAgentConfig');
88
86
 
89
87
  await act(async () => {
90
88
  await result.current.updateAgentConfig(config);
91
89
  });
92
90
 
93
- expect(updateDefaultAgentMock).toHaveBeenCalledWith({ config });
94
- updateDefaultAgentMock.mockRestore();
91
+ expect(updateSessionConfigMock).toHaveBeenCalledWith('inbox', config);
92
+ expect(refreshMock).toHaveBeenCalled();
93
+ updateSessionConfigMock.mockRestore();
94
+ refreshMock.mockRestore();
95
95
  });
96
96
 
97
97
  it('should update session config if current session is not inbox session', async () => {
98
- const { result } = renderHook(() => useSessionStore());
98
+ const { result } = renderHook(() => useAgentStore());
99
99
  const config = { model: 'gpt-3.5-turbo' };
100
100
  const updateSessionConfigMock = vi.spyOn(sessionService, 'updateSessionConfig');
101
- const refreshSessionsMock = vi.spyOn(result.current, 'refreshSessions');
101
+ const refreshMock = vi.spyOn(result.current, 'internal_refreshAgentConfig');
102
102
 
103
103
  // 模拟当前会话不是收件箱会话
104
- vi.spyOn(sessionSelectors, 'isInboxSession').mockReturnValue(false);
105
- vi.spyOn(sessionSelectors, 'currentSession').mockReturnValue({ id: 'session-id' } as any);
106
- vi.spyOn(result.current, 'activeId', 'get').mockReturnValue('session-id');
104
+ act(() => {
105
+ useAgentStore.setState({
106
+ activeId: 'session-id',
107
+ });
108
+ });
107
109
 
108
110
  await act(async () => {
109
111
  await result.current.updateAgentConfig(config);
110
112
  });
111
113
 
112
114
  expect(updateSessionConfigMock).toHaveBeenCalledWith('session-id', config);
113
- expect(refreshSessionsMock).toHaveBeenCalled();
115
+ expect(refreshMock).toHaveBeenCalled();
114
116
  updateSessionConfigMock.mockRestore();
115
- refreshSessionsMock.mockRestore();
117
+ refreshMock.mockRestore();
116
118
  });
117
119
 
118
120
  it('should not update config if there is no current session', async () => {
119
- const { result } = renderHook(() => useSessionStore());
121
+ const { result } = renderHook(() => useAgentStore());
120
122
  const config = { model: 'gpt-3.5-turbo' };
121
123
  const updateSessionConfigMock = vi.spyOn(sessionService, 'updateSessionConfig');
122
124
 
123
125
  // 模拟没有当前会话
124
- vi.spyOn(sessionSelectors, 'currentSession').mockReturnValue(null as any);
126
+ act(() => {
127
+ useAgentStore.setState({ activeId: null as any });
128
+ });
125
129
 
126
130
  await act(async () => {
127
131
  await result.current.updateAgentConfig(config);
@@ -131,45 +135,4 @@ describe('AgentSlice', () => {
131
135
  updateSessionConfigMock.mockRestore();
132
136
  });
133
137
  });
134
-
135
- describe('updateAgentMeta', () => {
136
- it('should not update meta if there is no current session', async () => {
137
- const { result } = renderHook(() => useSessionStore());
138
- const meta = { title: 'Test Agent' };
139
- const updateSessionMock = vi.spyOn(sessionService, 'updateSession');
140
- const refreshSessionsMock = vi.spyOn(result.current, 'refreshSessions');
141
-
142
- // 模拟没有当前会话
143
- vi.spyOn(sessionSelectors, 'currentSession').mockReturnValue(null as any);
144
-
145
- await act(async () => {
146
- await result.current.updateAgentMeta(meta as any);
147
- });
148
-
149
- expect(updateSessionMock).not.toHaveBeenCalled();
150
- expect(refreshSessionsMock).not.toHaveBeenCalled();
151
- updateSessionMock.mockRestore();
152
- refreshSessionsMock.mockRestore();
153
- });
154
-
155
- it('should update session meta and refresh sessions', async () => {
156
- const { result } = renderHook(() => useSessionStore());
157
- const meta = { title: 'Test Agent' };
158
- const updateSessionMock = vi.spyOn(sessionService, 'updateSession');
159
- const refreshSessionsMock = vi.spyOn(result.current, 'refreshSessions');
160
-
161
- // 模拟有当前会话
162
- vi.spyOn(sessionSelectors, 'currentSession').mockReturnValue({ id: 'session-id' } as any);
163
- vi.spyOn(result.current, 'activeId', 'get').mockReturnValue('session-id');
164
-
165
- await act(async () => {
166
- await result.current.updateAgentMeta(meta);
167
- });
168
-
169
- expect(updateSessionMock).toHaveBeenCalledWith('session-id', { meta });
170
- expect(refreshSessionsMock).toHaveBeenCalled();
171
- updateSessionMock.mockRestore();
172
- refreshSessionsMock.mockRestore();
173
- });
174
- });
175
138
  });
@@ -0,0 +1,107 @@
1
+ import isEqual from 'fast-deep-equal';
2
+ import { produce } from 'immer';
3
+ import { SWRResponse, mutate } from 'swr';
4
+ import { DeepPartial } from 'utility-types';
5
+ import { StateCreator } from 'zustand/vanilla';
6
+
7
+ import { useClientDataSWR } from '@/libs/swr';
8
+ import { sessionService } from '@/services/session';
9
+ import { useSessionStore } from '@/store/session';
10
+ import { LobeAgentConfig } from '@/types/agent';
11
+ import { merge } from '@/utils/merge';
12
+
13
+ import { AgentStore } from '../../store';
14
+ import { agentSelectors } from './selectors';
15
+
16
+ /**
17
+ * 助手接口
18
+ */
19
+ export interface AgentChatAction {
20
+ removePlugin: (id: string) => void;
21
+ togglePlugin: (id: string, open?: boolean) => Promise<void>;
22
+ updateAgentConfig: (config: Partial<LobeAgentConfig>) => Promise<void>;
23
+
24
+ useFetchAgentConfig: (id: string) => SWRResponse<LobeAgentConfig>;
25
+
26
+ /* eslint-disable typescript-sort-keys/interface */
27
+
28
+ internal_updateAgentConfig: (id: string, data: DeepPartial<LobeAgentConfig>) => Promise<void>;
29
+ internal_refreshAgentConfig: (id: string) => Promise<void>;
30
+ /* eslint-enable */
31
+ }
32
+
33
+ const FETCH_AGENT_CONFIG_KEY = 'FETCH_AGENT_CONFIG';
34
+
35
+ export const createChatSlice: StateCreator<
36
+ AgentStore,
37
+ [['zustand/devtools', never]],
38
+ [],
39
+ AgentChatAction
40
+ > = (set, get) => ({
41
+ removePlugin: async (id) => {
42
+ await get().togglePlugin(id, false);
43
+ },
44
+
45
+ togglePlugin: async (id, open) => {
46
+ const originConfig = agentSelectors.currentAgentConfig(get());
47
+
48
+ const config = produce(originConfig, (draft) => {
49
+ draft.plugins = produce(draft.plugins || [], (plugins) => {
50
+ const index = plugins.indexOf(id);
51
+ const shouldOpen = open !== undefined ? open : index === -1;
52
+
53
+ if (shouldOpen) {
54
+ // 如果 open 为 true 或者 id 不存在于 plugins 中,则添加它
55
+ if (index === -1) {
56
+ plugins.push(id);
57
+ }
58
+ } else {
59
+ // 如果 open 为 false 或者 id 存在于 plugins 中,则移除它
60
+ if (index !== -1) {
61
+ plugins.splice(index, 1);
62
+ }
63
+ }
64
+ });
65
+ });
66
+
67
+ await get().updateAgentConfig(config);
68
+ },
69
+ updateAgentConfig: async (config) => {
70
+ const { activeId } = get();
71
+
72
+ if (!activeId) return;
73
+
74
+ await get().internal_updateAgentConfig(activeId, config);
75
+ },
76
+
77
+ useFetchAgentConfig: (sessionId) =>
78
+ useClientDataSWR<LobeAgentConfig>(
79
+ [FETCH_AGENT_CONFIG_KEY, sessionId],
80
+ ([, id]: string[]) => sessionService.getSessionConfig(id),
81
+ {
82
+ onSuccess: (data) => {
83
+ if (get().isAgentConfigInit && isEqual(get().agentConfig, data)) return;
84
+
85
+ set({ agentConfig: data, isAgentConfigInit: true }, false, 'fetchAgentConfig');
86
+ },
87
+ },
88
+ ),
89
+
90
+ /* eslint-disable sort-keys-fix/sort-keys-fix */
91
+
92
+ internal_updateAgentConfig: async (id, data) => {
93
+ const prevModel = agentSelectors.currentAgentModel(get());
94
+ // optimistic update at frontend
95
+ set({ agentConfig: merge(get().agentConfig, data) }, false, 'optimistic_updateAgentConfig');
96
+
97
+ await sessionService.updateSessionConfig(id, data);
98
+ await get().internal_refreshAgentConfig(id);
99
+
100
+ // refresh sessions to update the agent config if the model has changed
101
+ if (prevModel !== data.model) await useSessionStore.getState().refreshSessions();
102
+ },
103
+
104
+ internal_refreshAgentConfig: async (id) => {
105
+ await mutate([FETCH_AGENT_CONFIG_KEY, id]);
106
+ },
107
+ });
@@ -0,0 +1,14 @@
1
+ import { DEFAULT_AGENT_CONFIG } from '@/const/settings';
2
+ import { LobeAgentConfig } from '@/types/agent';
3
+
4
+ export interface AgentState {
5
+ activeId: string;
6
+ agentConfig: LobeAgentConfig;
7
+ isAgentConfigInit: boolean;
8
+ }
9
+
10
+ export const initialSessionState: AgentState = {
11
+ activeId: 'inbox',
12
+ agentConfig: DEFAULT_AGENT_CONFIG,
13
+ isAgentConfigInit: false,
14
+ };
@@ -0,0 +1,82 @@
1
+ import { describe, expect, it } from 'vitest';
2
+
3
+ import { DEFAULT_AGENT_CONFIG, DEFAUTT_AGENT_TTS_CONFIG } from '@/const/settings';
4
+ import { AgentStore } from '@/store/agent';
5
+
6
+ import { agentSelectors } from './selectors';
7
+
8
+ vi.mock('i18next', () => ({
9
+ t: vi.fn((key) => key), // Simplified mock return value
10
+ }));
11
+
12
+ const mockSessionStore = {
13
+ activeId: '1',
14
+ agentConfig: DEFAULT_AGENT_CONFIG,
15
+ } as AgentStore;
16
+
17
+ describe('agentSelectors', () => {
18
+ describe('currentAgentConfig', () => {
19
+ it('should return the merged default and session-specific agent config', () => {
20
+ const config = agentSelectors.currentAgentConfig(mockSessionStore);
21
+ expect(config).toEqual(expect.objectContaining(mockSessionStore.agentConfig));
22
+ });
23
+ });
24
+
25
+ describe('currentAgentModel', () => {
26
+ it('should return the model from the agent config', () => {
27
+ const model = agentSelectors.currentAgentModel(mockSessionStore);
28
+ expect(model).toBe(mockSessionStore.agentConfig.model);
29
+ });
30
+ });
31
+
32
+ describe('hasSystemRole', () => {
33
+ it('should return true if the system role is defined in the agent config', () => {
34
+ const hasRole = agentSelectors.hasSystemRole(mockSessionStore);
35
+ expect(hasRole).toBe(false);
36
+ });
37
+
38
+ it('should return false if the system role is not defined in the agent config', () => {
39
+ const modifiedSessionStore = {
40
+ ...mockSessionStore,
41
+ agentConfig: {
42
+ ...mockSessionStore.agentConfig,
43
+ systemRole: 'test',
44
+ },
45
+ };
46
+ const hasRole = agentSelectors.hasSystemRole(modifiedSessionStore);
47
+ expect(hasRole).toBe(true);
48
+ });
49
+ });
50
+
51
+ describe('currentAgentTTS', () => {
52
+ it('should return the TTS config from the agent config', () => {
53
+ const ttsConfig = agentSelectors.currentAgentTTS(mockSessionStore);
54
+ expect(ttsConfig).toEqual(mockSessionStore.agentConfig.tts);
55
+ });
56
+
57
+ it('should return the default TTS config if none is defined in the agent config', () => {
58
+ const modifiedSessionStore = {
59
+ ...mockSessionStore,
60
+ sessions: [
61
+ {
62
+ ...mockSessionStore.agentConfig,
63
+ config: {
64
+ ...mockSessionStore.agentConfig,
65
+ tts: DEFAUTT_AGENT_TTS_CONFIG,
66
+ },
67
+ },
68
+ ],
69
+ };
70
+ const ttsConfig = agentSelectors.currentAgentTTS(modifiedSessionStore);
71
+ expect(ttsConfig).toEqual(DEFAUTT_AGENT_TTS_CONFIG);
72
+ });
73
+ });
74
+
75
+ describe('currentAgentTTSVoice', () => {
76
+ it('should return the appropriate TTS voice based on the service and language', () => {
77
+ const lang = 'en';
78
+ const ttsVoice = agentSelectors.currentAgentTTSVoice(lang)(mockSessionStore);
79
+ expect(ttsVoice).toBe(mockSessionStore.agentConfig.tts.voice.openai);
80
+ });
81
+ });
82
+ });