@chat-js/cli 0.3.0 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.js +11 -6
- package/package.json +1 -1
- package/templates/chat-app/app/(chat)/api/chat/prepare/route.ts +94 -0
- package/templates/chat-app/app/(chat)/api/chat/route.ts +97 -14
- package/templates/chat-app/chat.config.ts +141 -124
- package/templates/chat-app/components/chat-sync.tsx +6 -3
- package/templates/chat-app/components/feedback-actions.tsx +7 -3
- package/templates/chat-app/components/message-editor.tsx +8 -3
- package/templates/chat-app/components/message-siblings.tsx +14 -1
- package/templates/chat-app/components/model-selector.tsx +669 -407
- package/templates/chat-app/components/multimodal-input.tsx +252 -18
- package/templates/chat-app/components/parallel-response-cards.tsx +157 -0
- package/templates/chat-app/components/part/text-message-part.tsx +9 -5
- package/templates/chat-app/components/retry-button.tsx +25 -8
- package/templates/chat-app/components/user-message.tsx +136 -125
- package/templates/chat-app/hooks/chat-sync-hooks.ts +11 -0
- package/templates/chat-app/hooks/use-navigate-to-message.ts +39 -0
- package/templates/chat-app/lib/ai/types.ts +74 -3
- package/templates/chat-app/lib/config-schema.ts +5 -0
- package/templates/chat-app/lib/db/migrations/0044_gray_red_shift.sql +5 -0
- package/templates/chat-app/lib/db/migrations/meta/0044_snapshot.json +1567 -0
- package/templates/chat-app/lib/db/migrations/meta/_journal.json +8 -1
- package/templates/chat-app/lib/db/queries.ts +84 -4
- package/templates/chat-app/lib/db/schema.ts +4 -1
- package/templates/chat-app/lib/message-conversion.ts +14 -2
- package/templates/chat-app/lib/stores/hooks-threads.ts +37 -1
- package/templates/chat-app/lib/stores/with-threads.test.ts +137 -0
- package/templates/chat-app/lib/stores/with-threads.ts +157 -4
- package/templates/chat-app/lib/thread-utils.ts +23 -2
- package/templates/chat-app/providers/chat-input-provider.tsx +40 -2
- package/templates/chat-app/scripts/db-branch-delete.sh +7 -1
- package/templates/chat-app/scripts/db-branch-use.sh +7 -1
- package/templates/chat-app/scripts/with-db.sh +7 -1
- package/templates/chat-app/vitest.config.ts +2 -0
|
@@ -17,6 +17,7 @@ import type {
|
|
|
17
17
|
ToolName,
|
|
18
18
|
ToolOutput,
|
|
19
19
|
} from "@/lib/ai/types";
|
|
20
|
+
import { isSelectedModelValue } from "@/lib/ai/types";
|
|
20
21
|
import { createModuleLogger } from "@/lib/logger";
|
|
21
22
|
import { chatMessageToDbMessage } from "@/lib/message-conversion";
|
|
22
23
|
|
|
@@ -79,6 +80,35 @@ export async function saveChat({
|
|
|
79
80
|
}
|
|
80
81
|
}
|
|
81
82
|
|
|
83
|
+
export async function saveChatIfNotExists({
|
|
84
|
+
id,
|
|
85
|
+
userId,
|
|
86
|
+
title,
|
|
87
|
+
projectId,
|
|
88
|
+
}: {
|
|
89
|
+
id: string;
|
|
90
|
+
userId: string;
|
|
91
|
+
title: string;
|
|
92
|
+
projectId?: string;
|
|
93
|
+
}) {
|
|
94
|
+
try {
|
|
95
|
+
return await db
|
|
96
|
+
.insert(chat)
|
|
97
|
+
.values({
|
|
98
|
+
id,
|
|
99
|
+
createdAt: new Date(),
|
|
100
|
+
updatedAt: new Date(),
|
|
101
|
+
userId,
|
|
102
|
+
title,
|
|
103
|
+
projectId: projectId ?? null,
|
|
104
|
+
})
|
|
105
|
+
.onConflictDoNothing();
|
|
106
|
+
} catch (error) {
|
|
107
|
+
console.error("Failed to save chat in database");
|
|
108
|
+
throw error;
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
|
|
82
112
|
export async function deleteChatById({ id }: { id: string }) {
|
|
83
113
|
try {
|
|
84
114
|
// Get all messages for this chat to clean up their attachments
|
|
@@ -341,6 +371,43 @@ export async function saveMessage({
|
|
|
341
371
|
}
|
|
342
372
|
}
|
|
343
373
|
|
|
374
|
+
export async function saveMessageIfNotExists({
|
|
375
|
+
id,
|
|
376
|
+
chatId,
|
|
377
|
+
message: chatMessage,
|
|
378
|
+
}: {
|
|
379
|
+
id: string;
|
|
380
|
+
chatId: string;
|
|
381
|
+
message: ChatMessage;
|
|
382
|
+
}) {
|
|
383
|
+
try {
|
|
384
|
+
return await db.transaction(async (tx) => {
|
|
385
|
+
const dbMessage = chatMessageToDbMessage(chatMessage, chatId);
|
|
386
|
+
dbMessage.id = id;
|
|
387
|
+
|
|
388
|
+
const insertedMessages = await tx
|
|
389
|
+
.insert(message)
|
|
390
|
+
.values(dbMessage)
|
|
391
|
+
.onConflictDoNothing()
|
|
392
|
+
.returning({ id: message.id });
|
|
393
|
+
|
|
394
|
+
if (insertedMessages.length === 0) {
|
|
395
|
+
return;
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
const mappedDBParts = mapUIMessagePartsToDBParts(chatMessage.parts, id);
|
|
399
|
+
if (mappedDBParts.length > 0) {
|
|
400
|
+
await tx.insert(part).values(mappedDBParts);
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
await updateChatUpdatedAt({ chatId });
|
|
404
|
+
});
|
|
405
|
+
} catch (error) {
|
|
406
|
+
logger.error({ error, chatId, id }, "saveMessageIfNotExists failed");
|
|
407
|
+
throw error;
|
|
408
|
+
}
|
|
409
|
+
}
|
|
410
|
+
|
|
344
411
|
export async function saveChatMessages({
|
|
345
412
|
messages,
|
|
346
413
|
}: {
|
|
@@ -413,6 +480,11 @@ export async function updateMessage({
|
|
|
413
480
|
attachments: dbMessage.attachments,
|
|
414
481
|
createdAt: dbMessage.createdAt,
|
|
415
482
|
parentMessageId: dbMessage.parentMessageId,
|
|
483
|
+
selectedModel: dbMessage.selectedModel,
|
|
484
|
+
selectedTool: dbMessage.selectedTool,
|
|
485
|
+
parallelGroupId: dbMessage.parallelGroupId,
|
|
486
|
+
parallelIndex: dbMessage.parallelIndex,
|
|
487
|
+
isPrimaryParallel: dbMessage.isPrimaryParallel,
|
|
416
488
|
lastContext: dbMessage.lastContext,
|
|
417
489
|
activeStreamId: dbMessage.activeStreamId,
|
|
418
490
|
})
|
|
@@ -492,8 +564,12 @@ export async function getAllMessagesByChatId({
|
|
|
492
564
|
createdAt: msg.createdAt,
|
|
493
565
|
activeStreamId: msg.activeStreamId,
|
|
494
566
|
parentMessageId: msg.parentMessageId,
|
|
495
|
-
|
|
496
|
-
|
|
567
|
+
parallelGroupId: msg.parallelGroupId,
|
|
568
|
+
parallelIndex: msg.parallelIndex,
|
|
569
|
+
isPrimaryParallel: msg.isPrimaryParallel,
|
|
570
|
+
selectedModel: isSelectedModelValue(msg.selectedModel)
|
|
571
|
+
? msg.selectedModel
|
|
572
|
+
: ("" as ChatMessage["metadata"]["selectedModel"]),
|
|
497
573
|
selectedTool: (msg.selectedTool ||
|
|
498
574
|
undefined) as ChatMessage["metadata"]["selectedTool"],
|
|
499
575
|
usage: msg.lastContext as ChatMessage["metadata"]["usage"],
|
|
@@ -797,8 +873,12 @@ export async function getChatMessageWithPartsById({
|
|
|
797
873
|
createdAt: dbMessage.createdAt,
|
|
798
874
|
activeStreamId: dbMessage.activeStreamId,
|
|
799
875
|
parentMessageId: dbMessage.parentMessageId,
|
|
800
|
-
|
|
801
|
-
|
|
876
|
+
parallelGroupId: dbMessage.parallelGroupId,
|
|
877
|
+
parallelIndex: dbMessage.parallelIndex,
|
|
878
|
+
isPrimaryParallel: dbMessage.isPrimaryParallel,
|
|
879
|
+
selectedModel: isSelectedModelValue(dbMessage.selectedModel)
|
|
880
|
+
? dbMessage.selectedModel
|
|
881
|
+
: ("" as ChatMessage["metadata"]["selectedModel"]),
|
|
802
882
|
selectedTool: (dbMessage.selectedTool ||
|
|
803
883
|
undefined) as ChatMessage["metadata"]["selectedTool"],
|
|
804
884
|
usage: dbMessage.lastContext as ChatMessage["metadata"]["usage"],
|
|
@@ -113,8 +113,11 @@ export const message = pgTable("Message", {
|
|
|
113
113
|
attachments: json("attachments").notNull(),
|
|
114
114
|
createdAt: timestamp("createdAt").notNull(),
|
|
115
115
|
annotations: json("annotations"),
|
|
116
|
-
selectedModel:
|
|
116
|
+
selectedModel: json("selectedModel"),
|
|
117
117
|
selectedTool: varchar("selectedTool", { length: 256 }).default(""),
|
|
118
|
+
parallelGroupId: uuid("parallelGroupId"),
|
|
119
|
+
parallelIndex: integer("parallelIndex"),
|
|
120
|
+
isPrimaryParallel: boolean("isPrimaryParallel"),
|
|
118
121
|
lastContext: json("lastContext"),
|
|
119
122
|
activeStreamId: varchar("activeStreamId", { length: 64 }),
|
|
120
123
|
/** Timestamp when this message's stream was canceled by the user. Null means not canceled. */
|
|
@@ -1,7 +1,11 @@
|
|
|
1
1
|
import type { ModelId } from "@/lib/ai/app-models";
|
|
2
2
|
import type { Chat, DBMessage } from "@/lib/db/schema";
|
|
3
3
|
import type { UIChat } from "@/lib/types/ui-chat";
|
|
4
|
-
import
|
|
4
|
+
import {
|
|
5
|
+
isSelectedModelValue,
|
|
6
|
+
type ChatMessage,
|
|
7
|
+
type UiToolName,
|
|
8
|
+
} from "./ai/types";
|
|
5
9
|
|
|
6
10
|
// Helper functions for type conversion
|
|
7
11
|
export function dbChatToUIChat(chat: Chat): UIChat {
|
|
@@ -29,7 +33,12 @@ function _dbMessageToChatMessage(message: DBMessage): ChatMessage {
|
|
|
29
33
|
createdAt: message.createdAt,
|
|
30
34
|
activeStreamId: message.activeStreamId,
|
|
31
35
|
parentMessageId: message.parentMessageId,
|
|
32
|
-
|
|
36
|
+
parallelGroupId: message.parallelGroupId,
|
|
37
|
+
parallelIndex: message.parallelIndex,
|
|
38
|
+
isPrimaryParallel: message.isPrimaryParallel,
|
|
39
|
+
selectedModel: isSelectedModelValue(message.selectedModel)
|
|
40
|
+
? message.selectedModel
|
|
41
|
+
: ("" as ModelId),
|
|
33
42
|
selectedTool: (message.selectedTool as UiToolName | null) || undefined,
|
|
34
43
|
usage: message.lastContext as ChatMessage["metadata"]["usage"],
|
|
35
44
|
},
|
|
@@ -66,6 +75,9 @@ export function chatMessageToDbMessage(
|
|
|
66
75
|
parentMessageId,
|
|
67
76
|
selectedModel,
|
|
68
77
|
selectedTool: message.metadata?.selectedTool || null,
|
|
78
|
+
parallelGroupId: message.metadata?.parallelGroupId || null,
|
|
79
|
+
parallelIndex: message.metadata?.parallelIndex ?? null,
|
|
80
|
+
isPrimaryParallel: message.metadata?.isPrimaryParallel ?? null,
|
|
69
81
|
activeStreamId: message.metadata?.activeStreamId || null,
|
|
70
82
|
canceledAt: null,
|
|
71
83
|
};
|
|
@@ -7,7 +7,7 @@ import {
|
|
|
7
7
|
type CustomChatStoreState,
|
|
8
8
|
useCustomChatStoreApi,
|
|
9
9
|
} from "./custom-store-provider";
|
|
10
|
-
import type { MessageSiblingInfo } from "./with-threads";
|
|
10
|
+
import type { MessageSiblingInfo, ParallelGroupInfo } from "./with-threads";
|
|
11
11
|
|
|
12
12
|
function useThreadStore<T>(
|
|
13
13
|
selector: (store: CustomChatStoreState<ChatMessage>) => T,
|
|
@@ -95,3 +95,39 @@ export const useSwitchToSibling = () => {
|
|
|
95
95
|
[store]
|
|
96
96
|
);
|
|
97
97
|
};
|
|
98
|
+
|
|
99
|
+
export function useParallelGroupInfo(
|
|
100
|
+
messageId: string
|
|
101
|
+
): ParallelGroupInfo<ChatMessage> | null {
|
|
102
|
+
return useThreadStore(
|
|
103
|
+
(state) => state.getParallelGroupInfo(messageId),
|
|
104
|
+
(a, b) => {
|
|
105
|
+
if (a === null && b === null) {
|
|
106
|
+
return true;
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
if (a === null || b === null) {
|
|
110
|
+
return false;
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
return (
|
|
114
|
+
a.parallelGroupId === b.parallelGroupId &&
|
|
115
|
+
a.selectedMessageId === b.selectedMessageId &&
|
|
116
|
+
a.messages.length === b.messages.length &&
|
|
117
|
+
a.messages.every(
|
|
118
|
+
(msg, i) =>
|
|
119
|
+
msg.id === b.messages[i]?.id &&
|
|
120
|
+
msg.metadata?.activeStreamId === b.messages[i]?.metadata?.activeStreamId
|
|
121
|
+
)
|
|
122
|
+
);
|
|
123
|
+
}
|
|
124
|
+
);
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
export const useSwitchToMessage = () => {
|
|
128
|
+
const store = useCustomChatStoreApi<ChatMessage>();
|
|
129
|
+
return useCallback(
|
|
130
|
+
(messageId: string) => store.getState().switchToMessage(messageId),
|
|
131
|
+
[store]
|
|
132
|
+
);
|
|
133
|
+
};
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
import assert from "node:assert/strict";
|
|
2
|
+
import type { StoreState as BaseChatStoreState } from "@ai-sdk-tools/store";
|
|
3
|
+
import { describe, it } from "vitest";
|
|
4
|
+
import { createStore } from "zustand/vanilla";
|
|
5
|
+
import type { ChatMessage } from "../ai/types";
|
|
6
|
+
import { withThreads, type ThreadAugmentedState } from "./with-threads";
|
|
7
|
+
|
|
8
|
+
function createMessage({
|
|
9
|
+
id,
|
|
10
|
+
role,
|
|
11
|
+
createdAt,
|
|
12
|
+
parentMessageId = null,
|
|
13
|
+
parallelGroupId = null,
|
|
14
|
+
parallelIndex = null,
|
|
15
|
+
activeStreamId = null,
|
|
16
|
+
}: {
|
|
17
|
+
id: string;
|
|
18
|
+
role: ChatMessage["role"];
|
|
19
|
+
createdAt: string;
|
|
20
|
+
parentMessageId?: string | null;
|
|
21
|
+
parallelGroupId?: string | null;
|
|
22
|
+
parallelIndex?: number | null;
|
|
23
|
+
activeStreamId?: string | null;
|
|
24
|
+
}): ChatMessage {
|
|
25
|
+
return {
|
|
26
|
+
id,
|
|
27
|
+
role,
|
|
28
|
+
parts: [],
|
|
29
|
+
metadata: {
|
|
30
|
+
createdAt: new Date(createdAt),
|
|
31
|
+
parentMessageId,
|
|
32
|
+
parallelGroupId,
|
|
33
|
+
parallelIndex,
|
|
34
|
+
isPrimaryParallel: parallelIndex === null ? null : parallelIndex === 0,
|
|
35
|
+
selectedModel: "openai/gpt-4o-mini",
|
|
36
|
+
activeStreamId,
|
|
37
|
+
selectedTool: undefined,
|
|
38
|
+
},
|
|
39
|
+
};
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
function createThreadStore(initialMessages: ChatMessage[]) {
|
|
43
|
+
return createStore<ThreadAugmentedState<ChatMessage>>()(
|
|
44
|
+
withThreads<ChatMessage, BaseChatStoreState<ChatMessage>>(
|
|
45
|
+
(set) =>
|
|
46
|
+
({
|
|
47
|
+
messages: initialMessages,
|
|
48
|
+
setMessages: (messages: ChatMessage[]) => set({ messages }),
|
|
49
|
+
}) as BaseChatStoreState<ChatMessage>
|
|
50
|
+
)
|
|
51
|
+
);
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
describe("withThreads", () => {
|
|
55
|
+
it("preserves local-only optimistic branch nodes across server syncs", () => {
|
|
56
|
+
const rootUser = createMessage({
|
|
57
|
+
id: "user-root",
|
|
58
|
+
role: "user",
|
|
59
|
+
createdAt: "2024-01-01T00:00:00.000Z",
|
|
60
|
+
});
|
|
61
|
+
const branchA = createMessage({
|
|
62
|
+
id: "assistant-a",
|
|
63
|
+
role: "assistant",
|
|
64
|
+
createdAt: "2024-01-01T00:00:01.000Z",
|
|
65
|
+
parentMessageId: rootUser.id,
|
|
66
|
+
parallelGroupId: "group-root",
|
|
67
|
+
parallelIndex: 0,
|
|
68
|
+
});
|
|
69
|
+
const branchB = createMessage({
|
|
70
|
+
id: "assistant-b",
|
|
71
|
+
role: "assistant",
|
|
72
|
+
createdAt: "2024-01-01T00:00:02.000Z",
|
|
73
|
+
parentMessageId: rootUser.id,
|
|
74
|
+
parallelGroupId: "group-root",
|
|
75
|
+
parallelIndex: 1,
|
|
76
|
+
});
|
|
77
|
+
const nestedUser = createMessage({
|
|
78
|
+
id: "user-nested",
|
|
79
|
+
role: "user",
|
|
80
|
+
createdAt: "2024-01-01T00:00:03.000Z",
|
|
81
|
+
parentMessageId: branchA.id,
|
|
82
|
+
});
|
|
83
|
+
const nestedBranchA = createMessage({
|
|
84
|
+
id: "assistant-nested-a",
|
|
85
|
+
role: "assistant",
|
|
86
|
+
createdAt: "2024-01-01T00:00:04.000Z",
|
|
87
|
+
parentMessageId: nestedUser.id,
|
|
88
|
+
parallelGroupId: "group-nested",
|
|
89
|
+
parallelIndex: 0,
|
|
90
|
+
activeStreamId: "pending:assistant-nested-a",
|
|
91
|
+
});
|
|
92
|
+
const nestedBranchB = createMessage({
|
|
93
|
+
id: "assistant-nested-b",
|
|
94
|
+
role: "assistant",
|
|
95
|
+
createdAt: "2024-01-01T00:00:05.000Z",
|
|
96
|
+
parentMessageId: nestedUser.id,
|
|
97
|
+
parallelGroupId: "group-nested",
|
|
98
|
+
parallelIndex: 1,
|
|
99
|
+
activeStreamId: "pending:assistant-nested-b",
|
|
100
|
+
});
|
|
101
|
+
|
|
102
|
+
const store = createThreadStore([
|
|
103
|
+
rootUser,
|
|
104
|
+
branchA,
|
|
105
|
+
nestedUser,
|
|
106
|
+
nestedBranchA,
|
|
107
|
+
]);
|
|
108
|
+
|
|
109
|
+
store.getState().addMessageToTree(branchB);
|
|
110
|
+
store.getState().addMessageToTree(nestedBranchB);
|
|
111
|
+
|
|
112
|
+
store.getState().setMessagesWithEpoch([rootUser, branchB]);
|
|
113
|
+
store.getState().setAllMessages([rootUser, branchA, branchB]);
|
|
114
|
+
|
|
115
|
+
const allMessageIds = store
|
|
116
|
+
.getState()
|
|
117
|
+
.allMessages.map((message: ChatMessage) => message.id);
|
|
118
|
+
assert.deepEqual(allMessageIds, [
|
|
119
|
+
"user-root",
|
|
120
|
+
"assistant-a",
|
|
121
|
+
"assistant-b",
|
|
122
|
+
"user-nested",
|
|
123
|
+
"assistant-nested-a",
|
|
124
|
+
"assistant-nested-b",
|
|
125
|
+
]);
|
|
126
|
+
|
|
127
|
+
const restoredThread = store.getState().switchToMessage(branchA.id);
|
|
128
|
+
assert.deepEqual(
|
|
129
|
+
restoredThread?.map((message: ChatMessage) => message.id),
|
|
130
|
+
["user-root", "assistant-a", "user-nested", "assistant-nested-b"]
|
|
131
|
+
);
|
|
132
|
+
assert.equal(
|
|
133
|
+
restoredThread?.at(-1)?.metadata.activeStreamId,
|
|
134
|
+
"pending:assistant-nested-b"
|
|
135
|
+
);
|
|
136
|
+
});
|
|
137
|
+
});
|
|
@@ -19,6 +19,12 @@ export interface MessageSiblingInfo<UM> {
|
|
|
19
19
|
siblings: UM[];
|
|
20
20
|
}
|
|
21
21
|
|
|
22
|
+
export interface ParallelGroupInfo<UM> {
|
|
23
|
+
messages: UM[];
|
|
24
|
+
parallelGroupId: string;
|
|
25
|
+
selectedMessageId: string | null;
|
|
26
|
+
}
|
|
27
|
+
|
|
22
28
|
export type ThreadAugmentedState<UM extends UIMessage> =
|
|
23
29
|
BaseChatStoreState<UM> & {
|
|
24
30
|
threadEpoch: number;
|
|
@@ -41,6 +47,7 @@ export type ThreadAugmentedState<UM extends UIMessage> =
|
|
|
41
47
|
addMessageToTree: (message: UM) => void;
|
|
42
48
|
/** Look up sibling info for a message. */
|
|
43
49
|
getMessageSiblingInfo: (messageId: string) => MessageSiblingInfo<UM> | null;
|
|
50
|
+
getParallelGroupInfo: (messageId: string) => ParallelGroupInfo<UM> | null;
|
|
44
51
|
/**
|
|
45
52
|
* Switch to a sibling thread. Returns the new thread array,
|
|
46
53
|
* or null if no switch was possible.
|
|
@@ -49,6 +56,7 @@ export type ThreadAugmentedState<UM extends UIMessage> =
|
|
|
49
56
|
messageId: string,
|
|
50
57
|
direction: "prev" | "next"
|
|
51
58
|
) => UM[] | null;
|
|
59
|
+
switchToMessage: (messageId: string) => UM[] | null;
|
|
52
60
|
};
|
|
53
61
|
|
|
54
62
|
export const withThreads =
|
|
@@ -64,6 +72,33 @@ export const withThreads =
|
|
|
64
72
|
const rebuildMap = (msgs: UI_MESSAGE[]) =>
|
|
65
73
|
buildChildrenMap(msgs as (UI_MESSAGE & MessageNode)[]);
|
|
66
74
|
|
|
75
|
+
const mergeTreeMessages = (
|
|
76
|
+
serverMessages: UI_MESSAGE[],
|
|
77
|
+
existingTreeMessages: UI_MESSAGE[],
|
|
78
|
+
currentVisibleMessages: UI_MESSAGE[]
|
|
79
|
+
): UI_MESSAGE[] => {
|
|
80
|
+
const merged = new Map<string, UI_MESSAGE>();
|
|
81
|
+
|
|
82
|
+
for (const message of serverMessages) {
|
|
83
|
+
merged.set(message.id, message);
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
// Preserve every local-only tree node until the server returns a message with
|
|
87
|
+
// the same id. Restricting this to pending assistant shells orphaned optimistic
|
|
88
|
+
// user messages when switching away from an in-flight branch mid-stream.
|
|
89
|
+
for (const message of existingTreeMessages) {
|
|
90
|
+
if (!merged.has(message.id)) {
|
|
91
|
+
merged.set(message.id, message);
|
|
92
|
+
}
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
for (const message of currentVisibleMessages) {
|
|
96
|
+
merged.set(message.id, message);
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
return Array.from(merged.values());
|
|
100
|
+
};
|
|
101
|
+
|
|
67
102
|
return {
|
|
68
103
|
...base,
|
|
69
104
|
threadEpoch: 0,
|
|
@@ -96,10 +131,45 @@ export const withThreads =
|
|
|
96
131
|
},
|
|
97
132
|
|
|
98
133
|
setAllMessages: (messages: UI_MESSAGE[]) => {
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
134
|
+
const state = get();
|
|
135
|
+
const currentVisibleMessages = state.messages;
|
|
136
|
+
const existingTreeMessages = state.allMessages;
|
|
137
|
+
const mergedMessages = mergeTreeMessages(
|
|
138
|
+
messages,
|
|
139
|
+
existingTreeMessages,
|
|
140
|
+
currentVisibleMessages
|
|
141
|
+
);
|
|
142
|
+
|
|
143
|
+
// While the SDK is actively streaming, updating the visible thread with
|
|
144
|
+
// server data would mix the SDK's client-generated message ID with the
|
|
145
|
+
// server's assistantMessageId. The mismatch causes the SDK to push a
|
|
146
|
+
// second assistant message on the next chunk, bumping the epoch and
|
|
147
|
+
// remounting ChatSync mid-stream. Only update the tree index here and
|
|
148
|
+
// let the normal post-stream invalidation apply the full visible update.
|
|
149
|
+
if (state.status === "streaming" || state.status === "submitted") {
|
|
150
|
+
set((prev) => ({
|
|
151
|
+
...prev,
|
|
152
|
+
allMessages: mergedMessages,
|
|
153
|
+
childrenMap: rebuildMap(mergedMessages),
|
|
154
|
+
}));
|
|
155
|
+
return;
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
const currentLeafId = currentVisibleMessages.at(-1)?.id;
|
|
159
|
+
const nextVisibleThread = currentLeafId
|
|
160
|
+
? (buildThreadFromLeaf(
|
|
161
|
+
mergedMessages as (UI_MESSAGE & MessageNode)[],
|
|
162
|
+
currentLeafId
|
|
163
|
+
) as UI_MESSAGE[])
|
|
164
|
+
: currentVisibleMessages;
|
|
165
|
+
|
|
166
|
+
originalSetMessages(nextVisibleThread);
|
|
167
|
+
set((prev) => ({
|
|
168
|
+
...prev,
|
|
169
|
+
messages: nextVisibleThread,
|
|
170
|
+
threadInitialMessages: nextVisibleThread,
|
|
171
|
+
allMessages: mergedMessages,
|
|
172
|
+
childrenMap: rebuildMap(mergedMessages),
|
|
103
173
|
}));
|
|
104
174
|
},
|
|
105
175
|
|
|
@@ -135,6 +205,66 @@ export const withThreads =
|
|
|
135
205
|
return { siblings, siblingIndex };
|
|
136
206
|
},
|
|
137
207
|
|
|
208
|
+
getParallelGroupInfo: (
|
|
209
|
+
messageId: string
|
|
210
|
+
): ParallelGroupInfo<UI_MESSAGE> | null => {
|
|
211
|
+
const state = get();
|
|
212
|
+
const message = state.allMessages.find((item) => item.id === messageId);
|
|
213
|
+
if (!message) {
|
|
214
|
+
return null;
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
const metadata = (message as UI_MESSAGE & MessageNode).metadata;
|
|
218
|
+
const parallelGroupId = metadata?.parallelGroupId || null;
|
|
219
|
+
const parentId =
|
|
220
|
+
message.role === "user"
|
|
221
|
+
? message.id
|
|
222
|
+
: metadata?.parentMessageId || null;
|
|
223
|
+
|
|
224
|
+
if (!(parentId && parallelGroupId)) {
|
|
225
|
+
return null;
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
const groupMessages = (
|
|
229
|
+
(state.childrenMap.get(parentId) ?? []) as UI_MESSAGE[]
|
|
230
|
+
)
|
|
231
|
+
.filter(
|
|
232
|
+
(candidate) =>
|
|
233
|
+
(candidate as UI_MESSAGE & MessageNode).metadata
|
|
234
|
+
?.parallelGroupId === parallelGroupId
|
|
235
|
+
)
|
|
236
|
+
.sort((a, b) => {
|
|
237
|
+
const aIndex =
|
|
238
|
+
(a as UI_MESSAGE & MessageNode).metadata?.parallelIndex ??
|
|
239
|
+
Number.MAX_SAFE_INTEGER;
|
|
240
|
+
const bIndex =
|
|
241
|
+
(b as UI_MESSAGE & MessageNode).metadata?.parallelIndex ??
|
|
242
|
+
Number.MAX_SAFE_INTEGER;
|
|
243
|
+
|
|
244
|
+
if (aIndex !== bIndex) {
|
|
245
|
+
return aIndex - bIndex;
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
return 0;
|
|
249
|
+
});
|
|
250
|
+
|
|
251
|
+
if (groupMessages.length <= 1) {
|
|
252
|
+
return null;
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
const visibleMessageIds = new Set(state.messages.map((m) => m.id));
|
|
256
|
+
const selectedMessageId =
|
|
257
|
+
groupMessages.find((candidate) =>
|
|
258
|
+
visibleMessageIds.has(candidate.id)
|
|
259
|
+
)?.id ?? null;
|
|
260
|
+
|
|
261
|
+
return {
|
|
262
|
+
messages: groupMessages,
|
|
263
|
+
parallelGroupId,
|
|
264
|
+
selectedMessageId,
|
|
265
|
+
};
|
|
266
|
+
},
|
|
267
|
+
|
|
138
268
|
switchToSibling: (
|
|
139
269
|
messageId: string,
|
|
140
270
|
direction: "prev" | "next"
|
|
@@ -170,6 +300,29 @@ export const withThreads =
|
|
|
170
300
|
return newThread;
|
|
171
301
|
},
|
|
172
302
|
|
|
303
|
+
switchToMessage: (messageId: string): UI_MESSAGE[] | null => {
|
|
304
|
+
const state = get();
|
|
305
|
+
const { allMessages, childrenMap } = state;
|
|
306
|
+
const message = allMessages.find(
|
|
307
|
+
(candidate) => candidate.id === messageId
|
|
308
|
+
);
|
|
309
|
+
if (!message) {
|
|
310
|
+
return null;
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
const leaf = findLeafDfsToRightFromMessageId(
|
|
314
|
+
childrenMap as Map<string | null, (UI_MESSAGE & MessageNode)[]>,
|
|
315
|
+
messageId
|
|
316
|
+
);
|
|
317
|
+
const newThread = buildThreadFromLeaf(
|
|
318
|
+
allMessages as (UI_MESSAGE & MessageNode)[],
|
|
319
|
+
leaf ? leaf.id : messageId
|
|
320
|
+
) as UI_MESSAGE[];
|
|
321
|
+
|
|
322
|
+
state.setMessagesWithEpoch(newThread);
|
|
323
|
+
return newThread;
|
|
324
|
+
},
|
|
325
|
+
|
|
173
326
|
// Override setMessages to auto-bump epoch when thread changes
|
|
174
327
|
setMessages: (messages: UI_MESSAGE[]) => {
|
|
175
328
|
const currentMessages = get().messages;
|
|
@@ -3,6 +3,9 @@ export interface MessageNode {
|
|
|
3
3
|
id: string;
|
|
4
4
|
metadata?: {
|
|
5
5
|
parentMessageId: string | null;
|
|
6
|
+
parallelGroupId?: string | null;
|
|
7
|
+
parallelIndex?: number | null;
|
|
8
|
+
activeStreamId?: string | null;
|
|
6
9
|
createdAt: Date;
|
|
7
10
|
};
|
|
8
11
|
}
|
|
@@ -102,8 +105,26 @@ export function buildChildrenMap<T extends MessageNode>(
|
|
|
102
105
|
}
|
|
103
106
|
for (const siblings of map.values()) {
|
|
104
107
|
siblings.sort(
|
|
105
|
-
(a, b) =>
|
|
106
|
-
|
|
108
|
+
(a, b) => {
|
|
109
|
+
const aParallelIndex = a.metadata?.parallelIndex;
|
|
110
|
+
const bParallelIndex = b.metadata?.parallelIndex;
|
|
111
|
+
const sameParallelGroup =
|
|
112
|
+
a.metadata?.parallelGroupId &&
|
|
113
|
+
a.metadata?.parallelGroupId === b.metadata?.parallelGroupId;
|
|
114
|
+
|
|
115
|
+
if (
|
|
116
|
+
sameParallelGroup &&
|
|
117
|
+
typeof aParallelIndex === "number" &&
|
|
118
|
+
typeof bParallelIndex === "number" &&
|
|
119
|
+
aParallelIndex !== bParallelIndex
|
|
120
|
+
) {
|
|
121
|
+
return aParallelIndex - bParallelIndex;
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
return (
|
|
125
|
+
toTimestamp(a.metadata?.createdAt) - toTimestamp(b.metadata?.createdAt)
|
|
126
|
+
);
|
|
127
|
+
}
|
|
107
128
|
);
|
|
108
129
|
}
|
|
109
130
|
return map;
|