@mastra/memory 0.2.6 → 0.2.7-alpha.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +12 -10
- package/CHANGELOG.md +40 -0
- package/README.md +6 -2
- package/dist/_tsup-dts-rollup.d.cts +58 -1
- package/dist/_tsup-dts-rollup.d.ts +58 -1
- package/dist/index.cjs +6 -5
- package/dist/index.js +6 -5
- package/dist/processors/index.cjs +161 -0
- package/dist/processors/index.d.cts +2 -0
- package/dist/processors/index.d.ts +2 -0
- package/dist/processors/index.js +154 -0
- package/package.json +14 -21
- package/src/index.ts +8 -5
- package/src/processors/index.test.ts +236 -0
- package/src/processors/index.ts +2 -0
- package/src/processors/token-limiter.ts +150 -0
- package/src/processors/tool-call-filter.ts +77 -0
package/.turbo/turbo-build.log
CHANGED
|
@@ -1,27 +1,29 @@
|
|
|
1
1
|
|
|
2
|
-
> @mastra/memory@0.2.
|
|
3
|
-
> pnpm run check && tsup src/index.ts --format esm,cjs --experimental-dts --clean --treeshake=smallest --splitting
|
|
2
|
+
> @mastra/memory@0.2.7-alpha.2 build /home/runner/work/mastra/mastra/packages/memory
|
|
3
|
+
> pnpm run check && tsup src/index.ts src/processors/index.ts --format esm,cjs --experimental-dts --clean --treeshake=smallest --splitting
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
> @mastra/memory@0.2.
|
|
6
|
+
> @mastra/memory@0.2.7-alpha.2 check /home/runner/work/mastra/mastra/packages/memory
|
|
7
7
|
> tsc --noEmit
|
|
8
8
|
|
|
9
|
-
[34mCLI[39m Building entry: src/index.ts
|
|
9
|
+
[34mCLI[39m Building entry: src/index.ts, src/processors/index.ts
|
|
10
10
|
[34mCLI[39m Using tsconfig: tsconfig.json
|
|
11
11
|
[34mCLI[39m tsup v8.4.0
|
|
12
12
|
[34mTSC[39m Build start
|
|
13
|
-
[32mTSC[39m ⚡️ Build success in
|
|
13
|
+
[32mTSC[39m ⚡️ Build success in 11238ms
|
|
14
14
|
[34mDTS[39m Build start
|
|
15
15
|
[34mCLI[39m Target: es2022
|
|
16
16
|
Analysis will use the bundled TypeScript version 5.8.2
|
|
17
17
|
[36mWriting package typings: /home/runner/work/mastra/mastra/packages/memory/dist/_tsup-dts-rollup.d.ts[39m
|
|
18
18
|
Analysis will use the bundled TypeScript version 5.8.2
|
|
19
19
|
[36mWriting package typings: /home/runner/work/mastra/mastra/packages/memory/dist/_tsup-dts-rollup.d.cts[39m
|
|
20
|
-
[32mDTS[39m ⚡️ Build success in
|
|
20
|
+
[32mDTS[39m ⚡️ Build success in 5847ms
|
|
21
21
|
[34mCLI[39m Cleaning output folder
|
|
22
22
|
[34mESM[39m Build start
|
|
23
23
|
[34mCJS[39m Build start
|
|
24
|
-
[
|
|
25
|
-
[
|
|
26
|
-
[32mESM[39m
|
|
27
|
-
[
|
|
24
|
+
[32mESM[39m [1mdist/index.js [22m[32m13.60 KB[39m
|
|
25
|
+
[32mESM[39m [1mdist/processors/index.js [22m[32m5.33 KB[39m
|
|
26
|
+
[32mESM[39m ⚡️ Build success in 238ms
|
|
27
|
+
[32mCJS[39m [1mdist/index.cjs [22m[32m13.62 KB[39m
|
|
28
|
+
[32mCJS[39m [1mdist/processors/index.cjs [22m[32m5.54 KB[39m
|
|
29
|
+
[32mCJS[39m ⚡️ Build success in 238ms
|
package/CHANGELOG.md
CHANGED
|
@@ -1,5 +1,45 @@
|
|
|
1
1
|
# @mastra/memory
|
|
2
2
|
|
|
3
|
+
## 0.2.7-alpha.2
|
|
4
|
+
|
|
5
|
+
### Patch Changes
|
|
6
|
+
|
|
7
|
+
- Updated dependencies [56c31b7]
|
|
8
|
+
- Updated dependencies [dbbbf80]
|
|
9
|
+
- Updated dependencies [99d43b9]
|
|
10
|
+
- @mastra/core@0.8.0-alpha.2
|
|
11
|
+
- @mastra/rag@0.1.15-alpha.2
|
|
12
|
+
|
|
13
|
+
## 0.2.7-alpha.1
|
|
14
|
+
|
|
15
|
+
### Patch Changes
|
|
16
|
+
|
|
17
|
+
- a0967a0: Added new "Memory Processor" feature to @mastra/core and @mastra/memory, allowing devs to modify Mastra Memory before it's sent to the LLM
|
|
18
|
+
- 0118361: Add resourceId to memory metadata
|
|
19
|
+
- Updated dependencies [619c39d]
|
|
20
|
+
- Updated dependencies [fe56be0]
|
|
21
|
+
- Updated dependencies [a0967a0]
|
|
22
|
+
- Updated dependencies [e47f529]
|
|
23
|
+
- Updated dependencies [fca3b21]
|
|
24
|
+
- Updated dependencies [0118361]
|
|
25
|
+
- Updated dependencies [619c39d]
|
|
26
|
+
- @mastra/core@0.8.0-alpha.1
|
|
27
|
+
- @mastra/rag@0.1.15-alpha.1
|
|
28
|
+
|
|
29
|
+
## 0.2.7-alpha.0
|
|
30
|
+
|
|
31
|
+
### Patch Changes
|
|
32
|
+
|
|
33
|
+
- 7599d77: fix(deps): update ai sdk to ^4.2.2
|
|
34
|
+
- Updated dependencies [107bcfe]
|
|
35
|
+
- Updated dependencies [5b4e19f]
|
|
36
|
+
- Updated dependencies [7599d77]
|
|
37
|
+
- Updated dependencies [cafae83]
|
|
38
|
+
- Updated dependencies [8076ecf]
|
|
39
|
+
- Updated dependencies [304397c]
|
|
40
|
+
- @mastra/core@0.7.1-alpha.0
|
|
41
|
+
- @mastra/rag@0.1.15-alpha.0
|
|
42
|
+
|
|
3
43
|
## 0.2.6
|
|
4
44
|
|
|
5
45
|
### Patch Changes
|
package/README.md
CHANGED
|
@@ -3,10 +3,14 @@ import type { CoreMessage } from '@mastra/core';
|
|
|
3
3
|
import type { CoreTool } from '@mastra/core';
|
|
4
4
|
import { MastraMemory } from '@mastra/core/memory';
|
|
5
5
|
import type { MemoryConfig } from '@mastra/core/memory';
|
|
6
|
+
import { MemoryProcessor } from '@mastra/core/memory';
|
|
7
|
+
import { MemoryProcessor as MemoryProcessor_2 } from '@mastra/core';
|
|
8
|
+
import type { MemoryProcessorOpts } from '@mastra/core';
|
|
6
9
|
import type { MessageType } from '@mastra/core/memory';
|
|
7
10
|
import type { SharedMemoryConfig } from '@mastra/core/memory';
|
|
8
11
|
import type { StorageGetMessagesArg } from '@mastra/core/storage';
|
|
9
12
|
import type { StorageThreadType } from '@mastra/core/memory';
|
|
13
|
+
import type { TiktokenBPE } from 'js-tiktoken/lite';
|
|
10
14
|
|
|
11
15
|
/**
|
|
12
16
|
* Concrete implementation of MastraMemory that adds support for thread configuration
|
|
@@ -15,7 +19,9 @@ import type { StorageThreadType } from '@mastra/core/memory';
|
|
|
15
19
|
export declare class Memory extends MastraMemory {
|
|
16
20
|
constructor(config?: SharedMemoryConfig);
|
|
17
21
|
private validateThreadIsOwnedByResource;
|
|
18
|
-
query({ threadId, resourceId, selectBy, threadConfig, }: StorageGetMessagesArg
|
|
22
|
+
query({ threadId, resourceId, selectBy, threadConfig, }: StorageGetMessagesArg & {
|
|
23
|
+
threadConfig?: MemoryConfig;
|
|
24
|
+
}): Promise<{
|
|
19
25
|
messages: CoreMessage[];
|
|
20
26
|
uiMessages: AiMessageType[];
|
|
21
27
|
}>;
|
|
@@ -66,6 +72,57 @@ export declare class Memory extends MastraMemory {
|
|
|
66
72
|
getTools(config?: MemoryConfig): Record<string, CoreTool>;
|
|
67
73
|
}
|
|
68
74
|
|
|
75
|
+
/**
|
|
76
|
+
* Limits the total number of tokens in the messages.
|
|
77
|
+
* Uses js-tiktoken with o200k_base encoding by default for accurate token counting with modern models.
|
|
78
|
+
*/
|
|
79
|
+
declare class TokenLimiter extends MemoryProcessor {
|
|
80
|
+
private encoder;
|
|
81
|
+
private maxTokens;
|
|
82
|
+
TOKENS_PER_MESSAGE: number;
|
|
83
|
+
TOKENS_PER_TOOL: number;
|
|
84
|
+
TOKENS_PER_CONVERSATION: number;
|
|
85
|
+
/**
|
|
86
|
+
* Create a token limiter for messages.
|
|
87
|
+
* @param options Either a number (token limit) or a configuration object
|
|
88
|
+
*/
|
|
89
|
+
constructor(options: number | TokenLimiterOptions);
|
|
90
|
+
process(messages: CoreMessage[], { systemMessage, memorySystemMessage, newMessages }?: MemoryProcessorOpts): CoreMessage[];
|
|
91
|
+
countTokens(message: string | CoreMessage): number;
|
|
92
|
+
}
|
|
93
|
+
export { TokenLimiter }
|
|
94
|
+
export { TokenLimiter as TokenLimiter_alias_1 }
|
|
95
|
+
|
|
96
|
+
/**
|
|
97
|
+
* Configuration options for TokenLimiter
|
|
98
|
+
*/
|
|
99
|
+
declare interface TokenLimiterOptions {
|
|
100
|
+
/** Maximum number of tokens to allow */
|
|
101
|
+
limit: number;
|
|
102
|
+
/** Optional encoding to use (defaults to o200k_base which is used by gpt-4o) */
|
|
103
|
+
encoding?: TiktokenBPE;
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
/**
|
|
107
|
+
* Filters out tool calls and results from messages.
|
|
108
|
+
* By default (with no arguments), excludes all tool calls and their results.
|
|
109
|
+
* Can be configured to exclude only specific tools by name.
|
|
110
|
+
*/
|
|
111
|
+
declare class ToolCallFilter extends MemoryProcessor_2 {
|
|
112
|
+
private exclude;
|
|
113
|
+
/**
|
|
114
|
+
* Create a filter for tool calls and results.
|
|
115
|
+
* @param options Configuration options
|
|
116
|
+
* @param options.exclude List of specific tool names to exclude. If not provided, all tool calls are excluded.
|
|
117
|
+
*/
|
|
118
|
+
constructor(options?: {
|
|
119
|
+
exclude?: string[];
|
|
120
|
+
});
|
|
121
|
+
process(messages: CoreMessage[]): CoreMessage[];
|
|
122
|
+
}
|
|
123
|
+
export { ToolCallFilter }
|
|
124
|
+
export { ToolCallFilter as ToolCallFilter_alias_1 }
|
|
125
|
+
|
|
69
126
|
export declare const updateWorkingMemoryTool: CoreTool;
|
|
70
127
|
|
|
71
128
|
export { }
|
|
@@ -3,10 +3,14 @@ import type { CoreMessage } from '@mastra/core';
|
|
|
3
3
|
import type { CoreTool } from '@mastra/core';
|
|
4
4
|
import { MastraMemory } from '@mastra/core/memory';
|
|
5
5
|
import type { MemoryConfig } from '@mastra/core/memory';
|
|
6
|
+
import { MemoryProcessor } from '@mastra/core/memory';
|
|
7
|
+
import { MemoryProcessor as MemoryProcessor_2 } from '@mastra/core';
|
|
8
|
+
import type { MemoryProcessorOpts } from '@mastra/core';
|
|
6
9
|
import type { MessageType } from '@mastra/core/memory';
|
|
7
10
|
import type { SharedMemoryConfig } from '@mastra/core/memory';
|
|
8
11
|
import type { StorageGetMessagesArg } from '@mastra/core/storage';
|
|
9
12
|
import type { StorageThreadType } from '@mastra/core/memory';
|
|
13
|
+
import type { TiktokenBPE } from 'js-tiktoken/lite';
|
|
10
14
|
|
|
11
15
|
/**
|
|
12
16
|
* Concrete implementation of MastraMemory that adds support for thread configuration
|
|
@@ -15,7 +19,9 @@ import type { StorageThreadType } from '@mastra/core/memory';
|
|
|
15
19
|
export declare class Memory extends MastraMemory {
|
|
16
20
|
constructor(config?: SharedMemoryConfig);
|
|
17
21
|
private validateThreadIsOwnedByResource;
|
|
18
|
-
query({ threadId, resourceId, selectBy, threadConfig, }: StorageGetMessagesArg
|
|
22
|
+
query({ threadId, resourceId, selectBy, threadConfig, }: StorageGetMessagesArg & {
|
|
23
|
+
threadConfig?: MemoryConfig;
|
|
24
|
+
}): Promise<{
|
|
19
25
|
messages: CoreMessage[];
|
|
20
26
|
uiMessages: AiMessageType[];
|
|
21
27
|
}>;
|
|
@@ -66,6 +72,57 @@ export declare class Memory extends MastraMemory {
|
|
|
66
72
|
getTools(config?: MemoryConfig): Record<string, CoreTool>;
|
|
67
73
|
}
|
|
68
74
|
|
|
75
|
+
/**
|
|
76
|
+
* Limits the total number of tokens in the messages.
|
|
77
|
+
* Uses js-tiktoken with o200k_base encoding by default for accurate token counting with modern models.
|
|
78
|
+
*/
|
|
79
|
+
declare class TokenLimiter extends MemoryProcessor {
|
|
80
|
+
private encoder;
|
|
81
|
+
private maxTokens;
|
|
82
|
+
TOKENS_PER_MESSAGE: number;
|
|
83
|
+
TOKENS_PER_TOOL: number;
|
|
84
|
+
TOKENS_PER_CONVERSATION: number;
|
|
85
|
+
/**
|
|
86
|
+
* Create a token limiter for messages.
|
|
87
|
+
* @param options Either a number (token limit) or a configuration object
|
|
88
|
+
*/
|
|
89
|
+
constructor(options: number | TokenLimiterOptions);
|
|
90
|
+
process(messages: CoreMessage[], { systemMessage, memorySystemMessage, newMessages }?: MemoryProcessorOpts): CoreMessage[];
|
|
91
|
+
countTokens(message: string | CoreMessage): number;
|
|
92
|
+
}
|
|
93
|
+
export { TokenLimiter }
|
|
94
|
+
export { TokenLimiter as TokenLimiter_alias_1 }
|
|
95
|
+
|
|
96
|
+
/**
|
|
97
|
+
* Configuration options for TokenLimiter
|
|
98
|
+
*/
|
|
99
|
+
declare interface TokenLimiterOptions {
|
|
100
|
+
/** Maximum number of tokens to allow */
|
|
101
|
+
limit: number;
|
|
102
|
+
/** Optional encoding to use (defaults to o200k_base which is used by gpt-4o) */
|
|
103
|
+
encoding?: TiktokenBPE;
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
/**
|
|
107
|
+
* Filters out tool calls and results from messages.
|
|
108
|
+
* By default (with no arguments), excludes all tool calls and their results.
|
|
109
|
+
* Can be configured to exclude only specific tools by name.
|
|
110
|
+
*/
|
|
111
|
+
declare class ToolCallFilter extends MemoryProcessor_2 {
|
|
112
|
+
private exclude;
|
|
113
|
+
/**
|
|
114
|
+
* Create a filter for tool calls and results.
|
|
115
|
+
* @param options Configuration options
|
|
116
|
+
* @param options.exclude List of specific tool names to exclude. If not provided, all tool calls are excluded.
|
|
117
|
+
*/
|
|
118
|
+
constructor(options?: {
|
|
119
|
+
exclude?: string[];
|
|
120
|
+
});
|
|
121
|
+
process(messages: CoreMessage[]): CoreMessage[];
|
|
122
|
+
}
|
|
123
|
+
export { ToolCallFilter }
|
|
124
|
+
export { ToolCallFilter as ToolCallFilter_alias_1 }
|
|
125
|
+
|
|
69
126
|
export declare const updateWorkingMemoryTool: CoreTool;
|
|
70
127
|
|
|
71
128
|
export { }
|
package/dist/index.cjs
CHANGED
|
@@ -129,7 +129,7 @@ var Memory = class extends memory.MastraMemory {
|
|
|
129
129
|
threadId
|
|
130
130
|
};
|
|
131
131
|
}
|
|
132
|
-
const
|
|
132
|
+
const messagesResult = await this.query({
|
|
133
133
|
threadId,
|
|
134
134
|
selectBy: {
|
|
135
135
|
last: threadConfig.lastMessages,
|
|
@@ -137,11 +137,11 @@ var Memory = class extends memory.MastraMemory {
|
|
|
137
137
|
},
|
|
138
138
|
threadConfig: config
|
|
139
139
|
});
|
|
140
|
-
this.logger.debug(`Remembered message history includes ${
|
|
140
|
+
this.logger.debug(`Remembered message history includes ${messagesResult.messages.length} messages.`);
|
|
141
141
|
return {
|
|
142
142
|
threadId,
|
|
143
|
-
messages:
|
|
144
|
-
uiMessages:
|
|
143
|
+
messages: messagesResult.messages,
|
|
144
|
+
uiMessages: messagesResult.uiMessages
|
|
145
145
|
};
|
|
146
146
|
}
|
|
147
147
|
async getThreadById({ threadId }) {
|
|
@@ -214,7 +214,8 @@ var Memory = class extends memory.MastraMemory {
|
|
|
214
214
|
vectors: embeddings,
|
|
215
215
|
metadata: chunks.map(() => ({
|
|
216
216
|
message_id: message.id,
|
|
217
|
-
thread_id: message.threadId
|
|
217
|
+
thread_id: message.threadId,
|
|
218
|
+
resource_id: message.resourceId
|
|
218
219
|
}))
|
|
219
220
|
});
|
|
220
221
|
}
|
package/dist/index.js
CHANGED
|
@@ -127,7 +127,7 @@ var Memory = class extends MastraMemory {
|
|
|
127
127
|
threadId
|
|
128
128
|
};
|
|
129
129
|
}
|
|
130
|
-
const
|
|
130
|
+
const messagesResult = await this.query({
|
|
131
131
|
threadId,
|
|
132
132
|
selectBy: {
|
|
133
133
|
last: threadConfig.lastMessages,
|
|
@@ -135,11 +135,11 @@ var Memory = class extends MastraMemory {
|
|
|
135
135
|
},
|
|
136
136
|
threadConfig: config
|
|
137
137
|
});
|
|
138
|
-
this.logger.debug(`Remembered message history includes ${
|
|
138
|
+
this.logger.debug(`Remembered message history includes ${messagesResult.messages.length} messages.`);
|
|
139
139
|
return {
|
|
140
140
|
threadId,
|
|
141
|
-
messages:
|
|
142
|
-
uiMessages:
|
|
141
|
+
messages: messagesResult.messages,
|
|
142
|
+
uiMessages: messagesResult.uiMessages
|
|
143
143
|
};
|
|
144
144
|
}
|
|
145
145
|
async getThreadById({ threadId }) {
|
|
@@ -212,7 +212,8 @@ var Memory = class extends MastraMemory {
|
|
|
212
212
|
vectors: embeddings,
|
|
213
213
|
metadata: chunks.map(() => ({
|
|
214
214
|
message_id: message.id,
|
|
215
|
-
thread_id: message.threadId
|
|
215
|
+
thread_id: message.threadId,
|
|
216
|
+
resource_id: message.resourceId
|
|
216
217
|
}))
|
|
217
218
|
});
|
|
218
219
|
}
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
'use strict';
|
|
2
|
+
|
|
3
|
+
var memory = require('@mastra/core/memory');
|
|
4
|
+
var lite = require('js-tiktoken/lite');
|
|
5
|
+
var o200k_base = require('js-tiktoken/ranks/o200k_base');
|
|
6
|
+
var core = require('@mastra/core');
|
|
7
|
+
|
|
8
|
+
function _interopDefault (e) { return e && e.__esModule ? e : { default: e }; }
|
|
9
|
+
|
|
10
|
+
var o200k_base__default = /*#__PURE__*/_interopDefault(o200k_base);
|
|
11
|
+
|
|
12
|
+
// src/processors/token-limiter.ts
|
|
13
|
+
var TokenLimiter = class extends memory.MemoryProcessor {
|
|
14
|
+
encoder;
|
|
15
|
+
maxTokens;
|
|
16
|
+
// Token overheads per OpenAI's documentation
|
|
17
|
+
// See: https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken#6-counting-tokens-for-chat-completions-api-calls
|
|
18
|
+
// Every message follows <|start|>{role/name}\n{content}<|end|>
|
|
19
|
+
TOKENS_PER_MESSAGE = 3;
|
|
20
|
+
// tokens added for each message (start & end tokens)
|
|
21
|
+
TOKENS_PER_TOOL = 2;
|
|
22
|
+
// empirical adjustment for tool calls
|
|
23
|
+
TOKENS_PER_CONVERSATION = 25;
|
|
24
|
+
// fixed overhead for the conversation
|
|
25
|
+
/**
|
|
26
|
+
* Create a token limiter for messages.
|
|
27
|
+
* @param options Either a number (token limit) or a configuration object
|
|
28
|
+
*/
|
|
29
|
+
constructor(options) {
|
|
30
|
+
super({
|
|
31
|
+
name: "TokenLimiter"
|
|
32
|
+
});
|
|
33
|
+
if (typeof options === "number") {
|
|
34
|
+
this.maxTokens = options;
|
|
35
|
+
this.encoder = new lite.Tiktoken(o200k_base__default.default);
|
|
36
|
+
} else {
|
|
37
|
+
this.maxTokens = options.limit;
|
|
38
|
+
this.encoder = new lite.Tiktoken(options.encoding || o200k_base__default.default);
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
process(messages, { systemMessage, memorySystemMessage, newMessages } = {}) {
|
|
42
|
+
let totalTokens = 0;
|
|
43
|
+
totalTokens += this.TOKENS_PER_CONVERSATION;
|
|
44
|
+
if (systemMessage) {
|
|
45
|
+
totalTokens += this.countTokens(systemMessage);
|
|
46
|
+
totalTokens += this.TOKENS_PER_MESSAGE;
|
|
47
|
+
}
|
|
48
|
+
if (memorySystemMessage) {
|
|
49
|
+
totalTokens += this.countTokens(memorySystemMessage);
|
|
50
|
+
totalTokens += this.TOKENS_PER_MESSAGE;
|
|
51
|
+
}
|
|
52
|
+
const allMessages = [...messages, ...newMessages || []];
|
|
53
|
+
const result = [];
|
|
54
|
+
for (let i = allMessages.length - 1; i >= 0; i--) {
|
|
55
|
+
const message = allMessages[i];
|
|
56
|
+
if (!message) continue;
|
|
57
|
+
const messageTokens = this.countTokens(message);
|
|
58
|
+
if (totalTokens + messageTokens <= this.maxTokens) {
|
|
59
|
+
result.unshift(message);
|
|
60
|
+
totalTokens += messageTokens;
|
|
61
|
+
} else {
|
|
62
|
+
this.logger.info(
|
|
63
|
+
`filtering ${allMessages.length - result.length}/${allMessages.length} messages, token limit of ${this.maxTokens} exceeded`
|
|
64
|
+
);
|
|
65
|
+
break;
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
return result;
|
|
69
|
+
}
|
|
70
|
+
countTokens(message) {
|
|
71
|
+
if (typeof message === `string`) {
|
|
72
|
+
return this.encoder.encode(message).length;
|
|
73
|
+
}
|
|
74
|
+
let tokenString = message.role;
|
|
75
|
+
if (typeof message.content === "string") {
|
|
76
|
+
tokenString += message.content;
|
|
77
|
+
} else if (Array.isArray(message.content)) {
|
|
78
|
+
for (const part of message.content) {
|
|
79
|
+
tokenString += part.type;
|
|
80
|
+
if (part.type === "text") {
|
|
81
|
+
tokenString += part.text;
|
|
82
|
+
} else if (part.type === "tool-call") {
|
|
83
|
+
tokenString += part.toolName;
|
|
84
|
+
if (part.args) {
|
|
85
|
+
tokenString += typeof part.args === "string" ? part.args : JSON.stringify(part.args);
|
|
86
|
+
}
|
|
87
|
+
} else if (part.type === "tool-result") {
|
|
88
|
+
if (part.result !== void 0) {
|
|
89
|
+
tokenString += typeof part.result === "string" ? part.result : JSON.stringify(part.result);
|
|
90
|
+
}
|
|
91
|
+
} else {
|
|
92
|
+
tokenString += JSON.stringify(part);
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
const messageOverhead = this.TOKENS_PER_MESSAGE;
|
|
97
|
+
let toolOverhead = 0;
|
|
98
|
+
if (Array.isArray(message.content)) {
|
|
99
|
+
for (const part of message.content) {
|
|
100
|
+
if (part.type === "tool-call" || part.type === "tool-result") {
|
|
101
|
+
toolOverhead += this.TOKENS_PER_TOOL;
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
}
|
|
105
|
+
const totalMessageOverhead = messageOverhead + toolOverhead;
|
|
106
|
+
return this.encoder.encode(tokenString).length + totalMessageOverhead;
|
|
107
|
+
}
|
|
108
|
+
};
|
|
109
|
+
var ToolCallFilter = class extends core.MemoryProcessor {
|
|
110
|
+
exclude;
|
|
111
|
+
/**
|
|
112
|
+
* Create a filter for tool calls and results.
|
|
113
|
+
* @param options Configuration options
|
|
114
|
+
* @param options.exclude List of specific tool names to exclude. If not provided, all tool calls are excluded.
|
|
115
|
+
*/
|
|
116
|
+
constructor(options = {}) {
|
|
117
|
+
super({ name: "ToolCallFilter" });
|
|
118
|
+
if (!options || !options.exclude) {
|
|
119
|
+
this.exclude = "all";
|
|
120
|
+
} else {
|
|
121
|
+
this.exclude = Array.isArray(options.exclude) ? options.exclude : [];
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
process(messages) {
|
|
125
|
+
if (this.exclude === "all") {
|
|
126
|
+
return messages.filter((message) => {
|
|
127
|
+
if (Array.isArray(message.content)) {
|
|
128
|
+
return !message.content.some((part) => part.type === "tool-call" || part.type === "tool-result");
|
|
129
|
+
}
|
|
130
|
+
return true;
|
|
131
|
+
});
|
|
132
|
+
}
|
|
133
|
+
if (this.exclude.length > 0) {
|
|
134
|
+
const excludedToolCallIds = /* @__PURE__ */ new Set();
|
|
135
|
+
return messages.filter((message) => {
|
|
136
|
+
if (!Array.isArray(message.content)) return true;
|
|
137
|
+
if (message.role === "assistant") {
|
|
138
|
+
let shouldExclude = false;
|
|
139
|
+
for (const part of message.content) {
|
|
140
|
+
if (part.type === "tool-call" && this.exclude.includes(part.toolName)) {
|
|
141
|
+
excludedToolCallIds.add(part.toolCallId);
|
|
142
|
+
shouldExclude = true;
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
return !shouldExclude;
|
|
146
|
+
}
|
|
147
|
+
if (message.role === "tool") {
|
|
148
|
+
const shouldExclude = message.content.some(
|
|
149
|
+
(part) => part.type === "tool-result" && excludedToolCallIds.has(part.toolCallId)
|
|
150
|
+
);
|
|
151
|
+
return !shouldExclude;
|
|
152
|
+
}
|
|
153
|
+
return true;
|
|
154
|
+
});
|
|
155
|
+
}
|
|
156
|
+
return messages;
|
|
157
|
+
}
|
|
158
|
+
};
|
|
159
|
+
|
|
160
|
+
exports.TokenLimiter = TokenLimiter;
|
|
161
|
+
exports.ToolCallFilter = ToolCallFilter;
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import { MemoryProcessor } from '@mastra/core/memory';
|
|
2
|
+
import { Tiktoken } from 'js-tiktoken/lite';
|
|
3
|
+
import o200k_base from 'js-tiktoken/ranks/o200k_base';
|
|
4
|
+
import { MemoryProcessor as MemoryProcessor$1 } from '@mastra/core';
|
|
5
|
+
|
|
6
|
+
// src/processors/token-limiter.ts
|
|
7
|
+
var TokenLimiter = class extends MemoryProcessor {
|
|
8
|
+
encoder;
|
|
9
|
+
maxTokens;
|
|
10
|
+
// Token overheads per OpenAI's documentation
|
|
11
|
+
// See: https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken#6-counting-tokens-for-chat-completions-api-calls
|
|
12
|
+
// Every message follows <|start|>{role/name}\n{content}<|end|>
|
|
13
|
+
TOKENS_PER_MESSAGE = 3;
|
|
14
|
+
// tokens added for each message (start & end tokens)
|
|
15
|
+
TOKENS_PER_TOOL = 2;
|
|
16
|
+
// empirical adjustment for tool calls
|
|
17
|
+
TOKENS_PER_CONVERSATION = 25;
|
|
18
|
+
// fixed overhead for the conversation
|
|
19
|
+
/**
|
|
20
|
+
* Create a token limiter for messages.
|
|
21
|
+
* @param options Either a number (token limit) or a configuration object
|
|
22
|
+
*/
|
|
23
|
+
constructor(options) {
|
|
24
|
+
super({
|
|
25
|
+
name: "TokenLimiter"
|
|
26
|
+
});
|
|
27
|
+
if (typeof options === "number") {
|
|
28
|
+
this.maxTokens = options;
|
|
29
|
+
this.encoder = new Tiktoken(o200k_base);
|
|
30
|
+
} else {
|
|
31
|
+
this.maxTokens = options.limit;
|
|
32
|
+
this.encoder = new Tiktoken(options.encoding || o200k_base);
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
process(messages, { systemMessage, memorySystemMessage, newMessages } = {}) {
|
|
36
|
+
let totalTokens = 0;
|
|
37
|
+
totalTokens += this.TOKENS_PER_CONVERSATION;
|
|
38
|
+
if (systemMessage) {
|
|
39
|
+
totalTokens += this.countTokens(systemMessage);
|
|
40
|
+
totalTokens += this.TOKENS_PER_MESSAGE;
|
|
41
|
+
}
|
|
42
|
+
if (memorySystemMessage) {
|
|
43
|
+
totalTokens += this.countTokens(memorySystemMessage);
|
|
44
|
+
totalTokens += this.TOKENS_PER_MESSAGE;
|
|
45
|
+
}
|
|
46
|
+
const allMessages = [...messages, ...newMessages || []];
|
|
47
|
+
const result = [];
|
|
48
|
+
for (let i = allMessages.length - 1; i >= 0; i--) {
|
|
49
|
+
const message = allMessages[i];
|
|
50
|
+
if (!message) continue;
|
|
51
|
+
const messageTokens = this.countTokens(message);
|
|
52
|
+
if (totalTokens + messageTokens <= this.maxTokens) {
|
|
53
|
+
result.unshift(message);
|
|
54
|
+
totalTokens += messageTokens;
|
|
55
|
+
} else {
|
|
56
|
+
this.logger.info(
|
|
57
|
+
`filtering ${allMessages.length - result.length}/${allMessages.length} messages, token limit of ${this.maxTokens} exceeded`
|
|
58
|
+
);
|
|
59
|
+
break;
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
return result;
|
|
63
|
+
}
|
|
64
|
+
countTokens(message) {
|
|
65
|
+
if (typeof message === `string`) {
|
|
66
|
+
return this.encoder.encode(message).length;
|
|
67
|
+
}
|
|
68
|
+
let tokenString = message.role;
|
|
69
|
+
if (typeof message.content === "string") {
|
|
70
|
+
tokenString += message.content;
|
|
71
|
+
} else if (Array.isArray(message.content)) {
|
|
72
|
+
for (const part of message.content) {
|
|
73
|
+
tokenString += part.type;
|
|
74
|
+
if (part.type === "text") {
|
|
75
|
+
tokenString += part.text;
|
|
76
|
+
} else if (part.type === "tool-call") {
|
|
77
|
+
tokenString += part.toolName;
|
|
78
|
+
if (part.args) {
|
|
79
|
+
tokenString += typeof part.args === "string" ? part.args : JSON.stringify(part.args);
|
|
80
|
+
}
|
|
81
|
+
} else if (part.type === "tool-result") {
|
|
82
|
+
if (part.result !== void 0) {
|
|
83
|
+
tokenString += typeof part.result === "string" ? part.result : JSON.stringify(part.result);
|
|
84
|
+
}
|
|
85
|
+
} else {
|
|
86
|
+
tokenString += JSON.stringify(part);
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
const messageOverhead = this.TOKENS_PER_MESSAGE;
|
|
91
|
+
let toolOverhead = 0;
|
|
92
|
+
if (Array.isArray(message.content)) {
|
|
93
|
+
for (const part of message.content) {
|
|
94
|
+
if (part.type === "tool-call" || part.type === "tool-result") {
|
|
95
|
+
toolOverhead += this.TOKENS_PER_TOOL;
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
const totalMessageOverhead = messageOverhead + toolOverhead;
|
|
100
|
+
return this.encoder.encode(tokenString).length + totalMessageOverhead;
|
|
101
|
+
}
|
|
102
|
+
};
|
|
103
|
+
var ToolCallFilter = class extends MemoryProcessor$1 {
|
|
104
|
+
exclude;
|
|
105
|
+
/**
|
|
106
|
+
* Create a filter for tool calls and results.
|
|
107
|
+
* @param options Configuration options
|
|
108
|
+
* @param options.exclude List of specific tool names to exclude. If not provided, all tool calls are excluded.
|
|
109
|
+
*/
|
|
110
|
+
constructor(options = {}) {
|
|
111
|
+
super({ name: "ToolCallFilter" });
|
|
112
|
+
if (!options || !options.exclude) {
|
|
113
|
+
this.exclude = "all";
|
|
114
|
+
} else {
|
|
115
|
+
this.exclude = Array.isArray(options.exclude) ? options.exclude : [];
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
process(messages) {
|
|
119
|
+
if (this.exclude === "all") {
|
|
120
|
+
return messages.filter((message) => {
|
|
121
|
+
if (Array.isArray(message.content)) {
|
|
122
|
+
return !message.content.some((part) => part.type === "tool-call" || part.type === "tool-result");
|
|
123
|
+
}
|
|
124
|
+
return true;
|
|
125
|
+
});
|
|
126
|
+
}
|
|
127
|
+
if (this.exclude.length > 0) {
|
|
128
|
+
const excludedToolCallIds = /* @__PURE__ */ new Set();
|
|
129
|
+
return messages.filter((message) => {
|
|
130
|
+
if (!Array.isArray(message.content)) return true;
|
|
131
|
+
if (message.role === "assistant") {
|
|
132
|
+
let shouldExclude = false;
|
|
133
|
+
for (const part of message.content) {
|
|
134
|
+
if (part.type === "tool-call" && this.exclude.includes(part.toolName)) {
|
|
135
|
+
excludedToolCallIds.add(part.toolCallId);
|
|
136
|
+
shouldExclude = true;
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
return !shouldExclude;
|
|
140
|
+
}
|
|
141
|
+
if (message.role === "tool") {
|
|
142
|
+
const shouldExclude = message.content.some(
|
|
143
|
+
(part) => part.type === "tool-result" && excludedToolCallIds.has(part.toolCallId)
|
|
144
|
+
);
|
|
145
|
+
return !shouldExclude;
|
|
146
|
+
}
|
|
147
|
+
return true;
|
|
148
|
+
});
|
|
149
|
+
}
|
|
150
|
+
return messages;
|
|
151
|
+
}
|
|
152
|
+
};
|
|
153
|
+
|
|
154
|
+
export { TokenLimiter, ToolCallFilter };
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@mastra/memory",
|
|
3
|
-
"version": "0.2.
|
|
3
|
+
"version": "0.2.7-alpha.2",
|
|
4
4
|
"description": "",
|
|
5
5
|
"type": "module",
|
|
6
6
|
"main": "./dist/index.js",
|
|
@@ -16,24 +16,14 @@
|
|
|
16
16
|
"default": "./dist/index.cjs"
|
|
17
17
|
}
|
|
18
18
|
},
|
|
19
|
-
"./
|
|
19
|
+
"./processors": {
|
|
20
20
|
"import": {
|
|
21
|
-
"types": "./dist/
|
|
22
|
-
"default": "./dist/
|
|
21
|
+
"types": "./dist/processors/index.d.ts",
|
|
22
|
+
"default": "./dist/processors/index.js"
|
|
23
23
|
},
|
|
24
24
|
"require": {
|
|
25
|
-
"types": "./dist/
|
|
26
|
-
"default": "./dist/
|
|
27
|
-
}
|
|
28
|
-
},
|
|
29
|
-
"./postgres": {
|
|
30
|
-
"import": {
|
|
31
|
-
"types": "./dist/postgres/index.d.ts",
|
|
32
|
-
"default": "./dist/postgres/index.js"
|
|
33
|
-
},
|
|
34
|
-
"require": {
|
|
35
|
-
"types": "./dist/postgres/index.d.cts",
|
|
36
|
-
"default": "./dist/postgres/index.cjs"
|
|
25
|
+
"types": "./dist/processors/index.d.cts",
|
|
26
|
+
"default": "./dist/processors/index.cjs"
|
|
37
27
|
}
|
|
38
28
|
},
|
|
39
29
|
"./package.json": "./package.json"
|
|
@@ -43,16 +33,18 @@
|
|
|
43
33
|
"license": "ISC",
|
|
44
34
|
"dependencies": {
|
|
45
35
|
"@upstash/redis": "^1.34.5",
|
|
46
|
-
"
|
|
36
|
+
"js-tiktoken": "^1.0.19",
|
|
37
|
+
"ai": "^4.2.2",
|
|
47
38
|
"pg": "^8.13.3",
|
|
48
39
|
"pg-pool": "^3.7.1",
|
|
49
40
|
"postgres": "^3.4.5",
|
|
50
41
|
"redis": "^4.7.0",
|
|
51
42
|
"zod": "^3.24.2",
|
|
52
|
-
"@mastra/core": "^0.
|
|
53
|
-
"@mastra/rag": "^0.1.
|
|
43
|
+
"@mastra/core": "^0.8.0-alpha.2",
|
|
44
|
+
"@mastra/rag": "^0.1.15-alpha.2"
|
|
54
45
|
},
|
|
55
46
|
"devDependencies": {
|
|
47
|
+
"@ai-sdk/openai": "^1.3.3",
|
|
56
48
|
"@microsoft/api-extractor": "^7.52.1",
|
|
57
49
|
"@types/node": "^20.17.27",
|
|
58
50
|
"@types/pg": "^8.11.11",
|
|
@@ -65,10 +57,11 @@
|
|
|
65
57
|
},
|
|
66
58
|
"scripts": {
|
|
67
59
|
"check": "tsc --noEmit",
|
|
68
|
-
"build": "pnpm run check && tsup src/index.ts --format esm,cjs --experimental-dts --clean --treeshake=smallest --splitting",
|
|
60
|
+
"build": "pnpm run check && tsup src/index.ts src/processors/index.ts --format esm,cjs --experimental-dts --clean --treeshake=smallest --splitting",
|
|
69
61
|
"build:watch": "pnpm build --watch",
|
|
70
62
|
"test:integration": "cd integration-tests && pnpm run test",
|
|
71
|
-
"test": "pnpm
|
|
63
|
+
"test:unit": "pnpm vitest run ./src/*",
|
|
64
|
+
"test": "pnpm test:integration && pnpm test:unit",
|
|
72
65
|
"lint": "eslint ."
|
|
73
66
|
}
|
|
74
67
|
}
|
package/src/index.ts
CHANGED
|
@@ -41,7 +41,9 @@ export class Memory extends MastraMemory {
|
|
|
41
41
|
resourceId,
|
|
42
42
|
selectBy,
|
|
43
43
|
threadConfig,
|
|
44
|
-
}: StorageGetMessagesArg
|
|
44
|
+
}: StorageGetMessagesArg & {
|
|
45
|
+
threadConfig?: MemoryConfig;
|
|
46
|
+
}): Promise<{ messages: CoreMessage[]; uiMessages: AiMessageType[] }> {
|
|
45
47
|
if (resourceId) await this.validateThreadIsOwnedByResource(threadId, resourceId);
|
|
46
48
|
|
|
47
49
|
const vectorResults: {
|
|
@@ -148,7 +150,7 @@ export class Memory extends MastraMemory {
|
|
|
148
150
|
};
|
|
149
151
|
}
|
|
150
152
|
|
|
151
|
-
const
|
|
153
|
+
const messagesResult = await this.query({
|
|
152
154
|
threadId,
|
|
153
155
|
selectBy: {
|
|
154
156
|
last: threadConfig.lastMessages,
|
|
@@ -157,11 +159,11 @@ export class Memory extends MastraMemory {
|
|
|
157
159
|
threadConfig: config,
|
|
158
160
|
});
|
|
159
161
|
|
|
160
|
-
this.logger.debug(`Remembered message history includes ${
|
|
162
|
+
this.logger.debug(`Remembered message history includes ${messagesResult.messages.length} messages.`);
|
|
161
163
|
return {
|
|
162
164
|
threadId,
|
|
163
|
-
messages:
|
|
164
|
-
uiMessages:
|
|
165
|
+
messages: messagesResult.messages,
|
|
166
|
+
uiMessages: messagesResult.uiMessages,
|
|
165
167
|
};
|
|
166
168
|
}
|
|
167
169
|
|
|
@@ -271,6 +273,7 @@ export class Memory extends MastraMemory {
|
|
|
271
273
|
metadata: chunks.map(() => ({
|
|
272
274
|
message_id: message.id,
|
|
273
275
|
thread_id: message.threadId,
|
|
276
|
+
resource_id: message.resourceId,
|
|
274
277
|
})),
|
|
275
278
|
});
|
|
276
279
|
}
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
import { openai } from '@ai-sdk/openai';
|
|
2
|
+
import { createTool } from '@mastra/core';
|
|
3
|
+
import type { CoreMessage, MessageType } from '@mastra/core';
|
|
4
|
+
import { Agent } from '@mastra/core/agent';
|
|
5
|
+
import cl100k_base from 'js-tiktoken/ranks/cl100k_base';
|
|
6
|
+
import { describe, it, expect } from 'vitest';
|
|
7
|
+
import { z } from 'zod';
|
|
8
|
+
import { generateConversationHistory } from '../../integration-tests/src/test-utils';
|
|
9
|
+
import { TokenLimiter, ToolCallFilter } from './index';
|
|
10
|
+
|
|
11
|
+
describe('TokenLimiter', () => {
|
|
12
|
+
it('should limit messages to the specified token count', () => {
|
|
13
|
+
// Create messages with predictable token counts (approximately 25 tokens each)
|
|
14
|
+
const { messages } = generateConversationHistory({
|
|
15
|
+
threadId: '1',
|
|
16
|
+
messageCount: 5,
|
|
17
|
+
toolNames: [],
|
|
18
|
+
toolFrequency: 0,
|
|
19
|
+
});
|
|
20
|
+
|
|
21
|
+
const limiter = new TokenLimiter(200);
|
|
22
|
+
// @ts-ignore
|
|
23
|
+
const result = limiter.process(messages);
|
|
24
|
+
|
|
25
|
+
// Should prioritize newest messages (higher ids)
|
|
26
|
+
expect(result.length).toBe(2);
|
|
27
|
+
expect((result[0] as MessageType).id).toBe('message-8');
|
|
28
|
+
expect((result[1] as MessageType).id).toBe('message-9');
|
|
29
|
+
});
|
|
30
|
+
|
|
31
|
+
it('should handle empty messages array', () => {
|
|
32
|
+
const limiter = new TokenLimiter(1000);
|
|
33
|
+
const result = limiter.process([]);
|
|
34
|
+
expect(result).toEqual([]);
|
|
35
|
+
});
|
|
36
|
+
|
|
37
|
+
it('should use different encodings based on configuration', () => {
|
|
38
|
+
const { messages } = generateConversationHistory({
|
|
39
|
+
threadId: '6',
|
|
40
|
+
messageCount: 1,
|
|
41
|
+
toolNames: [],
|
|
42
|
+
toolFrequency: 0,
|
|
43
|
+
});
|
|
44
|
+
|
|
45
|
+
// Create limiters with different encoding settings
|
|
46
|
+
const defaultLimiter = new TokenLimiter(1000);
|
|
47
|
+
const customLimiter = new TokenLimiter({
|
|
48
|
+
limit: 1000,
|
|
49
|
+
encoding: cl100k_base,
|
|
50
|
+
});
|
|
51
|
+
|
|
52
|
+
// All should process messages successfully but potentially with different token counts
|
|
53
|
+
const defaultResult = defaultLimiter.process(messages as CoreMessage[]);
|
|
54
|
+
const customResult = customLimiter.process(messages as CoreMessage[]);
|
|
55
|
+
|
|
56
|
+
// Each should return the same messages but with potentially different token counts
|
|
57
|
+
expect(defaultResult.length).toBe(messages.length);
|
|
58
|
+
expect(customResult.length).toBe(messages.length);
|
|
59
|
+
});
|
|
60
|
+
|
|
61
|
+
function estimateTokens(messages: MessageType[]) {
|
|
62
|
+
// Create a TokenLimiter just for counting tokens
|
|
63
|
+
const testLimiter = new TokenLimiter(Infinity);
|
|
64
|
+
|
|
65
|
+
let estimatedTokens = testLimiter.TOKENS_PER_CONVERSATION;
|
|
66
|
+
|
|
67
|
+
// Count tokens for each message including all overheads
|
|
68
|
+
for (const message of messages) {
|
|
69
|
+
// Base token count from the countTokens method
|
|
70
|
+
estimatedTokens += testLimiter.countTokens(message as CoreMessage);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
return estimatedTokens;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
function percentDifference(a: number, b: number) {
|
|
77
|
+
const difference = Number(((Math.abs(a - b) / b) * 100).toFixed(2));
|
|
78
|
+
console.log(`${a} and ${b} are ${difference}% different`);
|
|
79
|
+
return difference;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
async function expectTokenEstimate(config: Parameters<typeof generateConversationHistory>[0], agent: Agent) {
|
|
83
|
+
const { messages, counts } = generateConversationHistory(config);
|
|
84
|
+
|
|
85
|
+
const estimate = estimateTokens(messages);
|
|
86
|
+
const used = (await agent.generate(messages.slice(0, -1) as CoreMessage[])).usage.totalTokens;
|
|
87
|
+
|
|
88
|
+
console.log(`Estimated ${estimate} tokens, used ${used} tokens.\n`, counts);
|
|
89
|
+
|
|
90
|
+
// Check if within 2% margin
|
|
91
|
+
expect(percentDifference(estimate, used)).toBeLessThanOrEqual(2);
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
const calculatorTool = createTool({
|
|
95
|
+
id: 'calculator',
|
|
96
|
+
description: 'Perform a simple calculation',
|
|
97
|
+
inputSchema: z.object({
|
|
98
|
+
expression: z.string().describe('The mathematical expression to calculate'),
|
|
99
|
+
}),
|
|
100
|
+
execute: async ({ context: { expression } }) => {
|
|
101
|
+
return `The result of ${expression} is ${eval(expression)}`;
|
|
102
|
+
},
|
|
103
|
+
});
|
|
104
|
+
|
|
105
|
+
const agent = new Agent({
|
|
106
|
+
name: 'token estimate agent',
|
|
107
|
+
model: openai('gpt-4o-mini'),
|
|
108
|
+
instructions: ``,
|
|
109
|
+
tools: { calculatorTool },
|
|
110
|
+
});
|
|
111
|
+
|
|
112
|
+
describe.concurrent(`98% accuracy`, () => {
|
|
113
|
+
it(`20 messages, no tools`, async () => {
|
|
114
|
+
await expectTokenEstimate(
|
|
115
|
+
{
|
|
116
|
+
messageCount: 10,
|
|
117
|
+
toolFrequency: 0,
|
|
118
|
+
threadId: '2',
|
|
119
|
+
},
|
|
120
|
+
agent,
|
|
121
|
+
);
|
|
122
|
+
});
|
|
123
|
+
|
|
124
|
+
it(`60 messages, no tools`, async () => {
|
|
125
|
+
await expectTokenEstimate(
|
|
126
|
+
{
|
|
127
|
+
messageCount: 30,
|
|
128
|
+
toolFrequency: 0,
|
|
129
|
+
threadId: '3',
|
|
130
|
+
},
|
|
131
|
+
agent,
|
|
132
|
+
);
|
|
133
|
+
});
|
|
134
|
+
|
|
135
|
+
it(`4 messages, 0 tools`, async () => {
|
|
136
|
+
await expectTokenEstimate(
|
|
137
|
+
{
|
|
138
|
+
messageCount: 2,
|
|
139
|
+
toolFrequency: 0,
|
|
140
|
+
threadId: '3',
|
|
141
|
+
},
|
|
142
|
+
agent,
|
|
143
|
+
);
|
|
144
|
+
});
|
|
145
|
+
|
|
146
|
+
it(`20 messages, 2 tool messages`, async () => {
|
|
147
|
+
await expectTokenEstimate(
|
|
148
|
+
{
|
|
149
|
+
messageCount: 10,
|
|
150
|
+
toolFrequency: 5,
|
|
151
|
+
threadId: '3',
|
|
152
|
+
},
|
|
153
|
+
agent,
|
|
154
|
+
);
|
|
155
|
+
});
|
|
156
|
+
|
|
157
|
+
it(`40 messages, 6 tool messages`, async () => {
|
|
158
|
+
await expectTokenEstimate(
|
|
159
|
+
{
|
|
160
|
+
messageCount: 20,
|
|
161
|
+
toolFrequency: 5,
|
|
162
|
+
threadId: '4',
|
|
163
|
+
},
|
|
164
|
+
agent,
|
|
165
|
+
);
|
|
166
|
+
});
|
|
167
|
+
|
|
168
|
+
it(`100 messages, 24 tool messages`, async () => {
|
|
169
|
+
await expectTokenEstimate(
|
|
170
|
+
{
|
|
171
|
+
messageCount: 50,
|
|
172
|
+
toolFrequency: 4,
|
|
173
|
+
threadId: '5',
|
|
174
|
+
},
|
|
175
|
+
agent,
|
|
176
|
+
);
|
|
177
|
+
});
|
|
178
|
+
|
|
179
|
+
it(`101 messages, 49 tool calls`, async () => {
|
|
180
|
+
await expectTokenEstimate(
|
|
181
|
+
{
|
|
182
|
+
messageCount: 50,
|
|
183
|
+
toolFrequency: 1,
|
|
184
|
+
threadId: '5',
|
|
185
|
+
},
|
|
186
|
+
agent,
|
|
187
|
+
);
|
|
188
|
+
});
|
|
189
|
+
});
|
|
190
|
+
});
|
|
191
|
+
|
|
192
|
+
describe.concurrent('ToolCallFilter', () => {
|
|
193
|
+
it('should exclude all tool calls when created with no arguments', () => {
|
|
194
|
+
const { messages } = generateConversationHistory({
|
|
195
|
+
threadId: '3',
|
|
196
|
+
toolNames: ['weather', 'calculator', 'search'],
|
|
197
|
+
messageCount: 1,
|
|
198
|
+
});
|
|
199
|
+
const filter = new ToolCallFilter();
|
|
200
|
+
const result = filter.process(messages as CoreMessage[]) as MessageType[];
|
|
201
|
+
|
|
202
|
+
// Should only keep the text message and assistant res
|
|
203
|
+
expect(result.length).toBe(2);
|
|
204
|
+
expect(result[0].id).toBe('message-0');
|
|
205
|
+
});
|
|
206
|
+
|
|
207
|
+
it('should exclude specific tool calls by name', () => {
|
|
208
|
+
const { messages } = generateConversationHistory({
|
|
209
|
+
threadId: '4',
|
|
210
|
+
toolNames: ['weather', 'calculator'],
|
|
211
|
+
messageCount: 2,
|
|
212
|
+
});
|
|
213
|
+
const filter = new ToolCallFilter({ exclude: ['weather'] });
|
|
214
|
+
const result = filter.process(messages as CoreMessage[]) as MessageType[];
|
|
215
|
+
|
|
216
|
+
// Should keep text message, assistant reply, calculator tool call, and calculator result
|
|
217
|
+
expect(result.length).toBe(4);
|
|
218
|
+
expect(result[0].id).toBe('message-0');
|
|
219
|
+
expect(result[1].id).toBe('message-1');
|
|
220
|
+
expect(result[2].id).toBe('message-2');
|
|
221
|
+
expect(result[3].id).toBe('message-3');
|
|
222
|
+
});
|
|
223
|
+
|
|
224
|
+
it('should keep all messages when exclude list is empty', () => {
|
|
225
|
+
const { messages } = generateConversationHistory({
|
|
226
|
+
threadId: '5',
|
|
227
|
+
toolNames: ['weather', 'calculator'],
|
|
228
|
+
});
|
|
229
|
+
|
|
230
|
+
const filter = new ToolCallFilter({ exclude: [] });
|
|
231
|
+
const result = filter.process(messages as CoreMessage[]);
|
|
232
|
+
|
|
233
|
+
// Should keep all messages
|
|
234
|
+
expect(result.length).toBe(messages.length);
|
|
235
|
+
});
|
|
236
|
+
});
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import type { CoreMessage, MemoryProcessorOpts } from '@mastra/core';
|
|
2
|
+
import { MemoryProcessor } from '@mastra/core/memory';
|
|
3
|
+
import { Tiktoken } from 'js-tiktoken/lite';
|
|
4
|
+
import type { TiktokenBPE } from 'js-tiktoken/lite';
|
|
5
|
+
import o200k_base from 'js-tiktoken/ranks/o200k_base';
|
|
6
|
+
|
|
7
|
+
/**
|
|
8
|
+
* Configuration options for TokenLimiter
|
|
9
|
+
*/
|
|
10
|
+
interface TokenLimiterOptions {
|
|
11
|
+
/** Maximum number of tokens to allow */
|
|
12
|
+
limit: number;
|
|
13
|
+
/** Optional encoding to use (defaults to o200k_base which is used by gpt-4o) */
|
|
14
|
+
encoding?: TiktokenBPE;
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
/**
|
|
18
|
+
* Limits the total number of tokens in the messages.
|
|
19
|
+
* Uses js-tiktoken with o200k_base encoding by default for accurate token counting with modern models.
|
|
20
|
+
*/
|
|
21
|
+
export class TokenLimiter extends MemoryProcessor {
|
|
22
|
+
private encoder: Tiktoken;
|
|
23
|
+
private maxTokens: number;
|
|
24
|
+
|
|
25
|
+
// Token overheads per OpenAI's documentation
|
|
26
|
+
// See: https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken#6-counting-tokens-for-chat-completions-api-calls
|
|
27
|
+
// Every message follows <|start|>{role/name}\n{content}<|end|>
|
|
28
|
+
public TOKENS_PER_MESSAGE = 3; // tokens added for each message (start & end tokens)
|
|
29
|
+
public TOKENS_PER_TOOL = 2; // empirical adjustment for tool calls
|
|
30
|
+
public TOKENS_PER_CONVERSATION = 25; // fixed overhead for the conversation
|
|
31
|
+
|
|
32
|
+
/**
|
|
33
|
+
* Create a token limiter for messages.
|
|
34
|
+
* @param options Either a number (token limit) or a configuration object
|
|
35
|
+
*/
|
|
36
|
+
constructor(options: number | TokenLimiterOptions) {
|
|
37
|
+
super({
|
|
38
|
+
name: 'TokenLimiter',
|
|
39
|
+
});
|
|
40
|
+
|
|
41
|
+
if (typeof options === 'number') {
|
|
42
|
+
// Simple number format - just the token limit with default encoding
|
|
43
|
+
this.maxTokens = options;
|
|
44
|
+
this.encoder = new Tiktoken(o200k_base);
|
|
45
|
+
} else {
|
|
46
|
+
// Object format with limit and optional encoding
|
|
47
|
+
this.maxTokens = options.limit;
|
|
48
|
+
this.encoder = new Tiktoken(options.encoding || o200k_base);
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
process(
|
|
53
|
+
messages: CoreMessage[],
|
|
54
|
+
{ systemMessage, memorySystemMessage, newMessages }: MemoryProcessorOpts = {},
|
|
55
|
+
): CoreMessage[] {
|
|
56
|
+
// Messages are already chronologically ordered - take most recent ones up to the token limit
|
|
57
|
+
let totalTokens = 0;
|
|
58
|
+
|
|
59
|
+
// Start with the conversation overhead
|
|
60
|
+
totalTokens += this.TOKENS_PER_CONVERSATION;
|
|
61
|
+
|
|
62
|
+
if (systemMessage) {
|
|
63
|
+
totalTokens += this.countTokens(systemMessage);
|
|
64
|
+
totalTokens += this.TOKENS_PER_MESSAGE; // Add message overhead for system message
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
if (memorySystemMessage) {
|
|
68
|
+
totalTokens += this.countTokens(memorySystemMessage);
|
|
69
|
+
totalTokens += this.TOKENS_PER_MESSAGE; // Add message overhead for memory system message
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
const allMessages = [...messages, ...(newMessages || [])];
|
|
73
|
+
|
|
74
|
+
const result: CoreMessage[] = [];
|
|
75
|
+
|
|
76
|
+
// Process messages in reverse (newest first)
|
|
77
|
+
for (let i = allMessages.length - 1; i >= 0; i--) {
|
|
78
|
+
const message = allMessages[i];
|
|
79
|
+
|
|
80
|
+
// Skip undefined messages (shouldn't happen, but TypeScript is concerned)
|
|
81
|
+
if (!message) continue;
|
|
82
|
+
|
|
83
|
+
const messageTokens = this.countTokens(message);
|
|
84
|
+
|
|
85
|
+
if (totalTokens + messageTokens <= this.maxTokens) {
|
|
86
|
+
// Insert at the beginning to maintain chronological order
|
|
87
|
+
result.unshift(message);
|
|
88
|
+
totalTokens += messageTokens;
|
|
89
|
+
} else {
|
|
90
|
+
this.logger.info(
|
|
91
|
+
`filtering ${allMessages.length - result.length}/${allMessages.length} messages, token limit of ${this.maxTokens} exceeded`,
|
|
92
|
+
);
|
|
93
|
+
// If we can't fit the message, we stop
|
|
94
|
+
break;
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
return result;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
public countTokens(message: string | CoreMessage): number {
|
|
102
|
+
if (typeof message === `string`) {
|
|
103
|
+
return this.encoder.encode(message).length;
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
let tokenString = message.role;
|
|
107
|
+
|
|
108
|
+
if (typeof message.content === 'string') {
|
|
109
|
+
tokenString += message.content;
|
|
110
|
+
} else if (Array.isArray(message.content)) {
|
|
111
|
+
// Calculate tokens for each content part
|
|
112
|
+
for (const part of message.content) {
|
|
113
|
+
tokenString += part.type;
|
|
114
|
+
if (part.type === 'text') {
|
|
115
|
+
tokenString += part.text;
|
|
116
|
+
} else if (part.type === 'tool-call') {
|
|
117
|
+
tokenString += part.toolName as any;
|
|
118
|
+
if (part.args) {
|
|
119
|
+
tokenString += typeof part.args === 'string' ? part.args : JSON.stringify(part.args);
|
|
120
|
+
}
|
|
121
|
+
} else if (part.type === 'tool-result') {
|
|
122
|
+
// Token cost for result if present
|
|
123
|
+
if (part.result !== undefined) {
|
|
124
|
+
tokenString += typeof part.result === 'string' ? part.result : JSON.stringify(part.result);
|
|
125
|
+
}
|
|
126
|
+
} else {
|
|
127
|
+
tokenString += JSON.stringify(part);
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
// Ensure we account for message formatting tokens
|
|
133
|
+
// See: https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken#6-counting-tokens-for-chat-completions-api-calls
|
|
134
|
+
const messageOverhead = this.TOKENS_PER_MESSAGE;
|
|
135
|
+
|
|
136
|
+
// Count tool calls for additional overhead
|
|
137
|
+
let toolOverhead = 0;
|
|
138
|
+
if (Array.isArray(message.content)) {
|
|
139
|
+
for (const part of message.content) {
|
|
140
|
+
if (part.type === 'tool-call' || part.type === 'tool-result') {
|
|
141
|
+
toolOverhead += this.TOKENS_PER_TOOL;
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
const totalMessageOverhead = messageOverhead + toolOverhead;
|
|
147
|
+
|
|
148
|
+
return this.encoder.encode(tokenString).length + totalMessageOverhead;
|
|
149
|
+
}
|
|
150
|
+
}
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import type { CoreMessage } from '@mastra/core';
|
|
2
|
+
import { MemoryProcessor } from '@mastra/core';
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* Filters out tool calls and results from messages.
|
|
6
|
+
* By default (with no arguments), excludes all tool calls and their results.
|
|
7
|
+
* Can be configured to exclude only specific tools by name.
|
|
8
|
+
*/
|
|
9
|
+
export class ToolCallFilter extends MemoryProcessor {
|
|
10
|
+
private exclude: string[] | 'all';
|
|
11
|
+
|
|
12
|
+
/**
|
|
13
|
+
* Create a filter for tool calls and results.
|
|
14
|
+
* @param options Configuration options
|
|
15
|
+
* @param options.exclude List of specific tool names to exclude. If not provided, all tool calls are excluded.
|
|
16
|
+
*/
|
|
17
|
+
constructor(options: { exclude?: string[] } = {}) {
|
|
18
|
+
super({ name: 'ToolCallFilter' });
|
|
19
|
+
// If no options or exclude is provided, exclude all tools
|
|
20
|
+
if (!options || !options.exclude) {
|
|
21
|
+
this.exclude = 'all'; // Exclude all tools
|
|
22
|
+
} else {
|
|
23
|
+
// Exclude specific tools
|
|
24
|
+
this.exclude = Array.isArray(options.exclude) ? options.exclude : [];
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
process(messages: CoreMessage[]): CoreMessage[] {
|
|
29
|
+
// Case 1: Exclude all tool calls and tool results
|
|
30
|
+
if (this.exclude === 'all') {
|
|
31
|
+
return messages.filter(message => {
|
|
32
|
+
if (Array.isArray(message.content)) {
|
|
33
|
+
return !message.content.some(part => part.type === 'tool-call' || part.type === 'tool-result');
|
|
34
|
+
}
|
|
35
|
+
return true;
|
|
36
|
+
});
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
// Case 2: Exclude specific tools by name
|
|
40
|
+
if (this.exclude.length > 0) {
|
|
41
|
+
// Single pass approach - track excluded tool call IDs while filtering
|
|
42
|
+
const excludedToolCallIds = new Set<string>();
|
|
43
|
+
|
|
44
|
+
return messages.filter(message => {
|
|
45
|
+
if (!Array.isArray(message.content)) return true;
|
|
46
|
+
|
|
47
|
+
// For assistant messages, check for excluded tool calls and track their IDs
|
|
48
|
+
if (message.role === 'assistant') {
|
|
49
|
+
let shouldExclude = false;
|
|
50
|
+
|
|
51
|
+
for (const part of message.content) {
|
|
52
|
+
if (part.type === 'tool-call' && this.exclude.includes(part.toolName)) {
|
|
53
|
+
excludedToolCallIds.add(part.toolCallId);
|
|
54
|
+
shouldExclude = true;
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
return !shouldExclude;
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
// For tool messages, filter out results for excluded tool calls
|
|
62
|
+
if (message.role === 'tool') {
|
|
63
|
+
const shouldExclude = message.content.some(
|
|
64
|
+
part => part.type === 'tool-result' && excludedToolCallIds.has(part.toolCallId),
|
|
65
|
+
);
|
|
66
|
+
|
|
67
|
+
return !shouldExclude;
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
return true;
|
|
71
|
+
});
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
// Case 3: Empty exclude array, return original messages
|
|
75
|
+
return messages;
|
|
76
|
+
}
|
|
77
|
+
}
|