@lobehub/chat 0.156.1 → 0.157.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +42 -0
- package/Dockerfile +4 -1
- package/package.json +3 -2
- package/src/config/modelProviders/anthropic.ts +3 -0
- package/src/config/modelProviders/google.ts +3 -0
- package/src/config/modelProviders/groq.ts +5 -1
- package/src/config/modelProviders/minimax.ts +10 -7
- package/src/config/modelProviders/mistral.ts +1 -0
- package/src/config/modelProviders/moonshot.ts +3 -0
- package/src/config/modelProviders/zhipu.ts +2 -6
- package/src/config/server/provider.ts +1 -1
- package/src/database/client/core/db.ts +32 -0
- package/src/database/client/core/schemas.ts +9 -0
- package/src/database/client/models/__tests__/message.test.ts +2 -2
- package/src/database/client/schemas/message.ts +8 -1
- package/src/features/AgentSetting/store/action.ts +15 -6
- package/src/features/Conversation/Actions/Tool.tsx +16 -0
- package/src/features/Conversation/Actions/index.ts +2 -2
- package/src/features/Conversation/Messages/Assistant/ToolCalls/index.tsx +78 -0
- package/src/features/Conversation/Messages/Assistant/ToolCalls/style.ts +25 -0
- package/src/features/Conversation/Messages/Assistant/index.tsx +47 -0
- package/src/features/Conversation/Messages/Default.tsx +4 -1
- package/src/features/Conversation/{Plugins → Messages/Tool}/Inspector/index.tsx +34 -35
- package/src/features/Conversation/Messages/Tool/index.tsx +44 -0
- package/src/features/Conversation/Messages/index.ts +3 -2
- package/src/features/Conversation/Plugins/Render/StandaloneType/Iframe.tsx +1 -1
- package/src/features/Conversation/components/SkeletonList.tsx +2 -2
- package/src/features/Conversation/index.tsx +2 -3
- package/src/libs/agent-runtime/BaseAI.ts +2 -9
- package/src/libs/agent-runtime/anthropic/index.test.ts +195 -0
- package/src/libs/agent-runtime/anthropic/index.ts +71 -15
- package/src/libs/agent-runtime/azureOpenai/index.ts +12 -13
- package/src/libs/agent-runtime/bedrock/index.ts +24 -18
- package/src/libs/agent-runtime/google/index.test.ts +154 -0
- package/src/libs/agent-runtime/google/index.ts +91 -10
- package/src/libs/agent-runtime/groq/index.test.ts +41 -72
- package/src/libs/agent-runtime/groq/index.ts +7 -0
- package/src/libs/agent-runtime/minimax/index.test.ts +2 -2
- package/src/libs/agent-runtime/minimax/index.ts +14 -37
- package/src/libs/agent-runtime/mistral/index.test.ts +0 -53
- package/src/libs/agent-runtime/mistral/index.ts +1 -0
- package/src/libs/agent-runtime/moonshot/index.test.ts +1 -71
- package/src/libs/agent-runtime/ollama/index.test.ts +197 -0
- package/src/libs/agent-runtime/ollama/index.ts +3 -3
- package/src/libs/agent-runtime/openai/index.test.ts +0 -53
- package/src/libs/agent-runtime/openrouter/index.test.ts +1 -53
- package/src/libs/agent-runtime/perplexity/index.test.ts +0 -71
- package/src/libs/agent-runtime/perplexity/index.ts +2 -3
- package/src/libs/agent-runtime/togetherai/__snapshots__/index.test.ts.snap +886 -0
- package/src/libs/agent-runtime/togetherai/fixtures/models.json +8111 -0
- package/src/libs/agent-runtime/togetherai/index.test.ts +16 -54
- package/src/libs/agent-runtime/types/chat.ts +19 -3
- package/src/libs/agent-runtime/utils/anthropicHelpers.test.ts +120 -1
- package/src/libs/agent-runtime/utils/anthropicHelpers.ts +67 -4
- package/src/libs/agent-runtime/utils/debugStream.test.ts +70 -0
- package/src/libs/agent-runtime/utils/debugStream.ts +39 -9
- package/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.test.ts +521 -0
- package/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts +76 -5
- package/src/libs/agent-runtime/utils/response.ts +12 -0
- package/src/libs/agent-runtime/utils/streams/anthropic.test.ts +197 -0
- package/src/libs/agent-runtime/utils/streams/anthropic.ts +91 -0
- package/src/libs/agent-runtime/utils/streams/bedrock/claude.ts +21 -0
- package/src/libs/agent-runtime/utils/streams/bedrock/common.ts +32 -0
- package/src/libs/agent-runtime/utils/streams/bedrock/index.ts +3 -0
- package/src/libs/agent-runtime/utils/streams/bedrock/llama.test.ts +196 -0
- package/src/libs/agent-runtime/utils/streams/bedrock/llama.ts +51 -0
- package/src/libs/agent-runtime/utils/streams/google-ai.test.ts +97 -0
- package/src/libs/agent-runtime/utils/streams/google-ai.ts +68 -0
- package/src/libs/agent-runtime/utils/streams/index.ts +7 -0
- package/src/libs/agent-runtime/utils/streams/minimax.ts +39 -0
- package/src/libs/agent-runtime/utils/streams/ollama.test.ts +77 -0
- package/src/libs/agent-runtime/utils/streams/ollama.ts +38 -0
- package/src/libs/agent-runtime/utils/streams/openai.test.ts +263 -0
- package/src/libs/agent-runtime/utils/streams/openai.ts +79 -0
- package/src/libs/agent-runtime/utils/streams/protocol.ts +100 -0
- package/src/libs/agent-runtime/zeroone/index.test.ts +1 -53
- package/src/libs/agent-runtime/zhipu/index.test.ts +1 -1
- package/src/libs/agent-runtime/zhipu/index.ts +3 -2
- package/src/locales/default/plugin.ts +3 -4
- package/src/migrations/FromV4ToV5/fixtures/from-v1-to-v5-output.json +245 -0
- package/src/migrations/FromV4ToV5/fixtures/function-input-v4.json +96 -0
- package/src/migrations/FromV4ToV5/fixtures/function-output-v5.json +120 -0
- package/src/migrations/FromV4ToV5/index.ts +58 -0
- package/src/migrations/FromV4ToV5/migrations.test.ts +49 -0
- package/src/migrations/FromV4ToV5/types/v4.ts +21 -0
- package/src/migrations/FromV4ToV5/types/v5.ts +27 -0
- package/src/migrations/index.ts +8 -1
- package/src/services/__tests__/chat.test.ts +10 -20
- package/src/services/chat.ts +78 -65
- package/src/store/chat/slices/enchance/action.ts +15 -10
- package/src/store/chat/slices/message/action.test.ts +36 -86
- package/src/store/chat/slices/message/action.ts +70 -79
- package/src/store/chat/slices/message/reducer.ts +18 -1
- package/src/store/chat/slices/message/selectors.test.ts +38 -68
- package/src/store/chat/slices/message/selectors.ts +1 -22
- package/src/store/chat/slices/plugin/action.test.ts +147 -203
- package/src/store/chat/slices/plugin/action.ts +96 -82
- package/src/store/chat/slices/share/action.test.ts +3 -3
- package/src/store/chat/slices/share/action.ts +1 -1
- package/src/store/chat/slices/topic/action.ts +7 -2
- package/src/store/tool/selectors/tool.ts +6 -24
- package/src/store/tool/slices/builtin/action.test.ts +90 -0
- package/src/types/llm.ts +1 -1
- package/src/types/message/index.ts +9 -4
- package/src/types/message/tools.ts +57 -0
- package/src/types/openai/chat.ts +6 -0
- package/src/utils/fetch.test.ts +245 -1
- package/src/utils/fetch.ts +120 -44
- package/src/utils/toolCall.ts +21 -0
- package/src/features/Conversation/Messages/Assistant.tsx +0 -26
- package/src/features/Conversation/Messages/Function.tsx +0 -35
- package/src/libs/agent-runtime/ollama/stream.ts +0 -31
- /package/src/features/Conversation/{Plugins → Messages/Tool}/Inspector/PluginResultJSON.tsx +0 -0
- /package/src/features/Conversation/{Plugins → Messages/Tool}/Inspector/Settings.tsx +0 -0
- /package/src/features/Conversation/{Plugins → Messages/Tool}/Inspector/style.ts +0 -0
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
import type { Stream } from '@anthropic-ai/sdk/streaming';
|
|
2
|
+
import { describe, expect, it, vi } from 'vitest';
|
|
3
|
+
|
|
4
|
+
import { AnthropicStream } from './anthropic';
|
|
5
|
+
|
|
6
|
+
describe('AnthropicStream', () => {
|
|
7
|
+
it('should transform Anthropic stream to protocol stream', async () => {
|
|
8
|
+
// @ts-ignore
|
|
9
|
+
const mockAnthropicStream: Stream = {
|
|
10
|
+
[Symbol.asyncIterator]() {
|
|
11
|
+
let count = 0;
|
|
12
|
+
return {
|
|
13
|
+
next: async () => {
|
|
14
|
+
switch (count) {
|
|
15
|
+
case 0:
|
|
16
|
+
count++;
|
|
17
|
+
return {
|
|
18
|
+
done: false,
|
|
19
|
+
value: {
|
|
20
|
+
type: 'message_start',
|
|
21
|
+
message: { id: 'message_1', metadata: {} },
|
|
22
|
+
},
|
|
23
|
+
};
|
|
24
|
+
case 1:
|
|
25
|
+
count++;
|
|
26
|
+
return {
|
|
27
|
+
done: false,
|
|
28
|
+
value: {
|
|
29
|
+
type: 'content_block_delta',
|
|
30
|
+
delta: { type: 'text_delta', text: 'Hello' },
|
|
31
|
+
},
|
|
32
|
+
};
|
|
33
|
+
case 2:
|
|
34
|
+
count++;
|
|
35
|
+
return {
|
|
36
|
+
done: false,
|
|
37
|
+
value: {
|
|
38
|
+
type: 'content_block_delta',
|
|
39
|
+
delta: { type: 'text_delta', text: ' world!' },
|
|
40
|
+
},
|
|
41
|
+
};
|
|
42
|
+
case 3:
|
|
43
|
+
count++;
|
|
44
|
+
return {
|
|
45
|
+
done: false,
|
|
46
|
+
value: {
|
|
47
|
+
type: 'message_delta',
|
|
48
|
+
delta: { stop_reason: 'stop' },
|
|
49
|
+
},
|
|
50
|
+
};
|
|
51
|
+
default:
|
|
52
|
+
return { done: true, value: undefined };
|
|
53
|
+
}
|
|
54
|
+
},
|
|
55
|
+
};
|
|
56
|
+
},
|
|
57
|
+
};
|
|
58
|
+
|
|
59
|
+
const onStartMock = vi.fn();
|
|
60
|
+
const onTextMock = vi.fn();
|
|
61
|
+
const onTokenMock = vi.fn();
|
|
62
|
+
const onCompletionMock = vi.fn();
|
|
63
|
+
|
|
64
|
+
const protocolStream = AnthropicStream(mockAnthropicStream, {
|
|
65
|
+
onStart: onStartMock,
|
|
66
|
+
onText: onTextMock,
|
|
67
|
+
onToken: onTokenMock,
|
|
68
|
+
onCompletion: onCompletionMock,
|
|
69
|
+
});
|
|
70
|
+
|
|
71
|
+
const decoder = new TextDecoder();
|
|
72
|
+
const chunks = [];
|
|
73
|
+
|
|
74
|
+
// @ts-ignore
|
|
75
|
+
for await (const chunk of protocolStream) {
|
|
76
|
+
chunks.push(decoder.decode(chunk, { stream: true }));
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
expect(chunks).toEqual([
|
|
80
|
+
'id: message_1\n',
|
|
81
|
+
'event: data\n',
|
|
82
|
+
`data: {"id":"message_1","metadata":{}}\n\n`,
|
|
83
|
+
'id: message_1\n',
|
|
84
|
+
'event: text\n',
|
|
85
|
+
`data: "Hello"\n\n`,
|
|
86
|
+
'id: message_1\n',
|
|
87
|
+
'event: text\n',
|
|
88
|
+
`data: " world!"\n\n`,
|
|
89
|
+
'id: message_1\n',
|
|
90
|
+
'event: stop\n',
|
|
91
|
+
`data: "stop"\n\n`,
|
|
92
|
+
]);
|
|
93
|
+
|
|
94
|
+
expect(onStartMock).toHaveBeenCalledTimes(1);
|
|
95
|
+
expect(onTextMock).toHaveBeenNthCalledWith(1, '"Hello"');
|
|
96
|
+
expect(onTextMock).toHaveBeenNthCalledWith(2, '" world!"');
|
|
97
|
+
expect(onTokenMock).toHaveBeenCalledTimes(2);
|
|
98
|
+
expect(onCompletionMock).toHaveBeenCalledTimes(1);
|
|
99
|
+
});
|
|
100
|
+
|
|
101
|
+
it('should handle tool use event and ReadableStream input', async () => {
|
|
102
|
+
const toolUseEvent = {
|
|
103
|
+
type: 'content_block_delta',
|
|
104
|
+
delta: {
|
|
105
|
+
type: 'tool_use',
|
|
106
|
+
tool_use: {
|
|
107
|
+
id: 'tool_use_1',
|
|
108
|
+
name: 'example_tool',
|
|
109
|
+
input: { arg1: 'value1' },
|
|
110
|
+
},
|
|
111
|
+
},
|
|
112
|
+
};
|
|
113
|
+
|
|
114
|
+
const mockReadableStream = new ReadableStream({
|
|
115
|
+
start(controller) {
|
|
116
|
+
controller.enqueue({
|
|
117
|
+
type: 'message_start',
|
|
118
|
+
message: { id: 'message_1', metadata: {} },
|
|
119
|
+
});
|
|
120
|
+
controller.enqueue(toolUseEvent);
|
|
121
|
+
controller.enqueue({
|
|
122
|
+
type: 'message_stop',
|
|
123
|
+
});
|
|
124
|
+
controller.close();
|
|
125
|
+
},
|
|
126
|
+
});
|
|
127
|
+
|
|
128
|
+
const onToolCallMock = vi.fn();
|
|
129
|
+
|
|
130
|
+
const protocolStream = AnthropicStream(mockReadableStream, {
|
|
131
|
+
onToolCall: onToolCallMock,
|
|
132
|
+
});
|
|
133
|
+
|
|
134
|
+
const decoder = new TextDecoder();
|
|
135
|
+
const chunks = [];
|
|
136
|
+
|
|
137
|
+
// @ts-ignore
|
|
138
|
+
for await (const chunk of protocolStream) {
|
|
139
|
+
chunks.push(decoder.decode(chunk, { stream: true }));
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
expect(chunks).toEqual([
|
|
143
|
+
'id: message_1\n',
|
|
144
|
+
'event: data\n',
|
|
145
|
+
`data: {"id":"message_1","metadata":{}}\n\n`,
|
|
146
|
+
'id: message_1\n',
|
|
147
|
+
'event: tool_calls\n',
|
|
148
|
+
`data: [{"function":{"arguments":"{\\"arg1\\":\\"value1\\"}","name":"example_tool"},"id":"tool_use_1","index":0,"type":"function"}]\n\n`,
|
|
149
|
+
'id: message_1\n',
|
|
150
|
+
'event: stop\n',
|
|
151
|
+
`data: "message_stop"\n\n`,
|
|
152
|
+
]);
|
|
153
|
+
|
|
154
|
+
expect(onToolCallMock).toHaveBeenCalledTimes(1);
|
|
155
|
+
});
|
|
156
|
+
|
|
157
|
+
it('should handle ReadableStream input', async () => {
|
|
158
|
+
const mockReadableStream = new ReadableStream({
|
|
159
|
+
start(controller) {
|
|
160
|
+
controller.enqueue({
|
|
161
|
+
type: 'message_start',
|
|
162
|
+
message: { id: 'message_1', metadata: {} },
|
|
163
|
+
});
|
|
164
|
+
controller.enqueue({
|
|
165
|
+
type: 'content_block_delta',
|
|
166
|
+
delta: { type: 'text_delta', text: 'Hello' },
|
|
167
|
+
});
|
|
168
|
+
controller.enqueue({
|
|
169
|
+
type: 'message_stop',
|
|
170
|
+
});
|
|
171
|
+
controller.close();
|
|
172
|
+
},
|
|
173
|
+
});
|
|
174
|
+
|
|
175
|
+
const protocolStream = AnthropicStream(mockReadableStream);
|
|
176
|
+
|
|
177
|
+
const decoder = new TextDecoder();
|
|
178
|
+
const chunks = [];
|
|
179
|
+
|
|
180
|
+
// @ts-ignore
|
|
181
|
+
for await (const chunk of protocolStream) {
|
|
182
|
+
chunks.push(decoder.decode(chunk, { stream: true }));
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
expect(chunks).toEqual([
|
|
186
|
+
'id: message_1\n',
|
|
187
|
+
'event: data\n',
|
|
188
|
+
`data: {"id":"message_1","metadata":{}}\n\n`,
|
|
189
|
+
'id: message_1\n',
|
|
190
|
+
'event: text\n',
|
|
191
|
+
`data: "Hello"\n\n`,
|
|
192
|
+
'id: message_1\n',
|
|
193
|
+
'event: stop\n',
|
|
194
|
+
`data: "message_stop"\n\n`,
|
|
195
|
+
]);
|
|
196
|
+
});
|
|
197
|
+
});
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import Anthropic from '@anthropic-ai/sdk';
|
|
2
|
+
import type { Stream } from '@anthropic-ai/sdk/streaming';
|
|
3
|
+
import { readableFromAsyncIterable } from 'ai';
|
|
4
|
+
|
|
5
|
+
import { ChatStreamCallbacks } from '../../types';
|
|
6
|
+
import {
|
|
7
|
+
StreamProtocolChunk,
|
|
8
|
+
StreamProtocolToolCallChunk,
|
|
9
|
+
StreamStack,
|
|
10
|
+
StreamToolCallChunkData,
|
|
11
|
+
createCallbacksTransformer,
|
|
12
|
+
createSSEProtocolTransformer,
|
|
13
|
+
} from './protocol';
|
|
14
|
+
|
|
15
|
+
export const transformAnthropicStream = (
|
|
16
|
+
chunk: Anthropic.MessageStreamEvent,
|
|
17
|
+
stack: StreamStack,
|
|
18
|
+
): StreamProtocolChunk => {
|
|
19
|
+
// maybe need another structure to add support for multiple choices
|
|
20
|
+
switch (chunk.type) {
|
|
21
|
+
case 'message_start': {
|
|
22
|
+
stack.id = chunk.message.id;
|
|
23
|
+
return { data: chunk.message, id: chunk.message.id, type: 'data' };
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
// case 'content_block_start': {
|
|
27
|
+
// return { data: chunk.content_block.text, id: stack.id, type: 'data' };
|
|
28
|
+
// }
|
|
29
|
+
|
|
30
|
+
case 'content_block_delta': {
|
|
31
|
+
switch (chunk.delta.type as string) {
|
|
32
|
+
default:
|
|
33
|
+
case 'text_delta': {
|
|
34
|
+
return { data: chunk.delta.text, id: stack.id, type: 'text' };
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
// TODO: due to anthropic currently don't support streaming tool calling
|
|
38
|
+
// we need to add this new `tool_use` type to support streaming
|
|
39
|
+
// and maybe we need to update it when the feature is available
|
|
40
|
+
case 'tool_use': {
|
|
41
|
+
const delta = (chunk.delta as any).tool_use as Anthropic.Beta.Tools.ToolUseBlock;
|
|
42
|
+
|
|
43
|
+
const toolCall: StreamToolCallChunkData = {
|
|
44
|
+
function: { arguments: JSON.stringify(delta.input), name: delta.name },
|
|
45
|
+
id: delta.id,
|
|
46
|
+
index: 0,
|
|
47
|
+
type: 'function',
|
|
48
|
+
};
|
|
49
|
+
|
|
50
|
+
return {
|
|
51
|
+
data: [toolCall],
|
|
52
|
+
id: stack.id,
|
|
53
|
+
type: 'tool_calls',
|
|
54
|
+
} as StreamProtocolToolCallChunk;
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
case 'message_delta': {
|
|
60
|
+
return { data: chunk.delta.stop_reason, id: stack.id, type: 'stop' };
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
case 'message_stop': {
|
|
64
|
+
return { data: 'message_stop', id: stack.id, type: 'stop' };
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
default: {
|
|
68
|
+
return { data: chunk, id: stack.id, type: 'data' };
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
};
|
|
72
|
+
|
|
73
|
+
const chatStreamable = async function* (stream: AsyncIterable<Anthropic.MessageStreamEvent>) {
|
|
74
|
+
for await (const response of stream) {
|
|
75
|
+
yield response;
|
|
76
|
+
}
|
|
77
|
+
};
|
|
78
|
+
|
|
79
|
+
export const AnthropicStream = (
|
|
80
|
+
stream: Stream<Anthropic.MessageStreamEvent> | ReadableStream,
|
|
81
|
+
callbacks?: ChatStreamCallbacks,
|
|
82
|
+
) => {
|
|
83
|
+
const streamStack: StreamStack = { id: '' };
|
|
84
|
+
|
|
85
|
+
const readableStream =
|
|
86
|
+
stream instanceof ReadableStream ? stream : readableFromAsyncIterable(chatStreamable(stream));
|
|
87
|
+
|
|
88
|
+
return readableStream
|
|
89
|
+
.pipeThrough(createSSEProtocolTransformer(transformAnthropicStream, streamStack))
|
|
90
|
+
.pipeThrough(createCallbacksTransformer(callbacks));
|
|
91
|
+
};
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import { InvokeModelWithResponseStreamResponse } from '@aws-sdk/client-bedrock-runtime';
|
|
2
|
+
|
|
3
|
+
import { nanoid } from '@/utils/uuid';
|
|
4
|
+
|
|
5
|
+
import { ChatStreamCallbacks } from '../../../types';
|
|
6
|
+
import { transformAnthropicStream } from '../anthropic';
|
|
7
|
+
import { StreamStack, createCallbacksTransformer, createSSEProtocolTransformer } from '../protocol';
|
|
8
|
+
import { createBedrockStream } from './common';
|
|
9
|
+
|
|
10
|
+
export const AWSBedrockClaudeStream = (
|
|
11
|
+
res: InvokeModelWithResponseStreamResponse | ReadableStream,
|
|
12
|
+
cb?: ChatStreamCallbacks,
|
|
13
|
+
): ReadableStream<string> => {
|
|
14
|
+
const streamStack: StreamStack = { id: 'chat_' + nanoid() };
|
|
15
|
+
|
|
16
|
+
const stream = res instanceof ReadableStream ? res : createBedrockStream(res);
|
|
17
|
+
|
|
18
|
+
return stream
|
|
19
|
+
.pipeThrough(createSSEProtocolTransformer(transformAnthropicStream, streamStack))
|
|
20
|
+
.pipeThrough(createCallbacksTransformer(cb));
|
|
21
|
+
};
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import {
|
|
2
|
+
InvokeModelWithResponseStreamResponse,
|
|
3
|
+
ResponseStream,
|
|
4
|
+
} from '@aws-sdk/client-bedrock-runtime';
|
|
5
|
+
import { readableFromAsyncIterable } from 'ai';
|
|
6
|
+
|
|
7
|
+
const chatStreamable = async function* (stream: AsyncIterable<ResponseStream>) {
|
|
8
|
+
for await (const response of stream) {
|
|
9
|
+
if (response.chunk) {
|
|
10
|
+
const decoder = new TextDecoder();
|
|
11
|
+
|
|
12
|
+
const value = decoder.decode(response.chunk.bytes, { stream: true });
|
|
13
|
+
try {
|
|
14
|
+
const chunk = JSON.parse(value);
|
|
15
|
+
|
|
16
|
+
yield chunk;
|
|
17
|
+
} catch (e) {
|
|
18
|
+
console.log('bedrock stream parser error:', e);
|
|
19
|
+
|
|
20
|
+
yield value;
|
|
21
|
+
}
|
|
22
|
+
} else {
|
|
23
|
+
yield response;
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
};
|
|
27
|
+
|
|
28
|
+
/**
|
|
29
|
+
* covert the bedrock response to a readable stream
|
|
30
|
+
*/
|
|
31
|
+
export const createBedrockStream = (res: InvokeModelWithResponseStreamResponse) =>
|
|
32
|
+
readableFromAsyncIterable(chatStreamable(res.body!));
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
import { InvokeModelWithResponseStreamResponse } from '@aws-sdk/client-bedrock-runtime';
|
|
2
|
+
import { Readable } from 'stream';
|
|
3
|
+
import { describe, expect, it, vi } from 'vitest';
|
|
4
|
+
|
|
5
|
+
import * as uuidModule from '@/utils/uuid';
|
|
6
|
+
|
|
7
|
+
import { AWSBedrockLlamaStream } from './llama';
|
|
8
|
+
|
|
9
|
+
describe('AWSBedrockLlamaStream', () => {
|
|
10
|
+
it('should transform Bedrock Llama stream to protocol stream', async () => {
|
|
11
|
+
vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('1');
|
|
12
|
+
const mockBedrockStream = new ReadableStream({
|
|
13
|
+
start(controller) {
|
|
14
|
+
controller.enqueue({ generation: 'Hello', generation_token_count: 1 });
|
|
15
|
+
controller.enqueue({ generation: ' world!', generation_token_count: 2 });
|
|
16
|
+
controller.enqueue({ stop_reason: 'stop' });
|
|
17
|
+
controller.close();
|
|
18
|
+
},
|
|
19
|
+
});
|
|
20
|
+
|
|
21
|
+
const onStartMock = vi.fn();
|
|
22
|
+
const onTextMock = vi.fn();
|
|
23
|
+
const onTokenMock = vi.fn();
|
|
24
|
+
const onCompletionMock = vi.fn();
|
|
25
|
+
|
|
26
|
+
const protocolStream = AWSBedrockLlamaStream(mockBedrockStream, {
|
|
27
|
+
onStart: onStartMock,
|
|
28
|
+
onText: onTextMock,
|
|
29
|
+
onToken: onTokenMock,
|
|
30
|
+
onCompletion: onCompletionMock,
|
|
31
|
+
});
|
|
32
|
+
|
|
33
|
+
const decoder = new TextDecoder();
|
|
34
|
+
const chunks = [];
|
|
35
|
+
|
|
36
|
+
// @ts-ignore
|
|
37
|
+
for await (const chunk of protocolStream) {
|
|
38
|
+
chunks.push(decoder.decode(chunk, { stream: true }));
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
expect(chunks).toEqual([
|
|
42
|
+
'id: chat_1\n',
|
|
43
|
+
'event: text\n',
|
|
44
|
+
`data: "Hello"\n\n`,
|
|
45
|
+
'id: chat_1\n',
|
|
46
|
+
'event: text\n',
|
|
47
|
+
`data: " world!"\n\n`,
|
|
48
|
+
'id: chat_1\n',
|
|
49
|
+
'event: stop\n',
|
|
50
|
+
`data: "finished"\n\n`,
|
|
51
|
+
]);
|
|
52
|
+
|
|
53
|
+
expect(onStartMock).toHaveBeenCalledTimes(1);
|
|
54
|
+
expect(onTextMock).toHaveBeenNthCalledWith(1, '"Hello"');
|
|
55
|
+
expect(onTextMock).toHaveBeenNthCalledWith(2, '" world!"');
|
|
56
|
+
expect(onTokenMock).toHaveBeenCalledTimes(2);
|
|
57
|
+
expect(onCompletionMock).toHaveBeenCalledTimes(1);
|
|
58
|
+
});
|
|
59
|
+
|
|
60
|
+
it('should transform Bedrock Llama AsyncIterator to protocol stream', async () => {
|
|
61
|
+
vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('1');
|
|
62
|
+
|
|
63
|
+
const mockBedrockStream: InvokeModelWithResponseStreamResponse = {
|
|
64
|
+
body: {
|
|
65
|
+
// @ts-ignore
|
|
66
|
+
async *[Symbol.asyncIterator]() {
|
|
67
|
+
yield { generation: 'Hello', generation_token_count: 1 };
|
|
68
|
+
yield { generation: ' world!', generation_token_count: 2 };
|
|
69
|
+
yield { stop_reason: 'stop' };
|
|
70
|
+
},
|
|
71
|
+
},
|
|
72
|
+
};
|
|
73
|
+
|
|
74
|
+
const onStartMock = vi.fn();
|
|
75
|
+
const onTextMock = vi.fn();
|
|
76
|
+
const onTokenMock = vi.fn();
|
|
77
|
+
const onCompletionMock = vi.fn();
|
|
78
|
+
|
|
79
|
+
const protocolStream = AWSBedrockLlamaStream(mockBedrockStream, {
|
|
80
|
+
onStart: onStartMock,
|
|
81
|
+
onText: onTextMock,
|
|
82
|
+
onToken: onTokenMock,
|
|
83
|
+
onCompletion: onCompletionMock,
|
|
84
|
+
});
|
|
85
|
+
|
|
86
|
+
const decoder = new TextDecoder();
|
|
87
|
+
const chunks = [];
|
|
88
|
+
|
|
89
|
+
// @ts-ignore
|
|
90
|
+
for await (const chunk of protocolStream) {
|
|
91
|
+
chunks.push(decoder.decode(chunk, { stream: true }));
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
expect(chunks).toEqual([
|
|
95
|
+
'id: chat_1\n',
|
|
96
|
+
'event: text\n',
|
|
97
|
+
`data: "Hello"\n\n`,
|
|
98
|
+
'id: chat_1\n',
|
|
99
|
+
'event: text\n',
|
|
100
|
+
`data: " world!"\n\n`,
|
|
101
|
+
'id: chat_1\n',
|
|
102
|
+
'event: stop\n',
|
|
103
|
+
`data: "finished"\n\n`,
|
|
104
|
+
]);
|
|
105
|
+
|
|
106
|
+
expect(onStartMock).toHaveBeenCalledTimes(1);
|
|
107
|
+
expect(onTextMock).toHaveBeenNthCalledWith(1, '"Hello"');
|
|
108
|
+
expect(onTextMock).toHaveBeenNthCalledWith(2, '" world!"');
|
|
109
|
+
expect(onTokenMock).toHaveBeenCalledTimes(2);
|
|
110
|
+
expect(onCompletionMock).toHaveBeenCalledTimes(1);
|
|
111
|
+
});
|
|
112
|
+
|
|
113
|
+
it('should handle Bedrock response with chunk property', async () => {
|
|
114
|
+
vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('2');
|
|
115
|
+
|
|
116
|
+
const mockBedrockStream: InvokeModelWithResponseStreamResponse = {
|
|
117
|
+
contentType: 'any',
|
|
118
|
+
body: {
|
|
119
|
+
// @ts-ignore
|
|
120
|
+
async *[Symbol.asyncIterator]() {
|
|
121
|
+
yield {
|
|
122
|
+
chunk: {
|
|
123
|
+
bytes: new TextEncoder().encode('{"generation":"Hello","generation_token_count":1}'),
|
|
124
|
+
},
|
|
125
|
+
};
|
|
126
|
+
yield {
|
|
127
|
+
chunk: {
|
|
128
|
+
bytes: new TextEncoder().encode(
|
|
129
|
+
'{"generation":" world!","generation_token_count":2}',
|
|
130
|
+
),
|
|
131
|
+
},
|
|
132
|
+
};
|
|
133
|
+
yield { chunk: { bytes: new TextEncoder().encode('{"stop_reason":"stop"}') } };
|
|
134
|
+
},
|
|
135
|
+
},
|
|
136
|
+
};
|
|
137
|
+
|
|
138
|
+
const onStartMock = vi.fn();
|
|
139
|
+
const onTextMock = vi.fn();
|
|
140
|
+
const onTokenMock = vi.fn();
|
|
141
|
+
const onCompletionMock = vi.fn();
|
|
142
|
+
|
|
143
|
+
const protocolStream = AWSBedrockLlamaStream(mockBedrockStream, {
|
|
144
|
+
onStart: onStartMock,
|
|
145
|
+
onText: onTextMock,
|
|
146
|
+
onToken: onTokenMock,
|
|
147
|
+
onCompletion: onCompletionMock,
|
|
148
|
+
});
|
|
149
|
+
|
|
150
|
+
const decoder = new TextDecoder();
|
|
151
|
+
const chunks = [];
|
|
152
|
+
|
|
153
|
+
// @ts-ignore
|
|
154
|
+
for await (const chunk of protocolStream) {
|
|
155
|
+
chunks.push(decoder.decode(chunk, { stream: true }));
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
expect(chunks).toEqual([
|
|
159
|
+
'id: chat_2\n',
|
|
160
|
+
'event: text\n',
|
|
161
|
+
`data: "Hello"\n\n`,
|
|
162
|
+
'id: chat_2\n',
|
|
163
|
+
'event: text\n',
|
|
164
|
+
`data: " world!"\n\n`,
|
|
165
|
+
'id: chat_2\n',
|
|
166
|
+
'event: stop\n',
|
|
167
|
+
`data: "finished"\n\n`,
|
|
168
|
+
]);
|
|
169
|
+
|
|
170
|
+
expect(onStartMock).toHaveBeenCalledTimes(1);
|
|
171
|
+
expect(onTextMock).toHaveBeenNthCalledWith(1, '"Hello"');
|
|
172
|
+
expect(onTextMock).toHaveBeenNthCalledWith(2, '" world!"');
|
|
173
|
+
expect(onTokenMock).toHaveBeenCalledTimes(2);
|
|
174
|
+
expect(onCompletionMock).toHaveBeenCalledTimes(1);
|
|
175
|
+
});
|
|
176
|
+
|
|
177
|
+
it('should handle empty stream', async () => {
|
|
178
|
+
const mockBedrockStream = new ReadableStream({
|
|
179
|
+
start(controller) {
|
|
180
|
+
controller.close();
|
|
181
|
+
},
|
|
182
|
+
});
|
|
183
|
+
|
|
184
|
+
const protocolStream = AWSBedrockLlamaStream(mockBedrockStream);
|
|
185
|
+
|
|
186
|
+
const decoder = new TextDecoder();
|
|
187
|
+
const chunks = [];
|
|
188
|
+
|
|
189
|
+
// @ts-ignore
|
|
190
|
+
for await (const chunk of protocolStream) {
|
|
191
|
+
chunks.push(decoder.decode(chunk, { stream: true }));
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
expect(chunks).toEqual([]);
|
|
195
|
+
});
|
|
196
|
+
});
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import { InvokeModelWithResponseStreamResponse } from '@aws-sdk/client-bedrock-runtime';
|
|
2
|
+
|
|
3
|
+
import { nanoid } from '@/utils/uuid';
|
|
4
|
+
|
|
5
|
+
import { ChatStreamCallbacks } from '../../../types';
|
|
6
|
+
import {
|
|
7
|
+
StreamProtocolChunk,
|
|
8
|
+
StreamStack,
|
|
9
|
+
createCallbacksTransformer,
|
|
10
|
+
createSSEProtocolTransformer,
|
|
11
|
+
} from '../protocol';
|
|
12
|
+
import { createBedrockStream } from './common';
|
|
13
|
+
|
|
14
|
+
interface AmazonBedrockInvocationMetrics {
|
|
15
|
+
firstByteLatency: number;
|
|
16
|
+
inputTokenCount: number;
|
|
17
|
+
invocationLatency: number;
|
|
18
|
+
outputTokenCount: number;
|
|
19
|
+
}
|
|
20
|
+
interface BedrockLlamaStreamChunk {
|
|
21
|
+
'amazon-bedrock-invocationMetrics'?: AmazonBedrockInvocationMetrics;
|
|
22
|
+
'generation': string;
|
|
23
|
+
'generation_token_count': number;
|
|
24
|
+
'prompt_token_count'?: number | null;
|
|
25
|
+
'stop_reason'?: null | 'stop' | string;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
export const transformLlamaStream = (
|
|
29
|
+
chunk: BedrockLlamaStreamChunk,
|
|
30
|
+
stack: StreamStack,
|
|
31
|
+
): StreamProtocolChunk => {
|
|
32
|
+
// maybe need another structure to add support for multiple choices
|
|
33
|
+
if (chunk.stop_reason) {
|
|
34
|
+
return { data: 'finished', id: stack.id, type: 'stop' };
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
return { data: chunk.generation, id: stack.id, type: 'text' };
|
|
38
|
+
};
|
|
39
|
+
|
|
40
|
+
export const AWSBedrockLlamaStream = (
|
|
41
|
+
res: InvokeModelWithResponseStreamResponse | ReadableStream,
|
|
42
|
+
cb?: ChatStreamCallbacks,
|
|
43
|
+
): ReadableStream<string> => {
|
|
44
|
+
const streamStack: StreamStack = { id: 'chat_' + nanoid() };
|
|
45
|
+
|
|
46
|
+
const stream = res instanceof ReadableStream ? res : createBedrockStream(res);
|
|
47
|
+
|
|
48
|
+
return stream
|
|
49
|
+
.pipeThrough(createSSEProtocolTransformer(transformLlamaStream, streamStack))
|
|
50
|
+
.pipeThrough(createCallbacksTransformer(cb));
|
|
51
|
+
};
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import { EnhancedGenerateContentResponse } from '@google/generative-ai';
|
|
2
|
+
import { describe, expect, it, vi } from 'vitest';
|
|
3
|
+
|
|
4
|
+
import * as uuidModule from '@/utils/uuid';
|
|
5
|
+
|
|
6
|
+
import { GoogleGenerativeAIStream } from './google-ai';
|
|
7
|
+
|
|
8
|
+
describe('GoogleGenerativeAIStream', () => {
|
|
9
|
+
it('should transform Google Generative AI stream to protocol stream', async () => {
|
|
10
|
+
vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('1');
|
|
11
|
+
|
|
12
|
+
const mockGenerateContentResponse = (text: string, functionCalls?: any[]) =>
|
|
13
|
+
({
|
|
14
|
+
text: () => text,
|
|
15
|
+
functionCall: () => functionCalls?.[0],
|
|
16
|
+
functionCalls: () => functionCalls,
|
|
17
|
+
}) as EnhancedGenerateContentResponse;
|
|
18
|
+
|
|
19
|
+
const mockGoogleStream = new ReadableStream({
|
|
20
|
+
start(controller) {
|
|
21
|
+
controller.enqueue(mockGenerateContentResponse('Hello'));
|
|
22
|
+
|
|
23
|
+
controller.enqueue(
|
|
24
|
+
mockGenerateContentResponse('', [{ name: 'testFunction', args: { arg1: 'value1' } }]),
|
|
25
|
+
);
|
|
26
|
+
controller.enqueue(mockGenerateContentResponse(' world!'));
|
|
27
|
+
controller.close();
|
|
28
|
+
},
|
|
29
|
+
});
|
|
30
|
+
|
|
31
|
+
const onStartMock = vi.fn();
|
|
32
|
+
const onTextMock = vi.fn();
|
|
33
|
+
const onTokenMock = vi.fn();
|
|
34
|
+
const onToolCallMock = vi.fn();
|
|
35
|
+
const onCompletionMock = vi.fn();
|
|
36
|
+
|
|
37
|
+
const protocolStream = GoogleGenerativeAIStream(mockGoogleStream, {
|
|
38
|
+
onStart: onStartMock,
|
|
39
|
+
onText: onTextMock,
|
|
40
|
+
onToken: onTokenMock,
|
|
41
|
+
onToolCall: onToolCallMock,
|
|
42
|
+
onCompletion: onCompletionMock,
|
|
43
|
+
});
|
|
44
|
+
|
|
45
|
+
const decoder = new TextDecoder();
|
|
46
|
+
const chunks = [];
|
|
47
|
+
|
|
48
|
+
// @ts-ignore
|
|
49
|
+
for await (const chunk of protocolStream) {
|
|
50
|
+
chunks.push(decoder.decode(chunk, { stream: true }));
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
expect(chunks).toEqual([
|
|
54
|
+
// text
|
|
55
|
+
'id: chat_1\n',
|
|
56
|
+
'event: text\n',
|
|
57
|
+
`data: "Hello"\n\n`,
|
|
58
|
+
|
|
59
|
+
// tool call
|
|
60
|
+
'id: chat_1\n',
|
|
61
|
+
'event: tool_calls\n',
|
|
62
|
+
`data: [{"function":{"arguments":"{\\"arg1\\":\\"value1\\"}","name":"testFunction"},"id":"testFunction_0","index":0,"type":"function"}]\n\n`,
|
|
63
|
+
|
|
64
|
+
// text
|
|
65
|
+
'id: chat_1\n',
|
|
66
|
+
'event: text\n',
|
|
67
|
+
`data: " world!"\n\n`,
|
|
68
|
+
]);
|
|
69
|
+
|
|
70
|
+
expect(onStartMock).toHaveBeenCalledTimes(1);
|
|
71
|
+
expect(onTextMock).toHaveBeenNthCalledWith(1, '"Hello"');
|
|
72
|
+
expect(onTextMock).toHaveBeenNthCalledWith(2, '" world!"');
|
|
73
|
+
expect(onTokenMock).toHaveBeenCalledTimes(2);
|
|
74
|
+
expect(onToolCallMock).toHaveBeenCalledTimes(1);
|
|
75
|
+
expect(onCompletionMock).toHaveBeenCalledTimes(1);
|
|
76
|
+
});
|
|
77
|
+
|
|
78
|
+
it('should handle empty stream', async () => {
|
|
79
|
+
const mockGoogleStream = new ReadableStream({
|
|
80
|
+
start(controller) {
|
|
81
|
+
controller.close();
|
|
82
|
+
},
|
|
83
|
+
});
|
|
84
|
+
|
|
85
|
+
const protocolStream = GoogleGenerativeAIStream(mockGoogleStream);
|
|
86
|
+
|
|
87
|
+
const decoder = new TextDecoder();
|
|
88
|
+
const chunks = [];
|
|
89
|
+
|
|
90
|
+
// @ts-ignore
|
|
91
|
+
for await (const chunk of protocolStream) {
|
|
92
|
+
chunks.push(decoder.decode(chunk, { stream: true }));
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
expect(chunks).toEqual([]);
|
|
96
|
+
});
|
|
97
|
+
});
|