@mastra/memory 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -0
- package/dist/cloudflare/index.d.ts +18 -0
- package/dist/cloudflare/kv.d.ts +53 -0
- package/dist/index.d.ts +3 -0
- package/dist/index.js +8 -0
- package/dist/memory.cjs.development.js +2369 -0
- package/dist/memory.cjs.development.js.map +1 -0
- package/dist/memory.cjs.production.min.js +2 -0
- package/dist/memory.cjs.production.min.js.map +1 -0
- package/dist/memory.esm.js +2361 -0
- package/dist/memory.esm.js.map +1 -0
- package/dist/postgres/index.d.ts +16 -0
- package/dist/redis/index.d.ts +22 -0
- package/dist/redis/providers.d.ts +26 -0
- package/dist/redis/types.d.ts +17 -0
- package/docker-compose.yaml +18 -0
- package/jest.config.ts +19 -0
- package/package.json +49 -0
- package/src/cloudflare/index.test.ts +230 -0
- package/src/cloudflare/index.ts +169 -0
- package/src/cloudflare/kv.ts +139 -0
- package/src/index.ts +3 -0
- package/src/postgres/index.test.ts +60 -0
- package/src/postgres/index.ts +256 -0
- package/src/redis/index.test.ts +245 -0
- package/src/redis/index.ts +189 -0
- package/src/redis/providers.ts +191 -0
- package/src/redis/types.ts +18 -0
- package/tsconfig.json +11 -0
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
import { MastraMemory, MessageType, ThreadType } from '@mastra/core';
|
|
2
|
+
import { randomUUID } from 'crypto';
|
|
3
|
+
import pg from 'pg';
|
|
4
|
+
|
|
5
|
+
const { Pool } = pg;
|
|
6
|
+
|
|
7
|
+
export class PgMemory extends MastraMemory {
|
|
8
|
+
private pool: pg.Pool;
|
|
9
|
+
|
|
10
|
+
constructor(connectionString: string) {
|
|
11
|
+
super();
|
|
12
|
+
this.pool = new Pool({ connectionString });
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
async drop() {
|
|
16
|
+
const client = await this.pool.connect();
|
|
17
|
+
await client.query('DELETE FROM mastra_messages');
|
|
18
|
+
await client.query('DELETE FROM mastra_threads');
|
|
19
|
+
client.release();
|
|
20
|
+
await this.pool.end();
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
async ensureTablesExist(): Promise<void> {
|
|
24
|
+
const client = await this.pool.connect();
|
|
25
|
+
try {
|
|
26
|
+
// Check if the threads table exists
|
|
27
|
+
const threadsResult = await client.query<{ exists: boolean }>(`
|
|
28
|
+
SELECT EXISTS (
|
|
29
|
+
SELECT 1
|
|
30
|
+
FROM information_schema.tables
|
|
31
|
+
WHERE table_name = 'mastra_threads'
|
|
32
|
+
);
|
|
33
|
+
`);
|
|
34
|
+
|
|
35
|
+
if (!threadsResult?.rows?.[0]?.exists) {
|
|
36
|
+
await client.query(`
|
|
37
|
+
CREATE TABLE IF NOT EXISTS mastra_threads (
|
|
38
|
+
id UUID PRIMARY KEY,
|
|
39
|
+
title TEXT,
|
|
40
|
+
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
|
41
|
+
updated_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
|
42
|
+
metadata JSONB
|
|
43
|
+
);
|
|
44
|
+
`);
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
// Check if the messages table exists
|
|
48
|
+
const messagesResult = await client.query<{ exists: boolean }>(`
|
|
49
|
+
SELECT EXISTS (
|
|
50
|
+
SELECT 1
|
|
51
|
+
FROM information_schema.tables
|
|
52
|
+
WHERE table_name = 'mastra_messages'
|
|
53
|
+
);
|
|
54
|
+
`);
|
|
55
|
+
|
|
56
|
+
if (!messagesResult?.rows?.[0]?.exists) {
|
|
57
|
+
await client.query(`
|
|
58
|
+
CREATE TABLE IF NOT EXISTS mastra_messages (
|
|
59
|
+
id UUID PRIMARY KEY,
|
|
60
|
+
content TEXT NOT NULL,
|
|
61
|
+
role VARCHAR(20) NOT NULL,
|
|
62
|
+
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
|
63
|
+
thread_id UUID NOT NULL,
|
|
64
|
+
FOREIGN KEY (thread_id) REFERENCES mastra_threads(id)
|
|
65
|
+
);
|
|
66
|
+
`);
|
|
67
|
+
}
|
|
68
|
+
} finally {
|
|
69
|
+
client.release();
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
async updateThread(id: string, title: string, metadata: Record<string, unknown>): Promise<ThreadType> {
|
|
74
|
+
const client = await this.pool.connect();
|
|
75
|
+
try {
|
|
76
|
+
const result = await client.query<ThreadType>(
|
|
77
|
+
`
|
|
78
|
+
UPDATE mastra_threads
|
|
79
|
+
SET title = $1, metadata = $2, updated_at = NOW()
|
|
80
|
+
WHERE id = $3
|
|
81
|
+
RETURNING *
|
|
82
|
+
`,
|
|
83
|
+
[title, JSON.stringify(metadata), id],
|
|
84
|
+
);
|
|
85
|
+
return result?.rows?.[0]!;
|
|
86
|
+
} finally {
|
|
87
|
+
client.release();
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
async deleteThread(id: string): Promise<void> {
|
|
92
|
+
const client = await this.pool.connect();
|
|
93
|
+
try {
|
|
94
|
+
await client.query(
|
|
95
|
+
`
|
|
96
|
+
DELETE FROM mastra_messages
|
|
97
|
+
WHERE thread_id = $1
|
|
98
|
+
`,
|
|
99
|
+
[id],
|
|
100
|
+
);
|
|
101
|
+
|
|
102
|
+
await client.query(
|
|
103
|
+
`
|
|
104
|
+
DELETE FROM mastra_threads
|
|
105
|
+
WHERE id = $1
|
|
106
|
+
`,
|
|
107
|
+
[id],
|
|
108
|
+
);
|
|
109
|
+
} finally {
|
|
110
|
+
client.release();
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
async deleteMessage(id: string): Promise<void> {
|
|
115
|
+
const client = await this.pool.connect();
|
|
116
|
+
try {
|
|
117
|
+
await client.query(
|
|
118
|
+
`
|
|
119
|
+
DELETE FROM mastra_messages
|
|
120
|
+
WHERE id = $1
|
|
121
|
+
`,
|
|
122
|
+
[id],
|
|
123
|
+
);
|
|
124
|
+
} finally {
|
|
125
|
+
client.release();
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
async getThreadById(threadId: string): Promise<ThreadType | null> {
|
|
130
|
+
await this.ensureTablesExist();
|
|
131
|
+
|
|
132
|
+
const client = await this.pool.connect();
|
|
133
|
+
try {
|
|
134
|
+
const result = await client.query<ThreadType>(
|
|
135
|
+
`
|
|
136
|
+
SELECT id, title, created_at AS createdAt, updated_at AS updatedAt, metadata
|
|
137
|
+
FROM mastra_threads
|
|
138
|
+
WHERE id = $1
|
|
139
|
+
`,
|
|
140
|
+
[threadId],
|
|
141
|
+
);
|
|
142
|
+
return result.rows[0] || null;
|
|
143
|
+
} finally {
|
|
144
|
+
client.release();
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
async saveThread(thread: ThreadType): Promise<ThreadType> {
|
|
149
|
+
await this.ensureTablesExist();
|
|
150
|
+
|
|
151
|
+
const client = await this.pool.connect();
|
|
152
|
+
try {
|
|
153
|
+
const { id, title, createdAt, updatedAt, metadata } = thread;
|
|
154
|
+
const result = await client.query<ThreadType>(
|
|
155
|
+
`
|
|
156
|
+
INSERT INTO mastra_threads (id, title, created_at, updated_at, metadata)
|
|
157
|
+
VALUES ($1, $2, $3, $4, $5)
|
|
158
|
+
ON CONFLICT (id) DO UPDATE SET title = $2, updated_at = $4, metadata = $5
|
|
159
|
+
RETURNING id, title, created_at AS createdAt, updated_at AS updatedAt, metadata
|
|
160
|
+
`,
|
|
161
|
+
[id, title, createdAt, updatedAt, JSON.stringify(metadata)],
|
|
162
|
+
);
|
|
163
|
+
return result?.rows?.[0]!;
|
|
164
|
+
} finally {
|
|
165
|
+
client.release();
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
async saveMessages(messages: MessageType[]): Promise<MessageType[]> {
|
|
170
|
+
await this.ensureTablesExist();
|
|
171
|
+
|
|
172
|
+
const client = await this.pool.connect();
|
|
173
|
+
try {
|
|
174
|
+
await client.query('BEGIN');
|
|
175
|
+
for (const message of messages) {
|
|
176
|
+
const { id, content, role, createdAt, threadId } = message;
|
|
177
|
+
await client.query(
|
|
178
|
+
`
|
|
179
|
+
INSERT INTO mastra_messages (id, content, role, created_at, thread_id)
|
|
180
|
+
VALUES ($1, $2, $3, $4, $5)
|
|
181
|
+
`,
|
|
182
|
+
[id, content, role, createdAt.toISOString(), threadId],
|
|
183
|
+
);
|
|
184
|
+
}
|
|
185
|
+
await client.query('COMMIT');
|
|
186
|
+
return messages;
|
|
187
|
+
} catch (error) {
|
|
188
|
+
await client.query('ROLLBACK');
|
|
189
|
+
throw error;
|
|
190
|
+
} finally {
|
|
191
|
+
client.release();
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
async getMessages(threadId: string): Promise<MessageType[]> {
|
|
196
|
+
await this.ensureTablesExist();
|
|
197
|
+
|
|
198
|
+
const client = await this.pool.connect();
|
|
199
|
+
try {
|
|
200
|
+
const result = await client.query<MessageType>(
|
|
201
|
+
`
|
|
202
|
+
SELECT
|
|
203
|
+
id,
|
|
204
|
+
content,
|
|
205
|
+
role,
|
|
206
|
+
created_at AS createdAt,
|
|
207
|
+
thread_id AS threadId
|
|
208
|
+
FROM mastra_messages
|
|
209
|
+
WHERE thread_id = $1
|
|
210
|
+
ORDER BY created_at ASC
|
|
211
|
+
`,
|
|
212
|
+
[threadId],
|
|
213
|
+
);
|
|
214
|
+
return result.rows;
|
|
215
|
+
} finally {
|
|
216
|
+
client.release();
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
async createThread(title?: string, metadata?: Record<string, unknown>): Promise<ThreadType> {
|
|
221
|
+
await this.ensureTablesExist();
|
|
222
|
+
|
|
223
|
+
const id = randomUUID();
|
|
224
|
+
const now = new Date();
|
|
225
|
+
const thread: ThreadType = {
|
|
226
|
+
id,
|
|
227
|
+
title,
|
|
228
|
+
createdAt: now,
|
|
229
|
+
updatedAt: now,
|
|
230
|
+
metadata,
|
|
231
|
+
};
|
|
232
|
+
return this.saveThread(thread);
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
async addMessage(threadId: string, content: string, role: 'user' | 'assistant'): Promise<MessageType> {
|
|
236
|
+
await this.ensureTablesExist();
|
|
237
|
+
|
|
238
|
+
const thread = await this.getThreadById(threadId);
|
|
239
|
+
if (!thread) {
|
|
240
|
+
throw new Error(`Thread with ID ${threadId} does not exist.`);
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
const id = randomUUID();
|
|
244
|
+
const message: MessageType = {
|
|
245
|
+
id,
|
|
246
|
+
content,
|
|
247
|
+
role,
|
|
248
|
+
createdAt: new Date(),
|
|
249
|
+
threadId,
|
|
250
|
+
};
|
|
251
|
+
|
|
252
|
+
const [savedMessage] = await this.saveMessages([message]);
|
|
253
|
+
|
|
254
|
+
return savedMessage!;
|
|
255
|
+
}
|
|
256
|
+
}
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
import { ThreadType, MessageType } from '@mastra/core';
|
|
2
|
+
|
|
3
|
+
import { LocalRedisProvider, RedisMemory } from './';
|
|
4
|
+
|
|
5
|
+
describe('RedisMemory', () => {
|
|
6
|
+
let memory: RedisMemory;
|
|
7
|
+
let provider: LocalRedisProvider;
|
|
8
|
+
|
|
9
|
+
beforeAll(async () => {
|
|
10
|
+
provider = new LocalRedisProvider();
|
|
11
|
+
memory = new RedisMemory(provider);
|
|
12
|
+
});
|
|
13
|
+
|
|
14
|
+
afterAll(async () => {
|
|
15
|
+
await memory.cleanup();
|
|
16
|
+
await provider.quit();
|
|
17
|
+
});
|
|
18
|
+
|
|
19
|
+
beforeEach(async () => {
|
|
20
|
+
await provider.flushall();
|
|
21
|
+
});
|
|
22
|
+
|
|
23
|
+
describe('Thread Operations', () => {
|
|
24
|
+
it('should create a thread with metadata', async () => {
|
|
25
|
+
const thread = await memory.createThread('Test Thread', {
|
|
26
|
+
testKey: 'testValue',
|
|
27
|
+
});
|
|
28
|
+
|
|
29
|
+
expect(thread).toEqual({
|
|
30
|
+
id: expect.any(String),
|
|
31
|
+
title: 'Test Thread',
|
|
32
|
+
createdAt: expect.any(Date),
|
|
33
|
+
updatedAt: expect.any(Date),
|
|
34
|
+
metadata: { testKey: 'testValue' },
|
|
35
|
+
});
|
|
36
|
+
});
|
|
37
|
+
|
|
38
|
+
it('should retrieve a created thread', async () => {
|
|
39
|
+
const created = await memory.createThread('Test Thread');
|
|
40
|
+
const retrieved = await memory.getThreadById(created.id);
|
|
41
|
+
|
|
42
|
+
expect(retrieved).toEqual(created);
|
|
43
|
+
});
|
|
44
|
+
|
|
45
|
+
it('should return null for non-existent thread', async () => {
|
|
46
|
+
const thread = await memory.getThreadById('nonexistent');
|
|
47
|
+
expect(thread).toBeNull();
|
|
48
|
+
});
|
|
49
|
+
|
|
50
|
+
it('should update thread title', async () => {
|
|
51
|
+
const thread = await memory.createThread('Initial Title');
|
|
52
|
+
thread.title = 'Updated Title';
|
|
53
|
+
|
|
54
|
+
await memory.saveThread(thread);
|
|
55
|
+
const retrieved = await memory.getThreadById(thread.id);
|
|
56
|
+
|
|
57
|
+
expect(retrieved?.title).toBe('Updated Title');
|
|
58
|
+
});
|
|
59
|
+
});
|
|
60
|
+
|
|
61
|
+
describe('Message Operations', () => {
|
|
62
|
+
let testThread: ThreadType;
|
|
63
|
+
|
|
64
|
+
beforeEach(async () => {
|
|
65
|
+
testThread = await memory.createThread('Test Thread');
|
|
66
|
+
});
|
|
67
|
+
|
|
68
|
+
it('should add a single message', async () => {
|
|
69
|
+
const message = await memory.addMessage(testThread.id, 'Hello World', 'user');
|
|
70
|
+
|
|
71
|
+
const messages = await memory.getMessages(testThread.id);
|
|
72
|
+
expect(messages).toHaveLength(1);
|
|
73
|
+
expect(messages[0]).toEqual(message);
|
|
74
|
+
});
|
|
75
|
+
|
|
76
|
+
it('should maintain message order', async () => {
|
|
77
|
+
await memory.addMessage(testThread.id, 'First', 'user');
|
|
78
|
+
await memory.addMessage(testThread.id, 'Second', 'assistant');
|
|
79
|
+
await memory.addMessage(testThread.id, 'Third', 'user');
|
|
80
|
+
|
|
81
|
+
const messages = await memory.getMessages(testThread.id);
|
|
82
|
+
expect(messages).toHaveLength(3);
|
|
83
|
+
expect(messages.map(m => m.content)).toEqual(['First', 'Second', 'Third']);
|
|
84
|
+
});
|
|
85
|
+
|
|
86
|
+
it('should handle bulk message saves', async () => {
|
|
87
|
+
const messages: MessageType[] = [
|
|
88
|
+
{
|
|
89
|
+
id: 'msg1',
|
|
90
|
+
threadId: testThread.id,
|
|
91
|
+
content: 'Message 1',
|
|
92
|
+
role: 'user',
|
|
93
|
+
createdAt: new Date(),
|
|
94
|
+
},
|
|
95
|
+
{
|
|
96
|
+
id: 'msg2',
|
|
97
|
+
threadId: testThread.id,
|
|
98
|
+
content: 'Message 2',
|
|
99
|
+
role: 'assistant',
|
|
100
|
+
createdAt: new Date(),
|
|
101
|
+
},
|
|
102
|
+
];
|
|
103
|
+
|
|
104
|
+
await memory.saveMessages(messages);
|
|
105
|
+
const retrieved = await memory.getMessages(testThread.id);
|
|
106
|
+
expect(retrieved).toHaveLength(2);
|
|
107
|
+
});
|
|
108
|
+
|
|
109
|
+
it('should prevent duplicate messages', async () => {
|
|
110
|
+
const message = await memory.addMessage(testThread.id, 'Test', 'user');
|
|
111
|
+
await memory.saveMessages([message]); // Try to save the same message again
|
|
112
|
+
|
|
113
|
+
const messages = await memory.getMessages(testThread.id);
|
|
114
|
+
expect(messages).toHaveLength(1);
|
|
115
|
+
});
|
|
116
|
+
});
|
|
117
|
+
|
|
118
|
+
describe('Thread Management', () => {
|
|
119
|
+
it('should list all thread IDs', async () => {
|
|
120
|
+
const thread1 = await memory.createThread('Thread 1');
|
|
121
|
+
const thread2 = await memory.createThread('Thread 2');
|
|
122
|
+
|
|
123
|
+
const ids = await memory.getAllThreadIds();
|
|
124
|
+
expect(ids).toContain(thread1.id);
|
|
125
|
+
expect(ids).toContain(thread2.id);
|
|
126
|
+
});
|
|
127
|
+
|
|
128
|
+
it('should delete thread and its messages', async () => {
|
|
129
|
+
const thread = await memory.createThread('Delete Test');
|
|
130
|
+
await memory.addMessage(thread.id, 'Test Message', 'user');
|
|
131
|
+
|
|
132
|
+
await memory.deleteThread(thread.id);
|
|
133
|
+
|
|
134
|
+
const deletedThread = await memory.getThreadById(thread.id);
|
|
135
|
+
const messages = await memory.getMessages(thread.id);
|
|
136
|
+
|
|
137
|
+
expect(deletedThread).toBeNull();
|
|
138
|
+
expect(messages).toHaveLength(0);
|
|
139
|
+
});
|
|
140
|
+
|
|
141
|
+
it('should retrieve multiple threads', async () => {
|
|
142
|
+
const thread1 = await memory.createThread('Thread 1');
|
|
143
|
+
const thread2 = await memory.createThread('Thread 2');
|
|
144
|
+
|
|
145
|
+
const threads = await memory.getThreads([thread1.id, thread2.id]);
|
|
146
|
+
expect(threads).toHaveLength(2);
|
|
147
|
+
expect(threads.map(t => t.title)).toEqual(['Thread 1', 'Thread 2']);
|
|
148
|
+
});
|
|
149
|
+
});
|
|
150
|
+
|
|
151
|
+
describe('Error Handling', () => {
|
|
152
|
+
it('should handle saving messages to non-existent thread', async () => {
|
|
153
|
+
const message: MessageType = {
|
|
154
|
+
id: 'test',
|
|
155
|
+
threadId: 'nonexistent',
|
|
156
|
+
content: 'Test',
|
|
157
|
+
role: 'user',
|
|
158
|
+
createdAt: new Date(),
|
|
159
|
+
};
|
|
160
|
+
|
|
161
|
+
await memory.saveMessages([message]);
|
|
162
|
+
const messages = await memory.getMessages('nonexistent');
|
|
163
|
+
expect(messages).toHaveLength(1);
|
|
164
|
+
});
|
|
165
|
+
|
|
166
|
+
it('should handle concurrent message saves', async () => {
|
|
167
|
+
const thread = await memory.createThread('Concurrent Test');
|
|
168
|
+
|
|
169
|
+
// Simulate concurrent saves
|
|
170
|
+
await Promise.all([
|
|
171
|
+
memory.addMessage(thread.id, 'Message 1', 'user'),
|
|
172
|
+
memory.addMessage(thread.id, 'Message 2', 'user'),
|
|
173
|
+
memory.addMessage(thread.id, 'Message 3', 'user'),
|
|
174
|
+
]);
|
|
175
|
+
|
|
176
|
+
const messages = await memory.getMessages(thread.id);
|
|
177
|
+
expect(messages).toHaveLength(3);
|
|
178
|
+
});
|
|
179
|
+
});
|
|
180
|
+
|
|
181
|
+
describe('Performance', () => {
|
|
182
|
+
it('should handle large number of messages', async () => {
|
|
183
|
+
const thread = await memory.createThread('Bulk Test');
|
|
184
|
+
const messageCount = 100;
|
|
185
|
+
|
|
186
|
+
// Create messages array first
|
|
187
|
+
const messages = Array.from({ length: messageCount }, (_, i) => ({
|
|
188
|
+
id: memory['generateId'](),
|
|
189
|
+
content: `Message ${i}`,
|
|
190
|
+
role: i % 2 === 0 ? ('user' as const) : ('assistant' as const),
|
|
191
|
+
createdAt: new Date(),
|
|
192
|
+
threadId: thread.id,
|
|
193
|
+
}));
|
|
194
|
+
|
|
195
|
+
// Save all messages in one operation
|
|
196
|
+
await memory.saveMessages(messages);
|
|
197
|
+
|
|
198
|
+
const retrievedMessages = await memory.getMessages(thread.id);
|
|
199
|
+
expect(retrievedMessages).toHaveLength(messageCount);
|
|
200
|
+
|
|
201
|
+
// Verify message content and ordering
|
|
202
|
+
const sortedMessages = retrievedMessages.sort((a, b) => {
|
|
203
|
+
const aNum = parseInt(a.content.split(' ')[1] || '0');
|
|
204
|
+
const bNum = parseInt(b.content.split(' ')[1] || '0');
|
|
205
|
+
return aNum - bNum;
|
|
206
|
+
});
|
|
207
|
+
|
|
208
|
+
sortedMessages.forEach((msg, i) => {
|
|
209
|
+
expect(msg.content).toBe(`Message ${i}`);
|
|
210
|
+
});
|
|
211
|
+
});
|
|
212
|
+
|
|
213
|
+
// Add a test for truly concurrent operations
|
|
214
|
+
it('should handle concurrent batch saves', async () => {
|
|
215
|
+
const thread = await memory.createThread('Concurrent Batch Test');
|
|
216
|
+
const batchSize = 20;
|
|
217
|
+
const numberOfBatches = 5;
|
|
218
|
+
|
|
219
|
+
// Create batches of messages
|
|
220
|
+
const batches = Array.from({ length: numberOfBatches }, (_, batchIndex) =>
|
|
221
|
+
Array.from({ length: batchSize }, (_, i) => ({
|
|
222
|
+
id: memory['generateId'](),
|
|
223
|
+
content: `Batch ${batchIndex} Message ${i}`,
|
|
224
|
+
role: i % 2 === 0 ? ('user' as const) : ('assistant' as const),
|
|
225
|
+
createdAt: new Date(),
|
|
226
|
+
threadId: thread.id,
|
|
227
|
+
})),
|
|
228
|
+
);
|
|
229
|
+
|
|
230
|
+
// Save batches concurrently
|
|
231
|
+
await Promise.all(batches.map(batch => memory.saveMessages(batch)));
|
|
232
|
+
|
|
233
|
+
const messages = await memory.getMessages(thread.id);
|
|
234
|
+
expect(messages).toHaveLength(batchSize * numberOfBatches);
|
|
235
|
+
|
|
236
|
+
// Verify all messages are present
|
|
237
|
+
const messageSet = new Set(messages.map(m => m.content));
|
|
238
|
+
for (let batchIndex = 0; batchIndex < numberOfBatches; batchIndex++) {
|
|
239
|
+
for (let i = 0; i < batchSize; i++) {
|
|
240
|
+
expect(messageSet.has(`Batch ${batchIndex} Message ${i}`)).toBeTruthy();
|
|
241
|
+
}
|
|
242
|
+
}
|
|
243
|
+
});
|
|
244
|
+
});
|
|
245
|
+
});
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
import { MastraMemory, ThreadType, MessageType } from '@mastra/core';
|
|
2
|
+
|
|
3
|
+
import { RedisClient } from './types';
|
|
4
|
+
|
|
5
|
+
export * from './types';
|
|
6
|
+
export * from './providers';
|
|
7
|
+
|
|
8
|
+
export class RedisMemory extends MastraMemory {
|
|
9
|
+
private redis: RedisClient;
|
|
10
|
+
private threadPrefix = 'thread:';
|
|
11
|
+
private messagePrefix = 'messages:';
|
|
12
|
+
private lockTimeouts: Map<string, NodeJS.Timeout> = new Map();
|
|
13
|
+
|
|
14
|
+
constructor(redis: RedisClient) {
|
|
15
|
+
super();
|
|
16
|
+
this.redis = redis;
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
async getThreadById(threadId: string): Promise<ThreadType | null> {
|
|
20
|
+
const thread = await this.redis.get(`${this.threadPrefix}${threadId}`);
|
|
21
|
+
if (thread && typeof thread.createdAt === 'string') {
|
|
22
|
+
thread.createdAt = new Date(thread.createdAt);
|
|
23
|
+
thread.updatedAt = new Date(thread.updatedAt);
|
|
24
|
+
}
|
|
25
|
+
return thread;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
async saveThread(thread: ThreadType): Promise<ThreadType> {
|
|
29
|
+
thread.updatedAt = new Date();
|
|
30
|
+
await this.redis.set(`${this.threadPrefix}${thread.id}`, thread);
|
|
31
|
+
await this.redis.sadd('threads', thread.id);
|
|
32
|
+
return thread;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
private async withLock<T>(key: string, operation: () => Promise<T>): Promise<T> {
|
|
36
|
+
const lockKey = `lock:${key}`;
|
|
37
|
+
const lockTimeout = 5000;
|
|
38
|
+
let locked = false;
|
|
39
|
+
let timeoutId: NodeJS.Timeout | undefined;
|
|
40
|
+
|
|
41
|
+
try {
|
|
42
|
+
// Try to acquire lock with retries
|
|
43
|
+
for (let i = 0; i < 3; i++) {
|
|
44
|
+
locked = await this.redis.sadd(lockKey, '1');
|
|
45
|
+
if (locked) break;
|
|
46
|
+
await new Promise(resolve => setTimeout(resolve, Math.random() * 100));
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
if (!locked) {
|
|
50
|
+
throw new Error('Could not acquire lock');
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// Set lock timeout
|
|
54
|
+
timeoutId = setTimeout(async () => {
|
|
55
|
+
try {
|
|
56
|
+
await this.redis.del(lockKey);
|
|
57
|
+
} catch {
|
|
58
|
+
// Ignore errors during cleanup
|
|
59
|
+
}
|
|
60
|
+
this.lockTimeouts.delete(lockKey);
|
|
61
|
+
}, lockTimeout);
|
|
62
|
+
|
|
63
|
+
this.lockTimeouts.set(lockKey, timeoutId);
|
|
64
|
+
|
|
65
|
+
// Execute operation
|
|
66
|
+
return await operation();
|
|
67
|
+
} finally {
|
|
68
|
+
if (timeoutId !== undefined) {
|
|
69
|
+
clearTimeout(timeoutId);
|
|
70
|
+
this.lockTimeouts.delete(lockKey);
|
|
71
|
+
}
|
|
72
|
+
if (locked) {
|
|
73
|
+
try {
|
|
74
|
+
await this.redis.del(lockKey);
|
|
75
|
+
} catch {
|
|
76
|
+
// Ignore errors during cleanup
|
|
77
|
+
}
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
// Add cleanup method
|
|
83
|
+
async cleanup(): Promise<void> {
|
|
84
|
+
// Clear all timeouts
|
|
85
|
+
for (const timeout of this.lockTimeouts.values()) {
|
|
86
|
+
clearTimeout(timeout);
|
|
87
|
+
}
|
|
88
|
+
this.lockTimeouts.clear();
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
async saveMessages(messages: MessageType[]): Promise<MessageType[]> {
|
|
92
|
+
if (!messages.length) return [];
|
|
93
|
+
|
|
94
|
+
const messagesByThread = messages.reduce(
|
|
95
|
+
(acc, message) => {
|
|
96
|
+
const key = `${this.messagePrefix}${message.threadId}`;
|
|
97
|
+
if (!acc[key]) {
|
|
98
|
+
acc[key] = [];
|
|
99
|
+
}
|
|
100
|
+
acc[key].push({
|
|
101
|
+
...message,
|
|
102
|
+
createdAt: new Date(message.createdAt),
|
|
103
|
+
});
|
|
104
|
+
return acc;
|
|
105
|
+
},
|
|
106
|
+
{} as Record<string, MessageType[]>,
|
|
107
|
+
);
|
|
108
|
+
|
|
109
|
+
for (const [key, threadMessages] of Object.entries(messagesByThread)) {
|
|
110
|
+
await this.withLock(key, async () => {
|
|
111
|
+
const existingMessages = (await this.redis.get(key)) || [];
|
|
112
|
+
|
|
113
|
+
const messageMap = new Map<string, MessageType>();
|
|
114
|
+
|
|
115
|
+
// Process existing messages
|
|
116
|
+
existingMessages.forEach((msg: MessageType) => {
|
|
117
|
+
messageMap.set(msg.id, {
|
|
118
|
+
...msg,
|
|
119
|
+
createdAt: new Date(msg.createdAt),
|
|
120
|
+
});
|
|
121
|
+
});
|
|
122
|
+
|
|
123
|
+
// Add new messages
|
|
124
|
+
threadMessages.forEach(msg => {
|
|
125
|
+
messageMap.set(msg.id, msg);
|
|
126
|
+
});
|
|
127
|
+
|
|
128
|
+
const updatedMessages = Array.from(messageMap.values());
|
|
129
|
+
updatedMessages.sort((a, b) => {
|
|
130
|
+
const timeCompare = a.createdAt.getTime() - b.createdAt.getTime();
|
|
131
|
+
return timeCompare === 0 ? a.id.localeCompare(b.id) : timeCompare;
|
|
132
|
+
});
|
|
133
|
+
|
|
134
|
+
await this.redis.set(key, updatedMessages);
|
|
135
|
+
if (threadMessages?.[0]?.threadId) {
|
|
136
|
+
const thread = await this.getThreadById(threadMessages?.[0]?.threadId);
|
|
137
|
+
if (thread) {
|
|
138
|
+
thread.updatedAt = new Date();
|
|
139
|
+
await this.redis.set(`${this.threadPrefix}${thread.id}`, thread);
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
});
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
return messages;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
async addMessage(threadId: string, content: string, role: 'user' | 'assistant'): Promise<MessageType> {
|
|
149
|
+
const message: MessageType = {
|
|
150
|
+
id: this.generateId(),
|
|
151
|
+
content,
|
|
152
|
+
role,
|
|
153
|
+
createdAt: new Date(),
|
|
154
|
+
threadId,
|
|
155
|
+
};
|
|
156
|
+
|
|
157
|
+
await this.saveMessages([message]);
|
|
158
|
+
return message;
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
async getMessages(threadId: string): Promise<MessageType[]> {
|
|
162
|
+
const messages = (await this.redis.get(`${this.messagePrefix}${threadId}`)) || [];
|
|
163
|
+
return messages.map((msg: MessageType) => ({
|
|
164
|
+
...msg,
|
|
165
|
+
createdAt: new Date(msg.createdAt),
|
|
166
|
+
}));
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
async getAllThreadIds(): Promise<string[]> {
|
|
170
|
+
return this.redis.smembers('threads');
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
async deleteThread(threadId: string): Promise<void> {
|
|
174
|
+
const pipeline = this.redis.pipeline();
|
|
175
|
+
pipeline.del(`${this.threadPrefix}${threadId}`);
|
|
176
|
+
pipeline.del(`${this.messagePrefix}${threadId}`);
|
|
177
|
+
pipeline.srem('threads', threadId);
|
|
178
|
+
await pipeline.exec();
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
async getThreads(threadIds: string[]): Promise<ThreadType[]> {
|
|
182
|
+
const threads = await Promise.all(threadIds.map(id => this.getThreadById(id)));
|
|
183
|
+
return threads.filter((t): t is ThreadType => t !== null);
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
protected generateId(): string {
|
|
187
|
+
return crypto.randomUUID();
|
|
188
|
+
}
|
|
189
|
+
}
|