@lobehub/chat 0.156.1 → 0.157.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. package/CHANGELOG.md +42 -0
  2. package/Dockerfile +4 -1
  3. package/package.json +3 -2
  4. package/src/config/modelProviders/anthropic.ts +3 -0
  5. package/src/config/modelProviders/google.ts +3 -0
  6. package/src/config/modelProviders/groq.ts +5 -1
  7. package/src/config/modelProviders/minimax.ts +10 -7
  8. package/src/config/modelProviders/mistral.ts +1 -0
  9. package/src/config/modelProviders/moonshot.ts +3 -0
  10. package/src/config/modelProviders/zhipu.ts +2 -6
  11. package/src/config/server/provider.ts +1 -1
  12. package/src/database/client/core/db.ts +32 -0
  13. package/src/database/client/core/schemas.ts +9 -0
  14. package/src/database/client/models/__tests__/message.test.ts +2 -2
  15. package/src/database/client/schemas/message.ts +8 -1
  16. package/src/features/AgentSetting/store/action.ts +15 -6
  17. package/src/features/Conversation/Actions/Tool.tsx +16 -0
  18. package/src/features/Conversation/Actions/index.ts +2 -2
  19. package/src/features/Conversation/Messages/Assistant/ToolCalls/index.tsx +78 -0
  20. package/src/features/Conversation/Messages/Assistant/ToolCalls/style.ts +25 -0
  21. package/src/features/Conversation/Messages/Assistant/index.tsx +47 -0
  22. package/src/features/Conversation/Messages/Default.tsx +4 -1
  23. package/src/features/Conversation/{Plugins → Messages/Tool}/Inspector/index.tsx +34 -35
  24. package/src/features/Conversation/Messages/Tool/index.tsx +44 -0
  25. package/src/features/Conversation/Messages/index.ts +3 -2
  26. package/src/features/Conversation/Plugins/Render/StandaloneType/Iframe.tsx +1 -1
  27. package/src/features/Conversation/components/SkeletonList.tsx +2 -2
  28. package/src/features/Conversation/index.tsx +2 -3
  29. package/src/libs/agent-runtime/BaseAI.ts +2 -9
  30. package/src/libs/agent-runtime/anthropic/index.test.ts +195 -0
  31. package/src/libs/agent-runtime/anthropic/index.ts +71 -15
  32. package/src/libs/agent-runtime/azureOpenai/index.ts +12 -13
  33. package/src/libs/agent-runtime/bedrock/index.ts +24 -18
  34. package/src/libs/agent-runtime/google/index.test.ts +154 -0
  35. package/src/libs/agent-runtime/google/index.ts +91 -10
  36. package/src/libs/agent-runtime/groq/index.test.ts +41 -72
  37. package/src/libs/agent-runtime/groq/index.ts +7 -0
  38. package/src/libs/agent-runtime/minimax/index.test.ts +2 -2
  39. package/src/libs/agent-runtime/minimax/index.ts +14 -37
  40. package/src/libs/agent-runtime/mistral/index.test.ts +0 -53
  41. package/src/libs/agent-runtime/mistral/index.ts +1 -0
  42. package/src/libs/agent-runtime/moonshot/index.test.ts +1 -71
  43. package/src/libs/agent-runtime/ollama/index.test.ts +197 -0
  44. package/src/libs/agent-runtime/ollama/index.ts +3 -3
  45. package/src/libs/agent-runtime/openai/index.test.ts +0 -53
  46. package/src/libs/agent-runtime/openrouter/index.test.ts +1 -53
  47. package/src/libs/agent-runtime/perplexity/index.test.ts +0 -71
  48. package/src/libs/agent-runtime/perplexity/index.ts +2 -3
  49. package/src/libs/agent-runtime/togetherai/__snapshots__/index.test.ts.snap +886 -0
  50. package/src/libs/agent-runtime/togetherai/fixtures/models.json +8111 -0
  51. package/src/libs/agent-runtime/togetherai/index.test.ts +16 -54
  52. package/src/libs/agent-runtime/types/chat.ts +19 -3
  53. package/src/libs/agent-runtime/utils/anthropicHelpers.test.ts +120 -1
  54. package/src/libs/agent-runtime/utils/anthropicHelpers.ts +67 -4
  55. package/src/libs/agent-runtime/utils/debugStream.test.ts +70 -0
  56. package/src/libs/agent-runtime/utils/debugStream.ts +39 -9
  57. package/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.test.ts +521 -0
  58. package/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts +76 -5
  59. package/src/libs/agent-runtime/utils/response.ts +12 -0
  60. package/src/libs/agent-runtime/utils/streams/anthropic.test.ts +197 -0
  61. package/src/libs/agent-runtime/utils/streams/anthropic.ts +91 -0
  62. package/src/libs/agent-runtime/utils/streams/bedrock/claude.ts +21 -0
  63. package/src/libs/agent-runtime/utils/streams/bedrock/common.ts +32 -0
  64. package/src/libs/agent-runtime/utils/streams/bedrock/index.ts +3 -0
  65. package/src/libs/agent-runtime/utils/streams/bedrock/llama.test.ts +196 -0
  66. package/src/libs/agent-runtime/utils/streams/bedrock/llama.ts +51 -0
  67. package/src/libs/agent-runtime/utils/streams/google-ai.test.ts +97 -0
  68. package/src/libs/agent-runtime/utils/streams/google-ai.ts +68 -0
  69. package/src/libs/agent-runtime/utils/streams/index.ts +7 -0
  70. package/src/libs/agent-runtime/utils/streams/minimax.ts +39 -0
  71. package/src/libs/agent-runtime/utils/streams/ollama.test.ts +77 -0
  72. package/src/libs/agent-runtime/utils/streams/ollama.ts +38 -0
  73. package/src/libs/agent-runtime/utils/streams/openai.test.ts +263 -0
  74. package/src/libs/agent-runtime/utils/streams/openai.ts +79 -0
  75. package/src/libs/agent-runtime/utils/streams/protocol.ts +100 -0
  76. package/src/libs/agent-runtime/zeroone/index.test.ts +1 -53
  77. package/src/libs/agent-runtime/zhipu/index.test.ts +1 -1
  78. package/src/libs/agent-runtime/zhipu/index.ts +3 -2
  79. package/src/locales/default/plugin.ts +3 -4
  80. package/src/migrations/FromV4ToV5/fixtures/from-v1-to-v5-output.json +245 -0
  81. package/src/migrations/FromV4ToV5/fixtures/function-input-v4.json +96 -0
  82. package/src/migrations/FromV4ToV5/fixtures/function-output-v5.json +120 -0
  83. package/src/migrations/FromV4ToV5/index.ts +58 -0
  84. package/src/migrations/FromV4ToV5/migrations.test.ts +49 -0
  85. package/src/migrations/FromV4ToV5/types/v4.ts +21 -0
  86. package/src/migrations/FromV4ToV5/types/v5.ts +27 -0
  87. package/src/migrations/index.ts +8 -1
  88. package/src/services/__tests__/chat.test.ts +10 -20
  89. package/src/services/chat.ts +78 -65
  90. package/src/store/chat/slices/enchance/action.ts +15 -10
  91. package/src/store/chat/slices/message/action.test.ts +36 -86
  92. package/src/store/chat/slices/message/action.ts +70 -79
  93. package/src/store/chat/slices/message/reducer.ts +18 -1
  94. package/src/store/chat/slices/message/selectors.test.ts +38 -68
  95. package/src/store/chat/slices/message/selectors.ts +1 -22
  96. package/src/store/chat/slices/plugin/action.test.ts +147 -203
  97. package/src/store/chat/slices/plugin/action.ts +96 -82
  98. package/src/store/chat/slices/share/action.test.ts +3 -3
  99. package/src/store/chat/slices/share/action.ts +1 -1
  100. package/src/store/chat/slices/topic/action.ts +7 -2
  101. package/src/store/tool/selectors/tool.ts +6 -24
  102. package/src/store/tool/slices/builtin/action.test.ts +90 -0
  103. package/src/types/llm.ts +1 -1
  104. package/src/types/message/index.ts +9 -4
  105. package/src/types/message/tools.ts +57 -0
  106. package/src/types/openai/chat.ts +6 -0
  107. package/src/utils/fetch.test.ts +245 -1
  108. package/src/utils/fetch.ts +120 -44
  109. package/src/utils/toolCall.ts +21 -0
  110. package/src/features/Conversation/Messages/Assistant.tsx +0 -26
  111. package/src/features/Conversation/Messages/Function.tsx +0 -35
  112. package/src/libs/agent-runtime/ollama/stream.ts +0 -31
  113. /package/src/features/Conversation/{Plugins → Messages/Tool}/Inspector/PluginResultJSON.tsx +0 -0
  114. /package/src/features/Conversation/{Plugins → Messages/Tool}/Inspector/Settings.tsx +0 -0
  115. /package/src/features/Conversation/{Plugins → Messages/Tool}/Inspector/style.ts +0 -0
