@chat-js/cli 0.2.1 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. package/dist/index.js +216 -171
  2. package/package.json +1 -1
  3. package/templates/chat-app/CHANGELOG.md +19 -0
  4. package/templates/chat-app/app/(chat)/actions.ts +9 -9
  5. package/templates/chat-app/app/(chat)/api/chat/prepare/route.ts +94 -0
  6. package/templates/chat-app/app/(chat)/api/chat/route.ts +97 -14
  7. package/templates/chat-app/chat.config.ts +144 -156
  8. package/templates/chat-app/components/chat-sync.tsx +6 -3
  9. package/templates/chat-app/components/feedback-actions.tsx +7 -3
  10. package/templates/chat-app/components/message-editor.tsx +8 -3
  11. package/templates/chat-app/components/message-siblings.tsx +14 -1
  12. package/templates/chat-app/components/model-selector.tsx +669 -407
  13. package/templates/chat-app/components/multimodal-input.tsx +252 -18
  14. package/templates/chat-app/components/parallel-response-cards.tsx +157 -0
  15. package/templates/chat-app/components/part/text-message-part.tsx +9 -5
  16. package/templates/chat-app/components/retry-button.tsx +25 -8
  17. package/templates/chat-app/components/user-message.tsx +136 -125
  18. package/templates/chat-app/hooks/chat-sync-hooks.ts +11 -0
  19. package/templates/chat-app/hooks/use-navigate-to-message.ts +39 -0
  20. package/templates/chat-app/lib/ai/gateway-model-defaults.ts +154 -100
  21. package/templates/chat-app/lib/ai/gateways/openrouter-gateway.ts +2 -2
  22. package/templates/chat-app/lib/ai/tools/generate-image.ts +9 -2
  23. package/templates/chat-app/lib/ai/tools/generate-video.ts +3 -0
  24. package/templates/chat-app/lib/ai/types.ts +74 -3
  25. package/templates/chat-app/lib/config-schema.ts +131 -132
  26. package/templates/chat-app/lib/config.ts +2 -2
  27. package/templates/chat-app/lib/db/migrations/0044_gray_red_shift.sql +5 -0
  28. package/templates/chat-app/lib/db/migrations/meta/0044_snapshot.json +1567 -0
  29. package/templates/chat-app/lib/db/migrations/meta/_journal.json +8 -1
  30. package/templates/chat-app/lib/db/queries.ts +84 -4
  31. package/templates/chat-app/lib/db/schema.ts +4 -1
  32. package/templates/chat-app/lib/message-conversion.ts +14 -2
  33. package/templates/chat-app/lib/stores/hooks-threads.ts +37 -1
  34. package/templates/chat-app/lib/stores/with-threads.test.ts +137 -0
  35. package/templates/chat-app/lib/stores/with-threads.ts +157 -4
  36. package/templates/chat-app/lib/thread-utils.ts +23 -2
  37. package/templates/chat-app/package.json +1 -1
  38. package/templates/chat-app/providers/chat-input-provider.tsx +40 -2
  39. package/templates/chat-app/scripts/db-branch-delete.sh +7 -1
  40. package/templates/chat-app/scripts/db-branch-use.sh +7 -1
  41. package/templates/chat-app/scripts/with-db.sh +7 -1
  42. package/templates/chat-app/vitest.config.ts +2 -0
@@ -1,7 +1,7 @@
1
1
  "use client";
2
2
  import type { UseChatHelpers } from "@ai-sdk/react";
3
3
  import { useChatActions, useChatStoreApi } from "@ai-sdk-tools/store";
4
- import { useMutation } from "@tanstack/react-query";
4
+ import { useMutation, useQueryClient } from "@tanstack/react-query";
5
5
  import { CameraIcon, FileIcon, ImageIcon, PlusIcon } from "lucide-react";
6
6
  import type React from "react";
