@mastra/upstash 0.12.1 → 0.12.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +7 -7
- package/CHANGELOG.md +53 -0
- package/dist/_tsup-dts-rollup.d.cts +342 -40
- package/dist/_tsup-dts-rollup.d.ts +342 -40
- package/dist/index.cjs +1133 -612
- package/dist/index.js +1134 -613
- package/docker-compose.yaml +1 -1
- package/package.json +5 -5
- package/src/storage/domains/legacy-evals/index.ts +279 -0
- package/src/storage/domains/memory/index.ts +902 -0
- package/src/storage/domains/operations/index.ts +168 -0
- package/src/storage/domains/scores/index.ts +216 -0
- package/src/storage/domains/traces/index.ts +172 -0
- package/src/storage/domains/utils.ts +57 -0
- package/src/storage/domains/workflows/index.ts +243 -0
- package/src/storage/index.test.ts +13 -0
- package/src/storage/index.ts +143 -1416
- package/src/storage/upstash.test.ts +0 -1461
|
@@ -0,0 +1,902 @@
|
|
|
1
|
+
import { MessageList } from '@mastra/core/agent';
|
|
2
|
+
import type { MastraMessageContentV2 } from '@mastra/core/agent';
|
|
3
|
+
import { ErrorCategory, ErrorDomain, MastraError } from '@mastra/core/error';
|
|
4
|
+
import type { MastraMessageV1, MastraMessageV2, StorageThreadType } from '@mastra/core/memory';
|
|
5
|
+
import {
|
|
6
|
+
MemoryStorage,
|
|
7
|
+
TABLE_RESOURCES,
|
|
8
|
+
TABLE_THREADS,
|
|
9
|
+
resolveMessageLimit,
|
|
10
|
+
TABLE_MESSAGES,
|
|
11
|
+
} from '@mastra/core/storage';
|
|
12
|
+
import type { StorageGetMessagesArg, PaginationInfo, StorageResourceType } from '@mastra/core/storage';
|
|
13
|
+
import type { Redis } from '@upstash/redis';
|
|
14
|
+
import type { StoreOperationsUpstash } from '../operations';
|
|
15
|
+
import { ensureDate, getKey, processRecord } from '../utils';
|
|
16
|
+
|
|
17
|
+
function getThreadMessagesKey(threadId: string): string {
|
|
18
|
+
return `thread:${threadId}:messages`;
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
function getMessageKey(threadId: string, messageId: string): string {
|
|
22
|
+
const key = getKey(TABLE_MESSAGES, { threadId, id: messageId });
|
|
23
|
+
return key;
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
export class StoreMemoryUpstash extends MemoryStorage {
|
|
27
|
+
private client: Redis;
|
|
28
|
+
private operations: StoreOperationsUpstash;
|
|
29
|
+
constructor({ client, operations }: { client: Redis; operations: StoreOperationsUpstash }) {
|
|
30
|
+
super();
|
|
31
|
+
this.client = client;
|
|
32
|
+
this.operations = operations;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
async getThreadById({ threadId }: { threadId: string }): Promise<StorageThreadType | null> {
|
|
36
|
+
try {
|
|
37
|
+
const thread = await this.operations.load<StorageThreadType>({
|
|
38
|
+
tableName: TABLE_THREADS,
|
|
39
|
+
keys: { id: threadId },
|
|
40
|
+
});
|
|
41
|
+
|
|
42
|
+
if (!thread) return null;
|
|
43
|
+
|
|
44
|
+
return {
|
|
45
|
+
...thread,
|
|
46
|
+
createdAt: ensureDate(thread.createdAt)!,
|
|
47
|
+
updatedAt: ensureDate(thread.updatedAt)!,
|
|
48
|
+
metadata: typeof thread.metadata === 'string' ? JSON.parse(thread.metadata) : thread.metadata,
|
|
49
|
+
};
|
|
50
|
+
} catch (error) {
|
|
51
|
+
throw new MastraError(
|
|
52
|
+
{
|
|
53
|
+
id: 'STORAGE_UPSTASH_STORAGE_GET_THREAD_BY_ID_FAILED',
|
|
54
|
+
domain: ErrorDomain.STORAGE,
|
|
55
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
56
|
+
details: {
|
|
57
|
+
threadId,
|
|
58
|
+
},
|
|
59
|
+
},
|
|
60
|
+
error,
|
|
61
|
+
);
|
|
62
|
+
}
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
/**
|
|
66
|
+
* @deprecated use getThreadsByResourceIdPaginated instead
|
|
67
|
+
*/
|
|
68
|
+
async getThreadsByResourceId({ resourceId }: { resourceId: string }): Promise<StorageThreadType[]> {
|
|
69
|
+
try {
|
|
70
|
+
const pattern = `${TABLE_THREADS}:*`;
|
|
71
|
+
const keys = await this.operations.scanKeys(pattern);
|
|
72
|
+
|
|
73
|
+
if (keys.length === 0) {
|
|
74
|
+
return [];
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
const allThreads: StorageThreadType[] = [];
|
|
78
|
+
const pipeline = this.client.pipeline();
|
|
79
|
+
keys.forEach(key => pipeline.get(key));
|
|
80
|
+
const results = await pipeline.exec();
|
|
81
|
+
|
|
82
|
+
for (let i = 0; i < results.length; i++) {
|
|
83
|
+
const thread = results[i] as StorageThreadType | null;
|
|
84
|
+
if (thread && thread.resourceId === resourceId) {
|
|
85
|
+
allThreads.push({
|
|
86
|
+
...thread,
|
|
87
|
+
createdAt: ensureDate(thread.createdAt)!,
|
|
88
|
+
updatedAt: ensureDate(thread.updatedAt)!,
|
|
89
|
+
metadata: typeof thread.metadata === 'string' ? JSON.parse(thread.metadata) : thread.metadata,
|
|
90
|
+
});
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
allThreads.sort((a, b) => b.createdAt.getTime() - a.createdAt.getTime());
|
|
95
|
+
return allThreads;
|
|
96
|
+
} catch (error) {
|
|
97
|
+
const mastraError = new MastraError(
|
|
98
|
+
{
|
|
99
|
+
id: 'STORAGE_UPSTASH_STORAGE_GET_THREADS_BY_RESOURCE_ID_FAILED',
|
|
100
|
+
domain: ErrorDomain.STORAGE,
|
|
101
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
102
|
+
details: {
|
|
103
|
+
resourceId,
|
|
104
|
+
},
|
|
105
|
+
},
|
|
106
|
+
error,
|
|
107
|
+
);
|
|
108
|
+
this.logger?.trackException(mastraError);
|
|
109
|
+
this.logger.error(mastraError.toString());
|
|
110
|
+
return [];
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
public async getThreadsByResourceIdPaginated(args: {
|
|
115
|
+
resourceId: string;
|
|
116
|
+
page: number;
|
|
117
|
+
perPage: number;
|
|
118
|
+
}): Promise<PaginationInfo & { threads: StorageThreadType[] }> {
|
|
119
|
+
const { resourceId, page = 0, perPage = 100 } = args;
|
|
120
|
+
|
|
121
|
+
try {
|
|
122
|
+
const allThreads = await this.getThreadsByResourceId({ resourceId });
|
|
123
|
+
|
|
124
|
+
const total = allThreads.length;
|
|
125
|
+
const start = page * perPage;
|
|
126
|
+
const end = start + perPage;
|
|
127
|
+
const paginatedThreads = allThreads.slice(start, end);
|
|
128
|
+
const hasMore = end < total;
|
|
129
|
+
|
|
130
|
+
return {
|
|
131
|
+
threads: paginatedThreads,
|
|
132
|
+
total,
|
|
133
|
+
page,
|
|
134
|
+
perPage,
|
|
135
|
+
hasMore,
|
|
136
|
+
};
|
|
137
|
+
} catch (error) {
|
|
138
|
+
const mastraError = new MastraError(
|
|
139
|
+
{
|
|
140
|
+
id: 'STORAGE_UPSTASH_STORAGE_GET_THREADS_BY_RESOURCE_ID_PAGINATED_FAILED',
|
|
141
|
+
domain: ErrorDomain.STORAGE,
|
|
142
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
143
|
+
details: {
|
|
144
|
+
resourceId,
|
|
145
|
+
page,
|
|
146
|
+
perPage,
|
|
147
|
+
},
|
|
148
|
+
},
|
|
149
|
+
error,
|
|
150
|
+
);
|
|
151
|
+
this.logger?.trackException(mastraError);
|
|
152
|
+
this.logger.error(mastraError.toString());
|
|
153
|
+
return {
|
|
154
|
+
threads: [],
|
|
155
|
+
total: 0,
|
|
156
|
+
page,
|
|
157
|
+
perPage,
|
|
158
|
+
hasMore: false,
|
|
159
|
+
};
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
async saveThread({ thread }: { thread: StorageThreadType }): Promise<StorageThreadType> {
|
|
164
|
+
try {
|
|
165
|
+
await this.operations.insert({
|
|
166
|
+
tableName: TABLE_THREADS,
|
|
167
|
+
record: thread,
|
|
168
|
+
});
|
|
169
|
+
return thread;
|
|
170
|
+
} catch (error) {
|
|
171
|
+
const mastraError = new MastraError(
|
|
172
|
+
{
|
|
173
|
+
id: 'STORAGE_UPSTASH_STORAGE_SAVE_THREAD_FAILED',
|
|
174
|
+
domain: ErrorDomain.STORAGE,
|
|
175
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
176
|
+
details: {
|
|
177
|
+
threadId: thread.id,
|
|
178
|
+
},
|
|
179
|
+
},
|
|
180
|
+
error,
|
|
181
|
+
);
|
|
182
|
+
this.logger?.trackException(mastraError);
|
|
183
|
+
this.logger.error(mastraError.toString());
|
|
184
|
+
throw mastraError;
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
async updateThread({
|
|
189
|
+
id,
|
|
190
|
+
title,
|
|
191
|
+
metadata,
|
|
192
|
+
}: {
|
|
193
|
+
id: string;
|
|
194
|
+
title: string;
|
|
195
|
+
metadata: Record<string, unknown>;
|
|
196
|
+
}): Promise<StorageThreadType> {
|
|
197
|
+
const thread = await this.getThreadById({ threadId: id });
|
|
198
|
+
if (!thread) {
|
|
199
|
+
throw new MastraError({
|
|
200
|
+
id: 'STORAGE_UPSTASH_STORAGE_UPDATE_THREAD_FAILED',
|
|
201
|
+
domain: ErrorDomain.STORAGE,
|
|
202
|
+
category: ErrorCategory.USER,
|
|
203
|
+
text: `Thread ${id} not found`,
|
|
204
|
+
details: {
|
|
205
|
+
threadId: id,
|
|
206
|
+
},
|
|
207
|
+
});
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
const updatedThread = {
|
|
211
|
+
...thread,
|
|
212
|
+
title,
|
|
213
|
+
metadata: {
|
|
214
|
+
...thread.metadata,
|
|
215
|
+
...metadata,
|
|
216
|
+
},
|
|
217
|
+
};
|
|
218
|
+
|
|
219
|
+
try {
|
|
220
|
+
await this.saveThread({ thread: updatedThread });
|
|
221
|
+
return updatedThread;
|
|
222
|
+
} catch (error) {
|
|
223
|
+
throw new MastraError(
|
|
224
|
+
{
|
|
225
|
+
id: 'STORAGE_UPSTASH_STORAGE_UPDATE_THREAD_FAILED',
|
|
226
|
+
domain: ErrorDomain.STORAGE,
|
|
227
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
228
|
+
details: {
|
|
229
|
+
threadId: id,
|
|
230
|
+
},
|
|
231
|
+
},
|
|
232
|
+
error,
|
|
233
|
+
);
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
async deleteThread({ threadId }: { threadId: string }): Promise<void> {
|
|
238
|
+
// Delete thread metadata and sorted set
|
|
239
|
+
const threadKey = getKey(TABLE_THREADS, { id: threadId });
|
|
240
|
+
const threadMessagesKey = getThreadMessagesKey(threadId);
|
|
241
|
+
try {
|
|
242
|
+
const messageIds: string[] = await this.client.zrange(threadMessagesKey, 0, -1);
|
|
243
|
+
|
|
244
|
+
const pipeline = this.client.pipeline();
|
|
245
|
+
pipeline.del(threadKey);
|
|
246
|
+
pipeline.del(threadMessagesKey);
|
|
247
|
+
|
|
248
|
+
for (let i = 0; i < messageIds.length; i++) {
|
|
249
|
+
const messageId = messageIds[i];
|
|
250
|
+
const messageKey = getMessageKey(threadId, messageId as string);
|
|
251
|
+
pipeline.del(messageKey);
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
await pipeline.exec();
|
|
255
|
+
|
|
256
|
+
// Bulk delete all message keys for this thread if any remain
|
|
257
|
+
await this.operations.scanAndDelete(getMessageKey(threadId, '*'));
|
|
258
|
+
} catch (error) {
|
|
259
|
+
throw new MastraError(
|
|
260
|
+
{
|
|
261
|
+
id: 'STORAGE_UPSTASH_STORAGE_DELETE_THREAD_FAILED',
|
|
262
|
+
domain: ErrorDomain.STORAGE,
|
|
263
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
264
|
+
details: {
|
|
265
|
+
threadId,
|
|
266
|
+
},
|
|
267
|
+
},
|
|
268
|
+
error,
|
|
269
|
+
);
|
|
270
|
+
}
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
async saveMessages(args: { messages: MastraMessageV1[]; format?: undefined | 'v1' }): Promise<MastraMessageV1[]>;
|
|
274
|
+
async saveMessages(args: { messages: MastraMessageV2[]; format: 'v2' }): Promise<MastraMessageV2[]>;
|
|
275
|
+
async saveMessages(
|
|
276
|
+
args: { messages: MastraMessageV1[]; format?: undefined | 'v1' } | { messages: MastraMessageV2[]; format: 'v2' },
|
|
277
|
+
): Promise<MastraMessageV2[] | MastraMessageV1[]> {
|
|
278
|
+
const { messages, format = 'v1' } = args;
|
|
279
|
+
if (messages.length === 0) return [];
|
|
280
|
+
|
|
281
|
+
const threadId = messages[0]?.threadId;
|
|
282
|
+
try {
|
|
283
|
+
if (!threadId) {
|
|
284
|
+
throw new Error('Thread ID is required');
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
// Check if thread exists
|
|
288
|
+
const thread = await this.getThreadById({ threadId });
|
|
289
|
+
if (!thread) {
|
|
290
|
+
throw new Error(`Thread ${threadId} not found`);
|
|
291
|
+
}
|
|
292
|
+
} catch (error) {
|
|
293
|
+
throw new MastraError(
|
|
294
|
+
{
|
|
295
|
+
id: 'STORAGE_UPSTASH_STORAGE_SAVE_MESSAGES_INVALID_ARGS',
|
|
296
|
+
domain: ErrorDomain.STORAGE,
|
|
297
|
+
category: ErrorCategory.USER,
|
|
298
|
+
},
|
|
299
|
+
error,
|
|
300
|
+
);
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
// Add an index to each message to maintain order
|
|
304
|
+
const messagesWithIndex = messages.map((message, index) => {
|
|
305
|
+
if (!message.threadId) {
|
|
306
|
+
throw new Error(
|
|
307
|
+
`Expected to find a threadId for message, but couldn't find one. An unexpected error has occurred.`,
|
|
308
|
+
);
|
|
309
|
+
}
|
|
310
|
+
if (!message.resourceId) {
|
|
311
|
+
throw new Error(
|
|
312
|
+
`Expected to find a resourceId for message, but couldn't find one. An unexpected error has occurred.`,
|
|
313
|
+
);
|
|
314
|
+
}
|
|
315
|
+
return {
|
|
316
|
+
...message,
|
|
317
|
+
_index: index,
|
|
318
|
+
};
|
|
319
|
+
});
|
|
320
|
+
|
|
321
|
+
// Get current thread data once (all messages belong to same thread)
|
|
322
|
+
const threadKey = getKey(TABLE_THREADS, { id: threadId });
|
|
323
|
+
const existingThread = await this.client.get<StorageThreadType>(threadKey);
|
|
324
|
+
|
|
325
|
+
try {
|
|
326
|
+
const batchSize = 1000;
|
|
327
|
+
for (let i = 0; i < messagesWithIndex.length; i += batchSize) {
|
|
328
|
+
const batch = messagesWithIndex.slice(i, i + batchSize);
|
|
329
|
+
const pipeline = this.client.pipeline();
|
|
330
|
+
|
|
331
|
+
for (const message of batch) {
|
|
332
|
+
const key = getMessageKey(message.threadId!, message.id);
|
|
333
|
+
const createdAtScore = new Date(message.createdAt).getTime();
|
|
334
|
+
const score = message._index !== undefined ? message._index : createdAtScore;
|
|
335
|
+
|
|
336
|
+
// Check if this message id exists in another thread
|
|
337
|
+
const existingKeyPattern = getMessageKey('*', message.id);
|
|
338
|
+
const keys = await this.operations.scanKeys(existingKeyPattern);
|
|
339
|
+
|
|
340
|
+
if (keys.length > 0) {
|
|
341
|
+
const pipeline2 = this.client.pipeline();
|
|
342
|
+
keys.forEach(key => pipeline2.get(key));
|
|
343
|
+
const results = await pipeline2.exec();
|
|
344
|
+
const existingMessages = results.filter(
|
|
345
|
+
(msg): msg is MastraMessageV2 | MastraMessageV1 => msg !== null,
|
|
346
|
+
) as (MastraMessageV2 | MastraMessageV1)[];
|
|
347
|
+
for (const existingMessage of existingMessages) {
|
|
348
|
+
const existingMessageKey = getMessageKey(existingMessage.threadId!, existingMessage.id);
|
|
349
|
+
if (existingMessage && existingMessage.threadId !== message.threadId) {
|
|
350
|
+
pipeline.del(existingMessageKey);
|
|
351
|
+
// Remove from old thread's sorted set
|
|
352
|
+
pipeline.zrem(getThreadMessagesKey(existingMessage.threadId!), existingMessage.id);
|
|
353
|
+
}
|
|
354
|
+
}
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
// Store the message data
|
|
358
|
+
pipeline.set(key, message);
|
|
359
|
+
|
|
360
|
+
// Add to sorted set for this thread
|
|
361
|
+
pipeline.zadd(getThreadMessagesKey(message.threadId!), {
|
|
362
|
+
score,
|
|
363
|
+
member: message.id,
|
|
364
|
+
});
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
// Update the thread's updatedAt field (only in the first batch)
|
|
368
|
+
if (i === 0 && existingThread) {
|
|
369
|
+
const updatedThread = {
|
|
370
|
+
...existingThread,
|
|
371
|
+
updatedAt: new Date(),
|
|
372
|
+
};
|
|
373
|
+
pipeline.set(threadKey, processRecord(TABLE_THREADS, updatedThread).processedRecord);
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
await pipeline.exec();
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
const list = new MessageList().add(messages, 'memory');
|
|
380
|
+
if (format === `v2`) return list.get.all.v2();
|
|
381
|
+
return list.get.all.v1();
|
|
382
|
+
} catch (error) {
|
|
383
|
+
throw new MastraError(
|
|
384
|
+
{
|
|
385
|
+
id: 'STORAGE_UPSTASH_STORAGE_SAVE_MESSAGES_FAILED',
|
|
386
|
+
domain: ErrorDomain.STORAGE,
|
|
387
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
388
|
+
details: {
|
|
389
|
+
threadId,
|
|
390
|
+
},
|
|
391
|
+
},
|
|
392
|
+
error,
|
|
393
|
+
);
|
|
394
|
+
}
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
private async _getIncludedMessages(
|
|
398
|
+
threadId: string,
|
|
399
|
+
selectBy: StorageGetMessagesArg['selectBy'],
|
|
400
|
+
): Promise<MastraMessageV2[] | MastraMessageV1[]> {
|
|
401
|
+
const messageIds = new Set<string>();
|
|
402
|
+
const messageIdToThreadIds: Record<string, string> = {};
|
|
403
|
+
|
|
404
|
+
// First, get specifically included messages and their context
|
|
405
|
+
if (selectBy?.include?.length) {
|
|
406
|
+
for (const item of selectBy.include) {
|
|
407
|
+
messageIds.add(item.id);
|
|
408
|
+
|
|
409
|
+
// Use per-include threadId if present, else fallback to main threadId
|
|
410
|
+
const itemThreadId = item.threadId || threadId;
|
|
411
|
+
messageIdToThreadIds[item.id] = itemThreadId;
|
|
412
|
+
const itemThreadMessagesKey = getThreadMessagesKey(itemThreadId);
|
|
413
|
+
|
|
414
|
+
// Get the rank of this message in the sorted set
|
|
415
|
+
const rank = await this.client.zrank(itemThreadMessagesKey, item.id);
|
|
416
|
+
if (rank === null) continue;
|
|
417
|
+
|
|
418
|
+
// Get previous messages if requested
|
|
419
|
+
if (item.withPreviousMessages) {
|
|
420
|
+
const start = Math.max(0, rank - item.withPreviousMessages);
|
|
421
|
+
const prevIds = rank === 0 ? [] : await this.client.zrange(itemThreadMessagesKey, start, rank - 1);
|
|
422
|
+
prevIds.forEach(id => {
|
|
423
|
+
messageIds.add(id as string);
|
|
424
|
+
messageIdToThreadIds[id as string] = itemThreadId;
|
|
425
|
+
});
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
// Get next messages if requested
|
|
429
|
+
if (item.withNextMessages) {
|
|
430
|
+
const nextIds = await this.client.zrange(itemThreadMessagesKey, rank + 1, rank + item.withNextMessages);
|
|
431
|
+
nextIds.forEach(id => {
|
|
432
|
+
messageIds.add(id as string);
|
|
433
|
+
messageIdToThreadIds[id as string] = itemThreadId;
|
|
434
|
+
});
|
|
435
|
+
}
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
const pipeline = this.client.pipeline();
|
|
439
|
+
Array.from(messageIds).forEach(id => {
|
|
440
|
+
const tId = messageIdToThreadIds[id] || threadId;
|
|
441
|
+
pipeline.get(getMessageKey(tId, id as string));
|
|
442
|
+
});
|
|
443
|
+
const results = await pipeline.exec();
|
|
444
|
+
return results.filter(result => result !== null) as MastraMessageV2[] | MastraMessageV1[];
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
return [];
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
/**
|
|
451
|
+
* @deprecated use getMessagesPaginated instead
|
|
452
|
+
*/
|
|
453
|
+
public async getMessages(args: StorageGetMessagesArg & { format?: 'v1' }): Promise<MastraMessageV1[]>;
|
|
454
|
+
public async getMessages(args: StorageGetMessagesArg & { format: 'v2' }): Promise<MastraMessageV2[]>;
|
|
455
|
+
public async getMessages({
|
|
456
|
+
threadId,
|
|
457
|
+
selectBy,
|
|
458
|
+
format,
|
|
459
|
+
}: StorageGetMessagesArg & { format?: 'v1' | 'v2' }): Promise<MastraMessageV1[] | MastraMessageV2[]> {
|
|
460
|
+
const threadMessagesKey = getThreadMessagesKey(threadId);
|
|
461
|
+
try {
|
|
462
|
+
const allMessageIds = await this.client.zrange(threadMessagesKey, 0, -1);
|
|
463
|
+
const limit = resolveMessageLimit({ last: selectBy?.last, defaultLimit: Number.MAX_SAFE_INTEGER });
|
|
464
|
+
|
|
465
|
+
const messageIds = new Set<string>();
|
|
466
|
+
const messageIdToThreadIds: Record<string, string> = {};
|
|
467
|
+
|
|
468
|
+
if (limit === 0 && !selectBy?.include) {
|
|
469
|
+
return [];
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
// Then get the most recent messages (or all if no limit)
|
|
473
|
+
if (limit === Number.MAX_SAFE_INTEGER) {
|
|
474
|
+
// Get all messages
|
|
475
|
+
const allIds = await this.client.zrange(threadMessagesKey, 0, -1);
|
|
476
|
+
allIds.forEach(id => {
|
|
477
|
+
messageIds.add(id as string);
|
|
478
|
+
messageIdToThreadIds[id as string] = threadId;
|
|
479
|
+
});
|
|
480
|
+
} else if (limit > 0) {
|
|
481
|
+
// Get limited number of recent messages
|
|
482
|
+
const latestIds = await this.client.zrange(threadMessagesKey, -limit, -1);
|
|
483
|
+
latestIds.forEach(id => {
|
|
484
|
+
messageIds.add(id as string);
|
|
485
|
+
messageIdToThreadIds[id as string] = threadId;
|
|
486
|
+
});
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
const includedMessages = await this._getIncludedMessages(threadId, selectBy);
|
|
490
|
+
|
|
491
|
+
// Fetch all needed messages in parallel
|
|
492
|
+
const messages = [
|
|
493
|
+
...includedMessages,
|
|
494
|
+
...((
|
|
495
|
+
await Promise.all(
|
|
496
|
+
Array.from(messageIds).map(async id => {
|
|
497
|
+
const tId = messageIdToThreadIds[id] || threadId;
|
|
498
|
+
const byThreadId = await this.client.get<MastraMessageV2 & { _index?: number }>(getMessageKey(tId, id));
|
|
499
|
+
if (byThreadId) return byThreadId;
|
|
500
|
+
|
|
501
|
+
return null;
|
|
502
|
+
}),
|
|
503
|
+
)
|
|
504
|
+
).filter(msg => msg !== null) as (MastraMessageV2 & { _index?: number })[]),
|
|
505
|
+
];
|
|
506
|
+
|
|
507
|
+
// Sort messages by their position in the sorted set
|
|
508
|
+
messages.sort((a, b) => allMessageIds.indexOf(a!.id) - allMessageIds.indexOf(b!.id));
|
|
509
|
+
|
|
510
|
+
const seen = new Set<string>();
|
|
511
|
+
const dedupedMessages = messages.filter(row => {
|
|
512
|
+
if (seen.has(row.id)) return false;
|
|
513
|
+
seen.add(row.id);
|
|
514
|
+
return true;
|
|
515
|
+
});
|
|
516
|
+
|
|
517
|
+
// Remove _index before returning and handle format conversion properly
|
|
518
|
+
const prepared = dedupedMessages
|
|
519
|
+
.filter(message => message !== null && message !== undefined)
|
|
520
|
+
.map(message => {
|
|
521
|
+
const { _index, ...messageWithoutIndex } = message as MastraMessageV2 & { _index?: number };
|
|
522
|
+
return messageWithoutIndex as unknown as MastraMessageV1;
|
|
523
|
+
});
|
|
524
|
+
|
|
525
|
+
// For backward compatibility, return messages directly without using MessageList
|
|
526
|
+
// since MessageList has deduplication logic that can cause issues
|
|
527
|
+
if (format === 'v2') {
|
|
528
|
+
// Convert V1 format back to V2 format
|
|
529
|
+
return prepared.map(msg => ({
|
|
530
|
+
...msg,
|
|
531
|
+
createdAt: new Date(msg.createdAt),
|
|
532
|
+
content: msg.content || { format: 2, parts: [{ type: 'text', text: '' }] },
|
|
533
|
+
})) as MastraMessageV2[];
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
return prepared.map(msg => ({
|
|
537
|
+
...msg,
|
|
538
|
+
createdAt: new Date(msg.createdAt),
|
|
539
|
+
}));
|
|
540
|
+
} catch (error) {
|
|
541
|
+
throw new MastraError(
|
|
542
|
+
{
|
|
543
|
+
id: 'STORAGE_UPSTASH_STORAGE_GET_MESSAGES_FAILED',
|
|
544
|
+
domain: ErrorDomain.STORAGE,
|
|
545
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
546
|
+
details: {
|
|
547
|
+
threadId,
|
|
548
|
+
},
|
|
549
|
+
},
|
|
550
|
+
error,
|
|
551
|
+
);
|
|
552
|
+
}
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
public async getMessagesPaginated(
|
|
556
|
+
args: StorageGetMessagesArg & {
|
|
557
|
+
format?: 'v1' | 'v2';
|
|
558
|
+
},
|
|
559
|
+
): Promise<PaginationInfo & { messages: MastraMessageV1[] | MastraMessageV2[] }> {
|
|
560
|
+
const { threadId, selectBy, format } = args;
|
|
561
|
+
const { page = 0, perPage = 40, dateRange } = selectBy?.pagination || {};
|
|
562
|
+
const fromDate = dateRange?.start;
|
|
563
|
+
const toDate = dateRange?.end;
|
|
564
|
+
const threadMessagesKey = getThreadMessagesKey(threadId);
|
|
565
|
+
const messages: (MastraMessageV2 | MastraMessageV1)[] = [];
|
|
566
|
+
|
|
567
|
+
try {
|
|
568
|
+
const includedMessages = await this._getIncludedMessages(threadId, selectBy);
|
|
569
|
+
|
|
570
|
+
messages.push(...includedMessages);
|
|
571
|
+
|
|
572
|
+
const allMessageIds = await this.client.zrange(
|
|
573
|
+
threadMessagesKey,
|
|
574
|
+
args?.selectBy?.last ? -args.selectBy.last : 0,
|
|
575
|
+
-1,
|
|
576
|
+
);
|
|
577
|
+
if (allMessageIds.length === 0) {
|
|
578
|
+
return {
|
|
579
|
+
messages: [],
|
|
580
|
+
total: 0,
|
|
581
|
+
page,
|
|
582
|
+
perPage,
|
|
583
|
+
hasMore: false,
|
|
584
|
+
};
|
|
585
|
+
}
|
|
586
|
+
|
|
587
|
+
// Use pipeline to fetch all messages efficiently
|
|
588
|
+
const pipeline = this.client.pipeline();
|
|
589
|
+
allMessageIds.forEach(id => pipeline.get(getMessageKey(threadId, id as string)));
|
|
590
|
+
const results = await pipeline.exec();
|
|
591
|
+
|
|
592
|
+
// Process messages and apply filters - handle undefined results from pipeline
|
|
593
|
+
let messagesData = results.filter((msg): msg is MastraMessageV2 | MastraMessageV1 => msg !== null) as (
|
|
594
|
+
| MastraMessageV2
|
|
595
|
+
| MastraMessageV1
|
|
596
|
+
)[];
|
|
597
|
+
|
|
598
|
+
// Apply date filters if provided
|
|
599
|
+
if (fromDate) {
|
|
600
|
+
messagesData = messagesData.filter(msg => msg && new Date(msg.createdAt).getTime() >= fromDate.getTime());
|
|
601
|
+
}
|
|
602
|
+
|
|
603
|
+
if (toDate) {
|
|
604
|
+
messagesData = messagesData.filter(msg => msg && new Date(msg.createdAt).getTime() <= toDate.getTime());
|
|
605
|
+
}
|
|
606
|
+
|
|
607
|
+
// Sort messages by their position in the sorted set
|
|
608
|
+
messagesData.sort((a, b) => allMessageIds.indexOf(a!.id) - allMessageIds.indexOf(b!.id));
|
|
609
|
+
|
|
610
|
+
const total = messagesData.length;
|
|
611
|
+
|
|
612
|
+
const start = page * perPage;
|
|
613
|
+
const end = start + perPage;
|
|
614
|
+
const hasMore = end < total;
|
|
615
|
+
const paginatedMessages = messagesData.slice(start, end);
|
|
616
|
+
|
|
617
|
+
messages.push(...paginatedMessages);
|
|
618
|
+
|
|
619
|
+
const list = new MessageList().add(messages, 'memory');
|
|
620
|
+
const finalMessages = (format === `v2` ? list.get.all.v2() : list.get.all.v1()) as
|
|
621
|
+
| MastraMessageV1[]
|
|
622
|
+
| MastraMessageV2[];
|
|
623
|
+
|
|
624
|
+
return {
|
|
625
|
+
messages: finalMessages,
|
|
626
|
+
total,
|
|
627
|
+
page,
|
|
628
|
+
perPage,
|
|
629
|
+
hasMore,
|
|
630
|
+
};
|
|
631
|
+
} catch (error) {
|
|
632
|
+
const mastraError = new MastraError(
|
|
633
|
+
{
|
|
634
|
+
id: 'STORAGE_UPSTASH_STORAGE_GET_MESSAGES_PAGINATED_FAILED',
|
|
635
|
+
domain: ErrorDomain.STORAGE,
|
|
636
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
637
|
+
details: {
|
|
638
|
+
threadId,
|
|
639
|
+
},
|
|
640
|
+
},
|
|
641
|
+
error,
|
|
642
|
+
);
|
|
643
|
+
this.logger.error(mastraError.toString());
|
|
644
|
+
this.logger?.trackException(mastraError);
|
|
645
|
+
return {
|
|
646
|
+
messages: [],
|
|
647
|
+
total: 0,
|
|
648
|
+
page,
|
|
649
|
+
perPage,
|
|
650
|
+
hasMore: false,
|
|
651
|
+
};
|
|
652
|
+
}
|
|
653
|
+
}
|
|
654
|
+
|
|
655
|
+
async getResourceById({ resourceId }: { resourceId: string }): Promise<StorageResourceType | null> {
|
|
656
|
+
try {
|
|
657
|
+
const key = `${TABLE_RESOURCES}:${resourceId}`;
|
|
658
|
+
const data = await this.client.get<StorageResourceType>(key);
|
|
659
|
+
|
|
660
|
+
if (!data) {
|
|
661
|
+
return null;
|
|
662
|
+
}
|
|
663
|
+
|
|
664
|
+
return {
|
|
665
|
+
...data,
|
|
666
|
+
createdAt: new Date(data.createdAt),
|
|
667
|
+
updatedAt: new Date(data.updatedAt),
|
|
668
|
+
// Ensure workingMemory is always returned as a string, regardless of automatic parsing
|
|
669
|
+
workingMemory: typeof data.workingMemory === 'object' ? JSON.stringify(data.workingMemory) : data.workingMemory,
|
|
670
|
+
metadata: typeof data.metadata === 'string' ? JSON.parse(data.metadata) : data.metadata,
|
|
671
|
+
};
|
|
672
|
+
} catch (error) {
|
|
673
|
+
this.logger.error('Error getting resource by ID:', error);
|
|
674
|
+
throw error;
|
|
675
|
+
}
|
|
676
|
+
}
|
|
677
|
+
|
|
678
|
+
async saveResource({ resource }: { resource: StorageResourceType }): Promise<StorageResourceType> {
|
|
679
|
+
try {
|
|
680
|
+
const key = `${TABLE_RESOURCES}:${resource.id}`;
|
|
681
|
+
const serializedResource = {
|
|
682
|
+
...resource,
|
|
683
|
+
metadata: JSON.stringify(resource.metadata),
|
|
684
|
+
createdAt: resource.createdAt.toISOString(),
|
|
685
|
+
updatedAt: resource.updatedAt.toISOString(),
|
|
686
|
+
};
|
|
687
|
+
|
|
688
|
+
await this.client.set(key, serializedResource);
|
|
689
|
+
|
|
690
|
+
return resource;
|
|
691
|
+
} catch (error) {
|
|
692
|
+
this.logger.error('Error saving resource:', error);
|
|
693
|
+
throw error;
|
|
694
|
+
}
|
|
695
|
+
}
|
|
696
|
+
|
|
697
|
+
async updateResource({
|
|
698
|
+
resourceId,
|
|
699
|
+
workingMemory,
|
|
700
|
+
metadata,
|
|
701
|
+
}: {
|
|
702
|
+
resourceId: string;
|
|
703
|
+
workingMemory?: string;
|
|
704
|
+
metadata?: Record<string, unknown>;
|
|
705
|
+
}): Promise<StorageResourceType> {
|
|
706
|
+
try {
|
|
707
|
+
const existingResource = await this.getResourceById({ resourceId });
|
|
708
|
+
|
|
709
|
+
if (!existingResource) {
|
|
710
|
+
// Create new resource if it doesn't exist
|
|
711
|
+
const newResource: StorageResourceType = {
|
|
712
|
+
id: resourceId,
|
|
713
|
+
workingMemory,
|
|
714
|
+
metadata: metadata || {},
|
|
715
|
+
createdAt: new Date(),
|
|
716
|
+
updatedAt: new Date(),
|
|
717
|
+
};
|
|
718
|
+
return this.saveResource({ resource: newResource });
|
|
719
|
+
}
|
|
720
|
+
|
|
721
|
+
const updatedResource = {
|
|
722
|
+
...existingResource,
|
|
723
|
+
workingMemory: workingMemory !== undefined ? workingMemory : existingResource.workingMemory,
|
|
724
|
+
metadata: {
|
|
725
|
+
...existingResource.metadata,
|
|
726
|
+
...metadata,
|
|
727
|
+
},
|
|
728
|
+
updatedAt: new Date(),
|
|
729
|
+
};
|
|
730
|
+
|
|
731
|
+
await this.saveResource({ resource: updatedResource });
|
|
732
|
+
return updatedResource;
|
|
733
|
+
} catch (error) {
|
|
734
|
+
this.logger.error('Error updating resource:', error);
|
|
735
|
+
throw error;
|
|
736
|
+
}
|
|
737
|
+
}
|
|
738
|
+
|
|
739
|
+
async updateMessages(args: {
|
|
740
|
+
messages: (Partial<Omit<MastraMessageV2, 'createdAt'>> & {
|
|
741
|
+
id: string;
|
|
742
|
+
content?: { metadata?: MastraMessageContentV2['metadata']; content?: MastraMessageContentV2['content'] };
|
|
743
|
+
})[];
|
|
744
|
+
}): Promise<MastraMessageV2[]> {
|
|
745
|
+
const { messages } = args;
|
|
746
|
+
|
|
747
|
+
if (messages.length === 0) {
|
|
748
|
+
return [];
|
|
749
|
+
}
|
|
750
|
+
|
|
751
|
+
try {
|
|
752
|
+
// Get all message IDs to update
|
|
753
|
+
const messageIds = messages.map(m => m.id);
|
|
754
|
+
|
|
755
|
+
// Find all existing messages by scanning for their keys
|
|
756
|
+
const existingMessages: (MastraMessageV2 | MastraMessageV1)[] = [];
|
|
757
|
+
const messageIdToKey: Record<string, string> = {};
|
|
758
|
+
|
|
759
|
+
// Scan for all message keys that match any of the IDs
|
|
760
|
+
for (const messageId of messageIds) {
|
|
761
|
+
const pattern = getMessageKey('*', messageId);
|
|
762
|
+
const keys = await this.operations.scanKeys(pattern);
|
|
763
|
+
|
|
764
|
+
for (const key of keys) {
|
|
765
|
+
const message = await this.client.get<MastraMessageV2 | MastraMessageV1>(key);
|
|
766
|
+
if (message && message.id === messageId) {
|
|
767
|
+
existingMessages.push(message);
|
|
768
|
+
messageIdToKey[messageId] = key;
|
|
769
|
+
break; // Found the message, no need to continue scanning
|
|
770
|
+
}
|
|
771
|
+
}
|
|
772
|
+
}
|
|
773
|
+
|
|
774
|
+
if (existingMessages.length === 0) {
|
|
775
|
+
return [];
|
|
776
|
+
}
|
|
777
|
+
|
|
778
|
+
const threadIdsToUpdate = new Set<string>();
|
|
779
|
+
const pipeline = this.client.pipeline();
|
|
780
|
+
|
|
781
|
+
// Process each existing message for updates
|
|
782
|
+
for (const existingMessage of existingMessages) {
|
|
783
|
+
const updatePayload = messages.find(m => m.id === existingMessage.id);
|
|
784
|
+
if (!updatePayload) continue;
|
|
785
|
+
|
|
786
|
+
const { id, ...fieldsToUpdate } = updatePayload;
|
|
787
|
+
if (Object.keys(fieldsToUpdate).length === 0) continue;
|
|
788
|
+
|
|
789
|
+
// Track thread IDs that need updating
|
|
790
|
+
threadIdsToUpdate.add(existingMessage.threadId!);
|
|
791
|
+
if (updatePayload.threadId && updatePayload.threadId !== existingMessage.threadId) {
|
|
792
|
+
threadIdsToUpdate.add(updatePayload.threadId);
|
|
793
|
+
}
|
|
794
|
+
|
|
795
|
+
// Create updated message object
|
|
796
|
+
const updatedMessage = { ...existingMessage };
|
|
797
|
+
|
|
798
|
+
// Special handling for the content field to merge instead of overwrite
|
|
799
|
+
if (fieldsToUpdate.content) {
|
|
800
|
+
const existingContent = existingMessage.content as MastraMessageContentV2;
|
|
801
|
+
const newContent = {
|
|
802
|
+
...existingContent,
|
|
803
|
+
...fieldsToUpdate.content,
|
|
804
|
+
// Deep merge metadata if it exists on both
|
|
805
|
+
...(existingContent?.metadata && fieldsToUpdate.content.metadata
|
|
806
|
+
? {
|
|
807
|
+
metadata: {
|
|
808
|
+
...existingContent.metadata,
|
|
809
|
+
...fieldsToUpdate.content.metadata,
|
|
810
|
+
},
|
|
811
|
+
}
|
|
812
|
+
: {}),
|
|
813
|
+
};
|
|
814
|
+
updatedMessage.content = newContent;
|
|
815
|
+
}
|
|
816
|
+
|
|
817
|
+
// Update other fields
|
|
818
|
+
for (const key in fieldsToUpdate) {
|
|
819
|
+
if (Object.prototype.hasOwnProperty.call(fieldsToUpdate, key) && key !== 'content') {
|
|
820
|
+
(updatedMessage as any)[key] = fieldsToUpdate[key as keyof typeof fieldsToUpdate];
|
|
821
|
+
}
|
|
822
|
+
}
|
|
823
|
+
|
|
824
|
+
// Update the message in Redis
|
|
825
|
+
const key = messageIdToKey[id];
|
|
826
|
+
if (key) {
|
|
827
|
+
// If the message is being moved to a different thread, we need to handle the key change
|
|
828
|
+
if (updatePayload.threadId && updatePayload.threadId !== existingMessage.threadId) {
|
|
829
|
+
// Remove from old thread's sorted set
|
|
830
|
+
const oldThreadMessagesKey = getThreadMessagesKey(existingMessage.threadId!);
|
|
831
|
+
pipeline.zrem(oldThreadMessagesKey, id);
|
|
832
|
+
|
|
833
|
+
// Delete the old message key
|
|
834
|
+
pipeline.del(key);
|
|
835
|
+
|
|
836
|
+
// Create new message key with new threadId
|
|
837
|
+
const newKey = getMessageKey(updatePayload.threadId, id);
|
|
838
|
+
pipeline.set(newKey, updatedMessage);
|
|
839
|
+
|
|
840
|
+
// Add to new thread's sorted set
|
|
841
|
+
const newThreadMessagesKey = getThreadMessagesKey(updatePayload.threadId);
|
|
842
|
+
const score =
|
|
843
|
+
(updatedMessage as any)._index !== undefined
|
|
844
|
+
? (updatedMessage as any)._index
|
|
845
|
+
: new Date(updatedMessage.createdAt).getTime();
|
|
846
|
+
pipeline.zadd(newThreadMessagesKey, { score, member: id });
|
|
847
|
+
} else {
|
|
848
|
+
// No thread change, just update the existing key
|
|
849
|
+
pipeline.set(key, updatedMessage);
|
|
850
|
+
}
|
|
851
|
+
}
|
|
852
|
+
}
|
|
853
|
+
|
|
854
|
+
// Update thread timestamps
|
|
855
|
+
const now = new Date();
|
|
856
|
+
for (const threadId of threadIdsToUpdate) {
|
|
857
|
+
if (threadId) {
|
|
858
|
+
const threadKey = getKey(TABLE_THREADS, { id: threadId });
|
|
859
|
+
const existingThread = await this.client.get<StorageThreadType>(threadKey);
|
|
860
|
+
if (existingThread) {
|
|
861
|
+
const updatedThread = {
|
|
862
|
+
...existingThread,
|
|
863
|
+
updatedAt: now,
|
|
864
|
+
};
|
|
865
|
+
pipeline.set(threadKey, processRecord(TABLE_THREADS, updatedThread).processedRecord);
|
|
866
|
+
}
|
|
867
|
+
}
|
|
868
|
+
}
|
|
869
|
+
|
|
870
|
+
// Execute all updates
|
|
871
|
+
await pipeline.exec();
|
|
872
|
+
|
|
873
|
+
// Return the updated messages
|
|
874
|
+
const updatedMessages: MastraMessageV2[] = [];
|
|
875
|
+
for (const messageId of messageIds) {
|
|
876
|
+
const key = messageIdToKey[messageId];
|
|
877
|
+
if (key) {
|
|
878
|
+
const updatedMessage = await this.client.get<MastraMessageV2 | MastraMessageV1>(key);
|
|
879
|
+
if (updatedMessage) {
|
|
880
|
+
// Convert to V2 format if needed
|
|
881
|
+
const v2e = updatedMessage as MastraMessageV2;
|
|
882
|
+
updatedMessages.push(v2e);
|
|
883
|
+
}
|
|
884
|
+
}
|
|
885
|
+
}
|
|
886
|
+
|
|
887
|
+
return updatedMessages;
|
|
888
|
+
} catch (error) {
|
|
889
|
+
throw new MastraError(
|
|
890
|
+
{
|
|
891
|
+
id: 'STORAGE_UPSTASH_STORAGE_UPDATE_MESSAGES_FAILED',
|
|
892
|
+
domain: ErrorDomain.STORAGE,
|
|
893
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
894
|
+
details: {
|
|
895
|
+
messageIds: messages.map(m => m.id).join(','),
|
|
896
|
+
},
|
|
897
|
+
},
|
|
898
|
+
error,
|
|
899
|
+
);
|
|
900
|
+
}
|
|
901
|
+
}
|
|
902
|
+
}
|