@ai-sdk/svelte 0.0.24 → 0.0.26

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/use-chat.ts CHANGED
@@ -6,7 +6,7 @@ import type {
6
6
  IdGenerator,
7
7
  JSONValue,
8
8
  Message,
9
- UseChatOptions,
9
+ UseChatOptions as SharedUseChatOptions,
10
10
  } from '@ai-sdk/ui-utils';
11
11
  import {
12
12
  callChatApi,
@@ -15,7 +15,23 @@ import {
15
15
  } from '@ai-sdk/ui-utils';
16
16
  import { useSWR } from 'sswr';
17
17
  import { Readable, Writable, derived, get, writable } from 'svelte/store';
18
- export type { CreateMessage, Message, UseChatOptions };
18
+ export type { CreateMessage, Message };
19
+
20
+ export type UseChatOptions = SharedUseChatOptions & {
21
+ /**
22
+ Maximal number of automatic roundtrips for tool calls.
23
+
24
+ An automatic tool call roundtrip is a call to the server with the
25
+ tool call results when all tool calls in the last assistant
26
+ message have results.
27
+
28
+ A maximum number is required to prevent infinite loops in the
29
+ case of misconfigured tools.
30
+
31
+ By default, it's set to 0, which will disable the feature.
32
+ */
33
+ maxToolRoundtrips?: number;
34
+ };
19
35
 
20
36
  export type UseChatHelpers = {
21
37
  /** Current messages in the chat */
@@ -66,7 +82,21 @@ export type UseChatHelpers = {
66
82
 
67
83
  /** Additional data added on the server via StreamData */
68
84
  data: Readable<JSONValue[] | undefined>;
85
+ /**
86
+ Maximal number of automatic roundtrips for tool calls.
87
+
88
+ An automatic tool call roundtrip is a call to the server with the
89
+ tool call results when all tool calls in the last assistant
90
+ message have results.
91
+
92
+ A maximum number is required to prevent infinite loops in the
93
+ case of misconfigured tools.
94
+
95
+ By default, it's set to 0, which will disable the feature.
96
+ */
97
+ maxToolRoundtrips?: number;
69
98
  };
99
+
70
100
  const getStreamedResponse = async (
71
101
  api: string,
72
102
  chatRequest: ChatRequest,
@@ -81,9 +111,10 @@ const getStreamedResponse = async (
81
111
  previousMessages: Message[],
82
112
  abortControllerRef: AbortController | null,
83
113
  generateId: IdGenerator,
84
- streamMode: 'stream-data' | 'text' | undefined,
85
- onFinish: ((message: Message) => void) | undefined,
114
+ streamProtocol: UseChatOptions['streamProtocol'],
115
+ onFinish: UseChatOptions['onFinish'],
86
116
  onResponse: ((response: Response) => void | Promise<void>) | undefined,
117
+ onToolCall: UseChatOptions['onToolCall'] | undefined,
87
118
  sendExtraMessageFields: boolean | undefined,
88
119
  fetch: FetchFunction | undefined,
89
120
  keepLastMessageOnError: boolean | undefined,
@@ -104,12 +135,14 @@ const getStreamedResponse = async (
104
135
  function_call,
105
136
  tool_calls,
106
137
  tool_call_id,
138
+ toolInvocations,
107
139
  }) => ({
108
140
  role,
109
141
  content,
110
142
  ...(name !== undefined && { name }),
111
143
  ...(data !== undefined && { data }),
112
144
  ...(annotations !== undefined && { annotations }),
145
+ ...(toolInvocations !== undefined && { toolInvocations }),
113
146
  // outdated function/tool call handling (TODO deprecate):
114
147
  tool_call_id,
115
148
  ...(function_call !== undefined && { function_call }),
@@ -137,7 +170,7 @@ const getStreamedResponse = async (
137
170
  tool_choice: chatRequest.tool_choice,
138
171
  }),
139
172
  },
140
- streamMode,
173
+ streamProtocol,
141
174
  credentials: extraMetadata.credentials,
142
175
  headers: {
143
176
  ...extraMetadata.headers,
@@ -156,7 +189,7 @@ const getStreamedResponse = async (
156
189
  },
157
190
  onFinish,
158
191
  generateId,
159
- onToolCall: undefined, // not implemented yet
192
+ onToolCall,
160
193
  fetch,
161
194
  });
162
195
  };
@@ -165,6 +198,36 @@ let uniqueId = 0;
165
198
 
166
199
  const store: Record<string, Message[] | undefined> = {};
167
200
 
201
+ /**
202
+ Check if the message is an assistant message with completed tool calls.
203
+ The message must have at least one tool invocation and all tool invocations
204
+ must have a result.
205
+ */
206
+ function isAssistantMessageWithCompletedToolCalls(message: Message) {
207
+ return (
208
+ message.role === 'assistant' &&
209
+ message.toolInvocations &&
210
+ message.toolInvocations.length > 0 &&
211
+ message.toolInvocations.every(toolInvocation => 'result' in toolInvocation)
212
+ );
213
+ }
214
+
215
+ /**
216
+ Returns the number of trailing assistant messages in the array.
217
+ */
218
+ function countTrailingAssistantMessages(messages: Message[]) {
219
+ let count = 0;
220
+ for (let i = messages.length - 1; i >= 0; i--) {
221
+ if (messages[i].role === 'assistant') {
222
+ count++;
223
+ } else {
224
+ break;
225
+ }
226
+ }
227
+
228
+ return count;
229
+ }
230
+
168
231
  export function useChat({
169
232
  api = '/api/chat',
170
233
  id,
@@ -174,16 +237,32 @@ export function useChat({
174
237
  experimental_onFunctionCall,
175
238
  experimental_onToolCall,
176
239
  streamMode,
240
+ streamProtocol,
177
241
  onResponse,
178
242
  onFinish,
179
243
  onError,
244
+ onToolCall,
180
245
  credentials,
181
246
  headers,
182
247
  body,
183
248
  generateId = generateIdFunc,
184
249
  fetch,
185
250
  keepLastMessageOnError = false,
186
- }: UseChatOptions = {}): UseChatHelpers {
251
+ maxToolRoundtrips = 0,
252
+ }: UseChatOptions = {}): UseChatHelpers & {
253
+ addToolResult: ({
254
+ toolCallId,
255
+ result,
256
+ }: {
257
+ toolCallId: string;
258
+ result: any;
259
+ }) => void;
260
+ } {
261
+ // streamMode is deprecated, use streamProtocol instead.
262
+ if (streamMode) {
263
+ streamProtocol ??= streamMode === 'text' ? 'text' : undefined;
264
+ }
265
+
187
266
  // Generate a unique id for the chat if not provided.
188
267
  const chatId = id || `chat-${uniqueId++}`;
189
268
 
@@ -226,6 +305,9 @@ export function useChat({
226
305
  // Actual mutation hook to send messages to the API endpoint and update the
227
306
  // chat state.
228
307
  async function triggerRequest(chatRequest: ChatRequest) {
308
+ const messagesSnapshot = get(messages);
309
+ const messageCount = messagesSnapshot.length;
310
+
229
311
  try {
230
312
  error.set(undefined);
231
313
  loading.set(true);
@@ -245,9 +327,10 @@ export function useChat({
245
327
  get(messages),
246
328
  abortController,
247
329
  generateId,
248
- streamMode,
330
+ streamProtocol,
249
331
  onFinish,
250
332
  onResponse,
333
+ onToolCall,
251
334
  sendExtraMessageFields,
252
335
  fetch,
253
336
  keepLastMessageOnError,
@@ -261,8 +344,6 @@ export function useChat({
261
344
  });
262
345
 
263
346
  abortController = null;
264
-
265
- return null;
266
347
  } catch (err) {
267
348
  // Ignore abort errors as they are expected.
268
349
  if ((err as any).name === 'AbortError') {
@@ -278,6 +359,25 @@ export function useChat({
278
359
  } finally {
279
360
  loading.set(false);
280
361
  }
362
+
363
+ // auto-submit when all tool calls in the last assistant message have results:
364
+ const newMessagesSnapshot = get(messages);
365
+ const lastMessage = newMessagesSnapshot[newMessagesSnapshot.length - 1];
366
+
367
+ if (
368
+ // ensure we actually have new messages (to prevent infinite loops in case of errors):
369
+ newMessagesSnapshot.length > messageCount &&
370
+ // ensure there is a last message:
371
+ lastMessage != null &&
372
+ // check if the feature is enabled:
373
+ maxToolRoundtrips > 0 &&
374
+ // check that roundtrip is possible:
375
+ isAssistantMessageWithCompletedToolCalls(lastMessage) &&
376
+ // limit the number of automatic roundtrips:
377
+ countTrailingAssistantMessages(newMessagesSnapshot) <= maxToolRoundtrips
378
+ ) {
379
+ await triggerRequest({ messages: newMessagesSnapshot });
380
+ }
281
381
  }
282
382
 
283
383
  const append: UseChatHelpers['append'] = async (
@@ -358,10 +458,6 @@ export function useChat({
358
458
  headers: requestOptions.headers,
359
459
  body: requestOptions.body,
360
460
  data,
361
- ...(functions !== undefined && { functions }),
362
- ...(function_call !== undefined && { function_call }),
363
- ...(tools !== undefined && { tools }),
364
- ...(tool_choice !== undefined && { tool_choice }),
365
461
  };
366
462
 
367
463
  return triggerRequest(chatRequest);
@@ -428,6 +524,40 @@ export function useChat({
428
524
  },
429
525
  );
430
526
 
527
+ const addToolResult = ({
528
+ toolCallId,
529
+ result,
530
+ }: {
531
+ toolCallId: string;
532
+ result: any;
533
+ }) => {
534
+ const messagesSnapshot = get(messages) ?? [];
535
+ const updatedMessages = messagesSnapshot.map((message, index, arr) =>
536
+ // update the tool calls in the last assistant message:
537
+ index === arr.length - 1 &&
538
+ message.role === 'assistant' &&
539
+ message.toolInvocations
540
+ ? {
541
+ ...message,
542
+ toolInvocations: message.toolInvocations.map(toolInvocation =>
543
+ toolInvocation.toolCallId === toolCallId
544
+ ? { ...toolInvocation, result }
545
+ : toolInvocation,
546
+ ),
547
+ }
548
+ : message,
549
+ );
550
+
551
+ messages.set(updatedMessages);
552
+
553
+ // auto-submit when all tool calls in the last assistant message have results:
554
+ const lastMessage = updatedMessages[updatedMessages.length - 1];
555
+
556
+ if (isAssistantMessageWithCompletedToolCalls(lastMessage)) {
557
+ triggerRequest({ messages: updatedMessages });
558
+ }
559
+ };
560
+
431
561
  return {
432
562
  messages,
433
563
  error,
@@ -439,5 +569,6 @@ export function useChat({
439
569
  handleSubmit,
440
570
  isLoading,
441
571
  data: streamData,
572
+ addToolResult,
442
573
  };
443
574
  }
@@ -61,11 +61,17 @@ export function useCompletion({
61
61
  headers,
62
62
  body,
63
63
  streamMode,
64
+ streamProtocol,
64
65
  onResponse,
65
66
  onFinish,
66
67
  onError,
67
68
  fetch,
68
69
  }: UseCompletionOptions = {}): UseCompletionHelpers {
70
+ // streamMode is deprecated, use streamProtocol instead.
71
+ if (streamMode) {
72
+ streamProtocol ??= streamMode === 'text' ? 'text' : undefined;
73
+ }
74
+
69
75
  // Generate an unique id for the completion if not provided.
70
76
  const completionId = id || `completion-${uniqueId++}`;
71
77
 
@@ -115,7 +121,7 @@ export function useCompletion({
115
121
  ...body,
116
122
  ...options?.body,
117
123
  },
118
- streamMode,
124
+ streamProtocol,
119
125
  setCompletion: mutate,
120
126
  setLoading: loadingState => loading.set(loadingState),
121
127
  setError: err => error.set(err),