@mastra/memory 0.0.2-alpha.2 → 0.0.2-alpha.21
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +147 -0
- package/dist/index.d.ts +2 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/kv/upstash.d.ts +72 -0
- package/dist/kv/upstash.d.ts.map +1 -0
- package/dist/memory.cjs.development.js +953 -480
- package/dist/memory.cjs.development.js.map +1 -1
- package/dist/memory.cjs.production.min.js +1 -1
- package/dist/memory.cjs.production.min.js.map +1 -1
- package/dist/memory.esm.js +953 -481
- package/dist/memory.esm.js.map +1 -1
- package/dist/postgres/index.d.ts +29 -18
- package/dist/postgres/index.d.ts.map +1 -0
- package/package.json +2 -3
- package/src/index.ts +1 -0
- package/src/kv/upstash.test.ts +253 -0
- package/src/kv/upstash.ts +298 -0
- package/src/postgres/index.ts +177 -276
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
import { MessageType, ThreadType } from '@mastra/core';
|
|
2
|
+
import { randomUUID } from 'crypto';
|
|
3
|
+
import dotenv from 'dotenv';
|
|
4
|
+
|
|
5
|
+
import { UpstashKVMemory } from './upstash';
|
|
6
|
+
|
|
7
|
+
dotenv.config();
|
|
8
|
+
|
|
9
|
+
// Ensure environment variables are set
|
|
10
|
+
if (!process.env.KV_REST_API_URL || !process.env.KV_REST_API_TOKEN) {
|
|
11
|
+
throw new Error('Required Vercel KV environment variables are not set');
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
describe('KVMemory Integration Tests', () => {
|
|
15
|
+
let memory: UpstashKVMemory;
|
|
16
|
+
const testPrefix = `test_${Date.now()}`;
|
|
17
|
+
|
|
18
|
+
beforeAll(() => {
|
|
19
|
+
memory = new UpstashKVMemory({
|
|
20
|
+
url: process.env.KV_REST_API_URL!,
|
|
21
|
+
token: process.env.KV_REST_API_TOKEN!,
|
|
22
|
+
prefix: testPrefix,
|
|
23
|
+
maxTokens: 1000,
|
|
24
|
+
});
|
|
25
|
+
});
|
|
26
|
+
|
|
27
|
+
afterAll(async () => {
|
|
28
|
+
await memory.drop();
|
|
29
|
+
});
|
|
30
|
+
|
|
31
|
+
describe('Thread Operations', () => {
|
|
32
|
+
let testThread: ThreadType;
|
|
33
|
+
|
|
34
|
+
beforeEach(() => {
|
|
35
|
+
testThread = {
|
|
36
|
+
id: randomUUID(),
|
|
37
|
+
title: 'Integration Test Thread',
|
|
38
|
+
createdAt: new Date(),
|
|
39
|
+
updatedAt: new Date(),
|
|
40
|
+
resourceid: 'test-resource',
|
|
41
|
+
metadata: { test: true },
|
|
42
|
+
};
|
|
43
|
+
});
|
|
44
|
+
|
|
45
|
+
it('should create and retrieve a thread', async () => {
|
|
46
|
+
const saved = await memory.saveThread({ thread: testThread });
|
|
47
|
+
expect(saved).toEqual(testThread);
|
|
48
|
+
|
|
49
|
+
const retrieved = await memory.getThreadById({ threadId: testThread.id });
|
|
50
|
+
expect(retrieved).toEqual(testThread);
|
|
51
|
+
});
|
|
52
|
+
|
|
53
|
+
it('should find threads by resource ID', async () => {
|
|
54
|
+
const thread1 = { ...testThread, id: randomUUID() };
|
|
55
|
+
const thread2 = {
|
|
56
|
+
...testThread,
|
|
57
|
+
id: randomUUID(),
|
|
58
|
+
resourceid: 'different-resource',
|
|
59
|
+
};
|
|
60
|
+
|
|
61
|
+
await memory.saveThread({ thread: thread1 });
|
|
62
|
+
await memory.saveThread({ thread: thread2 });
|
|
63
|
+
|
|
64
|
+
const threads = await memory.getThreadsByResourceId({
|
|
65
|
+
resourceid: 'test-resource',
|
|
66
|
+
});
|
|
67
|
+
|
|
68
|
+
expect(threads.length).toBeGreaterThanOrEqual(1);
|
|
69
|
+
expect(threads.some(t => t.id === thread1.id)).toBe(true);
|
|
70
|
+
expect(threads.some(t => t.id === thread2.id)).toBe(false);
|
|
71
|
+
});
|
|
72
|
+
|
|
73
|
+
it('should update thread title and metadata', async () => {
|
|
74
|
+
await memory.saveThread({ thread: testThread });
|
|
75
|
+
|
|
76
|
+
const updatedTitle = 'Updated Title';
|
|
77
|
+
const updatedMetadata = { updated: true };
|
|
78
|
+
|
|
79
|
+
const updated = await memory.updateThread(testThread.id, updatedTitle, updatedMetadata);
|
|
80
|
+
|
|
81
|
+
expect(updated.title).toBe(updatedTitle);
|
|
82
|
+
expect(updated.metadata).toEqual(updatedMetadata);
|
|
83
|
+
expect(updated.updatedAt).not.toEqual(testThread.updatedAt);
|
|
84
|
+
});
|
|
85
|
+
|
|
86
|
+
it('should delete a thread', async () => {
|
|
87
|
+
await memory.saveThread({ thread: testThread });
|
|
88
|
+
await memory.deleteThread(testThread.id);
|
|
89
|
+
|
|
90
|
+
const retrieved = await memory.getThreadById({
|
|
91
|
+
threadId: testThread.id,
|
|
92
|
+
});
|
|
93
|
+
expect(retrieved).toBeNull();
|
|
94
|
+
});
|
|
95
|
+
});
|
|
96
|
+
|
|
97
|
+
describe('Message Operations', () => {
|
|
98
|
+
let testThread: ThreadType;
|
|
99
|
+
let testMessage: MessageType;
|
|
100
|
+
|
|
101
|
+
beforeEach(async () => {
|
|
102
|
+
testThread = {
|
|
103
|
+
id: randomUUID(),
|
|
104
|
+
title: 'Test Thread for Messages',
|
|
105
|
+
createdAt: new Date(),
|
|
106
|
+
updatedAt: new Date(),
|
|
107
|
+
resourceid: 'test-resource',
|
|
108
|
+
metadata: {},
|
|
109
|
+
};
|
|
110
|
+
|
|
111
|
+
testMessage = {
|
|
112
|
+
id: randomUUID(),
|
|
113
|
+
content: 'Test message content',
|
|
114
|
+
role: 'user' as const,
|
|
115
|
+
type: 'text' as const,
|
|
116
|
+
createdAt: new Date(),
|
|
117
|
+
threadId: testThread.id,
|
|
118
|
+
};
|
|
119
|
+
|
|
120
|
+
await memory.saveThread({ thread: testThread });
|
|
121
|
+
});
|
|
122
|
+
|
|
123
|
+
it('should save and retrieve messages', async () => {
|
|
124
|
+
await memory.saveMessages({ messages: [testMessage] });
|
|
125
|
+
|
|
126
|
+
const { messages, uiMessages } = await memory.getMessages({
|
|
127
|
+
threadId: testThread.id,
|
|
128
|
+
});
|
|
129
|
+
|
|
130
|
+
expect(messages.length).toBe(1);
|
|
131
|
+
expect(messages?.[0]?.content).toBe(testMessage.content);
|
|
132
|
+
expect(uiMessages.length).toBe(1);
|
|
133
|
+
});
|
|
134
|
+
|
|
135
|
+
it('should handle message deletion', async () => {
|
|
136
|
+
const message1 = { ...testMessage, id: randomUUID() };
|
|
137
|
+
const message2 = { ...testMessage, id: randomUUID() };
|
|
138
|
+
|
|
139
|
+
await memory.saveMessages({ messages: [message1, message2] });
|
|
140
|
+
await memory.deleteMessage(message1.id);
|
|
141
|
+
|
|
142
|
+
const { messages } = await memory.getMessages({
|
|
143
|
+
threadId: testThread.id,
|
|
144
|
+
});
|
|
145
|
+
|
|
146
|
+
expect(messages.length).toBe(1);
|
|
147
|
+
expect(messages?.[0]?.id).toBe(message2.id);
|
|
148
|
+
});
|
|
149
|
+
|
|
150
|
+
it('should respect token limits in context window', async () => {
|
|
151
|
+
// Create a very long message that should exceed our small token limit
|
|
152
|
+
const longMessage: MessageType = {
|
|
153
|
+
...testMessage,
|
|
154
|
+
content: 'a'.repeat(100), // This should be about 25 tokens
|
|
155
|
+
id: randomUUID(),
|
|
156
|
+
};
|
|
157
|
+
|
|
158
|
+
const shortMessage: MessageType = {
|
|
159
|
+
...testMessage,
|
|
160
|
+
content: 'Short message', // This should be about 3 tokens
|
|
161
|
+
id: randomUUID(),
|
|
162
|
+
};
|
|
163
|
+
|
|
164
|
+
await memory.saveMessages({ messages: [longMessage, shortMessage] });
|
|
165
|
+
|
|
166
|
+
// Create a new memory instance with a very small token limit
|
|
167
|
+
const lowTokenMemory = new UpstashKVMemory({
|
|
168
|
+
url: process.env.KV_REST_API_URL!,
|
|
169
|
+
token: process.env.KV_REST_API_TOKEN!,
|
|
170
|
+
prefix: testPrefix,
|
|
171
|
+
maxTokens: 5, // Only allow 5 tokens
|
|
172
|
+
});
|
|
173
|
+
|
|
174
|
+
const context = await lowTokenMemory.getContextWindow({
|
|
175
|
+
threadId: testThread.id,
|
|
176
|
+
format: 'raw',
|
|
177
|
+
});
|
|
178
|
+
|
|
179
|
+
// Should only get the short message
|
|
180
|
+
expect(context.length).toBe(1);
|
|
181
|
+
expect(context?.[0]?.id).toBe(shortMessage.id);
|
|
182
|
+
});
|
|
183
|
+
|
|
184
|
+
it('should filter messages by date range', async () => {
|
|
185
|
+
const oldMessage: MessageType = {
|
|
186
|
+
...testMessage,
|
|
187
|
+
id: randomUUID(),
|
|
188
|
+
createdAt: new Date('2023-01-01'),
|
|
189
|
+
};
|
|
190
|
+
|
|
191
|
+
const newMessage: MessageType = {
|
|
192
|
+
...testMessage,
|
|
193
|
+
id: randomUUID(),
|
|
194
|
+
createdAt: new Date('2024-01-01'),
|
|
195
|
+
};
|
|
196
|
+
|
|
197
|
+
await memory.saveMessages({ messages: [oldMessage, newMessage] });
|
|
198
|
+
|
|
199
|
+
const context = await memory.getContextWindow({
|
|
200
|
+
threadId: testThread.id,
|
|
201
|
+
format: 'raw',
|
|
202
|
+
startDate: new Date('2023-12-31'),
|
|
203
|
+
endDate: new Date('2024-12-31'),
|
|
204
|
+
});
|
|
205
|
+
|
|
206
|
+
expect(context.length).toBe(1);
|
|
207
|
+
expect(context?.[0]?.id).toBe(newMessage.id);
|
|
208
|
+
});
|
|
209
|
+
});
|
|
210
|
+
|
|
211
|
+
describe('Tool Cache Operations', () => {
|
|
212
|
+
let testThread: ThreadType;
|
|
213
|
+
|
|
214
|
+
beforeEach(async () => {
|
|
215
|
+
testThread = {
|
|
216
|
+
id: randomUUID(),
|
|
217
|
+
title: 'Test Thread for Tool Cache',
|
|
218
|
+
createdAt: new Date(),
|
|
219
|
+
updatedAt: new Date(),
|
|
220
|
+
resourceid: 'test-resource',
|
|
221
|
+
metadata: {},
|
|
222
|
+
};
|
|
223
|
+
|
|
224
|
+
await memory.saveThread({ thread: testThread });
|
|
225
|
+
});
|
|
226
|
+
|
|
227
|
+
it('should cache and validate tool call arguments', async () => {
|
|
228
|
+
const toolArgs = { test: true };
|
|
229
|
+
const toolName = 'testTool';
|
|
230
|
+
|
|
231
|
+
const message: MessageType = {
|
|
232
|
+
id: randomUUID(),
|
|
233
|
+
content: 'Tool test',
|
|
234
|
+
role: 'assistant' as const,
|
|
235
|
+
type: 'text' as const,
|
|
236
|
+
createdAt: new Date(),
|
|
237
|
+
threadId: testThread.id,
|
|
238
|
+
toolCallArgs: [toolArgs],
|
|
239
|
+
toolNames: [toolName],
|
|
240
|
+
};
|
|
241
|
+
|
|
242
|
+
await memory.saveMessages({ messages: [message] });
|
|
243
|
+
|
|
244
|
+
const result = await memory.getToolResult({
|
|
245
|
+
threadId: testThread.id,
|
|
246
|
+
toolArgs,
|
|
247
|
+
toolName,
|
|
248
|
+
});
|
|
249
|
+
|
|
250
|
+
expect(result).toBeNull();
|
|
251
|
+
});
|
|
252
|
+
});
|
|
253
|
+
});
|
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
import { MastraMemory, MessageType as BaseMastraMessageType, ThreadType, MessageResponse } from '@mastra/core';
|
|
2
|
+
import { Redis } from '@upstash/redis';
|
|
3
|
+
import { ToolResultPart, Message as AiMessage, TextPart } from 'ai';
|
|
4
|
+
import crypto from 'crypto';
|
|
5
|
+
|
|
6
|
+
interface ToolCacheData {
|
|
7
|
+
expireAt: string;
|
|
8
|
+
result?: ToolResultPart['result'];
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
interface MessageType extends BaseMastraMessageType {
|
|
12
|
+
tokens?: number;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
// Internal type for serialized thread data
|
|
16
|
+
interface SerializedThreadType extends Omit<ThreadType, 'createdAt' | 'updatedAt'> {
|
|
17
|
+
createdAt: string;
|
|
18
|
+
updatedAt: string;
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
export class UpstashKVMemory extends MastraMemory {
|
|
22
|
+
private prefix: string;
|
|
23
|
+
|
|
24
|
+
kv: Redis;
|
|
25
|
+
|
|
26
|
+
constructor(config: { url: string; token: string; prefix?: string; maxTokens?: number }) {
|
|
27
|
+
super();
|
|
28
|
+
this.prefix = config.prefix || 'mastra';
|
|
29
|
+
this.MAX_CONTEXT_TOKENS = config.maxTokens;
|
|
30
|
+
|
|
31
|
+
this.kv = new Redis({
|
|
32
|
+
url: config.url,
|
|
33
|
+
token: config.token,
|
|
34
|
+
});
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
private getThreadKey(threadId: string): string {
|
|
38
|
+
return `${this.prefix}:thread:${threadId}`;
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
private getMessagesKey(threadId: string): string {
|
|
42
|
+
return `${this.prefix}:messages:${threadId}`;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
private getToolCacheKey(hashedArgs: string): string {
|
|
46
|
+
return `${this.prefix}:tool:${hashedArgs}`;
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
async getThreadById({ threadId }: { threadId: string }): Promise<ThreadType | null> {
|
|
50
|
+
const thread = await this.kv.get<SerializedThreadType>(this.getThreadKey(threadId));
|
|
51
|
+
return thread ? this.parseThread(thread) : null;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
async getThreadsByResourceId({ resourceid }: { resourceid: string }): Promise<ThreadType[]> {
|
|
55
|
+
const pattern = `${this.prefix}:thread:*`;
|
|
56
|
+
const keys = await this.kv.keys(pattern);
|
|
57
|
+
|
|
58
|
+
const threads = await Promise.all(keys.map(key => this.kv.get<SerializedThreadType>(key)));
|
|
59
|
+
|
|
60
|
+
return threads
|
|
61
|
+
.filter(thread => thread?.resourceid === resourceid)
|
|
62
|
+
.map(thread => this.parseThread(thread as SerializedThreadType));
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
async saveThread({ thread }: { thread: ThreadType }): Promise<ThreadType> {
|
|
66
|
+
const key = this.getThreadKey(thread.id);
|
|
67
|
+
const serializedThread: SerializedThreadType = {
|
|
68
|
+
...thread,
|
|
69
|
+
createdAt: thread.createdAt.toISOString(),
|
|
70
|
+
updatedAt: thread.updatedAt.toISOString(),
|
|
71
|
+
};
|
|
72
|
+
await this.kv.set(key, serializedThread);
|
|
73
|
+
return thread;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
async updateThread(id: string, title: string, metadata: Record<string, unknown>): Promise<ThreadType> {
|
|
77
|
+
const key = this.getThreadKey(id);
|
|
78
|
+
const thread = await this.kv.get<SerializedThreadType>(key);
|
|
79
|
+
|
|
80
|
+
if (!thread) {
|
|
81
|
+
throw new Error(`Thread ${id} not found`);
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
const updatedThread: SerializedThreadType = {
|
|
85
|
+
...thread,
|
|
86
|
+
title,
|
|
87
|
+
metadata,
|
|
88
|
+
updatedAt: new Date().toISOString(),
|
|
89
|
+
};
|
|
90
|
+
|
|
91
|
+
await this.kv.set(key, updatedThread);
|
|
92
|
+
return this.parseThread(updatedThread);
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
async deleteThread(id: string): Promise<void> {
|
|
96
|
+
await this.kv.del(this.getThreadKey(id));
|
|
97
|
+
await this.kv.del(this.getMessagesKey(id));
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
/**
|
|
101
|
+
* Tool Cache
|
|
102
|
+
*/
|
|
103
|
+
|
|
104
|
+
async validateToolCallArgs({ hashedArgs }: { hashedArgs: string }): Promise<boolean> {
|
|
105
|
+
const cacheKey = this.getToolCacheKey(hashedArgs);
|
|
106
|
+
const cached = await this.kv.get<ToolCacheData>(cacheKey);
|
|
107
|
+
return !!cached && new Date(cached.expireAt) > new Date();
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
async getToolResult({
|
|
111
|
+
threadId,
|
|
112
|
+
toolArgs,
|
|
113
|
+
toolName,
|
|
114
|
+
}: {
|
|
115
|
+
threadId: string;
|
|
116
|
+
toolArgs: Record<string, unknown>;
|
|
117
|
+
toolName: string;
|
|
118
|
+
}): Promise<ToolResultPart['result'] | null> {
|
|
119
|
+
const hashedToolArgs = crypto
|
|
120
|
+
.createHash('sha256')
|
|
121
|
+
.update(JSON.stringify({ args: toolArgs, threadId, toolName }))
|
|
122
|
+
.digest('hex');
|
|
123
|
+
|
|
124
|
+
const cacheKey = this.getToolCacheKey(hashedToolArgs);
|
|
125
|
+
const cached = await this.kv.get<ToolCacheData>(cacheKey);
|
|
126
|
+
|
|
127
|
+
if (cached && new Date(cached.expireAt) > new Date()) {
|
|
128
|
+
return cached.result || null;
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
return null;
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
async getContextWindow<T extends 'raw' | 'core_message'>({
|
|
135
|
+
threadId,
|
|
136
|
+
startDate,
|
|
137
|
+
endDate,
|
|
138
|
+
// @ts-ignore
|
|
139
|
+
format = 'raw' as T,
|
|
140
|
+
}: {
|
|
141
|
+
format?: T;
|
|
142
|
+
threadId: string;
|
|
143
|
+
startDate?: Date;
|
|
144
|
+
endDate?: Date;
|
|
145
|
+
}) {
|
|
146
|
+
const messagesKey = this.getMessagesKey(threadId);
|
|
147
|
+
const messages = await this.kv.lrange<MessageType>(messagesKey, 0, -1);
|
|
148
|
+
|
|
149
|
+
let filteredMessages = messages.filter(msg => msg.type === 'text' || msg.type === 'tool-result');
|
|
150
|
+
|
|
151
|
+
if (startDate) {
|
|
152
|
+
filteredMessages = filteredMessages.filter(msg => new Date(msg.createdAt) >= startDate);
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
if (endDate) {
|
|
156
|
+
filteredMessages = filteredMessages.filter(msg => new Date(msg.createdAt) <= endDate);
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
if (this.MAX_CONTEXT_TOKENS) {
|
|
160
|
+
let totalTokens = 0;
|
|
161
|
+
const messagesWithinTokenLimit: MessageType[] = [];
|
|
162
|
+
|
|
163
|
+
// Process messages from newest to oldest
|
|
164
|
+
for (const message of filteredMessages.reverse()) {
|
|
165
|
+
const content =
|
|
166
|
+
message.role === 'assistant'
|
|
167
|
+
? (message.content as Array<TextPart>)[0]?.text || ''
|
|
168
|
+
: (message.content as string);
|
|
169
|
+
|
|
170
|
+
// Use a more aggressive token estimation
|
|
171
|
+
// Roughly estimate 1 token per 4 characters
|
|
172
|
+
const tokens = Math.ceil(content.length / 4);
|
|
173
|
+
|
|
174
|
+
// Check if adding this message would exceed the token limit
|
|
175
|
+
if (totalTokens + tokens > this.MAX_CONTEXT_TOKENS) {
|
|
176
|
+
break;
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
totalTokens += tokens;
|
|
180
|
+
messagesWithinTokenLimit.unshift({
|
|
181
|
+
...message,
|
|
182
|
+
tokens,
|
|
183
|
+
});
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
// Return messages in chronological order
|
|
187
|
+
return this.parseMessages(messagesWithinTokenLimit) as MessageResponse<T>;
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
return this.parseMessages(filteredMessages) as MessageResponse<T>;
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
/**
|
|
194
|
+
* Messages
|
|
195
|
+
*/
|
|
196
|
+
|
|
197
|
+
async getMessages({ threadId }: { threadId: string }): Promise<{ messages: MessageType[]; uiMessages: AiMessage[] }> {
|
|
198
|
+
const messagesKey = this.getMessagesKey(threadId);
|
|
199
|
+
const messages = await this.kv.lrange<MessageType>(messagesKey, 0, -1);
|
|
200
|
+
const parsedMessages = this.parseMessages(messages);
|
|
201
|
+
const uiMessages = this.convertToUIMessages(parsedMessages);
|
|
202
|
+
|
|
203
|
+
return { messages: parsedMessages, uiMessages };
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
async saveMessages({ messages }: { messages: MessageType[] }): Promise<MessageType[]> {
|
|
207
|
+
const processedMessages: MessageType[] = [];
|
|
208
|
+
|
|
209
|
+
for (const message of messages) {
|
|
210
|
+
const { threadId, toolCallArgs, toolNames, createdAt } = message;
|
|
211
|
+
const messagesKey = this.getMessagesKey(threadId);
|
|
212
|
+
|
|
213
|
+
const processedMessage = { ...message };
|
|
214
|
+
|
|
215
|
+
if (message.type === 'text') {
|
|
216
|
+
const content =
|
|
217
|
+
message.role === 'assistant'
|
|
218
|
+
? (message.content as Array<TextPart>)[0]?.text || ''
|
|
219
|
+
: (message.content as string);
|
|
220
|
+
processedMessage.tokens = this.estimateTokens(content);
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
if (toolCallArgs?.length) {
|
|
224
|
+
const hashedToolCallArgs = toolCallArgs.map((args, index) =>
|
|
225
|
+
crypto
|
|
226
|
+
.createHash('sha256')
|
|
227
|
+
.update(JSON.stringify({ args, threadId, toolName: toolNames?.[index] }))
|
|
228
|
+
.digest('hex'),
|
|
229
|
+
);
|
|
230
|
+
|
|
231
|
+
let validArgExists = true;
|
|
232
|
+
for (const hashedArg of hashedToolCallArgs) {
|
|
233
|
+
const isValid = await this.validateToolCallArgs({ hashedArgs: hashedArg });
|
|
234
|
+
if (!isValid) {
|
|
235
|
+
validArgExists = false;
|
|
236
|
+
break;
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
const expireAt = validArgExists ? createdAt : new Date(createdAt.getTime() + 5 * 60 * 1000); // 5 minutes
|
|
241
|
+
|
|
242
|
+
for (const hashedArg of hashedToolCallArgs) {
|
|
243
|
+
const cacheKey = this.getToolCacheKey(hashedArg);
|
|
244
|
+
await this.kv.set(cacheKey, { expireAt: expireAt.toISOString() } as ToolCacheData);
|
|
245
|
+
}
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
await this.kv.rpush(messagesKey, processedMessage);
|
|
249
|
+
processedMessages.push(processedMessage);
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
return processedMessages;
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
async deleteMessage(id: string): Promise<void> {
|
|
256
|
+
const pattern = `${this.prefix}:messages:*`;
|
|
257
|
+
const keys = await this.kv.keys(pattern);
|
|
258
|
+
|
|
259
|
+
for (const key of keys) {
|
|
260
|
+
const messages = await this.kv.lrange<MessageType>(key, 0, -1);
|
|
261
|
+
const filteredMessages = messages.filter(msg => msg.id !== id);
|
|
262
|
+
|
|
263
|
+
if (messages.length !== filteredMessages.length) {
|
|
264
|
+
await this.kv.del(key);
|
|
265
|
+
if (filteredMessages.length > 0) {
|
|
266
|
+
await this.kv.rpush(key, ...filteredMessages);
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
}
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
/**
|
|
273
|
+
* Cleanup
|
|
274
|
+
*/
|
|
275
|
+
|
|
276
|
+
async drop(): Promise<void> {
|
|
277
|
+
const pattern = `${this.prefix}:*`;
|
|
278
|
+
const keys = await this.kv.keys(pattern);
|
|
279
|
+
if (keys.length > 0) {
|
|
280
|
+
await this.kv.del(...keys);
|
|
281
|
+
}
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
parseThread(thread: SerializedThreadType): ThreadType {
|
|
285
|
+
return {
|
|
286
|
+
...thread,
|
|
287
|
+
createdAt: new Date(thread.createdAt),
|
|
288
|
+
updatedAt: new Date(thread.updatedAt),
|
|
289
|
+
};
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
parseMessages(messages: MessageType[]): MessageType[] {
|
|
293
|
+
return messages.map(message => ({
|
|
294
|
+
...message,
|
|
295
|
+
createdAt: new Date(message.createdAt),
|
|
296
|
+
}));
|
|
297
|
+
}
|
|
298
|
+
}
|