illuma-agents 1.0.2 → 1.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +25 -21
- package/dist/cjs/agents/AgentContext.cjs +222 -0
- package/dist/cjs/agents/AgentContext.cjs.map +1 -0
- package/dist/cjs/common/enum.cjs +5 -4
- package/dist/cjs/common/enum.cjs.map +1 -1
- package/dist/cjs/events.cjs +7 -5
- package/dist/cjs/events.cjs.map +1 -1
- package/dist/cjs/graphs/Graph.cjs +328 -207
- package/dist/cjs/graphs/Graph.cjs.map +1 -1
- package/dist/cjs/graphs/MultiAgentGraph.cjs +507 -0
- package/dist/cjs/graphs/MultiAgentGraph.cjs.map +1 -0
- package/dist/cjs/llm/anthropic/index.cjs.map +1 -1
- package/dist/cjs/llm/google/index.cjs.map +1 -1
- package/dist/cjs/llm/ollama/index.cjs.map +1 -1
- package/dist/cjs/llm/openai/index.cjs +35 -0
- package/dist/cjs/llm/openai/index.cjs.map +1 -1
- package/dist/cjs/llm/openai/utils/index.cjs +3 -1
- package/dist/cjs/llm/openai/utils/index.cjs.map +1 -1
- package/dist/cjs/llm/openrouter/index.cjs.map +1 -1
- package/dist/cjs/llm/providers.cjs +0 -2
- package/dist/cjs/llm/providers.cjs.map +1 -1
- package/dist/cjs/llm/vertexai/index.cjs.map +1 -1
- package/dist/cjs/main.cjs +12 -1
- package/dist/cjs/main.cjs.map +1 -1
- package/dist/cjs/messages/cache.cjs +123 -0
- package/dist/cjs/messages/cache.cjs.map +1 -0
- package/dist/cjs/messages/content.cjs +53 -0
- package/dist/cjs/messages/content.cjs.map +1 -0
- package/dist/cjs/messages/format.cjs +17 -29
- package/dist/cjs/messages/format.cjs.map +1 -1
- package/dist/cjs/run.cjs +119 -74
- package/dist/cjs/run.cjs.map +1 -1
- package/dist/cjs/stream.cjs +77 -73
- package/dist/cjs/stream.cjs.map +1 -1
- package/dist/cjs/tools/Calculator.cjs +45 -0
- package/dist/cjs/tools/Calculator.cjs.map +1 -0
- package/dist/cjs/tools/CodeExecutor.cjs +22 -22
- package/dist/cjs/tools/CodeExecutor.cjs.map +1 -1
- package/dist/cjs/tools/ToolNode.cjs +5 -3
- package/dist/cjs/tools/ToolNode.cjs.map +1 -1
- package/dist/cjs/tools/handlers.cjs +20 -20
- package/dist/cjs/tools/handlers.cjs.map +1 -1
- package/dist/cjs/utils/events.cjs +31 -0
- package/dist/cjs/utils/events.cjs.map +1 -0
- package/dist/cjs/utils/handlers.cjs +70 -0
- package/dist/cjs/utils/handlers.cjs.map +1 -0
- package/dist/cjs/utils/tokens.cjs +54 -7
- package/dist/cjs/utils/tokens.cjs.map +1 -1
- package/dist/esm/agents/AgentContext.mjs +220 -0
- package/dist/esm/agents/AgentContext.mjs.map +1 -0
- package/dist/esm/common/enum.mjs +5 -4
- package/dist/esm/common/enum.mjs.map +1 -1
- package/dist/esm/events.mjs +7 -5
- package/dist/esm/events.mjs.map +1 -1
- package/dist/esm/graphs/Graph.mjs +330 -209
- package/dist/esm/graphs/Graph.mjs.map +1 -1
- package/dist/esm/graphs/MultiAgentGraph.mjs +505 -0
- package/dist/esm/graphs/MultiAgentGraph.mjs.map +1 -0
- package/dist/esm/llm/anthropic/index.mjs.map +1 -1
- package/dist/esm/llm/google/index.mjs.map +1 -1
- package/dist/esm/llm/ollama/index.mjs.map +1 -1
- package/dist/esm/llm/openai/index.mjs +35 -0
- package/dist/esm/llm/openai/index.mjs.map +1 -1
- package/dist/esm/llm/openai/utils/index.mjs +3 -1
- package/dist/esm/llm/openai/utils/index.mjs.map +1 -1
- package/dist/esm/llm/openrouter/index.mjs.map +1 -1
- package/dist/esm/llm/providers.mjs +0 -2
- package/dist/esm/llm/providers.mjs.map +1 -1
- package/dist/esm/llm/vertexai/index.mjs.map +1 -1
- package/dist/esm/main.mjs +7 -2
- package/dist/esm/main.mjs.map +1 -1
- package/dist/esm/messages/cache.mjs +120 -0
- package/dist/esm/messages/cache.mjs.map +1 -0
- package/dist/esm/messages/content.mjs +51 -0
- package/dist/esm/messages/content.mjs.map +1 -0
- package/dist/esm/messages/format.mjs +18 -29
- package/dist/esm/messages/format.mjs.map +1 -1
- package/dist/esm/run.mjs +119 -74
- package/dist/esm/run.mjs.map +1 -1
- package/dist/esm/stream.mjs +77 -73
- package/dist/esm/stream.mjs.map +1 -1
- package/dist/esm/tools/Calculator.mjs +24 -0
- package/dist/esm/tools/Calculator.mjs.map +1 -0
- package/dist/esm/tools/CodeExecutor.mjs +22 -22
- package/dist/esm/tools/CodeExecutor.mjs.map +1 -1
- package/dist/esm/tools/ToolNode.mjs +5 -3
- package/dist/esm/tools/ToolNode.mjs.map +1 -1
- package/dist/esm/tools/handlers.mjs +20 -20
- package/dist/esm/tools/handlers.mjs.map +1 -1
- package/dist/esm/utils/events.mjs +29 -0
- package/dist/esm/utils/events.mjs.map +1 -0
- package/dist/esm/utils/handlers.mjs +68 -0
- package/dist/esm/utils/handlers.mjs.map +1 -0
- package/dist/esm/utils/tokens.mjs +54 -8
- package/dist/esm/utils/tokens.mjs.map +1 -1
- package/dist/types/agents/AgentContext.d.ts +94 -0
- package/dist/types/common/enum.d.ts +7 -5
- package/dist/types/events.d.ts +3 -3
- package/dist/types/graphs/Graph.d.ts +60 -66
- package/dist/types/graphs/MultiAgentGraph.d.ts +47 -0
- package/dist/types/graphs/index.d.ts +1 -0
- package/dist/types/index.d.ts +1 -0
- package/dist/types/llm/openai/index.d.ts +10 -0
- package/dist/types/messages/cache.d.ts +20 -0
- package/dist/types/messages/content.d.ts +7 -0
- package/dist/types/messages/format.d.ts +1 -7
- package/dist/types/messages/index.d.ts +2 -0
- package/dist/types/messages/reducer.d.ts +9 -0
- package/dist/types/run.d.ts +16 -10
- package/dist/types/stream.d.ts +4 -3
- package/dist/types/tools/Calculator.d.ts +8 -0
- package/dist/types/tools/ToolNode.d.ts +1 -1
- package/dist/types/tools/handlers.d.ts +9 -7
- package/dist/types/tools/search/tool.d.ts +4 -4
- package/dist/types/types/graph.d.ts +124 -11
- package/dist/types/types/llm.d.ts +13 -9
- package/dist/types/types/messages.d.ts +4 -0
- package/dist/types/types/run.d.ts +46 -8
- package/dist/types/types/stream.d.ts +3 -2
- package/dist/types/utils/events.d.ts +6 -0
- package/dist/types/utils/handlers.d.ts +34 -0
- package/dist/types/utils/index.d.ts +1 -0
- package/dist/types/utils/tokens.d.ts +24 -0
- package/package.json +162 -145
- package/src/agents/AgentContext.ts +323 -0
- package/src/common/enum.ts +177 -176
- package/src/events.ts +197 -191
- package/src/graphs/Graph.ts +1058 -846
- package/src/graphs/MultiAgentGraph.ts +598 -0
- package/src/graphs/index.ts +2 -1
- package/src/index.ts +25 -24
- package/src/llm/anthropic/index.ts +413 -413
- package/src/llm/google/index.ts +222 -222
- package/src/llm/google/utils/zod_to_genai_parameters.ts +86 -88
- package/src/llm/ollama/index.ts +92 -92
- package/src/llm/openai/index.ts +894 -853
- package/src/llm/openai/utils/index.ts +920 -918
- package/src/llm/openrouter/index.ts +60 -60
- package/src/llm/providers.ts +55 -57
- package/src/llm/vertexai/index.ts +360 -360
- package/src/messages/cache.test.ts +461 -0
- package/src/messages/cache.ts +151 -0
- package/src/messages/content.test.ts +362 -0
- package/src/messages/content.ts +63 -0
- package/src/messages/format.ts +611 -625
- package/src/messages/formatAgentMessages.test.ts +1144 -917
- package/src/messages/index.ts +6 -4
- package/src/messages/reducer.ts +80 -0
- package/src/run.ts +447 -381
- package/src/scripts/abort.ts +157 -138
- package/src/scripts/ant_web_search.ts +158 -158
- package/src/scripts/cli.ts +172 -167
- package/src/scripts/cli2.ts +133 -125
- package/src/scripts/cli3.ts +184 -178
- package/src/scripts/cli4.ts +191 -184
- package/src/scripts/cli5.ts +191 -184
- package/src/scripts/code_exec.ts +213 -214
- package/src/scripts/code_exec_simple.ts +147 -129
- package/src/scripts/content.ts +138 -120
- package/src/scripts/handoff-test.ts +135 -0
- package/src/scripts/multi-agent-chain.ts +278 -0
- package/src/scripts/multi-agent-conditional.ts +220 -0
- package/src/scripts/multi-agent-document-review-chain.ts +197 -0
- package/src/scripts/multi-agent-hybrid-flow.ts +310 -0
- package/src/scripts/multi-agent-parallel.ts +343 -0
- package/src/scripts/multi-agent-sequence.ts +212 -0
- package/src/scripts/multi-agent-supervisor.ts +364 -0
- package/src/scripts/multi-agent-test.ts +186 -0
- package/src/scripts/search.ts +146 -150
- package/src/scripts/simple.ts +225 -225
- package/src/scripts/stream.ts +140 -122
- package/src/scripts/test-custom-prompt-key.ts +145 -0
- package/src/scripts/test-handoff-input.ts +170 -0
- package/src/scripts/test-multi-agent-list-handoff.ts +261 -0
- package/src/scripts/test-tools-before-handoff.ts +222 -0
- package/src/scripts/tools.ts +153 -155
- package/src/specs/agent-handoffs.test.ts +889 -0
- package/src/specs/anthropic.simple.test.ts +320 -317
- package/src/specs/azure.simple.test.ts +325 -316
- package/src/specs/openai.simple.test.ts +311 -316
- package/src/specs/openrouter.simple.test.ts +107 -0
- package/src/specs/prune.test.ts +758 -763
- package/src/specs/reasoning.test.ts +201 -165
- package/src/specs/thinking-prune.test.ts +769 -703
- package/src/specs/token-memoization.test.ts +39 -0
- package/src/stream.ts +664 -651
- package/src/tools/Calculator.test.ts +278 -0
- package/src/tools/Calculator.ts +25 -0
- package/src/tools/CodeExecutor.ts +220 -220
- package/src/tools/ToolNode.ts +170 -170
- package/src/tools/handlers.ts +341 -336
- package/src/types/graph.ts +372 -185
- package/src/types/llm.ts +141 -140
- package/src/types/messages.ts +4 -0
- package/src/types/run.ts +128 -89
- package/src/types/stream.ts +401 -400
- package/src/utils/events.ts +32 -0
- package/src/utils/handlers.ts +107 -0
- package/src/utils/index.ts +6 -5
- package/src/utils/llmConfig.ts +183 -183
- package/src/utils/tokens.ts +129 -70
- package/dist/types/scripts/abort.d.ts +0 -1
- package/dist/types/scripts/ant_web_search.d.ts +0 -1
- package/dist/types/scripts/args.d.ts +0 -7
- package/dist/types/scripts/caching.d.ts +0 -1
- package/dist/types/scripts/cli.d.ts +0 -1
- package/dist/types/scripts/cli2.d.ts +0 -1
- package/dist/types/scripts/cli3.d.ts +0 -1
- package/dist/types/scripts/cli4.d.ts +0 -1
- package/dist/types/scripts/cli5.d.ts +0 -1
- package/dist/types/scripts/code_exec.d.ts +0 -1
- package/dist/types/scripts/code_exec_files.d.ts +0 -1
- package/dist/types/scripts/code_exec_simple.d.ts +0 -1
- package/dist/types/scripts/content.d.ts +0 -1
- package/dist/types/scripts/empty_input.d.ts +0 -1
- package/dist/types/scripts/image.d.ts +0 -1
- package/dist/types/scripts/memory.d.ts +0 -1
- package/dist/types/scripts/search.d.ts +0 -1
- package/dist/types/scripts/simple.d.ts +0 -1
- package/dist/types/scripts/stream.d.ts +0 -1
- package/dist/types/scripts/thinking.d.ts +0 -1
- package/dist/types/scripts/tools.d.ts +0 -1
- package/dist/types/specs/spec.utils.d.ts +0 -1
- package/dist/types/tools/example.d.ts +0 -78
- package/src/tools/example.ts +0 -129
package/src/specs/prune.test.ts
CHANGED
|
@@ -1,763 +1,758 @@
|
|
|
1
|
-
// src/specs/prune.test.ts
|
|
2
|
-
import { config } from 'dotenv';
|
|
3
|
-
config();
|
|
4
|
-
import {
|
|
5
|
-
HumanMessage,
|
|
6
|
-
AIMessage,
|
|
7
|
-
SystemMessage,
|
|
8
|
-
BaseMessage,
|
|
9
|
-
ToolMessage,
|
|
10
|
-
} from '@langchain/core/messages';
|
|
11
|
-
import type { RunnableConfig } from '@langchain/core/runnables';
|
|
12
|
-
import type { UsageMetadata } from '@langchain/core/messages';
|
|
13
|
-
import type * as t from '@/types';
|
|
14
|
-
import { createPruneMessages } from '@/messages/prune';
|
|
15
|
-
import { getLLMConfig } from '@/utils/llmConfig';
|
|
16
|
-
import { Providers } from '@/common';
|
|
17
|
-
import { Run } from '@/run';
|
|
18
|
-
|
|
19
|
-
// Create a simple token counter for testing
|
|
20
|
-
const createTestTokenCounter = (): t.TokenCounter => {
|
|
21
|
-
// This simple token counter just counts characters as tokens for predictable testing
|
|
22
|
-
return (message: BaseMessage): number => {
|
|
23
|
-
// Use type assertion to help TypeScript understand the type
|
|
24
|
-
const content = message.content as
|
|
25
|
-
| string
|
|
26
|
-
| Array<t.MessageContentComplex | string>
|
|
27
|
-
| undefined;
|
|
28
|
-
|
|
29
|
-
// Handle string content
|
|
30
|
-
if (typeof content === 'string') {
|
|
31
|
-
return content.length;
|
|
32
|
-
}
|
|
33
|
-
|
|
34
|
-
// Handle array content
|
|
35
|
-
if (Array.isArray(content)) {
|
|
36
|
-
let totalLength = 0;
|
|
37
|
-
|
|
38
|
-
for (const item of content) {
|
|
39
|
-
if (typeof item === 'string') {
|
|
40
|
-
totalLength += item.length;
|
|
41
|
-
} else if (typeof item === 'object') {
|
|
42
|
-
if ('text' in item && typeof item.text === 'string') {
|
|
43
|
-
totalLength += item.text.length;
|
|
44
|
-
}
|
|
45
|
-
}
|
|
46
|
-
}
|
|
47
|
-
|
|
48
|
-
return totalLength;
|
|
49
|
-
}
|
|
50
|
-
|
|
51
|
-
// Default case - if content is null, undefined, or any other type
|
|
52
|
-
return 0;
|
|
53
|
-
};
|
|
54
|
-
};
|
|
55
|
-
|
|
56
|
-
// Since the internal functions in prune.ts are not exported, we'll reimplement them here for testing
|
|
57
|
-
// This is based on the implementation in src/messages/prune.ts
|
|
58
|
-
function calculateTotalTokens(usage: Partial<UsageMetadata>): UsageMetadata {
|
|
59
|
-
const baseInputTokens = Number(usage.input_tokens) || 0;
|
|
60
|
-
const cacheCreation = Number(usage.input_token_details?.cache_creation) || 0;
|
|
61
|
-
const cacheRead = Number(usage.input_token_details?.cache_read) || 0;
|
|
62
|
-
|
|
63
|
-
const totalInputTokens = baseInputTokens + cacheCreation + cacheRead;
|
|
64
|
-
const totalOutputTokens = Number(usage.output_tokens) || 0;
|
|
65
|
-
|
|
66
|
-
return {
|
|
67
|
-
input_tokens: totalInputTokens,
|
|
68
|
-
output_tokens: totalOutputTokens,
|
|
69
|
-
total_tokens: totalInputTokens + totalOutputTokens,
|
|
70
|
-
};
|
|
71
|
-
}
|
|
72
|
-
|
|
73
|
-
function getMessagesWithinTokenLimit({
|
|
74
|
-
messages: _messages,
|
|
75
|
-
maxContextTokens,
|
|
76
|
-
indexTokenCountMap,
|
|
77
|
-
startType,
|
|
78
|
-
}: {
|
|
79
|
-
messages: BaseMessage[];
|
|
80
|
-
maxContextTokens: number;
|
|
81
|
-
indexTokenCountMap: Record<string, number>;
|
|
82
|
-
startType?: string;
|
|
83
|
-
}): {
|
|
84
|
-
context: BaseMessage[];
|
|
85
|
-
remainingContextTokens: number;
|
|
86
|
-
messagesToRefine: BaseMessage[];
|
|
87
|
-
summaryIndex: number;
|
|
88
|
-
} {
|
|
89
|
-
// Every reply is primed with <|start|>assistant<|message|>, so we
|
|
90
|
-
// start with 3 tokens for the label after all messages have been counted.
|
|
91
|
-
let summaryIndex = -1;
|
|
92
|
-
let currentTokenCount = 3;
|
|
93
|
-
const instructions =
|
|
94
|
-
_messages[0]?.getType() === 'system' ? _messages[0] : undefined;
|
|
95
|
-
const instructionsTokenCount =
|
|
96
|
-
instructions != null ? indexTokenCountMap[0] : 0;
|
|
97
|
-
let remainingContextTokens = maxContextTokens - instructionsTokenCount;
|
|
98
|
-
const messages = [..._messages];
|
|
99
|
-
const context: BaseMessage[] = [];
|
|
100
|
-
|
|
101
|
-
if (currentTokenCount < remainingContextTokens) {
|
|
102
|
-
let currentIndex = messages.length;
|
|
103
|
-
while (
|
|
104
|
-
messages.length > 0 &&
|
|
105
|
-
currentTokenCount < remainingContextTokens &&
|
|
106
|
-
currentIndex > 1
|
|
107
|
-
) {
|
|
108
|
-
currentIndex--;
|
|
109
|
-
if (messages.length === 1 && instructions) {
|
|
110
|
-
break;
|
|
111
|
-
}
|
|
112
|
-
const poppedMessage = messages.pop();
|
|
113
|
-
if (!poppedMessage) continue;
|
|
114
|
-
|
|
115
|
-
const tokenCount = indexTokenCountMap[currentIndex] || 0;
|
|
116
|
-
|
|
117
|
-
if (currentTokenCount + tokenCount <= remainingContextTokens) {
|
|
118
|
-
context.push(poppedMessage);
|
|
119
|
-
currentTokenCount += tokenCount;
|
|
120
|
-
} else {
|
|
121
|
-
messages.push(poppedMessage);
|
|
122
|
-
break;
|
|
123
|
-
}
|
|
124
|
-
}
|
|
125
|
-
|
|
126
|
-
// If startType is specified, discard messages until we find one of the required type
|
|
127
|
-
if (startType != null && startType && context.length > 0) {
|
|
128
|
-
const requiredTypeIndex = context.findIndex(
|
|
129
|
-
(msg) => msg.getType() === startType
|
|
130
|
-
);
|
|
131
|
-
|
|
132
|
-
if (requiredTypeIndex > 0) {
|
|
133
|
-
// If we found a message of the required type, discard all messages before it
|
|
134
|
-
const remainingMessages = context.slice(requiredTypeIndex);
|
|
135
|
-
context.length = 0; // Clear the array
|
|
136
|
-
context.push(...remainingMessages);
|
|
137
|
-
}
|
|
138
|
-
}
|
|
139
|
-
}
|
|
140
|
-
|
|
141
|
-
if (instructions && _messages.length > 0) {
|
|
142
|
-
context.push(_messages[0] as BaseMessage);
|
|
143
|
-
messages.shift();
|
|
144
|
-
}
|
|
145
|
-
|
|
146
|
-
const prunedMemory = messages;
|
|
147
|
-
summaryIndex = prunedMemory.length - 1;
|
|
148
|
-
remainingContextTokens -= currentTokenCount;
|
|
149
|
-
|
|
150
|
-
return {
|
|
151
|
-
summaryIndex,
|
|
152
|
-
remainingContextTokens,
|
|
153
|
-
context: context.reverse(),
|
|
154
|
-
messagesToRefine: prunedMemory,
|
|
155
|
-
};
|
|
156
|
-
}
|
|
157
|
-
|
|
158
|
-
function checkValidNumber(value: unknown): value is number {
|
|
159
|
-
return typeof value === 'number' && !isNaN(value) && value > 0;
|
|
160
|
-
}
|
|
161
|
-
|
|
162
|
-
describe('Prune Messages Tests', () => {
|
|
163
|
-
jest.setTimeout(30000);
|
|
164
|
-
|
|
165
|
-
describe('calculateTotalTokens', () => {
|
|
166
|
-
it('should calculate total tokens correctly with all fields present', () => {
|
|
167
|
-
const usage: Partial<UsageMetadata> = {
|
|
168
|
-
input_tokens: 100,
|
|
169
|
-
output_tokens: 50,
|
|
170
|
-
input_token_details: {
|
|
171
|
-
cache_creation: 10,
|
|
172
|
-
cache_read: 5,
|
|
173
|
-
},
|
|
174
|
-
};
|
|
175
|
-
|
|
176
|
-
const result = calculateTotalTokens(usage);
|
|
177
|
-
|
|
178
|
-
expect(result.input_tokens).toBe(115); // 100 + 10 + 5
|
|
179
|
-
expect(result.output_tokens).toBe(50);
|
|
180
|
-
expect(result.total_tokens).toBe(165); // 115 + 50
|
|
181
|
-
});
|
|
182
|
-
|
|
183
|
-
it('should handle missing fields gracefully', () => {
|
|
184
|
-
const usage: Partial<UsageMetadata> = {
|
|
185
|
-
input_tokens: 100,
|
|
186
|
-
output_tokens: 50,
|
|
187
|
-
};
|
|
188
|
-
|
|
189
|
-
const result = calculateTotalTokens(usage);
|
|
190
|
-
|
|
191
|
-
expect(result.input_tokens).toBe(100);
|
|
192
|
-
expect(result.output_tokens).toBe(50);
|
|
193
|
-
expect(result.total_tokens).toBe(150);
|
|
194
|
-
});
|
|
195
|
-
|
|
196
|
-
it('should handle empty usage object', () => {
|
|
197
|
-
const usage: Partial<UsageMetadata> = {};
|
|
198
|
-
|
|
199
|
-
const result = calculateTotalTokens(usage);
|
|
200
|
-
|
|
201
|
-
expect(result.input_tokens).toBe(0);
|
|
202
|
-
expect(result.output_tokens).toBe(0);
|
|
203
|
-
expect(result.total_tokens).toBe(0);
|
|
204
|
-
});
|
|
205
|
-
});
|
|
206
|
-
|
|
207
|
-
describe('getMessagesWithinTokenLimit', () => {
|
|
208
|
-
it('should include all messages when under token limit', () => {
|
|
209
|
-
const messages = [
|
|
210
|
-
new SystemMessage('System instruction'),
|
|
211
|
-
new HumanMessage('Hello'),
|
|
212
|
-
new AIMessage('Hi there'),
|
|
213
|
-
];
|
|
214
|
-
|
|
215
|
-
const indexTokenCountMap = {
|
|
216
|
-
0: 17, // "System instruction"
|
|
217
|
-
1: 5, // "Hello"
|
|
218
|
-
2: 8, // "Hi there"
|
|
219
|
-
};
|
|
220
|
-
|
|
221
|
-
const result = getMessagesWithinTokenLimit({
|
|
222
|
-
messages,
|
|
223
|
-
maxContextTokens: 100,
|
|
224
|
-
indexTokenCountMap,
|
|
225
|
-
});
|
|
226
|
-
|
|
227
|
-
expect(result.context.length).toBe(3);
|
|
228
|
-
expect(result.context[0]).toBe(messages[0]); // System message
|
|
229
|
-
expect(result.context[0].getType()).toBe('system'); // System message
|
|
230
|
-
expect(result.remainingContextTokens).toBe(100 - 17 - 5 - 8 - 3); // -3 for the assistant label tokens
|
|
231
|
-
expect(result.messagesToRefine.length).toBe(0);
|
|
232
|
-
});
|
|
233
|
-
|
|
234
|
-
it('should prune oldest messages when over token limit', () => {
|
|
235
|
-
const messages = [
|
|
236
|
-
new SystemMessage('System instruction'),
|
|
237
|
-
new HumanMessage('Message 1'),
|
|
238
|
-
new AIMessage('Response 1'),
|
|
239
|
-
new HumanMessage('Message 2'),
|
|
240
|
-
new AIMessage('Response 2'),
|
|
241
|
-
];
|
|
242
|
-
|
|
243
|
-
const indexTokenCountMap = {
|
|
244
|
-
0: 17, // "System instruction"
|
|
245
|
-
1: 9, // "Message 1"
|
|
246
|
-
2: 10, // "Response 1"
|
|
247
|
-
3: 9, // "Message 2"
|
|
248
|
-
4: 10, // "Response 2"
|
|
249
|
-
};
|
|
250
|
-
|
|
251
|
-
// Set a limit that can only fit the system message and the last two messages
|
|
252
|
-
const result = getMessagesWithinTokenLimit({
|
|
253
|
-
messages,
|
|
254
|
-
maxContextTokens: 40,
|
|
255
|
-
indexTokenCountMap,
|
|
256
|
-
});
|
|
257
|
-
|
|
258
|
-
// Should include system message and the last two messages
|
|
259
|
-
expect(result.context.length).toBe(3);
|
|
260
|
-
expect(result.context[0]).toBe(messages[0]); // System message
|
|
261
|
-
expect(result.context[0].getType()).toBe('system'); // System message
|
|
262
|
-
expect(result.context[1]).toBe(messages[3]); // Message 2
|
|
263
|
-
expect(result.context[2]).toBe(messages[4]); // Response 2
|
|
264
|
-
|
|
265
|
-
// Should have the first two messages in messagesToRefine
|
|
266
|
-
expect(result.messagesToRefine.length).toBe(2);
|
|
267
|
-
expect(result.messagesToRefine[0]).toBe(messages[1]); // Message 1
|
|
268
|
-
expect(result.messagesToRefine[1]).toBe(messages[2]); // Response 1
|
|
269
|
-
});
|
|
270
|
-
|
|
271
|
-
it('should always include system message even when at token limit', () => {
|
|
272
|
-
const messages = [
|
|
273
|
-
new SystemMessage('System instruction'),
|
|
274
|
-
new HumanMessage('Hello'),
|
|
275
|
-
new AIMessage('Hi there'),
|
|
276
|
-
];
|
|
277
|
-
|
|
278
|
-
const indexTokenCountMap = {
|
|
279
|
-
0: 17, // "System instruction"
|
|
280
|
-
1: 5, // "Hello"
|
|
281
|
-
2: 8, // "Hi there"
|
|
282
|
-
};
|
|
283
|
-
|
|
284
|
-
// Set a limit that can only fit the system message
|
|
285
|
-
const result = getMessagesWithinTokenLimit({
|
|
286
|
-
messages,
|
|
287
|
-
maxContextTokens: 20,
|
|
288
|
-
indexTokenCountMap,
|
|
289
|
-
});
|
|
290
|
-
|
|
291
|
-
expect(result.context.length).toBe(1);
|
|
292
|
-
expect(result.context[0]).toBe(messages[0]); // System message
|
|
293
|
-
|
|
294
|
-
expect(result.messagesToRefine.length).toBe(2);
|
|
295
|
-
});
|
|
296
|
-
|
|
297
|
-
it('should start context with a specific message type when startType is specified', () => {
|
|
298
|
-
const messages = [
|
|
299
|
-
new SystemMessage('System instruction'),
|
|
300
|
-
new AIMessage('AI message 1'),
|
|
301
|
-
new HumanMessage('Human message 1'),
|
|
302
|
-
new AIMessage('AI message 2'),
|
|
303
|
-
new HumanMessage('Human message 2'),
|
|
304
|
-
];
|
|
305
|
-
|
|
306
|
-
const indexTokenCountMap = {
|
|
307
|
-
0: 17, // "System instruction"
|
|
308
|
-
1: 12, // "AI message 1"
|
|
309
|
-
2: 15, // "Human message 1"
|
|
310
|
-
3: 12, // "AI message 2"
|
|
311
|
-
4: 15, // "Human message 2"
|
|
312
|
-
};
|
|
313
|
-
|
|
314
|
-
// Set a limit that can fit all messages
|
|
315
|
-
const result = getMessagesWithinTokenLimit({
|
|
316
|
-
messages,
|
|
317
|
-
maxContextTokens: 100,
|
|
318
|
-
indexTokenCountMap,
|
|
319
|
-
startType: 'human',
|
|
320
|
-
});
|
|
321
|
-
|
|
322
|
-
// All messages should be included since we're under the token limit
|
|
323
|
-
expect(result.context.length).toBe(5);
|
|
324
|
-
expect(result.context[0]).toBe(messages[0]); // System message
|
|
325
|
-
expect(result.context[1]).toBe(messages[1]); // AI message 1
|
|
326
|
-
expect(result.context[2]).toBe(messages[2]); // Human message 1
|
|
327
|
-
expect(result.context[3]).toBe(messages[3]); // AI message 2
|
|
328
|
-
expect(result.context[4]).toBe(messages[4]); // Human message 2
|
|
329
|
-
|
|
330
|
-
// All messages should be included since we're under the token limit
|
|
331
|
-
expect(result.messagesToRefine.length).toBe(0);
|
|
332
|
-
});
|
|
333
|
-
|
|
334
|
-
it('should keep all messages if no message of required type is found', () => {
|
|
335
|
-
const messages = [
|
|
336
|
-
new SystemMessage('System instruction'),
|
|
337
|
-
new AIMessage('AI message 1'),
|
|
338
|
-
new AIMessage('AI message 2'),
|
|
339
|
-
];
|
|
340
|
-
|
|
341
|
-
const indexTokenCountMap = {
|
|
342
|
-
0: 17, // "System instruction"
|
|
343
|
-
1: 12, // "AI message 1"
|
|
344
|
-
2: 12, // "AI message 2"
|
|
345
|
-
};
|
|
346
|
-
|
|
347
|
-
// Set a limit that can fit all messages
|
|
348
|
-
const result = getMessagesWithinTokenLimit({
|
|
349
|
-
messages,
|
|
350
|
-
maxContextTokens: 100,
|
|
351
|
-
indexTokenCountMap,
|
|
352
|
-
startType: 'human',
|
|
353
|
-
});
|
|
354
|
-
|
|
355
|
-
// Should include all messages since no human messages exist to start from
|
|
356
|
-
expect(result.context.length).toBe(3);
|
|
357
|
-
expect(result.context[0]).toBe(messages[0]); // System message
|
|
358
|
-
expect(result.context[1]).toBe(messages[1]); // AI message 1
|
|
359
|
-
expect(result.context[2]).toBe(messages[2]); // AI message 2
|
|
360
|
-
|
|
361
|
-
expect(result.messagesToRefine.length).toBe(0);
|
|
362
|
-
});
|
|
363
|
-
});
|
|
364
|
-
|
|
365
|
-
describe('checkValidNumber', () => {
|
|
366
|
-
it('should return true for valid positive numbers', () => {
|
|
367
|
-
expect(checkValidNumber(5)).toBe(true);
|
|
368
|
-
expect(checkValidNumber(1.5)).toBe(true);
|
|
369
|
-
expect(checkValidNumber(Number.MAX_SAFE_INTEGER)).toBe(true);
|
|
370
|
-
});
|
|
371
|
-
|
|
372
|
-
it('should return false for zero, negative numbers, and NaN', () => {
|
|
373
|
-
expect(checkValidNumber(0)).toBe(false);
|
|
374
|
-
expect(checkValidNumber(-5)).toBe(false);
|
|
375
|
-
expect(checkValidNumber(NaN)).toBe(false);
|
|
376
|
-
});
|
|
377
|
-
|
|
378
|
-
it('should return false for non-number types', () => {
|
|
379
|
-
expect(checkValidNumber('5')).toBe(false);
|
|
380
|
-
expect(checkValidNumber(null)).toBe(false);
|
|
381
|
-
expect(checkValidNumber(undefined)).toBe(false);
|
|
382
|
-
expect(checkValidNumber({})).toBe(false);
|
|
383
|
-
expect(checkValidNumber([])).toBe(false);
|
|
384
|
-
});
|
|
385
|
-
});
|
|
386
|
-
|
|
387
|
-
describe('createPruneMessages', () => {
|
|
388
|
-
it('should return all messages when under token limit', () => {
|
|
389
|
-
const tokenCounter = createTestTokenCounter();
|
|
390
|
-
const messages = [
|
|
391
|
-
new SystemMessage('System instruction'),
|
|
392
|
-
new HumanMessage('Hello'),
|
|
393
|
-
new AIMessage('Hi there'),
|
|
394
|
-
];
|
|
395
|
-
|
|
396
|
-
const indexTokenCountMap = {
|
|
397
|
-
0: tokenCounter(messages[0]),
|
|
398
|
-
1: tokenCounter(messages[1]),
|
|
399
|
-
2: tokenCounter(messages[2]),
|
|
400
|
-
};
|
|
401
|
-
|
|
402
|
-
const pruneMessages = createPruneMessages({
|
|
403
|
-
maxTokens: 100,
|
|
404
|
-
startIndex: 0,
|
|
405
|
-
tokenCounter,
|
|
406
|
-
indexTokenCountMap,
|
|
407
|
-
});
|
|
408
|
-
|
|
409
|
-
const result = pruneMessages({ messages });
|
|
410
|
-
|
|
411
|
-
expect(result.context.length).toBe(3);
|
|
412
|
-
expect(result.context).toEqual(messages);
|
|
413
|
-
});
|
|
414
|
-
|
|
415
|
-
it('should prune messages when over token limit', () => {
|
|
416
|
-
const tokenCounter = createTestTokenCounter();
|
|
417
|
-
const messages = [
|
|
418
|
-
new SystemMessage('System instruction'),
|
|
419
|
-
new HumanMessage('Message 1'),
|
|
420
|
-
new AIMessage('Response 1'),
|
|
421
|
-
new HumanMessage('Message 2'),
|
|
422
|
-
new AIMessage('Response 2'),
|
|
423
|
-
];
|
|
424
|
-
|
|
425
|
-
const indexTokenCountMap = {
|
|
426
|
-
0: tokenCounter(messages[0]),
|
|
427
|
-
1: tokenCounter(messages[1]),
|
|
428
|
-
2: tokenCounter(messages[2]),
|
|
429
|
-
3: tokenCounter(messages[3]),
|
|
430
|
-
4: tokenCounter(messages[4]),
|
|
431
|
-
};
|
|
432
|
-
|
|
433
|
-
// Set a limit that can only fit the system message and the last two messages
|
|
434
|
-
const pruneMessages = createPruneMessages({
|
|
435
|
-
maxTokens: 40,
|
|
436
|
-
startIndex: 0,
|
|
437
|
-
tokenCounter,
|
|
438
|
-
indexTokenCountMap,
|
|
439
|
-
});
|
|
440
|
-
|
|
441
|
-
const result = pruneMessages({ messages });
|
|
442
|
-
|
|
443
|
-
// Should include system message and the last two messages
|
|
444
|
-
expect(result.context.length).toBe(3);
|
|
445
|
-
expect(result.context[0]).toBe(messages[0]); // System message
|
|
446
|
-
expect(result.context[1]).toBe(messages[3]); // Message 2
|
|
447
|
-
expect(result.context[2]).toBe(messages[4]); // Response 2
|
|
448
|
-
});
|
|
449
|
-
|
|
450
|
-
it('should respect startType parameter', () => {
|
|
451
|
-
const tokenCounter = createTestTokenCounter();
|
|
452
|
-
const messages = [
|
|
453
|
-
new SystemMessage('System instruction'),
|
|
454
|
-
new AIMessage('AI message 1'),
|
|
455
|
-
new HumanMessage('Human message 1'),
|
|
456
|
-
new AIMessage('AI message 2'),
|
|
457
|
-
new HumanMessage('Human message 2'),
|
|
458
|
-
];
|
|
459
|
-
|
|
460
|
-
const indexTokenCountMap = {
|
|
461
|
-
0: tokenCounter(messages[0]),
|
|
462
|
-
1: tokenCounter(messages[1]),
|
|
463
|
-
2: tokenCounter(messages[2]),
|
|
464
|
-
3: tokenCounter(messages[3]),
|
|
465
|
-
4: tokenCounter(messages[4]),
|
|
466
|
-
};
|
|
467
|
-
|
|
468
|
-
// Set a limit that can fit all messages
|
|
469
|
-
const pruneMessages = createPruneMessages({
|
|
470
|
-
maxTokens: 100,
|
|
471
|
-
startIndex: 0,
|
|
472
|
-
tokenCounter,
|
|
473
|
-
indexTokenCountMap: { ...indexTokenCountMap },
|
|
474
|
-
});
|
|
475
|
-
|
|
476
|
-
const result = pruneMessages({
|
|
477
|
-
messages,
|
|
478
|
-
startType: 'human',
|
|
479
|
-
});
|
|
480
|
-
|
|
481
|
-
// All messages should be included since we're under the token limit
|
|
482
|
-
expect(result.context.length).toBe(5);
|
|
483
|
-
expect(result.context[0]).toBe(messages[0]); // System message
|
|
484
|
-
expect(result.context[1]).toBe(messages[1]); // AI message 1
|
|
485
|
-
expect(result.context[2]).toBe(messages[2]); // Human message 1
|
|
486
|
-
expect(result.context[3]).toBe(messages[3]); // AI message 2
|
|
487
|
-
expect(result.context[4]).toBe(messages[4]); // Human message 2
|
|
488
|
-
});
|
|
489
|
-
|
|
490
|
-
it('should update token counts when usage metadata is provided', () => {
|
|
491
|
-
const tokenCounter = createTestTokenCounter();
|
|
492
|
-
const messages = [
|
|
493
|
-
new SystemMessage('System instruction'),
|
|
494
|
-
new HumanMessage('Hello'),
|
|
495
|
-
new AIMessage('Hi there'),
|
|
496
|
-
];
|
|
497
|
-
|
|
498
|
-
const indexTokenCountMap = {
|
|
499
|
-
0: tokenCounter(messages[0]),
|
|
500
|
-
1: tokenCounter(messages[1]),
|
|
501
|
-
2: tokenCounter(messages[2]),
|
|
502
|
-
};
|
|
503
|
-
|
|
504
|
-
const pruneMessages = createPruneMessages({
|
|
505
|
-
maxTokens: 100,
|
|
506
|
-
startIndex: 0,
|
|
507
|
-
tokenCounter,
|
|
508
|
-
indexTokenCountMap: { ...indexTokenCountMap },
|
|
509
|
-
});
|
|
510
|
-
|
|
511
|
-
// Provide usage metadata that indicates different token counts
|
|
512
|
-
const usageMetadata: Partial<UsageMetadata> = {
|
|
513
|
-
input_tokens: 50,
|
|
514
|
-
output_tokens: 25,
|
|
515
|
-
total_tokens: 75,
|
|
516
|
-
};
|
|
517
|
-
|
|
518
|
-
const result = pruneMessages({
|
|
519
|
-
messages,
|
|
520
|
-
usageMetadata,
|
|
521
|
-
});
|
|
522
|
-
|
|
523
|
-
// The function should have updated the indexTokenCountMap based on the usage metadata
|
|
524
|
-
expect(result.indexTokenCountMap).not.toEqual(indexTokenCountMap);
|
|
525
|
-
|
|
526
|
-
// The total of all values in indexTokenCountMap should equal the total_tokens from usageMetadata
|
|
527
|
-
const totalTokens = Object.values(result.indexTokenCountMap).reduce(
|
|
528
|
-
(a = 0, b = 0) => a + b,
|
|
529
|
-
0
|
|
530
|
-
);
|
|
531
|
-
expect(totalTokens).toBe(75);
|
|
532
|
-
});
|
|
533
|
-
});
|
|
534
|
-
|
|
535
|
-
describe('Tool Message Handling', () => {
|
|
536
|
-
it('should ensure context does not start with a tool message by finding an AI message', () => {
|
|
537
|
-
const tokenCounter = createTestTokenCounter();
|
|
538
|
-
const messages = [
|
|
539
|
-
new SystemMessage('System instruction'),
|
|
540
|
-
new AIMessage('AI message 1'),
|
|
541
|
-
new ToolMessage({ content: 'Tool result 1', tool_call_id: 'tool1' }),
|
|
542
|
-
new AIMessage('AI message 2'),
|
|
543
|
-
new ToolMessage({ content: 'Tool result 2', tool_call_id: 'tool2' }),
|
|
544
|
-
];
|
|
545
|
-
|
|
546
|
-
const indexTokenCountMap = {
|
|
547
|
-
0: 17, // System instruction
|
|
548
|
-
1: 12, // AI message 1
|
|
549
|
-
2: 13, // Tool result 1
|
|
550
|
-
3: 12, // AI message 2
|
|
551
|
-
4: 13, // Tool result 2
|
|
552
|
-
};
|
|
553
|
-
|
|
554
|
-
// Create a pruneMessages function with a token limit that will only include the last few messages
|
|
555
|
-
const pruneMessages = createPruneMessages({
|
|
556
|
-
maxTokens: 58, // Only enough for system + last 3 messages + 3, but should not include a parent-less tool message
|
|
557
|
-
startIndex: 0,
|
|
558
|
-
tokenCounter,
|
|
559
|
-
indexTokenCountMap: { ...indexTokenCountMap },
|
|
560
|
-
});
|
|
561
|
-
|
|
562
|
-
const result = pruneMessages({ messages });
|
|
563
|
-
|
|
564
|
-
// The context should include the system message, AI message 2, and Tool result 2
|
|
565
|
-
// It should NOT start with Tool result 2 alone
|
|
566
|
-
expect(result.context.length).toBe(3);
|
|
567
|
-
expect(result.context[0]).toBe(messages[0]); // System message
|
|
568
|
-
expect(result.context[1]).toBe(messages[3]); // AI message 2
|
|
569
|
-
expect(result.context[2]).toBe(messages[4]); // Tool result 2
|
|
570
|
-
});
|
|
571
|
-
|
|
572
|
-
it('should ensure context does not start with a tool message by finding a human message', () => {
|
|
573
|
-
const tokenCounter = createTestTokenCounter();
|
|
574
|
-
const messages = [
|
|
575
|
-
new SystemMessage('System instruction'),
|
|
576
|
-
new HumanMessage('Human message 1'),
|
|
577
|
-
new AIMessage('AI message 1'),
|
|
578
|
-
new ToolMessage({ content: 'Tool result 1', tool_call_id: 'tool1' }),
|
|
579
|
-
new HumanMessage('Human message 2'),
|
|
580
|
-
new ToolMessage({ content: 'Tool result 2', tool_call_id: 'tool2' }),
|
|
581
|
-
];
|
|
582
|
-
|
|
583
|
-
const indexTokenCountMap = {
|
|
584
|
-
0: 17, // System instruction
|
|
585
|
-
1: 15, // Human message 1
|
|
586
|
-
2: 12, // AI message 1
|
|
587
|
-
3: 13, // Tool result 1
|
|
588
|
-
4: 15, // Human message 2
|
|
589
|
-
5: 13, // Tool result 2
|
|
590
|
-
};
|
|
591
|
-
|
|
592
|
-
// Create a pruneMessages function with a token limit that will only include the last few messages
|
|
593
|
-
const pruneMessages = createPruneMessages({
|
|
594
|
-
maxTokens: 48, // Only enough for system + last 2 messages
|
|
595
|
-
startIndex: 0,
|
|
596
|
-
tokenCounter,
|
|
597
|
-
indexTokenCountMap: { ...indexTokenCountMap },
|
|
598
|
-
});
|
|
599
|
-
|
|
600
|
-
const result = pruneMessages({ messages });
|
|
601
|
-
|
|
602
|
-
// The context should include the system message, Human message 2, and Tool result 2
|
|
603
|
-
// It should NOT start with Tool result 2 alone
|
|
604
|
-
expect(result.context.length).toBe(3);
|
|
605
|
-
expect(result.context[0]).toBe(messages[0]); // System message
|
|
606
|
-
expect(result.context[1]).toBe(messages[4]); // Human message 2
|
|
607
|
-
expect(result.context[2]).toBe(messages[5]); // Tool result 2
|
|
608
|
-
});
|
|
609
|
-
|
|
610
|
-
it('should handle the case where a tool message is followed by an AI message', () => {
|
|
611
|
-
const tokenCounter = createTestTokenCounter();
|
|
612
|
-
const messages = [
|
|
613
|
-
new SystemMessage('System instruction'),
|
|
614
|
-
new HumanMessage('Human message'),
|
|
615
|
-
new AIMessage('AI message with tool use'),
|
|
616
|
-
new ToolMessage({ content: 'Tool result', tool_call_id: 'tool1' }),
|
|
617
|
-
new AIMessage('AI message after tool'),
|
|
618
|
-
];
|
|
619
|
-
|
|
620
|
-
const indexTokenCountMap = {
|
|
621
|
-
0: 17, // System instruction
|
|
622
|
-
1: 13, // Human message
|
|
623
|
-
2: 22, // AI message with tool use
|
|
624
|
-
3: 11, // Tool result
|
|
625
|
-
4: 19, // AI message after tool
|
|
626
|
-
};
|
|
627
|
-
|
|
628
|
-
const pruneMessages = createPruneMessages({
|
|
629
|
-
maxTokens: 50,
|
|
630
|
-
startIndex: 0,
|
|
631
|
-
tokenCounter,
|
|
632
|
-
indexTokenCountMap: { ...indexTokenCountMap },
|
|
633
|
-
});
|
|
634
|
-
|
|
635
|
-
const result = pruneMessages({ messages });
|
|
636
|
-
|
|
637
|
-
expect(result.context.length).toBe(2);
|
|
638
|
-
expect(result.context[0]).toBe(messages[0]); // System message
|
|
639
|
-
expect(result.context[1]).toBe(messages[4]); // AI message after tool
|
|
640
|
-
});
|
|
641
|
-
|
|
642
|
-
it('should handle the case where a tool message is followed by a human message', () => {
|
|
643
|
-
const tokenCounter = createTestTokenCounter();
|
|
644
|
-
const messages = [
|
|
645
|
-
new SystemMessage('System instruction'),
|
|
646
|
-
new HumanMessage('Human message 1'),
|
|
647
|
-
new AIMessage('AI message with tool use'),
|
|
648
|
-
new ToolMessage({ content: 'Tool result', tool_call_id: 'tool1' }),
|
|
649
|
-
new HumanMessage('Human message 2'),
|
|
650
|
-
];
|
|
651
|
-
|
|
652
|
-
const indexTokenCountMap = {
|
|
653
|
-
0: 17, // System instruction
|
|
654
|
-
1: 15, // Human message 1
|
|
655
|
-
2: 22, // AI message with tool use
|
|
656
|
-
3: 11, // Tool result
|
|
657
|
-
4: 15, // Human message 2
|
|
658
|
-
};
|
|
659
|
-
|
|
660
|
-
const pruneMessages = createPruneMessages({
|
|
661
|
-
maxTokens: 46,
|
|
662
|
-
startIndex: 0,
|
|
663
|
-
tokenCounter,
|
|
664
|
-
indexTokenCountMap: { ...indexTokenCountMap },
|
|
665
|
-
});
|
|
666
|
-
|
|
667
|
-
const result = pruneMessages({ messages });
|
|
668
|
-
|
|
669
|
-
expect(result.context.length).toBe(2);
|
|
670
|
-
expect(result.context[0]).toBe(messages[0]); // System message
|
|
671
|
-
expect(result.context[1]).toBe(messages[4]); // Human message 2
|
|
672
|
-
});
|
|
673
|
-
|
|
674
|
-
it('should handle complex sequence with multiple tool messages', () => {
|
|
675
|
-
const tokenCounter = createTestTokenCounter();
|
|
676
|
-
const messages = [
|
|
677
|
-
new SystemMessage('System instruction'),
|
|
678
|
-
new HumanMessage('Human message 1'),
|
|
679
|
-
new AIMessage('AI message 1 with tool use'),
|
|
680
|
-
new ToolMessage({ content: 'Tool result 1', tool_call_id: 'tool1' }),
|
|
681
|
-
new AIMessage('AI message 2 with tool use'),
|
|
682
|
-
new ToolMessage({ content: 'Tool result 2', tool_call_id: 'tool2' }),
|
|
683
|
-
new AIMessage('AI message 3 with tool use'),
|
|
684
|
-
new ToolMessage({ content: 'Tool result 3', tool_call_id: 'tool3' }),
|
|
685
|
-
];
|
|
686
|
-
|
|
687
|
-
const indexTokenCountMap = {
|
|
688
|
-
0: 17, // System instruction
|
|
689
|
-
1: 15, // Human message 1
|
|
690
|
-
2: 26, // AI message 1 with tool use
|
|
691
|
-
3: 13, // Tool result 1
|
|
692
|
-
4: 26, // AI message 2 with tool use
|
|
693
|
-
5: 13, // Tool result 2
|
|
694
|
-
6: 26, // AI message 3 with tool use
|
|
695
|
-
7: 13, // Tool result 3
|
|
696
|
-
};
|
|
697
|
-
|
|
698
|
-
const pruneMessages = createPruneMessages({
|
|
699
|
-
maxTokens: 111,
|
|
700
|
-
startIndex: 0,
|
|
701
|
-
tokenCounter,
|
|
702
|
-
indexTokenCountMap: { ...indexTokenCountMap },
|
|
703
|
-
});
|
|
704
|
-
|
|
705
|
-
const result = pruneMessages({ messages });
|
|
706
|
-
|
|
707
|
-
expect(result.context.length).toBe(5);
|
|
708
|
-
expect(result.context[0]).toBe(messages[0]); // System message
|
|
709
|
-
expect(result.context[1]).toBe(messages[4]); // AI message 2 with tool use
|
|
710
|
-
expect(result.context[2]).toBe(messages[5]); // Tool result 2
|
|
711
|
-
expect(result.context[3]).toBe(messages[6]); // AI message 3 with tool use
|
|
712
|
-
expect(result.context[4]).toBe(messages[7]); // Tool result 3
|
|
713
|
-
});
|
|
714
|
-
});
|
|
715
|
-
|
|
716
|
-
describe('Integration with Run', () => {
|
|
717
|
-
it('should initialize Run with custom token counter and process messages', async () => {
|
|
718
|
-
const provider = Providers.OPENAI;
|
|
719
|
-
const llmConfig = getLLMConfig(provider);
|
|
720
|
-
const tokenCounter = createTestTokenCounter();
|
|
721
|
-
|
|
722
|
-
const run = await Run.create<t.IState>({
|
|
723
|
-
runId: 'test-prune-run',
|
|
724
|
-
graphConfig: {
|
|
725
|
-
type: 'standard',
|
|
726
|
-
llmConfig,
|
|
727
|
-
instructions: 'You are a helpful assistant.',
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
expect(finalMessages).toBeDefined();
|
|
760
|
-
expect(finalMessages?.length).toBeGreaterThan(0);
|
|
761
|
-
});
|
|
762
|
-
});
|
|
763
|
-
});
|
|
1
|
+
// src/specs/prune.test.ts
|
|
2
|
+
import { config } from 'dotenv';
|
|
3
|
+
config();
|
|
4
|
+
import {
|
|
5
|
+
HumanMessage,
|
|
6
|
+
AIMessage,
|
|
7
|
+
SystemMessage,
|
|
8
|
+
BaseMessage,
|
|
9
|
+
ToolMessage,
|
|
10
|
+
} from '@langchain/core/messages';
|
|
11
|
+
import type { RunnableConfig } from '@langchain/core/runnables';
|
|
12
|
+
import type { UsageMetadata } from '@langchain/core/messages';
|
|
13
|
+
import type * as t from '@/types';
|
|
14
|
+
import { createPruneMessages } from '@/messages/prune';
|
|
15
|
+
import { getLLMConfig } from '@/utils/llmConfig';
|
|
16
|
+
import { Providers } from '@/common';
|
|
17
|
+
import { Run } from '@/run';
|
|
18
|
+
|
|
19
|
+
// Create a simple token counter for testing
|
|
20
|
+
const createTestTokenCounter = (): t.TokenCounter => {
|
|
21
|
+
// This simple token counter just counts characters as tokens for predictable testing
|
|
22
|
+
return (message: BaseMessage): number => {
|
|
23
|
+
// Use type assertion to help TypeScript understand the type
|
|
24
|
+
const content = message.content as
|
|
25
|
+
| string
|
|
26
|
+
| Array<t.MessageContentComplex | string>
|
|
27
|
+
| undefined;
|
|
28
|
+
|
|
29
|
+
// Handle string content
|
|
30
|
+
if (typeof content === 'string') {
|
|
31
|
+
return content.length;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
// Handle array content
|
|
35
|
+
if (Array.isArray(content)) {
|
|
36
|
+
let totalLength = 0;
|
|
37
|
+
|
|
38
|
+
for (const item of content) {
|
|
39
|
+
if (typeof item === 'string') {
|
|
40
|
+
totalLength += item.length;
|
|
41
|
+
} else if (typeof item === 'object') {
|
|
42
|
+
if ('text' in item && typeof item.text === 'string') {
|
|
43
|
+
totalLength += item.text.length;
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
return totalLength;
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
// Default case - if content is null, undefined, or any other type
|
|
52
|
+
return 0;
|
|
53
|
+
};
|
|
54
|
+
};
|
|
55
|
+
|
|
56
|
+
// Since the internal functions in prune.ts are not exported, we'll reimplement them here for testing
|
|
57
|
+
// This is based on the implementation in src/messages/prune.ts
|
|
58
|
+
function calculateTotalTokens(usage: Partial<UsageMetadata>): UsageMetadata {
|
|
59
|
+
const baseInputTokens = Number(usage.input_tokens) || 0;
|
|
60
|
+
const cacheCreation = Number(usage.input_token_details?.cache_creation) || 0;
|
|
61
|
+
const cacheRead = Number(usage.input_token_details?.cache_read) || 0;
|
|
62
|
+
|
|
63
|
+
const totalInputTokens = baseInputTokens + cacheCreation + cacheRead;
|
|
64
|
+
const totalOutputTokens = Number(usage.output_tokens) || 0;
|
|
65
|
+
|
|
66
|
+
return {
|
|
67
|
+
input_tokens: totalInputTokens,
|
|
68
|
+
output_tokens: totalOutputTokens,
|
|
69
|
+
total_tokens: totalInputTokens + totalOutputTokens,
|
|
70
|
+
};
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
function getMessagesWithinTokenLimit({
|
|
74
|
+
messages: _messages,
|
|
75
|
+
maxContextTokens,
|
|
76
|
+
indexTokenCountMap,
|
|
77
|
+
startType,
|
|
78
|
+
}: {
|
|
79
|
+
messages: BaseMessage[];
|
|
80
|
+
maxContextTokens: number;
|
|
81
|
+
indexTokenCountMap: Record<string, number>;
|
|
82
|
+
startType?: string;
|
|
83
|
+
}): {
|
|
84
|
+
context: BaseMessage[];
|
|
85
|
+
remainingContextTokens: number;
|
|
86
|
+
messagesToRefine: BaseMessage[];
|
|
87
|
+
summaryIndex: number;
|
|
88
|
+
} {
|
|
89
|
+
// Every reply is primed with <|start|>assistant<|message|>, so we
|
|
90
|
+
// start with 3 tokens for the label after all messages have been counted.
|
|
91
|
+
let summaryIndex = -1;
|
|
92
|
+
let currentTokenCount = 3;
|
|
93
|
+
const instructions =
|
|
94
|
+
_messages[0]?.getType() === 'system' ? _messages[0] : undefined;
|
|
95
|
+
const instructionsTokenCount =
|
|
96
|
+
instructions != null ? indexTokenCountMap[0] : 0;
|
|
97
|
+
let remainingContextTokens = maxContextTokens - instructionsTokenCount;
|
|
98
|
+
const messages = [..._messages];
|
|
99
|
+
const context: BaseMessage[] = [];
|
|
100
|
+
|
|
101
|
+
if (currentTokenCount < remainingContextTokens) {
|
|
102
|
+
let currentIndex = messages.length;
|
|
103
|
+
while (
|
|
104
|
+
messages.length > 0 &&
|
|
105
|
+
currentTokenCount < remainingContextTokens &&
|
|
106
|
+
currentIndex > 1
|
|
107
|
+
) {
|
|
108
|
+
currentIndex--;
|
|
109
|
+
if (messages.length === 1 && instructions) {
|
|
110
|
+
break;
|
|
111
|
+
}
|
|
112
|
+
const poppedMessage = messages.pop();
|
|
113
|
+
if (!poppedMessage) continue;
|
|
114
|
+
|
|
115
|
+
const tokenCount = indexTokenCountMap[currentIndex] || 0;
|
|
116
|
+
|
|
117
|
+
if (currentTokenCount + tokenCount <= remainingContextTokens) {
|
|
118
|
+
context.push(poppedMessage);
|
|
119
|
+
currentTokenCount += tokenCount;
|
|
120
|
+
} else {
|
|
121
|
+
messages.push(poppedMessage);
|
|
122
|
+
break;
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
// If startType is specified, discard messages until we find one of the required type
|
|
127
|
+
if (startType != null && startType && context.length > 0) {
|
|
128
|
+
const requiredTypeIndex = context.findIndex(
|
|
129
|
+
(msg) => msg.getType() === startType
|
|
130
|
+
);
|
|
131
|
+
|
|
132
|
+
if (requiredTypeIndex > 0) {
|
|
133
|
+
// If we found a message of the required type, discard all messages before it
|
|
134
|
+
const remainingMessages = context.slice(requiredTypeIndex);
|
|
135
|
+
context.length = 0; // Clear the array
|
|
136
|
+
context.push(...remainingMessages);
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
if (instructions && _messages.length > 0) {
|
|
142
|
+
context.push(_messages[0] as BaseMessage);
|
|
143
|
+
messages.shift();
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
const prunedMemory = messages;
|
|
147
|
+
summaryIndex = prunedMemory.length - 1;
|
|
148
|
+
remainingContextTokens -= currentTokenCount;
|
|
149
|
+
|
|
150
|
+
return {
|
|
151
|
+
summaryIndex,
|
|
152
|
+
remainingContextTokens,
|
|
153
|
+
context: context.reverse(),
|
|
154
|
+
messagesToRefine: prunedMemory,
|
|
155
|
+
};
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
function checkValidNumber(value: unknown): value is number {
|
|
159
|
+
return typeof value === 'number' && !isNaN(value) && value > 0;
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
describe('Prune Messages Tests', () => {
|
|
163
|
+
jest.setTimeout(30000);
|
|
164
|
+
|
|
165
|
+
describe('calculateTotalTokens', () => {
|
|
166
|
+
it('should calculate total tokens correctly with all fields present', () => {
|
|
167
|
+
const usage: Partial<UsageMetadata> = {
|
|
168
|
+
input_tokens: 100,
|
|
169
|
+
output_tokens: 50,
|
|
170
|
+
input_token_details: {
|
|
171
|
+
cache_creation: 10,
|
|
172
|
+
cache_read: 5,
|
|
173
|
+
},
|
|
174
|
+
};
|
|
175
|
+
|
|
176
|
+
const result = calculateTotalTokens(usage);
|
|
177
|
+
|
|
178
|
+
expect(result.input_tokens).toBe(115); // 100 + 10 + 5
|
|
179
|
+
expect(result.output_tokens).toBe(50);
|
|
180
|
+
expect(result.total_tokens).toBe(165); // 115 + 50
|
|
181
|
+
});
|
|
182
|
+
|
|
183
|
+
it('should handle missing fields gracefully', () => {
|
|
184
|
+
const usage: Partial<UsageMetadata> = {
|
|
185
|
+
input_tokens: 100,
|
|
186
|
+
output_tokens: 50,
|
|
187
|
+
};
|
|
188
|
+
|
|
189
|
+
const result = calculateTotalTokens(usage);
|
|
190
|
+
|
|
191
|
+
expect(result.input_tokens).toBe(100);
|
|
192
|
+
expect(result.output_tokens).toBe(50);
|
|
193
|
+
expect(result.total_tokens).toBe(150);
|
|
194
|
+
});
|
|
195
|
+
|
|
196
|
+
it('should handle empty usage object', () => {
|
|
197
|
+
const usage: Partial<UsageMetadata> = {};
|
|
198
|
+
|
|
199
|
+
const result = calculateTotalTokens(usage);
|
|
200
|
+
|
|
201
|
+
expect(result.input_tokens).toBe(0);
|
|
202
|
+
expect(result.output_tokens).toBe(0);
|
|
203
|
+
expect(result.total_tokens).toBe(0);
|
|
204
|
+
});
|
|
205
|
+
});
|
|
206
|
+
|
|
207
|
+
describe('getMessagesWithinTokenLimit', () => {
|
|
208
|
+
it('should include all messages when under token limit', () => {
|
|
209
|
+
const messages = [
|
|
210
|
+
new SystemMessage('System instruction'),
|
|
211
|
+
new HumanMessage('Hello'),
|
|
212
|
+
new AIMessage('Hi there'),
|
|
213
|
+
];
|
|
214
|
+
|
|
215
|
+
const indexTokenCountMap = {
|
|
216
|
+
0: 17, // "System instruction"
|
|
217
|
+
1: 5, // "Hello"
|
|
218
|
+
2: 8, // "Hi there"
|
|
219
|
+
};
|
|
220
|
+
|
|
221
|
+
const result = getMessagesWithinTokenLimit({
|
|
222
|
+
messages,
|
|
223
|
+
maxContextTokens: 100,
|
|
224
|
+
indexTokenCountMap,
|
|
225
|
+
});
|
|
226
|
+
|
|
227
|
+
expect(result.context.length).toBe(3);
|
|
228
|
+
expect(result.context[0]).toBe(messages[0]); // System message
|
|
229
|
+
expect(result.context[0].getType()).toBe('system'); // System message
|
|
230
|
+
expect(result.remainingContextTokens).toBe(100 - 17 - 5 - 8 - 3); // -3 for the assistant label tokens
|
|
231
|
+
expect(result.messagesToRefine.length).toBe(0);
|
|
232
|
+
});
|
|
233
|
+
|
|
234
|
+
it('should prune oldest messages when over token limit', () => {
|
|
235
|
+
const messages = [
|
|
236
|
+
new SystemMessage('System instruction'),
|
|
237
|
+
new HumanMessage('Message 1'),
|
|
238
|
+
new AIMessage('Response 1'),
|
|
239
|
+
new HumanMessage('Message 2'),
|
|
240
|
+
new AIMessage('Response 2'),
|
|
241
|
+
];
|
|
242
|
+
|
|
243
|
+
const indexTokenCountMap = {
|
|
244
|
+
0: 17, // "System instruction"
|
|
245
|
+
1: 9, // "Message 1"
|
|
246
|
+
2: 10, // "Response 1"
|
|
247
|
+
3: 9, // "Message 2"
|
|
248
|
+
4: 10, // "Response 2"
|
|
249
|
+
};
|
|
250
|
+
|
|
251
|
+
// Set a limit that can only fit the system message and the last two messages
|
|
252
|
+
const result = getMessagesWithinTokenLimit({
|
|
253
|
+
messages,
|
|
254
|
+
maxContextTokens: 40,
|
|
255
|
+
indexTokenCountMap,
|
|
256
|
+
});
|
|
257
|
+
|
|
258
|
+
// Should include system message and the last two messages
|
|
259
|
+
expect(result.context.length).toBe(3);
|
|
260
|
+
expect(result.context[0]).toBe(messages[0]); // System message
|
|
261
|
+
expect(result.context[0].getType()).toBe('system'); // System message
|
|
262
|
+
expect(result.context[1]).toBe(messages[3]); // Message 2
|
|
263
|
+
expect(result.context[2]).toBe(messages[4]); // Response 2
|
|
264
|
+
|
|
265
|
+
// Should have the first two messages in messagesToRefine
|
|
266
|
+
expect(result.messagesToRefine.length).toBe(2);
|
|
267
|
+
expect(result.messagesToRefine[0]).toBe(messages[1]); // Message 1
|
|
268
|
+
expect(result.messagesToRefine[1]).toBe(messages[2]); // Response 1
|
|
269
|
+
});
|
|
270
|
+
|
|
271
|
+
it('should always include system message even when at token limit', () => {
|
|
272
|
+
const messages = [
|
|
273
|
+
new SystemMessage('System instruction'),
|
|
274
|
+
new HumanMessage('Hello'),
|
|
275
|
+
new AIMessage('Hi there'),
|
|
276
|
+
];
|
|
277
|
+
|
|
278
|
+
const indexTokenCountMap = {
|
|
279
|
+
0: 17, // "System instruction"
|
|
280
|
+
1: 5, // "Hello"
|
|
281
|
+
2: 8, // "Hi there"
|
|
282
|
+
};
|
|
283
|
+
|
|
284
|
+
// Set a limit that can only fit the system message
|
|
285
|
+
const result = getMessagesWithinTokenLimit({
|
|
286
|
+
messages,
|
|
287
|
+
maxContextTokens: 20,
|
|
288
|
+
indexTokenCountMap,
|
|
289
|
+
});
|
|
290
|
+
|
|
291
|
+
expect(result.context.length).toBe(1);
|
|
292
|
+
expect(result.context[0]).toBe(messages[0]); // System message
|
|
293
|
+
|
|
294
|
+
expect(result.messagesToRefine.length).toBe(2);
|
|
295
|
+
});
|
|
296
|
+
|
|
297
|
+
it('should start context with a specific message type when startType is specified', () => {
|
|
298
|
+
const messages = [
|
|
299
|
+
new SystemMessage('System instruction'),
|
|
300
|
+
new AIMessage('AI message 1'),
|
|
301
|
+
new HumanMessage('Human message 1'),
|
|
302
|
+
new AIMessage('AI message 2'),
|
|
303
|
+
new HumanMessage('Human message 2'),
|
|
304
|
+
];
|
|
305
|
+
|
|
306
|
+
const indexTokenCountMap = {
|
|
307
|
+
0: 17, // "System instruction"
|
|
308
|
+
1: 12, // "AI message 1"
|
|
309
|
+
2: 15, // "Human message 1"
|
|
310
|
+
3: 12, // "AI message 2"
|
|
311
|
+
4: 15, // "Human message 2"
|
|
312
|
+
};
|
|
313
|
+
|
|
314
|
+
// Set a limit that can fit all messages
|
|
315
|
+
const result = getMessagesWithinTokenLimit({
|
|
316
|
+
messages,
|
|
317
|
+
maxContextTokens: 100,
|
|
318
|
+
indexTokenCountMap,
|
|
319
|
+
startType: 'human',
|
|
320
|
+
});
|
|
321
|
+
|
|
322
|
+
// All messages should be included since we're under the token limit
|
|
323
|
+
expect(result.context.length).toBe(5);
|
|
324
|
+
expect(result.context[0]).toBe(messages[0]); // System message
|
|
325
|
+
expect(result.context[1]).toBe(messages[1]); // AI message 1
|
|
326
|
+
expect(result.context[2]).toBe(messages[2]); // Human message 1
|
|
327
|
+
expect(result.context[3]).toBe(messages[3]); // AI message 2
|
|
328
|
+
expect(result.context[4]).toBe(messages[4]); // Human message 2
|
|
329
|
+
|
|
330
|
+
// All messages should be included since we're under the token limit
|
|
331
|
+
expect(result.messagesToRefine.length).toBe(0);
|
|
332
|
+
});
|
|
333
|
+
|
|
334
|
+
it('should keep all messages if no message of required type is found', () => {
|
|
335
|
+
const messages = [
|
|
336
|
+
new SystemMessage('System instruction'),
|
|
337
|
+
new AIMessage('AI message 1'),
|
|
338
|
+
new AIMessage('AI message 2'),
|
|
339
|
+
];
|
|
340
|
+
|
|
341
|
+
const indexTokenCountMap = {
|
|
342
|
+
0: 17, // "System instruction"
|
|
343
|
+
1: 12, // "AI message 1"
|
|
344
|
+
2: 12, // "AI message 2"
|
|
345
|
+
};
|
|
346
|
+
|
|
347
|
+
// Set a limit that can fit all messages
|
|
348
|
+
const result = getMessagesWithinTokenLimit({
|
|
349
|
+
messages,
|
|
350
|
+
maxContextTokens: 100,
|
|
351
|
+
indexTokenCountMap,
|
|
352
|
+
startType: 'human',
|
|
353
|
+
});
|
|
354
|
+
|
|
355
|
+
// Should include all messages since no human messages exist to start from
|
|
356
|
+
expect(result.context.length).toBe(3);
|
|
357
|
+
expect(result.context[0]).toBe(messages[0]); // System message
|
|
358
|
+
expect(result.context[1]).toBe(messages[1]); // AI message 1
|
|
359
|
+
expect(result.context[2]).toBe(messages[2]); // AI message 2
|
|
360
|
+
|
|
361
|
+
expect(result.messagesToRefine.length).toBe(0);
|
|
362
|
+
});
|
|
363
|
+
});
|
|
364
|
+
|
|
365
|
+
describe('checkValidNumber', () => {
|
|
366
|
+
it('should return true for valid positive numbers', () => {
|
|
367
|
+
expect(checkValidNumber(5)).toBe(true);
|
|
368
|
+
expect(checkValidNumber(1.5)).toBe(true);
|
|
369
|
+
expect(checkValidNumber(Number.MAX_SAFE_INTEGER)).toBe(true);
|
|
370
|
+
});
|
|
371
|
+
|
|
372
|
+
it('should return false for zero, negative numbers, and NaN', () => {
|
|
373
|
+
expect(checkValidNumber(0)).toBe(false);
|
|
374
|
+
expect(checkValidNumber(-5)).toBe(false);
|
|
375
|
+
expect(checkValidNumber(NaN)).toBe(false);
|
|
376
|
+
});
|
|
377
|
+
|
|
378
|
+
it('should return false for non-number types', () => {
|
|
379
|
+
expect(checkValidNumber('5')).toBe(false);
|
|
380
|
+
expect(checkValidNumber(null)).toBe(false);
|
|
381
|
+
expect(checkValidNumber(undefined)).toBe(false);
|
|
382
|
+
expect(checkValidNumber({})).toBe(false);
|
|
383
|
+
expect(checkValidNumber([])).toBe(false);
|
|
384
|
+
});
|
|
385
|
+
});
|
|
386
|
+
|
|
387
|
+
describe('createPruneMessages', () => {
|
|
388
|
+
it('should return all messages when under token limit', () => {
|
|
389
|
+
const tokenCounter = createTestTokenCounter();
|
|
390
|
+
const messages = [
|
|
391
|
+
new SystemMessage('System instruction'),
|
|
392
|
+
new HumanMessage('Hello'),
|
|
393
|
+
new AIMessage('Hi there'),
|
|
394
|
+
];
|
|
395
|
+
|
|
396
|
+
const indexTokenCountMap = {
|
|
397
|
+
0: tokenCounter(messages[0]),
|
|
398
|
+
1: tokenCounter(messages[1]),
|
|
399
|
+
2: tokenCounter(messages[2]),
|
|
400
|
+
};
|
|
401
|
+
|
|
402
|
+
const pruneMessages = createPruneMessages({
|
|
403
|
+
maxTokens: 100,
|
|
404
|
+
startIndex: 0,
|
|
405
|
+
tokenCounter,
|
|
406
|
+
indexTokenCountMap,
|
|
407
|
+
});
|
|
408
|
+
|
|
409
|
+
const result = pruneMessages({ messages });
|
|
410
|
+
|
|
411
|
+
expect(result.context.length).toBe(3);
|
|
412
|
+
expect(result.context).toEqual(messages);
|
|
413
|
+
});
|
|
414
|
+
|
|
415
|
+
it('should prune messages when over token limit', () => {
|
|
416
|
+
const tokenCounter = createTestTokenCounter();
|
|
417
|
+
const messages = [
|
|
418
|
+
new SystemMessage('System instruction'),
|
|
419
|
+
new HumanMessage('Message 1'),
|
|
420
|
+
new AIMessage('Response 1'),
|
|
421
|
+
new HumanMessage('Message 2'),
|
|
422
|
+
new AIMessage('Response 2'),
|
|
423
|
+
];
|
|
424
|
+
|
|
425
|
+
const indexTokenCountMap = {
|
|
426
|
+
0: tokenCounter(messages[0]),
|
|
427
|
+
1: tokenCounter(messages[1]),
|
|
428
|
+
2: tokenCounter(messages[2]),
|
|
429
|
+
3: tokenCounter(messages[3]),
|
|
430
|
+
4: tokenCounter(messages[4]),
|
|
431
|
+
};
|
|
432
|
+
|
|
433
|
+
// Set a limit that can only fit the system message and the last two messages
|
|
434
|
+
const pruneMessages = createPruneMessages({
|
|
435
|
+
maxTokens: 40,
|
|
436
|
+
startIndex: 0,
|
|
437
|
+
tokenCounter,
|
|
438
|
+
indexTokenCountMap,
|
|
439
|
+
});
|
|
440
|
+
|
|
441
|
+
const result = pruneMessages({ messages });
|
|
442
|
+
|
|
443
|
+
// Should include system message and the last two messages
|
|
444
|
+
expect(result.context.length).toBe(3);
|
|
445
|
+
expect(result.context[0]).toBe(messages[0]); // System message
|
|
446
|
+
expect(result.context[1]).toBe(messages[3]); // Message 2
|
|
447
|
+
expect(result.context[2]).toBe(messages[4]); // Response 2
|
|
448
|
+
});
|
|
449
|
+
|
|
450
|
+
it('should respect startType parameter', () => {
|
|
451
|
+
const tokenCounter = createTestTokenCounter();
|
|
452
|
+
const messages = [
|
|
453
|
+
new SystemMessage('System instruction'),
|
|
454
|
+
new AIMessage('AI message 1'),
|
|
455
|
+
new HumanMessage('Human message 1'),
|
|
456
|
+
new AIMessage('AI message 2'),
|
|
457
|
+
new HumanMessage('Human message 2'),
|
|
458
|
+
];
|
|
459
|
+
|
|
460
|
+
const indexTokenCountMap = {
|
|
461
|
+
0: tokenCounter(messages[0]),
|
|
462
|
+
1: tokenCounter(messages[1]),
|
|
463
|
+
2: tokenCounter(messages[2]),
|
|
464
|
+
3: tokenCounter(messages[3]),
|
|
465
|
+
4: tokenCounter(messages[4]),
|
|
466
|
+
};
|
|
467
|
+
|
|
468
|
+
// Set a limit that can fit all messages
|
|
469
|
+
const pruneMessages = createPruneMessages({
|
|
470
|
+
maxTokens: 100,
|
|
471
|
+
startIndex: 0,
|
|
472
|
+
tokenCounter,
|
|
473
|
+
indexTokenCountMap: { ...indexTokenCountMap },
|
|
474
|
+
});
|
|
475
|
+
|
|
476
|
+
const result = pruneMessages({
|
|
477
|
+
messages,
|
|
478
|
+
startType: 'human',
|
|
479
|
+
});
|
|
480
|
+
|
|
481
|
+
// All messages should be included since we're under the token limit
|
|
482
|
+
expect(result.context.length).toBe(5);
|
|
483
|
+
expect(result.context[0]).toBe(messages[0]); // System message
|
|
484
|
+
expect(result.context[1]).toBe(messages[1]); // AI message 1
|
|
485
|
+
expect(result.context[2]).toBe(messages[2]); // Human message 1
|
|
486
|
+
expect(result.context[3]).toBe(messages[3]); // AI message 2
|
|
487
|
+
expect(result.context[4]).toBe(messages[4]); // Human message 2
|
|
488
|
+
});
|
|
489
|
+
|
|
490
|
+
it('should update token counts when usage metadata is provided', () => {
|
|
491
|
+
const tokenCounter = createTestTokenCounter();
|
|
492
|
+
const messages = [
|
|
493
|
+
new SystemMessage('System instruction'),
|
|
494
|
+
new HumanMessage('Hello'),
|
|
495
|
+
new AIMessage('Hi there'),
|
|
496
|
+
];
|
|
497
|
+
|
|
498
|
+
const indexTokenCountMap = {
|
|
499
|
+
0: tokenCounter(messages[0]),
|
|
500
|
+
1: tokenCounter(messages[1]),
|
|
501
|
+
2: tokenCounter(messages[2]),
|
|
502
|
+
};
|
|
503
|
+
|
|
504
|
+
const pruneMessages = createPruneMessages({
|
|
505
|
+
maxTokens: 100,
|
|
506
|
+
startIndex: 0,
|
|
507
|
+
tokenCounter,
|
|
508
|
+
indexTokenCountMap: { ...indexTokenCountMap },
|
|
509
|
+
});
|
|
510
|
+
|
|
511
|
+
// Provide usage metadata that indicates different token counts
|
|
512
|
+
const usageMetadata: Partial<UsageMetadata> = {
|
|
513
|
+
input_tokens: 50,
|
|
514
|
+
output_tokens: 25,
|
|
515
|
+
total_tokens: 75,
|
|
516
|
+
};
|
|
517
|
+
|
|
518
|
+
const result = pruneMessages({
|
|
519
|
+
messages,
|
|
520
|
+
usageMetadata,
|
|
521
|
+
});
|
|
522
|
+
|
|
523
|
+
// The function should have updated the indexTokenCountMap based on the usage metadata
|
|
524
|
+
expect(result.indexTokenCountMap).not.toEqual(indexTokenCountMap);
|
|
525
|
+
|
|
526
|
+
// The total of all values in indexTokenCountMap should equal the total_tokens from usageMetadata
|
|
527
|
+
const totalTokens = Object.values(result.indexTokenCountMap).reduce(
|
|
528
|
+
(a = 0, b = 0) => a + b,
|
|
529
|
+
0
|
|
530
|
+
);
|
|
531
|
+
expect(totalTokens).toBe(75);
|
|
532
|
+
});
|
|
533
|
+
});
|
|
534
|
+
|
|
535
|
+
describe('Tool Message Handling', () => {
|
|
536
|
+
it('should ensure context does not start with a tool message by finding an AI message', () => {
|
|
537
|
+
const tokenCounter = createTestTokenCounter();
|
|
538
|
+
const messages = [
|
|
539
|
+
new SystemMessage('System instruction'),
|
|
540
|
+
new AIMessage('AI message 1'),
|
|
541
|
+
new ToolMessage({ content: 'Tool result 1', tool_call_id: 'tool1' }),
|
|
542
|
+
new AIMessage('AI message 2'),
|
|
543
|
+
new ToolMessage({ content: 'Tool result 2', tool_call_id: 'tool2' }),
|
|
544
|
+
];
|
|
545
|
+
|
|
546
|
+
const indexTokenCountMap = {
|
|
547
|
+
0: 17, // System instruction
|
|
548
|
+
1: 12, // AI message 1
|
|
549
|
+
2: 13, // Tool result 1
|
|
550
|
+
3: 12, // AI message 2
|
|
551
|
+
4: 13, // Tool result 2
|
|
552
|
+
};
|
|
553
|
+
|
|
554
|
+
// Create a pruneMessages function with a token limit that will only include the last few messages
|
|
555
|
+
const pruneMessages = createPruneMessages({
|
|
556
|
+
maxTokens: 58, // Only enough for system + last 3 messages + 3, but should not include a parent-less tool message
|
|
557
|
+
startIndex: 0,
|
|
558
|
+
tokenCounter,
|
|
559
|
+
indexTokenCountMap: { ...indexTokenCountMap },
|
|
560
|
+
});
|
|
561
|
+
|
|
562
|
+
const result = pruneMessages({ messages });
|
|
563
|
+
|
|
564
|
+
// The context should include the system message, AI message 2, and Tool result 2
|
|
565
|
+
// It should NOT start with Tool result 2 alone
|
|
566
|
+
expect(result.context.length).toBe(3);
|
|
567
|
+
expect(result.context[0]).toBe(messages[0]); // System message
|
|
568
|
+
expect(result.context[1]).toBe(messages[3]); // AI message 2
|
|
569
|
+
expect(result.context[2]).toBe(messages[4]); // Tool result 2
|
|
570
|
+
});
|
|
571
|
+
|
|
572
|
+
it('should ensure context does not start with a tool message by finding a human message', () => {
|
|
573
|
+
const tokenCounter = createTestTokenCounter();
|
|
574
|
+
const messages = [
|
|
575
|
+
new SystemMessage('System instruction'),
|
|
576
|
+
new HumanMessage('Human message 1'),
|
|
577
|
+
new AIMessage('AI message 1'),
|
|
578
|
+
new ToolMessage({ content: 'Tool result 1', tool_call_id: 'tool1' }),
|
|
579
|
+
new HumanMessage('Human message 2'),
|
|
580
|
+
new ToolMessage({ content: 'Tool result 2', tool_call_id: 'tool2' }),
|
|
581
|
+
];
|
|
582
|
+
|
|
583
|
+
const indexTokenCountMap = {
|
|
584
|
+
0: 17, // System instruction
|
|
585
|
+
1: 15, // Human message 1
|
|
586
|
+
2: 12, // AI message 1
|
|
587
|
+
3: 13, // Tool result 1
|
|
588
|
+
4: 15, // Human message 2
|
|
589
|
+
5: 13, // Tool result 2
|
|
590
|
+
};
|
|
591
|
+
|
|
592
|
+
// Create a pruneMessages function with a token limit that will only include the last few messages
|
|
593
|
+
const pruneMessages = createPruneMessages({
|
|
594
|
+
maxTokens: 48, // Only enough for system + last 2 messages
|
|
595
|
+
startIndex: 0,
|
|
596
|
+
tokenCounter,
|
|
597
|
+
indexTokenCountMap: { ...indexTokenCountMap },
|
|
598
|
+
});
|
|
599
|
+
|
|
600
|
+
const result = pruneMessages({ messages });
|
|
601
|
+
|
|
602
|
+
// The context should include the system message, Human message 2, and Tool result 2
|
|
603
|
+
// It should NOT start with Tool result 2 alone
|
|
604
|
+
expect(result.context.length).toBe(3);
|
|
605
|
+
expect(result.context[0]).toBe(messages[0]); // System message
|
|
606
|
+
expect(result.context[1]).toBe(messages[4]); // Human message 2
|
|
607
|
+
expect(result.context[2]).toBe(messages[5]); // Tool result 2
|
|
608
|
+
});
|
|
609
|
+
|
|
610
|
+
it('should handle the case where a tool message is followed by an AI message', () => {
|
|
611
|
+
const tokenCounter = createTestTokenCounter();
|
|
612
|
+
const messages = [
|
|
613
|
+
new SystemMessage('System instruction'),
|
|
614
|
+
new HumanMessage('Human message'),
|
|
615
|
+
new AIMessage('AI message with tool use'),
|
|
616
|
+
new ToolMessage({ content: 'Tool result', tool_call_id: 'tool1' }),
|
|
617
|
+
new AIMessage('AI message after tool'),
|
|
618
|
+
];
|
|
619
|
+
|
|
620
|
+
const indexTokenCountMap = {
|
|
621
|
+
0: 17, // System instruction
|
|
622
|
+
1: 13, // Human message
|
|
623
|
+
2: 22, // AI message with tool use
|
|
624
|
+
3: 11, // Tool result
|
|
625
|
+
4: 19, // AI message after tool
|
|
626
|
+
};
|
|
627
|
+
|
|
628
|
+
const pruneMessages = createPruneMessages({
|
|
629
|
+
maxTokens: 50,
|
|
630
|
+
startIndex: 0,
|
|
631
|
+
tokenCounter,
|
|
632
|
+
indexTokenCountMap: { ...indexTokenCountMap },
|
|
633
|
+
});
|
|
634
|
+
|
|
635
|
+
const result = pruneMessages({ messages });
|
|
636
|
+
|
|
637
|
+
expect(result.context.length).toBe(2);
|
|
638
|
+
expect(result.context[0]).toBe(messages[0]); // System message
|
|
639
|
+
expect(result.context[1]).toBe(messages[4]); // AI message after tool
|
|
640
|
+
});
|
|
641
|
+
|
|
642
|
+
it('should handle the case where a tool message is followed by a human message', () => {
|
|
643
|
+
const tokenCounter = createTestTokenCounter();
|
|
644
|
+
const messages = [
|
|
645
|
+
new SystemMessage('System instruction'),
|
|
646
|
+
new HumanMessage('Human message 1'),
|
|
647
|
+
new AIMessage('AI message with tool use'),
|
|
648
|
+
new ToolMessage({ content: 'Tool result', tool_call_id: 'tool1' }),
|
|
649
|
+
new HumanMessage('Human message 2'),
|
|
650
|
+
];
|
|
651
|
+
|
|
652
|
+
const indexTokenCountMap = {
|
|
653
|
+
0: 17, // System instruction
|
|
654
|
+
1: 15, // Human message 1
|
|
655
|
+
2: 22, // AI message with tool use
|
|
656
|
+
3: 11, // Tool result
|
|
657
|
+
4: 15, // Human message 2
|
|
658
|
+
};
|
|
659
|
+
|
|
660
|
+
const pruneMessages = createPruneMessages({
|
|
661
|
+
maxTokens: 46,
|
|
662
|
+
startIndex: 0,
|
|
663
|
+
tokenCounter,
|
|
664
|
+
indexTokenCountMap: { ...indexTokenCountMap },
|
|
665
|
+
});
|
|
666
|
+
|
|
667
|
+
const result = pruneMessages({ messages });
|
|
668
|
+
|
|
669
|
+
expect(result.context.length).toBe(2);
|
|
670
|
+
expect(result.context[0]).toBe(messages[0]); // System message
|
|
671
|
+
expect(result.context[1]).toBe(messages[4]); // Human message 2
|
|
672
|
+
});
|
|
673
|
+
|
|
674
|
+
it('should handle complex sequence with multiple tool messages', () => {
|
|
675
|
+
const tokenCounter = createTestTokenCounter();
|
|
676
|
+
const messages = [
|
|
677
|
+
new SystemMessage('System instruction'),
|
|
678
|
+
new HumanMessage('Human message 1'),
|
|
679
|
+
new AIMessage('AI message 1 with tool use'),
|
|
680
|
+
new ToolMessage({ content: 'Tool result 1', tool_call_id: 'tool1' }),
|
|
681
|
+
new AIMessage('AI message 2 with tool use'),
|
|
682
|
+
new ToolMessage({ content: 'Tool result 2', tool_call_id: 'tool2' }),
|
|
683
|
+
new AIMessage('AI message 3 with tool use'),
|
|
684
|
+
new ToolMessage({ content: 'Tool result 3', tool_call_id: 'tool3' }),
|
|
685
|
+
];
|
|
686
|
+
|
|
687
|
+
const indexTokenCountMap = {
|
|
688
|
+
0: 17, // System instruction
|
|
689
|
+
1: 15, // Human message 1
|
|
690
|
+
2: 26, // AI message 1 with tool use
|
|
691
|
+
3: 13, // Tool result 1
|
|
692
|
+
4: 26, // AI message 2 with tool use
|
|
693
|
+
5: 13, // Tool result 2
|
|
694
|
+
6: 26, // AI message 3 with tool use
|
|
695
|
+
7: 13, // Tool result 3
|
|
696
|
+
};
|
|
697
|
+
|
|
698
|
+
const pruneMessages = createPruneMessages({
|
|
699
|
+
maxTokens: 111,
|
|
700
|
+
startIndex: 0,
|
|
701
|
+
tokenCounter,
|
|
702
|
+
indexTokenCountMap: { ...indexTokenCountMap },
|
|
703
|
+
});
|
|
704
|
+
|
|
705
|
+
const result = pruneMessages({ messages });
|
|
706
|
+
|
|
707
|
+
expect(result.context.length).toBe(5);
|
|
708
|
+
expect(result.context[0]).toBe(messages[0]); // System message
|
|
709
|
+
expect(result.context[1]).toBe(messages[4]); // AI message 2 with tool use
|
|
710
|
+
expect(result.context[2]).toBe(messages[5]); // Tool result 2
|
|
711
|
+
expect(result.context[3]).toBe(messages[6]); // AI message 3 with tool use
|
|
712
|
+
expect(result.context[4]).toBe(messages[7]); // Tool result 3
|
|
713
|
+
});
|
|
714
|
+
});
|
|
715
|
+
|
|
716
|
+
describe('Integration with Run', () => {
|
|
717
|
+
it('should initialize Run with custom token counter and process messages', async () => {
|
|
718
|
+
const provider = Providers.OPENAI;
|
|
719
|
+
const llmConfig = getLLMConfig(provider);
|
|
720
|
+
const tokenCounter = createTestTokenCounter();
|
|
721
|
+
|
|
722
|
+
const run = await Run.create<t.IState>({
|
|
723
|
+
runId: 'test-prune-run',
|
|
724
|
+
graphConfig: {
|
|
725
|
+
type: 'standard',
|
|
726
|
+
llmConfig,
|
|
727
|
+
instructions: 'You are a helpful assistant.',
|
|
728
|
+
maxContextTokens: 1000,
|
|
729
|
+
},
|
|
730
|
+
returnContent: true,
|
|
731
|
+
tokenCounter,
|
|
732
|
+
indexTokenCountMap: {},
|
|
733
|
+
});
|
|
734
|
+
|
|
735
|
+
// Override the model to use a fake LLM
|
|
736
|
+
run.Graph?.overrideTestModel(['This is a test response'], 1);
|
|
737
|
+
|
|
738
|
+
const messages = [new HumanMessage('Hello, how are you?')];
|
|
739
|
+
|
|
740
|
+
const config: Partial<RunnableConfig> & {
|
|
741
|
+
version: 'v1' | 'v2';
|
|
742
|
+
streamMode: string;
|
|
743
|
+
} = {
|
|
744
|
+
configurable: {
|
|
745
|
+
thread_id: 'test-thread',
|
|
746
|
+
},
|
|
747
|
+
streamMode: 'values',
|
|
748
|
+
version: 'v2' as const,
|
|
749
|
+
};
|
|
750
|
+
|
|
751
|
+
await run.processStream({ messages }, config);
|
|
752
|
+
|
|
753
|
+
const finalMessages = run.getRunMessages();
|
|
754
|
+
expect(finalMessages).toBeDefined();
|
|
755
|
+
expect(finalMessages?.length).toBeGreaterThan(0);
|
|
756
|
+
});
|
|
757
|
+
});
|
|
758
|
+
});
|