@librechat/agents 2.3.0 → 2.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/cjs/graphs/Graph.cjs +6 -6
- package/dist/cjs/graphs/Graph.cjs.map +1 -1
- package/dist/cjs/llm/anthropic/llm.cjs +7 -7
- package/dist/cjs/llm/anthropic/llm.cjs.map +1 -1
- package/dist/cjs/llm/anthropic/utils/message_inputs.cjs +6 -6
- package/dist/cjs/llm/anthropic/utils/message_inputs.cjs.map +1 -1
- package/dist/cjs/llm/anthropic/utils/message_outputs.cjs +24 -24
- package/dist/cjs/llm/anthropic/utils/message_outputs.cjs.map +1 -1
- package/dist/cjs/llm/fake.cjs.map +1 -1
- package/dist/cjs/llm/text.cjs.map +1 -1
- package/dist/cjs/main.cjs +3 -0
- package/dist/cjs/main.cjs.map +1 -1
- package/dist/cjs/messages/core.cjs +5 -5
- package/dist/cjs/messages/core.cjs.map +1 -1
- package/dist/cjs/messages/format.cjs +11 -9
- package/dist/cjs/messages/format.cjs.map +1 -1
- package/dist/cjs/messages/prune.cjs +155 -181
- package/dist/cjs/messages/prune.cjs.map +1 -1
- package/dist/cjs/run.cjs.map +1 -1
- package/dist/cjs/stream.cjs +3 -4
- package/dist/cjs/stream.cjs.map +1 -1
- package/dist/cjs/tools/ToolNode.cjs +1 -1
- package/dist/cjs/tools/ToolNode.cjs.map +1 -1
- package/dist/cjs/utils/tokens.cjs +3 -3
- package/dist/cjs/utils/tokens.cjs.map +1 -1
- package/dist/esm/graphs/Graph.mjs +6 -6
- package/dist/esm/graphs/Graph.mjs.map +1 -1
- package/dist/esm/llm/anthropic/llm.mjs +7 -7
- package/dist/esm/llm/anthropic/llm.mjs.map +1 -1
- package/dist/esm/llm/anthropic/utils/message_inputs.mjs +6 -6
- package/dist/esm/llm/anthropic/utils/message_inputs.mjs.map +1 -1
- package/dist/esm/llm/anthropic/utils/message_outputs.mjs +24 -24
- package/dist/esm/llm/anthropic/utils/message_outputs.mjs.map +1 -1
- package/dist/esm/llm/fake.mjs.map +1 -1
- package/dist/esm/llm/text.mjs.map +1 -1
- package/dist/esm/main.mjs +1 -1
- package/dist/esm/messages/core.mjs +5 -5
- package/dist/esm/messages/core.mjs.map +1 -1
- package/dist/esm/messages/format.mjs +11 -9
- package/dist/esm/messages/format.mjs.map +1 -1
- package/dist/esm/messages/prune.mjs +153 -182
- package/dist/esm/messages/prune.mjs.map +1 -1
- package/dist/esm/run.mjs.map +1 -1
- package/dist/esm/stream.mjs +3 -4
- package/dist/esm/stream.mjs.map +1 -1
- package/dist/esm/tools/ToolNode.mjs +1 -1
- package/dist/esm/tools/ToolNode.mjs.map +1 -1
- package/dist/esm/utils/tokens.mjs +3 -3
- package/dist/esm/utils/tokens.mjs.map +1 -1
- package/dist/types/messages/format.d.ts +1 -2
- package/dist/types/messages/prune.d.ts +31 -2
- package/dist/types/types/stream.d.ts +2 -2
- package/dist/types/utils/tokens.d.ts +1 -1
- package/package.json +4 -3
- package/src/graphs/Graph.ts +8 -8
- package/src/llm/anthropic/llm.ts +7 -8
- package/src/llm/anthropic/types.ts +4 -4
- package/src/llm/anthropic/utils/message_inputs.ts +6 -6
- package/src/llm/anthropic/utils/message_outputs.ts +39 -39
- package/src/llm/fake.ts +2 -2
- package/src/llm/text.ts +1 -1
- package/src/messages/core.ts +6 -6
- package/src/messages/format.ts +43 -42
- package/src/messages/formatAgentMessages.test.ts +35 -35
- package/src/messages/formatAgentMessages.tools.test.ts +30 -30
- package/src/messages/prune.ts +182 -226
- package/src/messages/shiftIndexTokenCountMap.test.ts +18 -18
- package/src/mockStream.ts +1 -1
- package/src/run.ts +2 -2
- package/src/specs/prune.test.ts +89 -89
- package/src/specs/reasoning.test.ts +1 -1
- package/src/specs/thinking-prune.test.ts +291 -243
- package/src/specs/tool-error.test.ts +16 -17
- package/src/stream.ts +13 -14
- package/src/tools/ToolNode.ts +1 -1
- package/src/types/stream.ts +4 -3
- package/src/utils/tokens.ts +12 -12
package/src/messages/prune.ts
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import
|
|
1
|
+
import { concat } from '@langchain/core/utils/stream';
|
|
2
|
+
import { AIMessage, BaseMessage, UsageMetadata } from '@langchain/core/messages';
|
|
3
|
+
import type { ThinkingContentText, MessageContentComplex } from '@/types/stream';
|
|
3
4
|
import type { TokenCounter } from '@/types/run';
|
|
5
|
+
import { ContentTypes } from '@/common';
|
|
4
6
|
export type PruneMessagesFactoryParams = {
|
|
5
7
|
maxTokens: number;
|
|
6
8
|
startIndex: number;
|
|
@@ -11,20 +13,25 @@ export type PruneMessagesFactoryParams = {
|
|
|
11
13
|
export type PruneMessagesParams = {
|
|
12
14
|
messages: BaseMessage[];
|
|
13
15
|
usageMetadata?: Partial<UsageMetadata>;
|
|
14
|
-
|
|
16
|
+
startType?: ReturnType<BaseMessage['getType']>;
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
function isIndexInContext(arrayA: BaseMessage[], arrayB: BaseMessage[], targetIndex: number): boolean {
|
|
20
|
+
const startingIndexInA = arrayA.length - arrayB.length;
|
|
21
|
+
return targetIndex >= startingIndexInA;
|
|
15
22
|
}
|
|
16
23
|
|
|
17
24
|
/**
|
|
18
25
|
* Calculates the total tokens from a single usage object
|
|
19
|
-
*
|
|
26
|
+
*
|
|
20
27
|
* @param usage The usage metadata object containing token information
|
|
21
28
|
* @returns An object containing the total input and output tokens
|
|
22
29
|
*/
|
|
23
|
-
function calculateTotalTokens(usage: Partial<UsageMetadata>): UsageMetadata {
|
|
30
|
+
export function calculateTotalTokens(usage: Partial<UsageMetadata>): UsageMetadata {
|
|
24
31
|
const baseInputTokens = Number(usage.input_tokens) || 0;
|
|
25
32
|
const cacheCreation = Number(usage.input_token_details?.cache_creation) || 0;
|
|
26
33
|
const cacheRead = Number(usage.input_token_details?.cache_read) || 0;
|
|
27
|
-
|
|
34
|
+
|
|
28
35
|
const totalInputTokens = baseInputTokens + cacheCreation + cacheRead;
|
|
29
36
|
const totalOutputTokens = Number(usage.output_tokens) || 0;
|
|
30
37
|
|
|
@@ -38,273 +45,222 @@ function calculateTotalTokens(usage: Partial<UsageMetadata>): UsageMetadata {
|
|
|
38
45
|
/**
|
|
39
46
|
* Processes an array of messages and returns a context of messages that fit within a specified token limit.
|
|
40
47
|
* It iterates over the messages from newest to oldest, adding them to the context until the token limit is reached.
|
|
41
|
-
*
|
|
48
|
+
*
|
|
42
49
|
* @param options Configuration options for processing messages
|
|
43
50
|
* @returns Object containing the message context, remaining tokens, messages not included, and summary index
|
|
44
51
|
*/
|
|
45
|
-
function getMessagesWithinTokenLimit({
|
|
52
|
+
export function getMessagesWithinTokenLimit({
|
|
46
53
|
messages: _messages,
|
|
47
54
|
maxContextTokens,
|
|
48
55
|
indexTokenCountMap,
|
|
49
|
-
|
|
56
|
+
startType: _startType,
|
|
50
57
|
thinkingEnabled,
|
|
58
|
+
/** We may need to use this when recalculating */
|
|
51
59
|
tokenCounter,
|
|
52
60
|
}: {
|
|
53
61
|
messages: BaseMessage[];
|
|
54
62
|
maxContextTokens: number;
|
|
55
|
-
indexTokenCountMap: Record<string, number>;
|
|
56
|
-
|
|
63
|
+
indexTokenCountMap: Record<string, number | undefined>;
|
|
64
|
+
tokenCounter: TokenCounter;
|
|
65
|
+
startType?: string;
|
|
57
66
|
thinkingEnabled?: boolean;
|
|
58
|
-
tokenCounter?: TokenCounter;
|
|
59
67
|
}): {
|
|
60
68
|
context: BaseMessage[];
|
|
61
69
|
remainingContextTokens: number;
|
|
62
70
|
messagesToRefine: BaseMessage[];
|
|
63
|
-
summaryIndex: number;
|
|
64
71
|
} {
|
|
65
72
|
// Every reply is primed with <|start|>assistant<|message|>, so we
|
|
66
73
|
// start with 3 tokens for the label after all messages have been counted.
|
|
67
|
-
let summaryIndex = -1;
|
|
68
74
|
let currentTokenCount = 3;
|
|
69
|
-
const instructions = _messages
|
|
70
|
-
const instructionsTokenCount = instructions != null ? indexTokenCountMap[0] : 0;
|
|
71
|
-
|
|
75
|
+
const instructions = _messages[0]?.getType() === 'system' ? _messages[0] : undefined;
|
|
76
|
+
const instructionsTokenCount = instructions != null ? indexTokenCountMap[0] ?? 0 : 0;
|
|
77
|
+
const initialContextTokens = maxContextTokens - instructionsTokenCount;
|
|
78
|
+
let remainingContextTokens = initialContextTokens;
|
|
79
|
+
let startType = _startType;
|
|
80
|
+
const originalLength = _messages.length;
|
|
72
81
|
const messages = [..._messages];
|
|
82
|
+
/**
|
|
83
|
+
* IMPORTANT: this context array gets reversed at the end, since the latest messages get pushed first.
|
|
84
|
+
*
|
|
85
|
+
* This may be confusing to read, but it is done to ensure the context is in the correct order for the model.
|
|
86
|
+
* */
|
|
73
87
|
let context: BaseMessage[] = [];
|
|
74
88
|
|
|
89
|
+
let thinkingStartIndex = -1;
|
|
90
|
+
let thinkingEndIndex = -1;
|
|
91
|
+
let thinkingBlock: ThinkingContentText | undefined;
|
|
92
|
+
const endIndex = instructions != null ? 1 : 0;
|
|
93
|
+
const prunedMemory: BaseMessage[] = [];
|
|
94
|
+
|
|
75
95
|
if (currentTokenCount < remainingContextTokens) {
|
|
76
96
|
let currentIndex = messages.length;
|
|
77
|
-
while (messages.length > 0 && currentTokenCount < remainingContextTokens && currentIndex >
|
|
97
|
+
while (messages.length > 0 && currentTokenCount < remainingContextTokens && currentIndex > endIndex) {
|
|
78
98
|
currentIndex--;
|
|
79
99
|
if (messages.length === 1 && instructions) {
|
|
80
100
|
break;
|
|
81
101
|
}
|
|
82
102
|
const poppedMessage = messages.pop();
|
|
83
103
|
if (!poppedMessage) continue;
|
|
84
|
-
|
|
85
|
-
|
|
104
|
+
const messageType = poppedMessage.getType();
|
|
105
|
+
if (thinkingEnabled === true && thinkingEndIndex === -1 && (currentIndex === (originalLength - 1)) && (messageType === 'ai' || messageType === 'tool')) {
|
|
106
|
+
thinkingEndIndex = currentIndex;
|
|
107
|
+
}
|
|
108
|
+
if (thinkingEndIndex > -1 && !thinkingBlock && thinkingStartIndex < 0 && messageType === 'ai' && Array.isArray(poppedMessage.content)) {
|
|
109
|
+
thinkingBlock = (poppedMessage.content.find((content) => content.type === ContentTypes.THINKING)) as ThinkingContentText | undefined;
|
|
110
|
+
thinkingStartIndex = thinkingBlock != null ? currentIndex : -1;
|
|
111
|
+
}
|
|
112
|
+
/** False start, the latest message was not part of a multi-assistant/tool sequence of messages */
|
|
113
|
+
if (
|
|
114
|
+
thinkingEndIndex > -1
|
|
115
|
+
&& currentIndex === (thinkingEndIndex - 1)
|
|
116
|
+
&& (messageType !== 'ai' && messageType !== 'tool')
|
|
117
|
+
) {
|
|
118
|
+
thinkingEndIndex = -1;
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
const tokenCount = indexTokenCountMap[currentIndex] ?? 0;
|
|
86
122
|
|
|
87
|
-
if ((currentTokenCount + tokenCount) <= remainingContextTokens) {
|
|
123
|
+
if (prunedMemory.length === 0 && ((currentTokenCount + tokenCount) <= remainingContextTokens)) {
|
|
88
124
|
context.push(poppedMessage);
|
|
89
125
|
currentTokenCount += tokenCount;
|
|
90
126
|
} else {
|
|
91
|
-
|
|
127
|
+
prunedMemory.push(poppedMessage);
|
|
128
|
+
if (thinkingEndIndex > -1) {
|
|
129
|
+
continue;
|
|
130
|
+
}
|
|
92
131
|
break;
|
|
93
132
|
}
|
|
94
133
|
}
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
if (
|
|
101
|
-
|
|
134
|
+
|
|
135
|
+
if (thinkingEndIndex > -1 && context[context.length - 1].getType() === 'tool') {
|
|
136
|
+
startType = 'ai';
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
if (startType != null && startType && context.length > 0) {
|
|
140
|
+
const requiredTypeIndex = context.findIndex(msg => msg.getType() === startType);
|
|
141
|
+
|
|
142
|
+
if (requiredTypeIndex > 0) {
|
|
143
|
+
context = context.slice(requiredTypeIndex);
|
|
144
|
+
}
|
|
102
145
|
}
|
|
103
146
|
}
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
if (instructions && _messages.length > 0) {
|
|
147
|
+
|
|
148
|
+
if (instructions && originalLength > 0) {
|
|
107
149
|
context.push(_messages[0] as BaseMessage);
|
|
108
150
|
messages.shift();
|
|
109
151
|
}
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
}
|
|
142
|
-
}
|
|
143
|
-
if (thinkingBlocks.length > 0) break; // Stop after finding one thinking block
|
|
144
|
-
}
|
|
145
|
-
}
|
|
146
|
-
|
|
147
|
-
// If we found thinking blocks, add them to the first assistant message
|
|
148
|
-
if (thinkingBlocks.length > 0) {
|
|
149
|
-
// Calculate token count of original message
|
|
150
|
-
const originalTokenCount = tokenCounter(firstAssistantMsg);
|
|
151
|
-
|
|
152
|
-
// Create a new content array with thinking blocks at the beginning
|
|
153
|
-
let newContent: any[];
|
|
154
|
-
|
|
155
|
-
if (Array.isArray(firstAssistantMsg.content)) {
|
|
156
|
-
// Keep the original content (excluding any existing thinking blocks)
|
|
157
|
-
const originalContent = firstAssistantMsg.content.filter(item =>
|
|
158
|
-
!(item && typeof item === 'object' && item.type === 'thinking'));
|
|
159
|
-
|
|
160
|
-
newContent = [...thinkingBlocks, ...originalContent];
|
|
161
|
-
} else if (typeof firstAssistantMsg.content === 'string') {
|
|
162
|
-
newContent = [
|
|
163
|
-
...thinkingBlocks,
|
|
164
|
-
{ type: 'text', text: firstAssistantMsg.content }
|
|
165
|
-
];
|
|
166
|
-
} else {
|
|
167
|
-
newContent = thinkingBlocks;
|
|
168
|
-
}
|
|
169
|
-
|
|
170
|
-
// Create a new message with the updated content
|
|
171
|
-
const newMessage = new AIMessage({
|
|
172
|
-
content: newContent,
|
|
173
|
-
additional_kwargs: firstAssistantMsg.additional_kwargs,
|
|
174
|
-
response_metadata: firstAssistantMsg.response_metadata,
|
|
175
|
-
});
|
|
176
|
-
|
|
177
|
-
// Calculate token count of new message
|
|
178
|
-
const newTokenCount = tokenCounter(newMessage);
|
|
179
|
-
|
|
180
|
-
// Adjust current token count
|
|
181
|
-
currentTokenCount += (newTokenCount - originalTokenCount);
|
|
182
|
-
|
|
183
|
-
// Replace the first assistant message
|
|
184
|
-
context[firstAssistantIndex] = newMessage;
|
|
185
|
-
|
|
186
|
-
// If we've exceeded the token limit, we need to prune more messages
|
|
187
|
-
if (currentTokenCount > remainingContextTokens) {
|
|
188
|
-
// Build a map of tool call IDs to track AI <--> tool message correspondences
|
|
189
|
-
const toolCallIdMap = new Map<string, number>();
|
|
190
|
-
|
|
191
|
-
// Identify tool call IDs in the context
|
|
192
|
-
for (let i = 0; i < context.length; i++) {
|
|
193
|
-
const msg = context[i];
|
|
194
|
-
|
|
195
|
-
// Check for tool calls in AI messages
|
|
196
|
-
if (msg.getType() === 'ai' && Array.isArray(msg.content)) {
|
|
197
|
-
for (const item of msg.content) {
|
|
198
|
-
if (item && typeof item === 'object' && item.type === 'tool_use' && item.id) {
|
|
199
|
-
toolCallIdMap.set(item.id, i);
|
|
200
|
-
}
|
|
201
|
-
}
|
|
202
|
-
}
|
|
203
|
-
|
|
204
|
-
// Check for tool messages
|
|
205
|
-
if (msg.getType() === 'tool' && 'tool_call_id' in msg && typeof msg.tool_call_id === 'string') {
|
|
206
|
-
toolCallIdMap.set(msg.tool_call_id, i);
|
|
207
|
-
}
|
|
208
|
-
}
|
|
209
|
-
|
|
210
|
-
// Track which messages to remove
|
|
211
|
-
const indicesToRemove = new Set<number>();
|
|
212
|
-
|
|
213
|
-
// Start removing messages from the end, but preserve AI <--> tool message correspondences
|
|
214
|
-
let i = context.length - 1;
|
|
215
|
-
while (i > firstAssistantIndex && currentTokenCount > remainingContextTokens) {
|
|
216
|
-
const msgToRemove = context[i];
|
|
217
|
-
|
|
218
|
-
// Check if this is a tool message or has tool calls
|
|
219
|
-
let canRemove = true;
|
|
220
|
-
|
|
221
|
-
if (msgToRemove.getType() === 'tool' && 'tool_call_id' in msgToRemove && typeof msgToRemove.tool_call_id === 'string') {
|
|
222
|
-
// If this is a tool message, check if we need to keep its corresponding AI message
|
|
223
|
-
const aiIndex = toolCallIdMap.get(msgToRemove.tool_call_id);
|
|
224
|
-
if (aiIndex !== undefined && aiIndex !== i && !indicesToRemove.has(aiIndex)) {
|
|
225
|
-
// We need to remove both the tool message and its corresponding AI message
|
|
226
|
-
indicesToRemove.add(i);
|
|
227
|
-
indicesToRemove.add(aiIndex);
|
|
228
|
-
currentTokenCount -= (tokenCounter(msgToRemove) + tokenCounter(context[aiIndex]));
|
|
229
|
-
canRemove = false;
|
|
230
|
-
}
|
|
231
|
-
} else if (msgToRemove.getType() === 'ai' && Array.isArray(msgToRemove.content)) {
|
|
232
|
-
// If this is an AI message with tool calls, check if we need to keep its corresponding tool messages
|
|
233
|
-
for (const item of msgToRemove.content) {
|
|
234
|
-
if (item && typeof item === 'object' && item.type === 'tool_use' && item.id) {
|
|
235
|
-
const toolIndex = toolCallIdMap.get(item.id as string);
|
|
236
|
-
if (toolIndex !== undefined && toolIndex !== i && !indicesToRemove.has(toolIndex)) {
|
|
237
|
-
// We need to remove both the AI message and its corresponding tool message
|
|
238
|
-
indicesToRemove.add(i);
|
|
239
|
-
indicesToRemove.add(toolIndex);
|
|
240
|
-
currentTokenCount -= (tokenCounter(msgToRemove) + tokenCounter(context[toolIndex]));
|
|
241
|
-
canRemove = false;
|
|
242
|
-
break;
|
|
243
|
-
}
|
|
244
|
-
}
|
|
245
|
-
}
|
|
246
|
-
}
|
|
247
|
-
|
|
248
|
-
// If we can remove this message individually
|
|
249
|
-
if (canRemove && !indicesToRemove.has(i)) {
|
|
250
|
-
indicesToRemove.add(i);
|
|
251
|
-
currentTokenCount -= tokenCounter(msgToRemove);
|
|
252
|
-
}
|
|
253
|
-
|
|
254
|
-
i--;
|
|
255
|
-
}
|
|
256
|
-
|
|
257
|
-
// Remove messages in reverse order to avoid index shifting
|
|
258
|
-
const sortedIndices = Array.from(indicesToRemove).sort((a, b) => b - a);
|
|
259
|
-
for (const index of sortedIndices) {
|
|
260
|
-
context.splice(index, 1);
|
|
261
|
-
}
|
|
262
|
-
|
|
263
|
-
// Update remainingContextTokens to reflect the new token count
|
|
264
|
-
remainingContextTokens = maxContextTokens - currentTokenCount;
|
|
265
|
-
}
|
|
266
|
-
}
|
|
267
|
-
}
|
|
152
|
+
|
|
153
|
+
remainingContextTokens -= currentTokenCount;
|
|
154
|
+
const result = {
|
|
155
|
+
remainingContextTokens,
|
|
156
|
+
context: [] as BaseMessage[],
|
|
157
|
+
messagesToRefine: prunedMemory,
|
|
158
|
+
};
|
|
159
|
+
|
|
160
|
+
if (prunedMemory.length === 0 || thinkingEndIndex < 0 || (thinkingStartIndex > -1 && isIndexInContext(_messages, context, thinkingStartIndex))) {
|
|
161
|
+
// we reverse at this step to ensure the context is in the correct order for the model, and we need to work backwards
|
|
162
|
+
result.context = context.reverse();
|
|
163
|
+
return result;
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
if (thinkingEndIndex > -1 && thinkingStartIndex < 0) {
|
|
167
|
+
throw new Error('The payload is malformed. There is a thinking sequence but no "AI" messages with thinking blocks.');
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
if (!thinkingBlock) {
|
|
171
|
+
throw new Error('The payload is malformed. There is a thinking sequence but no thinking block found.');
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
// Since we have a thinking sequence, we need to find the last assistant message
|
|
175
|
+
// in the latest AI/tool sequence to add the thinking block that falls outside of the current context
|
|
176
|
+
// Latest messages are ordered first.
|
|
177
|
+
let assistantIndex = -1;
|
|
178
|
+
for (let i = 0; i < context.length; i++) {
|
|
179
|
+
const currentMessage = context[i];
|
|
180
|
+
const type = currentMessage.getType();
|
|
181
|
+
if (type === 'ai') {
|
|
182
|
+
assistantIndex = i;
|
|
268
183
|
}
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
// but maintain system message precedence
|
|
272
|
-
if (latestMessageIsAssistant && context.length > 0) {
|
|
273
|
-
// Find the first assistant message in the context
|
|
274
|
-
const assistantIndex = context.findIndex(msg => msg.getType() === 'ai');
|
|
275
|
-
if (assistantIndex > 0) {
|
|
276
|
-
// Check if there's a system message at the beginning
|
|
277
|
-
const hasSystemFirst = context[0].getType() === 'system';
|
|
278
|
-
|
|
279
|
-
// Move the assistant message to the appropriate position
|
|
280
|
-
const assistantMsg = context[assistantIndex];
|
|
281
|
-
context.splice(assistantIndex, 1);
|
|
282
|
-
|
|
283
|
-
if (hasSystemFirst) {
|
|
284
|
-
// Insert after the system message
|
|
285
|
-
context.splice(1, 0, assistantMsg);
|
|
286
|
-
} else {
|
|
287
|
-
// Insert at the beginning if no system message
|
|
288
|
-
context.unshift(assistantMsg);
|
|
289
|
-
}
|
|
290
|
-
}
|
|
184
|
+
if (assistantIndex > -1 && (type === 'human' || type === 'system')) {
|
|
185
|
+
break;
|
|
291
186
|
}
|
|
292
187
|
}
|
|
188
|
+
|
|
189
|
+
if (assistantIndex === -1) {
|
|
190
|
+
throw new Error('The payload is malformed. There is a thinking sequence but no "AI" messages to append thinking blocks to.');
|
|
293
191
|
}
|
|
294
192
|
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
remainingContextTokens
|
|
193
|
+
thinkingStartIndex = originalLength - 1 - assistantIndex;
|
|
194
|
+
const thinkingTokenCount = tokenCounter(new AIMessage({ content: [thinkingBlock] }));
|
|
195
|
+
const newRemainingCount = remainingContextTokens - thinkingTokenCount;
|
|
298
196
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
197
|
+
const content: MessageContentComplex[] = Array.isArray(context[assistantIndex].content)
|
|
198
|
+
? context[assistantIndex].content as MessageContentComplex[]
|
|
199
|
+
: [{
|
|
200
|
+
type: ContentTypes.TEXT,
|
|
201
|
+
text: context[assistantIndex].content,
|
|
202
|
+
}];
|
|
203
|
+
content.unshift(thinkingBlock);
|
|
204
|
+
context[assistantIndex].content = content;
|
|
205
|
+
if (newRemainingCount > 0) {
|
|
206
|
+
result.context = context.reverse();
|
|
207
|
+
return result;
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
const thinkingMessage: AIMessage = context[assistantIndex];
|
|
211
|
+
// now we need to an additional round of pruning but making the thinking block fit
|
|
212
|
+
const newThinkingMessageTokenCount = (indexTokenCountMap[thinkingStartIndex] ?? 0) + thinkingTokenCount;
|
|
213
|
+
remainingContextTokens = initialContextTokens - newThinkingMessageTokenCount;
|
|
214
|
+
currentTokenCount = 3;
|
|
215
|
+
let newContext: BaseMessage[] = [];
|
|
216
|
+
const secondRoundMessages = [..._messages];
|
|
217
|
+
let currentIndex = secondRoundMessages.length;
|
|
218
|
+
while (secondRoundMessages.length > 0 && currentTokenCount < remainingContextTokens && currentIndex > thinkingStartIndex) {
|
|
219
|
+
currentIndex--;
|
|
220
|
+
const poppedMessage = secondRoundMessages.pop();
|
|
221
|
+
if (!poppedMessage) continue;
|
|
222
|
+
const tokenCount = indexTokenCountMap[currentIndex] ?? 0;
|
|
223
|
+
if ((currentTokenCount + tokenCount) <= remainingContextTokens) {
|
|
224
|
+
newContext.push(poppedMessage);
|
|
225
|
+
currentTokenCount += tokenCount;
|
|
226
|
+
} else {
|
|
227
|
+
messages.push(poppedMessage);
|
|
228
|
+
break;
|
|
229
|
+
}
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
const firstMessage: AIMessage = newContext[newContext.length - 1];
|
|
233
|
+
const firstMessageType = newContext[newContext.length - 1].getType();
|
|
234
|
+
if (firstMessageType === 'tool') {
|
|
235
|
+
startType = 'ai';
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
if (startType != null && startType && newContext.length > 0) {
|
|
239
|
+
const requiredTypeIndex = newContext.findIndex(msg => msg.getType() === startType);
|
|
240
|
+
if (requiredTypeIndex > 0) {
|
|
241
|
+
newContext = newContext.slice(requiredTypeIndex);
|
|
242
|
+
}
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
if (firstMessageType === 'ai') {
|
|
246
|
+
newContext[newContext.length - 1] = new AIMessage({
|
|
247
|
+
content: concat(thinkingMessage.content as MessageContentComplex[], newContext[newContext.length - 1].content as MessageContentComplex[]),
|
|
248
|
+
tool_calls: concat(firstMessage.tool_calls, thinkingMessage.tool_calls),
|
|
249
|
+
});
|
|
250
|
+
} else {
|
|
251
|
+
newContext.push(thinkingMessage);
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
if (instructions && originalLength > 0) {
|
|
255
|
+
newContext.push(_messages[0] as BaseMessage);
|
|
256
|
+
secondRoundMessages.shift();
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
result.context = newContext.reverse();
|
|
260
|
+
return result;
|
|
305
261
|
}
|
|
306
262
|
|
|
307
|
-
function checkValidNumber(value: unknown): value is number {
|
|
263
|
+
export function checkValidNumber(value: unknown): value is number {
|
|
308
264
|
return typeof value === 'number' && !isNaN(value) && value > 0;
|
|
309
265
|
}
|
|
310
266
|
|
|
@@ -312,7 +268,6 @@ export function createPruneMessages(factoryParams: PruneMessagesFactoryParams) {
|
|
|
312
268
|
const indexTokenCountMap = { ...factoryParams.indexTokenCountMap };
|
|
313
269
|
let lastTurnStartIndex = factoryParams.startIndex;
|
|
314
270
|
let totalTokens = (Object.values(indexTokenCountMap)).reduce((a, b) => a + b, 0);
|
|
315
|
-
|
|
316
271
|
return function pruneMessages(params: PruneMessagesParams): {
|
|
317
272
|
context: BaseMessage[];
|
|
318
273
|
indexTokenCountMap: Record<string, number>;
|
|
@@ -334,8 +289,10 @@ export function createPruneMessages(factoryParams: PruneMessagesFactoryParams) {
|
|
|
334
289
|
|
|
335
290
|
for (let i = lastTurnStartIndex; i < params.messages.length; i++) {
|
|
336
291
|
const message = params.messages[i];
|
|
292
|
+
// eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
|
|
337
293
|
if (i === lastTurnStartIndex && indexTokenCountMap[i] === undefined && currentUsage) {
|
|
338
294
|
indexTokenCountMap[i] = currentUsage.output_tokens;
|
|
295
|
+
// eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
|
|
339
296
|
} else if (indexTokenCountMap[i] === undefined) {
|
|
340
297
|
indexTokenCountMap[i] = factoryParams.tokenCounter(message);
|
|
341
298
|
totalTokens += indexTokenCountMap[i];
|
|
@@ -359,16 +316,15 @@ export function createPruneMessages(factoryParams: PruneMessagesFactoryParams) {
|
|
|
359
316
|
return { context: params.messages, indexTokenCountMap };
|
|
360
317
|
}
|
|
361
318
|
|
|
362
|
-
// Pass the tokenCounter to getMessagesWithinTokenLimit for token recalculation
|
|
363
319
|
const { context } = getMessagesWithinTokenLimit({
|
|
364
320
|
maxContextTokens: factoryParams.maxTokens,
|
|
365
321
|
messages: params.messages,
|
|
366
322
|
indexTokenCountMap,
|
|
367
|
-
|
|
323
|
+
startType: params.startType,
|
|
368
324
|
thinkingEnabled: factoryParams.thinkingEnabled,
|
|
369
325
|
tokenCounter: factoryParams.tokenCounter,
|
|
370
326
|
});
|
|
371
327
|
|
|
372
328
|
return { context, indexTokenCountMap };
|
|
373
|
-
}
|
|
329
|
+
};
|
|
374
330
|
}
|
|
@@ -7,56 +7,56 @@ describe('shiftIndexTokenCountMap', () => {
|
|
|
7
7
|
1: 20,
|
|
8
8
|
2: 30
|
|
9
9
|
};
|
|
10
|
-
|
|
10
|
+
|
|
11
11
|
const systemMessageTokenCount = 15;
|
|
12
|
-
|
|
12
|
+
|
|
13
13
|
const result = shiftIndexTokenCountMap(originalMap, systemMessageTokenCount);
|
|
14
|
-
|
|
14
|
+
|
|
15
15
|
// Check that the system message token count is at index 0
|
|
16
16
|
expect(result[0]).toBe(15);
|
|
17
|
-
|
|
17
|
+
|
|
18
18
|
// Check that all other indices are shifted by 1
|
|
19
19
|
expect(result[1]).toBe(10);
|
|
20
20
|
expect(result[2]).toBe(20);
|
|
21
21
|
expect(result[3]).toBe(30);
|
|
22
|
-
|
|
22
|
+
|
|
23
23
|
// Check that the original map is not modified
|
|
24
24
|
expect(originalMap[0]).toBe(10);
|
|
25
25
|
expect(originalMap[1]).toBe(20);
|
|
26
26
|
expect(originalMap[2]).toBe(30);
|
|
27
27
|
});
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
it('should handle an empty map', () => {
|
|
30
30
|
const emptyMap: Record<number, number> = {};
|
|
31
31
|
const systemMessageTokenCount = 15;
|
|
32
|
-
|
|
32
|
+
|
|
33
33
|
const result = shiftIndexTokenCountMap(emptyMap, systemMessageTokenCount);
|
|
34
|
-
|
|
34
|
+
|
|
35
35
|
// Check that only the system message token count is in the result
|
|
36
36
|
expect(Object.keys(result).length).toBe(1);
|
|
37
37
|
expect(result[0]).toBe(15);
|
|
38
38
|
});
|
|
39
|
-
|
|
39
|
+
|
|
40
40
|
it('should handle non-sequential indices', () => {
|
|
41
41
|
const nonSequentialMap: Record<number, number> = {
|
|
42
42
|
0: 10,
|
|
43
43
|
2: 20,
|
|
44
44
|
5: 30
|
|
45
45
|
};
|
|
46
|
-
|
|
46
|
+
|
|
47
47
|
const systemMessageTokenCount = 15;
|
|
48
|
-
|
|
48
|
+
|
|
49
49
|
const result = shiftIndexTokenCountMap(nonSequentialMap, systemMessageTokenCount);
|
|
50
|
-
|
|
50
|
+
|
|
51
51
|
// Check that the system message token count is at index 0
|
|
52
52
|
expect(result[0]).toBe(15);
|
|
53
|
-
|
|
53
|
+
|
|
54
54
|
// Check that all other indices are shifted by 1
|
|
55
55
|
expect(result[1]).toBe(10);
|
|
56
56
|
expect(result[3]).toBe(20);
|
|
57
57
|
expect(result[6]).toBe(30);
|
|
58
58
|
});
|
|
59
|
-
|
|
59
|
+
|
|
60
60
|
it('should handle string keys', () => {
|
|
61
61
|
// TypeScript will convert string keys to numbers when accessing the object
|
|
62
62
|
const mapWithStringKeys: Record<string, number> = {
|
|
@@ -64,15 +64,15 @@ describe('shiftIndexTokenCountMap', () => {
|
|
|
64
64
|
'1': 20,
|
|
65
65
|
'2': 30
|
|
66
66
|
};
|
|
67
|
-
|
|
67
|
+
|
|
68
68
|
const systemMessageTokenCount = 15;
|
|
69
|
-
|
|
69
|
+
|
|
70
70
|
// Cast to Record<number, number> to match the function signature
|
|
71
71
|
const result = shiftIndexTokenCountMap(mapWithStringKeys as unknown as Record<number, number>, systemMessageTokenCount);
|
|
72
|
-
|
|
72
|
+
|
|
73
73
|
// Check that the system message token count is at index 0
|
|
74
74
|
expect(result[0]).toBe(15);
|
|
75
|
-
|
|
75
|
+
|
|
76
76
|
// Check that all other indices are shifted by 1
|
|
77
77
|
expect(result[1]).toBe(10);
|
|
78
78
|
expect(result[2]).toBe(20);
|
package/src/mockStream.ts
CHANGED
package/src/run.ts
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
// src/run.ts
|
|
2
|
-
import { zodToJsonSchema } from
|
|
2
|
+
import { zodToJsonSchema } from 'zod-to-json-schema';
|
|
3
3
|
import { PromptTemplate } from '@langchain/core/prompts';
|
|
4
4
|
import { AzureChatOpenAI, ChatOpenAI } from '@langchain/openai';
|
|
5
5
|
import { SystemMessage } from '@langchain/core/messages';
|
|
@@ -115,7 +115,7 @@ export class Run<T extends t.BaseGraphState> {
|
|
|
115
115
|
if (!tool.schema) {
|
|
116
116
|
return acc;
|
|
117
117
|
}
|
|
118
|
-
|
|
118
|
+
|
|
119
119
|
const jsonSchema = zodToJsonSchema(tool.schema.describe(tool.description ?? ''), tool.name);
|
|
120
120
|
return acc + tokenCounter(new SystemMessage(JSON.stringify(jsonSchema)));
|
|
121
121
|
}, 0) ?? 0) : 0;
|