@mastra/memory 0.10.1 → 0.10.2-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/index.ts CHANGED
@@ -1,14 +1,15 @@
1
1
  import { deepMerge } from '@mastra/core';
2
- import type { AiMessageType, CoreMessage, CoreTool } from '@mastra/core';
2
+ import type { CoreTool, MastraMessageV1 } from '@mastra/core';
3
+ import { MessageList } from '@mastra/core/agent';
4
+ import type { MastraMessageV2 } from '@mastra/core/agent';
3
5
  import { MastraMemory } from '@mastra/core/memory';
4
- import type { MessageType, MemoryConfig, SharedMemoryConfig, StorageThreadType } from '@mastra/core/memory';
6
+ import type { MemoryConfig, SharedMemoryConfig, StorageThreadType } from '@mastra/core/memory';
5
7
  import type { StorageGetMessagesArg } from '@mastra/core/storage';
6
8
  import { embedMany } from 'ai';
7
- import type { TextPart } from 'ai';
9
+ import type { TextPart, UIMessage } from 'ai';
8
10
 
9
11
  import xxhash from 'xxhash-wasm';
10
12
  import { updateWorkingMemoryTool } from './tools/working-memory';
11
- import { reorderToolCallsAndResults } from './utils';
12
13
 
13
14
  // Average characters per token based on OpenAI's tokenization
14
15
  const CHARS_PER_TOKEN = 4;
@@ -55,7 +56,7 @@ export class Memory extends MastraMemory {
55
56
  threadConfig,
56
57
  }: StorageGetMessagesArg & {
57
58
  threadConfig?: MemoryConfig;
58
- }): Promise<{ messages: CoreMessage[]; uiMessages: AiMessageType[] }> {
59
+ }): Promise<{ messages: MastraMessageV1[]; uiMessages: UIMessage[] }> {
59
60
  if (resourceId) await this.validateThreadIsOwnedByResource(threadId, resourceId);
60
61
 
61
62
  const vectorResults: {
@@ -137,16 +138,28 @@ export class Memory extends MastraMemory {
137
138
  threadConfig: config,
138
139
  });
139
140
 
140
- // First sort messages by date
141
141
  const orderedByDate = rawMessages.sort((a, b) => new Date(a.createdAt).getTime() - new Date(b.createdAt).getTime());
142
- // Then reorder tool calls to be directly before their results
143
- const reorderedToolCalls = reorderToolCallsAndResults(orderedByDate);
144
142
 
145
- // Parse and convert messages
146
- const messages = this.parseMessages(reorderedToolCalls);
147
- const uiMessages = this.convertToUIMessages(reorderedToolCalls);
148
-
149
- return { messages, uiMessages };
143
+ const list = new MessageList({ threadId, resourceId }).add(orderedByDate, 'memory');
144
+ return {
145
+ get messages() {
146
+ // returning v1 messages for backwards compat! v1 messages were CoreMessages stored in the db.
147
+ // returning .v1() takes stored messages which may be in v2 or v1 format and converts them to v1 shape, which is a CoreMessage + id + threadId + resourceId, etc
148
+ // Perhaps this should be called coreRecord or something ? - for now keeping v1 since it reflects that this used to be our db storage record shape
149
+ const v1Messages = list.get.all.v1();
150
+ // the conversion from V2/UIMessage -> V1/CoreMessage can sometimes split the messages up into more messages than before
151
+ // so slice off the earlier messages if it'll exceed the lastMessages setting
152
+ if (selectBy?.last && v1Messages.length > selectBy.last) {
153
+ // ex: 23 (v1 messages) minus 20 (selectBy.last messages)
154
+ // means we will start from index 3 and keep all the later newer messages from index 3 til the end of the array
155
+ return v1Messages.slice(v1Messages.length - selectBy.last);
156
+ }
157
+ return v1Messages;
158
+ },
159
+ get uiMessages() {
160
+ return list.get.all.ui();
161
+ },
162
+ };
150
163
  }
151
164
 
