@librechat/agents 2.2.2 → 2.2.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/cjs/graphs/Graph.cjs +50 -14
- package/dist/cjs/graphs/Graph.cjs.map +1 -1
- package/dist/cjs/main.cjs +3 -4
- package/dist/cjs/main.cjs.map +1 -1
- package/dist/cjs/messages/format.cjs +21 -0
- package/dist/cjs/messages/format.cjs.map +1 -1
- package/dist/cjs/messages/prune.cjs +124 -0
- package/dist/cjs/messages/prune.cjs.map +1 -0
- package/dist/cjs/run.cjs +24 -0
- package/dist/cjs/run.cjs.map +1 -1
- package/dist/cjs/utils/tokens.cjs +64 -0
- package/dist/cjs/utils/tokens.cjs.map +1 -0
- package/dist/esm/graphs/Graph.mjs +50 -14
- package/dist/esm/graphs/Graph.mjs.map +1 -1
- package/dist/esm/main.mjs +2 -3
- package/dist/esm/main.mjs.map +1 -1
- package/dist/esm/messages/format.mjs +21 -1
- package/dist/esm/messages/format.mjs.map +1 -1
- package/dist/esm/messages/prune.mjs +122 -0
- package/dist/esm/messages/prune.mjs.map +1 -0
- package/dist/esm/run.mjs +24 -0
- package/dist/esm/run.mjs.map +1 -1
- package/dist/esm/utils/tokens.mjs +62 -0
- package/dist/esm/utils/tokens.mjs.map +1 -0
- package/dist/types/graphs/Graph.d.ts +8 -1
- package/dist/types/messages/format.d.ts +9 -0
- package/dist/types/messages/index.d.ts +1 -2
- package/dist/types/messages/prune.d.ts +16 -0
- package/dist/types/types/run.d.ts +4 -0
- package/dist/types/utils/tokens.d.ts +2 -0
- package/package.json +1 -1
- package/src/graphs/Graph.ts +54 -16
- package/src/messages/format.ts +27 -0
- package/src/messages/index.ts +1 -2
- package/src/messages/prune.ts +167 -0
- package/src/messages/shiftIndexTokenCountMap.test.ts +81 -0
- package/src/run.ts +26 -0
- package/src/scripts/code_exec_simple.ts +21 -8
- package/src/specs/prune.test.ts +444 -0
- package/src/types/run.ts +5 -0
- package/src/utils/tokens.ts +70 -0
- package/dist/cjs/messages/transformers.cjs +0 -318
- package/dist/cjs/messages/transformers.cjs.map +0 -1
- package/dist/cjs/messages/trimMessagesFactory.cjs +0 -129
- package/dist/cjs/messages/trimMessagesFactory.cjs.map +0 -1
- package/dist/esm/messages/transformers.mjs +0 -316
- package/dist/esm/messages/transformers.mjs.map +0 -1
- package/dist/esm/messages/trimMessagesFactory.mjs +0 -127
- package/dist/esm/messages/trimMessagesFactory.mjs.map +0 -1
- package/dist/types/messages/transformers.d.ts +0 -320
- package/dist/types/messages/trimMessagesFactory.d.ts +0 -37
- package/src/messages/transformers.ts +0 -786
- package/src/messages/trimMessagesFactory.test.ts +0 -331
- package/src/messages/trimMessagesFactory.ts +0 -140
|
@@ -0,0 +1,444 @@
|
|
|
1
|
+
// src/specs/prune.test.ts
|
|
2
|
+
import { config } from 'dotenv';
|
|
3
|
+
config();
|
|
4
|
+
import { HumanMessage, AIMessage, SystemMessage, BaseMessage } from '@langchain/core/messages';
|
|
5
|
+
import type { RunnableConfig } from '@langchain/core/runnables';
|
|
6
|
+
import type { UsageMetadata } from '@langchain/core/messages';
|
|
7
|
+
import type * as t from '@/types';
|
|
8
|
+
import { GraphEvents, Providers } from '@/common';
|
|
9
|
+
import { getLLMConfig } from '@/utils/llmConfig';
|
|
10
|
+
import { Run } from '@/run';
|
|
11
|
+
import { createPruneMessages } from '@/messages/prune';
|
|
12
|
+
|
|
13
|
+
// Create a simple token counter for testing
|
|
14
|
+
const createTestTokenCounter = (): t.TokenCounter => {
|
|
15
|
+
// This simple token counter just counts characters as tokens for predictable testing
|
|
16
|
+
return (message: BaseMessage): number => {
|
|
17
|
+
// Use type assertion to help TypeScript understand the type
|
|
18
|
+
const content = message.content as string | Array<any> | undefined;
|
|
19
|
+
|
|
20
|
+
// Handle string content
|
|
21
|
+
if (typeof content === 'string') {
|
|
22
|
+
return content.length;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
// Handle array content
|
|
26
|
+
if (Array.isArray(content)) {
|
|
27
|
+
let totalLength = 0;
|
|
28
|
+
|
|
29
|
+
for (const item of content) {
|
|
30
|
+
if (typeof item === 'string') {
|
|
31
|
+
totalLength += item.length;
|
|
32
|
+
} else if (item && typeof item === 'object') {
|
|
33
|
+
if ('text' in item && typeof item.text === 'string') {
|
|
34
|
+
totalLength += item.text.length;
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
return totalLength;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
// Default case - if content is null, undefined, or any other type
|
|
43
|
+
return 0;
|
|
44
|
+
};
|
|
45
|
+
};
|
|
46
|
+
|
|
47
|
+
// Since the internal functions in prune.ts are not exported, we'll reimplement them here for testing
|
|
48
|
+
// This is based on the implementation in src/messages/prune.ts
|
|
49
|
+
function calculateTotalTokens(usage: Partial<UsageMetadata>): UsageMetadata {
|
|
50
|
+
const baseInputTokens = Number(usage.input_tokens) || 0;
|
|
51
|
+
const cacheCreation = Number(usage.input_token_details?.cache_creation) || 0;
|
|
52
|
+
const cacheRead = Number(usage.input_token_details?.cache_read) || 0;
|
|
53
|
+
|
|
54
|
+
const totalInputTokens = baseInputTokens + cacheCreation + cacheRead;
|
|
55
|
+
const totalOutputTokens = Number(usage.output_tokens) || 0;
|
|
56
|
+
|
|
57
|
+
return {
|
|
58
|
+
input_tokens: totalInputTokens,
|
|
59
|
+
output_tokens: totalOutputTokens,
|
|
60
|
+
total_tokens: totalInputTokens + totalOutputTokens
|
|
61
|
+
};
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
function getMessagesWithinTokenLimit({
|
|
65
|
+
messages: _messages,
|
|
66
|
+
maxContextTokens,
|
|
67
|
+
indexTokenCountMap,
|
|
68
|
+
}: {
|
|
69
|
+
messages: BaseMessage[];
|
|
70
|
+
maxContextTokens: number;
|
|
71
|
+
indexTokenCountMap: Record<string, number>;
|
|
72
|
+
}): {
|
|
73
|
+
context: BaseMessage[];
|
|
74
|
+
remainingContextTokens: number;
|
|
75
|
+
messagesToRefine: BaseMessage[];
|
|
76
|
+
summaryIndex: number;
|
|
77
|
+
} {
|
|
78
|
+
// Every reply is primed with <|start|>assistant<|message|>, so we
|
|
79
|
+
// start with 3 tokens for the label after all messages have been counted.
|
|
80
|
+
let summaryIndex = -1;
|
|
81
|
+
let currentTokenCount = 3;
|
|
82
|
+
const instructions = _messages?.[0]?.getType() === 'system' ? _messages[0] : undefined;
|
|
83
|
+
const instructionsTokenCount = instructions != null ? indexTokenCountMap[0] : 0;
|
|
84
|
+
let remainingContextTokens = maxContextTokens - instructionsTokenCount;
|
|
85
|
+
const messages = [..._messages];
|
|
86
|
+
const context: BaseMessage[] = [];
|
|
87
|
+
|
|
88
|
+
if (currentTokenCount < remainingContextTokens) {
|
|
89
|
+
let currentIndex = messages.length;
|
|
90
|
+
while (messages.length > 0 && currentTokenCount < remainingContextTokens && currentIndex > 1) {
|
|
91
|
+
currentIndex--;
|
|
92
|
+
if (messages.length === 1 && instructions) {
|
|
93
|
+
break;
|
|
94
|
+
}
|
|
95
|
+
const poppedMessage = messages.pop();
|
|
96
|
+
if (!poppedMessage) continue;
|
|
97
|
+
|
|
98
|
+
const tokenCount = indexTokenCountMap[currentIndex] || 0;
|
|
99
|
+
|
|
100
|
+
if ((currentTokenCount + tokenCount) <= remainingContextTokens) {
|
|
101
|
+
context.push(poppedMessage);
|
|
102
|
+
currentTokenCount += tokenCount;
|
|
103
|
+
} else {
|
|
104
|
+
messages.push(poppedMessage);
|
|
105
|
+
break;
|
|
106
|
+
}
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
if (instructions && _messages.length > 0) {
|
|
111
|
+
context.push(_messages[0] as BaseMessage);
|
|
112
|
+
messages.shift();
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
const prunedMemory = messages;
|
|
116
|
+
summaryIndex = prunedMemory.length - 1;
|
|
117
|
+
remainingContextTokens -= currentTokenCount;
|
|
118
|
+
|
|
119
|
+
return {
|
|
120
|
+
summaryIndex,
|
|
121
|
+
remainingContextTokens,
|
|
122
|
+
context: context.reverse(),
|
|
123
|
+
messagesToRefine: prunedMemory,
|
|
124
|
+
};
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
function checkValidNumber(value: unknown): value is number {
|
|
128
|
+
return typeof value === 'number' && !isNaN(value) && value > 0;
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
describe('Prune Messages Tests', () => {
|
|
132
|
+
jest.setTimeout(30000);
|
|
133
|
+
|
|
134
|
+
describe('calculateTotalTokens', () => {
|
|
135
|
+
it('should calculate total tokens correctly with all fields present', () => {
|
|
136
|
+
const usage: Partial<UsageMetadata> = {
|
|
137
|
+
input_tokens: 100,
|
|
138
|
+
output_tokens: 50,
|
|
139
|
+
input_token_details: {
|
|
140
|
+
cache_creation: 10,
|
|
141
|
+
cache_read: 5
|
|
142
|
+
}
|
|
143
|
+
};
|
|
144
|
+
|
|
145
|
+
const result = calculateTotalTokens(usage);
|
|
146
|
+
|
|
147
|
+
expect(result.input_tokens).toBe(115); // 100 + 10 + 5
|
|
148
|
+
expect(result.output_tokens).toBe(50);
|
|
149
|
+
expect(result.total_tokens).toBe(165); // 115 + 50
|
|
150
|
+
});
|
|
151
|
+
|
|
152
|
+
it('should handle missing fields gracefully', () => {
|
|
153
|
+
const usage: Partial<UsageMetadata> = {
|
|
154
|
+
input_tokens: 100,
|
|
155
|
+
output_tokens: 50
|
|
156
|
+
};
|
|
157
|
+
|
|
158
|
+
const result = calculateTotalTokens(usage);
|
|
159
|
+
|
|
160
|
+
expect(result.input_tokens).toBe(100);
|
|
161
|
+
expect(result.output_tokens).toBe(50);
|
|
162
|
+
expect(result.total_tokens).toBe(150);
|
|
163
|
+
});
|
|
164
|
+
|
|
165
|
+
it('should handle empty usage object', () => {
|
|
166
|
+
const usage: Partial<UsageMetadata> = {};
|
|
167
|
+
|
|
168
|
+
const result = calculateTotalTokens(usage);
|
|
169
|
+
|
|
170
|
+
expect(result.input_tokens).toBe(0);
|
|
171
|
+
expect(result.output_tokens).toBe(0);
|
|
172
|
+
expect(result.total_tokens).toBe(0);
|
|
173
|
+
});
|
|
174
|
+
});
|
|
175
|
+
|
|
176
|
+
describe('getMessagesWithinTokenLimit', () => {
|
|
177
|
+
it('should include all messages when under token limit', () => {
|
|
178
|
+
const messages = [
|
|
179
|
+
new SystemMessage('System instruction'),
|
|
180
|
+
new HumanMessage('Hello'),
|
|
181
|
+
new AIMessage('Hi there')
|
|
182
|
+
];
|
|
183
|
+
|
|
184
|
+
const indexTokenCountMap = {
|
|
185
|
+
0: 17, // "System instruction"
|
|
186
|
+
1: 5, // "Hello"
|
|
187
|
+
2: 8 // "Hi there"
|
|
188
|
+
};
|
|
189
|
+
|
|
190
|
+
const result = getMessagesWithinTokenLimit({
|
|
191
|
+
messages,
|
|
192
|
+
maxContextTokens: 100,
|
|
193
|
+
indexTokenCountMap
|
|
194
|
+
});
|
|
195
|
+
|
|
196
|
+
expect(result.context.length).toBe(3);
|
|
197
|
+
expect(result.context[0]).toBe(messages[0]); // System message
|
|
198
|
+
expect(result.context[0].getType()).toBe('system'); // System message
|
|
199
|
+
expect(result.remainingContextTokens).toBe(100 - 17 - 5 - 8 - 3); // -3 for the assistant label tokens
|
|
200
|
+
expect(result.messagesToRefine.length).toBe(0);
|
|
201
|
+
});
|
|
202
|
+
|
|
203
|
+
it('should prune oldest messages when over token limit', () => {
|
|
204
|
+
const messages = [
|
|
205
|
+
new SystemMessage('System instruction'),
|
|
206
|
+
new HumanMessage('Message 1'),
|
|
207
|
+
new AIMessage('Response 1'),
|
|
208
|
+
new HumanMessage('Message 2'),
|
|
209
|
+
new AIMessage('Response 2')
|
|
210
|
+
];
|
|
211
|
+
|
|
212
|
+
const indexTokenCountMap = {
|
|
213
|
+
0: 17, // "System instruction"
|
|
214
|
+
1: 9, // "Message 1"
|
|
215
|
+
2: 10, // "Response 1"
|
|
216
|
+
3: 9, // "Message 2"
|
|
217
|
+
4: 10 // "Response 2"
|
|
218
|
+
};
|
|
219
|
+
|
|
220
|
+
// Set a limit that can only fit the system message and the last two messages
|
|
221
|
+
const result = getMessagesWithinTokenLimit({
|
|
222
|
+
messages,
|
|
223
|
+
maxContextTokens: 40,
|
|
224
|
+
indexTokenCountMap
|
|
225
|
+
});
|
|
226
|
+
|
|
227
|
+
// Should include system message and the last two messages
|
|
228
|
+
expect(result.context.length).toBe(3);
|
|
229
|
+
expect(result.context[0]).toBe(messages[0]); // System message
|
|
230
|
+
expect(result.context[0].getType()).toBe('system'); // System message
|
|
231
|
+
expect(result.context[1]).toBe(messages[3]); // Message 2
|
|
232
|
+
expect(result.context[2]).toBe(messages[4]); // Response 2
|
|
233
|
+
|
|
234
|
+
// Should have the first two messages in messagesToRefine
|
|
235
|
+
expect(result.messagesToRefine.length).toBe(2);
|
|
236
|
+
expect(result.messagesToRefine[0]).toBe(messages[1]); // Message 1
|
|
237
|
+
expect(result.messagesToRefine[1]).toBe(messages[2]); // Response 1
|
|
238
|
+
});
|
|
239
|
+
|
|
240
|
+
it('should always include system message even when at token limit', () => {
|
|
241
|
+
const messages = [
|
|
242
|
+
new SystemMessage('System instruction'),
|
|
243
|
+
new HumanMessage('Hello'),
|
|
244
|
+
new AIMessage('Hi there')
|
|
245
|
+
];
|
|
246
|
+
|
|
247
|
+
const indexTokenCountMap = {
|
|
248
|
+
0: 17, // "System instruction"
|
|
249
|
+
1: 5, // "Hello"
|
|
250
|
+
2: 8 // "Hi there"
|
|
251
|
+
};
|
|
252
|
+
|
|
253
|
+
// Set a limit that can only fit the system message
|
|
254
|
+
const result = getMessagesWithinTokenLimit({
|
|
255
|
+
messages,
|
|
256
|
+
maxContextTokens: 20,
|
|
257
|
+
indexTokenCountMap
|
|
258
|
+
});
|
|
259
|
+
|
|
260
|
+
expect(result.context.length).toBe(1);
|
|
261
|
+
expect(result.context[0]).toBe(messages[0]); // System message
|
|
262
|
+
|
|
263
|
+
expect(result.messagesToRefine.length).toBe(2);
|
|
264
|
+
});
|
|
265
|
+
});
|
|
266
|
+
|
|
267
|
+
describe('checkValidNumber', () => {
|
|
268
|
+
it('should return true for valid positive numbers', () => {
|
|
269
|
+
expect(checkValidNumber(5)).toBe(true);
|
|
270
|
+
expect(checkValidNumber(1.5)).toBe(true);
|
|
271
|
+
expect(checkValidNumber(Number.MAX_SAFE_INTEGER)).toBe(true);
|
|
272
|
+
});
|
|
273
|
+
|
|
274
|
+
it('should return false for zero, negative numbers, and NaN', () => {
|
|
275
|
+
expect(checkValidNumber(0)).toBe(false);
|
|
276
|
+
expect(checkValidNumber(-5)).toBe(false);
|
|
277
|
+
expect(checkValidNumber(NaN)).toBe(false);
|
|
278
|
+
});
|
|
279
|
+
|
|
280
|
+
it('should return false for non-number types', () => {
|
|
281
|
+
expect(checkValidNumber('5')).toBe(false);
|
|
282
|
+
expect(checkValidNumber(null)).toBe(false);
|
|
283
|
+
expect(checkValidNumber(undefined)).toBe(false);
|
|
284
|
+
expect(checkValidNumber({})).toBe(false);
|
|
285
|
+
expect(checkValidNumber([])).toBe(false);
|
|
286
|
+
});
|
|
287
|
+
});
|
|
288
|
+
|
|
289
|
+
describe('createPruneMessages', () => {
|
|
290
|
+
it('should return all messages when under token limit', () => {
|
|
291
|
+
const tokenCounter = createTestTokenCounter();
|
|
292
|
+
const messages = [
|
|
293
|
+
new SystemMessage('System instruction'),
|
|
294
|
+
new HumanMessage('Hello'),
|
|
295
|
+
new AIMessage('Hi there')
|
|
296
|
+
];
|
|
297
|
+
|
|
298
|
+
const indexTokenCountMap = {
|
|
299
|
+
0: tokenCounter(messages[0]),
|
|
300
|
+
1: tokenCounter(messages[1]),
|
|
301
|
+
2: tokenCounter(messages[2])
|
|
302
|
+
};
|
|
303
|
+
|
|
304
|
+
const pruneMessages = createPruneMessages({
|
|
305
|
+
maxTokens: 100,
|
|
306
|
+
startIndex: 0,
|
|
307
|
+
tokenCounter,
|
|
308
|
+
indexTokenCountMap
|
|
309
|
+
});
|
|
310
|
+
|
|
311
|
+
const result = pruneMessages({ messages });
|
|
312
|
+
|
|
313
|
+
expect(result.context.length).toBe(3);
|
|
314
|
+
expect(result.context).toEqual(messages);
|
|
315
|
+
});
|
|
316
|
+
|
|
317
|
+
it('should prune messages when over token limit', () => {
|
|
318
|
+
const tokenCounter = createTestTokenCounter();
|
|
319
|
+
const messages = [
|
|
320
|
+
new SystemMessage('System instruction'),
|
|
321
|
+
new HumanMessage('Message 1'),
|
|
322
|
+
new AIMessage('Response 1'),
|
|
323
|
+
new HumanMessage('Message 2'),
|
|
324
|
+
new AIMessage('Response 2')
|
|
325
|
+
];
|
|
326
|
+
|
|
327
|
+
const indexTokenCountMap = {
|
|
328
|
+
0: tokenCounter(messages[0]),
|
|
329
|
+
1: tokenCounter(messages[1]),
|
|
330
|
+
2: tokenCounter(messages[2]),
|
|
331
|
+
3: tokenCounter(messages[3]),
|
|
332
|
+
4: tokenCounter(messages[4])
|
|
333
|
+
};
|
|
334
|
+
|
|
335
|
+
// Set a limit that can only fit the system message and the last two messages
|
|
336
|
+
const pruneMessages = createPruneMessages({
|
|
337
|
+
maxTokens: 40,
|
|
338
|
+
startIndex: 0,
|
|
339
|
+
tokenCounter,
|
|
340
|
+
indexTokenCountMap
|
|
341
|
+
});
|
|
342
|
+
|
|
343
|
+
const result = pruneMessages({ messages });
|
|
344
|
+
|
|
345
|
+
// Should include system message and the last two messages
|
|
346
|
+
expect(result.context.length).toBe(3);
|
|
347
|
+
expect(result.context[0]).toBe(messages[0]); // System message
|
|
348
|
+
expect(result.context[1]).toBe(messages[3]); // Message 2
|
|
349
|
+
expect(result.context[2]).toBe(messages[4]); // Response 2
|
|
350
|
+
});
|
|
351
|
+
|
|
352
|
+
it('should update token counts when usage metadata is provided', () => {
|
|
353
|
+
const tokenCounter = createTestTokenCounter();
|
|
354
|
+
const messages = [
|
|
355
|
+
new SystemMessage('System instruction'),
|
|
356
|
+
new HumanMessage('Hello'),
|
|
357
|
+
new AIMessage('Hi there')
|
|
358
|
+
];
|
|
359
|
+
|
|
360
|
+
const indexTokenCountMap = {
|
|
361
|
+
0: tokenCounter(messages[0]),
|
|
362
|
+
1: tokenCounter(messages[1]),
|
|
363
|
+
2: tokenCounter(messages[2])
|
|
364
|
+
};
|
|
365
|
+
|
|
366
|
+
const pruneMessages = createPruneMessages({
|
|
367
|
+
maxTokens: 100,
|
|
368
|
+
startIndex: 0,
|
|
369
|
+
tokenCounter,
|
|
370
|
+
indexTokenCountMap: { ...indexTokenCountMap }
|
|
371
|
+
});
|
|
372
|
+
|
|
373
|
+
// Provide usage metadata that indicates different token counts
|
|
374
|
+
const usageMetadata: Partial<UsageMetadata> = {
|
|
375
|
+
input_tokens: 50,
|
|
376
|
+
output_tokens: 25,
|
|
377
|
+
total_tokens: 75
|
|
378
|
+
};
|
|
379
|
+
|
|
380
|
+
const result = pruneMessages({
|
|
381
|
+
messages,
|
|
382
|
+
usageMetadata
|
|
383
|
+
});
|
|
384
|
+
|
|
385
|
+
// The function should have updated the indexTokenCountMap based on the usage metadata
|
|
386
|
+
expect(result.indexTokenCountMap).not.toEqual(indexTokenCountMap);
|
|
387
|
+
|
|
388
|
+
// The total of all values in indexTokenCountMap should equal the total_tokens from usageMetadata
|
|
389
|
+
const totalTokens = Object.values(result.indexTokenCountMap).reduce((a, b) => a + b, 0);
|
|
390
|
+
expect(totalTokens).toBe(75);
|
|
391
|
+
});
|
|
392
|
+
});
|
|
393
|
+
|
|
394
|
+
describe('Integration with Run', () => {
|
|
395
|
+
it('should initialize Run with custom token counter and process messages', async () => {
|
|
396
|
+
const provider = Providers.OPENAI;
|
|
397
|
+
const llmConfig = getLLMConfig(provider);
|
|
398
|
+
const tokenCounter = createTestTokenCounter();
|
|
399
|
+
|
|
400
|
+
const run = await Run.create<t.IState>({
|
|
401
|
+
runId: 'test-prune-run',
|
|
402
|
+
graphConfig: {
|
|
403
|
+
type: 'standard',
|
|
404
|
+
llmConfig,
|
|
405
|
+
instructions: 'You are a helpful assistant.',
|
|
406
|
+
},
|
|
407
|
+
returnContent: true,
|
|
408
|
+
});
|
|
409
|
+
|
|
410
|
+
// Override the model to use a fake LLM
|
|
411
|
+
run.Graph?.overrideTestModel(['This is a test response'], 1);
|
|
412
|
+
|
|
413
|
+
const messages = [
|
|
414
|
+
new HumanMessage('Hello, how are you?')
|
|
415
|
+
];
|
|
416
|
+
|
|
417
|
+
const indexTokenCountMap = {
|
|
418
|
+
0: tokenCounter(messages[0])
|
|
419
|
+
};
|
|
420
|
+
|
|
421
|
+
const config: Partial<RunnableConfig> & { version: 'v1' | 'v2'; streamMode: string } = {
|
|
422
|
+
configurable: {
|
|
423
|
+
thread_id: 'test-thread',
|
|
424
|
+
},
|
|
425
|
+
streamMode: 'values',
|
|
426
|
+
version: 'v2' as const,
|
|
427
|
+
};
|
|
428
|
+
|
|
429
|
+
await run.processStream(
|
|
430
|
+
{ messages },
|
|
431
|
+
config,
|
|
432
|
+
{
|
|
433
|
+
maxContextTokens: 1000,
|
|
434
|
+
indexTokenCountMap,
|
|
435
|
+
tokenCounter,
|
|
436
|
+
}
|
|
437
|
+
);
|
|
438
|
+
|
|
439
|
+
const finalMessages = run.getRunMessages();
|
|
440
|
+
expect(finalMessages).toBeDefined();
|
|
441
|
+
expect(finalMessages?.length).toBeGreaterThan(0);
|
|
442
|
+
});
|
|
443
|
+
});
|
|
444
|
+
});
|
package/src/types/run.ts
CHANGED
|
@@ -55,7 +55,12 @@ export type RunConfig = {
|
|
|
55
55
|
|
|
56
56
|
export type ProvidedCallbacks = (BaseCallbackHandler | CallbackHandlerMethods)[] | undefined;
|
|
57
57
|
|
|
58
|
+
export type TokenCounter = (message: BaseMessage) => number;
|
|
58
59
|
export type EventStreamOptions = {
|
|
59
60
|
callbacks?: graph.ClientCallbacks;
|
|
60
61
|
keepContent?: boolean;
|
|
62
|
+
/* Context Management */
|
|
63
|
+
maxContextTokens?: number;
|
|
64
|
+
tokenCounter?: TokenCounter;
|
|
65
|
+
indexTokenCountMap?: Record<string, number>;
|
|
61
66
|
}
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import { Tiktoken } from "js-tiktoken/lite";
|
|
2
|
+
import type { BaseMessage } from "@langchain/core/messages";
|
|
3
|
+
import { ContentTypes } from "@/common/enum";
|
|
4
|
+
|
|
5
|
+
function getTokenCountForMessage(message: BaseMessage, getTokenCount: (text: string) => number): number {
|
|
6
|
+
let tokensPerMessage = 3;
|
|
7
|
+
|
|
8
|
+
const processValue = (value: unknown) => {
|
|
9
|
+
if (Array.isArray(value)) {
|
|
10
|
+
for (let item of value) {
|
|
11
|
+
if (
|
|
12
|
+
!item ||
|
|
13
|
+
!item.type ||
|
|
14
|
+
item.type === ContentTypes.ERROR ||
|
|
15
|
+
item.type === ContentTypes.IMAGE_URL
|
|
16
|
+
) {
|
|
17
|
+
continue;
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
if (item.type === ContentTypes.TOOL_CALL && item.tool_call != null) {
|
|
21
|
+
const toolName = item.tool_call?.name || '';
|
|
22
|
+
if (toolName != null && toolName && typeof toolName === 'string') {
|
|
23
|
+
numTokens += getTokenCount(toolName);
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
const args = item.tool_call?.args || '';
|
|
27
|
+
if (args != null && args && typeof args === 'string') {
|
|
28
|
+
numTokens += getTokenCount(args);
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
const output = item.tool_call?.output || '';
|
|
32
|
+
if (output != null && output && typeof output === 'string') {
|
|
33
|
+
numTokens += getTokenCount(output);
|
|
34
|
+
}
|
|
35
|
+
continue;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
const nestedValue = item[item.type];
|
|
39
|
+
|
|
40
|
+
if (!nestedValue) {
|
|
41
|
+
continue;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
processValue(nestedValue);
|
|
45
|
+
}
|
|
46
|
+
} else if (typeof value === 'string') {
|
|
47
|
+
numTokens += getTokenCount(value);
|
|
48
|
+
} else if (typeof value === 'number') {
|
|
49
|
+
numTokens += getTokenCount(value.toString());
|
|
50
|
+
} else if (typeof value === 'boolean') {
|
|
51
|
+
numTokens += getTokenCount(value.toString());
|
|
52
|
+
}
|
|
53
|
+
};
|
|
54
|
+
|
|
55
|
+
let numTokens = tokensPerMessage;
|
|
56
|
+
processValue(message.content);
|
|
57
|
+
return numTokens;
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
export const createTokenCounter = async () => {
|
|
61
|
+
const res = await fetch(`https://tiktoken.pages.dev/js/o200k_base.json`);
|
|
62
|
+
const o200k_base = await res.json();
|
|
63
|
+
|
|
64
|
+
const countTokens = (text: string) => {
|
|
65
|
+
const enc = new Tiktoken(o200k_base);
|
|
66
|
+
return enc.encode(text).length;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
return (message: BaseMessage) => getTokenCountForMessage(message, countTokens);
|
|
70
|
+
}
|