@mastra/memory 0.0.1 → 0.0.2-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,245 +0,0 @@
1
- import { ThreadType, MessageType } from '@mastra/core';
2
-
3
- import { LocalRedisProvider, RedisMemory } from './';
4
-
5
- describe('RedisMemory', () => {
6
- let memory: RedisMemory;
7
- let provider: LocalRedisProvider;
8
-
9
- beforeAll(async () => {
10
- provider = new LocalRedisProvider();
11
- memory = new RedisMemory(provider);
12
- });
13
-
14
- afterAll(async () => {
15
- await memory.cleanup();
16
- await provider.quit();
17
- });
18
-
19
- beforeEach(async () => {
20
- await provider.flushall();
21
- });
22
-
23
- describe('Thread Operations', () => {
24
- it('should create a thread with metadata', async () => {
25
- const thread = await memory.createThread('Test Thread', {
26
- testKey: 'testValue',
27
- });
28
-
29
- expect(thread).toEqual({
30
- id: expect.any(String),
31
- title: 'Test Thread',
32
- createdAt: expect.any(Date),
33
- updatedAt: expect.any(Date),
34
- metadata: { testKey: 'testValue' },
35
- });
36
- });
37
-
38
- it('should retrieve a created thread', async () => {
39
- const created = await memory.createThread('Test Thread');
40
- const retrieved = await memory.getThreadById(created.id);
41
-
42
- expect(retrieved).toEqual(created);
43
- });
44
-
45
- it('should return null for non-existent thread', async () => {
46
- const thread = await memory.getThreadById('nonexistent');
47
- expect(thread).toBeNull();
48
- });
49
-
50
- it('should update thread title', async () => {
51
- const thread = await memory.createThread('Initial Title');
52
- thread.title = 'Updated Title';
53
-
54
- await memory.saveThread(thread);
55
- const retrieved = await memory.getThreadById(thread.id);
56
-
57
- expect(retrieved?.title).toBe('Updated Title');
58
- });
59
- });
60
-
61
- describe('Message Operations', () => {
62
- let testThread: ThreadType;
63
-
64
- beforeEach(async () => {
65
- testThread = await memory.createThread('Test Thread');
66
- });
67
-
68
- it('should add a single message', async () => {
69
- const message = await memory.addMessage(testThread.id, 'Hello World', 'user');
70
-
71
- const messages = await memory.getMessages(testThread.id);
72
- expect(messages).toHaveLength(1);
73
- expect(messages[0]).toEqual(message);
74
- });
75
-
76
- it('should maintain message order', async () => {
77
- await memory.addMessage(testThread.id, 'First', 'user');
78
- await memory.addMessage(testThread.id, 'Second', 'assistant');
79
- await memory.addMessage(testThread.id, 'Third', 'user');
80
-
81
- const messages = await memory.getMessages(testThread.id);
82
- expect(messages).toHaveLength(3);
83
- expect(messages.map(m => m.content)).toEqual(['First', 'Second', 'Third']);
84
- });
85
-
86
- it('should handle bulk message saves', async () => {
87
- const messages: MessageType[] = [
88
- {
89
- id: 'msg1',
90
- threadId: testThread.id,
91
- content: 'Message 1',
92
- role: 'user',
93
- createdAt: new Date(),
94
- },
95
- {
96
- id: 'msg2',
97
- threadId: testThread.id,
98
- content: 'Message 2',
99
- role: 'assistant',
100
- createdAt: new Date(),
101
- },
102
- ];
103
-
104
- await memory.saveMessages(messages);
105
- const retrieved = await memory.getMessages(testThread.id);
106
- expect(retrieved).toHaveLength(2);
107
- });
108
-
109
- it('should prevent duplicate messages', async () => {
110
- const message = await memory.addMessage(testThread.id, 'Test', 'user');
111
- await memory.saveMessages([message]); // Try to save the same message again
112
-
113
- const messages = await memory.getMessages(testThread.id);
114
- expect(messages).toHaveLength(1);
115
- });
116
- });
117
-
118
- describe('Thread Management', () => {
119
- it('should list all thread IDs', async () => {
120
- const thread1 = await memory.createThread('Thread 1');
121
- const thread2 = await memory.createThread('Thread 2');
122
-
123
- const ids = await memory.getAllThreadIds();
124
- expect(ids).toContain(thread1.id);
125
- expect(ids).toContain(thread2.id);
126
- });
127
-
128
- it('should delete thread and its messages', async () => {
129
- const thread = await memory.createThread('Delete Test');
130
- await memory.addMessage(thread.id, 'Test Message', 'user');
131
-
132
- await memory.deleteThread(thread.id);
133
-
134
- const deletedThread = await memory.getThreadById(thread.id);
135
- const messages = await memory.getMessages(thread.id);
136
-
137
- expect(deletedThread).toBeNull();
138
- expect(messages).toHaveLength(0);
139
- });
140
-
141
- it('should retrieve multiple threads', async () => {
142
- const thread1 = await memory.createThread('Thread 1');
143
- const thread2 = await memory.createThread('Thread 2');
144
-
145
- const threads = await memory.getThreads([thread1.id, thread2.id]);
146
- expect(threads).toHaveLength(2);
147
- expect(threads.map(t => t.title)).toEqual(['Thread 1', 'Thread 2']);
148
- });
149
- });
150
-
151
- describe('Error Handling', () => {
152
- it('should handle saving messages to non-existent thread', async () => {
153
- const message: MessageType = {
154
- id: 'test',
155
- threadId: 'nonexistent',
156
- content: 'Test',
157
- role: 'user',
158
- createdAt: new Date(),
159
- };
160
-
161
- await memory.saveMessages([message]);
162
- const messages = await memory.getMessages('nonexistent');
163
- expect(messages).toHaveLength(1);
164
- });
165
-
166
- it('should handle concurrent message saves', async () => {
167
- const thread = await memory.createThread('Concurrent Test');
168
-
169
- // Simulate concurrent saves
170
- await Promise.all([
171
- memory.addMessage(thread.id, 'Message 1', 'user'),
172
- memory.addMessage(thread.id, 'Message 2', 'user'),
173
- memory.addMessage(thread.id, 'Message 3', 'user'),
174
- ]);
175
-
176
- const messages = await memory.getMessages(thread.id);
177
- expect(messages).toHaveLength(3);
178
- });
179
- });
180
-
181
- describe('Performance', () => {
182
- it('should handle large number of messages', async () => {
183
- const thread = await memory.createThread('Bulk Test');
184
- const messageCount = 100;
185
-
186
- // Create messages array first
187
- const messages = Array.from({ length: messageCount }, (_, i) => ({
188
- id: memory['generateId'](),
189
- content: `Message ${i}`,
190
- role: i % 2 === 0 ? ('user' as const) : ('assistant' as const),
191
- createdAt: new Date(),
192
- threadId: thread.id,
193
- }));
194
-
195
- // Save all messages in one operation
196
- await memory.saveMessages(messages);
197
-
198
- const retrievedMessages = await memory.getMessages(thread.id);
199
- expect(retrievedMessages).toHaveLength(messageCount);
200
-
201
- // Verify message content and ordering
202
- const sortedMessages = retrievedMessages.sort((a, b) => {
203
- const aNum = parseInt(a.content.split(' ')[1] || '0');
204
- const bNum = parseInt(b.content.split(' ')[1] || '0');
205
- return aNum - bNum;
206
- });
207
-
208
- sortedMessages.forEach((msg, i) => {
209
- expect(msg.content).toBe(`Message ${i}`);
210
- });
211
- });
212
-
213
- // Add a test for truly concurrent operations
214
- it('should handle concurrent batch saves', async () => {
215
- const thread = await memory.createThread('Concurrent Batch Test');
216
- const batchSize = 20;
217
- const numberOfBatches = 5;
218
-
219
- // Create batches of messages
220
- const batches = Array.from({ length: numberOfBatches }, (_, batchIndex) =>
221
- Array.from({ length: batchSize }, (_, i) => ({
222
- id: memory['generateId'](),
223
- content: `Batch ${batchIndex} Message ${i}`,
224
- role: i % 2 === 0 ? ('user' as const) : ('assistant' as const),
225
- createdAt: new Date(),
226
- threadId: thread.id,
227
- })),
228
- );
229
-
230
- // Save batches concurrently
231
- await Promise.all(batches.map(batch => memory.saveMessages(batch)));
232
-
233
- const messages = await memory.getMessages(thread.id);
234
- expect(messages).toHaveLength(batchSize * numberOfBatches);
235
-
236
- // Verify all messages are present
237
- const messageSet = new Set(messages.map(m => m.content));
238
- for (let batchIndex = 0; batchIndex < numberOfBatches; batchIndex++) {
239
- for (let i = 0; i < batchSize; i++) {
240
- expect(messageSet.has(`Batch ${batchIndex} Message ${i}`)).toBeTruthy();
241
- }
242
- }
243
- });
244
- });
245
- });
@@ -1,189 +0,0 @@
1
- import { MastraMemory, ThreadType, MessageType } from '@mastra/core';
2
-
3
- import { RedisClient } from './types';
4
-
5
- export * from './types';
6
- export * from './providers';
7
-
8
- export class RedisMemory extends MastraMemory {
9
- private redis: RedisClient;
10
- private threadPrefix = 'thread:';
11
- private messagePrefix = 'messages:';
12
- private lockTimeouts: Map<string, NodeJS.Timeout> = new Map();
13
-
14
- constructor(redis: RedisClient) {
15
- super();
16
- this.redis = redis;
17
- }
18
-
19
- async getThreadById(threadId: string): Promise<ThreadType | null> {
20
- const thread = await this.redis.get(`${this.threadPrefix}${threadId}`);
21
- if (thread && typeof thread.createdAt === 'string') {
22
- thread.createdAt = new Date(thread.createdAt);
23
- thread.updatedAt = new Date(thread.updatedAt);
24
- }
25
- return thread;
26
- }
27
-
28
- async saveThread(thread: ThreadType): Promise<ThreadType> {
29
- thread.updatedAt = new Date();
30
- await this.redis.set(`${this.threadPrefix}${thread.id}`, thread);
31
- await this.redis.sadd('threads', thread.id);
32
- return thread;
33
- }
34
-
35
- private async withLock<T>(key: string, operation: () => Promise<T>): Promise<T> {
36
- const lockKey = `lock:${key}`;
37
- const lockTimeout = 5000;
38
- let locked = false;
39
- let timeoutId: NodeJS.Timeout | undefined;
40
-
41
- try {
42
- // Try to acquire lock with retries
43
- for (let i = 0; i < 3; i++) {
44
- locked = await this.redis.sadd(lockKey, '1');
45
- if (locked) break;
46
- await new Promise(resolve => setTimeout(resolve, Math.random() * 100));
47
- }
48
-
49
- if (!locked) {
50
- throw new Error('Could not acquire lock');
51
- }
52
-
53
- // Set lock timeout
54
- timeoutId = setTimeout(async () => {
55
- try {
56
- await this.redis.del(lockKey);
57
- } catch {
58
- // Ignore errors during cleanup
59
- }
60
- this.lockTimeouts.delete(lockKey);
61
- }, lockTimeout);
62
-
63
- this.lockTimeouts.set(lockKey, timeoutId);
64
-
65
- // Execute operation
66
- return await operation();
67
- } finally {
68
- if (timeoutId !== undefined) {
69
- clearTimeout(timeoutId);
70
- this.lockTimeouts.delete(lockKey);
71
- }
72
- if (locked) {
73
- try {
74
- await this.redis.del(lockKey);
75
- } catch {
76
- // Ignore errors during cleanup
77
- }
78
- }
79
- }
80
- }
81
-
82
- // Add cleanup method
83
- async cleanup(): Promise<void> {
84
- // Clear all timeouts
85
- for (const timeout of this.lockTimeouts.values()) {
86
- clearTimeout(timeout);
87
- }
88
- this.lockTimeouts.clear();
89
- }
90
-
91
- async saveMessages(messages: MessageType[]): Promise<MessageType[]> {
92
- if (!messages.length) return [];
93
-
94
- const messagesByThread = messages.reduce(
95
- (acc, message) => {
96
- const key = `${this.messagePrefix}${message.threadId}`;
97
- if (!acc[key]) {
98
- acc[key] = [];
99
- }
100
- acc[key].push({
101
- ...message,
102
- createdAt: new Date(message.createdAt),
103
- });
104
- return acc;
105
- },
106
- {} as Record<string, MessageType[]>,
107
- );
108
-
109
- for (const [key, threadMessages] of Object.entries(messagesByThread)) {
110
- await this.withLock(key, async () => {
111
- const existingMessages = (await this.redis.get(key)) || [];
112
-
113
- const messageMap = new Map<string, MessageType>();
114
-
115
- // Process existing messages
116
- existingMessages.forEach((msg: MessageType) => {
117
- messageMap.set(msg.id, {
118
- ...msg,
119
- createdAt: new Date(msg.createdAt),
120
- });
121
- });
122
-
123
- // Add new messages
124
- threadMessages.forEach(msg => {
125
- messageMap.set(msg.id, msg);
126
- });
127
-
128
- const updatedMessages = Array.from(messageMap.values());
129
- updatedMessages.sort((a, b) => {
130
- const timeCompare = a.createdAt.getTime() - b.createdAt.getTime();
131
- return timeCompare === 0 ? a.id.localeCompare(b.id) : timeCompare;
132
- });
133
-
134
- await this.redis.set(key, updatedMessages);
135
- if (threadMessages?.[0]?.threadId) {
136
- const thread = await this.getThreadById(threadMessages?.[0]?.threadId);
137
- if (thread) {
138
- thread.updatedAt = new Date();
139
- await this.redis.set(`${this.threadPrefix}${thread.id}`, thread);
140
- }
141
- }
142
- });
143
- }
144
-
145
- return messages;
146
- }
147
-
148
- async addMessage(threadId: string, content: string, role: 'user' | 'assistant'): Promise<MessageType> {
149
- const message: MessageType = {
150
- id: this.generateId(),
151
- content,
152
- role,
153
- createdAt: new Date(),
154
- threadId,
155
- };
156
-
157
- await this.saveMessages([message]);
158
- return message;
159
- }
160
-
161
- async getMessages(threadId: string): Promise<MessageType[]> {
162
- const messages = (await this.redis.get(`${this.messagePrefix}${threadId}`)) || [];
163
- return messages.map((msg: MessageType) => ({
164
- ...msg,
165
- createdAt: new Date(msg.createdAt),
166
- }));
167
- }
168
-
169
- async getAllThreadIds(): Promise<string[]> {
170
- return this.redis.smembers('threads');
171
- }
172
-
173
- async deleteThread(threadId: string): Promise<void> {
174
- const pipeline = this.redis.pipeline();
175
- pipeline.del(`${this.threadPrefix}${threadId}`);
176
- pipeline.del(`${this.messagePrefix}${threadId}`);
177
- pipeline.srem('threads', threadId);
178
- await pipeline.exec();
179
- }
180
-
181
- async getThreads(threadIds: string[]): Promise<ThreadType[]> {
182
- const threads = await Promise.all(threadIds.map(id => this.getThreadById(id)));
183
- return threads.filter((t): t is ThreadType => t !== null);
184
- }
185
-
186
- protected generateId(): string {
187
- return crypto.randomUUID();
188
- }
189
- }
@@ -1,191 +0,0 @@
1
- import { Redis as UpstashRedis } from '@upstash/redis';
2
- import { createClient } from 'redis';
3
-
4
- import { RedisClient, RedisPipeline } from './types';
5
-
6
- // Helper functions for date handling
7
- function serializeData(data: any): any {
8
- if (data === null || data === undefined) return data;
9
-
10
- if (data instanceof Date) {
11
- return { __type: 'Date', value: data.toISOString() };
12
- }
13
-
14
- if (Array.isArray(data)) {
15
- return data.map(item => serializeData(item));
16
- }
17
-
18
- if (typeof data === 'object') {
19
- return Object.keys(data).reduce((acc, key) => {
20
- acc[key] = serializeData(data[key]);
21
- return acc;
22
- }, {} as any);
23
- }
24
-
25
- return data;
26
- }
27
-
28
- function deserializeData(data: any): any {
29
- if (data === null || data === undefined) return data;
30
-
31
- if (typeof data === 'object' && data.__type === 'Date') {
32
- return new Date(data.value);
33
- }
34
-
35
- if (Array.isArray(data)) {
36
- return data.map(item => deserializeData(item));
37
- }
38
-
39
- if (typeof data === 'object') {
40
- return Object.keys(data).reduce((acc, key) => {
41
- acc[key] = deserializeData(data[key]);
42
- return acc;
43
- }, {} as any);
44
- }
45
-
46
- return data;
47
- }
48
-
49
- // Provider for Upstash/Vercel KV
50
- export class UpstashProvider implements RedisClient {
51
- private client: UpstashRedis;
52
-
53
- constructor(url: string, token: string) {
54
- this.client = new UpstashRedis({
55
- url,
56
- token,
57
- });
58
- }
59
- async get(key: string) {
60
- const data = await this.client.get(key);
61
- return deserializeData(data);
62
- }
63
-
64
- async set(key: string, value: any) {
65
- return this.client.set(key, serializeData(value));
66
- }
67
- async del(key: string) {
68
- return this.client.del(key);
69
- }
70
- async sadd(key: string, value: string) {
71
- return this.client.sadd(key, value);
72
- }
73
- async srem(key: string, value: string) {
74
- return this.client.srem(key, value);
75
- }
76
- async smembers(key: string) {
77
- return this.client.smembers(key);
78
- }
79
- async flushall() {
80
- return this.client.flushall();
81
- }
82
- pipeline() {
83
- const multi = this.client.multi();
84
- const pipeline: RedisPipeline = {
85
- get: (key: string) => {
86
- multi.get(key);
87
- return pipeline;
88
- },
89
- set: (key: string, value: any) => {
90
- multi.set(key, JSON.stringify(serializeData(value)));
91
- return pipeline;
92
- },
93
- del: (key: string) => {
94
- multi.del(key);
95
- return pipeline;
96
- },
97
- srem: (key: string, value: string) => {
98
- multi.srem(key, value);
99
- return pipeline;
100
- },
101
- exec: async () => {
102
- const results = await multi.exec();
103
- return results.map(result => {
104
- try {
105
- // For get operations that return string data
106
- if (typeof result === 'string') {
107
- return deserializeData(JSON.parse(result));
108
- }
109
- // For array results (like from lists/sets)
110
- if (Array.isArray(result)) {
111
- return result.map(item => {
112
- try {
113
- return typeof item === 'string' ? deserializeData(JSON.parse(item)) : item;
114
- } catch {
115
- return item;
116
- }
117
- });
118
- }
119
- return result;
120
- } catch {
121
- return result;
122
- }
123
- });
124
- },
125
- };
126
- return pipeline;
127
- }
128
- }
129
-
130
- export class LocalRedisProvider implements RedisClient {
131
- private client: any;
132
-
133
- constructor() {
134
- this.client = createClient({ url: 'redis://localhost:6379' });
135
- this.client.connect();
136
- }
137
-
138
- async get(key: string) {
139
- const data = await this.client.get(key);
140
- return data ? deserializeData(JSON.parse(data)) : null;
141
- }
142
-
143
- async set(key: string, value: any) {
144
- return this.client.set(key, JSON.stringify(serializeData(value)));
145
- }
146
- async del(key: string) {
147
- return this.client.del(key);
148
- }
149
- async sadd(key: string, value: string) {
150
- return this.client.sAdd(key, value);
151
- }
152
- async srem(key: string, value: string) {
153
- return this.client.sRem(key, value);
154
- }
155
- async smembers(key: string) {
156
- return this.client.sMembers(key);
157
- }
158
- async flushall() {
159
- return this.client.flushAll();
160
- }
161
-
162
- pipeline() {
163
- const multi = this.client.multi();
164
- const pipeline: RedisPipeline = {
165
- get: (key: string) => {
166
- multi.get(key);
167
- return pipeline;
168
- },
169
- set: (key: string, value: any) => {
170
- multi.set(key, JSON.stringify(value));
171
- return pipeline;
172
- },
173
- del: (key: string) => {
174
- multi.del(key);
175
- return pipeline;
176
- },
177
- srem: (key: string, value: string) => {
178
- multi.sRem(key, value);
179
- return pipeline;
180
- },
181
- exec: async () => {
182
- return multi.exec();
183
- },
184
- };
185
- return pipeline;
186
- }
187
-
188
- async quit() {
189
- await this.client.quit();
190
- }
191
- }
@@ -1,18 +0,0 @@
1
- export interface RedisClient {
2
- get(key: string): Promise<any>;
3
- set(key: string, value: any): Promise<any>;
4
- del(key: string): Promise<any>;
5
- sadd(key: string, ...values: string[]): Promise<any>;
6
- srem(key: string, ...values: string[]): Promise<any>;
7
- smembers(key: string): Promise<string[]>;
8
- flushall(): Promise<any>;
9
- pipeline(): RedisPipeline;
10
- }
11
-
12
- export interface RedisPipeline {
13
- get(key: string): RedisPipeline;
14
- set(key: string, value: any): RedisPipeline;
15
- del(key: string): RedisPipeline;
16
- srem(key: string, value: string): RedisPipeline;
17
- exec(): Promise<any>;
18
- }