@mastra/memory 0.12.2 → 0.12.3-alpha.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,367 @@
1
+ import { spawn } from 'node:child_process';
2
+ import { randomUUID } from 'node:crypto';
3
+ import type { UUID } from 'node:crypto';
4
+ import { createServer } from 'node:net';
5
+ import path from 'node:path';
6
+ import { openai } from '@ai-sdk/openai';
7
+ import { useChat } from '@ai-sdk/react';
8
+ import { Mastra } from '@mastra/core';
9
+ import { Agent } from '@mastra/core/agent';
10
+ import { renderHook, act, waitFor } from '@testing-library/react';
11
+ import type { UIMessage } from 'ai';
12
+ import { DefaultChatTransport, isToolUIPart } from 'ai';
13
+ import { JSDOM } from 'jsdom';
14
+ import { describe, expect, it, beforeAll, afterAll } from 'vitest';
15
+ import { memory, weatherAgent } from './mastra/agents/weather';
16
+ import { weatherTool } from './mastra/tools/weather';
17
+
18
+ // Helper to find an available port
19
+ async function getAvailablePort(): Promise<number> {
20
+ return new Promise((resolve, reject) => {
21
+ const server = createServer();
22
+ server.listen(0, () => {
23
+ const { port } = server.address() as { port: number };
24
+ server.close(() => resolve(port));
25
+ });
26
+ server.on('error', reject);
27
+ });
28
+ }
29
+
30
+ // Set up JSDOM environment for React testing
31
+ const dom = new JSDOM('<!doctype html><html><body></body></html>', {
32
+ url: 'http://localhost',
33
+ pretendToBeVisual: true,
34
+ resources: 'usable',
35
+ });
36
+ // @ts-ignore - JSDOM types don't match exactly but this works for testing
37
+ global.window = dom.window;
38
+ global.document = dom.window.document;
39
+ global.navigator = dom.window.navigator;
40
+ global.fetch = global.fetch || fetch;
41
+
42
+ describe('Memory Streaming Tests', () => {
43
+ it('should handle multiple tool calls in memory thread history', async () => {
44
+ // Create agent with memory and tools
45
+ const agent = new Agent({
46
+ name: 'test',
47
+ instructions:
48
+ 'You are a weather agent. When asked about weather in any city, use the get_weather tool with the city name as the postal code. Respond in a pirate accent and dont use the degrees symbol, print the word degrees when needed.',
49
+ model: openai('gpt-4o'),
50
+ memory,
51
+ tools: { get_weather: weatherTool },
52
+ });
53
+
54
+ const threadId = randomUUID();
55
+ const resourceId = 'test-resource';
56
+
57
+ // First weather check
58
+ const stream1 = await agent.streamVNext('what is the weather in LA?', {
59
+ threadId,
60
+ resourceId,
61
+ });
62
+
63
+ // Collect first stream
64
+ const chunks1: string[] = [];
65
+ for await (const chunk of stream1.fullStream) {
66
+ if (chunk.type === `text-delta`) {
67
+ chunks1.push(chunk.payload.text);
68
+ }
69
+ }
70
+ const response1 = chunks1.join('');
71
+
72
+ expect(chunks1.length).toBeGreaterThan(0);
73
+ expect(response1).toContain('70 degrees');
74
+
75
+ // Second weather check
76
+ const stream2 = await agent.streamVNext('what is the weather in Seattle?', {
77
+ threadId,
78
+ resourceId,
79
+ format: 'aisdk', // use aisdk output type this time just for fun
80
+ });
81
+
82
+ // Collect second stream
83
+ const chunks2: string[] = [];
84
+ for await (const chunk of stream2.fullStream) {
85
+ if (chunk.type === `text-delta`) {
86
+ chunks2.push(chunk.text);
87
+ }
88
+ }
89
+ const response2 = chunks2.join('');
90
+
91
+ expect(chunks2.length).toBeGreaterThan(0);
92
+ expect(response2).toContain('Seattle');
93
+ expect(response2).toContain('70 degrees');
94
+ });
95
+
96
+ it('should use custom mastra ID generator for messages in memory', async () => {
97
+ const agent = new Agent({
98
+ name: 'test-msg-id',
99
+ instructions: 'you are a helpful assistant.',
100
+ model: openai('gpt-4o'),
101
+ memory,
102
+ });
103
+
104
+ const threadId = randomUUID();
105
+ const resourceId = 'test-resource-msg-id';
106
+ const customIds: UUID[] = [];
107
+
108
+ new Mastra({
109
+ idGenerator: () => {
110
+ const id = randomUUID();
111
+ customIds.push(id);
112
+ return id;
113
+ },
114
+ agents: {
115
+ agent: agent,
116
+ },
117
+ });
118
+
119
+ await agent.generateVNext('Hello, world!', {
120
+ threadId,
121
+ resourceId,
122
+ });
123
+
124
+ const agentMemory = (await agent.getMemory())!;
125
+ const { messages } = await agentMemory.query({ threadId });
126
+
127
+ console.log('Custom IDs: ', customIds);
128
+ console.log('Messages: ', messages);
129
+
130
+ expect(messages).toHaveLength(2);
131
+ expect(messages.length).toBeLessThan(customIds.length);
132
+ for (const message of messages) {
133
+ if (!(`id` in message)) {
134
+ throw new Error(`Expected message.id`);
135
+ }
136
+ expect(customIds).contains(message.id);
137
+ }
138
+ });
139
+
140
+ describe('should stream via useChat after tool call', () => {
141
+ let mastraServer: ReturnType<typeof spawn>;
142
+ let port: number;
143
+ const threadId = randomUUID();
144
+ const resourceId = 'test-resource';
145
+
146
+ beforeAll(async () => {
147
+ port = await getAvailablePort();
148
+
149
+ mastraServer = spawn(
150
+ 'pnpm',
151
+ [
152
+ path.resolve(import.meta.dirname, `..`, `..`, `..`, `cli`, `dist`, `index.js`),
153
+ 'dev',
154
+ '--port',
155
+ port.toString(),
156
+ ],
157
+ {
158
+ stdio: 'pipe',
159
+ detached: true, // Run in a new process group so we can kill it and children
160
+ },
161
+ );
162
+
163
+ // Wait for server to be ready
164
+ await new Promise<void>((resolve, reject) => {
165
+ let output = '';
166
+ mastraServer.stdout?.on('data', data => {
167
+ output += data.toString();
168
+ console.log(output);
169
+ if (output.includes('http://localhost:')) {
170
+ resolve();
171
+ }
172
+ });
173
+ mastraServer.stderr?.on('data', data => {
174
+ console.error('Mastra server error:', data.toString());
175
+ });
176
+
177
+ setTimeout(() => reject(new Error('Mastra server failed to start')), 10000);
178
+ });
179
+ });
180
+
181
+ afterAll(() => {
182
+ // Kill the server and its process group
183
+ if (mastraServer?.pid) {
184
+ try {
185
+ process.kill(-mastraServer.pid, 'SIGTERM');
186
+ } catch (e) {
187
+ console.error('Failed to kill Mastra server:', e);
188
+ }
189
+ }
190
+ });
191
+
192
+ it('should stream via useChat after tool call', async () => {
193
+ let error: Error | null = null;
194
+ const { result } = renderHook(() => {
195
+ const chat = useChat({
196
+ transport: new DefaultChatTransport({
197
+ api: `http://localhost:${port}/api/agents/test/stream/vnext/ui`,
198
+ prepareSendMessagesRequest({ messages }) {
199
+ return {
200
+ body: {
201
+ messages: [messages.at(-1)],
202
+ threadId,
203
+ resourceId,
204
+ },
205
+ };
206
+ },
207
+ }),
208
+ onFinish(message) {
209
+ console.log('useChat finished', message);
210
+ },
211
+ onError(e) {
212
+ error = e;
213
+ console.error('useChat error:', error);
214
+ },
215
+ });
216
+ return chat;
217
+ });
218
+
219
+ let messageCount = 0;
220
+ async function expectResponse({ message, responseContains }: { message: string; responseContains: string[] }) {
221
+ messageCount++;
222
+ await act(async () => {
223
+ await result.current.sendMessage({
224
+ role: 'user',
225
+ parts: [{ type: 'text', text: message }],
226
+ });
227
+ });
228
+ const responseIndex = messageCount * 2 - 1;
229
+ await waitFor(
230
+ () => {
231
+ expect(error).toBeNull();
232
+ expect(result.current.messages).toHaveLength(messageCount * 2);
233
+ for (const should of responseContains) {
234
+ expect(
235
+ result.current.messages[responseIndex].parts.map(p => (`text` in p ? p.text : '')).join(``),
236
+ ).toContain(should);
237
+ }
238
+ },
239
+ { timeout: 1000 },
240
+ );
241
+ }
242
+
243
+ await expectResponse({
244
+ message: 'what is the weather in Los Angeles?',
245
+ responseContains: ['Los Angeles', '70'],
246
+ });
247
+
248
+ await expectResponse({
249
+ message: 'what is the weather in Seattle?',
250
+ responseContains: ['Seattle', '70'],
251
+ });
252
+ });
253
+
254
+ it('should stream useChat with client side tool calling', async () => {
255
+ let error: Error | null = null;
256
+ const threadId = randomUUID();
257
+
258
+ await weatherAgent.generateVNext(`hi`, {
259
+ threadId,
260
+ resourceId,
261
+ });
262
+ await weatherAgent.generateVNext(`LA weather`, { threadId, resourceId });
263
+
264
+ const agentMemory = (await weatherAgent.getMemory())!;
265
+ const initialMessages = (await agentMemory.query({ threadId })).uiMessages;
266
+ const state = { clipboard: '' };
267
+ const { result } = renderHook(() => {
268
+ const chat = useChat({
269
+ transport: new DefaultChatTransport({
270
+ api: `http://localhost:${port}/api/agents/test/stream/vnext/ui`,
271
+ prepareSendMessagesRequest({ messages }) {
272
+ return {
273
+ body: {
274
+ messages: [messages.at(-1)],
275
+ threadId,
276
+ resourceId,
277
+ },
278
+ };
279
+ },
280
+ }),
281
+ messages: initialMessages as UIMessage[],
282
+ onFinish(message) {
283
+ console.log('useChat finished', message);
284
+ },
285
+ onError(e) {
286
+ error = e;
287
+ console.error('useChat error:', error);
288
+ },
289
+ onToolCall: async ({ toolCall }) => {
290
+ console.log(toolCall);
291
+ if (toolCall.toolName === `clipboard`) {
292
+ await new Promise(res => setTimeout(res, 10));
293
+ return state.clipboard as any as void;
294
+ }
295
+ },
296
+ });
297
+ return chat;
298
+ });
299
+
300
+ async function expectResponse({ message, responseContains }: { message: string; responseContains: string[] }) {
301
+ const messageCountBefore = result.current.messages.length;
302
+ await act(async () => {
303
+ await result.current.sendMessage({
304
+ role: 'user',
305
+ parts: [{ type: 'text', text: message }],
306
+ });
307
+ });
308
+
309
+ // Wait for message count to increase
310
+ await waitFor(
311
+ () => {
312
+ expect(error).toBeNull();
313
+ expect(result.current.messages.length).toBeGreaterThan(messageCountBefore);
314
+ },
315
+ { timeout: 2000 },
316
+ );
317
+
318
+ // Get fresh reference to messages after all waits complete
319
+ const uiMessages = result.current.messages;
320
+ const latestMessage = uiMessages.at(-1);
321
+ if (!latestMessage) throw new Error(`No latest message`);
322
+ if (
323
+ latestMessage.role === `assistant` &&
324
+ latestMessage.parts.length === 2 &&
325
+ latestMessage.parts[1].type === `tool-clipboard`
326
+ ) {
327
+ // client side tool call
328
+ return;
329
+ }
330
+ for (const should of responseContains) {
331
+ let searchString = latestMessage.parts.map(p => (`text` in p ? p.text : ``)).join(``);
332
+
333
+ for (const part of latestMessage.parts) {
334
+ if (part.type === `text`) {
335
+ searchString += `\n${part.text}`;
336
+ }
337
+ if (isToolUIPart(part)) {
338
+ searchString += `\n${JSON.stringify(part)}`;
339
+ }
340
+ }
341
+
342
+ expect(searchString).toContain(should);
343
+ }
344
+ }
345
+
346
+ state.clipboard = `test 1!`;
347
+ await expectResponse({
348
+ message: 'whats in my clipboard?',
349
+ responseContains: [state.clipboard],
350
+ });
351
+ await expectResponse({
352
+ message: 'weather in Las Vegas',
353
+ responseContains: ['Las Vegas', '70'],
354
+ });
355
+ state.clipboard = `test 2!`;
356
+ await expectResponse({
357
+ message: 'whats in my clipboard?',
358
+ responseContains: [state.clipboard],
359
+ });
360
+ state.clipboard = `test 3!`;
361
+ await expectResponse({
362
+ message: 'whats in my clipboard now?',
363
+ responseContains: [state.clipboard],
364
+ });
365
+ });
366
+ });
367
+ });
@@ -0,0 +1,146 @@
1
+ import type { CoreMessage, MastraMessageV1 } from '@mastra/core';
2
+ import { MessageList } from '@mastra/core/agent';
3
+ import type { MastraMessageV2 } from '@mastra/core/agent';
4
+
5
+ const toolArgs = {
6
+ weather: { location: 'New York' },
7
+ calculator: { expression: '2+2' },
8
+ search: { query: 'latest AI developments' },
9
+ };
10
+
11
+ const toolResults = {
12
+ weather: 'Pretty hot',
13
+ calculator: '4',
14
+ search: 'Anthropic blah blah blah',
15
+ };
16
+
17
+ /**
18
+ * Creates a simulated conversation history with alternating messages and occasional tool calls
19
+ * @param threadId Thread ID for the messages
20
+ * @param messageCount Number of turn pairs (user + assistant) to generate
21
+ * @param toolFrequency How often to include tool calls (e.g., 3 means every 3rd assistant message)
22
+ * @returns Array of messages representing the conversation
23
+ */
24
+ export function generateConversationHistory({
25
+ threadId,
26
+ resourceId = 'test-resource',
27
+ messageCount = 5,
28
+ toolFrequency = 3,
29
+ toolNames = ['weather', 'calculator', 'search'],
30
+ }: {
31
+ threadId: string;
32
+ resourceId?: string;
33
+ messageCount?: number;
34
+ toolFrequency?: number;
35
+ toolNames?: (keyof typeof toolArgs)[];
36
+ }): {
37
+ messages: MastraMessageV1[];
38
+ messagesV2: MastraMessageV2[];
39
+ fakeCore: CoreMessage[];
40
+ counts: { messages: number; toolCalls: number; toolResults: number };
41
+ } {
42
+ const counts = { messages: 0, toolCalls: 0, toolResults: 0 };
43
+ // Create some words that will each be about one token
44
+ const words = ['apple', 'banana', 'orange', 'grape'];
45
+ // Arguments for different tools
46
+
47
+ const messages: MastraMessageV2[] = [];
48
+ const startTime = Date.now();
49
+
50
+ // Generate message pairs (user message followed by assistant response)
51
+ for (let i = 0; i < messageCount; i++) {
52
+ // Create user message content
53
+ const userContent = Array(25).fill(words).flat().join(' '); // ~100 tokens
54
+
55
+ // Add user message
56
+ messages.push({
57
+ role: 'user',
58
+ content: { format: 2, parts: [{ type: 'text', text: userContent }] },
59
+ id: `message-${i * 2}`,
60
+ threadId,
61
+ resourceId,
62
+ createdAt: new Date(startTime + i * 2000), // Each pair 2 seconds apart
63
+ });
64
+ counts.messages++;
65
+
66
+ // Determine if this assistant message should include a tool call
67
+ const includeTool = i > 0 && i % toolFrequency === 0;
68
+ const toolIndex = includeTool ? (i / toolFrequency) % toolNames.length : -1;
69
+ const toolName = includeTool ? toolNames[toolIndex] : '';
70
+
71
+ // Create assistant message
72
+ if (includeTool) {
73
+ // Assistant message with tool call
74
+ messages.push({
75
+ role: 'assistant',
76
+ content: {
77
+ format: 2,
78
+ parts: [
79
+ { type: 'text', text: `Using ${toolName} tool:` },
80
+ {
81
+ type: 'tool-invocation',
82
+ toolInvocation: {
83
+ state: 'result',
84
+ toolCallId: `tool-${i}`,
85
+ toolName,
86
+ args: toolArgs[toolName as keyof typeof toolArgs] || {},
87
+ result: toolResults[toolName as keyof typeof toolResults] || {},
88
+ },
89
+ },
90
+ ],
91
+ },
92
+ id: `tool-call-${i * 2 + 1}`,
93
+ threadId,
94
+ resourceId,
95
+ createdAt: new Date(startTime + i * 2000 + 1000), // 1 second after user message
96
+ });
97
+ counts.messages++;
98
+ counts.toolCalls++;
99
+ counts.toolResults++;
100
+ } else {
101
+ // Regular assistant text message
102
+ messages.push({
103
+ role: 'assistant',
104
+ content: { format: 2, parts: [{ type: 'text', text: Array(15).fill(words).flat().join(' ') }] }, // ~60 tokens
105
+ id: `message-${i * 2 + 1}`,
106
+ threadId,
107
+ resourceId,
108
+ createdAt: new Date(startTime + i * 2000 + 1000), // 1 second after user message
109
+ });
110
+ counts.messages++;
111
+ }
112
+ }
113
+
114
+ const latestMessage = messages.at(-1)!;
115
+ if (latestMessage.role === `assistant` && latestMessage.content.parts.at(-1)?.type === `tool-invocation`) {
116
+ const userContent = Array(25).fill(words).flat().join(' '); // ~100 tokens
117
+ messages.push({
118
+ role: 'user',
119
+ content: { format: 2, parts: [{ type: 'text', text: userContent }] },
120
+ id: `message-${messages.length + 1 * 2}`,
121
+ threadId,
122
+ resourceId,
123
+ createdAt: new Date(startTime + messages.length + 1 * 2000), // Each pair 2 seconds apart
124
+ });
125
+ counts.messages++;
126
+ }
127
+
128
+ const list = new MessageList().add(messages, 'memory');
129
+ return {
130
+ fakeCore: list.get.all.v1() as CoreMessage[],
131
+ messages: list.get.all.v1(),
132
+ messagesV2: list.get.all.v2(),
133
+ counts,
134
+ };
135
+ }
136
+
137
+ export function filterToolCallsByName(messages: CoreMessage[], name: string) {
138
+ return messages.filter(
139
+ m => Array.isArray(m.content) && m.content.some(part => part.type === 'tool-call' && part.toolName === name),
140
+ );
141
+ }
142
+ export function filterToolResultsByName(messages: CoreMessage[], name: string) {
143
+ return messages.filter(
144
+ m => Array.isArray(m.content) && m.content.some(part => part.type === 'tool-result' && part.toolName === name),
145
+ );
146
+ }