@librechat/agents 2.2.1 → 2.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. package/dist/cjs/graphs/Graph.cjs +56 -19
  2. package/dist/cjs/graphs/Graph.cjs.map +1 -1
  3. package/dist/cjs/main.cjs +18 -8
  4. package/dist/cjs/main.cjs.map +1 -1
  5. package/dist/cjs/{messages.cjs → messages/core.cjs} +2 -2
  6. package/dist/cjs/messages/core.cjs.map +1 -0
  7. package/dist/cjs/messages/format.cjs +334 -0
  8. package/dist/cjs/messages/format.cjs.map +1 -0
  9. package/dist/cjs/messages/prune.cjs +124 -0
  10. package/dist/cjs/messages/prune.cjs.map +1 -0
  11. package/dist/cjs/run.cjs +24 -0
  12. package/dist/cjs/run.cjs.map +1 -1
  13. package/dist/cjs/utils/tokens.cjs +64 -0
  14. package/dist/cjs/utils/tokens.cjs.map +1 -0
  15. package/dist/esm/graphs/Graph.mjs +51 -14
  16. package/dist/esm/graphs/Graph.mjs.map +1 -1
  17. package/dist/esm/main.mjs +3 -1
  18. package/dist/esm/main.mjs.map +1 -1
  19. package/dist/esm/{messages.mjs → messages/core.mjs} +2 -2
  20. package/dist/esm/messages/core.mjs.map +1 -0
  21. package/dist/esm/messages/format.mjs +326 -0
  22. package/dist/esm/messages/format.mjs.map +1 -0
  23. package/dist/esm/messages/prune.mjs +122 -0
  24. package/dist/esm/messages/prune.mjs.map +1 -0
  25. package/dist/esm/run.mjs +24 -0
  26. package/dist/esm/run.mjs.map +1 -1
  27. package/dist/esm/utils/tokens.mjs +62 -0
  28. package/dist/esm/utils/tokens.mjs.map +1 -0
  29. package/dist/types/graphs/Graph.d.ts +8 -1
  30. package/dist/types/messages/format.d.ts +120 -0
  31. package/dist/types/messages/index.d.ts +3 -0
  32. package/dist/types/messages/prune.d.ts +16 -0
  33. package/dist/types/types/run.d.ts +4 -0
  34. package/dist/types/utils/tokens.d.ts +2 -0
  35. package/package.json +1 -1
  36. package/src/graphs/Graph.ts +54 -16
  37. package/src/messages/format.ts +460 -0
  38. package/src/messages/formatAgentMessages.test.ts +628 -0
  39. package/src/messages/formatMessage.test.ts +277 -0
  40. package/src/messages/index.ts +3 -0
  41. package/src/messages/prune.ts +167 -0
  42. package/src/messages/shiftIndexTokenCountMap.test.ts +81 -0
  43. package/src/run.ts +26 -0
  44. package/src/scripts/code_exec_simple.ts +21 -8
  45. package/src/specs/prune.test.ts +444 -0
  46. package/src/types/run.ts +5 -0
  47. package/src/utils/tokens.ts +70 -0
  48. package/dist/cjs/messages.cjs.map +0 -1
  49. package/dist/esm/messages.mjs.map +0 -1
  50. /package/dist/types/{messages.d.ts → messages/core.d.ts} +0 -0
  51. /package/src/{messages.ts → messages/core.ts} +0 -0
