@chat-js/cli 0.3.0 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.js +11 -6
- package/package.json +1 -1
- package/templates/chat-app/app/(chat)/api/chat/prepare/route.ts +94 -0
- package/templates/chat-app/app/(chat)/api/chat/route.ts +97 -14
- package/templates/chat-app/chat.config.ts +141 -124
- package/templates/chat-app/components/chat-sync.tsx +6 -3
- package/templates/chat-app/components/feedback-actions.tsx +7 -3
- package/templates/chat-app/components/message-editor.tsx +8 -3
- package/templates/chat-app/components/message-siblings.tsx +14 -1
- package/templates/chat-app/components/model-selector.tsx +669 -407
- package/templates/chat-app/components/multimodal-input.tsx +252 -18
- package/templates/chat-app/components/parallel-response-cards.tsx +157 -0
- package/templates/chat-app/components/part/text-message-part.tsx +9 -5
- package/templates/chat-app/components/retry-button.tsx +25 -8
- package/templates/chat-app/components/user-message.tsx +136 -125
- package/templates/chat-app/hooks/chat-sync-hooks.ts +11 -0
- package/templates/chat-app/hooks/use-navigate-to-message.ts +39 -0
- package/templates/chat-app/lib/ai/types.ts +74 -3
- package/templates/chat-app/lib/config-schema.ts +5 -0
- package/templates/chat-app/lib/db/migrations/0044_gray_red_shift.sql +5 -0
- package/templates/chat-app/lib/db/migrations/meta/0044_snapshot.json +1567 -0
- package/templates/chat-app/lib/db/migrations/meta/_journal.json +8 -1
- package/templates/chat-app/lib/db/queries.ts +84 -4
- package/templates/chat-app/lib/db/schema.ts +4 -1
- package/templates/chat-app/lib/message-conversion.ts +14 -2
- package/templates/chat-app/lib/stores/hooks-threads.ts +37 -1
- package/templates/chat-app/lib/stores/with-threads.test.ts +137 -0
- package/templates/chat-app/lib/stores/with-threads.ts +157 -4
- package/templates/chat-app/lib/thread-utils.ts +23 -2
- package/templates/chat-app/providers/chat-input-provider.tsx +40 -2
- package/templates/chat-app/scripts/db-branch-delete.sh +7 -1
- package/templates/chat-app/scripts/db-branch-use.sh +7 -1
- package/templates/chat-app/scripts/with-db.sh +7 -1
- package/templates/chat-app/vitest.config.ts +2 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"use client";
|
|
2
2
|
import type { UseChatHelpers } from "@ai-sdk/react";
|
|
3
3
|
import { useChatActions, useChatStoreApi } from "@ai-sdk-tools/store";
|
|
4
|
-
import { useMutation } from "@tanstack/react-query";
|
|
4
|
+
import { useMutation, useQueryClient } from "@tanstack/react-query";
|
|
5
5
|
import { CameraIcon, FileIcon, ImageIcon, PlusIcon } from "lucide-react";
|
|
6
6
|
import type React from "react";
|
|
7
7
|
import {
|
|
@@ -25,17 +25,22 @@ import {
|
|
|
25
25
|
} from "@/components/ai-elements/prompt-input";
|
|
26
26
|
import { ContextBar } from "@/components/context-bar";
|
|
27
27
|
import { ContextUsageFromParent } from "@/components/context-usage";
|
|
28
|
-
import { useSaveMessageMutation } from "@/hooks/chat-sync-hooks";
|
|
29
28
|
import { useArtifact } from "@/hooks/use-artifact";
|
|
30
29
|
import { useIsMobile } from "@/hooks/use-mobile";
|
|
31
30
|
import type { AppModelId } from "@/lib/ai/app-model-id";
|
|
32
|
-
import
|
|
31
|
+
import {
|
|
32
|
+
expandSelectedModelValue,
|
|
33
|
+
type Attachment,
|
|
34
|
+
type ChatMessage,
|
|
35
|
+
type SelectedModelValue,
|
|
36
|
+
type UiToolName,
|
|
37
|
+
} from "@/lib/ai/types";
|
|
33
38
|
import { config } from "@/lib/config";
|
|
34
39
|
import { processFilesForUpload } from "@/lib/files/upload-prep";
|
|
35
40
|
import { useLastMessageId } from "@/lib/stores/hooks-base";
|
|
36
41
|
import { useAddMessageToTree } from "@/lib/stores/hooks-threads";
|
|
37
42
|
import { ANONYMOUS_LIMITS } from "@/lib/types/anonymous";
|
|
38
|
-
import { cn, generateUUID } from "@/lib/utils";
|
|
43
|
+
import { cn, fetchWithErrorHandlers, generateUUID } from "@/lib/utils";
|
|
39
44
|
import { useChatId } from "@/providers/chat-id-provider";
|
|
40
45
|
import { useChatInput } from "@/providers/chat-input-provider";
|
|
41
46
|
import { useChatModels } from "@/providers/chat-models-provider";
|
|
@@ -57,6 +62,16 @@ import { LimitDisplay } from "./upgrade-cta/limit-display";
|
|
|
57
62
|
import { LoginPrompt } from "./upgrade-cta/login-prompt";
|
|
58
63
|
|
|
59
64
|
const PROJECT_ROUTE_REGEX = /^\/project\/([^/]+)$/;
|
|
65
|
+
const PROJECT_CHAT_ROUTE_REGEX = /^\/project\/([^/]+)(?:\/chat\/[^/]+)?$/;
|
|
66
|
+
|
|
67
|
+
interface ParallelRequestSpec {
|
|
68
|
+
assistantMessageId: string;
|
|
69
|
+
createdAt: Date;
|
|
70
|
+
isPrimary: boolean;
|
|
71
|
+
modelId: AppModelId;
|
|
72
|
+
parallelGroupId: string;
|
|
73
|
+
parallelIndex: number;
|
|
74
|
+
}
|
|
60
75
|
|
|
61
76
|
/** Derive accept string for images only */
|
|
62
77
|
function getAcceptImages(acceptedTypes: Record<string, string[]>): string {
|
|
@@ -100,8 +115,8 @@ function PureMultimodalInput({
|
|
|
100
115
|
const { artifact, closeArtifact } = useArtifact();
|
|
101
116
|
const { data: session } = useSession();
|
|
102
117
|
const trpc = useTRPC();
|
|
118
|
+
const queryClient = useQueryClient();
|
|
103
119
|
const isMobile = useIsMobile();
|
|
104
|
-
const { mutate: saveChatMessage } = useSaveMessageMutation();
|
|
105
120
|
const addMessageToTree = useAddMessageToTree();
|
|
106
121
|
useChatId();
|
|
107
122
|
const {
|
|
@@ -117,7 +132,9 @@ function PureMultimodalInput({
|
|
|
117
132
|
attachments,
|
|
118
133
|
setAttachments,
|
|
119
134
|
selectedModelId,
|
|
135
|
+
selectedModelSelection,
|
|
120
136
|
handleModelChange,
|
|
137
|
+
handleModelSelectionChange,
|
|
121
138
|
getInputValue,
|
|
122
139
|
handleInputChange,
|
|
123
140
|
getInitialInput,
|
|
@@ -135,6 +152,18 @@ function PureMultimodalInput({
|
|
|
135
152
|
const stopStreamMutation = useMutation(
|
|
136
153
|
trpc.chat.stopStream.mutationOptions()
|
|
137
154
|
);
|
|
155
|
+
const normalizedSelectedModel = useMemo<SelectedModelValue>(() => {
|
|
156
|
+
const expanded = expandSelectedModelValue(selectedModelSelection);
|
|
157
|
+
|
|
158
|
+
return expanded.length > 1 ? selectedModelSelection : selectedModelId;
|
|
159
|
+
}, [selectedModelId, selectedModelSelection]);
|
|
160
|
+
const requestedModelIds = useMemo(
|
|
161
|
+
() => expandSelectedModelValue(normalizedSelectedModel),
|
|
162
|
+
[normalizedSelectedModel]
|
|
163
|
+
);
|
|
164
|
+
const parallelResponsesEnabled = config.features.parallelResponses;
|
|
165
|
+
const isParallelModelRequest =
|
|
166
|
+
parallelResponsesEnabled && requestedModelIds.length > 1;
|
|
138
167
|
|
|
139
168
|
// Attachment configuration from site config
|
|
140
169
|
const { maxBytes, maxDimension, acceptedTypes } = config.attachments;
|
|
@@ -181,6 +210,18 @@ function PureMultimodalInput({
|
|
|
181
210
|
const submission = useMemo(():
|
|
182
211
|
| { enabled: false; message: string }
|
|
183
212
|
| { enabled: true } => {
|
|
213
|
+
if (isParallelModelRequest && !session?.user) {
|
|
214
|
+
return {
|
|
215
|
+
enabled: false,
|
|
216
|
+
message: "Log in to use multiple models",
|
|
217
|
+
};
|
|
218
|
+
}
|
|
219
|
+
if (isParallelModelRequest && attachments.length > 0) {
|
|
220
|
+
return {
|
|
221
|
+
enabled: false,
|
|
222
|
+
message: "Multiple models with attachments are not supported yet",
|
|
223
|
+
};
|
|
224
|
+
}
|
|
184
225
|
if (isModelDisallowedForAnonymous) {
|
|
185
226
|
return { enabled: false, message: "Log in to use this model" };
|
|
186
227
|
}
|
|
@@ -203,7 +244,15 @@ function PureMultimodalInput({
|
|
|
203
244
|
};
|
|
204
245
|
}
|
|
205
246
|
return { enabled: true };
|
|
206
|
-
}, [
|
|
247
|
+
}, [
|
|
248
|
+
attachments.length,
|
|
249
|
+
isEmpty,
|
|
250
|
+
isModelDisallowedForAnonymous,
|
|
251
|
+
isParallelModelRequest,
|
|
252
|
+
session?.user,
|
|
253
|
+
status,
|
|
254
|
+
uploadQueue.length,
|
|
255
|
+
]);
|
|
207
256
|
|
|
208
257
|
// Helper function to process and validate files
|
|
209
258
|
const processFiles = useCallback(
|
|
@@ -275,6 +324,11 @@ function PureMultimodalInput({
|
|
|
275
324
|
[session?.user]
|
|
276
325
|
);
|
|
277
326
|
|
|
327
|
+
const getCurrentProjectId = useCallback(() => {
|
|
328
|
+
const projectMatch = window.location.pathname.match(PROJECT_CHAT_ROUTE_REGEX);
|
|
329
|
+
return projectMatch?.[1];
|
|
330
|
+
}, []);
|
|
331
|
+
|
|
278
332
|
// Trim messages in edit mode
|
|
279
333
|
const trimMessagesInEditMode = useCallback(
|
|
280
334
|
(parentId: string | null) => {
|
|
@@ -319,11 +373,122 @@ function PureMultimodalInput({
|
|
|
319
373
|
]
|
|
320
374
|
);
|
|
321
375
|
|
|
376
|
+
const invalidatePersistedMessages = useCallback(async () => {
|
|
377
|
+
await queryClient.invalidateQueries({
|
|
378
|
+
queryKey: trpc.chat.getChatMessages.queryKey({ chatId }),
|
|
379
|
+
});
|
|
380
|
+
}, [chatId, queryClient, trpc]);
|
|
381
|
+
|
|
382
|
+
const drainSecondaryParallelRequest = useCallback(
|
|
383
|
+
async ({
|
|
384
|
+
message,
|
|
385
|
+
requestSpec,
|
|
386
|
+
}: {
|
|
387
|
+
message: ChatMessage;
|
|
388
|
+
requestSpec: ParallelRequestSpec;
|
|
389
|
+
}) => {
|
|
390
|
+
const response = await fetchWithErrorHandlers("/api/chat", {
|
|
391
|
+
method: "POST",
|
|
392
|
+
headers: {
|
|
393
|
+
"Content-Type": "application/json",
|
|
394
|
+
},
|
|
395
|
+
body: JSON.stringify({
|
|
396
|
+
id: chatId,
|
|
397
|
+
message,
|
|
398
|
+
prevMessages: [],
|
|
399
|
+
projectId: getCurrentProjectId(),
|
|
400
|
+
assistantMessageId: requestSpec.assistantMessageId,
|
|
401
|
+
selectedModelId: requestSpec.modelId,
|
|
402
|
+
parallelGroupId: requestSpec.parallelGroupId,
|
|
403
|
+
parallelIndex: requestSpec.parallelIndex,
|
|
404
|
+
isPrimaryParallel: false,
|
|
405
|
+
}),
|
|
406
|
+
});
|
|
407
|
+
|
|
408
|
+
if (!response.body) {
|
|
409
|
+
return;
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
const reader = response.body.getReader();
|
|
413
|
+
|
|
414
|
+
while (true) {
|
|
415
|
+
const { done } = await reader.read();
|
|
416
|
+
|
|
417
|
+
if (done) {
|
|
418
|
+
break;
|
|
419
|
+
}
|
|
420
|
+
}
|
|
421
|
+
},
|
|
422
|
+
[chatId, getCurrentProjectId]
|
|
423
|
+
);
|
|
424
|
+
|
|
425
|
+
const runParallelSecondaryRequests = useCallback(
|
|
426
|
+
async ({
|
|
427
|
+
message,
|
|
428
|
+
secondaryRequestSpecs,
|
|
429
|
+
}: {
|
|
430
|
+
message: ChatMessage;
|
|
431
|
+
secondaryRequestSpecs: ParallelRequestSpec[];
|
|
432
|
+
}) => {
|
|
433
|
+
await fetchWithErrorHandlers("/api/chat/prepare", {
|
|
434
|
+
method: "POST",
|
|
435
|
+
headers: {
|
|
436
|
+
"Content-Type": "application/json",
|
|
437
|
+
},
|
|
438
|
+
body: JSON.stringify({
|
|
439
|
+
id: chatId,
|
|
440
|
+
message,
|
|
441
|
+
projectId: getCurrentProjectId(),
|
|
442
|
+
}),
|
|
443
|
+
});
|
|
444
|
+
|
|
445
|
+
const results = await Promise.allSettled(
|
|
446
|
+
secondaryRequestSpecs.map((requestSpec) =>
|
|
447
|
+
drainSecondaryParallelRequest({ message, requestSpec })
|
|
448
|
+
)
|
|
449
|
+
);
|
|
450
|
+
|
|
451
|
+
results.forEach((result, index) => {
|
|
452
|
+
if (result.status === "fulfilled") {
|
|
453
|
+
return;
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
const failedRequestSpec = secondaryRequestSpecs[index];
|
|
457
|
+
if (!failedRequestSpec) {
|
|
458
|
+
return;
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
addMessageToTree({
|
|
462
|
+
id: failedRequestSpec.assistantMessageId,
|
|
463
|
+
parts: [],
|
|
464
|
+
role: "assistant",
|
|
465
|
+
metadata: {
|
|
466
|
+
createdAt: failedRequestSpec.createdAt,
|
|
467
|
+
parentMessageId: message.id,
|
|
468
|
+
parallelGroupId: failedRequestSpec.parallelGroupId,
|
|
469
|
+
parallelIndex: failedRequestSpec.parallelIndex,
|
|
470
|
+
isPrimaryParallel: failedRequestSpec.isPrimary,
|
|
471
|
+
selectedModel: failedRequestSpec.modelId,
|
|
472
|
+
activeStreamId: null,
|
|
473
|
+
selectedTool: undefined,
|
|
474
|
+
},
|
|
475
|
+
});
|
|
476
|
+
});
|
|
477
|
+
|
|
478
|
+
await invalidatePersistedMessages();
|
|
479
|
+
},
|
|
480
|
+
[
|
|
481
|
+
addMessageToTree,
|
|
482
|
+
chatId,
|
|
483
|
+
drainSecondaryParallelRequest,
|
|
484
|
+
getCurrentProjectId,
|
|
485
|
+
invalidatePersistedMessages,
|
|
486
|
+
]
|
|
487
|
+
);
|
|
488
|
+
|
|
322
489
|
const coreSubmitLogic = useCallback(() => {
|
|
323
490
|
const input = getInputValue();
|
|
324
491
|
|
|
325
|
-
updateChatUrl(chatId);
|
|
326
|
-
|
|
327
492
|
// Get the appropriate parent message ID
|
|
328
493
|
const effectiveParentMessageId = isEditMode
|
|
329
494
|
? parentMessageId
|
|
@@ -334,6 +499,21 @@ function PureMultimodalInput({
|
|
|
334
499
|
trimMessagesInEditMode(parentMessageId);
|
|
335
500
|
}
|
|
336
501
|
|
|
502
|
+
const isParallelRequest = parallelResponsesEnabled && requestedModelIds.length > 1;
|
|
503
|
+
const parallelGroupId = isParallelRequest ? generateUUID() : null;
|
|
504
|
+
const requestSpecs = isParallelRequest
|
|
505
|
+
? requestedModelIds.map(
|
|
506
|
+
(modelId, parallelIndex): ParallelRequestSpec => ({
|
|
507
|
+
assistantMessageId: generateUUID(),
|
|
508
|
+
createdAt: new Date(Date.now() + parallelIndex),
|
|
509
|
+
isPrimary: parallelIndex === 0,
|
|
510
|
+
modelId,
|
|
511
|
+
parallelGroupId: parallelGroupId || generateUUID(),
|
|
512
|
+
parallelIndex,
|
|
513
|
+
})
|
|
514
|
+
)
|
|
515
|
+
: [];
|
|
516
|
+
|
|
337
517
|
const message: ChatMessage = {
|
|
338
518
|
id: generateUUID(),
|
|
339
519
|
parts: [
|
|
@@ -351,7 +531,10 @@ function PureMultimodalInput({
|
|
|
351
531
|
metadata: {
|
|
352
532
|
createdAt: new Date(),
|
|
353
533
|
parentMessageId: effectiveParentMessageId,
|
|
354
|
-
|
|
534
|
+
parallelGroupId,
|
|
535
|
+
parallelIndex: null,
|
|
536
|
+
isPrimaryParallel: null,
|
|
537
|
+
selectedModel: normalizedSelectedModel,
|
|
355
538
|
activeStreamId: null,
|
|
356
539
|
selectedTool: selectedTool || undefined,
|
|
357
540
|
},
|
|
@@ -360,10 +543,53 @@ function PureMultimodalInput({
|
|
|
360
543
|
|
|
361
544
|
onSendMessage?.(message);
|
|
362
545
|
|
|
363
|
-
|
|
364
|
-
saveChatMessage({ message, chatId });
|
|
546
|
+
const primaryRequest = requestSpecs[0];
|
|
365
547
|
|
|
366
|
-
|
|
548
|
+
if (primaryRequest) {
|
|
549
|
+
sendMessage(message, {
|
|
550
|
+
body: {
|
|
551
|
+
assistantMessageId: primaryRequest.assistantMessageId,
|
|
552
|
+
selectedModelId: primaryRequest.modelId,
|
|
553
|
+
parallelGroupId: primaryRequest.parallelGroupId,
|
|
554
|
+
parallelIndex: primaryRequest.parallelIndex,
|
|
555
|
+
isPrimaryParallel: true,
|
|
556
|
+
},
|
|
557
|
+
});
|
|
558
|
+
|
|
559
|
+
addMessageToTree(message);
|
|
560
|
+
handleModelChange(primaryRequest.modelId);
|
|
561
|
+
for (const requestSpec of requestSpecs) {
|
|
562
|
+
addMessageToTree({
|
|
563
|
+
id: requestSpec.assistantMessageId,
|
|
564
|
+
parts: [],
|
|
565
|
+
role: "assistant",
|
|
566
|
+
metadata: {
|
|
567
|
+
createdAt: requestSpec.createdAt,
|
|
568
|
+
parentMessageId: message.id,
|
|
569
|
+
parallelGroupId: requestSpec.parallelGroupId,
|
|
570
|
+
parallelIndex: requestSpec.parallelIndex,
|
|
571
|
+
isPrimaryParallel: requestSpec.isPrimary,
|
|
572
|
+
selectedModel: requestSpec.modelId,
|
|
573
|
+
activeStreamId: `pending:${requestSpec.assistantMessageId}`,
|
|
574
|
+
selectedTool: undefined,
|
|
575
|
+
},
|
|
576
|
+
});
|
|
577
|
+
}
|
|
578
|
+
|
|
579
|
+
void runParallelSecondaryRequests({
|
|
580
|
+
message,
|
|
581
|
+
secondaryRequestSpecs: requestSpecs.slice(1),
|
|
582
|
+
}).catch((error) => {
|
|
583
|
+
console.error("Failed to complete parallel requests", error);
|
|
584
|
+
toast.error("Failed to complete all parallel responses");
|
|
585
|
+
void invalidatePersistedMessages();
|
|
586
|
+
});
|
|
587
|
+
} else {
|
|
588
|
+
sendMessage(message);
|
|
589
|
+
addMessageToTree(message);
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
updateChatUrl(chatId);
|
|
367
593
|
|
|
368
594
|
// Refocus after submit
|
|
369
595
|
if (!isMobile) {
|
|
@@ -374,15 +600,19 @@ function PureMultimodalInput({
|
|
|
374
600
|
attachments,
|
|
375
601
|
isMobile,
|
|
376
602
|
chatId,
|
|
603
|
+
handleModelChange,
|
|
604
|
+
invalidatePersistedMessages,
|
|
377
605
|
selectedTool,
|
|
378
606
|
isEditMode,
|
|
379
607
|
getInputValue,
|
|
380
|
-
saveChatMessage,
|
|
381
608
|
parentMessageId,
|
|
382
|
-
|
|
609
|
+
normalizedSelectedModel,
|
|
383
610
|
editorRef,
|
|
384
611
|
lastMessageId,
|
|
385
612
|
onSendMessage,
|
|
613
|
+
parallelResponsesEnabled,
|
|
614
|
+
requestedModelIds,
|
|
615
|
+
runParallelSecondaryRequests,
|
|
386
616
|
sendMessage,
|
|
387
617
|
updateChatUrl,
|
|
388
618
|
trimMessagesInEditMode,
|
|
@@ -680,10 +910,11 @@ function PureMultimodalInput({
|
|
|
680
910
|
acceptImages={acceptImages}
|
|
681
911
|
attachmentsEnabled={attachmentsEnabled}
|
|
682
912
|
fileInputRef={fileInputRef}
|
|
683
|
-
|
|
913
|
+
onModelSelectionChange={handleModelSelectionChange}
|
|
684
914
|
onStop={handleStop}
|
|
685
915
|
parentMessageId={parentMessageId}
|
|
686
916
|
selectedModelId={selectedModelId}
|
|
917
|
+
selectedModelSelection={selectedModelSelection}
|
|
687
918
|
selectedTool={selectedTool}
|
|
688
919
|
setSelectedTool={setSelectedTool}
|
|
689
920
|
status={status}
|
|
@@ -831,7 +1062,8 @@ const AttachmentsButton = memo(PureAttachmentsButton);
|
|
|
831
1062
|
|
|
832
1063
|
function PureChatInputBottomControls({
|
|
833
1064
|
selectedModelId,
|
|
834
|
-
|
|
1065
|
+
selectedModelSelection,
|
|
1066
|
+
onModelSelectionChange,
|
|
835
1067
|
selectedTool,
|
|
836
1068
|
setSelectedTool,
|
|
837
1069
|
fileInputRef,
|
|
@@ -846,7 +1078,8 @@ function PureChatInputBottomControls({
|
|
|
846
1078
|
onStop,
|
|
847
1079
|
}: {
|
|
848
1080
|
selectedModelId: AppModelId;
|
|
849
|
-
|
|
1081
|
+
selectedModelSelection: SelectedModelValue;
|
|
1082
|
+
onModelSelectionChange: (selection: SelectedModelValue) => void;
|
|
850
1083
|
selectedTool: UiToolName | null;
|
|
851
1084
|
setSelectedTool: Dispatch<SetStateAction<UiToolName | null>>;
|
|
852
1085
|
fileInputRef: React.MutableRefObject<HTMLInputElement | null>;
|
|
@@ -874,8 +1107,9 @@ function PureChatInputBottomControls({
|
|
|
874
1107
|
)}
|
|
875
1108
|
<ModelSelector
|
|
876
1109
|
className="@[500px]:h-10 h-8 w-fit max-w-none shrink justify-start truncate @[500px]:px-3 px-2 @[500px]:text-sm text-xs"
|
|
877
|
-
onModelChangeAction={onModelChange}
|
|
878
1110
|
selectedModelId={selectedModelId}
|
|
1111
|
+
selectedModelSelection={selectedModelSelection}
|
|
1112
|
+
onModelSelectionChangeAction={onModelSelectionChange}
|
|
879
1113
|
/>
|
|
880
1114
|
<ConnectorsDropdown />
|
|
881
1115
|
<ResponsiveTools
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"use client";
|
|
2
|
+
|
|
3
|
+
import { useMessageById } from "@ai-sdk-tools/store";
|
|
4
|
+
import { LoaderCircle } from "lucide-react";
|
|
5
|
+
import { memo, useMemo } from "react";
|
|
6
|
+
import { Button } from "@/components/ui/button";
|
|
7
|
+
import { useNavigateToMessage } from "@/hooks/use-navigate-to-message";
|
|
8
|
+
import {
|
|
9
|
+
type ChatMessage,
|
|
10
|
+
expandSelectedModelValue,
|
|
11
|
+
getPrimarySelectedModelId,
|
|
12
|
+
} from "@/lib/ai/types";
|
|
13
|
+
import { useParallelGroupInfo } from "@/lib/stores/hooks-threads";
|
|
14
|
+
import { cn } from "@/lib/utils";
|
|
15
|
+
import { useChatInput } from "@/providers/chat-input-provider";
|
|
16
|
+
import { useChatModels } from "@/providers/chat-models-provider";
|
|
17
|
+
|
|
18
|
+
function PureParallelResponseCards({
|
|
19
|
+
messageId,
|
|
20
|
+
}: {
|
|
21
|
+
messageId: string;
|
|
22
|
+
}) {
|
|
23
|
+
const message = useMessageById<ChatMessage>(messageId);
|
|
24
|
+
const parallelGroupInfo = useParallelGroupInfo(messageId);
|
|
25
|
+
const navigateToMessage = useNavigateToMessage();
|
|
26
|
+
const { handleModelChange } = useChatInput();
|
|
27
|
+
const { getModelById, models } = useChatModels();
|
|
28
|
+
|
|
29
|
+
const cardSlots = useMemo(() => {
|
|
30
|
+
if (
|
|
31
|
+
!message ||
|
|
32
|
+
message.role !== "user" ||
|
|
33
|
+
!message.metadata.parallelGroupId ||
|
|
34
|
+
typeof message.metadata.selectedModel === "string"
|
|
35
|
+
) {
|
|
36
|
+
return [];
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
const requestedModelIds = expandSelectedModelValue(message.metadata.selectedModel);
|
|
40
|
+
|
|
41
|
+
return requestedModelIds.map((modelId, parallelIndex) => {
|
|
42
|
+
const actualMessage = parallelGroupInfo?.messages.find(
|
|
43
|
+
(candidate) => candidate.metadata.parallelIndex === parallelIndex
|
|
44
|
+
);
|
|
45
|
+
|
|
46
|
+
return {
|
|
47
|
+
modelId,
|
|
48
|
+
parallelIndex,
|
|
49
|
+
message: actualMessage ?? null,
|
|
50
|
+
};
|
|
51
|
+
});
|
|
52
|
+
}, [message, parallelGroupInfo]);
|
|
53
|
+
|
|
54
|
+
const sortedCardSlots = useMemo(() => {
|
|
55
|
+
return [...cardSlots].sort((left, right) => {
|
|
56
|
+
const leftModelId =
|
|
57
|
+
left.message?.metadata.selectedModel
|
|
58
|
+
? getPrimarySelectedModelId(left.message.metadata.selectedModel)
|
|
59
|
+
: left.modelId;
|
|
60
|
+
const rightModelId =
|
|
61
|
+
right.message?.metadata.selectedModel
|
|
62
|
+
? getPrimarySelectedModelId(right.message.metadata.selectedModel)
|
|
63
|
+
: right.modelId;
|
|
64
|
+
|
|
65
|
+
const leftIndex = leftModelId
|
|
66
|
+
? models.findIndex((m) => m.id === leftModelId)
|
|
67
|
+
: -1;
|
|
68
|
+
const rightIndex = rightModelId
|
|
69
|
+
? models.findIndex((m) => m.id === rightModelId)
|
|
70
|
+
: -1;
|
|
71
|
+
|
|
72
|
+
const leftOrder = leftIndex === -1 ? Infinity : leftIndex;
|
|
73
|
+
const rightOrder = rightIndex === -1 ? Infinity : rightIndex;
|
|
74
|
+
|
|
75
|
+
if (leftOrder !== rightOrder) {
|
|
76
|
+
return leftOrder - rightOrder;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
const leftMessageId =
|
|
80
|
+
left.message?.id ?? `${left.modelId}:${left.parallelIndex}`;
|
|
81
|
+
const rightMessageId =
|
|
82
|
+
right.message?.id ?? `${right.modelId}:${right.parallelIndex}`;
|
|
83
|
+
|
|
84
|
+
return leftMessageId.localeCompare(rightMessageId);
|
|
85
|
+
});
|
|
86
|
+
}, [cardSlots, models]);
|
|
87
|
+
|
|
88
|
+
const selectedParallelIndex = useMemo(() => {
|
|
89
|
+
if (parallelGroupInfo?.selectedMessageId) {
|
|
90
|
+
const selectedMessage = parallelGroupInfo.messages.find(
|
|
91
|
+
(candidate) => candidate.id === parallelGroupInfo.selectedMessageId
|
|
92
|
+
);
|
|
93
|
+
if (typeof selectedMessage?.metadata.parallelIndex === "number") {
|
|
94
|
+
return selectedMessage.metadata.parallelIndex;
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
return cardSlots.length > 0 ? 0 : null;
|
|
99
|
+
}, [cardSlots.length, parallelGroupInfo]);
|
|
100
|
+
|
|
101
|
+
if (!message || sortedCardSlots.length <= 1) {
|
|
102
|
+
return null;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
return (
|
|
106
|
+
<div className="mt-3 flex flex-wrap justify-end gap-2">
|
|
107
|
+
{sortedCardSlots.map((slot) => {
|
|
108
|
+
const modelId =
|
|
109
|
+
slot.message?.metadata.selectedModel
|
|
110
|
+
? getPrimarySelectedModelId(slot.message.metadata.selectedModel)
|
|
111
|
+
: slot.modelId;
|
|
112
|
+
const modelName = modelId ? getModelById(modelId)?.name ?? modelId : "Model";
|
|
113
|
+
const isSelected = selectedParallelIndex === slot.parallelIndex;
|
|
114
|
+
const isStreaming = slot.message
|
|
115
|
+
? slot.message.metadata.activeStreamId !== null
|
|
116
|
+
: true;
|
|
117
|
+
const statusLabel = isSelected
|
|
118
|
+
? "Selected"
|
|
119
|
+
: isStreaming
|
|
120
|
+
? "Generating..."
|
|
121
|
+
: "Task completed";
|
|
122
|
+
|
|
123
|
+
return (
|
|
124
|
+
<Button
|
|
125
|
+
className={cn(
|
|
126
|
+
"h-auto min-w-[160px] flex-col items-start gap-1 rounded-xl px-3 py-2 text-left",
|
|
127
|
+
isSelected && "border-primary bg-primary/5 text-primary"
|
|
128
|
+
)}
|
|
129
|
+
disabled={!slot.message}
|
|
130
|
+
key={`${message.id}-${slot.parallelIndex}`}
|
|
131
|
+
onClick={() => {
|
|
132
|
+
if (slot.message) {
|
|
133
|
+
navigateToMessage(slot.message.id);
|
|
134
|
+
if (modelId) {
|
|
135
|
+
handleModelChange(modelId);
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
}}
|
|
139
|
+
type="button"
|
|
140
|
+
variant="outline"
|
|
141
|
+
>
|
|
142
|
+
<span className="font-medium text-sm">{modelName}</span>
|
|
143
|
+
<span className="flex items-center gap-1 text-muted-foreground text-xs">
|
|
144
|
+
{isStreaming ? <LoaderCircle className="size-3 animate-spin" /> : null}
|
|
145
|
+
{statusLabel}
|
|
146
|
+
</span>
|
|
147
|
+
</Button>
|
|
148
|
+
);
|
|
149
|
+
})}
|
|
150
|
+
</div>
|
|
151
|
+
);
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
export const ParallelResponseCards = memo(
|
|
155
|
+
PureParallelResponseCards,
|
|
156
|
+
(prevProps, nextProps) => prevProps.messageId === nextProps.messageId
|
|
157
|
+
);
|
|
@@ -4,9 +4,13 @@ import { memo } from "react";
|
|
|
4
4
|
import { Response } from "../ai-elements/response";
|
|
5
5
|
|
|
6
6
|
export const TextMessagePart = memo(
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
7
|
+
({ text, isLoading }: { text: string; isLoading: boolean }) => (
|
|
8
|
+
<Response
|
|
9
|
+
animated
|
|
10
|
+
isAnimating={isLoading}
|
|
11
|
+
mode={isLoading ? "streaming" : "static"}
|
|
12
|
+
>
|
|
13
|
+
{text}
|
|
14
|
+
</Response>
|
|
15
|
+
),
|
|
12
16
|
);
|
|
@@ -7,7 +7,7 @@ import { RefreshCcw } from "lucide-react";
|
|
|
7
7
|
import { useCallback } from "react";
|
|
8
8
|
import { toast } from "sonner";
|
|
9
9
|
import { Action } from "@/components/ai-elements/actions";
|
|
10
|
-
import type
|
|
10
|
+
import { getPrimarySelectedModelId, type ChatMessage } from "@/lib/ai/types";
|
|
11
11
|
|
|
12
12
|
export function RetryButton({
|
|
13
13
|
messageId,
|
|
@@ -28,16 +28,20 @@ export function RetryButton({
|
|
|
28
28
|
|
|
29
29
|
// Find the current message (AI response) and its parent (user message)
|
|
30
30
|
const currentMessages = chatStore.getState().messages;
|
|
31
|
-
const
|
|
32
|
-
|
|
33
|
-
);
|
|
34
|
-
if (currentMessageIdx === -1) {
|
|
31
|
+
const currentMessage = currentMessages.find((msg) => msg.id === messageId);
|
|
32
|
+
if (!currentMessage) {
|
|
35
33
|
toast.error("Cannot find the message to retry");
|
|
36
34
|
return;
|
|
37
35
|
}
|
|
38
36
|
|
|
39
|
-
|
|
40
|
-
|
|
37
|
+
const currentMessageIdx = currentMessages.findIndex(
|
|
38
|
+
(msg) => msg.id === messageId
|
|
39
|
+
);
|
|
40
|
+
const parentMessageId = currentMessage.metadata?.parentMessageId ?? null;
|
|
41
|
+
const parentMessageIdx = parentMessageId
|
|
42
|
+
? currentMessages.findIndex((msg) => msg.id === parentMessageId)
|
|
43
|
+
: currentMessageIdx - 1;
|
|
44
|
+
|
|
41
45
|
if (parentMessageIdx < 0) {
|
|
42
46
|
toast.error("Cannot find the user message to retry");
|
|
43
47
|
return;
|
|
@@ -48,6 +52,16 @@ export function RetryButton({
|
|
|
48
52
|
toast.error("Parent message is not from user");
|
|
49
53
|
return;
|
|
50
54
|
}
|
|
55
|
+
|
|
56
|
+
const retryModelId =
|
|
57
|
+
getPrimarySelectedModelId(currentMessage.metadata?.selectedModel) ||
|
|
58
|
+
getPrimarySelectedModelId(parentMessage.metadata?.selectedModel);
|
|
59
|
+
|
|
60
|
+
if (!retryModelId) {
|
|
61
|
+
toast.error("Cannot determine which model to retry");
|
|
62
|
+
return;
|
|
63
|
+
}
|
|
64
|
+
|
|
51
65
|
setMessages(currentMessages.slice(0, parentMessageIdx));
|
|
52
66
|
|
|
53
67
|
// Resend the parent user message
|
|
@@ -57,7 +71,10 @@ export function RetryButton({
|
|
|
57
71
|
metadata: {
|
|
58
72
|
...parentMessage.metadata,
|
|
59
73
|
createdAt: parentMessage.metadata?.createdAt || new Date(),
|
|
60
|
-
|
|
74
|
+
parallelGroupId: null,
|
|
75
|
+
parallelIndex: null,
|
|
76
|
+
isPrimaryParallel: null,
|
|
77
|
+
selectedModel: retryModelId,
|
|
61
78
|
parentMessageId: parentMessage.metadata?.parentMessageId || null,
|
|
62
79
|
},
|
|
63
80
|
},
|