152
165
  async rememberMessages({
@@ -159,19 +172,14 @@ export class Memory extends MastraMemory {
159
172
  resourceId?: string;
160
173
  vectorMessageSearch?: string;
161
174
  config?: MemoryConfig;
162
- }): Promise<{
163
- threadId: string;
164
- messages: CoreMessage[];
165
- uiMessages: AiMessageType[];
166
- }> {
175
+ }): Promise<{ messages: MastraMessageV1[]; messagesV2: MastraMessageV2[] }> {
167
176
  if (resourceId) await this.validateThreadIsOwnedByResource(threadId, resourceId);
168
177
  const threadConfig = this.getMergedThreadConfig(config || {});
169
178
 
170
179
  if (!threadConfig.lastMessages && !threadConfig.semanticRecall) {
171
180
  return {
172
181
  messages: [],
173
- uiMessages: [],
174
- threadId,
182
+ messagesV2: [],
175
183
  };
176
184
  }
177
185
 
@@ -183,13 +191,11 @@ export class Memory extends MastraMemory {
183
191
  },
184
192
  threadConfig: config,
185
193
  });
194
+ // Using MessageList here just to convert mixed input messages to single type output messages
195
+ const list = new MessageList({ threadId, resourceId }).add(messagesResult.messages, 'memory');
186
196
 
187
197
  this.logger.debug(`Remembered message history includes ${messagesResult.messages.length} messages.`);
188
- return {
189
- threadId,
190
- messages: messagesResult.messages,
191
- uiMessages: messagesResult.uiMessages,
192
- };
198
+ return { messages: list.get.all.v1(), messagesV2: list.get.all.mastra() };
193
199
  }
194
200
 