@@ -0,0 +1,62 @@
1
+ import { Tiktoken } from 'js-tiktoken/lite';
2
+ import { ContentTypes } from '../common/enum.mjs';
3
+
4
+ function getTokenCountForMessage(message, getTokenCount) {
5
+ let tokensPerMessage = 3;
6
+ const processValue = (value) => {
7
+ if (Array.isArray(value)) {
8
+ for (let item of value) {
9
+ if (!item ||
10
+ !item.type ||
11
+ item.type === ContentTypes.ERROR ||
12
+ item.type === ContentTypes.IMAGE_URL) {
13
+ continue;
14
+ }
15
+ if (item.type === ContentTypes.TOOL_CALL && item.tool_call != null) {
16
+ const toolName = item.tool_call?.name || '';
17
+ if (toolName != null && toolName && typeof toolName === 'string') {
18
+ numTokens += getTokenCount(toolName);
19
+ }
20
+ const args = item.tool_call?.args || '';
21
+ if (args != null && args && typeof args === 'string') {
22
+ numTokens += getTokenCount(args);
23
+ }
24
+ const output = item.tool_call?.output || '';
25
+ if (output != null && output && typeof output === 'string') {
26
+ numTokens += getTokenCount(output);
27
+ }
28
+ continue;
29
+ }
30
+ const nestedValue = item[item.type];
31
+ if (!nestedValue) {
32
+ continue;
33
+ }
34
+ processValue(nestedValue);
35
+ }
36
+ }
37
+ else if (typeof value === 'string') {
38
+ numTokens += getTokenCount(value);
39
+ }
40
+ else if (typeof value === 'number') {
41
+ numTokens += getTokenCount(value.toString());
42
+ }
43
+ else if (typeof value === 'boolean') {
44
+ numTokens += getTokenCount(value.toString());
45
+ }
46
+ };
47
+ let numTokens = tokensPerMessage;
48
+ processValue(message.content);
49
+ return numTokens;
50
+ }
51
+ const createTokenCounter = async () => {
52
+ const res = await fetch(`https://tiktoken.pages.dev/js/o200k_base.json`);
53
+ const o200k_base = await res.json();
54
+ const countTokens = (text) => {
55
+ const enc = new Tiktoken(o200k_base);
56
+ return enc.encode(text).length;
57
+ };
58
+ return (message) => getTokenCountForMessage(message, countTokens);
59
+ };
60
+
61
+ export { createTokenCounter };
62
+ //# sourceMappingURL=tokens.mjs.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"tokens.mjs","sources":["../../../src/utils/tokens.ts"],"sourcesContent":["import { Tiktoken } from \"js-tiktoken/lite\";\nimport type { BaseMessage } from \"@langchain/core/messages\";\nimport { ContentTypes } from \"@/common/enum\";\n\nfunction getTokenCountForMessage(message: BaseMessage, getTokenCount: (text: string) => number): number {\n let tokensPerMessage = 3;\n\n const processValue = (value: unknown) => {\n if (Array.isArray(value)) {\n for (let item of value) {\n if (\n !item ||\n !item.type ||\n item.type === ContentTypes.ERROR ||\n item.type === ContentTypes.IMAGE_URL\n ) {\n continue;\n }\n\n if (item.type === ContentTypes.TOOL_CALL && item.tool_call != null) {\n const toolName = item.tool_call?.name || '';\n if (toolName != null && toolName && typeof toolName === 'string') {\n numTokens += getTokenCount(toolName);\n }\n\n const args = item.tool_call?.args || '';\n if (args != null && args && typeof args === 'string') {\n numTokens += getTokenCount(args);\n }\n\n const output = item.tool_call?.output || '';\n if (output != null && output && typeof output === 'string') {\n numTokens += getTokenCount(output);\n }\n continue;\n }\n\n const nestedValue = item[item.type];\n\n if (!nestedValue) {\n continue;\n }\n\n processValue(nestedValue);\n }\n } else if (typeof value === 'string') {\n numTokens += getTokenCount(value);\n } else if (typeof value === 'number') {\n numTokens += getTokenCount(value.toString());\n } else if (typeof value === 'boolean') {\n numTokens += getTokenCount(value.toString());\n }\n };\n\n let numTokens = tokensPerMessage;\n processValue(message.content);\n return numTokens;\n}\n\nexport const createTokenCounter = async () => {\n const res = await fetch(`https://tiktoken.pages.dev/js/o200k_base.json`);\n const o200k_base = await res.json();\n\n const countTokens = (text: string) => {\n const enc = new Tiktoken(o200k_base);\n return enc.encode(text).length;\n }\n \n return (message: BaseMessage) => getTokenCountForMessage(message, countTokens);\n}"],"names":[],"mappings":";;;AAIA,SAAS,uBAAuB,CAAC,OAAoB,EAAE,aAAuC,EAAA;IAC5F,IAAI,gBAAgB,GAAG,CAAC;AAExB,IAAA,MAAM,YAAY,GAAG,CAAC,KAAc,KAAI;AACtC,QAAA,IAAI,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;AACxB,YAAA,KAAK,IAAI,IAAI,IAAI,KAAK,EAAE;AACtB,gBAAA,IACE,CAAC,IAAI;oBACL,CAAC,IAAI,CAAC,IAAI;AACV,oBAAA,IAAI,CAAC,IAAI,KAAK,YAAY,CAAC,KAAK;AAChC,oBAAA,IAAI,CAAC,IAAI,KAAK,YAAY,CAAC,SAAS,EACpC;oBACA;;AAGF,gBAAA,IAAI,IAAI,CAAC,IAAI,KAAK,YAAY,CAAC,SAAS,IAAI,IAAI,CAAC,SAAS,IAAI,IAAI,EAAE;oBAClE,MAAM,QAAQ,GAAG,IAAI,CAAC,SAAS,EAAE,IAAI,IAAI,EAAE;oBAC3C,IAAI,QAAQ,IAAI,IAAI,IAAI,QAAQ,IAAI,OAAO,QAAQ,KAAK,QAAQ,EAAE;AAChE,wBAAA,SAAS,IAAI,aAAa,CAAC,QAAQ,CAAC;;oBAGtC,MAAM,IAAI,GAAG,IAAI,CAAC,SAAS,EAAE,IAAI,IAAI,EAAE;oBACvC,IAAI,IAAI,IAAI,IAAI,IAAI,IAAI,IAAI,OAAO,IAAI,KAAK,QAAQ,EAAE;AACpD,wBAAA,SAAS,IAAI,aAAa,CAAC,IAAI,CAAC;;oBAGlC,MAAM,MAAM,GAAG,IAAI,CAAC,SAAS,EAAE,MAAM,IAAI,EAAE;oBAC3C,IAAI,MAAM,IAAI,IAAI,IAAI,MAAM,IAAI,OAAO,MAAM,KAAK,QAAQ,EAAE;AAC1D,wBAAA,SAAS,IAAI,aAAa,CAAC,MAAM,CAAC;;oBAEpC;;gBAGF,MAAM,WAAW,GAAG,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC;gBAEnC,IAAI,CAAC,WAAW,EAAE;oBAChB;;gBAGF,YAAY,CAAC,WAAW,CAAC;;;AAEtB,aAAA,IAAI,OAAO,KAAK,KAAK,QAAQ,EAAE;AACpC,YAAA,SAAS,IAAI,aAAa,CAAC,KAAK,CAAC;;AAC5B,aAAA,IAAI,OAAO,KAAK,KAAK,QAAQ,EAAE;YACpC,SAAS,IAAI,aAAa,CAAC,KAAK,CAAC,QAAQ,EAAE,CAAC;;AACvC,aAAA,IAAI,OAAO,KAAK,KAAK,SAAS,EAAE;YACrC,SAAS,IAAI,aAAa,CAAC,KAAK,CAAC,QAAQ,EAAE,CAAC;;AAEhD,KAAC;IAED,IAAI,SAAS,GAAG,gBAAgB;AAChC,IAAA,YAAY,CAAC,OAAO,CAAC,OAAO,CAAC;AAC7B,IAAA,OAAO,SAAS;AAClB;AAEa,MAAA,kBAAkB,GAAG,YAAW;AAC3C,IAAA,MAAM,GAAG,GAAG,MAAM,KAAK,CAAC,CAAA,6CAAA,CAA+C,CAAC;AACxE,IAAA,MAAM,UAAU,GAAG,MAAM,GAAG,CAAC,IAAI,EAAE;AAEnC,IAAA,MAAM,WAAW,GAAG,CAAC,IAAY,KAAI;AACnC,QAAA,MAAM,GAAG,GAAG,IAAI,QAAQ,CAAC,UAAU,CAAC;QACpC,OAAO,GAAG,CAAC,MAAM,CAAC,IAAI,CAAC,CAAC,MAAM;AAChC,KAAC;IAED,OAAO,CAAC,OAAoB,KAAK,uBAAuB,CAAC,OAAO,EAAE,WAAW,CAAC;AAChF;;;;"}
@@ -2,10 +2,11 @@ import { ToolNode } from '@langchain/langgraph/prebuilt';
2
2
  import { START } from '@langchain/langgraph';
