@ai-sdk/react 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.mjs ADDED
@@ -0,0 +1,653 @@
1
+ // src/use-chat.ts
2
+ import {
3
+ callChatApi,
4
+ generateId as generateIdFunc,
5
+ processChatStream
6
+ } from "@ai-sdk/ui-utils";
7
+ import { useCallback, useEffect, useId, useRef, useState } from "react";
8
+ import useSWR from "swr";
9
+ var getStreamedResponse = async (api, chatRequest, mutate, mutateStreamData, existingData, extraMetadataRef, messagesRef, abortControllerRef, generateId2, streamMode, onFinish, onResponse, onToolCall, sendExtraMessageFields) => {
10
+ var _a, _b;
11
+ const previousMessages = messagesRef.current;
12
+ mutate(chatRequest.messages, false);
13
+ const constructedMessagesPayload = sendExtraMessageFields ? chatRequest.messages : chatRequest.messages.map(
14
+ ({
15
+ role,
16
+ content,
17
+ name,
18
+ data,
19
+ annotations,
20
+ toolInvocations,
21
+ function_call,
22
+ tool_calls,
23
+ tool_call_id
24
+ }) => ({
25
+ role,
26
+ content,
27
+ ...name !== void 0 && { name },
28
+ ...data !== void 0 && { data },
29
+ ...annotations !== void 0 && { annotations },
30
+ ...toolInvocations !== void 0 && { toolInvocations },
31
+ // outdated function/tool call handling (TODO deprecate):
32
+ tool_call_id,
33
+ ...function_call !== void 0 && { function_call },
34
+ ...tool_calls !== void 0 && { tool_calls }
35
+ })
36
+ );
37
+ return await callChatApi({
38
+ api,
39
+ messages: constructedMessagesPayload,
40
+ body: {
41
+ data: chatRequest.data,
42
+ ...extraMetadataRef.current.body,
43
+ ...(_a = chatRequest.options) == null ? void 0 : _a.body,
44
+ ...chatRequest.functions !== void 0 && {
45
+ functions: chatRequest.functions
46
+ },
47
+ ...chatRequest.function_call !== void 0 && {
48
+ function_call: chatRequest.function_call
49
+ },
50
+ ...chatRequest.tools !== void 0 && {
51
+ tools: chatRequest.tools
52
+ },
53
+ ...chatRequest.tool_choice !== void 0 && {
54
+ tool_choice: chatRequest.tool_choice
55
+ }
56
+ },
57
+ streamMode,
58
+ credentials: extraMetadataRef.current.credentials,
59
+ headers: {
60
+ ...extraMetadataRef.current.headers,
61
+ ...(_b = chatRequest.options) == null ? void 0 : _b.headers
62
+ },
63
+ abortController: () => abortControllerRef.current,
64
+ restoreMessagesOnFailure() {
65
+ mutate(previousMessages, false);
66
+ },
67
+ onResponse,
68
+ onUpdate(merged, data) {
69
+ mutate([...chatRequest.messages, ...merged], false);
70
+ mutateStreamData([...existingData || [], ...data || []], false);
71
+ },
72
+ onToolCall,
73
+ onFinish,
74
+ generateId: generateId2
75
+ });
76
+ };
77
+ function useChat({
78
+ api = "/api/chat",
79
+ id,
80
+ initialMessages,
81
+ initialInput = "",
82
+ sendExtraMessageFields,
83
+ experimental_onFunctionCall,
84
+ experimental_onToolCall,
85
+ onToolCall,
86
+ experimental_maxAutomaticRoundtrips = 0,
87
+ maxAutomaticRoundtrips = experimental_maxAutomaticRoundtrips,
88
+ maxToolRoundtrips = maxAutomaticRoundtrips,
89
+ streamMode,
90
+ onResponse,
91
+ onFinish,
92
+ onError,
93
+ credentials,
94
+ headers,
95
+ body,
96
+ generateId: generateId2 = generateIdFunc
97
+ } = {}) {
98
+ const hookId = useId();
99
+ const idKey = id != null ? id : hookId;
100
+ const chatKey = typeof api === "string" ? [api, idKey] : idKey;
101
+ const [initialMessagesFallback] = useState([]);
102
+ const { data: messages, mutate } = useSWR(
103
+ [chatKey, "messages"],
104
+ null,
105
+ { fallbackData: initialMessages != null ? initialMessages : initialMessagesFallback }
106
+ );
107
+ const { data: isLoading = false, mutate: mutateLoading } = useSWR(
108
+ [chatKey, "loading"],
109
+ null
110
+ );
111
+ const { data: streamData, mutate: mutateStreamData } = useSWR([chatKey, "streamData"], null);
112
+ const { data: error = void 0, mutate: setError } = useSWR([chatKey, "error"], null);
113
+ const messagesRef = useRef(messages || []);
114
+ useEffect(() => {
115
+ messagesRef.current = messages || [];
116
+ }, [messages]);
117
+ const abortControllerRef = useRef(null);
118
+ const extraMetadataRef = useRef({
119
+ credentials,
120
+ headers,
121
+ body
122
+ });
123
+ useEffect(() => {
124
+ extraMetadataRef.current = {
125
+ credentials,
126
+ headers,
127
+ body
128
+ };
129
+ }, [credentials, headers, body]);
130
+ const triggerRequest = useCallback(
131
+ async (chatRequest) => {
132
+ try {
133
+ mutateLoading(true);
134
+ setError(void 0);
135
+ const abortController = new AbortController();
136
+ abortControllerRef.current = abortController;
137
+ await processChatStream({
138
+ getStreamedResponse: () => getStreamedResponse(
139
+ api,
140
+ chatRequest,
141
+ mutate,
142
+ mutateStreamData,
143
+ streamData,
144
+ extraMetadataRef,
145
+ messagesRef,
146
+ abortControllerRef,
147
+ generateId2,
148
+ streamMode,
149
+ onFinish,
150
+ onResponse,
151
+ onToolCall,
152
+ sendExtraMessageFields
153
+ ),
154
+ experimental_onFunctionCall,
155
+ experimental_onToolCall,
156
+ updateChatRequest: (chatRequestParam) => {
157
+ chatRequest = chatRequestParam;
158
+ },
159
+ getCurrentMessages: () => messagesRef.current
160
+ });
161
+ abortControllerRef.current = null;
162
+ } catch (err) {
163
+ if (err.name === "AbortError") {
164
+ abortControllerRef.current = null;
165
+ return null;
166
+ }
167
+ if (onError && err instanceof Error) {
168
+ onError(err);
169
+ }
170
+ setError(err);
171
+ } finally {
172
+ mutateLoading(false);
173
+ }
174
+ const messages2 = messagesRef.current;
175
+ const lastMessage = messages2[messages2.length - 1];
176
+ if (
177
+ // ensure there is a last message:
178
+ lastMessage != null && // check if the feature is enabled:
179
+ maxToolRoundtrips > 0 && // check that roundtrip is possible:
180
+ isAssistantMessageWithCompletedToolCalls(lastMessage) && // limit the number of automatic roundtrips:
181
+ countTrailingAssistantMessages(messages2) <= maxToolRoundtrips
182
+ ) {
183
+ await triggerRequest({ messages: messages2 });
184
+ }
185
+ },
186
+ [
187
+ mutate,
188
+ mutateLoading,
189
+ api,
190
+ extraMetadataRef,
191
+ onResponse,
192
+ onFinish,
193
+ onError,
194
+ setError,
195
+ mutateStreamData,
196
+ streamData,
197
+ streamMode,
198
+ sendExtraMessageFields,
199
+ experimental_onFunctionCall,
200
+ experimental_onToolCall,
201
+ onToolCall,
202
+ maxToolRoundtrips,
203
+ messagesRef,
204
+ abortControllerRef,
205
+ generateId2
206
+ ]
207
+ );
208
+ const append = useCallback(
209
+ async (message, {
210
+ options,
211
+ functions,
212
+ function_call,
213
+ tools,
214
+ tool_choice,
215
+ data
216
+ } = {}) => {
217
+ if (!message.id) {
218
+ message.id = generateId2();
219
+ }
220
+ const chatRequest = {
221
+ messages: messagesRef.current.concat(message),
222
+ options,
223
+ data,
224
+ ...functions !== void 0 && { functions },
225
+ ...function_call !== void 0 && { function_call },
226
+ ...tools !== void 0 && { tools },
227
+ ...tool_choice !== void 0 && { tool_choice }
228
+ };
229
+ return triggerRequest(chatRequest);
230
+ },
231
+ [triggerRequest, generateId2]
232
+ );
233
+ const reload = useCallback(
234
+ async ({
235
+ options,
236
+ functions,
237
+ function_call,
238
+ tools,
239
+ tool_choice
240
+ } = {}) => {
241
+ if (messagesRef.current.length === 0)
242
+ return null;
243
+ const lastMessage = messagesRef.current[messagesRef.current.length - 1];
244
+ if (lastMessage.role === "assistant") {
245
+ const chatRequest2 = {
246
+ messages: messagesRef.current.slice(0, -1),
247
+ options,
248
+ ...functions !== void 0 && { functions },
249
+ ...function_call !== void 0 && { function_call },
250
+ ...tools !== void 0 && { tools },
251
+ ...tool_choice !== void 0 && { tool_choice }
252
+ };
253
+ return triggerRequest(chatRequest2);
254
+ }
255
+ const chatRequest = {
256
+ messages: messagesRef.current,
257
+ options,
258
+ ...functions !== void 0 && { functions },
259
+ ...function_call !== void 0 && { function_call },
260
+ ...tools !== void 0 && { tools },
261
+ ...tool_choice !== void 0 && { tool_choice }
262
+ };
263
+ return triggerRequest(chatRequest);
264
+ },
265
+ [triggerRequest]
266
+ );
267
+ const stop = useCallback(() => {
268
+ if (abortControllerRef.current) {
269
+ abortControllerRef.current.abort();
270
+ abortControllerRef.current = null;
271
+ }
272
+ }, []);
273
+ const setMessages = useCallback(
274
+ (messages2) => {
275
+ mutate(messages2, false);
276
+ messagesRef.current = messages2;
277
+ },
278
+ [mutate]
279
+ );
280
+ const [input, setInput] = useState(initialInput);
281
+ const handleSubmit = useCallback(
282
+ (e, options = {}, metadata) => {
283
+ if (metadata) {
284
+ extraMetadataRef.current = {
285
+ ...extraMetadataRef.current,
286
+ ...metadata
287
+ };
288
+ }
289
+ e.preventDefault();
290
+ if (!input)
291
+ return;
292
+ append(
293
+ {
294
+ content: input,
295
+ role: "user",
296
+ createdAt: /* @__PURE__ */ new Date()
297
+ },
298
+ options
299
+ );
300
+ setInput("");
301
+ },
302
+ [input, append]
303
+ );
304
+ const handleInputChange = (e) => {
305
+ setInput(e.target.value);
306
+ };
307
+ const addToolResult = ({
308
+ toolCallId,
309
+ result
310
+ }) => {
311
+ const updatedMessages = messagesRef.current.map(
312
+ (message, index, arr) => (
313
+ // update the tool calls in the last assistant message:
314
+ index === arr.length - 1 && message.role === "assistant" && message.toolInvocations ? {
315
+ ...message,
316
+ toolInvocations: message.toolInvocations.map(
317
+ (toolInvocation) => toolInvocation.toolCallId === toolCallId ? { ...toolInvocation, result } : toolInvocation
318
+ )
319
+ } : message
320
+ )
321
+ );
322
+ mutate(updatedMessages, false);
323
+ const lastMessage = updatedMessages[updatedMessages.length - 1];
324
+ if (isAssistantMessageWithCompletedToolCalls(lastMessage)) {
325
+ triggerRequest({ messages: updatedMessages });
326
+ }
327
+ };
328
+ return {
329
+ messages: messages || [],
330
+ error,
331
+ append,
332
+ reload,
333
+ stop,
334
+ setMessages,
335
+ input,
336
+ setInput,
337
+ handleInputChange,
338
+ handleSubmit,
339
+ isLoading,
340
+ data: streamData,
341
+ addToolResult,
342
+ experimental_addToolResult: addToolResult
343
+ };
344
+ }
345
+ function isAssistantMessageWithCompletedToolCalls(message) {
346
+ return message.role === "assistant" && message.toolInvocations && message.toolInvocations.length > 0 && message.toolInvocations.every((toolInvocation) => "result" in toolInvocation);
347
+ }
348
+ function countTrailingAssistantMessages(messages) {
349
+ let count = 0;
350
+ for (let i = messages.length - 1; i >= 0; i--) {
351
+ if (messages[i].role === "assistant") {
352
+ count++;
353
+ } else {
354
+ break;
355
+ }
356
+ }
357
+ return count;
358
+ }
359
+
360
+ // src/use-completion.ts
361
+ import {
362
+ callCompletionApi
363
+ } from "@ai-sdk/ui-utils";
364
+ import { useCallback as useCallback2, useEffect as useEffect2, useId as useId2, useRef as useRef2, useState as useState2 } from "react";
365
+ import useSWR2 from "swr";
366
+ function useCompletion({
367
+ api = "/api/completion",
368
+ id,
369
+ initialCompletion = "",
370
+ initialInput = "",
371
+ credentials,
372
+ headers,
373
+ body,
374
+ streamMode,
375
+ onResponse,
376
+ onFinish,
377
+ onError
378
+ } = {}) {
379
+ const hookId = useId2();
380
+ const completionId = id || hookId;
381
+ const { data, mutate } = useSWR2([api, completionId], null, {
382
+ fallbackData: initialCompletion
383
+ });
384
+ const { data: isLoading = false, mutate: mutateLoading } = useSWR2(
385
+ [completionId, "loading"],
386
+ null
387
+ );
388
+ const { data: streamData, mutate: mutateStreamData } = useSWR2([completionId, "streamData"], null);
389
+ const [error, setError] = useState2(void 0);
390
+ const completion = data;
391
+ const [abortController, setAbortController] = useState2(null);
392
+ const extraMetadataRef = useRef2({
393
+ credentials,
394
+ headers,
395
+ body
396
+ });
397
+ useEffect2(() => {
398
+ extraMetadataRef.current = {
399
+ credentials,
400
+ headers,
401
+ body
402
+ };
403
+ }, [credentials, headers, body]);
404
+ const triggerRequest = useCallback2(
405
+ async (prompt, options) => callCompletionApi({
406
+ api,
407
+ prompt,
408
+ credentials: extraMetadataRef.current.credentials,
409
+ headers: { ...extraMetadataRef.current.headers, ...options == null ? void 0 : options.headers },
410
+ body: {
411
+ ...extraMetadataRef.current.body,
412
+ ...options == null ? void 0 : options.body
413
+ },
414
+ streamMode,
415
+ setCompletion: (completion2) => mutate(completion2, false),
416
+ setLoading: mutateLoading,
417
+ setError,
418
+ setAbortController,
419
+ onResponse,
420
+ onFinish,
421
+ onError,
422
+ onData: (data2) => {
423
+ mutateStreamData([...streamData || [], ...data2 || []], false);
424
+ }
425
+ }),
426
+ [
427
+ mutate,
428
+ mutateLoading,
429
+ api,
430
+ extraMetadataRef,
431
+ setAbortController,
432
+ onResponse,
433
+ onFinish,
434
+ onError,
435
+ setError,
436
+ streamData,
437
+ streamMode,
438
+ mutateStreamData
439
+ ]
440
+ );
441
+ const stop = useCallback2(() => {
442
+ if (abortController) {
443
+ abortController.abort();
444
+ setAbortController(null);
445
+ }
446
+ }, [abortController]);
447
+ const setCompletion = useCallback2(
448
+ (completion2) => {
449
+ mutate(completion2, false);
450
+ },
451
+ [mutate]
452
+ );
453
+ const complete = useCallback2(
454
+ async (prompt, options) => {
455
+ return triggerRequest(prompt, options);
456
+ },
457
+ [triggerRequest]
458
+ );
459
+ const [input, setInput] = useState2(initialInput);
460
+ const handleSubmit = useCallback2(
461
+ (e) => {
462
+ e.preventDefault();
463
+ if (!input)
464
+ return;
465
+ return complete(input);
466
+ },
467
+ [input, complete]
468
+ );
469
+ const handleInputChange = (e) => {
470
+ setInput(e.target.value);
471
+ };
472
+ return {
473
+ completion,
474
+ complete,
475
+ error,
476
+ setCompletion,
477
+ stop,
478
+ input,
479
+ setInput,
480
+ handleInputChange,
481
+ handleSubmit,
482
+ isLoading,
483
+ data: streamData
484
+ };
485
+ }
486
+
487
+ // src/use-assistant.ts
488
+ import { isAbortError } from "@ai-sdk/provider-utils";
489
+ import {
490
+ generateId,
491
+ readDataStream
492
+ } from "@ai-sdk/ui-utils";
493
+ import { useCallback as useCallback3, useRef as useRef3, useState as useState3 } from "react";
494
+ function useAssistant({
495
+ api,
496
+ threadId: threadIdParam,
497
+ credentials,
498
+ headers,
499
+ body,
500
+ onError
501
+ }) {
502
+ const [messages, setMessages] = useState3([]);
503
+ const [input, setInput] = useState3("");
504
+ const [threadId, setThreadId] = useState3(void 0);
505
+ const [status, setStatus] = useState3("awaiting_message");
506
+ const [error, setError] = useState3(void 0);
507
+ const handleInputChange = (event) => {
508
+ setInput(event.target.value);
509
+ };
510
+ const abortControllerRef = useRef3(null);
511
+ const stop = useCallback3(() => {
512
+ if (abortControllerRef.current) {
513
+ abortControllerRef.current.abort();
514
+ abortControllerRef.current = null;
515
+ }
516
+ }, []);
517
+ const append = async (message, requestOptions) => {
518
+ var _a;
519
+ setStatus("in_progress");
520
+ setMessages((messages2) => {
521
+ var _a2;
522
+ return [
523
+ ...messages2,
524
+ {
525
+ ...message,
526
+ id: (_a2 = message.id) != null ? _a2 : generateId()
527
+ }
528
+ ];
529
+ });
530
+ setInput("");
531
+ const abortController = new AbortController();
532
+ try {
533
+ abortControllerRef.current = abortController;
534
+ const result = await fetch(api, {
535
+ method: "POST",
536
+ credentials,
537
+ signal: abortController.signal,
538
+ headers: { "Content-Type": "application/json", ...headers },
539
+ body: JSON.stringify({
540
+ ...body,
541
+ // always use user-provided threadId when available:
542
+ threadId: (_a = threadIdParam != null ? threadIdParam : threadId) != null ? _a : null,
543
+ message: message.content,
544
+ // optional request data:
545
+ data: requestOptions == null ? void 0 : requestOptions.data
546
+ })
547
+ });
548
+ if (result.body == null) {
549
+ throw new Error("The response body is empty.");
550
+ }
551
+ for await (const { type, value } of readDataStream(
552
+ result.body.getReader()
553
+ )) {
554
+ switch (type) {
555
+ case "assistant_message": {
556
+ setMessages((messages2) => [
557
+ ...messages2,
558
+ {
559
+ id: value.id,
560
+ role: value.role,
561
+ content: value.content[0].text.value
562
+ }
563
+ ]);
564
+ break;
565
+ }
566
+ case "text": {
567
+ setMessages((messages2) => {
568
+ const lastMessage = messages2[messages2.length - 1];
569
+ return [
570
+ ...messages2.slice(0, messages2.length - 1),
571
+ {
572
+ id: lastMessage.id,
573
+ role: lastMessage.role,
574
+ content: lastMessage.content + value
575
+ }
576
+ ];
577
+ });
578
+ break;
579
+ }
580
+ case "data_message": {
581
+ setMessages((messages2) => {
582
+ var _a2;
583
+ return [
584
+ ...messages2,
585
+ {
586
+ id: (_a2 = value.id) != null ? _a2 : generateId(),
587
+ role: "data",
588
+ content: "",
589
+ data: value.data
590
+ }
591
+ ];
592
+ });
593
+ break;
594
+ }
595
+ case "assistant_control_data": {
596
+ setThreadId(value.threadId);
597
+ setMessages((messages2) => {
598
+ const lastMessage = messages2[messages2.length - 1];
599
+ lastMessage.id = value.messageId;
600
+ return [...messages2.slice(0, messages2.length - 1), lastMessage];
601
+ });
602
+ break;
603
+ }
604
+ case "error": {
605
+ setError(new Error(value));
606
+ break;
607
+ }
608
+ }
609
+ }
610
+ } catch (error2) {
611
+ if (isAbortError(error2) && abortController.signal.aborted) {
612
+ abortControllerRef.current = null;
613
+ return;
614
+ }
615
+ if (onError && error2 instanceof Error) {
616
+ onError(error2);
617
+ }
618
+ setError(error2);
619
+ } finally {
620
+ abortControllerRef.current = null;
621
+ setStatus("awaiting_message");
622
+ }
623
+ };
624
+ const submitMessage = async (event, requestOptions) => {
625
+ var _a;
626
+ (_a = event == null ? void 0 : event.preventDefault) == null ? void 0 : _a.call(event);
627
+ if (input === "") {
628
+ return;
629
+ }
630
+ append({ role: "user", content: input }, requestOptions);
631
+ };
632
+ return {
633
+ append,
634
+ messages,
635
+ setMessages,
636
+ threadId,
637
+ input,
638
+ setInput,
639
+ handleInputChange,
640
+ submitMessage,
641
+ status,
642
+ error,
643
+ stop
644
+ };
645
+ }
646
+ var experimental_useAssistant = useAssistant;
647
+ export {
648
+ experimental_useAssistant,
649
+ useAssistant,
650
+ useChat,
651
+ useCompletion
652
+ };
653
+ //# sourceMappingURL=index.mjs.map