@mastra/upstash 0.12.1 → 0.12.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,902 @@
1
+ import { MessageList } from '@mastra/core/agent';
2
+ import type { MastraMessageContentV2 } from '@mastra/core/agent';
3
+ import { ErrorCategory, ErrorDomain, MastraError } from '@mastra/core/error';
4
+ import type { MastraMessageV1, MastraMessageV2, StorageThreadType } from '@mastra/core/memory';
5
+ import {
6
+ MemoryStorage,
7
+ TABLE_RESOURCES,
8
+ TABLE_THREADS,
9
+ resolveMessageLimit,
10
+ TABLE_MESSAGES,
11
+ } from '@mastra/core/storage';
12
+ import type { StorageGetMessagesArg, PaginationInfo, StorageResourceType } from '@mastra/core/storage';
13
+ import type { Redis } from '@upstash/redis';
14
+ import type { StoreOperationsUpstash } from '../operations';
15
+ import { ensureDate, getKey, processRecord } from '../utils';
16
+
17
+ function getThreadMessagesKey(threadId: string): string {
18
+ return `thread:${threadId}:messages`;
19
+ }
20
+
21
+ function getMessageKey(threadId: string, messageId: string): string {
22
+ const key = getKey(TABLE_MESSAGES, { threadId, id: messageId });
23
+ return key;
24
+ }
25
+
26
+ export class StoreMemoryUpstash extends MemoryStorage {
27
+ private client: Redis;
28
+ private operations: StoreOperationsUpstash;
29
+ constructor({ client, operations }: { client: Redis; operations: StoreOperationsUpstash }) {
30
+ super();
31
+ this.client = client;
32
+ this.operations = operations;
33
+ }
34
+
35
+ async getThreadById({ threadId }: { threadId: string }): Promise<StorageThreadType | null> {
36
+ try {
37
+ const thread = await this.operations.load<StorageThreadType>({
38
+ tableName: TABLE_THREADS,
39
+ keys: { id: threadId },
40
+ });
41
+
42
+ if (!thread) return null;
43
+
44
+ return {
45
+ ...thread,
46
+ createdAt: ensureDate(thread.createdAt)!,
47
+ updatedAt: ensureDate(thread.updatedAt)!,
48
+ metadata: typeof thread.metadata === 'string' ? JSON.parse(thread.metadata) : thread.metadata,
49
+ };
50
+ } catch (error) {
51
+ throw new MastraError(
52
+ {
53
+ id: 'STORAGE_UPSTASH_STORAGE_GET_THREAD_BY_ID_FAILED',
54
+ domain: ErrorDomain.STORAGE,
55
+ category: ErrorCategory.THIRD_PARTY,
56
+ details: {
57
+ threadId,
58
+ },
59
+ },
60
+ error,
61
+ );
62
+ }
63
+ }
64
+
65
+ /**
66
+ * @deprecated use getThreadsByResourceIdPaginated instead
67
+ */
68
+ async getThreadsByResourceId({ resourceId }: { resourceId: string }): Promise<StorageThreadType[]> {
69
+ try {
70
+ const pattern = `${TABLE_THREADS}:*`;
71
+ const keys = await this.operations.scanKeys(pattern);
72
+
73
+ if (keys.length === 0) {
74
+ return [];
75
+ }
76
+
77
+ const allThreads: StorageThreadType[] = [];
78
+ const pipeline = this.client.pipeline();
79
+ keys.forEach(key => pipeline.get(key));
80
+ const results = await pipeline.exec();
81
+
82
+ for (let i = 0; i < results.length; i++) {
83
+ const thread = results[i] as StorageThreadType | null;
84
+ if (thread && thread.resourceId === resourceId) {
85
+ allThreads.push({
86
+ ...thread,
87
+ createdAt: ensureDate(thread.createdAt)!,
88
+ updatedAt: ensureDate(thread.updatedAt)!,
89
+ metadata: typeof thread.metadata === 'string' ? JSON.parse(thread.metadata) : thread.metadata,
90
+ });
91
+ }
92
+ }
93
+
94
+ allThreads.sort((a, b) => b.createdAt.getTime() - a.createdAt.getTime());
95
+ return allThreads;
96
+ } catch (error) {
97
+ const mastraError = new MastraError(
98
+ {
99
+ id: 'STORAGE_UPSTASH_STORAGE_GET_THREADS_BY_RESOURCE_ID_FAILED',
100
+ domain: ErrorDomain.STORAGE,
101
+ category: ErrorCategory.THIRD_PARTY,
102
+ details: {
103
+ resourceId,
104
+ },
105
+ },
106
+ error,
107
+ );
108
+ this.logger?.trackException(mastraError);
109
+ this.logger.error(mastraError.toString());
110
+ return [];
111
+ }
112
+ }
113
+
114
+ public async getThreadsByResourceIdPaginated(args: {
115
+ resourceId: string;
116
+ page: number;
117
+ perPage: number;
118
+ }): Promise<PaginationInfo & { threads: StorageThreadType[] }> {
119
+ const { resourceId, page = 0, perPage = 100 } = args;
120
+
121
+ try {
122
+ const allThreads = await this.getThreadsByResourceId({ resourceId });
123
+
124
+ const total = allThreads.length;
125
+ const start = page * perPage;
126
+ const end = start + perPage;
127
+ const paginatedThreads = allThreads.slice(start, end);
128
+ const hasMore = end < total;
129
+
130
+ return {
131
+ threads: paginatedThreads,
132
+ total,
133
+ page,
134
+ perPage,
135
+ hasMore,
136
+ };
137
+ } catch (error) {
138
+ const mastraError = new MastraError(
139
+ {
140
+ id: 'STORAGE_UPSTASH_STORAGE_GET_THREADS_BY_RESOURCE_ID_PAGINATED_FAILED',
141
+ domain: ErrorDomain.STORAGE,
142
+ category: ErrorCategory.THIRD_PARTY,
143
+ details: {
144
+ resourceId,
145
+ page,
146
+ perPage,
147
+ },
148
+ },
149
+ error,
150
+ );
151
+ this.logger?.trackException(mastraError);
152
+ this.logger.error(mastraError.toString());
153
+ return {
154
+ threads: [],
155
+ total: 0,
156
+ page,
157
+ perPage,
158
+ hasMore: false,
159
+ };
160
+ }
161
+ }
162
+
163
+ async saveThread({ thread }: { thread: StorageThreadType }): Promise<StorageThreadType> {
164
+ try {
165
+ await this.operations.insert({
166
+ tableName: TABLE_THREADS,
167
+ record: thread,
168
+ });
169
+ return thread;
170
+ } catch (error) {
171
+ const mastraError = new MastraError(
172
+ {
173
+ id: 'STORAGE_UPSTASH_STORAGE_SAVE_THREAD_FAILED',
174
+ domain: ErrorDomain.STORAGE,
175
+ category: ErrorCategory.THIRD_PARTY,
176
+ details: {
177
+ threadId: thread.id,
178
+ },
179
+ },
180
+ error,
181
+ );
182
+ this.logger?.trackException(mastraError);
183
+ this.logger.error(mastraError.toString());
184
+ throw mastraError;
185
+ }
186
+ }
187
+
188
+ async updateThread({
189
+ id,
190
+ title,
191
+ metadata,
192
+ }: {
193
+ id: string;
194
+ title: string;
195
+ metadata: Record<string, unknown>;
196
+ }): Promise<StorageThreadType> {
197
+ const thread = await this.getThreadById({ threadId: id });
198
+ if (!thread) {
199
+ throw new MastraError({
200
+ id: 'STORAGE_UPSTASH_STORAGE_UPDATE_THREAD_FAILED',
201
+ domain: ErrorDomain.STORAGE,
202
+ category: ErrorCategory.USER,
203
+ text: `Thread ${id} not found`,
204
+ details: {
205
+ threadId: id,
206
+ },
207
+ });
208
+ }
209
+
210
+ const updatedThread = {
211
+ ...thread,
212
+ title,
213
+ metadata: {
214
+ ...thread.metadata,
215
+ ...metadata,
216
+ },
217
+ };
218
+
219
+ try {
220
+ await this.saveThread({ thread: updatedThread });
221
+ return updatedThread;
222
+ } catch (error) {
223
+ throw new MastraError(
224
+ {
225
+ id: 'STORAGE_UPSTASH_STORAGE_UPDATE_THREAD_FAILED',
226
+ domain: ErrorDomain.STORAGE,
227
+ category: ErrorCategory.THIRD_PARTY,
228
+ details: {
229
+ threadId: id,
230
+ },
231
+ },
232
+ error,
233
+ );
234
+ }
235
+ }
236
+
237
+ async deleteThread({ threadId }: { threadId: string }): Promise<void> {
238
+ // Delete thread metadata and sorted set
239
+ const threadKey = getKey(TABLE_THREADS, { id: threadId });
240
+ const threadMessagesKey = getThreadMessagesKey(threadId);
241
+ try {
242
+ const messageIds: string[] = await this.client.zrange(threadMessagesKey, 0, -1);
243
+
244
+ const pipeline = this.client.pipeline();
245
+ pipeline.del(threadKey);
246
+ pipeline.del(threadMessagesKey);
247
+
248
+ for (let i = 0; i < messageIds.length; i++) {
249
+ const messageId = messageIds[i];
250
+ const messageKey = getMessageKey(threadId, messageId as string);
251
+ pipeline.del(messageKey);
252
+ }
253
+
254
+ await pipeline.exec();
255
+
256
+ // Bulk delete all message keys for this thread if any remain
257
+ await this.operations.scanAndDelete(getMessageKey(threadId, '*'));
258
+ } catch (error) {
259
+ throw new MastraError(
260
+ {
261
+ id: 'STORAGE_UPSTASH_STORAGE_DELETE_THREAD_FAILED',
262
+ domain: ErrorDomain.STORAGE,
263
+ category: ErrorCategory.THIRD_PARTY,
264
+ details: {
265
+ threadId,
266
+ },
267
+ },
268
+ error,
269
+ );
270
+ }
271
+ }
272
+
273
+ async saveMessages(args: { messages: MastraMessageV1[]; format?: undefined | 'v1' }): Promise<MastraMessageV1[]>;
274
+ async saveMessages(args: { messages: MastraMessageV2[]; format: 'v2' }): Promise<MastraMessageV2[]>;
275
+ async saveMessages(
276
+ args: { messages: MastraMessageV1[]; format?: undefined | 'v1' } | { messages: MastraMessageV2[]; format: 'v2' },
277
+ ): Promise<MastraMessageV2[] | MastraMessageV1[]> {
278
+ const { messages, format = 'v1' } = args;
279
+ if (messages.length === 0) return [];
280
+
281
+ const threadId = messages[0]?.threadId;
282
+ try {
283
+ if (!threadId) {
284
+ throw new Error('Thread ID is required');
285
+ }
286
+
287
+ // Check if thread exists
288
+ const thread = await this.getThreadById({ threadId });
289
+ if (!thread) {
290
+ throw new Error(`Thread ${threadId} not found`);
291
+ }
292
+ } catch (error) {
293
+ throw new MastraError(
294
+ {
295
+ id: 'STORAGE_UPSTASH_STORAGE_SAVE_MESSAGES_INVALID_ARGS',
296
+ domain: ErrorDomain.STORAGE,
297
+ category: ErrorCategory.USER,
298
+ },
299
+ error,
300
+ );
301
+ }
302
+
303
+ // Add an index to each message to maintain order
304
+ const messagesWithIndex = messages.map((message, index) => {
305
+ if (!message.threadId) {
306
+ throw new Error(
307
+ `Expected to find a threadId for message, but couldn't find one. An unexpected error has occurred.`,
308
+ );
309
+ }
310
+ if (!message.resourceId) {
311
+ throw new Error(
312
+ `Expected to find a resourceId for message, but couldn't find one. An unexpected error has occurred.`,
313
+ );
314
+ }
315
+ return {
316
+ ...message,
317
+ _index: index,
318
+ };
319
+ });
320
+
321
+ // Get current thread data once (all messages belong to same thread)
322
+ const threadKey = getKey(TABLE_THREADS, { id: threadId });
323
+ const existingThread = await this.client.get<StorageThreadType>(threadKey);
324
+
325
+ try {
326
+ const batchSize = 1000;
327
+ for (let i = 0; i < messagesWithIndex.length; i += batchSize) {
328
+ const batch = messagesWithIndex.slice(i, i + batchSize);
329
+ const pipeline = this.client.pipeline();
330
+
331
+ for (const message of batch) {
332
+ const key = getMessageKey(message.threadId!, message.id);
333
+ const createdAtScore = new Date(message.createdAt).getTime();
334
+ const score = message._index !== undefined ? message._index : createdAtScore;
335
+
336
+ // Check if this message id exists in another thread
337
+ const existingKeyPattern = getMessageKey('*', message.id);
338
+ const keys = await this.operations.scanKeys(existingKeyPattern);
339
+
340
+ if (keys.length > 0) {
341
+ const pipeline2 = this.client.pipeline();
342
+ keys.forEach(key => pipeline2.get(key));
343
+ const results = await pipeline2.exec();
344
+ const existingMessages = results.filter(
345
+ (msg): msg is MastraMessageV2 | MastraMessageV1 => msg !== null,
346
+ ) as (MastraMessageV2 | MastraMessageV1)[];
347
+ for (const existingMessage of existingMessages) {
348
+ const existingMessageKey = getMessageKey(existingMessage.threadId!, existingMessage.id);
349
+ if (existingMessage && existingMessage.threadId !== message.threadId) {
350
+ pipeline.del(existingMessageKey);
351
+ // Remove from old thread's sorted set
352
+ pipeline.zrem(getThreadMessagesKey(existingMessage.threadId!), existingMessage.id);
353
+ }
354
+ }
355
+ }
356
+
357
+ // Store the message data
358
+ pipeline.set(key, message);
359
+
360
+ // Add to sorted set for this thread
361
+ pipeline.zadd(getThreadMessagesKey(message.threadId!), {
362
+ score,
363
+ member: message.id,
364
+ });
365
+ }
366
+
367
+ // Update the thread's updatedAt field (only in the first batch)
368
+ if (i === 0 && existingThread) {
369
+ const updatedThread = {
370
+ ...existingThread,
371
+ updatedAt: new Date(),
372
+ };
373
+ pipeline.set(threadKey, processRecord(TABLE_THREADS, updatedThread).processedRecord);
374
+ }
375
+
376
+ await pipeline.exec();
377
+ }
378
+
379
+ const list = new MessageList().add(messages, 'memory');
380
+ if (format === `v2`) return list.get.all.v2();
381
+ return list.get.all.v1();
382
+ } catch (error) {
383
+ throw new MastraError(
384
+ {
385
+ id: 'STORAGE_UPSTASH_STORAGE_SAVE_MESSAGES_FAILED',
386
+ domain: ErrorDomain.STORAGE,
387
+ category: ErrorCategory.THIRD_PARTY,
388
+ details: {
389
+ threadId,
390
+ },
391
+ },
392
+ error,
393
+ );
394
+ }
395
+ }
396
+
397
+ private async _getIncludedMessages(
398
+ threadId: string,
399
+ selectBy: StorageGetMessagesArg['selectBy'],
400
+ ): Promise<MastraMessageV2[] | MastraMessageV1[]> {
401
+ const messageIds = new Set<string>();
402
+ const messageIdToThreadIds: Record<string, string> = {};
403
+
404
+ // First, get specifically included messages and their context
405
+ if (selectBy?.include?.length) {
406
+ for (const item of selectBy.include) {
407
+ messageIds.add(item.id);
408
+
409
+ // Use per-include threadId if present, else fallback to main threadId
410
+ const itemThreadId = item.threadId || threadId;
411
+ messageIdToThreadIds[item.id] = itemThreadId;
412
+ const itemThreadMessagesKey = getThreadMessagesKey(itemThreadId);
413
+
414
+ // Get the rank of this message in the sorted set
415
+ const rank = await this.client.zrank(itemThreadMessagesKey, item.id);
416
+ if (rank === null) continue;
417
+
418
+ // Get previous messages if requested
419
+ if (item.withPreviousMessages) {
420
+ const start = Math.max(0, rank - item.withPreviousMessages);
421
+ const prevIds = rank === 0 ? [] : await this.client.zrange(itemThreadMessagesKey, start, rank - 1);
422
+ prevIds.forEach(id => {
423
+ messageIds.add(id as string);
424
+ messageIdToThreadIds[id as string] = itemThreadId;
425
+ });
426
+ }
427
+
428
+ // Get next messages if requested
429
+ if (item.withNextMessages) {
430
+ const nextIds = await this.client.zrange(itemThreadMessagesKey, rank + 1, rank + item.withNextMessages);
431
+ nextIds.forEach(id => {
432
+ messageIds.add(id as string);
433
+ messageIdToThreadIds[id as string] = itemThreadId;
434
+ });
435
+ }
436
+ }
437
+
438
+ const pipeline = this.client.pipeline();
439
+ Array.from(messageIds).forEach(id => {
440
+ const tId = messageIdToThreadIds[id] || threadId;
441
+ pipeline.get(getMessageKey(tId, id as string));
442
+ });
443
+ const results = await pipeline.exec();
444
+ return results.filter(result => result !== null) as MastraMessageV2[] | MastraMessageV1[];
445
+ }
446
+
447
+ return [];
448
+ }
449
+
450
+ /**
451
+ * @deprecated use getMessagesPaginated instead
452
+ */
453
+ public async getMessages(args: StorageGetMessagesArg & { format?: 'v1' }): Promise<MastraMessageV1[]>;
454
+ public async getMessages(args: StorageGetMessagesArg & { format: 'v2' }): Promise<MastraMessageV2[]>;
455
+ public async getMessages({
456
+ threadId,
457
+ selectBy,
458
+ format,
459
+ }: StorageGetMessagesArg & { format?: 'v1' | 'v2' }): Promise<MastraMessageV1[] | MastraMessageV2[]> {
460
+ const threadMessagesKey = getThreadMessagesKey(threadId);
461
+ try {
462
+ const allMessageIds = await this.client.zrange(threadMessagesKey, 0, -1);
463
+ const limit = resolveMessageLimit({ last: selectBy?.last, defaultLimit: Number.MAX_SAFE_INTEGER });
464
+
465
+ const messageIds = new Set<string>();
466
+ const messageIdToThreadIds: Record<string, string> = {};
467
+
468
+ if (limit === 0 && !selectBy?.include) {
469
+ return [];
470
+ }
471
+
472
+ // Then get the most recent messages (or all if no limit)
473
+ if (limit === Number.MAX_SAFE_INTEGER) {
474
+ // Get all messages
475
+ const allIds = await this.client.zrange(threadMessagesKey, 0, -1);
476
+ allIds.forEach(id => {
477
+ messageIds.add(id as string);
478
+ messageIdToThreadIds[id as string] = threadId;
479
+ });
480
+ } else if (limit > 0) {
481
+ // Get limited number of recent messages
482
+ const latestIds = await this.client.zrange(threadMessagesKey, -limit, -1);
483
+ latestIds.forEach(id => {
484
+ messageIds.add(id as string);
485
+ messageIdToThreadIds[id as string] = threadId;
486
+ });
487
+ }
488
+
489
+ const includedMessages = await this._getIncludedMessages(threadId, selectBy);
490
+
491
+ // Fetch all needed messages in parallel
492
+ const messages = [
493
+ ...includedMessages,
494
+ ...((
495
+ await Promise.all(
496
+ Array.from(messageIds).map(async id => {
497
+ const tId = messageIdToThreadIds[id] || threadId;
498
+ const byThreadId = await this.client.get<MastraMessageV2 & { _index?: number }>(getMessageKey(tId, id));
499
+ if (byThreadId) return byThreadId;
500
+
501
+ return null;
502
+ }),
503
+ )
504
+ ).filter(msg => msg !== null) as (MastraMessageV2 & { _index?: number })[]),
505
+ ];
506
+
507
+ // Sort messages by their position in the sorted set
508
+ messages.sort((a, b) => allMessageIds.indexOf(a!.id) - allMessageIds.indexOf(b!.id));
509
+
510
+ const seen = new Set<string>();
511
+ const dedupedMessages = messages.filter(row => {
512
+ if (seen.has(row.id)) return false;
513
+ seen.add(row.id);
514
+ return true;
515
+ });
516
+
517
+ // Remove _index before returning and handle format conversion properly
518
+ const prepared = dedupedMessages
519
+ .filter(message => message !== null && message !== undefined)
520
+ .map(message => {
521
+ const { _index, ...messageWithoutIndex } = message as MastraMessageV2 & { _index?: number };
522
+ return messageWithoutIndex as unknown as MastraMessageV1;
523
+ });
524
+
525
+ // For backward compatibility, return messages directly without using MessageList
526
+ // since MessageList has deduplication logic that can cause issues
527
+ if (format === 'v2') {
528
+ // Convert V1 format back to V2 format
529
+ return prepared.map(msg => ({
530
+ ...msg,
531
+ createdAt: new Date(msg.createdAt),
532
+ content: msg.content || { format: 2, parts: [{ type: 'text', text: '' }] },
533
+ })) as MastraMessageV2[];
534
+ }
535
+
536
+ return prepared.map(msg => ({
537
+ ...msg,
538
+ createdAt: new Date(msg.createdAt),
539
+ }));
540
+ } catch (error) {
541
+ throw new MastraError(
542
+ {
543
+ id: 'STORAGE_UPSTASH_STORAGE_GET_MESSAGES_FAILED',
544
+ domain: ErrorDomain.STORAGE,
545
+ category: ErrorCategory.THIRD_PARTY,
546
+ details: {
547
+ threadId,
548
+ },
549
+ },
550
+ error,
551
+ );
552
+ }
553
+ }
554
+
555
+ public async getMessagesPaginated(
556
+ args: StorageGetMessagesArg & {
557
+ format?: 'v1' | 'v2';
558
+ },
559
+ ): Promise<PaginationInfo & { messages: MastraMessageV1[] | MastraMessageV2[] }> {
560
+ const { threadId, selectBy, format } = args;
561
+ const { page = 0, perPage = 40, dateRange } = selectBy?.pagination || {};
562
+ const fromDate = dateRange?.start;
563
+ const toDate = dateRange?.end;
564
+ const threadMessagesKey = getThreadMessagesKey(threadId);
565
+ const messages: (MastraMessageV2 | MastraMessageV1)[] = [];
566
+
567
+ try {
568
+ const includedMessages = await this._getIncludedMessages(threadId, selectBy);
569
+
570
+ messages.push(...includedMessages);
571
+
572
+ const allMessageIds = await this.client.zrange(
573
+ threadMessagesKey,
574
+ args?.selectBy?.last ? -args.selectBy.last : 0,
575
+ -1,
576
+ );
577
+ if (allMessageIds.length === 0) {
578
+ return {
579
+ messages: [],
580
+ total: 0,
581
+ page,
582
+ perPage,
583
+ hasMore: false,
584
+ };
585
+ }
586
+
587
+ // Use pipeline to fetch all messages efficiently
588
+ const pipeline = this.client.pipeline();
589
+ allMessageIds.forEach(id => pipeline.get(getMessageKey(threadId, id as string)));
590
+ const results = await pipeline.exec();
591
+
592
+ // Process messages and apply filters - handle undefined results from pipeline
593
+ let messagesData = results.filter((msg): msg is MastraMessageV2 | MastraMessageV1 => msg !== null) as (
594
+ | MastraMessageV2
595
+ | MastraMessageV1
596
+ )[];
597
+
598
+ // Apply date filters if provided
599
+ if (fromDate) {
600
+ messagesData = messagesData.filter(msg => msg && new Date(msg.createdAt).getTime() >= fromDate.getTime());
601
+ }
602
+
603
+ if (toDate) {
604
+ messagesData = messagesData.filter(msg => msg && new Date(msg.createdAt).getTime() <= toDate.getTime());
605
+ }
606
+
607
+ // Sort messages by their position in the sorted set
608
+ messagesData.sort((a, b) => allMessageIds.indexOf(a!.id) - allMessageIds.indexOf(b!.id));
609
+
610
+ const total = messagesData.length;
611
+
612
+ const start = page * perPage;
613
+ const end = start + perPage;
614
+ const hasMore = end < total;
615
+ const paginatedMessages = messagesData.slice(start, end);
616
+
617
+ messages.push(...paginatedMessages);
618
+
619
+ const list = new MessageList().add(messages, 'memory');
620
+ const finalMessages = (format === `v2` ? list.get.all.v2() : list.get.all.v1()) as
621
+ | MastraMessageV1[]
622
+ | MastraMessageV2[];
623
+
624
+ return {
625
+ messages: finalMessages,
626
+ total,
627
+ page,
628
+ perPage,
629
+ hasMore,
630
+ };
631
+ } catch (error) {
632
+ const mastraError = new MastraError(
633
+ {
634
+ id: 'STORAGE_UPSTASH_STORAGE_GET_MESSAGES_PAGINATED_FAILED',
635
+ domain: ErrorDomain.STORAGE,
636
+ category: ErrorCategory.THIRD_PARTY,
637
+ details: {
638
+ threadId,
639
+ },
640
+ },
641
+ error,
642
+ );
643
+ this.logger.error(mastraError.toString());
644
+ this.logger?.trackException(mastraError);
645
+ return {
646
+ messages: [],
647
+ total: 0,
648
+ page,
649
+ perPage,
650
+ hasMore: false,
651
+ };
652
+ }
653
+ }
654
+
655
+ async getResourceById({ resourceId }: { resourceId: string }): Promise<StorageResourceType | null> {
656
+ try {
657
+ const key = `${TABLE_RESOURCES}:${resourceId}`;
658
+ const data = await this.client.get<StorageResourceType>(key);
659
+
660
+ if (!data) {
661
+ return null;
662
+ }
663
+
664
+ return {
665
+ ...data,
666
+ createdAt: new Date(data.createdAt),
667
+ updatedAt: new Date(data.updatedAt),
668
+ // Ensure workingMemory is always returned as a string, regardless of automatic parsing
669
+ workingMemory: typeof data.workingMemory === 'object' ? JSON.stringify(data.workingMemory) : data.workingMemory,
670
+ metadata: typeof data.metadata === 'string' ? JSON.parse(data.metadata) : data.metadata,
671
+ };
672
+ } catch (error) {
673
+ this.logger.error('Error getting resource by ID:', error);
674
+ throw error;
675
+ }
676
+ }
677
+
678
+ async saveResource({ resource }: { resource: StorageResourceType }): Promise<StorageResourceType> {
679
+ try {
680
+ const key = `${TABLE_RESOURCES}:${resource.id}`;
681
+ const serializedResource = {
682
+ ...resource,
683
+ metadata: JSON.stringify(resource.metadata),
684
+ createdAt: resource.createdAt.toISOString(),
685
+ updatedAt: resource.updatedAt.toISOString(),
686
+ };
687
+
688
+ await this.client.set(key, serializedResource);
689
+
690
+ return resource;
691
+ } catch (error) {
692
+ this.logger.error('Error saving resource:', error);
693
+ throw error;
694
+ }
695
+ }
696
+
697
+ async updateResource({
698
+ resourceId,
699
+ workingMemory,
700
+ metadata,
701
+ }: {
702
+ resourceId: string;
703
+ workingMemory?: string;
704
+ metadata?: Record<string, unknown>;
705
+ }): Promise<StorageResourceType> {
706
+ try {
707
+ const existingResource = await this.getResourceById({ resourceId });
708
+
709
+ if (!existingResource) {
710
+ // Create new resource if it doesn't exist
711
+ const newResource: StorageResourceType = {
712
+ id: resourceId,
713
+ workingMemory,
714
+ metadata: metadata || {},
715
+ createdAt: new Date(),
716
+ updatedAt: new Date(),
717
+ };
718
+ return this.saveResource({ resource: newResource });
719
+ }
720
+
721
+ const updatedResource = {
722
+ ...existingResource,
723
+ workingMemory: workingMemory !== undefined ? workingMemory : existingResource.workingMemory,
724
+ metadata: {
725
+ ...existingResource.metadata,
726
+ ...metadata,
727
+ },
728
+ updatedAt: new Date(),
729
+ };
730
+
731
+ await this.saveResource({ resource: updatedResource });
732
+ return updatedResource;
733
+ } catch (error) {
734
+ this.logger.error('Error updating resource:', error);
735
+ throw error;
736
+ }
737
+ }
738
+
739
+ async updateMessages(args: {
740
+ messages: (Partial<Omit<MastraMessageV2, 'createdAt'>> & {
741
+ id: string;
742
+ content?: { metadata?: MastraMessageContentV2['metadata']; content?: MastraMessageContentV2['content'] };
743
+ })[];
744
+ }): Promise<MastraMessageV2[]> {
745
+ const { messages } = args;
746
+
747
+ if (messages.length === 0) {
748
+ return [];
749
+ }
750
+
751
+ try {
752
+ // Get all message IDs to update
753
+ const messageIds = messages.map(m => m.id);
754
+
755
+ // Find all existing messages by scanning for their keys
756
+ const existingMessages: (MastraMessageV2 | MastraMessageV1)[] = [];
757
+ const messageIdToKey: Record<string, string> = {};
758
+
759
+ // Scan for all message keys that match any of the IDs
760
+ for (const messageId of messageIds) {
761
+ const pattern = getMessageKey('*', messageId);
762
+ const keys = await this.operations.scanKeys(pattern);
763
+
764
+ for (const key of keys) {
765
+ const message = await this.client.get<MastraMessageV2 | MastraMessageV1>(key);
766
+ if (message && message.id === messageId) {
767
+ existingMessages.push(message);
768
+ messageIdToKey[messageId] = key;
769
+ break; // Found the message, no need to continue scanning
770
+ }
771
+ }
772
+ }
773
+
774
+ if (existingMessages.length === 0) {
775
+ return [];
776
+ }
777
+
778
+ const threadIdsToUpdate = new Set<string>();
779
+ const pipeline = this.client.pipeline();
780
+
781
+ // Process each existing message for updates
782
+ for (const existingMessage of existingMessages) {
783
+ const updatePayload = messages.find(m => m.id === existingMessage.id);
784
+ if (!updatePayload) continue;
785
+
786
+ const { id, ...fieldsToUpdate } = updatePayload;
787
+ if (Object.keys(fieldsToUpdate).length === 0) continue;
788
+
789
+ // Track thread IDs that need updating
790
+ threadIdsToUpdate.add(existingMessage.threadId!);
791
+ if (updatePayload.threadId && updatePayload.threadId !== existingMessage.threadId) {
792
+ threadIdsToUpdate.add(updatePayload.threadId);
793
+ }
794
+
795
+ // Create updated message object
796
+ const updatedMessage = { ...existingMessage };
797
+
798
+ // Special handling for the content field to merge instead of overwrite
799
+ if (fieldsToUpdate.content) {
800
+ const existingContent = existingMessage.content as MastraMessageContentV2;
801
+ const newContent = {
802
+ ...existingContent,
803
+ ...fieldsToUpdate.content,
804
+ // Deep merge metadata if it exists on both
805
+ ...(existingContent?.metadata && fieldsToUpdate.content.metadata
806
+ ? {
807
+ metadata: {
808
+ ...existingContent.metadata,
809
+ ...fieldsToUpdate.content.metadata,
810
+ },
811
+ }
812
+ : {}),
813
+ };
814
+ updatedMessage.content = newContent;
815
+ }
816
+
817
+ // Update other fields
818
+ for (const key in fieldsToUpdate) {
819
+ if (Object.prototype.hasOwnProperty.call(fieldsToUpdate, key) && key !== 'content') {
820
+ (updatedMessage as any)[key] = fieldsToUpdate[key as keyof typeof fieldsToUpdate];
821
+ }
822
+ }
823
+
824
+ // Update the message in Redis
825
+ const key = messageIdToKey[id];
826
+ if (key) {
827
+ // If the message is being moved to a different thread, we need to handle the key change
828
+ if (updatePayload.threadId && updatePayload.threadId !== existingMessage.threadId) {
829
+ // Remove from old thread's sorted set
830
+ const oldThreadMessagesKey = getThreadMessagesKey(existingMessage.threadId!);
831
+ pipeline.zrem(oldThreadMessagesKey, id);
832
+
833
+ // Delete the old message key
834
+ pipeline.del(key);
835
+
836
+ // Create new message key with new threadId
837
+ const newKey = getMessageKey(updatePayload.threadId, id);
838
+ pipeline.set(newKey, updatedMessage);
839
+
840
+ // Add to new thread's sorted set
841
+ const newThreadMessagesKey = getThreadMessagesKey(updatePayload.threadId);
842
+ const score =
843
+ (updatedMessage as any)._index !== undefined
844
+ ? (updatedMessage as any)._index
845
+ : new Date(updatedMessage.createdAt).getTime();
846
+ pipeline.zadd(newThreadMessagesKey, { score, member: id });
847
+ } else {
848
+ // No thread change, just update the existing key
849
+ pipeline.set(key, updatedMessage);
850
+ }
851
+ }
852
+ }
853
+
854
+ // Update thread timestamps
855
+ const now = new Date();
856
+ for (const threadId of threadIdsToUpdate) {
857
+ if (threadId) {
858
+ const threadKey = getKey(TABLE_THREADS, { id: threadId });
859
+ const existingThread = await this.client.get<StorageThreadType>(threadKey);
860
+ if (existingThread) {
861
+ const updatedThread = {
862
+ ...existingThread,
863
+ updatedAt: now,
864
+ };
865
+ pipeline.set(threadKey, processRecord(TABLE_THREADS, updatedThread).processedRecord);
866
+ }
867
+ }
868
+ }
869
+
870
+ // Execute all updates
871
+ await pipeline.exec();
872
+
873
+ // Return the updated messages
874
+ const updatedMessages: MastraMessageV2[] = [];
875
+ for (const messageId of messageIds) {
876
+ const key = messageIdToKey[messageId];
877
+ if (key) {
878
+ const updatedMessage = await this.client.get<MastraMessageV2 | MastraMessageV1>(key);
879
+ if (updatedMessage) {
880
+ // Convert to V2 format if needed
881
+ const v2e = updatedMessage as MastraMessageV2;
882
+ updatedMessages.push(v2e);
883
+ }
884
+ }
885
+ }
886
+
887
+ return updatedMessages;
888
+ } catch (error) {
889
+ throw new MastraError(
890
+ {
891
+ id: 'STORAGE_UPSTASH_STORAGE_UPDATE_MESSAGES_FAILED',
892
+ domain: ErrorDomain.STORAGE,
893
+ category: ErrorCategory.THIRD_PARTY,
894
+ details: {
895
+ messageIds: messages.map(m => m.id).join(','),
896
+ },
897
+ },
898
+ error,
899
+ );
900
+ }
901
+ }
902
+ }