3
3
  import { Runnable, RunnableConfig } from '@langchain/core/runnables';
4
4
  import { SystemMessage } from '@langchain/core/messages';
5
- import type { BaseMessage } from '@langchain/core/messages';
5
+ import type { BaseMessage, UsageMetadata } from '@langchain/core/messages';
6
6
  import type * as t from '@/types';
7
7
  import { Providers, GraphNodeKeys, Callback, ContentTypes } from '@/common';
8
8
  import { ToolNode as CustomToolNode } from '@/tools/ToolNode';
9
+ import { createPruneMessages } from '@/messages';
9
10
  import { HandlerRegistry } from '@/events';
10
11
  export type GraphNode = GraphNodeKeys | typeof START;
11
12
  export type ClientCallback<T extends unknown[]> = (graph: StandardGraph, ...args: T) => void;
@@ -49,8 +50,13 @@ export declare abstract class Graph<T extends t.BaseGraphState = t.BaseGraphStat
49
50
  stepKeyIds: Map<string, string[]>;
50
51
  contentIndexMap: Map<string, number>;
51
52
  toolCallStepIds: Map<string, string>;
53
+ currentUsage: Partial<UsageMetadata> | undefined;
54
+ indexTokenCountMap: Record<string, number>;
55
+ maxContextTokens: number | undefined;
56
+ pruneMessages?: ReturnType<typeof createPruneMessages>;
52
57
  /** The amount of time that should pass before another consecutive API call */
