@mastra/memory 0.10.1-alpha.0 → 0.10.2-alpha.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +3 -20
- package/CHANGELOG.md +26 -0
- package/dist/_tsup-dts-rollup.d.cts +11 -18
- package/dist/_tsup-dts-rollup.d.ts +11 -18
- package/dist/index.cjs +86 -135
- package/dist/index.js +86 -135
- package/package.json +4 -4
- package/src/index.ts +125 -136
- package/src/processors/index.test.ts +10 -10
- package/vitest.config.ts +3 -0
- package/src/utils/index.ts +0 -88
package/src/index.ts
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
import { deepMerge } from '@mastra/core';
|
|
2
|
-
import type {
|
|
2
|
+
import type { CoreTool, MastraMessageV1 } from '@mastra/core';
|
|
3
|
+
import { MessageList } from '@mastra/core/agent';
|
|
4
|
+
import type { MastraMessageV2 } from '@mastra/core/agent';
|
|
3
5
|
import { MastraMemory } from '@mastra/core/memory';
|
|
4
|
-
import type {
|
|
6
|
+
import type { MemoryConfig, SharedMemoryConfig, StorageThreadType } from '@mastra/core/memory';
|
|
5
7
|
import type { StorageGetMessagesArg } from '@mastra/core/storage';
|
|
6
8
|
import { embedMany } from 'ai';
|
|
7
|
-
import type { TextPart } from 'ai';
|
|
9
|
+
import type { TextPart, UIMessage } from 'ai';
|
|
8
10
|
|
|
9
11
|
import xxhash from 'xxhash-wasm';
|
|
10
12
|
import { updateWorkingMemoryTool } from './tools/working-memory';
|
|
11
|
-
import { reorderToolCallsAndResults } from './utils';
|
|
12
13
|
|
|
13
14
|
// Average characters per token based on OpenAI's tokenization
|
|
14
15
|
const CHARS_PER_TOKEN = 4;
|
|
@@ -55,7 +56,7 @@ export class Memory extends MastraMemory {
|
|
|
55
56
|
threadConfig,
|
|
56
57
|
}: StorageGetMessagesArg & {
|
|
57
58
|
threadConfig?: MemoryConfig;
|
|
58
|
-
}): Promise<{ messages:
|
|
59
|
+
}): Promise<{ messages: MastraMessageV1[]; uiMessages: UIMessage[] }> {
|
|
59
60
|
if (resourceId) await this.validateThreadIsOwnedByResource(threadId, resourceId);
|
|
60
61
|
|
|
61
62
|
const vectorResults: {
|
|
@@ -137,16 +138,28 @@ export class Memory extends MastraMemory {
|
|
|
137
138
|
threadConfig: config,
|
|
138
139
|
});
|
|
139
140
|
|
|
140
|
-
// First sort messages by date
|
|
141
141
|
const orderedByDate = rawMessages.sort((a, b) => new Date(a.createdAt).getTime() - new Date(b.createdAt).getTime());
|
|
142
|
-
// Then reorder tool calls to be directly before their results
|
|
143
|
-
const reorderedToolCalls = reorderToolCallsAndResults(orderedByDate);
|
|
144
142
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
143
|
+
const list = new MessageList({ threadId, resourceId }).add(orderedByDate, 'memory');
|
|
144
|
+
return {
|
|
145
|
+
get messages() {
|
|
146
|
+
// returning v1 messages for backwards compat! v1 messages were CoreMessages stored in the db.
|
|
147
|
+
// returning .v1() takes stored messages which may be in v2 or v1 format and converts them to v1 shape, which is a CoreMessage + id + threadId + resourceId, etc
|
|
148
|
+
// Perhaps this should be called coreRecord or something ? - for now keeping v1 since it reflects that this used to be our db storage record shape
|
|
149
|
+
const v1Messages = list.get.all.v1();
|
|
150
|
+
// the conversion from V2/UIMessage -> V1/CoreMessage can sometimes split the messages up into more messages than before
|
|
151
|
+
// so slice off the earlier messages if it'll exceed the lastMessages setting
|
|
152
|
+
if (selectBy?.last && v1Messages.length > selectBy.last) {
|
|
153
|
+
// ex: 23 (v1 messages) minus 20 (selectBy.last messages)
|
|
154
|
+
// means we will start from index 3 and keep all the later newer messages from index 3 til the end of the array
|
|
155
|
+
return v1Messages.slice(v1Messages.length - selectBy.last);
|
|
156
|
+
}
|
|
157
|
+
return v1Messages;
|
|
158
|
+
},
|
|
159
|
+
get uiMessages() {
|
|
160
|
+
return list.get.all.ui();
|
|
161
|
+
},
|
|
162
|
+
};
|
|
150
163
|
}
|
|
151
164
|
|
|
152
165
|
async rememberMessages({
|
|
@@ -159,19 +172,14 @@ export class Memory extends MastraMemory {
|
|
|
159
172
|
resourceId?: string;
|
|
160
173
|
vectorMessageSearch?: string;
|
|
161
174
|
config?: MemoryConfig;
|
|
162
|
-
}): Promise<{
|
|
163
|
-
threadId: string;
|
|
164
|
-
messages: CoreMessage[];
|
|
165
|
-
uiMessages: AiMessageType[];
|
|
166
|
-
}> {
|
|
175
|
+
}): Promise<{ messages: MastraMessageV1[]; messagesV2: MastraMessageV2[] }> {
|
|
167
176
|
if (resourceId) await this.validateThreadIsOwnedByResource(threadId, resourceId);
|
|
168
177
|
const threadConfig = this.getMergedThreadConfig(config || {});
|
|
169
178
|
|
|
170
179
|
if (!threadConfig.lastMessages && !threadConfig.semanticRecall) {
|
|
171
180
|
return {
|
|
172
181
|
messages: [],
|
|
173
|
-
|
|
174
|
-
threadId,
|
|
182
|
+
messagesV2: [],
|
|
175
183
|
};
|
|
176
184
|
}
|
|
177
185
|
|
|
@@ -183,13 +191,11 @@ export class Memory extends MastraMemory {
|
|
|
183
191
|
},
|
|
184
192
|
threadConfig: config,
|
|
185
193
|
});
|
|
194
|
+
// Using MessageList here just to convert mixed input messages to single type output messages
|
|
195
|
+
const list = new MessageList({ threadId, resourceId }).add(messagesResult.messages, 'memory');
|
|
186
196
|
|
|
187
197
|
this.logger.debug(`Remembered message history includes ${messagesResult.messages.length} messages.`);
|
|
188
|
-
return {
|
|
189
|
-
threadId,
|
|
190
|
-
messages: messagesResult.messages,
|
|
191
|
-
uiMessages: messagesResult.uiMessages,
|
|
192
|
-
};
|
|
198
|
+
return { messages: list.get.all.v1(), messagesV2: list.get.all.mastra() };
|
|
193
199
|
}
|
|
194
200
|
|
|
195
201
|
async getThreadById({ threadId }: { threadId: string }): Promise<StorageThreadType | null> {
|
|
@@ -324,14 +330,20 @@ export class Memory extends MastraMemory {
|
|
|
324
330
|
messages,
|
|
325
331
|
memoryConfig,
|
|
326
332
|
}: {
|
|
327
|
-
messages:
|
|
333
|
+
messages: (MastraMessageV1 | MastraMessageV2)[];
|
|
328
334
|
memoryConfig?: MemoryConfig;
|
|
329
|
-
}): Promise<
|
|
330
|
-
// First save working memory from any messages
|
|
331
|
-
await this.saveWorkingMemory(messages);
|
|
332
|
-
|
|
335
|
+
}): Promise<MastraMessageV2[]> {
|
|
333
336
|
// Then strip working memory tags from all messages
|
|
334
|
-
const updatedMessages =
|
|
337
|
+
const updatedMessages = messages
|
|
338
|
+
.map(m => {
|
|
339
|
+
if (MessageList.isMastraMessageV1(m)) {
|
|
340
|
+
return this.updateMessageToHideWorkingMemory(m);
|
|
341
|
+
}
|
|
342
|
+
// add this to prevent "error saving undefined in the db" if a project is on an earlier storage version but new memory/storage
|
|
343
|
+
if (!m.type) m.type = `v2`;
|
|
344
|
+
return this.updateMessageToHideWorkingMemoryV2(m);
|
|
345
|
+
})
|
|
346
|
+
.filter((m): m is MastraMessageV1 | MastraMessageV2 => Boolean(m));
|
|
335
347
|
|
|
336
348
|
const config = this.getMergedThreadConfig(memoryConfig);
|
|
337
349
|
|
|
@@ -343,16 +355,34 @@ export class Memory extends MastraMemory {
|
|
|
343
355
|
updatedMessages.map(async message => {
|
|
344
356
|
let textForEmbedding: string | null = null;
|
|
345
357
|
|
|
346
|
-
if (
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
358
|
+
if (MessageList.isMastraMessageV2(message)) {
|
|
359
|
+
if (
|
|
360
|
+
message.content.content &&
|
|
361
|
+
typeof message.content.content === 'string' &&
|
|
362
|
+
message.content.content.trim() !== ''
|
|
363
|
+
) {
|
|
364
|
+
textForEmbedding = message.content.content;
|
|
365
|
+
} else if (message.content.parts && message.content.parts.length > 0) {
|
|
366
|
+
// Extract text from all text parts, concatenate
|
|
367
|
+
const joined = message.content.parts
|
|
368
|
+
.filter(part => part.type === 'text')
|
|
369
|
+
.map(part => (part as TextPart).text)
|
|
370
|
+
.join(' ')
|
|
371
|
+
.trim();
|
|
372
|
+
if (joined) textForEmbedding = joined;
|
|
373
|
+
}
|
|
374
|
+
} else if (MessageList.isMastraMessageV1(message)) {
|
|
375
|
+
if (message.content && typeof message.content === 'string' && message.content.trim() !== '') {
|
|
376
|
+
textForEmbedding = message.content;
|
|
377
|
+
} else if (message.content && Array.isArray(message.content) && message.content.length > 0) {
|
|
378
|
+
// Extract text from all text parts, concatenate
|
|
379
|
+
const joined = message.content
|
|
380
|
+
.filter(part => part.type === 'text')
|
|
381
|
+
.map(part => part.text)
|
|
382
|
+
.join(' ')
|
|
383
|
+
.trim();
|
|
384
|
+
if (joined) textForEmbedding = joined;
|
|
385
|
+
}
|
|
356
386
|
}
|
|
357
387
|
|
|
358
388
|
if (!textForEmbedding) return;
|
|
@@ -384,47 +414,70 @@ export class Memory extends MastraMemory {
|
|
|
384
414
|
|
|
385
415
|
return result;
|
|
386
416
|
}
|
|
417
|
+
protected updateMessageToHideWorkingMemory(message: MastraMessageV1): MastraMessageV1 | null {
|
|
418
|
+
const workingMemoryRegex = /<working_memory>([^]*?)<\/working_memory>/g;
|
|
387
419
|
|
|
388
|
-
|
|
420
|
+
if (typeof message?.content === `string`) {
|
|
421
|
+
return {
|
|
422
|
+
...message,
|
|
423
|
+
content: message.content.replace(workingMemoryRegex, ``).trim(),
|
|
424
|
+
};
|
|
425
|
+
} else if (Array.isArray(message?.content)) {
|
|
426
|
+
// Filter out updateWorkingMemory tool-call/result content items
|
|
427
|
+
const filteredContent = message.content.filter(
|
|
428
|
+
content =>
|
|
429
|
+
(content.type !== 'tool-call' && content.type !== 'tool-result') ||
|
|
430
|
+
content.toolName !== 'updateWorkingMemory',
|
|
431
|
+
);
|
|
432
|
+
const newContent = filteredContent.map(content => {
|
|
433
|
+
if (content.type === 'text') {
|
|
434
|
+
return {
|
|
435
|
+
...content,
|
|
436
|
+
text: content.text.replace(workingMemoryRegex, '').trim(),
|
|
437
|
+
};
|
|
438
|
+
}
|
|
439
|
+
return { ...content };
|
|
440
|
+
}) as MastraMessageV1['content'];
|
|
441
|
+
if (!newContent.length) return null;
|
|
442
|
+
return { ...message, content: newContent };
|
|
443
|
+
} else {
|
|
444
|
+
return { ...message };
|
|
445
|
+
}
|
|
446
|
+
}
|
|
447
|
+
protected updateMessageToHideWorkingMemoryV2(message: MastraMessageV2): MastraMessageV2 | null {
|
|
389
448
|
const workingMemoryRegex = /<working_memory>([^]*?)<\/working_memory>/g;
|
|
390
449
|
|
|
391
|
-
const
|
|
450
|
+
const newMessage = { ...message, content: { ...message.content } }; // Deep copy message and content
|
|
392
451
|
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
);
|
|
408
|
-
if (filteredContent.length === 0) {
|
|
409
|
-
// If nothing left, skip this message
|
|
410
|
-
continue;
|
|
411
|
-
}
|
|
412
|
-
const newContent = filteredContent.map(content => {
|
|
413
|
-
if (content.type === 'text') {
|
|
452
|
+
if (newMessage.content.content && typeof newMessage.content.content === 'string') {
|
|
453
|
+
newMessage.content.content = newMessage.content.content.replace(workingMemoryRegex, '').trim();
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
if (newMessage.content.parts) {
|
|
457
|
+
newMessage.content.parts = newMessage.content.parts
|
|
458
|
+
.filter(part => {
|
|
459
|
+
if (part.type === 'tool-invocation') {
|
|
460
|
+
return part.toolInvocation.toolName !== 'updateWorkingMemory';
|
|
461
|
+
}
|
|
462
|
+
return true;
|
|
463
|
+
})
|
|
464
|
+
.map(part => {
|
|
465
|
+
if (part.type === 'text') {
|
|
414
466
|
return {
|
|
415
|
-
...
|
|
416
|
-
text:
|
|
467
|
+
...part,
|
|
468
|
+
text: part.text.replace(workingMemoryRegex, '').trim(),
|
|
417
469
|
};
|
|
418
470
|
}
|
|
419
|
-
return
|
|
420
|
-
})
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
471
|
+
return part;
|
|
472
|
+
});
|
|
473
|
+
|
|
474
|
+
// If all parts were filtered out (e.g., only contained updateWorkingMemory tool calls) we need to skip the whole message, it was only working memory tool calls/results
|
|
475
|
+
if (newMessage.content.parts.length === 0) {
|
|
476
|
+
return null;
|
|
424
477
|
}
|
|
425
478
|
}
|
|
426
479
|
|
|
427
|
-
return
|
|
480
|
+
return newMessage;
|
|
428
481
|
}
|
|
429
482
|
|
|
430
483
|
protected parseWorkingMemory(text: string): string | null {
|
|
@@ -457,46 +510,6 @@ export class Memory extends MastraMemory {
|
|
|
457
510
|
return memory.trim();
|
|
458
511
|
}
|
|
459
512
|
|
|
460
|
-
private async saveWorkingMemory(messages: MessageType[]) {
|
|
461
|
-
const latestMessage = messages[messages.length - 1];
|
|
462
|
-
|
|
463
|
-
if (!latestMessage || !this.threadConfig.workingMemory?.enabled) {
|
|
464
|
-
return;
|
|
465
|
-
}
|
|
466
|
-
|
|
467
|
-
const latestContent = !latestMessage?.content
|
|
468
|
-
? null
|
|
469
|
-
: typeof latestMessage.content === 'string'
|
|
470
|
-
? latestMessage.content
|
|
471
|
-
: latestMessage.content
|
|
472
|
-
.filter(c => c.type === 'text')
|
|
473
|
-
.map(c => c.text)
|
|
474
|
-
.join('\n');
|
|
475
|
-
|
|
476
|
-
const threadId = latestMessage?.threadId;
|
|
477
|
-
if (!latestContent || !threadId) {
|
|
478
|
-
return;
|
|
479
|
-
}
|
|
480
|
-
|
|
481
|
-
const newMemory = this.parseWorkingMemory(latestContent);
|
|
482
|
-
if (!newMemory) {
|
|
483
|
-
return;
|
|
484
|
-
}
|
|
485
|
-
|
|
486
|
-
const thread = await this.storage.getThreadById({ threadId });
|
|
487
|
-
if (!thread) return;
|
|
488
|
-
|
|
489
|
-
// Update thread metadata with new working memory
|
|
490
|
-
await this.storage.updateThread({
|
|
491
|
-
id: thread.id,
|
|
492
|
-
title: thread.title || '',
|
|
493
|
-
metadata: deepMerge(thread.metadata || {}, {
|
|
494
|
-
workingMemory: newMemory,
|
|
495
|
-
}),
|
|
496
|
-
});
|
|
497
|
-
return newMemory;
|
|
498
|
-
}
|
|
499
|
-
|
|
500
513
|
public async getSystemMessage({
|
|
501
514
|
threadId,
|
|
502
515
|
memoryConfig,
|
|
@@ -530,30 +543,6 @@ export class Memory extends MastraMemory {
|
|
|
530
543
|
- **Projects**:
|
|
531
544
|
`;
|
|
532
545
|
|
|
533
|
-
private getWorkingMemoryWithInstruction(workingMemoryBlock: string) {
|
|
534
|
-
return `WORKING_MEMORY_SYSTEM_INSTRUCTION:
|
|
535
|
-
Store and update any conversation-relevant information by including "<working_memory>text</working_memory>" in your responses. Updates replace existing memory while maintaining this structure. If information might be referenced again - store it!
|
|
536
|
-
|
|
537
|
-
Guidelines:
|
|
538
|
-
1. Store anything that could be useful later in the conversation
|
|
539
|
-
2. Update proactively when information changes, no matter how small
|
|
540
|
-
3. Use Markdown for all data
|
|
541
|
-
4. Act naturally - don't mention this system to users. Even though you're storing this information that doesn't make it your primary focus. Do not ask them generally for "information about yourself"
|
|
542
|
-
|
|
543
|
-
Memory Structure:
|
|
544
|
-
<working_memory>
|
|
545
|
-
${workingMemoryBlock}
|
|
546
|
-
</working_memory>
|
|
547
|
-
|
|
548
|
-
Notes:
|
|
549
|
-
- Update memory whenever referenced information changes
|
|
550
|
-
- If you're unsure whether to store something, store it (eg if the user tells you their name or other information, output the <working_memory> block immediately to update it)
|
|
551
|
-
- This system is here so that you can maintain the conversation when your context window is very short. Update your working memory because you may need it to maintain the conversation without the full conversation history
|
|
552
|
-
- REMEMBER: the way you update your working memory is by outputting the entire "<working_memory>text</working_memory>" block in your response. The system will pick this up and store it for you. The user will not see it.
|
|
553
|
-
- IMPORTANT: You MUST output the <working_memory> block in every response to a prompt where you received relevant information.
|
|
554
|
-
- IMPORTANT: Preserve the Markdown formatting structure above while updating the content.`;
|
|
555
|
-
}
|
|
556
|
-
|
|
557
546
|
private getWorkingMemoryToolInstruction(workingMemoryBlock: string) {
|
|
558
547
|
return `WORKING_MEMORY_SYSTEM_INSTRUCTION:
|
|
559
548
|
Store and update any conversation-relevant information by calling the updateWorkingMemory tool. If information might be referenced again - store it!
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { openai } from '@ai-sdk/openai';
|
|
2
2
|
import { createTool } from '@mastra/core';
|
|
3
|
-
import type {
|
|
3
|
+
import type { MessageType } from '@mastra/core';
|
|
4
4
|
import { Agent } from '@mastra/core/agent';
|
|
5
5
|
import cl100k_base from 'js-tiktoken/ranks/cl100k_base';
|
|
6
6
|
import { describe, it, expect, vi } from 'vitest';
|
|
@@ -26,8 +26,8 @@ describe('TokenLimiter', () => {
|
|
|
26
26
|
|
|
27
27
|
// Should prioritize newest messages (higher ids)
|
|
28
28
|
expect(result.length).toBe(2);
|
|
29
|
-
expect(
|
|
30
|
-
expect(
|
|
29
|
+
expect(result[0].id).toBe('message-8');
|
|
30
|
+
expect(result[1].id).toBe('message-9');
|
|
31
31
|
});
|
|
32
32
|
|
|
33
33
|
it('should handle empty messages array', () => {
|
|
@@ -52,8 +52,8 @@ describe('TokenLimiter', () => {
|
|
|
52
52
|
});
|
|
53
53
|
|
|
54
54
|
// All should process messages successfully but potentially with different token counts
|
|
55
|
-
const defaultResult = defaultLimiter.process(messages
|
|
56
|
-
const customResult = customLimiter.process(messages
|
|
55
|
+
const defaultResult = defaultLimiter.process(messages);
|
|
56
|
+
const customResult = customLimiter.process(messages);
|
|
57
57
|
|
|
58
58
|
// Each should return the same messages but with potentially different token counts
|
|
59
59
|
expect(defaultResult.length).toBe(messages.length);
|
|
@@ -69,7 +69,7 @@ describe('TokenLimiter', () => {
|
|
|
69
69
|
// Count tokens for each message including all overheads
|
|
70
70
|
for (const message of messages) {
|
|
71
71
|
// Base token count from the countTokens method
|
|
72
|
-
estimatedTokens += testLimiter.countTokens(message
|
|
72
|
+
estimatedTokens += testLimiter.countTokens(message);
|
|
73
73
|
}
|
|
74
74
|
|
|
75
75
|
return Number(estimatedTokens.toFixed(2));
|
|
@@ -85,7 +85,7 @@ describe('TokenLimiter', () => {
|
|
|
85
85
|
const { messages, counts } = generateConversationHistory(config);
|
|
86
86
|
|
|
87
87
|
const estimate = estimateTokens(messages);
|
|
88
|
-
const used = (await agent.generate(messages.slice(0, -1)
|
|
88
|
+
const used = (await agent.generate(messages.slice(0, -1))).usage.totalTokens;
|
|
89
89
|
|
|
90
90
|
console.log(`Estimated ${estimate} tokens, used ${used} tokens.\n`, counts);
|
|
91
91
|
|
|
@@ -199,7 +199,7 @@ describe.concurrent('ToolCallFilter', () => {
|
|
|
199
199
|
messageCount: 1,
|
|
200
200
|
});
|
|
201
201
|
const filter = new ToolCallFilter();
|
|
202
|
-
const result = filter.process(messages
|
|
202
|
+
const result = filter.process(messages) as MessageType[];
|
|
203
203
|
|
|
204
204
|
// Should only keep the text message and assistant res
|
|
205
205
|
expect(result.length).toBe(2);
|
|
@@ -213,7 +213,7 @@ describe.concurrent('ToolCallFilter', () => {
|
|
|
213
213
|
messageCount: 2,
|
|
214
214
|
});
|
|
215
215
|
const filter = new ToolCallFilter({ exclude: ['weather'] });
|
|
216
|
-
const result = filter.process(messages
|
|
216
|
+
const result = filter.process(messages);
|
|
217
217
|
|
|
218
218
|
// Should keep text message, assistant reply, calculator tool call, and calculator result
|
|
219
219
|
expect(result.length).toBe(4);
|
|
@@ -230,7 +230,7 @@ describe.concurrent('ToolCallFilter', () => {
|
|
|
230
230
|
});
|
|
231
231
|
|
|
232
232
|
const filter = new ToolCallFilter({ exclude: [] });
|
|
233
|
-
const result = filter.process(messages
|
|
233
|
+
const result = filter.process(messages);
|
|
234
234
|
|
|
235
235
|
// Should keep all messages
|
|
236
236
|
expect(result.length).toBe(messages.length);
|
package/vitest.config.ts
CHANGED
package/src/utils/index.ts
DELETED
|
@@ -1,88 +0,0 @@
|
|
|
1
|
-
import type { MessageType } from '@mastra/core/memory';
|
|
2
|
-
|
|
3
|
-
const isToolCallWithId = (message: MessageType | undefined, targetToolCallId: string): boolean => {
|
|
4
|
-
if (!message || !Array.isArray(message.content)) return false;
|
|
5
|
-
return message.content.some(
|
|
6
|
-
part =>
|
|
7
|
-
part &&
|
|
8
|
-
typeof part === 'object' &&
|
|
9
|
-
'type' in part &&
|
|
10
|
-
part.type === 'tool-call' &&
|
|
11
|
-
'toolCallId' in part &&
|
|
12
|
-
part.toolCallId === targetToolCallId,
|
|
13
|
-
);
|
|
14
|
-
};
|
|
15
|
-
|
|
16
|
-
const getToolResultIndexById = (id: string, results: MessageType[]) =>
|
|
17
|
-
results.findIndex(message => {
|
|
18
|
-
if (!Array.isArray(message?.content)) return false;
|
|
19
|
-
return message.content.some(
|
|
20
|
-
part =>
|
|
21
|
-
part &&
|
|
22
|
-
typeof part === 'object' &&
|
|
23
|
-
'type' in part &&
|
|
24
|
-
part.type === 'tool-result' &&
|
|
25
|
-
'toolCallId' in part &&
|
|
26
|
-
part.toolCallId === id,
|
|
27
|
-
);
|
|
28
|
-
});
|
|
29
|
-
|
|
30
|
-
/**
|
|
31
|
-
* Self-heals message ordering to ensure tool calls are directly before their matching tool results.
|
|
32
|
-
* This is needed due to a bug where messages were saved in the wrong order. That bug is fixed, but this code ensures any tool calls saved in the wrong order in the past will still be usable now.
|
|
33
|
-
*/
|
|
34
|
-
export function reorderToolCallsAndResults(messages: MessageType[]): MessageType[] {
|
|
35
|
-
if (!messages.length) return messages;
|
|
36
|
-
|
|
37
|
-
// Create a copy of messages to avoid modifying the original
|
|
38
|
-
const results = [...messages];
|
|
39
|
-
|
|
40
|
-
const toolCallIds = new Set<string>();
|
|
41
|
-
|
|
42
|
-
// First loop: collect all tool result IDs in a set
|
|
43
|
-
for (const message of results) {
|
|
44
|
-
if (!Array.isArray(message.content)) continue;
|
|
45
|
-
|
|
46
|
-
for (const part of message.content) {
|
|
47
|
-
if (
|
|
48
|
-
part &&
|
|
49
|
-
typeof part === 'object' &&
|
|
50
|
-
'type' in part &&
|
|
51
|
-
part.type === 'tool-result' &&
|
|
52
|
-
'toolCallId' in part &&
|
|
53
|
-
part.toolCallId
|
|
54
|
-
) {
|
|
55
|
-
toolCallIds.add(part.toolCallId);
|
|
56
|
-
}
|
|
57
|
-
}
|
|
58
|
-
}
|
|
59
|
-
|
|
60
|
-
// Second loop: for each tool ID, ensure tool calls come before tool results
|
|
61
|
-
for (const toolCallId of toolCallIds) {
|
|
62
|
-
// Find tool result index
|
|
63
|
-
const resultIndex = getToolResultIndexById(toolCallId, results);
|
|
64
|
-
|
|
65
|
-
// Check if tool call is at resultIndex - 1
|
|
66
|
-
const oneMessagePrev = results[resultIndex - 1];
|
|
67
|
-
if (isToolCallWithId(oneMessagePrev, toolCallId)) {
|
|
68
|
-
continue; // Tool call is already in the correct position
|
|
69
|
-
}
|
|
70
|
-
|
|
71
|
-
// Find the tool call anywhere in the array
|
|
72
|
-
const toolCallIndex = results.findIndex(message => isToolCallWithId(message, toolCallId));
|
|
73
|
-
|
|
74
|
-
if (toolCallIndex !== -1 && toolCallIndex !== resultIndex - 1) {
|
|
75
|
-
// Store the tool call message
|
|
76
|
-
const toolCall = results[toolCallIndex];
|
|
77
|
-
if (!toolCall) continue;
|
|
78
|
-
|
|
79
|
-
// Remove the tool call from its current position
|
|
80
|
-
results.splice(toolCallIndex, 1);
|
|
81
|
-
|
|
82
|
-
// Insert right before the tool result
|
|
83
|
-
results.splice(getToolResultIndexById(toolCallId, results), 0, toolCall);
|
|
84
|
-
}
|
|
85
|
-
}
|
|
86
|
-
|
|
87
|
-
return results;
|
|
88
|
-
}
|