@@ -0,0 +1,197 @@
1
+ import type { Stream } from '@anthropic-ai/sdk/streaming';
2
+ import { describe, expect, it, vi } from 'vitest';
3
+
4
+ import { AnthropicStream } from './anthropic';
5
+
6
+ describe('AnthropicStream', () => {
7
+ it('should transform Anthropic stream to protocol stream', async () => {
8
+ // @ts-ignore
9
+ const mockAnthropicStream: Stream = {
10
+ [Symbol.asyncIterator]() {
11
+ let count = 0;
12
+ return {
13
+ next: async () => {
14
+ switch (count) {
15
+ case 0:
16
+ count++;
17
+ return {
18
+ done: false,
19
+ value: {
20
+ type: 'message_start',
21
+ message: { id: 'message_1', metadata: {} },
22
+ },
23
+ };
24
+ case 1:
25
+ count++;
26
+ return {
27
+ done: false,
28
+ value: {
29
+ type: 'content_block_delta',
30
+ delta: { type: 'text_delta', text: 'Hello' },
31
+ },
32
+ };
33
+ case 2:
34
+ count++;
35
+ return {
36
+ done: false,
37
+ value: {
38
+ type: 'content_block_delta',
39
+ delta: { type: 'text_delta', text: ' world!' },
40
+ },
41
+ };
42
+ case 3:
43
+ count++;
44
+ return {
45
+ done: false,
46
+ value: {
47
+ type: 'message_delta',
48
+ delta: { stop_reason: 'stop' },
49
+ },
50
+ };
51
+ default:
52
+ return { done: true, value: undefined };
53
+ }
54
+ },
55
+ };
56
+ },
57
+ };
58
+
59
+ const onStartMock = vi.fn();
60
+ const onTextMock = vi.fn();
61
+ const onTokenMock = vi.fn();
62
+ const onCompletionMock = vi.fn();
63
+
64
+ const protocolStream = AnthropicStream(mockAnthropicStream, {
65
+ onStart: onStartMock,
66
+ onText: onTextMock,
67
+ onToken: onTokenMock,
68
+ onCompletion: onCompletionMock,
69
+ });
70
+
71
+ const decoder = new TextDecoder();
72
+ const chunks = [];
73
+
74
+ // @ts-ignore
75
+ for await (const chunk of protocolStream) {
76
+ chunks.push(decoder.decode(chunk, { stream: true }));
77
+ }
78
+
79
+ expect(chunks).toEqual([
80
+ 'id: message_1\n',
81
+ 'event: data\n',
82
+ `data: {"id":"message_1","metadata":{}}\n\n`,
83
+ 'id: message_1\n',
84
+ 'event: text\n',
85
+ `data: "Hello"\n\n`,
86
+ 'id: message_1\n',
87
+ 'event: text\n',
88
+ `data: " world!"\n\n`,
89
+ 'id: message_1\n',
90
+ 'event: stop\n',
91
+ `data: "stop"\n\n`,
92
+ ]);
93
+
94
+ expect(onStartMock).toHaveBeenCalledTimes(1);
95
+ expect(onTextMock).toHaveBeenNthCalledWith(1, '"Hello"');
96
+ expect(onTextMock).toHaveBeenNthCalledWith(2, '" world!"');
97
+ expect(onTokenMock).toHaveBeenCalledTimes(2);
98
+ expect(onCompletionMock).toHaveBeenCalledTimes(1);
99
+ });
100
+
101
+ it('should handle tool use event and ReadableStream input', async () => {
102
+ const toolUseEvent = {
103
+ type: 'content_block_delta',
104
+ delta: {
105
+ type: 'tool_use',
106
+ tool_use: {
107
+ id: 'tool_use_1',
108
+ name: 'example_tool',
109
+ input: { arg1: 'value1' },
110
+ },
111
+ },
112
+ };
113
+
114
+ const mockReadableStream = new ReadableStream({
115
+ start(controller) {
116
+ controller.enqueue({
117
+ type: 'message_start',
118
+ message: { id: 'message_1', metadata: {} },
119
+ });
120
+ controller.enqueue(toolUseEvent);
121
+ controller.enqueue({
122
+ type: 'message_stop',
123
+ });
124
+ controller.close();
125
+ },
126
+ });
127
+
128
+ const onToolCallMock = vi.fn();
129
+
130
+ const protocolStream = AnthropicStream(mockReadableStream, {
131
+ onToolCall: onToolCallMock,
132
+ });
133
+
134
+ const decoder = new TextDecoder();
135
+ const chunks = [];
136
+
137
+ // @ts-ignore
138
+ for await (const chunk of protocolStream) {
139
+ chunks.push(decoder.decode(chunk, { stream: true }));
140
+ }
141
+
142
+ expect(chunks).toEqual([
143
+ 'id: message_1\n',
144
+ 'event: data\n',
145
+ `data: {"id":"message_1","metadata":{}}\n\n`,
146
+ 'id: message_1\n',
147
+ 'event: tool_calls\n',
148
+ `data: [{"function":{"arguments":"{\\"arg1\\":\\"value1\\"}","name":"example_tool"},"id":"tool_use_1","index":0,"type":"function"}]\n\n`,
149
+ 'id: message_1\n',
150
+ 'event: stop\n',
151
+ `data: "message_stop"\n\n`,
152
+ ]);
153
+
154
+ expect(onToolCallMock).toHaveBeenCalledTimes(1);
155
+ });
156
+
157
+ it('should handle ReadableStream input', async () => {
158
+ const mockReadableStream = new ReadableStream({
159
+ start(controller) {
160
+ controller.enqueue({
161
+ type: 'message_start',
162
+ message: { id: 'message_1', metadata: {} },
163
+ });
164
+ controller.enqueue({
165
+ type: 'content_block_delta',
166
+ delta: { type: 'text_delta', text: 'Hello' },
167
+ });
168
+ controller.enqueue({
169
+ type: 'message_stop',
170
+ });
171
+ controller.close();
172
+ },
173
+ });
174
+
175
+ const protocolStream = AnthropicStream(mockReadableStream);
176
+
177
+ const decoder = new TextDecoder();
178
+ const chunks = [];
179
+
180
+ // @ts-ignore
181
+ for await (const chunk of protocolStream) {
182
+ chunks.push(decoder.decode(chunk, { stream: true }));
183
+ }
184
+
185
+ expect(chunks).toEqual([
186
+ 'id: message_1\n',
187
+ 'event: data\n',
188
+ `data: {"id":"message_1","metadata":{}}\n\n`,
189
+ 'id: message_1\n',
190
+ 'event: text\n',
191
+ `data: "Hello"\n\n`,
192
+ 'id: message_1\n',
193
+ 'event: stop\n',
194
+ `data: "message_stop"\n\n`,
195
+ ]);
196
+ });
197
+ });
@@ -0,0 +1,91 @@
1
+ import Anthropic from '@anthropic-ai/sdk';
2
+ import type { Stream } from '@anthropic-ai/sdk/streaming';
3
+ import { readableFromAsyncIterable } from 'ai';
4
+
5
+ import { ChatStreamCallbacks } from '../../types';
6
+ import {
7
+ StreamProtocolChunk,
8
+ StreamProtocolToolCallChunk,
9
+ StreamStack,
10
+ StreamToolCallChunkData,
11
+ createCallbacksTransformer,
12
+ createSSEProtocolTransformer,
13
+ } from './protocol';
14
+
15
+ export const transformAnthropicStream = (
16
+ chunk: Anthropic.MessageStreamEvent,
17
+ stack: StreamStack,
18
+ ): StreamProtocolChunk => {
19
+ // maybe need another structure to add support for multiple choices
20
+ switch (chunk.type) {
21
+ case 'message_start': {
22
+ stack.id = chunk.message.id;
23
+ return { data: chunk.message, id: chunk.message.id, type: 'data' };
24
+ }
25
+
26
+ // case 'content_block_start': {
27
+ // return { data: chunk.content_block.text, id: stack.id, type: 'data' };
28
+ // }
29
+
30
+ case 'content_block_delta': {
31
+ switch (chunk.delta.type as string) {
32
+ default:
33
+ case 'text_delta': {
34
+ return { data: chunk.delta.text, id: stack.id, type: 'text' };
35
+ }
36
+
37
+ // TODO: due to anthropic currently don't support streaming tool calling
38
+ // we need to add this new `tool_use` type to support streaming
39
+ // and maybe we need to update it when the feature is available
40
+ case 'tool_use': {
41
+ const delta = (chunk.delta as any).tool_use as Anthropic.Beta.Tools.ToolUseBlock;
42
+
43
+ const toolCall: StreamToolCallChunkData = {
44
+ function: { arguments: JSON.stringify(delta.input), name: delta.name },
45
+ id: delta.id,
46
+ index: 0,
47
+ type: 'function',
48
+ };
49
+
50
+ return {
51
+ data: [toolCall],
52
+ id: stack.id,
53
+ type: 'tool_calls',
54
+ } as StreamProtocolToolCallChunk;
55
+ }
56
+ }
57
+ }
58
+
59
+ case 'message_delta': {
60
+ return { data: chunk.delta.stop_reason, id: stack.id, type: 'stop' };
61
+ }
62
+
63
+ case 'message_stop': {
64
+ return { data: 'message_stop', id: stack.id, type: 'stop' };
65
+ }
66
+
67
+ default: {
68
+ return { data: chunk, id: stack.id, type: 'data' };
69
+ }
70
+ }
71
+ };
72
+
73
+ const chatStreamable = async function* (stream: AsyncIterable<Anthropic.MessageStreamEvent>) {
74
+ for await (const response of stream) {
75
+ yield response;
76
+ }
77
+ };
78
+
79
+ export const AnthropicStream = (
80
+ stream: Stream<Anthropic.MessageStreamEvent> | ReadableStream,
81
+ callbacks?: ChatStreamCallbacks,
82
+ ) => {
83
+ const streamStack: StreamStack = { id: '' };
84
+
85
+ const readableStream =
86
+ stream instanceof ReadableStream ? stream : readableFromAsyncIterable(chatStreamable(stream));
87
+
88
+ return readableStream
89
+ .pipeThrough(createSSEProtocolTransformer(transformAnthropicStream, streamStack))
90
+ .pipeThrough(createCallbacksTransformer(callbacks));
91
+ };
@@ -0,0 +1,21 @@
1
+ import { InvokeModelWithResponseStreamResponse } from '@aws-sdk/client-bedrock-runtime';
2
+
3
+ import { nanoid } from '@/utils/uuid';
4
+
5
+ import { ChatStreamCallbacks } from '../../../types';
6
+ import { transformAnthropicStream } from '../anthropic';
7
+ import { StreamStack, createCallbacksTransformer, createSSEProtocolTransformer } from '../protocol';
8
+ import { createBedrockStream } from './common';
9
+
10
+ export const AWSBedrockClaudeStream = (
11
+ res: InvokeModelWithResponseStreamResponse | ReadableStream,
12
+ cb?: ChatStreamCallbacks,
13
+ ): ReadableStream<string> => {
14
+ const streamStack: StreamStack = { id: 'chat_' + nanoid() };
15
+
16
+ const stream = res instanceof ReadableStream ? res : createBedrockStream(res);
17
+
18
+ return stream
19
+ .pipeThrough(createSSEProtocolTransformer(transformAnthropicStream, streamStack))
20
+ .pipeThrough(createCallbacksTransformer(cb));
21
+ };
@@ -0,0 +1,32 @@
1
+ import {
2
+ InvokeModelWithResponseStreamResponse,
3
+ ResponseStream,
4
+ } from '@aws-sdk/client-bedrock-runtime';
5
+ import { readableFromAsyncIterable } from 'ai';
6
+
7
+ const chatStreamable = async function* (stream: AsyncIterable<ResponseStream>) {
8
+ for await (const response of stream) {
9
+ if (response.chunk) {
10
+ const decoder = new TextDecoder();
11
+
12
+ const value = decoder.decode(response.chunk.bytes, { stream: true });
13
+ try {
14
+ const chunk = JSON.parse(value);
15
+
16
+ yield chunk;
17
+ } catch (e) {
18
+ console.log('bedrock stream parser error:', e);
19
+
20
+ yield value;
21
+ }
22
+ } else {
23
+ yield response;
24
+ }
25
+ }
26
+ };
27
+
28
+ /**
29
+ * covert the bedrock response to a readable stream
30
+ */
31
+ export const createBedrockStream = (res: InvokeModelWithResponseStreamResponse) =>
32
+ readableFromAsyncIterable(chatStreamable(res.body!));
@@ -0,0 +1,3 @@
1
+ export * from './claude';
2
+ export * from './common';
3
+ export * from './llama';
@@ -0,0 +1,196 @@
1
+ import { InvokeModelWithResponseStreamResponse } from '@aws-sdk/client-bedrock-runtime';
2
+ import { Readable } from 'stream';
3
+ import { describe, expect, it, vi } from 'vitest';
4
+
5
+ import * as uuidModule from '@/utils/uuid';
6
+
7
+ import { AWSBedrockLlamaStream } from './llama';
8
+
9
+ describe('AWSBedrockLlamaStream', () => {
10
+ it('should transform Bedrock Llama stream to protocol stream', async () => {
11
+ vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('1');
12
+ const mockBedrockStream = new ReadableStream({
13
+ start(controller) {
14
+ controller.enqueue({ generation: 'Hello', generation_token_count: 1 });
15
+ controller.enqueue({ generation: ' world!', generation_token_count: 2 });
16
+ controller.enqueue({ stop_reason: 'stop' });
17
+ controller.close();
18
+ },
19
+ });
20
+
21
+ const onStartMock = vi.fn();
22
+ const onTextMock = vi.fn();
23
+ const onTokenMock = vi.fn();
24
+ const onCompletionMock = vi.fn();
25
+
26
+ const protocolStream = AWSBedrockLlamaStream(mockBedrockStream, {
27
+ onStart: onStartMock,
28
+ onText: onTextMock,
29
+ onToken: onTokenMock,
30
+ onCompletion: onCompletionMock,
31
+ });
32
+
33
+ const decoder = new TextDecoder();
34
+ const chunks = [];
35
+
36
+ // @ts-ignore
37
+ for await (const chunk of protocolStream) {
38
+ chunks.push(decoder.decode(chunk, { stream: true }));
39
+ }
40
+
41
+ expect(chunks).toEqual([
42
+ 'id: chat_1\n',
43
+ 'event: text\n',
44
+ `data: "Hello"\n\n`,
45
+ 'id: chat_1\n',
46
+ 'event: text\n',
47
+ `data: " world!"\n\n`,
48
+ 'id: chat_1\n',
49
+ 'event: stop\n',
50
+ `data: "finished"\n\n`,
51
+ ]);
52
+
53
+ expect(onStartMock).toHaveBeenCalledTimes(1);
54
+ expect(onTextMock).toHaveBeenNthCalledWith(1, '"Hello"');
55
+ expect(onTextMock).toHaveBeenNthCalledWith(2, '" world!"');
56
+ expect(onTokenMock).toHaveBeenCalledTimes(2);
57
+ expect(onCompletionMock).toHaveBeenCalledTimes(1);
58
+ });
59
+
60
+ it('should transform Bedrock Llama AsyncIterator to protocol stream', async () => {
61
+ vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('1');
62
+
63
+ const mockBedrockStream: InvokeModelWithResponseStreamResponse = {
64
+ body: {
65
+ // @ts-ignore
66
+ async *[Symbol.asyncIterator]() {
67
+ yield { generation: 'Hello', generation_token_count: 1 };
68
+ yield { generation: ' world!', generation_token_count: 2 };
69
+ yield { stop_reason: 'stop' };
70
+ },
71
+ },
72
+ };
73
+
74
+ const onStartMock = vi.fn();
75
+ const onTextMock = vi.fn();
76
+ const onTokenMock = vi.fn();
77
+ const onCompletionMock = vi.fn();
78
+
79
+ const protocolStream = AWSBedrockLlamaStream(mockBedrockStream, {
80
+ onStart: onStartMock,
81
+ onText: onTextMock,
82
+ onToken: onTokenMock,
83
+ onCompletion: onCompletionMock,
84
+ });
85
+
86
+ const decoder = new TextDecoder();
87
+ const chunks = [];
88
+
89
+ // @ts-ignore
90
+ for await (const chunk of protocolStream) {
91
+ chunks.push(decoder.decode(chunk, { stream: true }));
92
+ }
93
+
94
+ expect(chunks).toEqual([
95
+ 'id: chat_1\n',
96
+ 'event: text\n',
97
+ `data: "Hello"\n\n`,
98
+ 'id: chat_1\n',
99
+ 'event: text\n',
100
+ `data: " world!"\n\n`,
101
+ 'id: chat_1\n',
102
+ 'event: stop\n',
103
+ `data: "finished"\n\n`,
104
+ ]);
105
+
106
+ expect(onStartMock).toHaveBeenCalledTimes(1);
107
+ expect(onTextMock).toHaveBeenNthCalledWith(1, '"Hello"');
108
+ expect(onTextMock).toHaveBeenNthCalledWith(2, '" world!"');
109
+ expect(onTokenMock).toHaveBeenCalledTimes(2);
110
+ expect(onCompletionMock).toHaveBeenCalledTimes(1);
111
+ });
112
+
113
+ it('should handle Bedrock response with chunk property', async () => {
114
+ vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('2');
115
+
116
+ const mockBedrockStream: InvokeModelWithResponseStreamResponse = {
117
+ contentType: 'any',
118
+ body: {
119
+ // @ts-ignore
120
+ async *[Symbol.asyncIterator]() {
121
+ yield {
122
+ chunk: {
123
+ bytes: new TextEncoder().encode('{"generation":"Hello","generation_token_count":1}'),
124
+ },
125
+ };
126
+ yield {
127
+ chunk: {
128
+ bytes: new TextEncoder().encode(
129
+ '{"generation":" world!","generation_token_count":2}',
130
+ ),
131
+ },
132
+ };
133
+ yield { chunk: { bytes: new TextEncoder().encode('{"stop_reason":"stop"}') } };
134
+ },
135
+ },
136
+ };
137
+
138
+ const onStartMock = vi.fn();
139
+ const onTextMock = vi.fn();
140
+ const onTokenMock = vi.fn();
141
+ const onCompletionMock = vi.fn();
142
+
143
+ const protocolStream = AWSBedrockLlamaStream(mockBedrockStream, {
144
+ onStart: onStartMock,
145
+ onText: onTextMock,
146
+ onToken: onTokenMock,
147
+ onCompletion: onCompletionMock,
148
+ });
149
+
150
+ const decoder = new TextDecoder();
151
+ const chunks = [];
152
+
153
+ // @ts-ignore
154
+ for await (const chunk of protocolStream) {
155
+ chunks.push(decoder.decode(chunk, { stream: true }));
156
+ }
157
+
158
+ expect(chunks).toEqual([
159
+ 'id: chat_2\n',
160
+ 'event: text\n',
161
+ `data: "Hello"\n\n`,
162
+ 'id: chat_2\n',
163
+ 'event: text\n',
164
+ `data: " world!"\n\n`,
165
+ 'id: chat_2\n',
166
+ 'event: stop\n',
167
+ `data: "finished"\n\n`,
168
+ ]);
169
+
170
+ expect(onStartMock).toHaveBeenCalledTimes(1);
171
+ expect(onTextMock).toHaveBeenNthCalledWith(1, '"Hello"');
172
+ expect(onTextMock).toHaveBeenNthCalledWith(2, '" world!"');
173
+ expect(onTokenMock).toHaveBeenCalledTimes(2);
174
+ expect(onCompletionMock).toHaveBeenCalledTimes(1);
175
+ });
176
+
177
+ it('should handle empty stream', async () => {
178
+ const mockBedrockStream = new ReadableStream({
179
+ start(controller) {
180
+ controller.close();
181
+ },
182
+ });
183
+
184
+ const protocolStream = AWSBedrockLlamaStream(mockBedrockStream);
185
+
186
+ const decoder = new TextDecoder();
187
+ const chunks = [];
188
+
189
+ // @ts-ignore
190
+ for await (const chunk of protocolStream) {
191
+ chunks.push(decoder.decode(chunk, { stream: true }));
192
+ }
193
+
194
+ expect(chunks).toEqual([]);
195
+ });
196
+ });
@@ -0,0 +1,51 @@
1
+ import { InvokeModelWithResponseStreamResponse } from '@aws-sdk/client-bedrock-runtime';
2
+
3
+ import { nanoid } from '@/utils/uuid';
4
+
5
+ import { ChatStreamCallbacks } from '../../../types';
6
+ import {
7
+ StreamProtocolChunk,
8
+ StreamStack,
9
+ createCallbacksTransformer,
10
+ createSSEProtocolTransformer,
11
+ } from '../protocol';
12
+ import { createBedrockStream } from './common';
13
+
14
+ interface AmazonBedrockInvocationMetrics {
15
+ firstByteLatency: number;
16
+ inputTokenCount: number;
17
+ invocationLatency: number;
18
+ outputTokenCount: number;
19
+ }
20
+ interface BedrockLlamaStreamChunk {
21
+ 'amazon-bedrock-invocationMetrics'?: AmazonBedrockInvocationMetrics;
22
+ 'generation': string;
23
+ 'generation_token_count': number;
24
+ 'prompt_token_count'?: number | null;
25
+ 'stop_reason'?: null | 'stop' | string;
26
+ }
27
+
28
+ export const transformLlamaStream = (
29
+ chunk: BedrockLlamaStreamChunk,
30
+ stack: StreamStack,
31
+ ): StreamProtocolChunk => {
32
+ // maybe need another structure to add support for multiple choices
33
+ if (chunk.stop_reason) {
34
+ return { data: 'finished', id: stack.id, type: 'stop' };
35
+ }
36
+
37
+ return { data: chunk.generation, id: stack.id, type: 'text' };
38
+ };
39
+
40
+ export const AWSBedrockLlamaStream = (
41
+ res: InvokeModelWithResponseStreamResponse | ReadableStream,
42
+ cb?: ChatStreamCallbacks,
43
+ ): ReadableStream<string> => {
44
+ const streamStack: StreamStack = { id: 'chat_' + nanoid() };
45
+
46
+ const stream = res instanceof ReadableStream ? res : createBedrockStream(res);
47
+
48
+ return stream
49
+ .pipeThrough(createSSEProtocolTransformer(transformLlamaStream, streamStack))
50
+ .pipeThrough(createCallbacksTransformer(cb));
51
+ };
@@ -0,0 +1,97 @@
1
+ import { EnhancedGenerateContentResponse } from '@google/generative-ai';
2
+ import { describe, expect, it, vi } from 'vitest';
3
+
4
+ import * as uuidModule from '@/utils/uuid';
5
+
6
+ import { GoogleGenerativeAIStream } from './google-ai';
7
+
8
+ describe('GoogleGenerativeAIStream', () => {
9
+ it('should transform Google Generative AI stream to protocol stream', async () => {
10
+ vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('1');
11
+
12
+ const mockGenerateContentResponse = (text: string, functionCalls?: any[]) =>
13
+ ({
14
+ text: () => text,
15
+ functionCall: () => functionCalls?.[0],
16
+ functionCalls: () => functionCalls,
17
+ }) as EnhancedGenerateContentResponse;
18
+
19
+ const mockGoogleStream = new ReadableStream({
20
+ start(controller) {
21
+ controller.enqueue(mockGenerateContentResponse('Hello'));
22
+
23
+ controller.enqueue(
24
+ mockGenerateContentResponse('', [{ name: 'testFunction', args: { arg1: 'value1' } }]),
25
+ );
26
+ controller.enqueue(mockGenerateContentResponse(' world!'));
27
+ controller.close();
28
+ },
29
+ });
30
+
31
+ const onStartMock = vi.fn();
32
+ const onTextMock = vi.fn();
33
+ const onTokenMock = vi.fn();
34
+ const onToolCallMock = vi.fn();
35
+ const onCompletionMock = vi.fn();
36
+
37
+ const protocolStream = GoogleGenerativeAIStream(mockGoogleStream, {
38
+ onStart: onStartMock,
39
+ onText: onTextMock,
40
+ onToken: onTokenMock,
41
+ onToolCall: onToolCallMock,
42
+ onCompletion: onCompletionMock,
43
+ });
44
+
45
+ const decoder = new TextDecoder();
46
+ const chunks = [];
47
+
48
+ // @ts-ignore
49
+ for await (const chunk of protocolStream) {
50
+ chunks.push(decoder.decode(chunk, { stream: true }));
51
+ }
52
+
53
+ expect(chunks).toEqual([
54
+ // text
55
+ 'id: chat_1\n',
56
+ 'event: text\n',
57
+ `data: "Hello"\n\n`,
58
+
59
+ // tool call
60
+ 'id: chat_1\n',
61
+ 'event: tool_calls\n',
62
+ `data: [{"function":{"arguments":"{\\"arg1\\":\\"value1\\"}","name":"testFunction"},"id":"testFunction_0","index":0,"type":"function"}]\n\n`,
63
+
64
+ // text
65
+ 'id: chat_1\n',
66
+ 'event: text\n',
67
+ `data: " world!"\n\n`,
68
+ ]);
69
+
70
+ expect(onStartMock).toHaveBeenCalledTimes(1);
71
+ expect(onTextMock).toHaveBeenNthCalledWith(1, '"Hello"');
72
+ expect(onTextMock).toHaveBeenNthCalledWith(2, '" world!"');
73
+ expect(onTokenMock).toHaveBeenCalledTimes(2);
74
+ expect(onToolCallMock).toHaveBeenCalledTimes(1);
75
+ expect(onCompletionMock).toHaveBeenCalledTimes(1);
76
+ });
77
+
78
+ it('should handle empty stream', async () => {
79
+ const mockGoogleStream = new ReadableStream({
80
+ start(controller) {
81
+ controller.close();
82
+ },
83
+ });
84
+
85
+ const protocolStream = GoogleGenerativeAIStream(mockGoogleStream);
86
+
87
+ const decoder = new TextDecoder();
88
+ const chunks = [];
89
+
90
+ // @ts-ignore
91
+ for await (const chunk of protocolStream) {
92
+ chunks.push(decoder.decode(chunk, { stream: true }));
93
+ }
94
+
95
+ expect(chunks).toEqual([]);
96
+ });
97
+ });