53
58
  streamBuffer: number | undefined;
59
+ tokenCounter?: t.TokenCounter;
54
60
  signal?: AbortSignal;
55
61
  }
56
62
  export declare class StandardGraph extends Graph<t.BaseGraphState, GraphNode> {
@@ -87,6 +93,7 @@ export declare class StandardGraph extends Graph<t.BaseGraphState, GraphNode> {
87
93
  clientOptions?: t.ClientOptions;
88
94
  omitOriginalOptions?: string[];
89
95
  }): t.ChatModelInstance;
96
+ storeUsageMetadata(finalMessage?: BaseMessage): void;
90
97
  createCallModel(): (state: t.BaseGraphState, config?: RunnableConfig) => Promise<Partial<t.BaseGraphState>>;
91
98
  createWorkflow(): t.CompiledWorkflow<t.BaseGraphState>;
92
99
  /**
@@ -0,0 +1,120 @@
1
+ import { ToolMessage, BaseMessage } from '@langchain/core/messages';
2
+ import { HumanMessage, AIMessage, SystemMessage } from '@langchain/core/messages';
3
+ import { MessageContentImageUrl } from '@langchain/core/messages';
4
+ import type { MessageContentComplex } from '@/types';
5
+ import { Providers, ContentTypes } from '@/common';
6
+ interface VisionMessageParams {
7
+ message: {
8
+ role: string;
9
+ content: string;
10
+ name?: string;
11
+ [key: string]: any;
12
+ };
13
+ image_urls: MessageContentImageUrl[];
14
+ endpoint?: Providers;
15
+ }
16
+ /**
17
+ * Formats a message to OpenAI Vision API payload format.
18
+ *
19
+ * @param {VisionMessageParams} params - The parameters for formatting.
20
+ * @returns {Object} - The formatted message.
21
+ */
22
+ export declare const formatVisionMessage: ({ message, image_urls, endpoint }: VisionMessageParams) => {
23
+ role: string;
24
+ content: MessageContentComplex[];
25
+ name?: string;
26
+ [key: string]: any;
27
+ };
28
+ interface MessageInput {
29
+ role?: string;
30
+ _name?: string;
31
+ sender?: string;
32
+ text?: string;
33
+ content?: string | MessageContentComplex[];
34
+ image_urls?: MessageContentImageUrl[];
35
+ lc_id?: string[];
36
+ [key: string]: any;
37
+ }
38
+ interface FormatMessageParams {
39
+ message: MessageInput;
40
+ userName?: string;
41
+ assistantName?: string;
42
+ endpoint?: Providers;
43
+ langChain?: boolean;
44
+ }
45
+ interface FormattedMessage {
46
+ role: string;
47
+ content: string | MessageContentComplex[];
48
+ name?: string;
49
+ [key: string]: any;
50
+ }
51
+ /**
52
+ * Formats a message to OpenAI payload format based on the provided options.
53
+ *
54
+ * @param {FormatMessageParams} params - The parameters for formatting.
55
+ * @returns {FormattedMessage | HumanMessage | AIMessage | SystemMessage} - The formatted message.
56
+ */
57
+ export declare const formatMessage: ({ message, userName, assistantName, endpoint, langChain }: FormatMessageParams) => FormattedMessage | HumanMessage | AIMessage | SystemMessage;
58
+ /**
59
+ * Formats an array of messages for LangChain.
60
+ *
61
+ * @param {Array<MessageInput>} messages - The array of messages to format.
62
+ * @param {Omit<FormatMessageParams, 'message' | 'langChain'>} formatOptions - The options for formatting each message.
63
+ * @returns {Array<HumanMessage | AIMessage | SystemMessage>} - The array of formatted LangChain messages.
64
+ */
65
+ export declare const formatLangChainMessages: (messages: Array<MessageInput>, formatOptions: Omit<FormatMessageParams, "message" | "langChain">) => Array<HumanMessage | AIMessage | SystemMessage>;
66
+ interface LangChainMessage {
67
+ lc_kwargs?: {
68
+ additional_kwargs?: Record<string, any>;
69
+ [key: string]: any;
70
+ };
71
+ kwargs?: {
72
+ additional_kwargs?: Record<string, any>;
73
+ [key: string]: any;
74
+ };
75
+ [key: string]: any;
76
+ }
77
+ /**
78
+ * Formats a LangChain message object by merging properties from `lc_kwargs` or `kwargs` and `additional_kwargs`.
79
+ *
80
+ * @param {LangChainMessage} message - The message object to format.
81
+ * @returns {Record<string, any>} The formatted LangChain message.
82
+ */
83
+ export declare const formatFromLangChain: (message: LangChainMessage) => Record<string, any>;
84
+ interface TMessage {
85
+ role?: string;
86
+ content?: string | Array<{
87
+ type: ContentTypes;
88
+ text?: string;
89
+ tool_call_ids?: string[];
90
+ [key: string]: any;
91
+ }>;
92
+ [key: string]: any;
93
+ }
94
+ /**
95
+ * Formats an array of messages for LangChain, handling tool calls and creating ToolMessage instances.
96
+ *
97
+ * @param {Array<Partial<TMessage>>} payload - The array of messages to format.
98
+ * @param {Record<number, number>} [indexTokenCountMap] - Optional map of message indices to token counts.
99
+ * @returns {Object} - Object containing formatted messages and updated indexTokenCountMap if provided.
100
+ */
101
+ export declare const formatAgentMessages: (payload: Array<Partial<TMessage>>, indexTokenCountMap?: Record<number, number>) => {
102
+ messages: Array<HumanMessage | AIMessage | SystemMessage | ToolMessage>;
103
+ indexTokenCountMap?: Record<number, number>;
104
+ };
105
+ /**
106
+ * Formats an array of messages for LangChain, making sure all content fields are strings
107
+ * @param {Array<HumanMessage | AIMessage | SystemMessage | ToolMessage>} payload - The array of messages to format.
108
+ * @returns {Array<HumanMessage | AIMessage | SystemMessage | ToolMessage>} - The array of formatted LangChain messages, including ToolMessages for tool calls.
109
+ */
110
+ export declare const formatContentStrings: (payload: Array<BaseMessage>) => Array<BaseMessage>;
111
+ /**
112
+ * Adds a value at key 0 for system messages and shifts all key indices by one in an indexTokenCountMap.
113
+ * This is useful when adding a system message at the beginning of a conversation.
114
+ *
115
+ * @param indexTokenCountMap - The original map of message indices to token counts
116
+ * @param instructionsTokenCount - The token count for the system message to add at index 0
117
+ * @returns A new map with the system message at index 0 and all other indices shifted by 1
118
+ */
119
+ export declare function shiftIndexTokenCountMap(indexTokenCountMap: Record<number, number>, instructionsTokenCount: number): Record<number, number>;
120
+ export {};
@@ -0,0 +1,3 @@
1
+ export * from './core';
2
+ export * from './prune';
3
+ export * from './format';
@@ -0,0 +1,16 @@
1
+ import type { BaseMessage, UsageMetadata } from '@langchain/core/messages';
2
+ import type { TokenCounter } from '@/types/run';
3
+ export type PruneMessagesFactoryParams = {
4
+ maxTokens: number;
5
+ startIndex: number;
6
+ tokenCounter: TokenCounter;
7
+ indexTokenCountMap: Record<string, number>;
8
+ };
9
+ export type PruneMessagesParams = {
10
+ messages: BaseMessage[];
11
+ usageMetadata?: Partial<UsageMetadata>;
12
+ };
13
+ export declare function createPruneMessages(factoryParams: PruneMessagesFactoryParams): (params: PruneMessagesParams) => {
14
+ context: BaseMessage[];
15
+ indexTokenCountMap: Record<string, number>;
16
+ };
@@ -50,7 +50,11 @@ export type RunConfig = {
50
50
  returnContent?: boolean;
51
51
  };
