@mastra/memory 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,256 @@
1
+ import { MastraMemory, MessageType, ThreadType } from '@mastra/core';
2
+ import { randomUUID } from 'crypto';
3
+ import pg from 'pg';
4
+
5
+ const { Pool } = pg;
6
+
7
+ export class PgMemory extends MastraMemory {
8
+ private pool: pg.Pool;
9
+
10
+ constructor(connectionString: string) {
11
+ super();
12
+ this.pool = new Pool({ connectionString });
13
+ }
14
+
15
+ async drop() {
16
+ const client = await this.pool.connect();
17
+ await client.query('DELETE FROM mastra_messages');
18
+ await client.query('DELETE FROM mastra_threads');
19
+ client.release();
20
+ await this.pool.end();
21
+ }
22
+
23
+ async ensureTablesExist(): Promise<void> {
24
+ const client = await this.pool.connect();
25
+ try {
26
+ // Check if the threads table exists
27
+ const threadsResult = await client.query<{ exists: boolean }>(`
28
+ SELECT EXISTS (
29
+ SELECT 1
30
+ FROM information_schema.tables
31
+ WHERE table_name = 'mastra_threads'
32
+ );
33
+ `);
34
+
35
+ if (!threadsResult?.rows?.[0]?.exists) {
36
+ await client.query(`
37
+ CREATE TABLE IF NOT EXISTS mastra_threads (
38
+ id UUID PRIMARY KEY,
39
+ title TEXT,
40
+ created_at TIMESTAMP WITH TIME ZONE NOT NULL,
41
+ updated_at TIMESTAMP WITH TIME ZONE NOT NULL,
42
+ metadata JSONB
43
+ );
44
+ `);
45
+ }
46
+
47
+ // Check if the messages table exists
48
+ const messagesResult = await client.query<{ exists: boolean }>(`
49
+ SELECT EXISTS (
50
+ SELECT 1
51
+ FROM information_schema.tables
52
+ WHERE table_name = 'mastra_messages'
53
+ );
54
+ `);
55
+
56
+ if (!messagesResult?.rows?.[0]?.exists) {
57
+ await client.query(`
58
+ CREATE TABLE IF NOT EXISTS mastra_messages (
59
+ id UUID PRIMARY KEY,
60
+ content TEXT NOT NULL,
61
+ role VARCHAR(20) NOT NULL,
62
+ created_at TIMESTAMP WITH TIME ZONE NOT NULL,
63
+ thread_id UUID NOT NULL,
64
+ FOREIGN KEY (thread_id) REFERENCES mastra_threads(id)
65
+ );
66
+ `);
67
+ }
68
+ } finally {
69
+ client.release();
70
+ }
71
+ }
72
+
73
+ async updateThread(id: string, title: string, metadata: Record<string, unknown>): Promise<ThreadType> {
74
+ const client = await this.pool.connect();
75
+ try {
76
+ const result = await client.query<ThreadType>(
77
+ `
78
+ UPDATE mastra_threads
79
+ SET title = $1, metadata = $2, updated_at = NOW()
80
+ WHERE id = $3
81
+ RETURNING *
82
+ `,
83
+ [title, JSON.stringify(metadata), id],
84
+ );
85
+ return result?.rows?.[0]!;
86
+ } finally {
87
+ client.release();
88
+ }
89
+ }
90
+
91
+ async deleteThread(id: string): Promise<void> {
92
+ const client = await this.pool.connect();
93
+ try {
94
+ await client.query(
95
+ `
96
+ DELETE FROM mastra_messages
97
+ WHERE thread_id = $1
98
+ `,
99
+ [id],
100
+ );
101
+
102
+ await client.query(
103
+ `
104
+ DELETE FROM mastra_threads
105
+ WHERE id = $1
106
+ `,
107
+ [id],
108
+ );
109
+ } finally {
110
+ client.release();
111
+ }
112
+ }
113
+
114
+ async deleteMessage(id: string): Promise<void> {
115
+ const client = await this.pool.connect();
116
+ try {
117
+ await client.query(
118
+ `
119
+ DELETE FROM mastra_messages
120
+ WHERE id = $1
121
+ `,
122
+ [id],
123
+ );
124
+ } finally {
125
+ client.release();
126
+ }
127
+ }
128
+
129
+ async getThreadById(threadId: string): Promise<ThreadType | null> {
130
+ await this.ensureTablesExist();
131
+
132
+ const client = await this.pool.connect();
133
+ try {
134
+ const result = await client.query<ThreadType>(
135
+ `
136
+ SELECT id, title, created_at AS createdAt, updated_at AS updatedAt, metadata
137
+ FROM mastra_threads
138
+ WHERE id = $1
139
+ `,
140
+ [threadId],
141
+ );
142
+ return result.rows[0] || null;
143
+ } finally {
144
+ client.release();
145
+ }
146
+ }
147
+
148
+ async saveThread(thread: ThreadType): Promise<ThreadType> {
149
+ await this.ensureTablesExist();
150
+
151
+ const client = await this.pool.connect();
152
+ try {
153
+ const { id, title, createdAt, updatedAt, metadata } = thread;
154
+ const result = await client.query<ThreadType>(
155
+ `
156
+ INSERT INTO mastra_threads (id, title, created_at, updated_at, metadata)
157
+ VALUES ($1, $2, $3, $4, $5)
158
+ ON CONFLICT (id) DO UPDATE SET title = $2, updated_at = $4, metadata = $5
159
+ RETURNING id, title, created_at AS createdAt, updated_at AS updatedAt, metadata
160
+ `,
161
+ [id, title, createdAt, updatedAt, JSON.stringify(metadata)],
162
+ );
163
+ return result?.rows?.[0]!;
164
+ } finally {
165
+ client.release();
166
+ }
167
+ }
168
+
169
+ async saveMessages(messages: MessageType[]): Promise<MessageType[]> {
170
+ await this.ensureTablesExist();
171
+
172
+ const client = await this.pool.connect();
173
+ try {
174
+ await client.query('BEGIN');
175
+ for (const message of messages) {
176
+ const { id, content, role, createdAt, threadId } = message;
177
+ await client.query(
178
+ `
179
+ INSERT INTO mastra_messages (id, content, role, created_at, thread_id)
180
+ VALUES ($1, $2, $3, $4, $5)
181
+ `,
182
+ [id, content, role, createdAt.toISOString(), threadId],
183
+ );
184
+ }
185
+ await client.query('COMMIT');
186
+ return messages;
187
+ } catch (error) {
188
+ await client.query('ROLLBACK');
189
+ throw error;
190
+ } finally {
191
+ client.release();
192
+ }
193
+ }
194
+
195
+ async getMessages(threadId: string): Promise<MessageType[]> {
196
+ await this.ensureTablesExist();
197
+
198
+ const client = await this.pool.connect();
199
+ try {
200
+ const result = await client.query<MessageType>(
201
+ `
202
+ SELECT
203
+ id,
204
+ content,
205
+ role,
206
+ created_at AS createdAt,
207
+ thread_id AS threadId
208
+ FROM mastra_messages
209
+ WHERE thread_id = $1
210
+ ORDER BY created_at ASC
211
+ `,
212
+ [threadId],
213
+ );
214
+ return result.rows;
215
+ } finally {
216
+ client.release();
217
+ }
218
+ }
219
+
220
+ async createThread(title?: string, metadata?: Record<string, unknown>): Promise<ThreadType> {
221
+ await this.ensureTablesExist();
222
+
223
+ const id = randomUUID();
224
+ const now = new Date();
225
+ const thread: ThreadType = {
226
+ id,
227
+ title,
228
+ createdAt: now,
229
+ updatedAt: now,
230
+ metadata,
231
+ };
232
+ return this.saveThread(thread);
233
+ }
234
+
235
+ async addMessage(threadId: string, content: string, role: 'user' | 'assistant'): Promise<MessageType> {
236
+ await this.ensureTablesExist();
237
+
238
+ const thread = await this.getThreadById(threadId);
239
+ if (!thread) {
240
+ throw new Error(`Thread with ID ${threadId} does not exist.`);
241
+ }
242
+
243
+ const id = randomUUID();
244
+ const message: MessageType = {
245
+ id,
246
+ content,
247
+ role,
248
+ createdAt: new Date(),
249
+ threadId,
250
+ };
251
+
252
+ const [savedMessage] = await this.saveMessages([message]);
253
+
254
+ return savedMessage!;
255
+ }
256
+ }
@@ -0,0 +1,245 @@
1
+ import { ThreadType, MessageType } from '@mastra/core';
2
+
3
+ import { LocalRedisProvider, RedisMemory } from './';
4
+
5
+ describe('RedisMemory', () => {
6
+ let memory: RedisMemory;
7
+ let provider: LocalRedisProvider;
8
+
9
+ beforeAll(async () => {
10
+ provider = new LocalRedisProvider();
11
+ memory = new RedisMemory(provider);
12
+ });
13
+
14
+ afterAll(async () => {
15
+ await memory.cleanup();
16
+ await provider.quit();
17
+ });
18
+
19
+ beforeEach(async () => {
20
+ await provider.flushall();
21
+ });
22
+
23
+ describe('Thread Operations', () => {
24
+ it('should create a thread with metadata', async () => {
25
+ const thread = await memory.createThread('Test Thread', {
26
+ testKey: 'testValue',
27
+ });
28
+
29
+ expect(thread).toEqual({
30
+ id: expect.any(String),
31
+ title: 'Test Thread',
32
+ createdAt: expect.any(Date),
33
+ updatedAt: expect.any(Date),
34
+ metadata: { testKey: 'testValue' },
35
+ });
36
+ });
37
+
38
+ it('should retrieve a created thread', async () => {
39
+ const created = await memory.createThread('Test Thread');
40
+ const retrieved = await memory.getThreadById(created.id);
41
+
42
+ expect(retrieved).toEqual(created);
43
+ });
44
+
45
+ it('should return null for non-existent thread', async () => {
46
+ const thread = await memory.getThreadById('nonexistent');
47
+ expect(thread).toBeNull();
48
+ });
49
+
50
+ it('should update thread title', async () => {
51
+ const thread = await memory.createThread('Initial Title');
52
+ thread.title = 'Updated Title';
53
+
54
+ await memory.saveThread(thread);
55
+ const retrieved = await memory.getThreadById(thread.id);
56
+
57
+ expect(retrieved?.title).toBe('Updated Title');
58
+ });
59
+ });
60
+
61
+ describe('Message Operations', () => {
62
+ let testThread: ThreadType;
63
+
64
+ beforeEach(async () => {
65
+ testThread = await memory.createThread('Test Thread');
66
+ });
67
+
68
+ it('should add a single message', async () => {
69
+ const message = await memory.addMessage(testThread.id, 'Hello World', 'user');
70
+
71
+ const messages = await memory.getMessages(testThread.id);
72
+ expect(messages).toHaveLength(1);
73
+ expect(messages[0]).toEqual(message);
74
+ });
75
+
76
+ it('should maintain message order', async () => {
77
+ await memory.addMessage(testThread.id, 'First', 'user');
78
+ await memory.addMessage(testThread.id, 'Second', 'assistant');
79
+ await memory.addMessage(testThread.id, 'Third', 'user');
80
+
81
+ const messages = await memory.getMessages(testThread.id);
82
+ expect(messages).toHaveLength(3);
83
+ expect(messages.map(m => m.content)).toEqual(['First', 'Second', 'Third']);
84
+ });
85
+
86
+ it('should handle bulk message saves', async () => {
87
+ const messages: MessageType[] = [
88
+ {
89
+ id: 'msg1',
90
+ threadId: testThread.id,
91
+ content: 'Message 1',
92
+ role: 'user',
93
+ createdAt: new Date(),
94
+ },
95
+ {
96
+ id: 'msg2',
97
+ threadId: testThread.id,
98
+ content: 'Message 2',
99
+ role: 'assistant',
100
+ createdAt: new Date(),
101
+ },
102
+ ];
103
+
104
+ await memory.saveMessages(messages);
105
+ const retrieved = await memory.getMessages(testThread.id);
106
+ expect(retrieved).toHaveLength(2);
107
+ });
108
+
109
+ it('should prevent duplicate messages', async () => {
110
+ const message = await memory.addMessage(testThread.id, 'Test', 'user');
111
+ await memory.saveMessages([message]); // Try to save the same message again
112
+
113
+ const messages = await memory.getMessages(testThread.id);
114
+ expect(messages).toHaveLength(1);
115
+ });
116
+ });
117
+
118
+ describe('Thread Management', () => {
119
+ it('should list all thread IDs', async () => {
120
+ const thread1 = await memory.createThread('Thread 1');
121
+ const thread2 = await memory.createThread('Thread 2');
122
+
123
+ const ids = await memory.getAllThreadIds();
124
+ expect(ids).toContain(thread1.id);
125
+ expect(ids).toContain(thread2.id);
126
+ });
127
+
128
+ it('should delete thread and its messages', async () => {
129
+ const thread = await memory.createThread('Delete Test');
130
+ await memory.addMessage(thread.id, 'Test Message', 'user');
131
+
132
+ await memory.deleteThread(thread.id);
133
+
134
+ const deletedThread = await memory.getThreadById(thread.id);
135
+ const messages = await memory.getMessages(thread.id);
136
+
137
+ expect(deletedThread).toBeNull();
138
+ expect(messages).toHaveLength(0);
139
+ });
140
+
141
+ it('should retrieve multiple threads', async () => {
142
+ const thread1 = await memory.createThread('Thread 1');
143
+ const thread2 = await memory.createThread('Thread 2');
144
+
145
+ const threads = await memory.getThreads([thread1.id, thread2.id]);
146
+ expect(threads).toHaveLength(2);
147
+ expect(threads.map(t => t.title)).toEqual(['Thread 1', 'Thread 2']);
148
+ });
149
+ });
150
+
151
+ describe('Error Handling', () => {
152
+ it('should handle saving messages to non-existent thread', async () => {
153
+ const message: MessageType = {
154
+ id: 'test',
155
+ threadId: 'nonexistent',
156
+ content: 'Test',
157
+ role: 'user',
158
+ createdAt: new Date(),
159
+ };
160
+
161
+ await memory.saveMessages([message]);
162
+ const messages = await memory.getMessages('nonexistent');
163
+ expect(messages).toHaveLength(1);
164
+ });
165
+
166
+ it('should handle concurrent message saves', async () => {
167
+ const thread = await memory.createThread('Concurrent Test');
168
+
169
+ // Simulate concurrent saves
170
+ await Promise.all([
171
+ memory.addMessage(thread.id, 'Message 1', 'user'),
172
+ memory.addMessage(thread.id, 'Message 2', 'user'),
173
+ memory.addMessage(thread.id, 'Message 3', 'user'),
174
+ ]);
175
+
176
+ const messages = await memory.getMessages(thread.id);
177
+ expect(messages).toHaveLength(3);
178
+ });
179
+ });
180
+
181
+ describe('Performance', () => {
182
+ it('should handle large number of messages', async () => {
183
+ const thread = await memory.createThread('Bulk Test');
184
+ const messageCount = 100;
185
+
186
+ // Create messages array first
187
+ const messages = Array.from({ length: messageCount }, (_, i) => ({
188
+ id: memory['generateId'](),
189
+ content: `Message ${i}`,
190
+ role: i % 2 === 0 ? ('user' as const) : ('assistant' as const),
191
+ createdAt: new Date(),
192
+ threadId: thread.id,
193
+ }));
194
+
195
+ // Save all messages in one operation
196
+ await memory.saveMessages(messages);
197
+
198
+ const retrievedMessages = await memory.getMessages(thread.id);
199
+ expect(retrievedMessages).toHaveLength(messageCount);
200
+
201
+ // Verify message content and ordering
202
+ const sortedMessages = retrievedMessages.sort((a, b) => {
203
+ const aNum = parseInt(a.content.split(' ')[1] || '0');
204
+ const bNum = parseInt(b.content.split(' ')[1] || '0');
205
+ return aNum - bNum;
206
+ });
207
+
208
+ sortedMessages.forEach((msg, i) => {
209
+ expect(msg.content).toBe(`Message ${i}`);
210
+ });
211
+ });
212
+
213
+ // Add a test for truly concurrent operations
214
+ it('should handle concurrent batch saves', async () => {
215
+ const thread = await memory.createThread('Concurrent Batch Test');
216
+ const batchSize = 20;
217
+ const numberOfBatches = 5;
218
+
219
+ // Create batches of messages
220
+ const batches = Array.from({ length: numberOfBatches }, (_, batchIndex) =>
221
+ Array.from({ length: batchSize }, (_, i) => ({
222
+ id: memory['generateId'](),
223
+ content: `Batch ${batchIndex} Message ${i}`,
224
+ role: i % 2 === 0 ? ('user' as const) : ('assistant' as const),
225
+ createdAt: new Date(),
226
+ threadId: thread.id,
227
+ })),
228
+ );
229
+
230
+ // Save batches concurrently
231
+ await Promise.all(batches.map(batch => memory.saveMessages(batch)));
232
+
233
+ const messages = await memory.getMessages(thread.id);
234
+ expect(messages).toHaveLength(batchSize * numberOfBatches);
235
+
236
+ // Verify all messages are present
237
+ const messageSet = new Set(messages.map(m => m.content));
238
+ for (let batchIndex = 0; batchIndex < numberOfBatches; batchIndex++) {
239
+ for (let i = 0; i < batchSize; i++) {
240
+ expect(messageSet.has(`Batch ${batchIndex} Message ${i}`)).toBeTruthy();
241
+ }
242
+ }
243
+ });
244
+ });
245
+ });
@@ -0,0 +1,189 @@
1
+ import { MastraMemory, ThreadType, MessageType } from '@mastra/core';
2
+
3
+ import { RedisClient } from './types';
4
+
5
+ export * from './types';
6
+ export * from './providers';
7
+
8
+ export class RedisMemory extends MastraMemory {
9
+ private redis: RedisClient;
10
+ private threadPrefix = 'thread:';
11
+ private messagePrefix = 'messages:';
12
+ private lockTimeouts: Map<string, NodeJS.Timeout> = new Map();
13
+
14
+ constructor(redis: RedisClient) {
15
+ super();
16
+ this.redis = redis;
17
+ }
18
+
19
+ async getThreadById(threadId: string): Promise<ThreadType | null> {
20
+ const thread = await this.redis.get(`${this.threadPrefix}${threadId}`);
21
+ if (thread && typeof thread.createdAt === 'string') {
22
+ thread.createdAt = new Date(thread.createdAt);
23
+ thread.updatedAt = new Date(thread.updatedAt);
24
+ }
25
+ return thread;
26
+ }
27
+
28
+ async saveThread(thread: ThreadType): Promise<ThreadType> {
29
+ thread.updatedAt = new Date();
30
+ await this.redis.set(`${this.threadPrefix}${thread.id}`, thread);
31
+ await this.redis.sadd('threads', thread.id);
32
+ return thread;
33
+ }
34
+
35
+ private async withLock<T>(key: string, operation: () => Promise<T>): Promise<T> {
36
+ const lockKey = `lock:${key}`;
37
+ const lockTimeout = 5000;
38
+ let locked = false;
39
+ let timeoutId: NodeJS.Timeout | undefined;
40
+
41
+ try {
42
+ // Try to acquire lock with retries
43
+ for (let i = 0; i < 3; i++) {
44
+ locked = await this.redis.sadd(lockKey, '1');
45
+ if (locked) break;
46
+ await new Promise(resolve => setTimeout(resolve, Math.random() * 100));
47
+ }
48
+
49
+ if (!locked) {
50
+ throw new Error('Could not acquire lock');
51
+ }
52
+
53
+ // Set lock timeout
54
+ timeoutId = setTimeout(async () => {
55
+ try {
56
+ await this.redis.del(lockKey);
57
+ } catch {
58
+ // Ignore errors during cleanup
59
+ }
60
+ this.lockTimeouts.delete(lockKey);
61
+ }, lockTimeout);
62
+
63
+ this.lockTimeouts.set(lockKey, timeoutId);
64
+
65
+ // Execute operation
66
+ return await operation();
67
+ } finally {
68
+ if (timeoutId !== undefined) {
69
+ clearTimeout(timeoutId);
70
+ this.lockTimeouts.delete(lockKey);
71
+ }
72
+ if (locked) {
73
+ try {
74
+ await this.redis.del(lockKey);
75
+ } catch {
76
+ // Ignore errors during cleanup
77
+ }
78
+ }
79
+ }
80
+ }
81
+
82
+ // Add cleanup method
83
+ async cleanup(): Promise<void> {
84
+ // Clear all timeouts
85
+ for (const timeout of this.lockTimeouts.values()) {
86
+ clearTimeout(timeout);
87
+ }
88
+ this.lockTimeouts.clear();
89
+ }
90
+
91
+ async saveMessages(messages: MessageType[]): Promise<MessageType[]> {
92
+ if (!messages.length) return [];
93
+
94
+ const messagesByThread = messages.reduce(
95
+ (acc, message) => {
96
+ const key = `${this.messagePrefix}${message.threadId}`;
97
+ if (!acc[key]) {
98
+ acc[key] = [];
99
+ }
100
+ acc[key].push({
101
+ ...message,
102
+ createdAt: new Date(message.createdAt),
103
+ });
104
+ return acc;
105
+ },
106
+ {} as Record<string, MessageType[]>,
107
+ );
108
+
109
+ for (const [key, threadMessages] of Object.entries(messagesByThread)) {
110
+ await this.withLock(key, async () => {
111
+ const existingMessages = (await this.redis.get(key)) || [];
112
+
113
+ const messageMap = new Map<string, MessageType>();
114
+
115
+ // Process existing messages
116
+ existingMessages.forEach((msg: MessageType) => {
117
+ messageMap.set(msg.id, {
118
+ ...msg,
119
+ createdAt: new Date(msg.createdAt),
120
+ });
121
+ });
122
+
123
+ // Add new messages
124
+ threadMessages.forEach(msg => {
125
+ messageMap.set(msg.id, msg);
126
+ });
127
+
128
+ const updatedMessages = Array.from(messageMap.values());
129
+ updatedMessages.sort((a, b) => {
130
+ const timeCompare = a.createdAt.getTime() - b.createdAt.getTime();
131
+ return timeCompare === 0 ? a.id.localeCompare(b.id) : timeCompare;
132
+ });
133
+
134
+ await this.redis.set(key, updatedMessages);
135
+ if (threadMessages?.[0]?.threadId) {
136
+ const thread = await this.getThreadById(threadMessages?.[0]?.threadId);
137
+ if (thread) {
138
+ thread.updatedAt = new Date();
139
+ await this.redis.set(`${this.threadPrefix}${thread.id}`, thread);
140
+ }
141
+ }
142
+ });
143
+ }
144
+
145
+ return messages;
146
+ }
147
+
148
+ async addMessage(threadId: string, content: string, role: 'user' | 'assistant'): Promise<MessageType> {
149
+ const message: MessageType = {
150
+ id: this.generateId(),
151
+ content,
152
+ role,
153
+ createdAt: new Date(),
154
+ threadId,
155
+ };
156
+
157
+ await this.saveMessages([message]);
158
+ return message;
159
+ }
160
+
161
+ async getMessages(threadId: string): Promise<MessageType[]> {
162
+ const messages = (await this.redis.get(`${this.messagePrefix}${threadId}`)) || [];
163
+ return messages.map((msg: MessageType) => ({
164
+ ...msg,
165
+ createdAt: new Date(msg.createdAt),
166
+ }));
167
+ }
168
+
169
+ async getAllThreadIds(): Promise<string[]> {
170
+ return this.redis.smembers('threads');
171
+ }
172
+
173
+ async deleteThread(threadId: string): Promise<void> {
174
+ const pipeline = this.redis.pipeline();
175
+ pipeline.del(`${this.threadPrefix}${threadId}`);
176
+ pipeline.del(`${this.messagePrefix}${threadId}`);
177
+ pipeline.srem('threads', threadId);
178
+ await pipeline.exec();
179
+ }
180
+
181
+ async getThreads(threadIds: string[]): Promise<ThreadType[]> {
182
+ const threads = await Promise.all(threadIds.map(id => this.getThreadById(id)));
183
+ return threads.filter((t): t is ThreadType => t !== null);
184
+ }
185
+
186
+ protected generateId(): string {
187
+ return crypto.randomUUID();
188
+ }
189
+ }