@mastra/memory 0.10.2 → 0.10.3-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,9 @@
1
1
 
2
- > @mastra/memory@0.10.2-alpha.2 build /home/runner/work/mastra/mastra/packages/memory
2
+ > @mastra/memory@0.10.3-alpha.1 build /home/runner/work/mastra/mastra/packages/memory
3
3
  > pnpm run check && tsup --silent src/index.ts src/processors/index.ts --format esm,cjs --experimental-dts --clean --treeshake=smallest --splitting
4
4
 
5
5
 
6
- > @mastra/memory@0.10.2-alpha.2 check /home/runner/work/mastra/mastra/packages/memory
6
+ > @mastra/memory@0.10.3-alpha.1 check /home/runner/work/mastra/mastra/packages/memory
7
7
  > tsc --noEmit
8
8
 
9
9
  Analysis will use the bundled TypeScript version 5.8.3
package/CHANGELOG.md CHANGED
@@ -1,5 +1,28 @@
1
1
  # @mastra/memory
2
2
 
3
+ ## 0.10.3-alpha.1
4
+
5
+ ### Patch Changes
6
+
7
+ - 48eddb9: update filter logic in Memory class to support semantic recall search scope
8
+ - Updated dependencies [48eddb9]
9
+ - @mastra/core@0.10.4-alpha.2
10
+
11
+ ## 0.10.3-alpha.0
12
+
13
+ ### Patch Changes
14
+
15
+ - 1ccccff: dependencies updates:
16
+ - Updated dependency [`zod@^3.25.56` ↗︎](https://www.npmjs.com/package/zod/v/3.25.56) (from `^3.24.3`, in `dependencies`)
17
+ - 1ccccff: dependencies updates:
18
+ - Updated dependency [`zod@^3.25.56` ↗︎](https://www.npmjs.com/package/zod/v/3.25.56) (from `^3.24.3`, in `dependencies`)
19
+ - a382d3b: Fix token limiter estimations after recent MessageList work changed message structure
20
+ - Updated dependencies [f6fd25f]
21
+ - Updated dependencies [dffb67b]
22
+ - Updated dependencies [f1309d3]
23
+ - Updated dependencies [f7f8293]
24
+ - @mastra/core@0.10.4-alpha.1
25
+
3
26
  ## 0.10.2
4
27
 
5
28
  ### Patch Changes
@@ -21,6 +21,7 @@ import type { UIMessage } from 'ai';
21
21
  export declare class Memory extends MastraMemory {
22
22
  constructor(config?: SharedMemoryConfig);
23
23
  private validateThreadIsOwnedByResource;
24
+ private checkStorageFeatureSupport;
24
25
  query({ threadId, resourceId, selectBy, threadConfig, }: StorageGetMessagesArg & {
25
26
  threadConfig?: MemoryConfig;
26
27
  }): Promise<{
@@ -91,7 +92,6 @@ declare class TokenLimiter extends MemoryProcessor {
91
92
  private encoder;
92
93
  private maxTokens;
93
94
  TOKENS_PER_MESSAGE: number;
94
- TOKENS_PER_TOOL: number;
95
95
  TOKENS_PER_CONVERSATION: number;
96
96
  /**
97
97
  * Create a token limiter for messages.
@@ -21,6 +21,7 @@ import type { UIMessage } from 'ai';
21
21
  export declare class Memory extends MastraMemory {
22
22
  constructor(config?: SharedMemoryConfig);
23
23
  private validateThreadIsOwnedByResource;
24
+ private checkStorageFeatureSupport;
24
25
  query({ threadId, resourceId, selectBy, threadConfig, }: StorageGetMessagesArg & {
25
26
  threadConfig?: MemoryConfig;
26
27
  }): Promise<{
@@ -91,7 +92,6 @@ declare class TokenLimiter extends MemoryProcessor {
91
92
  private encoder;
92
93
  private maxTokens;
93
94
  TOKENS_PER_MESSAGE: number;
94
- TOKENS_PER_TOOL: number;
95
95
  TOKENS_PER_CONVERSATION: number;
96
96
  /**
97
97
  * Create a token limiter for messages.
package/dist/index.cjs CHANGED
@@ -56,6 +56,7 @@ var Memory = class extends memory.MastraMemory {
56
56
  }
57
57
  });
58
58
  this.threadConfig = mergedConfig;
59
+ this.checkStorageFeatureSupport(mergedConfig);
59
60
  }
60
61
  async validateThreadIsOwnedByResource(threadId, resourceId) {
61
62
  const thread = await this.storage.getThreadById({ threadId });
@@ -68,6 +69,13 @@ var Memory = class extends memory.MastraMemory {
68
69
  );
69
70
  }
70
71
  }
72
+ checkStorageFeatureSupport(config) {
73
+ if (typeof config.semanticRecall === `object` && config.semanticRecall.scope === `resource` && !this.storage.supports.selectByIncludeResourceScope) {
74
+ throw new Error(
75
+ `Memory error: Attached storage adapter "${this.storage.name || "unknown"}" doesn't support semanticRecall: { scope: "resource" } yet and currently only supports per-thread semantic recall.`
76
+ );
77
+ }
78
+ }
71
79
  async query({
72
80
  threadId,
73
81
  resourceId,
@@ -82,6 +90,7 @@ var Memory = class extends memory.MastraMemory {
82
90
  threadConfig
83
91
  });
84
92
  const config = this.getMergedThreadConfig(threadConfig || {});
93
+ this.checkStorageFeatureSupport(config);
85
94
  const defaultRange = DEFAULT_MESSAGE_RANGE;
86
95
  const defaultTopK = DEFAULT_TOP_K;
87
96
  const vectorConfig = typeof config?.semanticRecall === `boolean` ? {
@@ -91,7 +100,8 @@ var Memory = class extends memory.MastraMemory {
91
100
  topK: config?.semanticRecall?.topK ?? defaultTopK,
92
101
  messageRange: config?.semanticRecall?.messageRange ?? defaultRange
93
102
  };
94
- if (config?.semanticRecall && selectBy?.vectorSearchString && this.vector && !!selectBy.vectorSearchString) {
103
+ const resourceScope = typeof config?.semanticRecall === "object" && config?.semanticRecall?.scope === `resource`;
104
+ if (config?.semanticRecall && selectBy?.vectorSearchString && this.vector) {
95
105
  const { embeddings, dimension } = await this.embedMessageContent(selectBy.vectorSearchString);
96
106
  const { indexName } = await this.createEmbeddingIndex(dimension);
97
107
  await Promise.all(
@@ -106,7 +116,9 @@ var Memory = class extends memory.MastraMemory {
106
116
  indexName,
107
117
  queryVector: embedding,
108
118
  topK: vectorConfig.topK,
109
- filter: {
119
+ filter: resourceScope ? {
120
+ resource_id: resourceId
121
+ } : {
110
122
  thread_id: threadId
111
123
  }
112
124
  })
@@ -116,12 +128,14 @@ var Memory = class extends memory.MastraMemory {
116
128
  }
117
129
  const rawMessages = await this.storage.getMessages({
118
130
  threadId,
131
+ resourceId,
119
132
  format: "v2",
120
133
  selectBy: {
121
134
  ...selectBy,
122
135
  ...vectorResults?.length ? {
123
136
  include: vectorResults.map((r) => ({
124
137
  id: r.metadata?.message_id,
138
+ threadId: r.metadata?.thread_id,
125
139
  withNextMessages: typeof vectorConfig.messageRange === "number" ? vectorConfig.messageRange : vectorConfig.messageRange.after,
126
140
  withPreviousMessages: typeof vectorConfig.messageRange === "number" ? vectorConfig.messageRange : vectorConfig.messageRange.before
127
141
  }))
@@ -161,6 +175,7 @@ var Memory = class extends memory.MastraMemory {
161
175
  };
162
176
  }
163
177
  const messagesResult = await this.query({
178
+ resourceId,
164
179
  threadId,
165
180
  selectBy: {
166
181
  last: threadConfig.lastMessages,
package/dist/index.js CHANGED
@@ -50,6 +50,7 @@ var Memory = class extends MastraMemory {
50
50
  }
51
51
  });
52
52
  this.threadConfig = mergedConfig;
53
+ this.checkStorageFeatureSupport(mergedConfig);
53
54
  }
54
55
  async validateThreadIsOwnedByResource(threadId, resourceId) {
55
56
  const thread = await this.storage.getThreadById({ threadId });
@@ -62,6 +63,13 @@ var Memory = class extends MastraMemory {
62
63
  );
63
64
  }
64
65
  }
66
+ checkStorageFeatureSupport(config) {
67
+ if (typeof config.semanticRecall === `object` && config.semanticRecall.scope === `resource` && !this.storage.supports.selectByIncludeResourceScope) {
68
+ throw new Error(
69
+ `Memory error: Attached storage adapter "${this.storage.name || "unknown"}" doesn't support semanticRecall: { scope: "resource" } yet and currently only supports per-thread semantic recall.`
70
+ );
71
+ }
72
+ }
65
73
  async query({
66
74
  threadId,
67
75
  resourceId,
@@ -76,6 +84,7 @@ var Memory = class extends MastraMemory {
76
84
  threadConfig
77
85
  });
78
86
  const config = this.getMergedThreadConfig(threadConfig || {});
87
+ this.checkStorageFeatureSupport(config);
79
88
  const defaultRange = DEFAULT_MESSAGE_RANGE;
80
89
  const defaultTopK = DEFAULT_TOP_K;
81
90
  const vectorConfig = typeof config?.semanticRecall === `boolean` ? {
@@ -85,7 +94,8 @@ var Memory = class extends MastraMemory {
85
94
  topK: config?.semanticRecall?.topK ?? defaultTopK,
86
95
  messageRange: config?.semanticRecall?.messageRange ?? defaultRange
87
96
  };
88
- if (config?.semanticRecall && selectBy?.vectorSearchString && this.vector && !!selectBy.vectorSearchString) {
97
+ const resourceScope = typeof config?.semanticRecall === "object" && config?.semanticRecall?.scope === `resource`;
98
+ if (config?.semanticRecall && selectBy?.vectorSearchString && this.vector) {
89
99
  const { embeddings, dimension } = await this.embedMessageContent(selectBy.vectorSearchString);
90
100
  const { indexName } = await this.createEmbeddingIndex(dimension);
91
101
  await Promise.all(
@@ -100,7 +110,9 @@ var Memory = class extends MastraMemory {
100
110
  indexName,
101
111
  queryVector: embedding,
102
112
  topK: vectorConfig.topK,
103
- filter: {
113
+ filter: resourceScope ? {
114
+ resource_id: resourceId
115
+ } : {
104
116
  thread_id: threadId
105
117
  }
106
118
  })
@@ -110,12 +122,14 @@ var Memory = class extends MastraMemory {
110
122
  }
111
123
  const rawMessages = await this.storage.getMessages({
112
124
  threadId,
125
+ resourceId,
113
126
  format: "v2",
114
127
  selectBy: {
115
128
  ...selectBy,
116
129
  ...vectorResults?.length ? {
117
130
  include: vectorResults.map((r) => ({
118
131
  id: r.metadata?.message_id,
132
+ threadId: r.metadata?.thread_id,
119
133
  withNextMessages: typeof vectorConfig.messageRange === "number" ? vectorConfig.messageRange : vectorConfig.messageRange.after,
120
134
  withPreviousMessages: typeof vectorConfig.messageRange === "number" ? vectorConfig.messageRange : vectorConfig.messageRange.before
121
135
  }))
@@ -155,6 +169,7 @@ var Memory = class extends MastraMemory {
155
169
  };
156
170
  }
157
171
  const messagesResult = await this.query({
172
+ resourceId,
158
173
  threadId,
159
174
  selectBy: {
160
175
  last: threadConfig.lastMessages,
@@ -18,9 +18,7 @@ var TokenLimiter = class extends memory.MemoryProcessor {
18
18
  // Every message follows <|start|>{role/name}\n{content}<|end|>
19
19
  TOKENS_PER_MESSAGE = 3.8;
20
20
  // tokens added for each message (start & end tokens)
21
- TOKENS_PER_TOOL = 2.2;
22
- // empirical adjustment for tool calls
23
- TOKENS_PER_CONVERSATION = 25;
21
+ TOKENS_PER_CONVERSATION = 24;
24
22
  // fixed overhead for the conversation
25
23
  /**
26
24
  * Create a token limiter for messages.
@@ -74,38 +72,41 @@ var TokenLimiter = class extends memory.MemoryProcessor {
74
72
  return this.encoder.encode(message).length;
75
73
  }
76
74
  let tokenString = message.role;
77
- if (typeof message.content === "string") {
75
+ let overhead = 0;
76
+ if (typeof message.content === "string" && message.content) {
78
77
  tokenString += message.content;
79
78
  } else if (Array.isArray(message.content)) {
80
79
  for (const part of message.content) {
81
- tokenString += part.type;
82
80
  if (part.type === "text") {
83
81
  tokenString += part.text;
84
- } else if (part.type === "tool-call") {
85
- tokenString += part.toolName;
86
- if (part.args) {
87
- tokenString += typeof part.args === "string" ? part.args : JSON.stringify(part.args);
82
+ } else if (part.type === "tool-call" || part.type === `tool-result`) {
83
+ if (`args` in part && part.args && part.type === `tool-call`) {
84
+ tokenString += part.toolName;
85
+ if (typeof part.args === "string") {
86
+ tokenString += part.args;
87
+ } else {
88
+ tokenString += JSON.stringify(part.args);
89
+ overhead -= 12;
90
+ }
88
91
  }
89
- } else if (part.type === "tool-result") {
90
- if (part.result !== void 0) {
91
- tokenString += typeof part.result === "string" ? part.result : JSON.stringify(part.result);
92
+ if (`result` in part && part.result !== void 0 && part.type === `tool-result`) {
93
+ if (typeof part.result === "string") {
94
+ tokenString += part.result;
95
+ } else {
96
+ tokenString += JSON.stringify(part.result);
97
+ overhead -= 12;
98
+ }
92
99
  }
93
100
  } else {
94
101
  tokenString += JSON.stringify(part);
95
102
  }
96
103
  }
97
104
  }
98
- const messageOverhead = this.TOKENS_PER_MESSAGE;
99
- let toolOverhead = 0;
100
- if (Array.isArray(message.content)) {
101
- for (const part of message.content) {
102
- if (part.type === "tool-call" || part.type === "tool-result") {
103
- toolOverhead += this.TOKENS_PER_TOOL;
104
- }
105
- }
105
+ if (typeof message.content === `string` || // if the message included non-tool parts, add our message overhead
106
+ message.content.some((p) => p.type !== `tool-call` && p.type !== `tool-result`)) {
107
+ overhead += this.TOKENS_PER_MESSAGE;
106
108
  }
107
- const totalMessageOverhead = messageOverhead + toolOverhead;
108
- return this.encoder.encode(tokenString).length + totalMessageOverhead;
109
+ return this.encoder.encode(tokenString).length + overhead;
109
110
  }
110
111
  };
111
112
  var ToolCallFilter = class extends core.MemoryProcessor {
@@ -12,9 +12,7 @@ var TokenLimiter = class extends MemoryProcessor {
12
12
  // Every message follows <|start|>{role/name}\n{content}<|end|>
13
13
  TOKENS_PER_MESSAGE = 3.8;
14
14
  // tokens added for each message (start & end tokens)
15
- TOKENS_PER_TOOL = 2.2;
16
- // empirical adjustment for tool calls
17
- TOKENS_PER_CONVERSATION = 25;
15
+ TOKENS_PER_CONVERSATION = 24;
18
16
  // fixed overhead for the conversation
19
17
  /**
20
18
  * Create a token limiter for messages.
@@ -68,38 +66,41 @@ var TokenLimiter = class extends MemoryProcessor {
68
66
  return this.encoder.encode(message).length;
69
67
  }
70
68
  let tokenString = message.role;
71
- if (typeof message.content === "string") {
69
+ let overhead = 0;
70
+ if (typeof message.content === "string" && message.content) {
72
71
  tokenString += message.content;
73
72
  } else if (Array.isArray(message.content)) {
74
73
  for (const part of message.content) {
75
- tokenString += part.type;
76
74
  if (part.type === "text") {
77
75
  tokenString += part.text;
78
- } else if (part.type === "tool-call") {
79
- tokenString += part.toolName;
80
- if (part.args) {
81
- tokenString += typeof part.args === "string" ? part.args : JSON.stringify(part.args);
76
+ } else if (part.type === "tool-call" || part.type === `tool-result`) {
77
+ if (`args` in part && part.args && part.type === `tool-call`) {
78
+ tokenString += part.toolName;
79
+ if (typeof part.args === "string") {
80
+ tokenString += part.args;
81
+ } else {
82
+ tokenString += JSON.stringify(part.args);
83
+ overhead -= 12;
84
+ }
82
85
  }
83
- } else if (part.type === "tool-result") {
84
- if (part.result !== void 0) {
85
- tokenString += typeof part.result === "string" ? part.result : JSON.stringify(part.result);
86
+ if (`result` in part && part.result !== void 0 && part.type === `tool-result`) {
87
+ if (typeof part.result === "string") {
88
+ tokenString += part.result;
89
+ } else {
90
+ tokenString += JSON.stringify(part.result);
91
+ overhead -= 12;
92
+ }
86
93
  }
87
94
  } else {
88
95
  tokenString += JSON.stringify(part);
89
96
  }
90
97
  }
91
98
  }
92
- const messageOverhead = this.TOKENS_PER_MESSAGE;
93
- let toolOverhead = 0;
94
- if (Array.isArray(message.content)) {
95
- for (const part of message.content) {
96
- if (part.type === "tool-call" || part.type === "tool-result") {
97
- toolOverhead += this.TOKENS_PER_TOOL;
98
- }
99
- }
99
+ if (typeof message.content === `string` || // if the message included non-tool parts, add our message overhead
100
+ message.content.some((p) => p.type !== `tool-call` && p.type !== `tool-result`)) {
101
+ overhead += this.TOKENS_PER_MESSAGE;
100
102
  }
101
- const totalMessageOverhead = messageOverhead + toolOverhead;
102
- return this.encoder.encode(tokenString).length + totalMessageOverhead;
103
+ return this.encoder.encode(tokenString).length + overhead;
103
104
  }
104
105
  };
105
106
  var ToolCallFilter = class extends MemoryProcessor$1 {
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@mastra/memory",
3
- "version": "0.10.2",
3
+ "version": "0.10.3-alpha.1",
4
4
  "description": "",
5
5
  "type": "module",
6
6
  "main": "./dist/index.js",
@@ -40,20 +40,20 @@
40
40
  "postgres": "^3.4.5",
41
41
  "redis": "^4.7.0",
42
42
  "xxhash-wasm": "^1.1.0",
43
- "zod": "^3.24.3"
43
+ "zod": "^3.25.56"
44
44
  },
45
45
  "devDependencies": {
46
46
  "@ai-sdk/openai": "^1.3.3",
47
- "@microsoft/api-extractor": "^7.52.5",
48
- "@types/node": "^20.17.27",
47
+ "@microsoft/api-extractor": "^7.52.8",
48
+ "@types/node": "^20.17.57",
49
49
  "@types/pg": "^8.11.11",
50
- "eslint": "^9.23.0",
51
- "tsup": "^8.4.0",
50
+ "eslint": "^9.28.0",
51
+ "tsup": "^8.5.0",
52
52
  "typescript": "^5.8.2",
53
53
  "typescript-eslint": "^8.26.1",
54
- "vitest": "^3.1.2",
55
- "@internal/lint": "0.0.8",
56
- "@mastra/core": "0.10.2"
54
+ "vitest": "^3.2.2",
55
+ "@internal/lint": "0.0.10",
56
+ "@mastra/core": "0.10.4-alpha.2"
57
57
  },
58
58
  "peerDependencies": {
59
59
  "@mastra/core": "^0.10.2-alpha.0"
package/src/index.ts CHANGED
@@ -35,6 +35,8 @@ export class Memory extends MastraMemory {
35
35
  },
36
36
  });
37
37
  this.threadConfig = mergedConfig;
38
+
39
+ this.checkStorageFeatureSupport(mergedConfig);
38
40
  }
39
41
 
40
42
  private async validateThreadIsOwnedByResource(threadId: string, resourceId: string) {
@@ -49,6 +51,18 @@ export class Memory extends MastraMemory {
49
51
  }
50
52
  }
51
53
 
54
+ private checkStorageFeatureSupport(config: MemoryConfig) {
55
+ if (
56
+ typeof config.semanticRecall === `object` &&
57
+ config.semanticRecall.scope === `resource` &&
58
+ !this.storage.supports.selectByIncludeResourceScope
59
+ ) {
60
+ throw new Error(
61
+ `Memory error: Attached storage adapter "${this.storage.name || 'unknown'}" doesn't support semanticRecall: { scope: "resource" } yet and currently only supports per-thread semantic recall.`,
62
+ );
63
+ }
64
+ }
65
+
52
66
  async query({
53
67
  threadId,
54
68
  resourceId,
@@ -74,6 +88,8 @@ export class Memory extends MastraMemory {
74
88
 
75
89
  const config = this.getMergedThreadConfig(threadConfig || {});
76
90
 
91
+ this.checkStorageFeatureSupport(config);
92
+
77
93
  const defaultRange = DEFAULT_MESSAGE_RANGE;
78
94
  const defaultTopK = DEFAULT_TOP_K;
79
95
 
@@ -88,7 +104,9 @@ export class Memory extends MastraMemory {
88
104
  messageRange: config?.semanticRecall?.messageRange ?? defaultRange,
89
105
  };
90
106
 
91
- if (config?.semanticRecall && selectBy?.vectorSearchString && this.vector && !!selectBy.vectorSearchString) {
107
+ const resourceScope = typeof config?.semanticRecall === 'object' && config?.semanticRecall?.scope === `resource`;
108
+
109
+ if (config?.semanticRecall && selectBy?.vectorSearchString && this.vector) {
92
110
  const { embeddings, dimension } = await this.embedMessageContent(selectBy.vectorSearchString!);
93
111
  const { indexName } = await this.createEmbeddingIndex(dimension);
94
112
 
@@ -105,9 +123,13 @@ export class Memory extends MastraMemory {
105
123
  indexName,
106
124
  queryVector: embedding,
107
125
  topK: vectorConfig.topK,
108
- filter: {
109
- thread_id: threadId,
110
- },
126
+ filter: resourceScope
127
+ ? {
128
+ resource_id: resourceId,
129
+ }
130
+ : {
131
+ thread_id: threadId,
132
+ },
111
133
  })),
112
134
  );
113
135
  }),
@@ -117,6 +139,7 @@ export class Memory extends MastraMemory {
117
139
  // Get raw messages from storage
118
140
  const rawMessages = await this.storage.getMessages({
119
141
  threadId,
142
+ resourceId,
120
143
  format: 'v2',
121
144
  selectBy: {
122
145
  ...selectBy,
@@ -124,6 +147,7 @@ export class Memory extends MastraMemory {
124
147
  ? {
125
148
  include: vectorResults.map(r => ({
126
149
  id: r.metadata?.message_id,
150
+ threadId: r.metadata?.thread_id,
127
151
  withNextMessages:
128
152
  typeof vectorConfig.messageRange === 'number'
129
153
  ? vectorConfig.messageRange
@@ -188,6 +212,7 @@ export class Memory extends MastraMemory {
188
212
  }
189
213
 
190
214
  const messagesResult = await this.query({
215
+ resourceId,
191
216
  threadId,
192
217
  selectBy: {
193
218
  last: threadConfig.lastMessages,
@@ -85,7 +85,7 @@ describe('TokenLimiter', () => {
85
85
  const { messages, fakeCore, counts } = generateConversationHistory(config);
86
86
 
87
87
  const estimate = estimateTokens(messages);
88
- const used = (await agent.generate(fakeCore.slice(0, -1))).usage.totalTokens;
88
+ const used = (await agent.generate(fakeCore)).usage.promptTokens;
89
89
 
90
90
  console.log(`Estimated ${estimate} tokens, used ${used} tokens.\n`, counts);
91
91
 
@@ -100,7 +100,8 @@ describe('TokenLimiter', () => {
100
100
  expression: z.string().describe('The mathematical expression to calculate'),
101
101
  }),
102
102
  execute: async ({ context: { expression } }) => {
103
- return `The result of ${expression} is ${eval(expression)}`;
103
+ // Don't actually eval the expression. The model is dumb and sometimes passes "banana" as the expression because that's one of the sample tokens we're using in input messages lmao
104
+ return `The result of ${expression} is 10`;
104
105
  },
105
106
  });
106
107
 
@@ -178,16 +179,23 @@ describe('TokenLimiter', () => {
178
179
  );
179
180
  });
180
181
 
181
- it(`101 messages, 49 tool calls`, async () => {
182
- await expectTokenEstimate(
183
- {
184
- messageCount: 50,
185
- toolFrequency: 1,
186
- threadId: '5',
187
- },
188
- agent,
189
- );
190
- });
182
+ it(
183
+ `101 messages, 49 tool calls`,
184
+ async () => {
185
+ await expectTokenEstimate(
186
+ {
187
+ messageCount: 50,
188
+ toolFrequency: 1,
189
+ threadId: '5',
190
+ },
191
+ agent,
192
+ );
193
+ },
194
+ {
195
+ // for some reason AI SDK randomly returns 2x token count here
196
+ retry: 3,
197
+ },
198
+ );
191
199
  });
192
200
  });
193
201
 
@@ -26,8 +26,7 @@ export class TokenLimiter extends MemoryProcessor {
26
26
  // See: https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken#6-counting-tokens-for-chat-completions-api-calls
27
27
  // Every message follows <|start|>{role/name}\n{content}<|end|>
28
28
  public TOKENS_PER_MESSAGE = 3.8; // tokens added for each message (start & end tokens)
29
- public TOKENS_PER_TOOL = 2.2; // empirical adjustment for tool calls
30
- public TOKENS_PER_CONVERSATION = 25; // fixed overhead for the conversation
29
+ public TOKENS_PER_CONVERSATION = 24; // fixed overhead for the conversation
31
30
 
32
31
  /**
33
32
  * Create a token limiter for messages.
@@ -107,24 +106,35 @@ export class TokenLimiter extends MemoryProcessor {
107
106
  }
108
107
 
109
108
  let tokenString = message.role;
109
+ let overhead = 0;
110
110
 
111
- if (typeof message.content === 'string') {
111
+ if (typeof message.content === 'string' && message.content) {
112
112
  tokenString += message.content;
113
113
  } else if (Array.isArray(message.content)) {
114
114
  // Calculate tokens for each content part
115
115
  for (const part of message.content) {
116
- tokenString += part.type;
117
116
  if (part.type === 'text') {
118
117
  tokenString += part.text;
119
- } else if (part.type === 'tool-call') {
120
- tokenString += part.toolName as any;
121
- if (part.args) {
122
- tokenString += typeof part.args === 'string' ? part.args : JSON.stringify(part.args);
118
+ } else if (part.type === 'tool-call' || part.type === `tool-result`) {
119
+ if (`args` in part && part.args && part.type === `tool-call`) {
120
+ tokenString += part.toolName as any;
121
+ if (typeof part.args === 'string') {
122
+ tokenString += part.args;
123
+ } else {
124
+ tokenString += JSON.stringify(part.args);
125
+ // minus some tokens for JSON
126
+ overhead -= 12;
127
+ }
123
128
  }
124
- } else if (part.type === 'tool-result') {
125
129
  // Token cost for result if present
126
- if (part.result !== undefined) {
127
- tokenString += typeof part.result === 'string' ? part.result : JSON.stringify(part.result);
130
+ if (`result` in part && part.result !== undefined && part.type === `tool-result`) {
131
+ if (typeof part.result === 'string') {
132
+ tokenString += part.result;
133
+ } else {
134
+ tokenString += JSON.stringify(part.result);
135
+ // minus some tokens for JSON
136
+ overhead -= 12;
137
+ }
128
138
  }
129
139
  } else {
130
140
  tokenString += JSON.stringify(part);
@@ -132,22 +142,16 @@ export class TokenLimiter extends MemoryProcessor {
132
142
  }
133
143
  }
134
144
 
135
- // Ensure we account for message formatting tokens
136
- // See: https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken#6-counting-tokens-for-chat-completions-api-calls
137
- const messageOverhead = this.TOKENS_PER_MESSAGE;
138
-
139
- // Count tool calls for additional overhead
140
- let toolOverhead = 0;
141
- if (Array.isArray(message.content)) {
142
- for (const part of message.content) {
143
- if (part.type === 'tool-call' || part.type === 'tool-result') {
144
- toolOverhead += this.TOKENS_PER_TOOL;
145
- }
146
- }
145
+ if (
146
+ typeof message.content === `string` ||
147
+ // if the message included non-tool parts, add our message overhead
148
+ message.content.some(p => p.type !== `tool-call` && p.type !== `tool-result`)
149
+ ) {
150
+ // Ensure we account for message formatting tokens
151
+ // See: https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken#6-counting-tokens-for-chat-completions-api-calls
152
+ overhead += this.TOKENS_PER_MESSAGE;
147
153
  }
148
154
 
149
- const totalMessageOverhead = messageOverhead + toolOverhead;
150
-
151
- return this.encoder.encode(tokenString).length + totalMessageOverhead;
155
+ return this.encoder.encode(tokenString).length + overhead;
152
156
  }
153
157
  }