52
52
  export type ProvidedCallbacks = (BaseCallbackHandler | CallbackHandlerMethods)[] | undefined;
53
+ export type TokenCounter = (message: BaseMessage) => number;
53
54
  export type EventStreamOptions = {
54
55
  callbacks?: graph.ClientCallbacks;
55
56
  keepContent?: boolean;
57
+ maxContextTokens?: number;
58
+ tokenCounter?: TokenCounter;
59
+ indexTokenCountMap?: Record<string, number>;
56
60
  };
@@ -0,0 +1,2 @@
1
+ import type { BaseMessage } from "@langchain/core/messages";
2
+ export declare const createTokenCounter: () => Promise<(message: BaseMessage) => number>;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@librechat/agents",
3
- "version": "2.2.1",
3
+ "version": "2.2.3",
4
4
  "main": "./dist/cjs/main.cjs",
5
5
  "module": "./dist/esm/main.mjs",
6
6
  "types": "./dist/types/index.d.ts",
@@ -8,12 +8,13 @@ import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai';
8
8
  import { Runnable, RunnableConfig } from '@langchain/core/runnables';
9
9
  import { dispatchCustomEvent } from '@langchain/core/callbacks/dispatch';
10
10
  import { AIMessageChunk, ToolMessage, SystemMessage } from '@langchain/core/messages';
