@mastra/memory 0.0.2-alpha.8 → 0.1.0-alpha.65
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +440 -0
- package/LICENSE +44 -0
- package/dist/index.d.ts +44 -2
- package/dist/index.js +117 -6
- package/package.json +22 -15
- package/src/index.test.ts +7 -0
- package/src/index.ts +170 -2
- package/vitest.config.ts +8 -0
- package/dist/kv/upstash.d.ts +0 -71
- package/dist/memory.cjs.development.js +0 -1579
- package/dist/memory.cjs.development.js.map +0 -1
- package/dist/memory.cjs.production.min.js +0 -2
- package/dist/memory.cjs.production.min.js.map +0 -1
- package/dist/memory.esm.js +0 -1574
- package/dist/memory.esm.js.map +0 -1
- package/dist/postgres/index.d.ts +0 -59
- package/jest.config.ts +0 -19
- package/src/kv/upstash.test.ts +0 -253
- package/src/kv/upstash.ts +0 -298
- package/src/postgres/index.test.ts +0 -68
- package/src/postgres/index.ts +0 -492
package/src/kv/upstash.ts
DELETED
|
@@ -1,298 +0,0 @@
|
|
|
1
|
-
import { MastraMemory, MessageType as BaseMastraMessageType, ThreadType, MessageResponse } from '@mastra/core';
|
|
2
|
-
import { Redis } from '@upstash/redis';
|
|
3
|
-
import { ToolResultPart, Message as AiMessage, TextPart } from 'ai';
|
|
4
|
-
import crypto from 'crypto';
|
|
5
|
-
|
|
6
|
-
interface ToolCacheData {
|
|
7
|
-
expireAt: string;
|
|
8
|
-
result?: ToolResultPart['result'];
|
|
9
|
-
}
|
|
10
|
-
|
|
11
|
-
interface MessageType extends BaseMastraMessageType {
|
|
12
|
-
tokens?: number;
|
|
13
|
-
}
|
|
14
|
-
|
|
15
|
-
// Internal type for serialized thread data
|
|
16
|
-
interface SerializedThreadType extends Omit<ThreadType, 'createdAt' | 'updatedAt'> {
|
|
17
|
-
createdAt: string;
|
|
18
|
-
updatedAt: string;
|
|
19
|
-
}
|
|
20
|
-
|
|
21
|
-
export class UpstashKVMemory extends MastraMemory {
|
|
22
|
-
private prefix: string;
|
|
23
|
-
|
|
24
|
-
kv: Redis;
|
|
25
|
-
|
|
26
|
-
constructor(config: { url: string; token: string; prefix?: string; maxTokens?: number }) {
|
|
27
|
-
super();
|
|
28
|
-
this.prefix = config.prefix || 'mastra';
|
|
29
|
-
this.MAX_CONTEXT_TOKENS = config.maxTokens;
|
|
30
|
-
|
|
31
|
-
this.kv = new Redis({
|
|
32
|
-
url: config.url,
|
|
33
|
-
token: config.token,
|
|
34
|
-
});
|
|
35
|
-
}
|
|
36
|
-
|
|
37
|
-
private getThreadKey(threadId: string): string {
|
|
38
|
-
return `${this.prefix}:thread:${threadId}`;
|
|
39
|
-
}
|
|
40
|
-
|
|
41
|
-
private getMessagesKey(threadId: string): string {
|
|
42
|
-
return `${this.prefix}:messages:${threadId}`;
|
|
43
|
-
}
|
|
44
|
-
|
|
45
|
-
private getToolCacheKey(hashedArgs: string): string {
|
|
46
|
-
return `${this.prefix}:tool:${hashedArgs}`;
|
|
47
|
-
}
|
|
48
|
-
|
|
49
|
-
async getThreadById({ threadId }: { threadId: string }): Promise<ThreadType | null> {
|
|
50
|
-
const thread = await this.kv.get<SerializedThreadType>(this.getThreadKey(threadId));
|
|
51
|
-
return thread ? this.parseThread(thread) : null;
|
|
52
|
-
}
|
|
53
|
-
|
|
54
|
-
async getThreadsByResourceId({ resourceid }: { resourceid: string }): Promise<ThreadType[]> {
|
|
55
|
-
const pattern = `${this.prefix}:thread:*`;
|
|
56
|
-
const keys = await this.kv.keys(pattern);
|
|
57
|
-
|
|
58
|
-
const threads = await Promise.all(keys.map(key => this.kv.get<SerializedThreadType>(key)));
|
|
59
|
-
|
|
60
|
-
return threads
|
|
61
|
-
.filter(thread => thread?.resourceid === resourceid)
|
|
62
|
-
.map(thread => this.parseThread(thread as SerializedThreadType));
|
|
63
|
-
}
|
|
64
|
-
|
|
65
|
-
async saveThread({ thread }: { thread: ThreadType }): Promise<ThreadType> {
|
|
66
|
-
const key = this.getThreadKey(thread.id);
|
|
67
|
-
const serializedThread: SerializedThreadType = {
|
|
68
|
-
...thread,
|
|
69
|
-
createdAt: thread.createdAt.toISOString(),
|
|
70
|
-
updatedAt: thread.updatedAt.toISOString(),
|
|
71
|
-
};
|
|
72
|
-
await this.kv.set(key, serializedThread);
|
|
73
|
-
return thread;
|
|
74
|
-
}
|
|
75
|
-
|
|
76
|
-
async updateThread(id: string, title: string, metadata: Record<string, unknown>): Promise<ThreadType> {
|
|
77
|
-
const key = this.getThreadKey(id);
|
|
78
|
-
const thread = await this.kv.get<SerializedThreadType>(key);
|
|
79
|
-
|
|
80
|
-
if (!thread) {
|
|
81
|
-
throw new Error(`Thread ${id} not found`);
|
|
82
|
-
}
|
|
83
|
-
|
|
84
|
-
const updatedThread: SerializedThreadType = {
|
|
85
|
-
...thread,
|
|
86
|
-
title,
|
|
87
|
-
metadata,
|
|
88
|
-
updatedAt: new Date().toISOString(),
|
|
89
|
-
};
|
|
90
|
-
|
|
91
|
-
await this.kv.set(key, updatedThread);
|
|
92
|
-
return this.parseThread(updatedThread);
|
|
93
|
-
}
|
|
94
|
-
|
|
95
|
-
async deleteThread(id: string): Promise<void> {
|
|
96
|
-
await this.kv.del(this.getThreadKey(id));
|
|
97
|
-
await this.kv.del(this.getMessagesKey(id));
|
|
98
|
-
}
|
|
99
|
-
|
|
100
|
-
/**
|
|
101
|
-
* Tool Cache
|
|
102
|
-
*/
|
|
103
|
-
|
|
104
|
-
async validateToolCallArgs({ hashedArgs }: { hashedArgs: string }): Promise<boolean> {
|
|
105
|
-
const cacheKey = this.getToolCacheKey(hashedArgs);
|
|
106
|
-
const cached = await this.kv.get<ToolCacheData>(cacheKey);
|
|
107
|
-
return !!cached && new Date(cached.expireAt) > new Date();
|
|
108
|
-
}
|
|
109
|
-
|
|
110
|
-
async getToolResult({
|
|
111
|
-
threadId,
|
|
112
|
-
toolArgs,
|
|
113
|
-
toolName,
|
|
114
|
-
}: {
|
|
115
|
-
threadId: string;
|
|
116
|
-
toolArgs: Record<string, unknown>;
|
|
117
|
-
toolName: string;
|
|
118
|
-
}): Promise<ToolResultPart['result'] | null> {
|
|
119
|
-
const hashedToolArgs = crypto
|
|
120
|
-
.createHash('sha256')
|
|
121
|
-
.update(JSON.stringify({ args: toolArgs, threadId, toolName }))
|
|
122
|
-
.digest('hex');
|
|
123
|
-
|
|
124
|
-
const cacheKey = this.getToolCacheKey(hashedToolArgs);
|
|
125
|
-
const cached = await this.kv.get<ToolCacheData>(cacheKey);
|
|
126
|
-
|
|
127
|
-
if (cached && new Date(cached.expireAt) > new Date()) {
|
|
128
|
-
return cached.result || null;
|
|
129
|
-
}
|
|
130
|
-
|
|
131
|
-
return null;
|
|
132
|
-
}
|
|
133
|
-
|
|
134
|
-
async getContextWindow<T extends 'raw' | 'core_message'>({
|
|
135
|
-
threadId,
|
|
136
|
-
startDate,
|
|
137
|
-
endDate,
|
|
138
|
-
format = 'raw' as T,
|
|
139
|
-
}: {
|
|
140
|
-
format?: T;
|
|
141
|
-
threadId: string;
|
|
142
|
-
startDate?: Date;
|
|
143
|
-
endDate?: Date;
|
|
144
|
-
}) {
|
|
145
|
-
const messagesKey = this.getMessagesKey(threadId);
|
|
146
|
-
const messages = await this.kv.lrange<MessageType>(messagesKey, 0, -1);
|
|
147
|
-
|
|
148
|
-
let filteredMessages = messages.filter(msg => msg.type === 'text');
|
|
149
|
-
|
|
150
|
-
if (startDate) {
|
|
151
|
-
filteredMessages = filteredMessages.filter(msg => new Date(msg.createdAt) >= startDate);
|
|
152
|
-
}
|
|
153
|
-
|
|
154
|
-
if (endDate) {
|
|
155
|
-
filteredMessages = filteredMessages.filter(msg => new Date(msg.createdAt) <= endDate);
|
|
156
|
-
}
|
|
157
|
-
|
|
158
|
-
if (this.MAX_CONTEXT_TOKENS) {
|
|
159
|
-
let totalTokens = 0;
|
|
160
|
-
const messagesWithinTokenLimit: MessageType[] = [];
|
|
161
|
-
|
|
162
|
-
// Process messages from newest to oldest
|
|
163
|
-
for (const message of filteredMessages.reverse()) {
|
|
164
|
-
const content =
|
|
165
|
-
message.role === 'assistant'
|
|
166
|
-
? (message.content as Array<TextPart>)[0]?.text || ''
|
|
167
|
-
: (message.content as string);
|
|
168
|
-
|
|
169
|
-
// Use a more aggressive token estimation
|
|
170
|
-
// Roughly estimate 1 token per 4 characters
|
|
171
|
-
const tokens = Math.ceil(content.length / 4);
|
|
172
|
-
|
|
173
|
-
// Check if adding this message would exceed the token limit
|
|
174
|
-
if (totalTokens + tokens > this.MAX_CONTEXT_TOKENS) {
|
|
175
|
-
break;
|
|
176
|
-
}
|
|
177
|
-
|
|
178
|
-
totalTokens += tokens;
|
|
179
|
-
messagesWithinTokenLimit.unshift({
|
|
180
|
-
...message,
|
|
181
|
-
tokens,
|
|
182
|
-
});
|
|
183
|
-
}
|
|
184
|
-
|
|
185
|
-
console.log('Format:', format);
|
|
186
|
-
// Return messages in chronological order
|
|
187
|
-
return this.parseMessages(messagesWithinTokenLimit) as MessageResponse<T>;
|
|
188
|
-
}
|
|
189
|
-
|
|
190
|
-
return this.parseMessages(filteredMessages) as MessageResponse<T>;
|
|
191
|
-
}
|
|
192
|
-
|
|
193
|
-
/**
|
|
194
|
-
* Messages
|
|
195
|
-
*/
|
|
196
|
-
|
|
197
|
-
async getMessages({ threadId }: { threadId: string }): Promise<{ messages: MessageType[]; uiMessages: AiMessage[] }> {
|
|
198
|
-
const messagesKey = this.getMessagesKey(threadId);
|
|
199
|
-
const messages = await this.kv.lrange<MessageType>(messagesKey, 0, -1);
|
|
200
|
-
const parsedMessages = this.parseMessages(messages);
|
|
201
|
-
const uiMessages = this.convertToUIMessages(parsedMessages);
|
|
202
|
-
|
|
203
|
-
return { messages: parsedMessages, uiMessages };
|
|
204
|
-
}
|
|
205
|
-
|
|
206
|
-
async saveMessages({ messages }: { messages: MessageType[] }): Promise<MessageType[]> {
|
|
207
|
-
const processedMessages: MessageType[] = [];
|
|
208
|
-
|
|
209
|
-
for (const message of messages) {
|
|
210
|
-
const { threadId, toolCallArgs, toolNames, createdAt } = message;
|
|
211
|
-
const messagesKey = this.getMessagesKey(threadId);
|
|
212
|
-
|
|
213
|
-
const processedMessage = { ...message };
|
|
214
|
-
|
|
215
|
-
if (message.type === 'text') {
|
|
216
|
-
const content =
|
|
217
|
-
message.role === 'assistant'
|
|
218
|
-
? (message.content as Array<TextPart>)[0]?.text || ''
|
|
219
|
-
: (message.content as string);
|
|
220
|
-
processedMessage.tokens = this.estimateTokens(content);
|
|
221
|
-
}
|
|
222
|
-
|
|
223
|
-
if (toolCallArgs?.length) {
|
|
224
|
-
const hashedToolCallArgs = toolCallArgs.map((args, index) =>
|
|
225
|
-
crypto
|
|
226
|
-
.createHash('sha256')
|
|
227
|
-
.update(JSON.stringify({ args, threadId, toolName: toolNames?.[index] }))
|
|
228
|
-
.digest('hex'),
|
|
229
|
-
);
|
|
230
|
-
|
|
231
|
-
let validArgExists = true;
|
|
232
|
-
for (const hashedArg of hashedToolCallArgs) {
|
|
233
|
-
const isValid = await this.validateToolCallArgs({ hashedArgs: hashedArg });
|
|
234
|
-
if (!isValid) {
|
|
235
|
-
validArgExists = false;
|
|
236
|
-
break;
|
|
237
|
-
}
|
|
238
|
-
}
|
|
239
|
-
|
|
240
|
-
const expireAt = validArgExists ? createdAt : new Date(createdAt.getTime() + 5 * 60 * 1000); // 5 minutes
|
|
241
|
-
|
|
242
|
-
for (const hashedArg of hashedToolCallArgs) {
|
|
243
|
-
const cacheKey = this.getToolCacheKey(hashedArg);
|
|
244
|
-
await this.kv.set(cacheKey, { expireAt: expireAt.toISOString() } as ToolCacheData);
|
|
245
|
-
}
|
|
246
|
-
}
|
|
247
|
-
|
|
248
|
-
await this.kv.rpush(messagesKey, processedMessage);
|
|
249
|
-
processedMessages.push(processedMessage);
|
|
250
|
-
}
|
|
251
|
-
|
|
252
|
-
return processedMessages;
|
|
253
|
-
}
|
|
254
|
-
|
|
255
|
-
async deleteMessage(id: string): Promise<void> {
|
|
256
|
-
const pattern = `${this.prefix}:messages:*`;
|
|
257
|
-
const keys = await this.kv.keys(pattern);
|
|
258
|
-
|
|
259
|
-
for (const key of keys) {
|
|
260
|
-
const messages = await this.kv.lrange<MessageType>(key, 0, -1);
|
|
261
|
-
const filteredMessages = messages.filter(msg => msg.id !== id);
|
|
262
|
-
|
|
263
|
-
if (messages.length !== filteredMessages.length) {
|
|
264
|
-
await this.kv.del(key);
|
|
265
|
-
if (filteredMessages.length > 0) {
|
|
266
|
-
await this.kv.rpush(key, ...filteredMessages);
|
|
267
|
-
}
|
|
268
|
-
}
|
|
269
|
-
}
|
|
270
|
-
}
|
|
271
|
-
|
|
272
|
-
/**
|
|
273
|
-
* Cleanup
|
|
274
|
-
*/
|
|
275
|
-
|
|
276
|
-
async drop(): Promise<void> {
|
|
277
|
-
const pattern = `${this.prefix}:*`;
|
|
278
|
-
const keys = await this.kv.keys(pattern);
|
|
279
|
-
if (keys.length > 0) {
|
|
280
|
-
await this.kv.del(...keys);
|
|
281
|
-
}
|
|
282
|
-
}
|
|
283
|
-
|
|
284
|
-
parseThread(thread: SerializedThreadType): ThreadType {
|
|
285
|
-
return {
|
|
286
|
-
...thread,
|
|
287
|
-
createdAt: new Date(thread.createdAt),
|
|
288
|
-
updatedAt: new Date(thread.updatedAt),
|
|
289
|
-
};
|
|
290
|
-
}
|
|
291
|
-
|
|
292
|
-
parseMessages(messages: MessageType[]): MessageType[] {
|
|
293
|
-
return messages.map(message => ({
|
|
294
|
-
...message,
|
|
295
|
-
createdAt: new Date(message.createdAt),
|
|
296
|
-
}));
|
|
297
|
-
}
|
|
298
|
-
}
|
|
@@ -1,68 +0,0 @@
|
|
|
1
|
-
import dotenv from 'dotenv';
|
|
2
|
-
|
|
3
|
-
import { PgMemory } from './';
|
|
4
|
-
|
|
5
|
-
dotenv.config();
|
|
6
|
-
|
|
7
|
-
const connectionString = process.env.DB_URL! || 'postgres://postgres:password@localhost:5434/mastra';
|
|
8
|
-
const resourceid = 'resource';
|
|
9
|
-
|
|
10
|
-
describe('PgMastraMemory', () => {
|
|
11
|
-
let memory: PgMemory;
|
|
12
|
-
|
|
13
|
-
beforeAll(async () => {
|
|
14
|
-
memory = new PgMemory({ connectionString });
|
|
15
|
-
});
|
|
16
|
-
|
|
17
|
-
afterAll(async () => {
|
|
18
|
-
await memory.drop();
|
|
19
|
-
});
|
|
20
|
-
|
|
21
|
-
it('should create and retrieve a thread', async () => {
|
|
22
|
-
const thread = await memory.createThread({ title: 'Test thread', resourceid });
|
|
23
|
-
const retrievedThread = await memory.getThreadById({ threadId: thread.id });
|
|
24
|
-
expect(retrievedThread).toEqual(thread);
|
|
25
|
-
});
|
|
26
|
-
|
|
27
|
-
it('should save and retrieve messages', async () => {
|
|
28
|
-
const thread = await memory.createThread({ title: 'Test thread 2', resourceid });
|
|
29
|
-
const message1 = await memory.addMessage({ threadId: thread.id, content: 'Hello', role: 'user', type: 'text' });
|
|
30
|
-
// const message2 = await memory.addMessage(thread.id, 'World', 'assistant');
|
|
31
|
-
const memoryMessages = await memory.getMessages({ threadId: thread.id });
|
|
32
|
-
const messages = memoryMessages.messages;
|
|
33
|
-
|
|
34
|
-
console.log(messages);
|
|
35
|
-
expect(messages[0]?.content).toEqual(message1.content);
|
|
36
|
-
});
|
|
37
|
-
|
|
38
|
-
it('should update a thread', async () => {
|
|
39
|
-
const thread = await memory.createThread({ title: 'Initial Thread Title', resourceid });
|
|
40
|
-
const updatedThread = await memory.updateThread(thread.id, 'Updated Thread Title', { test: true, updated: true });
|
|
41
|
-
|
|
42
|
-
expect(updatedThread.title).toEqual('Updated Thread Title');
|
|
43
|
-
expect(updatedThread.metadata).toEqual({ test: true, updated: true });
|
|
44
|
-
});
|
|
45
|
-
|
|
46
|
-
it('should delete a thread', async () => {
|
|
47
|
-
const thread = await memory.createThread({ title: 'Thread to Delete', resourceid });
|
|
48
|
-
await memory.deleteThread(thread.id);
|
|
49
|
-
|
|
50
|
-
const retrievedThread = await memory.getThreadById({ threadId: thread.id });
|
|
51
|
-
expect(retrievedThread).toBeNull();
|
|
52
|
-
});
|
|
53
|
-
|
|
54
|
-
it('should delete a message', async () => {
|
|
55
|
-
const thread = await memory.createThread({ title: 'Thread with Message', resourceid });
|
|
56
|
-
const message = await memory.addMessage({
|
|
57
|
-
threadId: thread.id,
|
|
58
|
-
content: 'Message to Delete',
|
|
59
|
-
role: 'user',
|
|
60
|
-
type: 'text',
|
|
61
|
-
});
|
|
62
|
-
await memory.deleteMessage(message.id);
|
|
63
|
-
|
|
64
|
-
const memoryMessages = await memory.getMessages({ threadId: thread.id });
|
|
65
|
-
const messages = memoryMessages.messages;
|
|
66
|
-
expect(messages.length).toEqual(0);
|
|
67
|
-
});
|
|
68
|
-
});
|