@ai-sdk/react 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,572 @@
1
+ import type {
2
+ ChatRequest,
3
+ ChatRequestOptions,
4
+ CreateMessage,
5
+ IdGenerator,
6
+ JSONValue,
7
+ Message,
8
+ UseChatOptions,
9
+ } from '@ai-sdk/ui-utils';
10
+ import {
11
+ callChatApi,
12
+ generateId as generateIdFunc,
13
+ processChatStream,
14
+ } from '@ai-sdk/ui-utils';
15
+ import { useCallback, useEffect, useId, useRef, useState } from 'react';
16
+ import useSWR, { KeyedMutator } from 'swr';
17
+
18
+ export type { CreateMessage, Message, UseChatOptions };
19
+
20
+ export type UseChatHelpers = {
21
+ /** Current messages in the chat */
22
+ messages: Message[];
23
+ /** The error object of the API request */
24
+ error: undefined | Error;
25
+ /**
26
+ * Append a user message to the chat list. This triggers the API call to fetch
27
+ * the assistant's response.
28
+ * @param message The message to append
29
+ * @param options Additional options to pass to the API call
30
+ */
31
+ append: (
32
+ message: Message | CreateMessage,
33
+ chatRequestOptions?: ChatRequestOptions,
34
+ ) => Promise<string | null | undefined>;
35
+ /**
36
+ * Reload the last AI chat response for the given chat history. If the last
37
+ * message isn't from the assistant, it will request the API to generate a
38
+ * new response.
39
+ */
40
+ reload: (
41
+ chatRequestOptions?: ChatRequestOptions,
42
+ ) => Promise<string | null | undefined>;
43
+ /**
44
+ * Abort the current request immediately, keep the generated tokens if any.
45
+ */
46
+ stop: () => void;
47
+ /**
48
+ * Update the `messages` state locally. This is useful when you want to
49
+ * edit the messages on the client, and then trigger the `reload` method
50
+ * manually to regenerate the AI response.
51
+ */
52
+ setMessages: (messages: Message[]) => void;
53
+ /** The current value of the input */
54
+ input: string;
55
+ /** setState-powered method to update the input value */
56
+ setInput: React.Dispatch<React.SetStateAction<string>>;
57
+ /** An input/textarea-ready onChange handler to control the value of the input */
58
+ handleInputChange: (
59
+ e:
60
+ | React.ChangeEvent<HTMLInputElement>
61
+ | React.ChangeEvent<HTMLTextAreaElement>,
62
+ ) => void;
63
+ /** Form submission handler to automatically reset input and append a user message */
64
+ handleSubmit: (
65
+ e: React.FormEvent<HTMLFormElement>,
66
+ chatRequestOptions?: ChatRequestOptions,
67
+ ) => void;
68
+ metadata?: Object;
69
+ /** Whether the API request is in progress */
70
+ isLoading: boolean;
71
+ /** Additional data added on the server via StreamData */
72
+ data?: JSONValue[];
73
+ };
74
+
75
+ const getStreamedResponse = async (
76
+ api: string,
77
+ chatRequest: ChatRequest,
78
+ mutate: KeyedMutator<Message[]>,
79
+ mutateStreamData: KeyedMutator<JSONValue[] | undefined>,
80
+ existingData: JSONValue[] | undefined,
81
+ extraMetadataRef: React.MutableRefObject<any>,
82
+ messagesRef: React.MutableRefObject<Message[]>,
83
+ abortControllerRef: React.MutableRefObject<AbortController | null>,
84
+ generateId: IdGenerator,
85
+ streamMode?: 'stream-data' | 'text',
86
+ onFinish?: (message: Message) => void,
87
+ onResponse?: (response: Response) => void | Promise<void>,
88
+ onToolCall?: UseChatOptions['onToolCall'],
89
+ sendExtraMessageFields?: boolean,
90
+ ) => {
91
+ // Do an optimistic update to the chat state to show the updated messages
92
+ // immediately.
93
+ const previousMessages = messagesRef.current;
94
+ mutate(chatRequest.messages, false);
95
+
96
+ const constructedMessagesPayload = sendExtraMessageFields
97
+ ? chatRequest.messages
98
+ : chatRequest.messages.map(
99
+ ({
100
+ role,
101
+ content,
102
+ name,
103
+ data,
104
+ annotations,
105
+ toolInvocations,
106
+ function_call,
107
+ tool_calls,
108
+ tool_call_id,
109
+ }) => ({
110
+ role,
111
+ content,
112
+ ...(name !== undefined && { name }),
113
+ ...(data !== undefined && { data }),
114
+ ...(annotations !== undefined && { annotations }),
115
+ ...(toolInvocations !== undefined && { toolInvocations }),
116
+ // outdated function/tool call handling (TODO deprecate):
117
+ tool_call_id,
118
+ ...(function_call !== undefined && { function_call }),
119
+ ...(tool_calls !== undefined && { tool_calls }),
120
+ }),
121
+ );
122
+
123
+ return await callChatApi({
124
+ api,
125
+ messages: constructedMessagesPayload,
126
+ body: {
127
+ data: chatRequest.data,
128
+ ...extraMetadataRef.current.body,
129
+ ...chatRequest.options?.body,
130
+ ...(chatRequest.functions !== undefined && {
131
+ functions: chatRequest.functions,
132
+ }),
133
+ ...(chatRequest.function_call !== undefined && {
134
+ function_call: chatRequest.function_call,
135
+ }),
136
+ ...(chatRequest.tools !== undefined && {
137
+ tools: chatRequest.tools,
138
+ }),
139
+ ...(chatRequest.tool_choice !== undefined && {
140
+ tool_choice: chatRequest.tool_choice,
141
+ }),
142
+ },
143
+ streamMode,
144
+ credentials: extraMetadataRef.current.credentials,
145
+ headers: {
146
+ ...extraMetadataRef.current.headers,
147
+ ...chatRequest.options?.headers,
148
+ },
149
+ abortController: () => abortControllerRef.current,
150
+ restoreMessagesOnFailure() {
151
+ mutate(previousMessages, false);
152
+ },
153
+ onResponse,
154
+ onUpdate(merged, data) {
155
+ mutate([...chatRequest.messages, ...merged], false);
156
+ mutateStreamData([...(existingData || []), ...(data || [])], false);
157
+ },
158
+ onToolCall,
159
+ onFinish,
160
+ generateId,
161
+ });
162
+ };
163
+
164
+ export function useChat({
165
+ api = '/api/chat',
166
+ id,
167
+ initialMessages,
168
+ initialInput = '',
169
+ sendExtraMessageFields,
170
+ experimental_onFunctionCall,
171
+ experimental_onToolCall,
172
+ onToolCall,
173
+ experimental_maxAutomaticRoundtrips = 0,
174
+ maxAutomaticRoundtrips = experimental_maxAutomaticRoundtrips,
175
+ maxToolRoundtrips = maxAutomaticRoundtrips,
176
+ streamMode,
177
+ onResponse,
178
+ onFinish,
179
+ onError,
180
+ credentials,
181
+ headers,
182
+ body,
183
+ generateId = generateIdFunc,
184
+ }: Omit<UseChatOptions, 'api'> & {
185
+ api?: string;
186
+ key?: string;
187
+ /**
188
+ @deprecated Use `maxToolRoundtrips` instead.
189
+ */
190
+ experimental_maxAutomaticRoundtrips?: number;
191
+
192
+ /**
193
+ @deprecated Use `maxToolRoundtrips` instead.
194
+ */
195
+ maxAutomaticRoundtrips?: number;
196
+
197
+ /**
198
+ Maximal number of automatic roundtrips for tool calls.
199
+
200
+ An automatic tool call roundtrip is a call to the server with the
201
+ tool call results when all tool calls in the last assistant
202
+ message have results.
203
+
204
+ A maximum number is required to prevent infinite loops in the
205
+ case of misconfigured tools.
206
+
207
+ By default, it's set to 0, which will disable the feature.
208
+ */
209
+ maxToolRoundtrips?: number;
210
+ } = {}): UseChatHelpers & {
211
+ /**
212
+ * @deprecated Use `addToolResult` instead.
213
+ */
214
+ experimental_addToolResult: ({
215
+ toolCallId,
216
+ result,
217
+ }: {
218
+ toolCallId: string;
219
+ result: any;
220
+ }) => void;
221
+ addToolResult: ({
222
+ toolCallId,
223
+ result,
224
+ }: {
225
+ toolCallId: string;
226
+ result: any;
227
+ }) => void;
228
+ } {
229
+ // Generate a unique id for the chat if not provided.
230
+ const hookId = useId();
231
+ const idKey = id ?? hookId;
232
+ const chatKey = typeof api === 'string' ? [api, idKey] : idKey;
233
+
234
+ // Store a empty array as the initial messages
235
+ // (instead of using a default parameter value that gets re-created each time)
236
+ // to avoid re-renders:
237
+ const [initialMessagesFallback] = useState([]);
238
+
239
+ // Store the chat state in SWR, using the chatId as the key to share states.
240
+ const { data: messages, mutate } = useSWR<Message[]>(
241
+ [chatKey, 'messages'],
242
+ null,
243
+ { fallbackData: initialMessages ?? initialMessagesFallback },
244
+ );
245
+
246
+ // We store loading state in another hook to sync loading states across hook invocations
247
+ const { data: isLoading = false, mutate: mutateLoading } = useSWR<boolean>(
248
+ [chatKey, 'loading'],
249
+ null,
250
+ );
251
+
252
+ const { data: streamData, mutate: mutateStreamData } = useSWR<
253
+ JSONValue[] | undefined
254
+ >([chatKey, 'streamData'], null);
255
+
256
+ const { data: error = undefined, mutate: setError } = useSWR<
257
+ undefined | Error
258
+ >([chatKey, 'error'], null);
259
+
260
+ // Keep the latest messages in a ref.
261
+ const messagesRef = useRef<Message[]>(messages || []);
262
+ useEffect(() => {
263
+ messagesRef.current = messages || [];
264
+ }, [messages]);
265
+
266
+ // Abort controller to cancel the current API call.
267
+ const abortControllerRef = useRef<AbortController | null>(null);
268
+
269
+ const extraMetadataRef = useRef({
270
+ credentials,
271
+ headers,
272
+ body,
273
+ });
274
+
275
+ useEffect(() => {
276
+ extraMetadataRef.current = {
277
+ credentials,
278
+ headers,
279
+ body,
280
+ };
281
+ }, [credentials, headers, body]);
282
+
283
+ const triggerRequest = useCallback(
284
+ async (chatRequest: ChatRequest) => {
285
+ try {
286
+ mutateLoading(true);
287
+ setError(undefined);
288
+
289
+ const abortController = new AbortController();
290
+ abortControllerRef.current = abortController;
291
+
292
+ await processChatStream({
293
+ getStreamedResponse: () =>
294
+ getStreamedResponse(
295
+ api,
296
+ chatRequest,
297
+ mutate,
298
+ mutateStreamData,
299
+ streamData!,
300
+ extraMetadataRef,
301
+ messagesRef,
302
+ abortControllerRef,
303
+ generateId,
304
+ streamMode,
305
+ onFinish,
306
+ onResponse,
307
+ onToolCall,
308
+ sendExtraMessageFields,
309
+ ),
310
+ experimental_onFunctionCall,
311
+ experimental_onToolCall,
312
+ updateChatRequest: chatRequestParam => {
313
+ chatRequest = chatRequestParam;
314
+ },
315
+ getCurrentMessages: () => messagesRef.current,
316
+ });
317
+
318
+ abortControllerRef.current = null;
319
+ } catch (err) {
320
+ // Ignore abort errors as they are expected.
321
+ if ((err as any).name === 'AbortError') {
322
+ abortControllerRef.current = null;
323
+ return null;
324
+ }
325
+
326
+ if (onError && err instanceof Error) {
327
+ onError(err);
328
+ }
329
+
330
+ setError(err as Error);
331
+ } finally {
332
+ mutateLoading(false);
333
+ }
334
+
335
+ // auto-submit when all tool calls in the last assistant message have results:
336
+ const messages = messagesRef.current;
337
+ const lastMessage = messages[messages.length - 1];
338
+ if (
339
+ // ensure there is a last message:
340
+ lastMessage != null &&
341
+ // check if the feature is enabled:
342
+ maxToolRoundtrips > 0 &&
343
+ // check that roundtrip is possible:
344
+ isAssistantMessageWithCompletedToolCalls(lastMessage) &&
345
+ // limit the number of automatic roundtrips:
346
+ countTrailingAssistantMessages(messages) <= maxToolRoundtrips
347
+ ) {
348
+ await triggerRequest({ messages });
349
+ }
350
+ },
351
+ [
352
+ mutate,
353
+ mutateLoading,
354
+ api,
355
+ extraMetadataRef,
356
+ onResponse,
357
+ onFinish,
358
+ onError,
359
+ setError,
360
+ mutateStreamData,
361
+ streamData,
362
+ streamMode,
363
+ sendExtraMessageFields,
364
+ experimental_onFunctionCall,
365
+ experimental_onToolCall,
366
+ onToolCall,
367
+ maxToolRoundtrips,
368
+ messagesRef,
369
+ abortControllerRef,
370
+ generateId,
371
+ ],
372
+ );
373
+
374
+ const append = useCallback(
375
+ async (
376
+ message: Message | CreateMessage,
377
+ {
378
+ options,
379
+ functions,
380
+ function_call,
381
+ tools,
382
+ tool_choice,
383
+ data,
384
+ }: ChatRequestOptions = {},
385
+ ) => {
386
+ if (!message.id) {
387
+ message.id = generateId();
388
+ }
389
+
390
+ const chatRequest: ChatRequest = {
391
+ messages: messagesRef.current.concat(message as Message),
392
+ options,
393
+ data,
394
+ ...(functions !== undefined && { functions }),
395
+ ...(function_call !== undefined && { function_call }),
396
+ ...(tools !== undefined && { tools }),
397
+ ...(tool_choice !== undefined && { tool_choice }),
398
+ };
399
+
400
+ return triggerRequest(chatRequest);
401
+ },
402
+ [triggerRequest, generateId],
403
+ );
404
+
405
+ const reload = useCallback(
406
+ async ({
407
+ options,
408
+ functions,
409
+ function_call,
410
+ tools,
411
+ tool_choice,
412
+ }: ChatRequestOptions = {}) => {
413
+ if (messagesRef.current.length === 0) return null;
414
+
415
+ // Remove last assistant message and retry last user message.
416
+ const lastMessage = messagesRef.current[messagesRef.current.length - 1];
417
+ if (lastMessage.role === 'assistant') {
418
+ const chatRequest: ChatRequest = {
419
+ messages: messagesRef.current.slice(0, -1),
420
+ options,
421
+ ...(functions !== undefined && { functions }),
422
+ ...(function_call !== undefined && { function_call }),
423
+ ...(tools !== undefined && { tools }),
424
+ ...(tool_choice !== undefined && { tool_choice }),
425
+ };
426
+
427
+ return triggerRequest(chatRequest);
428
+ }
429
+
430
+ const chatRequest: ChatRequest = {
431
+ messages: messagesRef.current,
432
+ options,
433
+ ...(functions !== undefined && { functions }),
434
+ ...(function_call !== undefined && { function_call }),
435
+ ...(tools !== undefined && { tools }),
436
+ ...(tool_choice !== undefined && { tool_choice }),
437
+ };
438
+
439
+ return triggerRequest(chatRequest);
440
+ },
441
+ [triggerRequest],
442
+ );
443
+
444
+ const stop = useCallback(() => {
445
+ if (abortControllerRef.current) {
446
+ abortControllerRef.current.abort();
447
+ abortControllerRef.current = null;
448
+ }
449
+ }, []);
450
+
451
+ const setMessages = useCallback(
452
+ (messages: Message[]) => {
453
+ mutate(messages, false);
454
+ messagesRef.current = messages;
455
+ },
456
+ [mutate],
457
+ );
458
+
459
+ // Input state and handlers.
460
+ const [input, setInput] = useState(initialInput);
461
+
462
+ const handleSubmit = useCallback(
463
+ (
464
+ e: React.FormEvent<HTMLFormElement>,
465
+ options: ChatRequestOptions = {},
466
+ metadata?: Object,
467
+ ) => {
468
+ if (metadata) {
469
+ extraMetadataRef.current = {
470
+ ...extraMetadataRef.current,
471
+ ...metadata,
472
+ };
473
+ }
474
+
475
+ e.preventDefault();
476
+ if (!input) return;
477
+
478
+ append(
479
+ {
480
+ content: input,
481
+ role: 'user',
482
+ createdAt: new Date(),
483
+ },
484
+ options,
485
+ );
486
+ setInput('');
487
+ },
488
+ [input, append],
489
+ );
490
+
491
+ const handleInputChange = (e: any) => {
492
+ setInput(e.target.value);
493
+ };
494
+
495
+ const addToolResult = ({
496
+ toolCallId,
497
+ result,
498
+ }: {
499
+ toolCallId: string;
500
+ result: any;
501
+ }) => {
502
+ const updatedMessages = messagesRef.current.map((message, index, arr) =>
503
+ // update the tool calls in the last assistant message:
504
+ index === arr.length - 1 &&
505
+ message.role === 'assistant' &&
506
+ message.toolInvocations
507
+ ? {
508
+ ...message,
509
+ toolInvocations: message.toolInvocations.map(toolInvocation =>
510
+ toolInvocation.toolCallId === toolCallId
511
+ ? { ...toolInvocation, result }
512
+ : toolInvocation,
513
+ ),
514
+ }
515
+ : message,
516
+ );
517
+
518
+ mutate(updatedMessages, false);
519
+
520
+ // auto-submit when all tool calls in the last assistant message have results:
521
+ const lastMessage = updatedMessages[updatedMessages.length - 1];
522
+ if (isAssistantMessageWithCompletedToolCalls(lastMessage)) {
523
+ triggerRequest({ messages: updatedMessages });
524
+ }
525
+ };
526
+
527
+ return {
528
+ messages: messages || [],
529
+ error,
530
+ append,
531
+ reload,
532
+ stop,
533
+ setMessages,
534
+ input,
535
+ setInput,
536
+ handleInputChange,
537
+ handleSubmit,
538
+ isLoading,
539
+ data: streamData,
540
+ addToolResult,
541
+ experimental_addToolResult: addToolResult,
542
+ };
543
+ }
544
+
545
+ /**
546
+ Check if the message is an assistant message with completed tool calls.
547
+ The message must have at least one tool invocation and all tool invocations
548
+ must have a result.
549
+ */
550
+ function isAssistantMessageWithCompletedToolCalls(message: Message) {
551
+ return (
552
+ message.role === 'assistant' &&
553
+ message.toolInvocations &&
554
+ message.toolInvocations.length > 0 &&
555
+ message.toolInvocations.every(toolInvocation => 'result' in toolInvocation)
556
+ );
557
+ }
558
+
559
+ /**
560
+ Returns the number of trailing assistant messages in the array.
561
+ */
562
+ function countTrailingAssistantMessages(messages: Message[]) {
563
+ let count = 0;
564
+ for (let i = messages.length - 1; i >= 0; i--) {
565
+ if (messages[i].role === 'assistant') {
566
+ count++;
567
+ } else {
568
+ break;
569
+ }
570
+ }
571
+ return count;
572
+ }