7
7
  import {
@@ -25,17 +25,22 @@ import {
25
25
  } from "@/components/ai-elements/prompt-input";
26
26
  import { ContextBar } from "@/components/context-bar";
27
27
  import { ContextUsageFromParent } from "@/components/context-usage";
28
- import { useSaveMessageMutation } from "@/hooks/chat-sync-hooks";
29
28
  import { useArtifact } from "@/hooks/use-artifact";
30
29
  import { useIsMobile } from "@/hooks/use-mobile";
31
30
  import type { AppModelId } from "@/lib/ai/app-model-id";
32
- import type { Attachment, ChatMessage, UiToolName } from "@/lib/ai/types";
31
+ import {
32
+ expandSelectedModelValue,
33
+ type Attachment,
34
+ type ChatMessage,
35
+ type SelectedModelValue,
36
+ type UiToolName,
37
+ } from "@/lib/ai/types";
33
38
  import { config } from "@/lib/config";
34
39
  import { processFilesForUpload } from "@/lib/files/upload-prep";
35
40
  import { useLastMessageId } from "@/lib/stores/hooks-base";
36
41
  import { useAddMessageToTree } from "@/lib/stores/hooks-threads";
37
42
  import { ANONYMOUS_LIMITS } from "@/lib/types/anonymous";
38
- import { cn, generateUUID } from "@/lib/utils";
43
+ import { cn, fetchWithErrorHandlers, generateUUID } from "@/lib/utils";
39
44
  import { useChatId } from "@/providers/chat-id-provider";
40
45
  import { useChatInput } from "@/providers/chat-input-provider";
41
46
  import { useChatModels } from "@/providers/chat-models-provider";
@@ -57,6 +62,16 @@ import { LimitDisplay } from "./upgrade-cta/limit-display";
57
62
  import { LoginPrompt } from "./upgrade-cta/login-prompt";
58
63
 
59
64
  const PROJECT_ROUTE_REGEX = /^\/project\/([^/]+)$/;
65
+ const PROJECT_CHAT_ROUTE_REGEX = /^\/project\/([^/]+)(?:\/chat\/[^/]+)?$/;
66
+
67
+ interface ParallelRequestSpec {
68
+ assistantMessageId: string;
69
+ createdAt: Date;
70
+ isPrimary: boolean;
71
+ modelId: AppModelId;
72
+ parallelGroupId: string;
73
+ parallelIndex: number;
74
+ }
60
75
 
61
76
  /** Derive accept string for images only */
62
77
  function getAcceptImages(acceptedTypes: Record<string, string[]>): string {
@@ -100,8 +115,8 @@ function PureMultimodalInput({
100
115
  const { artifact, closeArtifact } = useArtifact();
101
116
  const { data: session } = useSession();
102
117
  const trpc = useTRPC();
118
+ const queryClient = useQueryClient();
103
119
  const isMobile = useIsMobile();
104
- const { mutate: saveChatMessage } = useSaveMessageMutation();
105
120
  const addMessageToTree = useAddMessageToTree();
106
121
  useChatId();
107
122
  const {
@@ -117,7 +132,9 @@ function PureMultimodalInput({
117
132
  attachments,
118
133
  setAttachments,
119
134
  selectedModelId,
135
+ selectedModelSelection,
120
136
  handleModelChange,
137
+ handleModelSelectionChange,
121
138
  getInputValue,
122
139
  handleInputChange,
123
140
  getInitialInput,
@@ -135,6 +152,18 @@ function PureMultimodalInput({
135
152
  const stopStreamMutation = useMutation(
136
153
  trpc.chat.stopStream.mutationOptions()
137
154
  );
155
+ const normalizedSelectedModel = useMemo<SelectedModelValue>(() => {
156
+ const expanded = expandSelectedModelValue(selectedModelSelection);
157
+
158
+ return expanded.length > 1 ? selectedModelSelection : selectedModelId;
159
+ }, [selectedModelId, selectedModelSelection]);
160
+ const requestedModelIds = useMemo(
161
+ () => expandSelectedModelValue(normalizedSelectedModel),
162
+ [normalizedSelectedModel]
163
+ );
164
+ const parallelResponsesEnabled = config.features.parallelResponses;
165
+ const isParallelModelRequest =
166
+ parallelResponsesEnabled && requestedModelIds.length > 1;
138
167
 
139
168
  // Attachment configuration from site config
140
169
  const { maxBytes, maxDimension, acceptedTypes } = config.attachments;
@@ -181,6 +210,18 @@ function PureMultimodalInput({
181
210
  const submission = useMemo(():
182
211
  | { enabled: false; message: string }
183
212
  | { enabled: true } => {
213
+ if (isParallelModelRequest && !session?.user) {
214
+ return {
215
+ enabled: false,
216
+ message: "Log in to use multiple models",
217
+ };
218
+ }
219
+ if (isParallelModelRequest && attachments.length > 0) {
220
+ return {
221
+ enabled: false,
222
+ message: "Multiple models with attachments are not supported yet",
223
+ };
224
+ }
184
225
  if (isModelDisallowedForAnonymous) {
185
226
  return { enabled: false, message: "Log in to use this model" };
186
227
  }
@@ -203,7 +244,15 @@ function PureMultimodalInput({
203
244
  };
204
245
  }
205
246
  return { enabled: true };
206
- }, [isEmpty, isModelDisallowedForAnonymous, status, uploadQueue.length]);
247
+ }, [
248
+ attachments.length,
249
+ isEmpty,
250
+ isModelDisallowedForAnonymous,
251
+ isParallelModelRequest,
252
+ session?.user,
253
+ status,
254
+ uploadQueue.length,
255
+ ]);
207
256
 
208
257
  // Helper function to process and validate files
209
258
  const processFiles = useCallback(
@@ -275,6 +324,11 @@ function PureMultimodalInput({
275
324
  [session?.user]
276
325
  );
277
326
 
327
+ const getCurrentProjectId = useCallback(() => {
328
+ const projectMatch = window.location.pathname.match(PROJECT_CHAT_ROUTE_REGEX);
329
+ return projectMatch?.[1];
330
+ }, []);
331
+
278
332
  // Trim messages in edit mode
279
333
  const trimMessagesInEditMode = useCallback(
280
334
  (parentId: string | null) => {
@@ -319,11 +373,122 @@ function PureMultimodalInput({
319
373
  ]
320
374
  );
321
375
 
376
+ const invalidatePersistedMessages = useCallback(async () => {
377
+ await queryClient.invalidateQueries({
378
+ queryKey: trpc.chat.getChatMessages.queryKey({ chatId }),
379
+ });
380
+ }, [chatId, queryClient, trpc]);
381
+
382
+ const drainSecondaryParallelRequest = useCallback(
383
+ async ({
384
+ message,
385
+ requestSpec,
386
+ }: {
387
+ message: ChatMessage;
388
+ requestSpec: ParallelRequestSpec;
389
+ }) => {
390
+ const response = await fetchWithErrorHandlers("/api/chat", {
391
+ method: "POST",
392
+ headers: {
393
+ "Content-Type": "application/json",
394
+ },
395
+ body: JSON.stringify({
396
+ id: chatId,
397
+ message,
398
+ prevMessages: [],
399
+ projectId: getCurrentProjectId(),
400
+ assistantMessageId: requestSpec.assistantMessageId,
401
+ selectedModelId: requestSpec.modelId,
402
+ parallelGroupId: requestSpec.parallelGroupId,
403
+ parallelIndex: requestSpec.parallelIndex,
404
+ isPrimaryParallel: false,
405
+ }),
406
+ });
407
+
408
+ if (!response.body) {
409
+ return;
410
+ }
411
+
412
+ const reader = response.body.getReader();
413
+
414
+ while (true) {
415
+ const { done } = await reader.read();
416
+
417
+ if (done) {
418
+ break;
419
+ }
420
+ }
421
+ },
422
+ [chatId, getCurrentProjectId]
423
+ );
424
+
425
+ const runParallelSecondaryRequests = useCallback(
426
+ async ({
427
+ message,
428
+ secondaryRequestSpecs,
429
+ }: {
430
+ message: ChatMessage;
431
+ secondaryRequestSpecs: ParallelRequestSpec[];
432
+ }) => {
433
+ await fetchWithErrorHandlers("/api/chat/prepare", {
434
+ method: "POST",
435
+ headers: {
436
+ "Content-Type": "application/json",
437
+ },
438
+ body: JSON.stringify({
439
+ id: chatId,
440
+ message,
441
+ projectId: getCurrentProjectId(),
442
+ }),
443
+ });
444
+
445
+ const results = await Promise.allSettled(
446
+ secondaryRequestSpecs.map((requestSpec) =>
447
+ drainSecondaryParallelRequest({ message, requestSpec })
448
+ )
449
+ );
450
+
451
+ results.forEach((result, index) => {
452
+ if (result.status === "fulfilled") {
453
+ return;
454
+ }
455
+
456
+ const failedRequestSpec = secondaryRequestSpecs[index];
457
+ if (!failedRequestSpec) {
458
+ return;
459
+ }
460
+
461
+ addMessageToTree({
462
+ id: failedRequestSpec.assistantMessageId,
463
+ parts: [],
464
+ role: "assistant",
465
+ metadata: {
466
+ createdAt: failedRequestSpec.createdAt,
467
+ parentMessageId: message.id,
468
+ parallelGroupId: failedRequestSpec.parallelGroupId,
469
+ parallelIndex: failedRequestSpec.parallelIndex,
470
+ isPrimaryParallel: failedRequestSpec.isPrimary,
471
+ selectedModel: failedRequestSpec.modelId,
472
+ activeStreamId: null,
473
+ selectedTool: undefined,
474
+ },
475
+ });
476
+ });
477
+
478
+ await invalidatePersistedMessages();
479
+ },
480
+ [
481
+ addMessageToTree,
482
+ chatId,
483
+ drainSecondaryParallelRequest,
484
+ getCurrentProjectId,
485
+ invalidatePersistedMessages,
486
+ ]
487
+ );
488
+
322
489
  const coreSubmitLogic = useCallback(() => {
323
490
  const input = getInputValue();
324
491
 
325
- updateChatUrl(chatId);
326
-
327
492
  // Get the appropriate parent message ID
328
493
  const effectiveParentMessageId = isEditMode
329
494
  ? parentMessageId
@@ -334,6 +499,21 @@ function PureMultimodalInput({
334
499
  trimMessagesInEditMode(parentMessageId);
335
500
  }
336
501
 
502
+ const isParallelRequest = parallelResponsesEnabled && requestedModelIds.length > 1;
503
+ const parallelGroupId = isParallelRequest ? generateUUID() : null;
504
+ const requestSpecs = isParallelRequest
505
+ ? requestedModelIds.map(
506
+ (modelId, parallelIndex): ParallelRequestSpec => ({
507
+ assistantMessageId: generateUUID(),
508
+ createdAt: new Date(Date.now() + parallelIndex),
509
+ isPrimary: parallelIndex === 0,
510
+ modelId,
511
+ parallelGroupId: parallelGroupId || generateUUID(),
512
+ parallelIndex,
513
+ })
514
+ )
515
+ : [];
516
+
337
517
  const message: ChatMessage = {
338
518
  id: generateUUID(),
339
519
  parts: [
@@ -351,7 +531,10 @@ function PureMultimodalInput({
351
531
  metadata: {
352
532
  createdAt: new Date(),
353
533
  parentMessageId: effectiveParentMessageId,
354
- selectedModel: selectedModelId,
534
+ parallelGroupId,
535
+ parallelIndex: null,
536
+ isPrimaryParallel: null,
537
+ selectedModel: normalizedSelectedModel,
355
538
  activeStreamId: null,
356
539
  selectedTool: selectedTool || undefined,
357
540
  },
@@ -360,10 +543,53 @@ function PureMultimodalInput({
360
543
 
361
544
  onSendMessage?.(message);
362
545
 
363
- addMessageToTree(message);
364
- saveChatMessage({ message, chatId });
546
+ const primaryRequest = requestSpecs[0];
365
547
 
366
- sendMessage(message);
548
+ if (primaryRequest) {
549
+ sendMessage(message, {
550
+ body: {
551
+ assistantMessageId: primaryRequest.assistantMessageId,
552
+ selectedModelId: primaryRequest.modelId,
553
+ parallelGroupId: primaryRequest.parallelGroupId,
554
+ parallelIndex: primaryRequest.parallelIndex,
555
+ isPrimaryParallel: true,
556
+ },
557
+ });
558
+
559
+ addMessageToTree(message);
560
+ handleModelChange(primaryRequest.modelId);
561
+ for (const requestSpec of requestSpecs) {
562
+ addMessageToTree({
563
+ id: requestSpec.assistantMessageId,
564
+ parts: [],
565
+ role: "assistant",
566
+ metadata: {
567
+ createdAt: requestSpec.createdAt,
568
+ parentMessageId: message.id,
569
+ parallelGroupId: requestSpec.parallelGroupId,
570
+ parallelIndex: requestSpec.parallelIndex,
571
+ isPrimaryParallel: requestSpec.isPrimary,
572
+ selectedModel: requestSpec.modelId,
573
+ activeStreamId: `pending:${requestSpec.assistantMessageId}`,
574
+ selectedTool: undefined,
575
+ },
576
+ });
577
+ }
578
+
579
+ void runParallelSecondaryRequests({
580
+ message,
581
+ secondaryRequestSpecs: requestSpecs.slice(1),
582
+ }).catch((error) => {
583
+ console.error("Failed to complete parallel requests", error);
584
+ toast.error("Failed to complete all parallel responses");
585
+ void invalidatePersistedMessages();
586
+ });
587
+ } else {
588
+ sendMessage(message);
589
+ addMessageToTree(message);
590
+ }
591
+
592
+ updateChatUrl(chatId);
367
593
 
368
594
  // Refocus after submit
369
595
  if (!isMobile) {
@@ -374,15 +600,19 @@ function PureMultimodalInput({
374
600
  attachments,
375
601
  isMobile,
376
602
  chatId,
603
+ handleModelChange,
604
+ invalidatePersistedMessages,
377
605
  selectedTool,
378
606
  isEditMode,
379
607
  getInputValue,
380
- saveChatMessage,
381
608
  parentMessageId,
382
- selectedModelId,
609
+ normalizedSelectedModel,
383
610
  editorRef,
384
611
  lastMessageId,
385
612
  onSendMessage,
613
+ parallelResponsesEnabled,
614
+ requestedModelIds,
615
+ runParallelSecondaryRequests,
386
616
  sendMessage,
387
617
  updateChatUrl,
388
618
  trimMessagesInEditMode,
@@ -680,10 +910,11 @@ function PureMultimodalInput({
680
910
  acceptImages={acceptImages}
681
911
  attachmentsEnabled={attachmentsEnabled}
682
912
  fileInputRef={fileInputRef}
683
- onModelChange={handleModelChange}
913
+ onModelSelectionChange={handleModelSelectionChange}
684
914
  onStop={handleStop}
685
915
  parentMessageId={parentMessageId}
686
916
  selectedModelId={selectedModelId}
917
+ selectedModelSelection={selectedModelSelection}
687
918
  selectedTool={selectedTool}
688
919
  setSelectedTool={setSelectedTool}
689
920
  status={status}
@@ -831,7 +1062,8 @@ const AttachmentsButton = memo(PureAttachmentsButton);
831
1062
 
832
1063
  function PureChatInputBottomControls({
833
1064
  selectedModelId,
834
- onModelChange,
1065
+ selectedModelSelection,
1066
+ onModelSelectionChange,
835
1067
  selectedTool,
836
1068
  setSelectedTool,
837
1069
  fileInputRef,
@@ -846,7 +1078,8 @@ function PureChatInputBottomControls({
846
1078
  onStop,
847
1079
  }: {
848
1080
  selectedModelId: AppModelId;
849
- onModelChange: (modelId: AppModelId) => void;
1081
+ selectedModelSelection: SelectedModelValue;
1082
+ onModelSelectionChange: (selection: SelectedModelValue) => void;
850
1083
  selectedTool: UiToolName | null;
851
1084
  setSelectedTool: Dispatch<SetStateAction<UiToolName | null>>;
852
1085
  fileInputRef: React.MutableRefObject<HTMLInputElement | null>;
@@ -874,8 +1107,9 @@ function PureChatInputBottomControls({
874
1107
  )}
875
1108
  <ModelSelector
876
1109
  className="@[500px]:h-10 h-8 w-fit max-w-none shrink justify-start truncate @[500px]:px-3 px-2 @[500px]:text-sm text-xs"
877
- onModelChangeAction={onModelChange}
878
1110
  selectedModelId={selectedModelId}
1111
+ selectedModelSelection={selectedModelSelection}
1112
+ onModelSelectionChangeAction={onModelSelectionChange}
879
1113
  />
880
1114
  <ConnectorsDropdown />
881
1115
  <ResponsiveTools
@@ -0,0 +1,157 @@
1
+ "use client";
2
+
3
+ import { useMessageById } from "@ai-sdk-tools/store";
4
+ import { LoaderCircle } from "lucide-react";
5
+ import { memo, useMemo } from "react";
6
+ import { Button } from "@/components/ui/button";
7
+ import { useNavigateToMessage } from "@/hooks/use-navigate-to-message";
8
+ import {
9
+ type ChatMessage,
10
+ expandSelectedModelValue,
11
+ getPrimarySelectedModelId,
12
+ } from "@/lib/ai/types";
13
+ import { useParallelGroupInfo } from "@/lib/stores/hooks-threads";
14
+ import { cn } from "@/lib/utils";
15
+ import { useChatInput } from "@/providers/chat-input-provider";
16
+ import { useChatModels } from "@/providers/chat-models-provider";
17
+
18
+ function PureParallelResponseCards({
19
+ messageId,
20
+ }: {
21
+ messageId: string;
22
+ }) {
23
+ const message = useMessageById<ChatMessage>(messageId);
24
+ const parallelGroupInfo = useParallelGroupInfo(messageId);
25
+ const navigateToMessage = useNavigateToMessage();
26
+ const { handleModelChange } = useChatInput();
27
+ const { getModelById, models } = useChatModels();
28
+
29
+ const cardSlots = useMemo(() => {
30
+ if (
31
+ !message ||
32
+ message.role !== "user" ||
33
+ !message.metadata.parallelGroupId ||
34
+ typeof message.metadata.selectedModel === "string"
35
+ ) {
36
+ return [];
37
+ }
38
+
39
+ const requestedModelIds = expandSelectedModelValue(message.metadata.selectedModel);
40
+
41
+ return requestedModelIds.map((modelId, parallelIndex) => {
42
+ const actualMessage = parallelGroupInfo?.messages.find(
43
+ (candidate) => candidate.metadata.parallelIndex === parallelIndex
44
+ );
45
+
46
+ return {
47
+ modelId,
48
+ parallelIndex,
49
+ message: actualMessage ?? null,
50
+ };
51
+ });
52
+ }, [message, parallelGroupInfo]);
53
+
54
+ const sortedCardSlots = useMemo(() => {
55
+ return [...cardSlots].sort((left, right) => {
56
+ const leftModelId =
57
+ left.message?.metadata.selectedModel
58
+ ? getPrimarySelectedModelId(left.message.metadata.selectedModel)
59
+ : left.modelId;
60
+ const rightModelId =
61
+ right.message?.metadata.selectedModel
62
+ ? getPrimarySelectedModelId(right.message.metadata.selectedModel)
63
+ : right.modelId;
64
+
65
+ const leftIndex = leftModelId
66
+ ? models.findIndex((m) => m.id === leftModelId)
67
+ : -1;
68
+ const rightIndex = rightModelId
69
+ ? models.findIndex((m) => m.id === rightModelId)
70
+ : -1;
71
+
72
+ const leftOrder = leftIndex === -1 ? Infinity : leftIndex;
73
+ const rightOrder = rightIndex === -1 ? Infinity : rightIndex;
74
+
75
+ if (leftOrder !== rightOrder) {
76
+ return leftOrder - rightOrder;
77
+ }
78
+
79
+ const leftMessageId =
80
+ left.message?.id ?? `${left.modelId}:${left.parallelIndex}`;
81
+ const rightMessageId =
82
+ right.message?.id ?? `${right.modelId}:${right.parallelIndex}`;
83
+
84
+ return leftMessageId.localeCompare(rightMessageId);
85
+ });
86
+ }, [cardSlots, models]);
87
+
88
+ const selectedParallelIndex = useMemo(() => {
89
+ if (parallelGroupInfo?.selectedMessageId) {
90
+ const selectedMessage = parallelGroupInfo.messages.find(
91
+ (candidate) => candidate.id === parallelGroupInfo.selectedMessageId
92
+ );
93
+ if (typeof selectedMessage?.metadata.parallelIndex === "number") {
94
+ return selectedMessage.metadata.parallelIndex;
95
+ }
96
+ }
97
+
98
+ return cardSlots.length > 0 ? 0 : null;
99
+ }, [cardSlots.length, parallelGroupInfo]);
100
+
101
+ if (!message || sortedCardSlots.length <= 1) {
102
+ return null;
103
+ }
104
+
105
+ return (
106
+ <div className="mt-3 flex flex-wrap justify-end gap-2">
107
+ {sortedCardSlots.map((slot) => {
108
+ const modelId =
109
+ slot.message?.metadata.selectedModel
110
+ ? getPrimarySelectedModelId(slot.message.metadata.selectedModel)
111
+ : slot.modelId;
112
+ const modelName = modelId ? getModelById(modelId)?.name ?? modelId : "Model";
113
+ const isSelected = selectedParallelIndex === slot.parallelIndex;
114
+ const isStreaming = slot.message
115
+ ? slot.message.metadata.activeStreamId !== null
116
+ : true;
117
+ const statusLabel = isSelected
118
+ ? "Selected"
119
+ : isStreaming
120
+ ? "Generating..."
121
+ : "Task completed";
122
+
123
+ return (
124
+ <Button
125
+ className={cn(
126
+ "h-auto min-w-[160px] flex-col items-start gap-1 rounded-xl px-3 py-2 text-left",
127
+ isSelected && "border-primary bg-primary/5 text-primary"
128
+ )}
129
+ disabled={!slot.message}
130
+ key={`${message.id}-${slot.parallelIndex}`}
131
+ onClick={() => {
132
+ if (slot.message) {
133
+ navigateToMessage(slot.message.id);
134
+ if (modelId) {
135
+ handleModelChange(modelId);
136
+ }
137
+ }
138
+ }}
139
+ type="button"
140
+ variant="outline"
141
+ >
142
+ <span className="font-medium text-sm">{modelName}</span>
143
+ <span className="flex items-center gap-1 text-muted-foreground text-xs">
144
+ {isStreaming ? <LoaderCircle className="size-3 animate-spin" /> : null}
145
+ {statusLabel}
146
+ </span>
147
+ </Button>
148
+ );
149
+ })}
150
+ </div>
151
+ );
152
+ }
153
+
154
+ export const ParallelResponseCards = memo(
155
+ PureParallelResponseCards,
156
+ (prevProps, nextProps) => prevProps.messageId === nextProps.messageId
157
+ );
@@ -4,9 +4,13 @@ import { memo } from "react";
4
4
  import { Response } from "../ai-elements/response";
5
5
 
6
6
  export const TextMessagePart = memo(
7
- ({ text, isLoading }: { text: string; isLoading: boolean }) => (
8
- <Response isAnimating={isLoading} mode={isLoading ? "streaming" : "static"}>
9
- {text}
10
- </Response>
11
- )
7
+ ({ text, isLoading }: { text: string; isLoading: boolean }) => (
8
+ <Response
9
+ animated
10
+ isAnimating={isLoading}
11
+ mode={isLoading ? "streaming" : "static"}
12
+ >
13
+ {text}
14
+ </Response>
15
+ ),
12
16
  );
@@ -7,7 +7,7 @@ import { RefreshCcw } from "lucide-react";
7
7
  import { useCallback } from "react";
8
8
  import { toast } from "sonner";
9
9
  import { Action } from "@/components/ai-elements/actions";
10
- import type { ChatMessage } from "@/lib/ai/types";
10
+ import { getPrimarySelectedModelId, type ChatMessage } from "@/lib/ai/types";
11
11
 
12
12
  export function RetryButton({
13
13
  messageId,
@@ -28,16 +28,20 @@ export function RetryButton({
28
28
 
29
29
  // Find the current message (AI response) and its parent (user message)
30
30
  const currentMessages = chatStore.getState().messages;
31
- const currentMessageIdx = currentMessages.findIndex(
32
- (msg) => msg.id === messageId
33
- );
34
- if (currentMessageIdx === -1) {
31
+ const currentMessage = currentMessages.find((msg) => msg.id === messageId);
32
+ if (!currentMessage) {
35
33
  toast.error("Cannot find the message to retry");
36
34
  return;
37
35
  }
38
36
 
39
- // Find the parent user message (should be the message before the AI response)
40
- const parentMessageIdx = currentMessageIdx - 1;
37
+ const currentMessageIdx = currentMessages.findIndex(
38
+ (msg) => msg.id === messageId
39
+ );
40
+ const parentMessageId = currentMessage.metadata?.parentMessageId ?? null;
41
+ const parentMessageIdx = parentMessageId
42
+ ? currentMessages.findIndex((msg) => msg.id === parentMessageId)
43
+ : currentMessageIdx - 1;
44
+
41
45
  if (parentMessageIdx < 0) {
42
46
  toast.error("Cannot find the user message to retry");
43
47
  return;
@@ -48,6 +52,16 @@ export function RetryButton({
48
52
  toast.error("Parent message is not from user");
49
53
  return;
50
54
  }
55
+
56
+ const retryModelId =
57
+ getPrimarySelectedModelId(currentMessage.metadata?.selectedModel) ||
58
+ getPrimarySelectedModelId(parentMessage.metadata?.selectedModel);
59
+
60
+ if (!retryModelId) {
61
+ toast.error("Cannot determine which model to retry");
62
+ return;
63
+ }
64
+
51
65
  setMessages(currentMessages.slice(0, parentMessageIdx));
52
66
 
53
67
  // Resend the parent user message
@@ -57,7 +71,10 @@ export function RetryButton({
57
71
  metadata: {
58
72
  ...parentMessage.metadata,
59
73
  createdAt: parentMessage.metadata?.createdAt || new Date(),
60
- selectedModel: parentMessage.metadata?.selectedModel || "",
74
+ parallelGroupId: null,
75
+ parallelIndex: null,
76
+ isPrimaryParallel: null,
77
+ selectedModel: retryModelId,
61
78
  parentMessageId: parentMessage.metadata?.parentMessageId || null,
62
79
  },
63
80
  },