11
- import type { BaseMessage, BaseMessageFields } from '@langchain/core/messages';
11
+ import type { BaseMessage, BaseMessageFields, UsageMetadata } from '@langchain/core/messages';
12
12
  import type * as t from '@/types';
13
13
  import { Providers, GraphEvents, GraphNodeKeys, StepTypes, Callback, ContentTypes } from '@/common';
14
14
  import { getChatModelClass, manualToolStreamProviders } from '@/llm/providers';
15
15
  import { ToolNode as CustomToolNode, toolsCondition } from '@/tools/ToolNode';
16
16
  import {
17
+ createPruneMessages,
17
18
  modifyDeltaProperties,
18
19
  formatArtifactPayload,
19
20
  convertMessagesToContent,
@@ -74,8 +75,13 @@ export abstract class Graph<
74
75
  stepKeyIds: Map<string, string[]> = new Map<string, string[]>();
75
76
  contentIndexMap: Map<string, number> = new Map();
76
77
  toolCallStepIds: Map<string, string> = new Map();
78
+ currentUsage: Partial<UsageMetadata> | undefined;
79
+ indexTokenCountMap: Record<string, number> = {};
80
+ maxContextTokens: number | undefined;
81
+ pruneMessages?: ReturnType<typeof createPruneMessages>;
77
82
  /** The amount of time that should pass before another consecutive API call */
78
83
  streamBuffer: number | undefined;
84
+ tokenCounter?: t.TokenCounter;
79
85
  signal?: AbortSignal;
80
86
  }
81
87
 
@@ -166,6 +172,10 @@ export class StandardGraph extends Graph<
166
172
  this.currentTokenType = resetIfNotEmpty(this.currentTokenType, ContentTypes.TEXT);
167
173
  this.lastToken = resetIfNotEmpty(this.lastToken, undefined);
168
174
  this.tokenTypeSwitch = resetIfNotEmpty(this.tokenTypeSwitch, undefined);
175
+ this.indexTokenCountMap = resetIfNotEmpty(this.indexTokenCountMap, {});
176
+ this.currentUsage = resetIfNotEmpty(this.currentUsage, undefined);
177
+ this.tokenCounter = resetIfNotEmpty(this.tokenCounter, undefined);
178
+ this.maxContextTokens = resetIfNotEmpty(this.maxContextTokens, undefined);
169
179
  }