195
201
  async getThreadById({ threadId }: { threadId: string }): Promise<StorageThreadType | null> {
@@ -324,14 +330,20 @@ export class Memory extends MastraMemory {
324
330
  messages,
325
331
  memoryConfig,
326
332
  }: {
327
- messages: MessageType[];
333
+ messages: (MastraMessageV1 | MastraMessageV2)[];
328
334
  memoryConfig?: MemoryConfig;
329
- }): Promise<MessageType[]> {
330
- // First save working memory from any messages
331
- await this.saveWorkingMemory(messages);
332
-
335
+ }): Promise<MastraMessageV2[]> {
333
336
  // Then strip working memory tags from all messages
334
- const updatedMessages = this.updateMessagesToHideWorkingMemory(messages);
337
+ const updatedMessages = messages
338
+ .map(m => {
339
+ if (MessageList.isMastraMessageV1(m)) {
340
+ return this.updateMessageToHideWorkingMemory(m);
341
+ }
342
+ // add this to prevent "error saving undefined in the db" if a project is on an earlier storage version but new memory/storage
343
+ if (!m.type) m.type = `v2`;
344
+ return this.updateMessageToHideWorkingMemoryV2(m);
345
+ })
346
+ .filter((m): m is MastraMessageV1 | MastraMessageV2 => Boolean(m));
335
347
 
336
348
  const config = this.getMergedThreadConfig(memoryConfig);
337
349
 
@@ -343,16 +355,34 @@ export class Memory extends MastraMemory {
343
355
  updatedMessages.map(async message => {
344
356
  let textForEmbedding: string | null = null;
345
357
 
346
- if (typeof message.content === 'string' && message.content.trim() !== '') {
347
- textForEmbedding = message.content;
348
- } else if (Array.isArray(message.content)) {
349
- // Extract text from all text parts, concatenate
350
- const joined = message.content
351
- .filter(part => part && part.type === 'text' && typeof part.text === 'string')
352
- .map(part => (part as TextPart).text)
353
- .join(' ')
354
- .trim();
355
- if (joined) textForEmbedding = joined;
358
+ if (MessageList.isMastraMessageV2(message)) {
359
+ if (
360
+ message.content.content &&
361
+ typeof message.content.content === 'string' &&
362
+ message.content.content.trim() !== ''
363
+ ) {
364
+ textForEmbedding = message.content.content;
365
+ } else if (message.content.parts && message.content.parts.length > 0) {
366
+ // Extract text from all text parts, concatenate
367
+ const joined = message.content.parts
368
+ .filter(part => part.type === 'text')
369
+ .map(part => (part as TextPart).text)
370
+ .join(' ')
371
+ .trim();
372
+ if (joined) textForEmbedding = joined;
373
+ }
374
+ } else if (MessageList.isMastraMessageV1(message)) {
375
+ if (message.content && typeof message.content === 'string' && message.content.trim() !== '') {
376
+ textForEmbedding = message.content;
377
+ } else if (message.content && Array.isArray(message.content) && message.content.length > 0) {
378
+ // Extract text from all text parts, concatenate
379
+ const joined = message.content
380
+ .filter(part => part.type === 'text')
381
+ .map(part => part.text)
382
+ .join(' ')
383
+ .trim();
384
+ if (joined) textForEmbedding = joined;
385
+ }
356
386
  }
357
387
 
358
388
  if (!textForEmbedding) return;
@@ -384,47 +414,70 @@ export class Memory extends MastraMemory {
384
414
 
385
415
  return result;
386
416
  }
417
+ protected updateMessageToHideWorkingMemory(message: MastraMessageV1): MastraMessageV1 | null {
418
+ const workingMemoryRegex = /<working_memory>([^]*?)<\/working_memory>/g;
387
419
 
388
- protected updateMessagesToHideWorkingMemory(messages: MessageType[]): MessageType[] {
420
+ if (typeof message?.content === `string`) {
421
+ return {
422
+ ...message,
423
+ content: message.content.replace(workingMemoryRegex, ``).trim(),
424
+ };
425
+ } else if (Array.isArray(message?.content)) {
426
+ // Filter out updateWorkingMemory tool-call/result content items
427
+ const filteredContent = message.content.filter(
428
+ content =>
429
+ (content.type !== 'tool-call' && content.type !== 'tool-result') ||
430
+ content.toolName !== 'updateWorkingMemory',
431
+ );
432
+ const newContent = filteredContent.map(content => {
433
+ if (content.type === 'text') {
434
+ return {
435
+ ...content,
436
+ text: content.text.replace(workingMemoryRegex, '').trim(),
437
+ };
438
+ }
439
+ return { ...content };
440
+ }) as MastraMessageV1['content'];
441
+ if (!newContent.length) return null;
442
+ return { ...message, content: newContent };
443
+ } else {
444
+ return { ...message };
445
+ }
446
+ }
447
+ protected updateMessageToHideWorkingMemoryV2(message: MastraMessageV2): MastraMessageV2 | null {
389
448
  const workingMemoryRegex = /<working_memory>([^]*?)<\/working_memory>/g;
390
449
 
391
- const updatedMessages: MessageType[] = [];
450
+ const newMessage = { ...message, content: { ...message.content } }; // Deep copy message and content
392
451
 
393
- for (const message of messages) {
394
- if (typeof message?.content === `string`) {
395
- updatedMessages.push({
396
- ...message,
397
- content: message.content.replace(workingMemoryRegex, ``).trim(),
398
- });
399
- } else if (Array.isArray(message?.content)) {
400
- // Filter out updateWorkingMemory tool-call/result content items
401
- const filteredContent = message.content.filter(
402
- content =>
403
- !(
404
- (content.type === 'tool-call' || content.type === 'tool-result') &&
405
- content.toolName === 'updateWorkingMemory'
406
- ),
407
- );
408
- if (filteredContent.length === 0) {
409
- // If nothing left, skip this message
410
- continue;
411
- }
412
- const newContent = filteredContent.map(content => {
413
- if (content.type === 'text') {
452
+ if (newMessage.content.content && typeof newMessage.content.content === 'string') {
453
+ newMessage.content.content = newMessage.content.content.replace(workingMemoryRegex, '').trim();
454
+ }
455
+
456
+ if (newMessage.content.parts) {
457
+ newMessage.content.parts = newMessage.content.parts
458
+ .filter(part => {
459
+ if (part.type === 'tool-invocation') {
460
+ return part.toolInvocation.toolName !== 'updateWorkingMemory';
461
+ }
462
+ return true;
463
+ })
464
+ .map(part => {
465
+ if (part.type === 'text') {
414
466
  return {
415
- ...content,
416
- text: content.text.replace(workingMemoryRegex, '').trim(),
467
+ ...part,
468
+ text: part.text.replace(workingMemoryRegex, '').trim(),
417
469
  };
418
470
  }
419
- return { ...content };
420
- }) as MessageType['content'];
421
- updatedMessages.push({ ...message, content: newContent });
422
- } else {
423
- updatedMessages.push({ ...message });
471
+ return part;
472
+ });
473
+
474
+ // If all parts were filtered out (e.g., only contained updateWorkingMemory tool calls) we need to skip the whole message, it was only working memory tool calls/results
475
+ if (newMessage.content.parts.length === 0) {
476
+ return null;
424
477
  }
425
478
  }
426
479
 
427
- return updatedMessages;
480
+ return newMessage;
428
481
  }
429
482
 
430
483
  protected parseWorkingMemory(text: string): string | null {
@@ -457,46 +510,6 @@ export class Memory extends MastraMemory {
457
510
  return memory.trim();
458
511
  }
459
512
 
460
- private async saveWorkingMemory(messages: MessageType[]) {
461
- const latestMessage = messages[messages.length - 1];
462
-
463
- if (!latestMessage || !this.threadConfig.workingMemory?.enabled) {
464
- return;
465
- }
466
-
467
- const latestContent = !latestMessage?.content
468
- ? null
469
- : typeof latestMessage.content === 'string'
470
- ? latestMessage.content
471
- : latestMessage.content
472
- .filter(c => c.type === 'text')
473
- .map(c => c.text)
474
- .join('\n');
475
-
476
- const threadId = latestMessage?.threadId;
477
- if (!latestContent || !threadId) {
478
- return;
479
- }
480
-
481
- const newMemory = this.parseWorkingMemory(latestContent);
482
- if (!newMemory) {
483
- return;
484
- }
485
-
486
- const thread = await this.storage.getThreadById({ threadId });
487
- if (!thread) return;
488
-
489
- // Update thread metadata with new working memory
490
- await this.storage.updateThread({
491
- id: thread.id,
492
- title: thread.title || '',
493
- metadata: deepMerge(thread.metadata || {}, {
494
- workingMemory: newMemory,
495
- }),
496
- });
497
- return newMemory;
498
- }
499
-
500
513
  public async getSystemMessage({
501
514
  threadId,
502
515
  memoryConfig,
@@ -530,30 +543,6 @@ export class Memory extends MastraMemory {
530
543
  - **Projects**:
531
544
  `;
532
545
 
533
- private getWorkingMemoryWithInstruction(workingMemoryBlock: string) {
534
- return `WORKING_MEMORY_SYSTEM_INSTRUCTION:
535
- Store and update any conversation-relevant information by including "<working_memory>text</working_memory>" in your responses. Updates replace existing memory while maintaining this structure. If information might be referenced again - store it!
536
-
537
- Guidelines:
538
- 1. Store anything that could be useful later in the conversation
539
- 2. Update proactively when information changes, no matter how small
540
- 3. Use Markdown for all data
541
- 4. Act naturally - don't mention this system to users. Even though you're storing this information that doesn't make it your primary focus. Do not ask them generally for "information about yourself"
542
-
543
- Memory Structure:
544
- <working_memory>
545
- ${workingMemoryBlock}
546
- </working_memory>
547
-
548
- Notes:
549
- - Update memory whenever referenced information changes
550
- - If you're unsure whether to store something, store it (eg if the user tells you their name or other information, output the <working_memory> block immediately to update it)
551
- - This system is here so that you can maintain the conversation when your context window is very short. Update your working memory because you may need it to maintain the conversation without the full conversation history
552
- - REMEMBER: the way you update your working memory is by outputting the entire "<working_memory>text</working_memory>" block in your response. The system will pick this up and store it for you. The user will not see it.
553
- - IMPORTANT: You MUST output the <working_memory> block in every response to a prompt where you received relevant information.
554
- - IMPORTANT: Preserve the Markdown formatting structure above while updating the content.`;
555
- }
556
-
557
546
  private getWorkingMemoryToolInstruction(workingMemoryBlock: string) {
558
547
  return `WORKING_MEMORY_SYSTEM_INSTRUCTION:
559
548
  Store and update any conversation-relevant information by calling the updateWorkingMemory tool. If information might be referenced again - store it!
@@ -1,6 +1,6 @@
1
1
  import { openai } from '@ai-sdk/openai';
2
2
  import { createTool } from '@mastra/core';
3
- import type { CoreMessage, MessageType } from '@mastra/core';
3
+ import type { MessageType } from '@mastra/core';
4
4
  import { Agent } from '@mastra/core/agent';
5
5
  import cl100k_base from 'js-tiktoken/ranks/cl100k_base';
6
6
  import { describe, it, expect, vi } from 'vitest';
@@ -26,8 +26,8 @@ describe('TokenLimiter', () => {
26
26
 
27
27
  // Should prioritize newest messages (higher ids)
28
28
  expect(result.length).toBe(2);
29
- expect((result[0] as MessageType).id).toBe('message-8');
30
- expect((result[1] as MessageType).id).toBe('message-9');
29
+ expect(result[0].id).toBe('message-8');
30
+ expect(result[1].id).toBe('message-9');
31
31
  });
32
32
 
33
33
  it('should handle empty messages array', () => {
@@ -52,8 +52,8 @@ describe('TokenLimiter', () => {
52
52
  });
53
53
 
54
54
  // All should process messages successfully but potentially with different token counts
55
- const defaultResult = defaultLimiter.process(messages as CoreMessage[]);
56
- const customResult = customLimiter.process(messages as CoreMessage[]);
55
+ const defaultResult = defaultLimiter.process(messages);
56
+ const customResult = customLimiter.process(messages);
57
57
 
58
58
  // Each should return the same messages but with potentially different token counts
59
59
  expect(defaultResult.length).toBe(messages.length);
@@ -69,7 +69,7 @@ describe('TokenLimiter', () => {
69
69
  // Count tokens for each message including all overheads
70
70
  for (const message of messages) {
71
71
  // Base token count from the countTokens method
72
- estimatedTokens += testLimiter.countTokens(message as CoreMessage);
72
+ estimatedTokens += testLimiter.countTokens(message);
73
73
  }
74
74
 
75
75
  return Number(estimatedTokens.toFixed(2));
@@ -85,7 +85,7 @@ describe('TokenLimiter', () => {
85
85
  const { messages, counts } = generateConversationHistory(config);
86
86
 
87
87
  const estimate = estimateTokens(messages);
88
- const used = (await agent.generate(messages.slice(0, -1) as CoreMessage[])).usage.totalTokens;
88
+ const used = (await agent.generate(messages.slice(0, -1))).usage.totalTokens;
89
89
 
90
90
  console.log(`Estimated ${estimate} tokens, used ${used} tokens.\n`, counts);
91
91
 
@@ -199,7 +199,7 @@ describe.concurrent('ToolCallFilter', () => {
199
199
  messageCount: 1,
200
200
  });
201
201
  const filter = new ToolCallFilter();
202
- const result = filter.process(messages as CoreMessage[]) as MessageType[];
202
+ const result = filter.process(messages) as MessageType[];
203
203
 
204
204
  // Should only keep the text message and assistant res
205
205
  expect(result.length).toBe(2);
@@ -213,7 +213,7 @@ describe.concurrent('ToolCallFilter', () => {
213
213
  messageCount: 2,
214
214
  });
215
215
  const filter = new ToolCallFilter({ exclude: ['weather'] });
216
- const result = filter.process(messages as CoreMessage[]) as MessageType[];
216
+ const result = filter.process(messages);
217
217
 
218
218
  // Should keep text message, assistant reply, calculator tool call, and calculator result
219
219
  expect(result.length).toBe(4);
@@ -230,7 +230,7 @@ describe.concurrent('ToolCallFilter', () => {
230
230
  });
231
231
 
232
232
  const filter = new ToolCallFilter({ exclude: [] });
233
- const result = filter.process(messages as CoreMessage[]);
233
+ const result = filter.process(messages);
234
234
 
235
235
  // Should keep all messages
236
236
  expect(result.length).toBe(messages.length);
package/vitest.config.ts CHANGED
@@ -4,5 +4,8 @@ export default defineConfig({
4
4
  test: {
5
5
  environment: 'node',
6
6
  include: ['src/**/*.test.ts'],
7
+ // smaller output to save token space when LLMs run tests
8
+ reporters: 'dot',
9
+ bail: 1,
7
10
  },
8
11
  });
@@ -1,88 +0,0 @@
1
- import type { MessageType } from '@mastra/core/memory';
2
-
3
- const isToolCallWithId = (message: MessageType | undefined, targetToolCallId: string): boolean => {
4
- if (!message || !Array.isArray(message.content)) return false;
5
- return message.content.some(
6
- part =>
7
- part &&
8
- typeof part === 'object' &&
9
- 'type' in part &&
10
- part.type === 'tool-call' &&
11
- 'toolCallId' in part &&
12
- part.toolCallId === targetToolCallId,
13
- );
14
- };
15
-
16
- const getToolResultIndexById = (id: string, results: MessageType[]) =>
17
- results.findIndex(message => {
18
- if (!Array.isArray(message?.content)) return false;
19
- return message.content.some(
20
- part =>
21
- part &&
22
- typeof part === 'object' &&
23
- 'type' in part &&
24
- part.type === 'tool-result' &&
25
- 'toolCallId' in part &&
26
- part.toolCallId === id,
27
- );
28
- });
29
-
30
- /**
31
- * Self-heals message ordering to ensure tool calls are directly before their matching tool results.
32
- * This is needed due to a bug where messages were saved in the wrong order. That bug is fixed, but this code ensures any tool calls saved in the wrong order in the past will still be usable now.
33
- */
34
- export function reorderToolCallsAndResults(messages: MessageType[]): MessageType[] {
35
- if (!messages.length) return messages;
36
-
37
- // Create a copy of messages to avoid modifying the original
38
- const results = [...messages];
39
-
40
- const toolCallIds = new Set<string>();
41
-
42
- // First loop: collect all tool result IDs in a set
43
- for (const message of results) {
44
- if (!Array.isArray(message.content)) continue;
45
-
46
- for (const part of message.content) {
47
- if (
48
- part &&
49
- typeof part === 'object' &&
50
- 'type' in part &&
51
- part.type === 'tool-result' &&
52
- 'toolCallId' in part &&
53
- part.toolCallId
54
- ) {
55
- toolCallIds.add(part.toolCallId);
56
- }
57
- }
58
- }
59
-
60
- // Second loop: for each tool ID, ensure tool calls come before tool results
61
- for (const toolCallId of toolCallIds) {
62
- // Find tool result index
63
- const resultIndex = getToolResultIndexById(toolCallId, results);
64
-
65
- // Check if tool call is at resultIndex - 1
66
- const oneMessagePrev = results[resultIndex - 1];
67
- if (isToolCallWithId(oneMessagePrev, toolCallId)) {
68
- continue; // Tool call is already in the correct position
69
- }
70
-
71
- // Find the tool call anywhere in the array
72
- const toolCallIndex = results.findIndex(message => isToolCallWithId(message, toolCallId));
73
-
74
- if (toolCallIndex !== -1 && toolCallIndex !== resultIndex - 1) {
75
- // Store the tool call message
76
- const toolCall = results[toolCallIndex];
77
- if (!toolCall) continue;
78
-
79
- // Remove the tool call from its current position
80
- results.splice(toolCallIndex, 1);
81
-
82
- // Insert right before the tool result
83
- results.splice(getToolResultIndexById(toolCallId, results), 0, toolCall);
84
- }
85
- }
86
-
87
- return results;
88
- }