@librechat/agents 2.2.2 → 2.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/cjs/graphs/Graph.cjs +51 -14
- package/dist/cjs/graphs/Graph.cjs.map +1 -1
- package/dist/cjs/main.cjs +6 -4
- package/dist/cjs/main.cjs.map +1 -1
- package/dist/cjs/messages/format.cjs +21 -0
- package/dist/cjs/messages/format.cjs.map +1 -1
- package/dist/cjs/messages/prune.cjs +124 -0
- package/dist/cjs/messages/prune.cjs.map +1 -0
- package/dist/cjs/run.cjs +24 -0
- package/dist/cjs/run.cjs.map +1 -1
- package/dist/cjs/tools/ToolNode.cjs +1 -0
- package/dist/cjs/tools/ToolNode.cjs.map +1 -1
- package/dist/cjs/utils/tokens.cjs +65 -0
- package/dist/cjs/utils/tokens.cjs.map +1 -0
- package/dist/esm/graphs/Graph.mjs +51 -14
- package/dist/esm/graphs/Graph.mjs.map +1 -1
- package/dist/esm/main.mjs +3 -3
- package/dist/esm/messages/format.mjs +21 -1
- package/dist/esm/messages/format.mjs.map +1 -1
- package/dist/esm/messages/prune.mjs +122 -0
- package/dist/esm/messages/prune.mjs.map +1 -0
- package/dist/esm/run.mjs +24 -0
- package/dist/esm/run.mjs.map +1 -1
- package/dist/esm/tools/ToolNode.mjs +1 -0
- package/dist/esm/tools/ToolNode.mjs.map +1 -1
- package/dist/esm/utils/tokens.mjs +62 -0
- package/dist/esm/utils/tokens.mjs.map +1 -0
- package/dist/types/graphs/Graph.d.ts +8 -1
- package/dist/types/messages/format.d.ts +9 -0
- package/dist/types/messages/index.d.ts +1 -2
- package/dist/types/messages/prune.d.ts +16 -0
- package/dist/types/types/run.d.ts +4 -0
- package/dist/types/utils/index.d.ts +1 -0
- package/dist/types/utils/tokens.d.ts +3 -0
- package/package.json +1 -1
- package/src/graphs/Graph.ts +54 -16
- package/src/messages/format.ts +27 -0
- package/src/messages/index.ts +1 -2
- package/src/messages/prune.ts +167 -0
- package/src/messages/shiftIndexTokenCountMap.test.ts +81 -0
- package/src/run.ts +26 -0
- package/src/scripts/code_exec_simple.ts +21 -8
- package/src/specs/prune.test.ts +444 -0
- package/src/types/run.ts +5 -0
- package/src/utils/index.ts +2 -1
- package/src/utils/tokens.ts +70 -0
- package/dist/cjs/messages/transformers.cjs +0 -318
- package/dist/cjs/messages/transformers.cjs.map +0 -1
- package/dist/cjs/messages/trimMessagesFactory.cjs +0 -129
- package/dist/cjs/messages/trimMessagesFactory.cjs.map +0 -1
- package/dist/esm/messages/transformers.mjs +0 -316
- package/dist/esm/messages/transformers.mjs.map +0 -1
- package/dist/esm/messages/trimMessagesFactory.mjs +0 -127
- package/dist/esm/messages/trimMessagesFactory.mjs.map +0 -1
- package/dist/types/messages/transformers.d.ts +0 -320
- package/dist/types/messages/trimMessagesFactory.d.ts +0 -37
- package/src/messages/transformers.ts +0 -786
- package/src/messages/trimMessagesFactory.test.ts +0 -331
- package/src/messages/trimMessagesFactory.ts +0 -140
package/src/graphs/Graph.ts
CHANGED
|
@@ -8,12 +8,13 @@ import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai';
|
|
|
8
8
|
import { Runnable, RunnableConfig } from '@langchain/core/runnables';
|
|
9
9
|
import { dispatchCustomEvent } from '@langchain/core/callbacks/dispatch';
|
|
10
10
|
import { AIMessageChunk, ToolMessage, SystemMessage } from '@langchain/core/messages';
|
|
11
|
-
import type { BaseMessage, BaseMessageFields } from '@langchain/core/messages';
|
|
11
|
+
import type { BaseMessage, BaseMessageFields, UsageMetadata } from '@langchain/core/messages';
|
|
12
12
|
import type * as t from '@/types';
|
|
13
13
|
import { Providers, GraphEvents, GraphNodeKeys, StepTypes, Callback, ContentTypes } from '@/common';
|
|
14
14
|
import { getChatModelClass, manualToolStreamProviders } from '@/llm/providers';
|
|
15
15
|
import { ToolNode as CustomToolNode, toolsCondition } from '@/tools/ToolNode';
|
|
16
16
|
import {
|
|
17
|
+
createPruneMessages,
|
|
17
18
|
modifyDeltaProperties,
|
|
18
19
|
formatArtifactPayload,
|
|
19
20
|
convertMessagesToContent,
|
|
@@ -74,8 +75,13 @@ export abstract class Graph<
|
|
|
74
75
|
stepKeyIds: Map<string, string[]> = new Map<string, string[]>();
|
|
75
76
|
contentIndexMap: Map<string, number> = new Map();
|
|
76
77
|
toolCallStepIds: Map<string, string> = new Map();
|
|
78
|
+
currentUsage: Partial<UsageMetadata> | undefined;
|
|
79
|
+
indexTokenCountMap: Record<string, number> = {};
|
|
80
|
+
maxContextTokens: number | undefined;
|
|
81
|
+
pruneMessages?: ReturnType<typeof createPruneMessages>;
|
|
77
82
|
/** The amount of time that should pass before another consecutive API call */
|
|
78
83
|
streamBuffer: number | undefined;
|
|
84
|
+
tokenCounter?: t.TokenCounter;
|
|
79
85
|
signal?: AbortSignal;
|
|
80
86
|
}
|
|
81
87
|
|
|
@@ -166,6 +172,10 @@ export class StandardGraph extends Graph<
|
|
|
166
172
|
this.currentTokenType = resetIfNotEmpty(this.currentTokenType, ContentTypes.TEXT);
|
|
167
173
|
this.lastToken = resetIfNotEmpty(this.lastToken, undefined);
|
|
168
174
|
this.tokenTypeSwitch = resetIfNotEmpty(this.tokenTypeSwitch, undefined);
|
|
175
|
+
this.indexTokenCountMap = resetIfNotEmpty(this.indexTokenCountMap, {});
|
|
176
|
+
this.currentUsage = resetIfNotEmpty(this.currentUsage, undefined);
|
|
177
|
+
this.tokenCounter = resetIfNotEmpty(this.tokenCounter, undefined);
|
|
178
|
+
this.maxContextTokens = resetIfNotEmpty(this.maxContextTokens, undefined);
|
|
169
179
|
}
|
|
170
180
|
|
|
171
181
|
/* Run Step Processing */
|
|
@@ -326,6 +336,12 @@ export class StandardGraph extends Graph<
|
|
|
326
336
|
return new ChatModelClass(options);
|
|
327
337
|
}
|
|
328
338
|
|
|
339
|
+
storeUsageMetadata(finalMessage?: BaseMessage): void {
|
|
340
|
+
if (finalMessage && 'usage_metadata' in finalMessage && finalMessage.usage_metadata) {
|
|
341
|
+
this.currentUsage = finalMessage.usage_metadata as Partial<UsageMetadata>;
|
|
342
|
+
}
|
|
343
|
+
}
|
|
344
|
+
|
|
329
345
|
createCallModel() {
|
|
330
346
|
return async (state: t.BaseGraphState, config?: RunnableConfig): Promise<Partial<t.BaseGraphState>> => {
|
|
331
347
|
const { provider = '' } = (config?.configurable as t.GraphConfig | undefined) ?? {} ;
|
|
@@ -338,9 +354,27 @@ export class StandardGraph extends Graph<
|
|
|
338
354
|
this.config = config;
|
|
339
355
|
const { messages } = state;
|
|
340
356
|
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
357
|
+
let messagesToUse = messages;
|
|
358
|
+
if (!this.pruneMessages && this.tokenCounter && this.maxContextTokens && this.indexTokenCountMap[0] != null) {
|
|
359
|
+
this.pruneMessages = createPruneMessages({
|
|
360
|
+
indexTokenCountMap: this.indexTokenCountMap,
|
|
361
|
+
maxTokens: this.maxContextTokens,
|
|
362
|
+
tokenCounter: this.tokenCounter,
|
|
363
|
+
startIndex: this.startIndex,
|
|
364
|
+
});
|
|
365
|
+
}
|
|
366
|
+
if (this.pruneMessages) {
|
|
367
|
+
const { context, indexTokenCountMap } = this.pruneMessages({
|
|
368
|
+
messages,
|
|
369
|
+
usageMetadata: this.currentUsage,
|
|
370
|
+
});
|
|
371
|
+
this.indexTokenCountMap = indexTokenCountMap;
|
|
372
|
+
messagesToUse = context;
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
const finalMessages = messagesToUse;
|
|
376
|
+
const lastMessageX = finalMessages.length >= 2 ? finalMessages[finalMessages.length - 2] : null;
|
|
377
|
+
const lastMessageY = finalMessages.length >= 1 ? finalMessages[finalMessages.length - 1] : null;
|
|
344
378
|
|
|
345
379
|
if (
|
|
346
380
|
provider === Providers.BEDROCK
|
|
@@ -372,6 +406,7 @@ export class StandardGraph extends Graph<
|
|
|
372
406
|
|
|
373
407
|
this.lastStreamCall = Date.now();
|
|
374
408
|
|
|
409
|
+
let result: Partial<t.BaseGraphState>;
|
|
375
410
|
if ((this.tools?.length ?? 0) > 0 && manualToolStreamProviders.has(provider)) {
|
|
376
411
|
const stream = await this.boundModel.stream(finalMessages, config);
|
|
377
412
|
let finalChunk: AIMessageChunk | undefined;
|
|
@@ -385,19 +420,22 @@ export class StandardGraph extends Graph<
|
|
|
385
420
|
}
|
|
386
421
|
|
|
387
422
|
finalChunk = modifyDeltaProperties(this.provider, finalChunk);
|
|
388
|
-
|
|
389
|
-
}
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
}
|
|
423
|
+
result = { messages: [finalChunk as AIMessageChunk] };
|
|
424
|
+
} else {
|
|
425
|
+
const finalMessage = (await this.boundModel.invoke(finalMessages, config)) as AIMessageChunk;
|
|
426
|
+
if ((finalMessage.tool_calls?.length ?? 0) > 0) {
|
|
427
|
+
finalMessage.tool_calls = finalMessage.tool_calls?.filter((tool_call) => {
|
|
428
|
+
if (!tool_call.name) {
|
|
429
|
+
return false;
|
|
430
|
+
}
|
|
431
|
+
return true;
|
|
432
|
+
});
|
|
433
|
+
}
|
|
434
|
+
result = { messages: [finalMessage] };
|
|
399
435
|
}
|
|
400
|
-
|
|
436
|
+
|
|
437
|
+
this.storeUsageMetadata(result?.messages?.[0]);
|
|
438
|
+
return result;
|
|
401
439
|
};
|
|
402
440
|
}
|
|
403
441
|
|
package/src/messages/format.ts
CHANGED
|
@@ -431,3 +431,30 @@ export const formatContentStrings = (payload: Array<BaseMessage>): Array<BaseMes
|
|
|
431
431
|
|
|
432
432
|
return result;
|
|
433
433
|
};
|
|
434
|
+
|
|
435
|
+
/**
|
|
436
|
+
* Adds a value at key 0 for system messages and shifts all key indices by one in an indexTokenCountMap.
|
|
437
|
+
* This is useful when adding a system message at the beginning of a conversation.
|
|
438
|
+
*
|
|
439
|
+
* @param indexTokenCountMap - The original map of message indices to token counts
|
|
440
|
+
* @param instructionsTokenCount - The token count for the system message to add at index 0
|
|
441
|
+
* @returns A new map with the system message at index 0 and all other indices shifted by 1
|
|
442
|
+
*/
|
|
443
|
+
export function shiftIndexTokenCountMap(
|
|
444
|
+
indexTokenCountMap: Record<number, number>,
|
|
445
|
+
instructionsTokenCount: number
|
|
446
|
+
): Record<number, number> {
|
|
447
|
+
// Create a new map to avoid modifying the original
|
|
448
|
+
const shiftedMap: Record<number, number> = {};
|
|
449
|
+
|
|
450
|
+
// Add the system message token count at index 0
|
|
451
|
+
shiftedMap[0] = instructionsTokenCount;
|
|
452
|
+
|
|
453
|
+
// Shift all existing indices by 1
|
|
454
|
+
for (const [indexStr, tokenCount] of Object.entries(indexTokenCountMap)) {
|
|
455
|
+
const index = Number(indexStr);
|
|
456
|
+
shiftedMap[index + 1] = tokenCount;
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
return shiftedMap;
|
|
460
|
+
}
|
package/src/messages/index.ts
CHANGED
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import type { BaseMessage, UsageMetadata } from '@langchain/core/messages';
|
|
2
|
+
import type { TokenCounter } from '@/types/run';
|
|
3
|
+
export type PruneMessagesFactoryParams = {
|
|
4
|
+
maxTokens: number;
|
|
5
|
+
startIndex: number;
|
|
6
|
+
tokenCounter: TokenCounter;
|
|
7
|
+
indexTokenCountMap: Record<string, number>;
|
|
8
|
+
};
|
|
9
|
+
export type PruneMessagesParams = {
|
|
10
|
+
messages: BaseMessage[];
|
|
11
|
+
usageMetadata?: Partial<UsageMetadata>;
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* Calculates the total tokens from a single usage object
|
|
16
|
+
*
|
|
17
|
+
* @param usage The usage metadata object containing token information
|
|
18
|
+
* @returns An object containing the total input and output tokens
|
|
19
|
+
*/
|
|
20
|
+
function calculateTotalTokens(usage: Partial<UsageMetadata>): UsageMetadata {
|
|
21
|
+
const baseInputTokens = Number(usage.input_tokens) || 0;
|
|
22
|
+
const cacheCreation = Number(usage.input_token_details?.cache_creation) || 0;
|
|
23
|
+
const cacheRead = Number(usage.input_token_details?.cache_read) || 0;
|
|
24
|
+
|
|
25
|
+
const totalInputTokens = baseInputTokens + cacheCreation + cacheRead;
|
|
26
|
+
const totalOutputTokens = Number(usage.output_tokens) || 0;
|
|
27
|
+
|
|
28
|
+
return {
|
|
29
|
+
input_tokens: totalInputTokens,
|
|
30
|
+
output_tokens: totalOutputTokens,
|
|
31
|
+
total_tokens: totalInputTokens + totalOutputTokens
|
|
32
|
+
};
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
/**
|
|
36
|
+
* Processes an array of messages and returns a context of messages that fit within a specified token limit.
|
|
37
|
+
* It iterates over the messages from newest to oldest, adding them to the context until the token limit is reached.
|
|
38
|
+
*
|
|
39
|
+
* @param options Configuration options for processing messages
|
|
40
|
+
* @returns Object containing the message context, remaining tokens, messages not included, and summary index
|
|
41
|
+
*/
|
|
42
|
+
function getMessagesWithinTokenLimit({
|
|
43
|
+
messages: _messages,
|
|
44
|
+
maxContextTokens,
|
|
45
|
+
indexTokenCountMap,
|
|
46
|
+
}: {
|
|
47
|
+
messages: BaseMessage[];
|
|
48
|
+
maxContextTokens: number;
|
|
49
|
+
indexTokenCountMap: Record<string, number>;
|
|
50
|
+
}): {
|
|
51
|
+
context: BaseMessage[];
|
|
52
|
+
remainingContextTokens: number;
|
|
53
|
+
messagesToRefine: BaseMessage[];
|
|
54
|
+
summaryIndex: number;
|
|
55
|
+
} {
|
|
56
|
+
// Every reply is primed with <|start|>assistant<|message|>, so we
|
|
57
|
+
// start with 3 tokens for the label after all messages have been counted.
|
|
58
|
+
let summaryIndex = -1;
|
|
59
|
+
let currentTokenCount = 3;
|
|
60
|
+
const instructions = _messages?.[0]?.getType() === 'system' ? _messages[0] : undefined;
|
|
61
|
+
const instructionsTokenCount = instructions != null ? indexTokenCountMap[0] : 0;
|
|
62
|
+
let remainingContextTokens = maxContextTokens - instructionsTokenCount;
|
|
63
|
+
const messages = [..._messages];
|
|
64
|
+
const context: BaseMessage[] = [];
|
|
65
|
+
|
|
66
|
+
if (currentTokenCount < remainingContextTokens) {
|
|
67
|
+
let currentIndex = messages.length;
|
|
68
|
+
while (messages.length > 0 && currentTokenCount < remainingContextTokens && currentIndex > 1) {
|
|
69
|
+
currentIndex--;
|
|
70
|
+
if (messages.length === 1 && instructions) {
|
|
71
|
+
break;
|
|
72
|
+
}
|
|
73
|
+
const poppedMessage = messages.pop();
|
|
74
|
+
if (!poppedMessage) continue;
|
|
75
|
+
|
|
76
|
+
const tokenCount = indexTokenCountMap[currentIndex] || 0;
|
|
77
|
+
|
|
78
|
+
if ((currentTokenCount + tokenCount) <= remainingContextTokens) {
|
|
79
|
+
context.push(poppedMessage);
|
|
80
|
+
currentTokenCount += tokenCount;
|
|
81
|
+
} else {
|
|
82
|
+
messages.push(poppedMessage);
|
|
83
|
+
break;
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
if (instructions && _messages.length > 0) {
|
|
89
|
+
context.push(_messages[0] as BaseMessage);
|
|
90
|
+
messages.shift();
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
const prunedMemory = messages;
|
|
94
|
+
summaryIndex = prunedMemory.length - 1;
|
|
95
|
+
remainingContextTokens -= currentTokenCount;
|
|
96
|
+
|
|
97
|
+
return {
|
|
98
|
+
summaryIndex,
|
|
99
|
+
remainingContextTokens,
|
|
100
|
+
context: context.reverse(),
|
|
101
|
+
messagesToRefine: prunedMemory,
|
|
102
|
+
};
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
function checkValidNumber(value: unknown): value is number {
|
|
106
|
+
return typeof value === 'number' && !isNaN(value) && value > 0;
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
export function createPruneMessages(factoryParams: PruneMessagesFactoryParams) {
|
|
110
|
+
const indexTokenCountMap = { ...factoryParams.indexTokenCountMap };
|
|
111
|
+
let lastTurnStartIndex = factoryParams.startIndex;
|
|
112
|
+
let totalTokens = (Object.values(indexTokenCountMap)).reduce((a, b) => a + b, 0);
|
|
113
|
+
return function pruneMessages(params: PruneMessagesParams): {
|
|
114
|
+
context: BaseMessage[];
|
|
115
|
+
indexTokenCountMap: Record<string, number>;
|
|
116
|
+
} {
|
|
117
|
+
let currentUsage: UsageMetadata | undefined;
|
|
118
|
+
if (params.usageMetadata && (
|
|
119
|
+
checkValidNumber(params.usageMetadata.input_tokens)
|
|
120
|
+
|| (
|
|
121
|
+
checkValidNumber(params.usageMetadata.input_token_details)
|
|
122
|
+
&& (
|
|
123
|
+
checkValidNumber(params.usageMetadata.input_token_details.cache_creation)
|
|
124
|
+
|| checkValidNumber(params.usageMetadata.input_token_details.cache_read)
|
|
125
|
+
)
|
|
126
|
+
)
|
|
127
|
+
) && checkValidNumber(params.usageMetadata.output_tokens)) {
|
|
128
|
+
currentUsage = calculateTotalTokens(params.usageMetadata);
|
|
129
|
+
totalTokens = currentUsage.total_tokens;
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
for (let i = lastTurnStartIndex; i < params.messages.length; i++) {
|
|
133
|
+
const message = params.messages[i];
|
|
134
|
+
if (i === lastTurnStartIndex && indexTokenCountMap[i] === undefined && currentUsage) {
|
|
135
|
+
indexTokenCountMap[i] = currentUsage.output_tokens;
|
|
136
|
+
} else if (indexTokenCountMap[i] === undefined) {
|
|
137
|
+
indexTokenCountMap[i] = factoryParams.tokenCounter(message);
|
|
138
|
+
totalTokens += indexTokenCountMap[i];
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
// If `currentUsage` is defined, we need to distribute the current total tokensto our `indexTokenCountMap`,
|
|
143
|
+
// for all message index keys before `lastTurnStartIndex`, as it has the most accurate count for those messages.
|
|
144
|
+
// We must distribute it in a weighted manner, so that the total token count is equal to `currentUsage.total_tokens`,
|
|
145
|
+
// relative the manually counted tokens in `indexTokenCountMap`.
|
|
146
|
+
if (currentUsage) {
|
|
147
|
+
const totalIndexTokens = Object.values(indexTokenCountMap).reduce((a, b) => a + b, 0);
|
|
148
|
+
const ratio = currentUsage.total_tokens / totalIndexTokens;
|
|
149
|
+
for (const key in indexTokenCountMap) {
|
|
150
|
+
indexTokenCountMap[key] = Math.round(indexTokenCountMap[key] * ratio);
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
lastTurnStartIndex = params.messages.length;
|
|
155
|
+
if (totalTokens <= factoryParams.maxTokens) {
|
|
156
|
+
return { context: params.messages, indexTokenCountMap };
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
const { context } = getMessagesWithinTokenLimit({
|
|
160
|
+
maxContextTokens: factoryParams.maxTokens,
|
|
161
|
+
messages: params.messages,
|
|
162
|
+
indexTokenCountMap,
|
|
163
|
+
});
|
|
164
|
+
|
|
165
|
+
return { context, indexTokenCountMap };
|
|
166
|
+
}
|
|
167
|
+
}
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import { shiftIndexTokenCountMap } from './format';
|
|
2
|
+
|
|
3
|
+
describe('shiftIndexTokenCountMap', () => {
|
|
4
|
+
it('should add a system message token count at index 0 and shift all other indices', () => {
|
|
5
|
+
const originalMap: Record<number, number> = {
|
|
6
|
+
0: 10,
|
|
7
|
+
1: 20,
|
|
8
|
+
2: 30
|
|
9
|
+
};
|
|
10
|
+
|
|
11
|
+
const systemMessageTokenCount = 15;
|
|
12
|
+
|
|
13
|
+
const result = shiftIndexTokenCountMap(originalMap, systemMessageTokenCount);
|
|
14
|
+
|
|
15
|
+
// Check that the system message token count is at index 0
|
|
16
|
+
expect(result[0]).toBe(15);
|
|
17
|
+
|
|
18
|
+
// Check that all other indices are shifted by 1
|
|
19
|
+
expect(result[1]).toBe(10);
|
|
20
|
+
expect(result[2]).toBe(20);
|
|
21
|
+
expect(result[3]).toBe(30);
|
|
22
|
+
|
|
23
|
+
// Check that the original map is not modified
|
|
24
|
+
expect(originalMap[0]).toBe(10);
|
|
25
|
+
expect(originalMap[1]).toBe(20);
|
|
26
|
+
expect(originalMap[2]).toBe(30);
|
|
27
|
+
});
|
|
28
|
+
|
|
29
|
+
it('should handle an empty map', () => {
|
|
30
|
+
const emptyMap: Record<number, number> = {};
|
|
31
|
+
const systemMessageTokenCount = 15;
|
|
32
|
+
|
|
33
|
+
const result = shiftIndexTokenCountMap(emptyMap, systemMessageTokenCount);
|
|
34
|
+
|
|
35
|
+
// Check that only the system message token count is in the result
|
|
36
|
+
expect(Object.keys(result).length).toBe(1);
|
|
37
|
+
expect(result[0]).toBe(15);
|
|
38
|
+
});
|
|
39
|
+
|
|
40
|
+
it('should handle non-sequential indices', () => {
|
|
41
|
+
const nonSequentialMap: Record<number, number> = {
|
|
42
|
+
0: 10,
|
|
43
|
+
2: 20,
|
|
44
|
+
5: 30
|
|
45
|
+
};
|
|
46
|
+
|
|
47
|
+
const systemMessageTokenCount = 15;
|
|
48
|
+
|
|
49
|
+
const result = shiftIndexTokenCountMap(nonSequentialMap, systemMessageTokenCount);
|
|
50
|
+
|
|
51
|
+
// Check that the system message token count is at index 0
|
|
52
|
+
expect(result[0]).toBe(15);
|
|
53
|
+
|
|
54
|
+
// Check that all other indices are shifted by 1
|
|
55
|
+
expect(result[1]).toBe(10);
|
|
56
|
+
expect(result[3]).toBe(20);
|
|
57
|
+
expect(result[6]).toBe(30);
|
|
58
|
+
});
|
|
59
|
+
|
|
60
|
+
it('should handle string keys', () => {
|
|
61
|
+
// TypeScript will convert string keys to numbers when accessing the object
|
|
62
|
+
const mapWithStringKeys: Record<string, number> = {
|
|
63
|
+
'0': 10,
|
|
64
|
+
'1': 20,
|
|
65
|
+
'2': 30
|
|
66
|
+
};
|
|
67
|
+
|
|
68
|
+
const systemMessageTokenCount = 15;
|
|
69
|
+
|
|
70
|
+
// Cast to Record<number, number> to match the function signature
|
|
71
|
+
const result = shiftIndexTokenCountMap(mapWithStringKeys as unknown as Record<number, number>, systemMessageTokenCount);
|
|
72
|
+
|
|
73
|
+
// Check that the system message token count is at index 0
|
|
74
|
+
expect(result[0]).toBe(15);
|
|
75
|
+
|
|
76
|
+
// Check that all other indices are shifted by 1
|
|
77
|
+
expect(result[1]).toBe(10);
|
|
78
|
+
expect(result[2]).toBe(20);
|
|
79
|
+
expect(result[3]).toBe(30);
|
|
80
|
+
});
|
|
81
|
+
});
|
package/src/run.ts
CHANGED
|
@@ -1,13 +1,17 @@
|
|
|
1
1
|
// src/run.ts
|
|
2
|
+
import { zodToJsonSchema } from "zod-to-json-schema";
|
|
2
3
|
import { PromptTemplate } from '@langchain/core/prompts';
|
|
3
4
|
import { AzureChatOpenAI, ChatOpenAI } from '@langchain/openai';
|
|
5
|
+
import { SystemMessage } from '@langchain/core/messages';
|
|
4
6
|
import type { BaseMessage, MessageContentComplex } from '@langchain/core/messages';
|
|
5
7
|
import type { ClientCallbacks, SystemCallbacks } from '@/graphs/Graph';
|
|
6
8
|
import type { RunnableConfig } from '@langchain/core/runnables';
|
|
7
9
|
import type * as t from '@/types';
|
|
8
10
|
import { GraphEvents, Providers, Callback } from '@/common';
|
|
9
11
|
import { manualToolStreamProviders } from '@/llm/providers';
|
|
12
|
+
import { shiftIndexTokenCountMap } from '@/messages/format';
|
|
10
13
|
import { createTitleRunnable } from '@/utils/title';
|
|
14
|
+
import { createTokenCounter } from '@/utils/tokens';
|
|
11
15
|
import { StandardGraph } from '@/graphs/Graph';
|
|
12
16
|
import { HandlerRegistry } from '@/events';
|
|
13
17
|
import { isOpenAILike } from '@/utils/llm';
|
|
@@ -106,6 +110,28 @@ export class Run<T extends t.BaseGraphState> {
|
|
|
106
110
|
throw new Error('Run ID not provided');
|
|
107
111
|
}
|
|
108
112
|
|
|
113
|
+
const tokenCounter = streamOptions?.tokenCounter ?? (streamOptions?.indexTokenCountMap ? await createTokenCounter() : undefined);
|
|
114
|
+
const toolTokens = tokenCounter ? (this.Graph.tools?.reduce((acc, tool) => {
|
|
115
|
+
if (!tool.schema) {
|
|
116
|
+
return acc;
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
const jsonSchema = zodToJsonSchema(tool.schema.describe(tool.description ?? ''), tool.name);
|
|
120
|
+
return acc + tokenCounter(new SystemMessage(JSON.stringify(jsonSchema)));
|
|
121
|
+
}, 0) ?? 0) : 0;
|
|
122
|
+
let instructionTokens = toolTokens;
|
|
123
|
+
if (this.Graph.systemMessage && tokenCounter) {
|
|
124
|
+
instructionTokens += tokenCounter(this.Graph.systemMessage);
|
|
125
|
+
}
|
|
126
|
+
if (instructionTokens > 0) {
|
|
127
|
+
this.Graph.indexTokenCountMap = shiftIndexTokenCountMap(streamOptions?.indexTokenCountMap ?? {}, instructionTokens);
|
|
128
|
+
} else {
|
|
129
|
+
this.Graph.indexTokenCountMap = streamOptions?.indexTokenCountMap ?? {};
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
this.Graph.maxContextTokens = streamOptions?.maxContextTokens;
|
|
133
|
+
this.Graph.tokenCounter = tokenCounter;
|
|
134
|
+
|
|
109
135
|
config.run_id = this.id;
|
|
110
136
|
config.configurable = Object.assign(config.configurable ?? {}, { run_id: this.id, provider: this.provider });
|
|
111
137
|
|
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
// src/scripts/cli.ts
|
|
2
2
|
import { config } from 'dotenv';
|
|
3
3
|
config();
|
|
4
|
-
import { HumanMessage,
|
|
4
|
+
import { HumanMessage, BaseMessage } from '@langchain/core/messages';
|
|
5
5
|
import { TavilySearchResults } from '@langchain/community/tools/tavily_search';
|
|
6
6
|
import type * as t from '@/types';
|
|
7
7
|
import { ChatModelStreamHandler, createContentAggregator } from '@/stream';
|
|
8
|
-
import { ToolEndHandler, ModelEndHandler, createMetadataAggregator } from '@/events';
|
|
9
8
|
import { createCodeExecutionTool } from '@/tools/CodeExecutor';
|
|
9
|
+
import { ToolEndHandler, ModelEndHandler } from '@/events';
|
|
10
|
+
import { createTokenCounter } from '@/utils/tokens';
|
|
10
11
|
import { getLLMConfig } from '@/utils/llmConfig';
|
|
11
12
|
import { getArgs } from '@/scripts/args';
|
|
12
13
|
import { GraphEvents } from '@/common';
|
|
@@ -58,19 +59,22 @@ async function testCodeExecution(): Promise<void> {
|
|
|
58
59
|
};
|
|
59
60
|
|
|
60
61
|
const llmConfig = getLLMConfig(provider);
|
|
62
|
+
const instructions = 'You are a friendly AI assistant with coding capabilities. Always address the user by their name.';
|
|
63
|
+
const additional_instructions = `The user's name is ${userName} and they are located in ${location}.`;
|
|
61
64
|
|
|
62
|
-
const
|
|
65
|
+
const runConfig: t.RunConfig = {
|
|
63
66
|
runId: 'message-num-1',
|
|
64
67
|
graphConfig: {
|
|
65
68
|
type: 'standard',
|
|
66
69
|
llmConfig,
|
|
67
70
|
tools: [new TavilySearchResults(), createCodeExecutionTool()],
|
|
68
|
-
instructions
|
|
69
|
-
additional_instructions
|
|
71
|
+
instructions,
|
|
72
|
+
additional_instructions,
|
|
70
73
|
},
|
|
71
74
|
returnContent: true,
|
|
72
75
|
customHandlers,
|
|
73
|
-
}
|
|
76
|
+
};
|
|
77
|
+
const run = await Run.create<t.IState>(runConfig);
|
|
74
78
|
|
|
75
79
|
const config = {
|
|
76
80
|
configurable: {
|
|
@@ -86,13 +90,22 @@ async function testCodeExecution(): Promise<void> {
|
|
|
86
90
|
// const userMessage1 = `how much memory is this (its in bytes) in MB? 31192000`;
|
|
87
91
|
// const userMessage1 = `can you show me a good use case for rscript by running some code`;
|
|
88
92
|
const userMessage1 = `Run hello world in french and in english, using python. please run 2 parallel code executions.`;
|
|
93
|
+
const humanMessage = new HumanMessage(userMessage1);
|
|
94
|
+
const tokenCounter = await createTokenCounter();
|
|
95
|
+
const indexTokenCountMap = {
|
|
96
|
+
0: tokenCounter(humanMessage),
|
|
97
|
+
};
|
|
89
98
|
|
|
90
|
-
conversationHistory.push(
|
|
99
|
+
conversationHistory.push(humanMessage);
|
|
91
100
|
|
|
92
101
|
let inputs = {
|
|
93
102
|
messages: conversationHistory,
|
|
94
103
|
};
|
|
95
|
-
const finalContentParts1 = await run.processStream(inputs, config
|
|
104
|
+
const finalContentParts1 = await run.processStream(inputs, config, {
|
|
105
|
+
maxContextTokens: 8000,
|
|
106
|
+
indexTokenCountMap,
|
|
107
|
+
tokenCounter,
|
|
108
|
+
});
|
|
96
109
|
const finalMessages1 = run.getRunMessages();
|
|
97
110
|
if (finalMessages1) {
|
|
98
111
|
conversationHistory.push(...finalMessages1);
|