170
180
 
171
181
  /* Run Step Processing */
@@ -326,6 +336,12 @@ export class StandardGraph extends Graph<
326
336
  return new ChatModelClass(options);
327
337
  }
328
338
 
339
+ storeUsageMetadata(finalMessage?: BaseMessage): void {
340
+ if (finalMessage && 'usage_metadata' in finalMessage && finalMessage.usage_metadata) {
341
+ this.currentUsage = finalMessage.usage_metadata as Partial<UsageMetadata>;
342
+ }
343
+ }
344
+
329
345
  createCallModel() {
330
346
  return async (state: t.BaseGraphState, config?: RunnableConfig): Promise<Partial<t.BaseGraphState>> => {
331
347
  const { provider = '' } = (config?.configurable as t.GraphConfig | undefined) ?? {} ;
@@ -338,9 +354,27 @@ export class StandardGraph extends Graph<
338
354
  this.config = config;
339
355
  const { messages } = state;
340
356
 
341
- const finalMessages = messages;
342
- const lastMessageX = finalMessages[finalMessages.length - 2];
343
- const lastMessageY = finalMessages[finalMessages.length - 1];
357
+ let messagesToUse = messages;
358
+ if (!this.pruneMessages && this.tokenCounter && this.maxContextTokens && this.indexTokenCountMap[0] != null) {
359
+ this.pruneMessages = createPruneMessages({
360
+ indexTokenCountMap: this.indexTokenCountMap,
361
+ maxTokens: this.maxContextTokens,
362
+ tokenCounter: this.tokenCounter,
363
+ startIndex: this.startIndex,
364
+ });
365
+ }
366
+ if (this.pruneMessages) {
367
+ const { context, indexTokenCountMap } = this.pruneMessages({
368
+ messages,
369
+ usageMetadata: this.currentUsage,
370
+ });
371
+ this.indexTokenCountMap = indexTokenCountMap;
372
+ messagesToUse = context;
373
+ }
374
+
375
+ const finalMessages = messagesToUse;
376
+ const lastMessageX = finalMessages.length >= 2 ? finalMessages[finalMessages.length - 2] : null;
377
+ const lastMessageY = finalMessages.length >= 1 ? finalMessages[finalMessages.length - 1] : null;
344
378
 
345
379
  if (
346
380
  provider === Providers.BEDROCK
@@ -372,6 +406,7 @@ export class StandardGraph extends Graph<
372
406
 
373
407
  this.lastStreamCall = Date.now();
374
408
 
409
+ let result: Partial<t.BaseGraphState>;
375
410
  if ((this.tools?.length ?? 0) > 0 && manualToolStreamProviders.has(provider)) {
376
411
  const stream = await this.boundModel.stream(finalMessages, config);
377
412
  let finalChunk: AIMessageChunk | undefined;
@@ -385,19 +420,22 @@ export class StandardGraph extends Graph<
385
420
  }
386
421
 
387
422
  finalChunk = modifyDeltaProperties(this.provider, finalChunk);
388
- return { messages: [finalChunk as AIMessageChunk] };
389
- }
390
-
391
- const finalMessage = (await this.boundModel.invoke(finalMessages, config)) as AIMessageChunk;
392
- if ((finalMessage.tool_calls?.length ?? 0) > 0) {
393
- finalMessage.tool_calls = finalMessage.tool_calls?.filter((tool_call) => {
394
- if (!tool_call.name) {
395
- return false;
396
- }
397
- return true;
398
- });
423
+ result = { messages: [finalChunk as AIMessageChunk] };
424
+ } else {
425
+ const finalMessage = (await this.boundModel.invoke(finalMessages, config)) as AIMessageChunk;
426
+ if ((finalMessage.tool_calls?.length ?? 0) > 0) {
427
+ finalMessage.tool_calls = finalMessage.tool_calls?.filter((tool_call) => {
428
+ if (!tool_call.name) {
429
+ return false;
430
+ }
431
+ return true;
432
+ });
433
+ }
434
+ result = { messages: [finalMessage] };
399
435
  }
400
- return { messages: [finalMessage] };
436
+
437
+ this.storeUsageMetadata(result?.messages?.[0]);
438
+ return result;
401
439
  };
402
440
  }
403
441