@mastra/memory 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,230 @@
1
+ import { ThreadType } from '@mastra/core';
2
+
3
+ import { CloudflareKVMemory } from './';
4
+ import { MockKV } from './kv';
5
+
6
+ describe('CloudflareKVMemory', () => {
7
+ let memory: CloudflareKVMemory;
8
+ let mockKV: MockKV;
9
+
10
+ beforeEach(() => {
11
+ mockKV = new MockKV();
12
+ memory = new CloudflareKVMemory(mockKV);
13
+ });
14
+
15
+ afterEach(async () => {
16
+ await memory.cleanup();
17
+ });
18
+
19
+ describe('Thread Operations', () => {
20
+ it('should create and retrieve a thread', async () => {
21
+ const thread = await memory.createThread('Test Thread', {
22
+ testKey: 'testValue',
23
+ });
24
+
25
+ expect(thread).toEqual({
26
+ id: expect.any(String),
27
+ title: 'Test Thread',
28
+ createdAt: expect.any(Date),
29
+ updatedAt: expect.any(Date),
30
+ metadata: { testKey: 'testValue' },
31
+ });
32
+
33
+ const retrieved = await memory.getThreadById(thread.id);
34
+ expect(retrieved).toEqual(thread);
35
+ });
36
+
37
+ it('should return null for non-existent thread', async () => {
38
+ const thread = await memory.getThreadById('nonexistent');
39
+ expect(thread).toBeNull();
40
+ });
41
+
42
+ it('should update thread timestamps', async () => {
43
+ const thread = await memory.createThread('Initial Title');
44
+ const originalUpdatedAt = thread.updatedAt;
45
+
46
+ // Wait a bit to ensure timestamp difference
47
+ await new Promise(resolve => setTimeout(resolve, 100));
48
+
49
+ thread.title = 'Updated Title';
50
+ const updated = await memory.saveThread(thread);
51
+
52
+ expect(updated.updatedAt.getTime()).toBeGreaterThan(originalUpdatedAt.getTime());
53
+ });
54
+ });
55
+
56
+ describe('Message Operations', () => {
57
+ let testThread: ThreadType;
58
+
59
+ beforeEach(async () => {
60
+ testThread = await memory.createThread('Message Test Thread');
61
+ });
62
+
63
+ it('should add and retrieve messages', async () => {
64
+ // Add multiple messages
65
+ await memory.addMessage(testThread.id, 'Hello', 'user');
66
+ await memory.addMessage(testThread.id, 'Hi there', 'assistant');
67
+
68
+ // Retrieve all messages
69
+ const messages = await memory.getMessages(testThread.id);
70
+
71
+ expect(messages).toHaveLength(2);
72
+
73
+ // Verify both messages are present without assuming order
74
+ expect(messages).toEqual(
75
+ expect.arrayContaining([
76
+ expect.objectContaining({
77
+ content: 'Hello',
78
+ role: 'user',
79
+ threadId: testThread.id,
80
+ }),
81
+ expect.objectContaining({
82
+ content: 'Hi there',
83
+ role: 'assistant',
84
+ threadId: testThread.id,
85
+ }),
86
+ ]),
87
+ );
88
+
89
+ // Also verify that each message has all required fields
90
+ messages.forEach(msg => {
91
+ expect(msg).toEqual(
92
+ expect.objectContaining({
93
+ id: expect.any(String),
94
+ content: expect.any(String),
95
+ role: expect.stringMatching(/^(user|assistant)$/),
96
+ createdAt: expect.any(Date),
97
+ threadId: testThread.id,
98
+ }),
99
+ );
100
+ });
101
+ });
102
+
103
+ it('should handle message ordering', async () => {
104
+ const messages = [
105
+ {
106
+ id: memory['generateId'](),
107
+ content: 'Message 1',
108
+ role: 'user' as const,
109
+ createdAt: new Date('2024-01-01'),
110
+ threadId: testThread.id,
111
+ },
112
+ {
113
+ id: memory['generateId'](),
114
+ content: 'Message 2',
115
+ role: 'assistant' as const,
116
+ createdAt: new Date('2024-01-02'),
117
+ threadId: testThread.id,
118
+ },
119
+ ];
120
+
121
+ await memory.saveMessages(messages);
122
+ const retrieved = await memory.getMessages(testThread.id);
123
+
124
+ expect(retrieved).toHaveLength(2);
125
+ expect(retrieved?.[0]?.content).toBe('Message 1');
126
+ expect(retrieved?.[1]?.content).toBe('Message 2');
127
+ });
128
+
129
+ it('should handle duplicate message saves', async () => {
130
+ const message = await memory.addMessage(testThread.id, 'Duplicate', 'user');
131
+
132
+ // Try to save the same message again
133
+ await memory.saveMessages([message]);
134
+
135
+ const messages = await memory.getMessages(testThread.id);
136
+ expect(messages).toHaveLength(1);
137
+ });
138
+ });
139
+
140
+ describe('Batch Operations', () => {
141
+ it('should handle multiple threads', async () => {
142
+ const thread1 = await memory.createThread('Thread 1');
143
+ const thread2 = await memory.createThread('Thread 2');
144
+
145
+ const allThreadIds = await memory.getAllThreadIds();
146
+ expect(allThreadIds).toContain(thread1.id);
147
+ expect(allThreadIds).toContain(thread2.id);
148
+
149
+ const threads = await memory.getThreads([thread1.id, thread2.id]);
150
+ expect(threads).toHaveLength(2);
151
+ });
152
+
153
+ it('should delete thread and associated messages', async () => {
154
+ const thread = await memory.createThread('Delete Test');
155
+ await memory.addMessage(thread.id, 'Test Message', 'user');
156
+
157
+ await memory.deleteThread(thread.id);
158
+
159
+ const deletedThread = await memory.getThreadById(thread.id);
160
+ const messages = await memory.getMessages(thread.id);
161
+
162
+ expect(deletedThread).toBeNull();
163
+ expect(messages).toHaveLength(0);
164
+ });
165
+ });
166
+
167
+ describe.only('Error Handling', () => {
168
+ it('should handle concurrent message saves', async () => {
169
+ const thread = await memory.createThread('Concurrent Test');
170
+ console.log('Created thread:', thread.id);
171
+
172
+ const messagesToSave = [
173
+ {
174
+ id: memory['generateId'](),
175
+ content: 'Message 1',
176
+ role: 'user' as const,
177
+ createdAt: new Date(),
178
+ threadId: thread.id,
179
+ },
180
+ {
181
+ id: memory['generateId'](),
182
+ content: 'Message 2',
183
+ role: 'user' as const,
184
+ createdAt: new Date(),
185
+ threadId: thread.id,
186
+ },
187
+ {
188
+ id: memory['generateId'](),
189
+ content: 'Message 3',
190
+ role: 'user' as const,
191
+ createdAt: new Date(),
192
+ threadId: thread.id,
193
+ },
194
+ ];
195
+
196
+ console.log('Messages to save:', messagesToSave);
197
+
198
+ // Save messages concurrently but in batches
199
+ if (messagesToSave?.length && messagesToSave?.[0] && messagesToSave?.[2]) {
200
+ await Promise.all([memory.saveMessages(messagesToSave).then(() => console.log('Batch 2 saved'))]);
201
+ }
202
+
203
+ const messages = await memory.getMessages(thread.id);
204
+ console.log('Retrieved messages:', messages);
205
+
206
+ expect(messages).toHaveLength(3);
207
+ });
208
+ });
209
+
210
+ describe('Performance', () => {
211
+ it('should handle large number of messages', async () => {
212
+ const thread = await memory.createThread('Bulk Test');
213
+ const messageCount = 100;
214
+
215
+ // Create many messages
216
+ const messages = Array.from({ length: messageCount }, (_, i) => ({
217
+ id: memory['generateId'](),
218
+ content: `Message ${i}`,
219
+ role: i % 2 === 0 ? ('user' as const) : ('assistant' as const),
220
+ createdAt: new Date(),
221
+ threadId: thread.id,
222
+ }));
223
+
224
+ await memory.saveMessages(messages);
225
+ const retrieved = await memory.getMessages(thread.id);
226
+
227
+ expect(retrieved).toHaveLength(messageCount);
228
+ });
229
+ });
230
+ });
@@ -0,0 +1,169 @@
1
+ import { MastraMemory, MessageType, ThreadType } from '@mastra/core';
2
+
3
+ import { CloudflareKVProvider, KVNamespace } from './kv';
4
+
5
+ export class CloudflareKVMemory extends MastraMemory {
6
+ private kv: CloudflareKVProvider;
7
+ private threadPrefix = 'thread:';
8
+ private messagePrefix = 'messages:';
9
+
10
+ constructor(namespace: KVNamespace) {
11
+ super();
12
+ this.kv = new CloudflareKVProvider(namespace);
13
+ }
14
+
15
+ async getThreadById(threadId: string): Promise<ThreadType | null> {
16
+ const thread = await this.kv.get<ThreadType>(`${this.threadPrefix}${threadId}`);
17
+ if (thread && typeof thread.createdAt === 'string') {
18
+ thread.createdAt = new Date(thread.createdAt);
19
+ thread.updatedAt = new Date(thread.updatedAt);
20
+ }
21
+ return thread;
22
+ }
23
+
24
+ async saveThread(thread: ThreadType): Promise<ThreadType> {
25
+ thread.updatedAt = new Date();
26
+ await this.kv.set(`${this.threadPrefix}${thread.id}`, thread);
27
+ await this.kv.sadd('threads', thread.id);
28
+ return thread;
29
+ }
30
+
31
+ private async retryOperation<T>(operation: () => Promise<T>, maxRetries = 5): Promise<T> {
32
+ let lastError: Error | undefined;
33
+
34
+ for (let attempt = 0; attempt < maxRetries; attempt++) {
35
+ try {
36
+ return await operation();
37
+ } catch (error) {
38
+ lastError = error as Error;
39
+ await new Promise(resolve => setTimeout(resolve, Math.random() * 100 * Math.pow(2, attempt)));
40
+ }
41
+ }
42
+
43
+ throw lastError || new Error(`Operation failed after ${maxRetries} attempts`);
44
+ }
45
+
46
+ async addMessage(threadId: string, content: string, role: 'user' | 'assistant'): Promise<MessageType> {
47
+ const message: MessageType = {
48
+ id: this.generateId(),
49
+ content,
50
+ role,
51
+ createdAt: new Date(),
52
+ threadId,
53
+ };
54
+
55
+ await this.retryOperation(async () => {
56
+ const key = `${this.messagePrefix}${threadId}`;
57
+ const existingMessages = (await this.kv.get<MessageType[]>(key)) || [];
58
+
59
+ const messageMap = new Map<string, MessageType>();
60
+
61
+ existingMessages.forEach(msg => {
62
+ messageMap.set(msg.id, {
63
+ ...msg,
64
+ createdAt: new Date(msg.createdAt),
65
+ });
66
+ });
67
+
68
+ messageMap.set(message.id, message);
69
+
70
+ const updatedMessages = Array.from(messageMap.values());
71
+ updatedMessages.sort((a, b) => {
72
+ const timeCompare = a.createdAt.getTime() - b.createdAt.getTime();
73
+ return timeCompare === 0 ? a.id.localeCompare(b.id) : timeCompare;
74
+ });
75
+
76
+ await this.kv.set(key, updatedMessages);
77
+ });
78
+
79
+ return message;
80
+ }
81
+
82
+ async saveMessages(messages: MessageType[]): Promise<MessageType[]> {
83
+ if (!messages.length) return [];
84
+
85
+ const messagesByThread = new Map<string, MessageType[]>();
86
+
87
+ for (const message of messages) {
88
+ const key = `${this.messagePrefix}${message.threadId}`;
89
+ if (!messagesByThread.has(key)) {
90
+ messagesByThread.set(key, []);
91
+ }
92
+ messagesByThread.get(key)!.push({
93
+ ...message,
94
+ createdAt: new Date(message.createdAt),
95
+ });
96
+ }
97
+
98
+ await Promise.all(
99
+ Array.from(messagesByThread.entries()).map(([key, threadMessages]) =>
100
+ this.retryOperation(async () => {
101
+ let saved = false;
102
+ while (!saved) {
103
+ const { data: existingMessages, version } = await this.kv.getWithVersion<MessageType[]>(key);
104
+ console.log('Read version:', version, 'Messages:', existingMessages?.length || 0);
105
+
106
+ const messageMap = new Map<string, MessageType>();
107
+
108
+ (existingMessages || []).forEach(msg => {
109
+ messageMap.set(msg.id, {
110
+ ...msg,
111
+ createdAt: new Date(msg.createdAt),
112
+ });
113
+ });
114
+
115
+ threadMessages.forEach(msg => {
116
+ messageMap.set(msg.id, msg);
117
+ });
118
+
119
+ const updatedMessages = Array.from(messageMap.values());
120
+ updatedMessages.sort((a, b) => {
121
+ const timeCompare = a.createdAt.getTime() - b.createdAt.getTime();
122
+ return timeCompare === 0 ? a.id.localeCompare(b.id) : timeCompare;
123
+ });
124
+
125
+ saved = await this.kv.setWithVersion(key, updatedMessages, version);
126
+ console.log('Save attempt with version:', version, 'Success:', saved);
127
+
128
+ if (!saved) {
129
+ // If save failed, someone else updated the messages, retry with new version
130
+ await new Promise(resolve => setTimeout(resolve, Math.random() * 50));
131
+ }
132
+ }
133
+ }),
134
+ ),
135
+ );
136
+
137
+ return messages;
138
+ }
139
+
140
+ async getMessages(threadId: string): Promise<MessageType[]> {
141
+ const messages = (await this.kv.get<MessageType[]>(`${this.messagePrefix}${threadId}`)) || [];
142
+ return messages.map(msg => ({
143
+ ...msg,
144
+ createdAt: new Date(msg.createdAt),
145
+ }));
146
+ }
147
+
148
+ async getAllThreadIds(): Promise<string[]> {
149
+ return this.kv.smembers('threads');
150
+ }
151
+
152
+ async deleteThread(threadId: string): Promise<void> {
153
+ await Promise.all([
154
+ this.kv.del(`${this.threadPrefix}${threadId}`),
155
+ this.kv.del(`${this.messagePrefix}${threadId}`),
156
+ this.kv.srem('threads', threadId),
157
+ ]);
158
+ }
159
+
160
+ async getThreads(threadIds: string[]): Promise<ThreadType[]> {
161
+ const threads = await Promise.all(threadIds.map(id => this.getThreadById(id)));
162
+ return threads.filter((t): t is ThreadType => t !== null);
163
+ }
164
+
165
+ async cleanup(): Promise<void> {
166
+ // Flush all data
167
+ await this.kv.flushall();
168
+ }
169
+ }
@@ -0,0 +1,139 @@
1
+ // Actual Cloudflare KV interface
2
+ export interface KVNamespace {
3
+ get(key: string): Promise<string | null>;
4
+ get(key: string, type: 'text'): Promise<string | null>;
5
+ get<T>(key: string, type: 'json'): Promise<T | null>;
6
+ put(key: string, value: string | ReadableStream | ArrayBuffer | FormData): Promise<void>;
7
+ delete(key: string): Promise<void>;
8
+ // List is actually on a different interface in real CF workers
9
+ list(options?: { prefix?: string; limit?: number; cursor?: string }): Promise<{
10
+ keys: { name: string }[];
11
+ list_complete: boolean;
12
+ cursor?: string;
13
+ }>;
14
+ }
15
+
16
+ // Mock implementation for testing
17
+ export class MockKV implements KVNamespace {
18
+ private store = new Map<string, string>();
19
+
20
+ async get(key: string): Promise<string | null>;
21
+ async get(key: string, type: 'text'): Promise<string | null>;
22
+ async get<T>(key: string, type: 'json'): Promise<T | null>;
23
+ async get<T>(key: string, type?: 'text' | 'json'): Promise<string | T | null> {
24
+ const value = this.store.get(key);
25
+ if (value === undefined) return null;
26
+
27
+ if (type === 'json') {
28
+ return JSON.parse(value) as T;
29
+ }
30
+
31
+ return value;
32
+ }
33
+
34
+ async put(key: string, value: string | ReadableStream | ArrayBuffer | FormData): Promise<void> {
35
+ if (typeof value !== 'string') {
36
+ throw new Error('MockKV only supports string values for testing');
37
+ }
38
+ this.store.set(key, value);
39
+ }
40
+
41
+ async delete(key: string): Promise<void> {
42
+ this.store.delete(key);
43
+ }
44
+
45
+ async list(options?: { prefix?: string; limit?: number; cursor?: string }): Promise<{
46
+ keys: { name: string }[];
47
+ list_complete: boolean;
48
+ cursor?: string;
49
+ }> {
50
+ const allKeys = Array.from(this.store.keys())
51
+ .filter(key => !options?.prefix || key.startsWith(options.prefix))
52
+ .map(name => ({ name }));
53
+
54
+ const start = options?.cursor ? parseInt(options.cursor) : 0;
55
+ const limit = options?.limit ?? allKeys.length;
56
+ const end = start + limit;
57
+ const keys = allKeys.slice(start, end);
58
+ const list_complete = end >= allKeys.length;
59
+
60
+ return {
61
+ keys,
62
+ list_complete,
63
+ cursor: list_complete ? undefined : end.toString(),
64
+ };
65
+ }
66
+ }
67
+
68
+ export class CloudflareKVProvider {
69
+ private namespace: KVNamespace;
70
+
71
+ constructor(namespace: KVNamespace) {
72
+ this.namespace = namespace;
73
+ }
74
+
75
+ async get<T>(key: string): Promise<T | null> {
76
+ return this.namespace.get<T>(key, 'json');
77
+ }
78
+
79
+ async set(key: string, value: any) {
80
+ await this.namespace.put(key, JSON.stringify(value));
81
+ }
82
+
83
+ async del(key: string) {
84
+ await this.namespace.delete(key);
85
+ }
86
+
87
+ async sadd(key: string, value: string) {
88
+ const set = (await this.get<string[]>(key)) || [];
89
+ if (!set.includes(value)) {
90
+ set.push(value);
91
+ await this.set(key, set);
92
+ return 1;
93
+ }
94
+ return 0;
95
+ }
96
+
97
+ async srem(key: string, value: string) {
98
+ const set = (await this.get<string[]>(key)) || [];
99
+ const index = set.indexOf(value);
100
+ if (index !== -1) {
101
+ set.splice(index, 1);
102
+ await this.set(key, set);
103
+ return 1;
104
+ }
105
+ return 0;
106
+ }
107
+
108
+ async smembers(key: string): Promise<string[]> {
109
+ return (await this.get<string[]>(key)) || [];
110
+ }
111
+
112
+ async flushall() {
113
+ let cursor: string | undefined;
114
+ do {
115
+ const result = await this.namespace.list({ cursor });
116
+ await Promise.all(result.keys.map(key => this.namespace.delete(key.name)));
117
+ cursor = result.cursor;
118
+ } while (cursor);
119
+ }
120
+
121
+ async getWithVersion<T>(key: string): Promise<{ data: T | null; version: number }> {
122
+ const versionKey = `${key}:version`;
123
+ const [data, version] = await Promise.all([this.get<T>(key), this.get<number>(versionKey).then(v => v || 0)]);
124
+ return { data, version };
125
+ }
126
+
127
+ async setWithVersion<T>(key: string, value: T, expectedVersion: number): Promise<boolean> {
128
+ const versionKey = `${key}:version`;
129
+ const currentVersion = (await this.get<number>(versionKey)) || 0;
130
+
131
+ if (currentVersion !== expectedVersion) {
132
+ return false;
133
+ }
134
+
135
+ await Promise.all([this.set(key, value), this.set(versionKey, expectedVersion + 1)]);
136
+
137
+ return true;
138
+ }
139
+ }
package/src/index.ts ADDED
@@ -0,0 +1,3 @@
1
+ export * from './cloudflare';
2
+ export * from './postgres';
3
+ export * from './redis';
@@ -0,0 +1,60 @@
1
+ import dotenv from 'dotenv';
2
+
3
+ import { PgMemory } from './';
4
+
5
+ dotenv.config();
6
+
7
+ const connectionString = process.env.DB_URL! || 'postgres://postgres:password@localhost:5434/mastra';
8
+
9
+ describe('PgMastraMemory', () => {
10
+ let memory: PgMemory;
11
+
12
+ beforeAll(async () => {
13
+ memory = new PgMemory(connectionString);
14
+ });
15
+
16
+ afterAll(async () => {
17
+ await memory.drop();
18
+ });
19
+
20
+ it('should create and retrieve a thread', async () => {
21
+ const thread = await memory.createThread('Test Thread', { test: true });
22
+ const retrievedThread = await memory.getThreadById(thread.id);
23
+ expect(retrievedThread).toEqual(thread);
24
+ });
25
+
26
+ it('should save and retrieve messages', async () => {
27
+ const thread = await memory.createThread('Test Thread 2', { test: true });
28
+ const message1 = await memory.addMessage(thread.id, 'Hello', 'user');
29
+ // const message2 = await memory.addMessage(thread.id, 'World', 'assistant');
30
+ const messages = await memory.getMessages(thread.id);
31
+
32
+ console.log(messages);
33
+ expect(messages[0]?.content).toEqual(message1.content);
34
+ });
35
+
36
+ it('should update a thread', async () => {
37
+ const thread = await memory.createThread('Initial Thread Title', { test: true });
38
+ const updatedThread = await memory.updateThread(thread.id, 'Updated Thread Title', { test: true, updated: true });
39
+
40
+ expect(updatedThread.title).toEqual('Updated Thread Title');
41
+ expect(updatedThread.metadata).toEqual({ test: true, updated: true });
42
+ });
43
+
44
+ it('should delete a thread', async () => {
45
+ const thread = await memory.createThread('Thread to Delete', { test: true });
46
+ await memory.deleteThread(thread.id);
47
+
48
+ const retrievedThread = await memory.getThreadById(thread.id);
49
+ expect(retrievedThread).toBeNull();
50
+ });
51
+
52
+ it('should delete a message', async () => {
53
+ const thread = await memory.createThread('Thread with Message', { test: true });
54
+ const message = await memory.addMessage(thread.id, 'Message to Delete', 'user');
55
+ await memory.deleteMessage(message.id);
56
+
57
+ const messages = await memory.getMessages(thread.id);
58
+ expect(messages.length).toEqual(0);
59
+ });
60
+ });