@librechat/agents 2.2.1 → 2.2.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/cjs/graphs/Graph.cjs +56 -19
- package/dist/cjs/graphs/Graph.cjs.map +1 -1
- package/dist/cjs/main.cjs +18 -8
- package/dist/cjs/main.cjs.map +1 -1
- package/dist/cjs/{messages.cjs → messages/core.cjs} +2 -2
- package/dist/cjs/messages/core.cjs.map +1 -0
- package/dist/cjs/messages/format.cjs +334 -0
- package/dist/cjs/messages/format.cjs.map +1 -0
- package/dist/cjs/messages/prune.cjs +124 -0
- package/dist/cjs/messages/prune.cjs.map +1 -0
- package/dist/cjs/run.cjs +24 -0
- package/dist/cjs/run.cjs.map +1 -1
- package/dist/cjs/utils/tokens.cjs +64 -0
- package/dist/cjs/utils/tokens.cjs.map +1 -0
- package/dist/esm/graphs/Graph.mjs +51 -14
- package/dist/esm/graphs/Graph.mjs.map +1 -1
- package/dist/esm/main.mjs +3 -1
- package/dist/esm/main.mjs.map +1 -1
- package/dist/esm/{messages.mjs → messages/core.mjs} +2 -2
- package/dist/esm/messages/core.mjs.map +1 -0
- package/dist/esm/messages/format.mjs +326 -0
- package/dist/esm/messages/format.mjs.map +1 -0
- package/dist/esm/messages/prune.mjs +122 -0
- package/dist/esm/messages/prune.mjs.map +1 -0
- package/dist/esm/run.mjs +24 -0
- package/dist/esm/run.mjs.map +1 -1
- package/dist/esm/utils/tokens.mjs +62 -0
- package/dist/esm/utils/tokens.mjs.map +1 -0
- package/dist/types/graphs/Graph.d.ts +8 -1
- package/dist/types/messages/format.d.ts +120 -0
- package/dist/types/messages/index.d.ts +3 -0
- package/dist/types/messages/prune.d.ts +16 -0
- package/dist/types/types/run.d.ts +4 -0
- package/dist/types/utils/tokens.d.ts +2 -0
- package/package.json +1 -1
- package/src/graphs/Graph.ts +54 -16
- package/src/messages/format.ts +460 -0
- package/src/messages/formatAgentMessages.test.ts +628 -0
- package/src/messages/formatMessage.test.ts +277 -0
- package/src/messages/index.ts +3 -0
- package/src/messages/prune.ts +167 -0
- package/src/messages/shiftIndexTokenCountMap.test.ts +81 -0
- package/src/run.ts +26 -0
- package/src/scripts/code_exec_simple.ts +21 -8
- package/src/specs/prune.test.ts +444 -0
- package/src/types/run.ts +5 -0
- package/src/utils/tokens.ts +70 -0
- package/dist/cjs/messages.cjs.map +0 -1
- package/dist/esm/messages.mjs.map +0 -1
- /package/dist/types/{messages.d.ts → messages/core.d.ts} +0 -0
- /package/src/{messages.ts → messages/core.ts} +0 -0
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import { Tiktoken } from 'js-tiktoken/lite';
|
|
2
|
+
import { ContentTypes } from '../common/enum.mjs';
|
|
3
|
+
|
|
4
|
+
function getTokenCountForMessage(message, getTokenCount) {
|
|
5
|
+
let tokensPerMessage = 3;
|
|
6
|
+
const processValue = (value) => {
|
|
7
|
+
if (Array.isArray(value)) {
|
|
8
|
+
for (let item of value) {
|
|
9
|
+
if (!item ||
|
|
10
|
+
!item.type ||
|
|
11
|
+
item.type === ContentTypes.ERROR ||
|
|
12
|
+
item.type === ContentTypes.IMAGE_URL) {
|
|
13
|
+
continue;
|
|
14
|
+
}
|
|
15
|
+
if (item.type === ContentTypes.TOOL_CALL && item.tool_call != null) {
|
|
16
|
+
const toolName = item.tool_call?.name || '';
|
|
17
|
+
if (toolName != null && toolName && typeof toolName === 'string') {
|
|
18
|
+
numTokens += getTokenCount(toolName);
|
|
19
|
+
}
|
|
20
|
+
const args = item.tool_call?.args || '';
|
|
21
|
+
if (args != null && args && typeof args === 'string') {
|
|
22
|
+
numTokens += getTokenCount(args);
|
|
23
|
+
}
|
|
24
|
+
const output = item.tool_call?.output || '';
|
|
25
|
+
if (output != null && output && typeof output === 'string') {
|
|
26
|
+
numTokens += getTokenCount(output);
|
|
27
|
+
}
|
|
28
|
+
continue;
|
|
29
|
+
}
|
|
30
|
+
const nestedValue = item[item.type];
|
|
31
|
+
if (!nestedValue) {
|
|
32
|
+
continue;
|
|
33
|
+
}
|
|
34
|
+
processValue(nestedValue);
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
else if (typeof value === 'string') {
|
|
38
|
+
numTokens += getTokenCount(value);
|
|
39
|
+
}
|
|
40
|
+
else if (typeof value === 'number') {
|
|
41
|
+
numTokens += getTokenCount(value.toString());
|
|
42
|
+
}
|
|
43
|
+
else if (typeof value === 'boolean') {
|
|
44
|
+
numTokens += getTokenCount(value.toString());
|
|
45
|
+
}
|
|
46
|
+
};
|
|
47
|
+
let numTokens = tokensPerMessage;
|
|
48
|
+
processValue(message.content);
|
|
49
|
+
return numTokens;
|
|
50
|
+
}
|
|
51
|
+
const createTokenCounter = async () => {
|
|
52
|
+
const res = await fetch(`https://tiktoken.pages.dev/js/o200k_base.json`);
|
|
53
|
+
const o200k_base = await res.json();
|
|
54
|
+
const countTokens = (text) => {
|
|
55
|
+
const enc = new Tiktoken(o200k_base);
|
|
56
|
+
return enc.encode(text).length;
|
|
57
|
+
};
|
|
58
|
+
return (message) => getTokenCountForMessage(message, countTokens);
|
|
59
|
+
};
|
|
60
|
+
|
|
61
|
+
export { createTokenCounter };
|
|
62
|
+
//# sourceMappingURL=tokens.mjs.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"tokens.mjs","sources":["../../../src/utils/tokens.ts"],"sourcesContent":["import { Tiktoken } from \"js-tiktoken/lite\";\nimport type { BaseMessage } from \"@langchain/core/messages\";\nimport { ContentTypes } from \"@/common/enum\";\n\nfunction getTokenCountForMessage(message: BaseMessage, getTokenCount: (text: string) => number): number {\n let tokensPerMessage = 3;\n\n const processValue = (value: unknown) => {\n if (Array.isArray(value)) {\n for (let item of value) {\n if (\n !item ||\n !item.type ||\n item.type === ContentTypes.ERROR ||\n item.type === ContentTypes.IMAGE_URL\n ) {\n continue;\n }\n\n if (item.type === ContentTypes.TOOL_CALL && item.tool_call != null) {\n const toolName = item.tool_call?.name || '';\n if (toolName != null && toolName && typeof toolName === 'string') {\n numTokens += getTokenCount(toolName);\n }\n\n const args = item.tool_call?.args || '';\n if (args != null && args && typeof args === 'string') {\n numTokens += getTokenCount(args);\n }\n\n const output = item.tool_call?.output || '';\n if (output != null && output && typeof output === 'string') {\n numTokens += getTokenCount(output);\n }\n continue;\n }\n\n const nestedValue = item[item.type];\n\n if (!nestedValue) {\n continue;\n }\n\n processValue(nestedValue);\n }\n } else if (typeof value === 'string') {\n numTokens += getTokenCount(value);\n } else if (typeof value === 'number') {\n numTokens += getTokenCount(value.toString());\n } else if (typeof value === 'boolean') {\n numTokens += getTokenCount(value.toString());\n }\n };\n\n let numTokens = tokensPerMessage;\n processValue(message.content);\n return numTokens;\n}\n\nexport const createTokenCounter = async () => {\n const res = await fetch(`https://tiktoken.pages.dev/js/o200k_base.json`);\n const o200k_base = await res.json();\n\n const countTokens = (text: string) => {\n const enc = new Tiktoken(o200k_base);\n return enc.encode(text).length;\n }\n \n return (message: BaseMessage) => getTokenCountForMessage(message, countTokens);\n}"],"names":[],"mappings":";;;AAIA,SAAS,uBAAuB,CAAC,OAAoB,EAAE,aAAuC,EAAA;IAC5F,IAAI,gBAAgB,GAAG,CAAC;AAExB,IAAA,MAAM,YAAY,GAAG,CAAC,KAAc,KAAI;AACtC,QAAA,IAAI,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;AACxB,YAAA,KAAK,IAAI,IAAI,IAAI,KAAK,EAAE;AACtB,gBAAA,IACE,CAAC,IAAI;oBACL,CAAC,IAAI,CAAC,IAAI;AACV,oBAAA,IAAI,CAAC,IAAI,KAAK,YAAY,CAAC,KAAK;AAChC,oBAAA,IAAI,CAAC,IAAI,KAAK,YAAY,CAAC,SAAS,EACpC;oBACA;;AAGF,gBAAA,IAAI,IAAI,CAAC,IAAI,KAAK,YAAY,CAAC,SAAS,IAAI,IAAI,CAAC,SAAS,IAAI,IAAI,EAAE;oBAClE,MAAM,QAAQ,GAAG,IAAI,CAAC,SAAS,EAAE,IAAI,IAAI,EAAE;oBAC3C,IAAI,QAAQ,IAAI,IAAI,IAAI,QAAQ,IAAI,OAAO,QAAQ,KAAK,QAAQ,EAAE;AAChE,wBAAA,SAAS,IAAI,aAAa,CAAC,QAAQ,CAAC;;oBAGtC,MAAM,IAAI,GAAG,IAAI,CAAC,SAAS,EAAE,IAAI,IAAI,EAAE;oBACvC,IAAI,IAAI,IAAI,IAAI,IAAI,IAAI,IAAI,OAAO,IAAI,KAAK,QAAQ,EAAE;AACpD,wBAAA,SAAS,IAAI,aAAa,CAAC,IAAI,CAAC;;oBAGlC,MAAM,MAAM,GAAG,IAAI,CAAC,SAAS,EAAE,MAAM,IAAI,EAAE;oBAC3C,IAAI,MAAM,IAAI,IAAI,IAAI,MAAM,IAAI,OAAO,MAAM,KAAK,QAAQ,EAAE;AAC1D,wBAAA,SAAS,IAAI,aAAa,CAAC,MAAM,CAAC;;oBAEpC;;gBAGF,MAAM,WAAW,GAAG,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC;gBAEnC,IAAI,CAAC,WAAW,EAAE;oBAChB;;gBAGF,YAAY,CAAC,WAAW,CAAC;;;AAEtB,aAAA,IAAI,OAAO,KAAK,KAAK,QAAQ,EAAE;AACpC,YAAA,SAAS,IAAI,aAAa,CAAC,KAAK,CAAC;;AAC5B,aAAA,IAAI,OAAO,KAAK,KAAK,QAAQ,EAAE;YACpC,SAAS,IAAI,aAAa,CAAC,KAAK,CAAC,QAAQ,EAAE,CAAC;;AACvC,aAAA,IAAI,OAAO,KAAK,KAAK,SAAS,EAAE;YACrC,SAAS,IAAI,aAAa,CAAC,KAAK,CAAC,QAAQ,EAAE,CAAC;;AAEhD,KAAC;IAED,IAAI,SAAS,GAAG,gBAAgB;AAChC,IAAA,YAAY,CAAC,OAAO,CAAC,OAAO,CAAC;AAC7B,IAAA,OAAO,SAAS;AAClB;AAEa,MAAA,kBAAkB,GAAG,YAAW;AAC3C,IAAA,MAAM,GAAG,GAAG,MAAM,KAAK,CAAC,CAAA,6CAAA,CAA+C,CAAC;AACxE,IAAA,MAAM,UAAU,GAAG,MAAM,GAAG,CAAC,IAAI,EAAE;AAEnC,IAAA,MAAM,WAAW,GAAG,CAAC,IAAY,KAAI;AACnC,QAAA,MAAM,GAAG,GAAG,IAAI,QAAQ,CAAC,UAAU,CAAC;QACpC,OAAO,GAAG,CAAC,MAAM,CAAC,IAAI,CAAC,CAAC,MAAM;AAChC,KAAC;IAED,OAAO,CAAC,OAAoB,KAAK,uBAAuB,CAAC,OAAO,EAAE,WAAW,CAAC;AAChF;;;;"}
|
|
@@ -2,10 +2,11 @@ import { ToolNode } from '@langchain/langgraph/prebuilt';
|
|
|
2
2
|
import { START } from '@langchain/langgraph';
|
|
3
3
|
import { Runnable, RunnableConfig } from '@langchain/core/runnables';
|
|
4
4
|
import { SystemMessage } from '@langchain/core/messages';
|
|
5
|
-
import type { BaseMessage } from '@langchain/core/messages';
|
|
5
|
+
import type { BaseMessage, UsageMetadata } from '@langchain/core/messages';
|
|
6
6
|
import type * as t from '@/types';
|
|
7
7
|
import { Providers, GraphNodeKeys, Callback, ContentTypes } from '@/common';
|
|
8
8
|
import { ToolNode as CustomToolNode } from '@/tools/ToolNode';
|
|
9
|
+
import { createPruneMessages } from '@/messages';
|
|
9
10
|
import { HandlerRegistry } from '@/events';
|
|
10
11
|
export type GraphNode = GraphNodeKeys | typeof START;
|
|
11
12
|
export type ClientCallback<T extends unknown[]> = (graph: StandardGraph, ...args: T) => void;
|
|
@@ -49,8 +50,13 @@ export declare abstract class Graph<T extends t.BaseGraphState = t.BaseGraphStat
|
|
|
49
50
|
stepKeyIds: Map<string, string[]>;
|
|
50
51
|
contentIndexMap: Map<string, number>;
|
|
51
52
|
toolCallStepIds: Map<string, string>;
|
|
53
|
+
currentUsage: Partial<UsageMetadata> | undefined;
|
|
54
|
+
indexTokenCountMap: Record<string, number>;
|
|
55
|
+
maxContextTokens: number | undefined;
|
|
56
|
+
pruneMessages?: ReturnType<typeof createPruneMessages>;
|
|
52
57
|
/** The amount of time that should pass before another consecutive API call */
|
|
53
58
|
streamBuffer: number | undefined;
|
|
59
|
+
tokenCounter?: t.TokenCounter;
|
|
54
60
|
signal?: AbortSignal;
|
|
55
61
|
}
|
|
56
62
|
export declare class StandardGraph extends Graph<t.BaseGraphState, GraphNode> {
|
|
@@ -87,6 +93,7 @@ export declare class StandardGraph extends Graph<t.BaseGraphState, GraphNode> {
|
|
|
87
93
|
clientOptions?: t.ClientOptions;
|
|
88
94
|
omitOriginalOptions?: string[];
|
|
89
95
|
}): t.ChatModelInstance;
|
|
96
|
+
storeUsageMetadata(finalMessage?: BaseMessage): void;
|
|
90
97
|
createCallModel(): (state: t.BaseGraphState, config?: RunnableConfig) => Promise<Partial<t.BaseGraphState>>;
|
|
91
98
|
createWorkflow(): t.CompiledWorkflow<t.BaseGraphState>;
|
|
92
99
|
/**
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import { ToolMessage, BaseMessage } from '@langchain/core/messages';
|
|
2
|
+
import { HumanMessage, AIMessage, SystemMessage } from '@langchain/core/messages';
|
|
3
|
+
import { MessageContentImageUrl } from '@langchain/core/messages';
|
|
4
|
+
import type { MessageContentComplex } from '@/types';
|
|
5
|
+
import { Providers, ContentTypes } from '@/common';
|
|
6
|
+
interface VisionMessageParams {
|
|
7
|
+
message: {
|
|
8
|
+
role: string;
|
|
9
|
+
content: string;
|
|
10
|
+
name?: string;
|
|
11
|
+
[key: string]: any;
|
|
12
|
+
};
|
|
13
|
+
image_urls: MessageContentImageUrl[];
|
|
14
|
+
endpoint?: Providers;
|
|
15
|
+
}
|
|
16
|
+
/**
|
|
17
|
+
* Formats a message to OpenAI Vision API payload format.
|
|
18
|
+
*
|
|
19
|
+
* @param {VisionMessageParams} params - The parameters for formatting.
|
|
20
|
+
* @returns {Object} - The formatted message.
|
|
21
|
+
*/
|
|
22
|
+
export declare const formatVisionMessage: ({ message, image_urls, endpoint }: VisionMessageParams) => {
|
|
23
|
+
role: string;
|
|
24
|
+
content: MessageContentComplex[];
|
|
25
|
+
name?: string;
|
|
26
|
+
[key: string]: any;
|
|
27
|
+
};
|
|
28
|
+
interface MessageInput {
|
|
29
|
+
role?: string;
|
|
30
|
+
_name?: string;
|
|
31
|
+
sender?: string;
|
|
32
|
+
text?: string;
|
|
33
|
+
content?: string | MessageContentComplex[];
|
|
34
|
+
image_urls?: MessageContentImageUrl[];
|
|
35
|
+
lc_id?: string[];
|
|
36
|
+
[key: string]: any;
|
|
37
|
+
}
|
|
38
|
+
interface FormatMessageParams {
|
|
39
|
+
message: MessageInput;
|
|
40
|
+
userName?: string;
|
|
41
|
+
assistantName?: string;
|
|
42
|
+
endpoint?: Providers;
|
|
43
|
+
langChain?: boolean;
|
|
44
|
+
}
|
|
45
|
+
interface FormattedMessage {
|
|
46
|
+
role: string;
|
|
47
|
+
content: string | MessageContentComplex[];
|
|
48
|
+
name?: string;
|
|
49
|
+
[key: string]: any;
|
|
50
|
+
}
|
|
51
|
+
/**
|
|
52
|
+
* Formats a message to OpenAI payload format based on the provided options.
|
|
53
|
+
*
|
|
54
|
+
* @param {FormatMessageParams} params - The parameters for formatting.
|
|
55
|
+
* @returns {FormattedMessage | HumanMessage | AIMessage | SystemMessage} - The formatted message.
|
|
56
|
+
*/
|
|
57
|
+
export declare const formatMessage: ({ message, userName, assistantName, endpoint, langChain }: FormatMessageParams) => FormattedMessage | HumanMessage | AIMessage | SystemMessage;
|
|
58
|
+
/**
|
|
59
|
+
* Formats an array of messages for LangChain.
|
|
60
|
+
*
|
|
61
|
+
* @param {Array<MessageInput>} messages - The array of messages to format.
|
|
62
|
+
* @param {Omit<FormatMessageParams, 'message' | 'langChain'>} formatOptions - The options for formatting each message.
|
|
63
|
+
* @returns {Array<HumanMessage | AIMessage | SystemMessage>} - The array of formatted LangChain messages.
|
|
64
|
+
*/
|
|
65
|
+
export declare const formatLangChainMessages: (messages: Array<MessageInput>, formatOptions: Omit<FormatMessageParams, "message" | "langChain">) => Array<HumanMessage | AIMessage | SystemMessage>;
|
|
66
|
+
interface LangChainMessage {
|
|
67
|
+
lc_kwargs?: {
|
|
68
|
+
additional_kwargs?: Record<string, any>;
|
|
69
|
+
[key: string]: any;
|
|
70
|
+
};
|
|
71
|
+
kwargs?: {
|
|
72
|
+
additional_kwargs?: Record<string, any>;
|
|
73
|
+
[key: string]: any;
|
|
74
|
+
};
|
|
75
|
+
[key: string]: any;
|
|
76
|
+
}
|
|
77
|
+
/**
|
|
78
|
+
* Formats a LangChain message object by merging properties from `lc_kwargs` or `kwargs` and `additional_kwargs`.
|
|
79
|
+
*
|
|
80
|
+
* @param {LangChainMessage} message - The message object to format.
|
|
81
|
+
* @returns {Record<string, any>} The formatted LangChain message.
|
|
82
|
+
*/
|
|
83
|
+
export declare const formatFromLangChain: (message: LangChainMessage) => Record<string, any>;
|
|
84
|
+
interface TMessage {
|
|
85
|
+
role?: string;
|
|
86
|
+
content?: string | Array<{
|
|
87
|
+
type: ContentTypes;
|
|
88
|
+
text?: string;
|
|
89
|
+
tool_call_ids?: string[];
|
|
90
|
+
[key: string]: any;
|
|
91
|
+
}>;
|
|
92
|
+
[key: string]: any;
|
|
93
|
+
}
|
|
94
|
+
/**
|
|
95
|
+
* Formats an array of messages for LangChain, handling tool calls and creating ToolMessage instances.
|
|
96
|
+
*
|
|
97
|
+
* @param {Array<Partial<TMessage>>} payload - The array of messages to format.
|
|
98
|
+
* @param {Record<number, number>} [indexTokenCountMap] - Optional map of message indices to token counts.
|
|
99
|
+
* @returns {Object} - Object containing formatted messages and updated indexTokenCountMap if provided.
|
|
100
|
+
*/
|
|
101
|
+
export declare const formatAgentMessages: (payload: Array<Partial<TMessage>>, indexTokenCountMap?: Record<number, number>) => {
|
|
102
|
+
messages: Array<HumanMessage | AIMessage | SystemMessage | ToolMessage>;
|
|
103
|
+
indexTokenCountMap?: Record<number, number>;
|
|
104
|
+
};
|
|
105
|
+
/**
|
|
106
|
+
* Formats an array of messages for LangChain, making sure all content fields are strings
|
|
107
|
+
* @param {Array<HumanMessage | AIMessage | SystemMessage | ToolMessage>} payload - The array of messages to format.
|
|
108
|
+
* @returns {Array<HumanMessage | AIMessage | SystemMessage | ToolMessage>} - The array of formatted LangChain messages, including ToolMessages for tool calls.
|
|
109
|
+
*/
|
|
110
|
+
export declare const formatContentStrings: (payload: Array<BaseMessage>) => Array<BaseMessage>;
|
|
111
|
+
/**
|
|
112
|
+
* Adds a value at key 0 for system messages and shifts all key indices by one in an indexTokenCountMap.
|
|
113
|
+
* This is useful when adding a system message at the beginning of a conversation.
|
|
114
|
+
*
|
|
115
|
+
* @param indexTokenCountMap - The original map of message indices to token counts
|
|
116
|
+
* @param instructionsTokenCount - The token count for the system message to add at index 0
|
|
117
|
+
* @returns A new map with the system message at index 0 and all other indices shifted by 1
|
|
118
|
+
*/
|
|
119
|
+
export declare function shiftIndexTokenCountMap(indexTokenCountMap: Record<number, number>, instructionsTokenCount: number): Record<number, number>;
|
|
120
|
+
export {};
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import type { BaseMessage, UsageMetadata } from '@langchain/core/messages';
|
|
2
|
+
import type { TokenCounter } from '@/types/run';
|
|
3
|
+
export type PruneMessagesFactoryParams = {
|
|
4
|
+
maxTokens: number;
|
|
5
|
+
startIndex: number;
|
|
6
|
+
tokenCounter: TokenCounter;
|
|
7
|
+
indexTokenCountMap: Record<string, number>;
|
|
8
|
+
};
|
|
9
|
+
export type PruneMessagesParams = {
|
|
10
|
+
messages: BaseMessage[];
|
|
11
|
+
usageMetadata?: Partial<UsageMetadata>;
|
|
12
|
+
};
|
|
13
|
+
export declare function createPruneMessages(factoryParams: PruneMessagesFactoryParams): (params: PruneMessagesParams) => {
|
|
14
|
+
context: BaseMessage[];
|
|
15
|
+
indexTokenCountMap: Record<string, number>;
|
|
16
|
+
};
|
|
@@ -50,7 +50,11 @@ export type RunConfig = {
|
|
|
50
50
|
returnContent?: boolean;
|
|
51
51
|
};
|
|
52
52
|
export type ProvidedCallbacks = (BaseCallbackHandler | CallbackHandlerMethods)[] | undefined;
|
|
53
|
+
export type TokenCounter = (message: BaseMessage) => number;
|
|
53
54
|
export type EventStreamOptions = {
|
|
54
55
|
callbacks?: graph.ClientCallbacks;
|
|
55
56
|
keepContent?: boolean;
|
|
57
|
+
maxContextTokens?: number;
|
|
58
|
+
tokenCounter?: TokenCounter;
|
|
59
|
+
indexTokenCountMap?: Record<string, number>;
|
|
56
60
|
};
|
package/package.json
CHANGED
package/src/graphs/Graph.ts
CHANGED
|
@@ -8,12 +8,13 @@ import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai';
|
|
|
8
8
|
import { Runnable, RunnableConfig } from '@langchain/core/runnables';
|
|
9
9
|
import { dispatchCustomEvent } from '@langchain/core/callbacks/dispatch';
|
|
10
10
|
import { AIMessageChunk, ToolMessage, SystemMessage } from '@langchain/core/messages';
|
|
11
|
-
import type { BaseMessage, BaseMessageFields } from '@langchain/core/messages';
|
|
11
|
+
import type { BaseMessage, BaseMessageFields, UsageMetadata } from '@langchain/core/messages';
|
|
12
12
|
import type * as t from '@/types';
|
|
13
13
|
import { Providers, GraphEvents, GraphNodeKeys, StepTypes, Callback, ContentTypes } from '@/common';
|
|
14
14
|
import { getChatModelClass, manualToolStreamProviders } from '@/llm/providers';
|
|
15
15
|
import { ToolNode as CustomToolNode, toolsCondition } from '@/tools/ToolNode';
|
|
16
16
|
import {
|
|
17
|
+
createPruneMessages,
|
|
17
18
|
modifyDeltaProperties,
|
|
18
19
|
formatArtifactPayload,
|
|
19
20
|
convertMessagesToContent,
|
|
@@ -74,8 +75,13 @@ export abstract class Graph<
|
|
|
74
75
|
stepKeyIds: Map<string, string[]> = new Map<string, string[]>();
|
|
75
76
|
contentIndexMap: Map<string, number> = new Map();
|
|
76
77
|
toolCallStepIds: Map<string, string> = new Map();
|
|
78
|
+
currentUsage: Partial<UsageMetadata> | undefined;
|
|
79
|
+
indexTokenCountMap: Record<string, number> = {};
|
|
80
|
+
maxContextTokens: number | undefined;
|
|
81
|
+
pruneMessages?: ReturnType<typeof createPruneMessages>;
|
|
77
82
|
/** The amount of time that should pass before another consecutive API call */
|
|
78
83
|
streamBuffer: number | undefined;
|
|
84
|
+
tokenCounter?: t.TokenCounter;
|
|
79
85
|
signal?: AbortSignal;
|
|
80
86
|
}
|
|
81
87
|
|
|
@@ -166,6 +172,10 @@ export class StandardGraph extends Graph<
|
|
|
166
172
|
this.currentTokenType = resetIfNotEmpty(this.currentTokenType, ContentTypes.TEXT);
|
|
167
173
|
this.lastToken = resetIfNotEmpty(this.lastToken, undefined);
|
|
168
174
|
this.tokenTypeSwitch = resetIfNotEmpty(this.tokenTypeSwitch, undefined);
|
|
175
|
+
this.indexTokenCountMap = resetIfNotEmpty(this.indexTokenCountMap, {});
|
|
176
|
+
this.currentUsage = resetIfNotEmpty(this.currentUsage, undefined);
|
|
177
|
+
this.tokenCounter = resetIfNotEmpty(this.tokenCounter, undefined);
|
|
178
|
+
this.maxContextTokens = resetIfNotEmpty(this.maxContextTokens, undefined);
|
|
169
179
|
}
|
|
170
180
|
|
|
171
181
|
/* Run Step Processing */
|
|
@@ -326,6 +336,12 @@ export class StandardGraph extends Graph<
|
|
|
326
336
|
return new ChatModelClass(options);
|
|
327
337
|
}
|
|
328
338
|
|
|
339
|
+
storeUsageMetadata(finalMessage?: BaseMessage): void {
|
|
340
|
+
if (finalMessage && 'usage_metadata' in finalMessage && finalMessage.usage_metadata) {
|
|
341
|
+
this.currentUsage = finalMessage.usage_metadata as Partial<UsageMetadata>;
|
|
342
|
+
}
|
|
343
|
+
}
|
|
344
|
+
|
|
329
345
|
createCallModel() {
|
|
330
346
|
return async (state: t.BaseGraphState, config?: RunnableConfig): Promise<Partial<t.BaseGraphState>> => {
|
|
331
347
|
const { provider = '' } = (config?.configurable as t.GraphConfig | undefined) ?? {} ;
|
|
@@ -338,9 +354,27 @@ export class StandardGraph extends Graph<
|
|
|
338
354
|
this.config = config;
|
|
339
355
|
const { messages } = state;
|
|
340
356
|
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
357
|
+
let messagesToUse = messages;
|
|
358
|
+
if (!this.pruneMessages && this.tokenCounter && this.maxContextTokens && this.indexTokenCountMap[0] != null) {
|
|
359
|
+
this.pruneMessages = createPruneMessages({
|
|
360
|
+
indexTokenCountMap: this.indexTokenCountMap,
|
|
361
|
+
maxTokens: this.maxContextTokens,
|
|
362
|
+
tokenCounter: this.tokenCounter,
|
|
363
|
+
startIndex: this.startIndex,
|
|
364
|
+
});
|
|
365
|
+
}
|
|
366
|
+
if (this.pruneMessages) {
|
|
367
|
+
const { context, indexTokenCountMap } = this.pruneMessages({
|
|
368
|
+
messages,
|
|
369
|
+
usageMetadata: this.currentUsage,
|
|
370
|
+
});
|
|
371
|
+
this.indexTokenCountMap = indexTokenCountMap;
|
|
372
|
+
messagesToUse = context;
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
const finalMessages = messagesToUse;
|
|
376
|
+
const lastMessageX = finalMessages.length >= 2 ? finalMessages[finalMessages.length - 2] : null;
|
|
377
|
+
const lastMessageY = finalMessages.length >= 1 ? finalMessages[finalMessages.length - 1] : null;
|
|
344
378
|
|
|
345
379
|
if (
|
|
346
380
|
provider === Providers.BEDROCK
|
|
@@ -372,6 +406,7 @@ export class StandardGraph extends Graph<
|
|
|
372
406
|
|
|
373
407
|
this.lastStreamCall = Date.now();
|
|
374
408
|
|
|
409
|
+
let result: Partial<t.BaseGraphState>;
|
|
375
410
|
if ((this.tools?.length ?? 0) > 0 && manualToolStreamProviders.has(provider)) {
|
|
376
411
|
const stream = await this.boundModel.stream(finalMessages, config);
|
|
377
412
|
let finalChunk: AIMessageChunk | undefined;
|
|
@@ -385,19 +420,22 @@ export class StandardGraph extends Graph<
|
|
|
385
420
|
}
|
|
386
421
|
|
|
387
422
|
finalChunk = modifyDeltaProperties(this.provider, finalChunk);
|
|
388
|
-
|
|
389
|
-
}
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
}
|
|
423
|
+
result = { messages: [finalChunk as AIMessageChunk] };
|
|
424
|
+
} else {
|
|
425
|
+
const finalMessage = (await this.boundModel.invoke(finalMessages, config)) as AIMessageChunk;
|
|
426
|
+
if ((finalMessage.tool_calls?.length ?? 0) > 0) {
|
|
427
|
+
finalMessage.tool_calls = finalMessage.tool_calls?.filter((tool_call) => {
|
|
428
|
+
if (!tool_call.name) {
|
|
429
|
+
return false;
|
|
430
|
+
}
|
|
431
|
+
return true;
|
|
432
|
+
});
|
|
433
|
+
}
|
|
434
|
+
result = { messages: [finalMessage] };
|
|
399
435
|
}
|
|
400
|
-
|
|
436
|
+
|
|
437
|
+
this.storeUsageMetadata(result?.messages?.[0]);
|
|
438
|
+
return result;
|
|
401
439
|
};
|
|
402
440
|
}
|
|
403
441
|
|