@mastra/pg 0.12.3 → 0.12.4-alpha.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,900 @@
1
+ import { MessageList } from '@mastra/core/agent';
2
+ import type { MastraMessageContentV2 } from '@mastra/core/agent';
3
+ import { ErrorCategory, ErrorDomain, MastraError } from '@mastra/core/error';
4
+ import type { MastraMessageV1, MastraMessageV2, StorageThreadType } from '@mastra/core/memory';
5
+ import {
6
+ MemoryStorage,
7
+ resolveMessageLimit,
8
+ TABLE_MESSAGES,
9
+ TABLE_RESOURCES,
10
+ TABLE_THREADS,
11
+ } from '@mastra/core/storage';
12
+ import type { StorageGetMessagesArg, PaginationInfo, StorageResourceType } from '@mastra/core/storage';
13
+ import type { IDatabase } from 'pg-promise';
14
+ import type { StoreOperationsPG } from '../operations';
15
+ import { getTableName, getSchemaName } from '../utils';
16
+
17
+ export class MemoryPG extends MemoryStorage {
18
+ private client: IDatabase<{}>;
19
+ private schema: string;
20
+ private operations: StoreOperationsPG;
21
+
22
+ constructor({
23
+ client,
24
+ schema,
25
+ operations,
26
+ }: {
27
+ client: IDatabase<{}>;
28
+ schema: string;
29
+ operations: StoreOperationsPG;
30
+ }) {
31
+ super();
32
+ this.client = client;
33
+ this.schema = schema;
34
+ this.operations = operations;
35
+ }
36
+
37
+ async getThreadById({ threadId }: { threadId: string }): Promise<StorageThreadType | null> {
38
+ try {
39
+ const tableName = getTableName({ indexName: TABLE_THREADS, schemaName: getSchemaName(this.schema) });
40
+
41
+ const thread = await this.client.oneOrNone<StorageThreadType & { createdAtZ: Date; updatedAtZ: Date }>(
42
+ `SELECT * FROM ${tableName} WHERE id = $1`,
43
+ [threadId],
44
+ );
45
+
46
+ if (!thread) {
47
+ return null;
48
+ }
49
+
50
+ return {
51
+ id: thread.id,
52
+ resourceId: thread.resourceId,
53
+ title: thread.title,
54
+ metadata: typeof thread.metadata === 'string' ? JSON.parse(thread.metadata) : thread.metadata,
55
+ createdAt: thread.createdAtZ || thread.createdAt,
56
+ updatedAt: thread.updatedAtZ || thread.updatedAt,
57
+ };
58
+ } catch (error) {
59
+ throw new MastraError(
60
+ {
61
+ id: 'MASTRA_STORAGE_PG_STORE_GET_THREAD_BY_ID_FAILED',
62
+ domain: ErrorDomain.STORAGE,
63
+ category: ErrorCategory.THIRD_PARTY,
64
+ details: {
65
+ threadId,
66
+ },
67
+ },
68
+ error,
69
+ );
70
+ }
71
+ }
72
+
73
+ /**
74
+ * @deprecated use getThreadsByResourceIdPaginated instead
75
+ */
76
+ public async getThreadsByResourceId(args: { resourceId: string }): Promise<StorageThreadType[]> {
77
+ const { resourceId } = args;
78
+
79
+ try {
80
+ const tableName = getTableName({ indexName: TABLE_THREADS, schemaName: getSchemaName(this.schema) });
81
+ const baseQuery = `FROM ${tableName} WHERE "resourceId" = $1`;
82
+ const queryParams: any[] = [resourceId];
83
+
84
+ const dataQuery = `SELECT id, "resourceId", title, metadata, "createdAt", "updatedAt" ${baseQuery} ORDER BY "createdAt" DESC`;
85
+ const rows = await this.client.manyOrNone(dataQuery, queryParams);
86
+ return (rows || []).map(thread => ({
87
+ ...thread,
88
+ metadata: typeof thread.metadata === 'string' ? JSON.parse(thread.metadata) : thread.metadata,
89
+ createdAt: thread.createdAt,
90
+ updatedAt: thread.updatedAt,
91
+ }));
92
+ } catch (error) {
93
+ this.logger.error(`Error getting threads for resource ${resourceId}:`, error);
94
+ return [];
95
+ }
96
+ }
97
+
98
+ public async getThreadsByResourceIdPaginated(args: {
99
+ resourceId: string;
100
+ page: number;
101
+ perPage: number;
102
+ }): Promise<PaginationInfo & { threads: StorageThreadType[] }> {
103
+ const { resourceId, page = 0, perPage: perPageInput } = args;
104
+ try {
105
+ const tableName = getTableName({ indexName: TABLE_THREADS, schemaName: getSchemaName(this.schema) });
106
+ const baseQuery = `FROM ${tableName} WHERE "resourceId" = $1`;
107
+ const queryParams: any[] = [resourceId];
108
+ const perPage = perPageInput !== undefined ? perPageInput : 100;
109
+ const currentOffset = page * perPage;
110
+
111
+ const countQuery = `SELECT COUNT(*) ${baseQuery}`;
112
+ const countResult = await this.client.one(countQuery, queryParams);
113
+ const total = parseInt(countResult.count, 10);
114
+
115
+ if (total === 0) {
116
+ return {
117
+ threads: [],
118
+ total: 0,
119
+ page,
120
+ perPage,
121
+ hasMore: false,
122
+ };
123
+ }
124
+
125
+ const dataQuery = `SELECT id, "resourceId", title, metadata, "createdAt", "updatedAt" ${baseQuery} ORDER BY "createdAt" DESC LIMIT $2 OFFSET $3`;
126
+ const rows = await this.client.manyOrNone(dataQuery, [...queryParams, perPage, currentOffset]);
127
+
128
+ const threads = (rows || []).map(thread => ({
129
+ ...thread,
130
+ metadata: typeof thread.metadata === 'string' ? JSON.parse(thread.metadata) : thread.metadata,
131
+ createdAt: thread.createdAt, // Assuming already Date objects or ISO strings
132
+ updatedAt: thread.updatedAt,
133
+ }));
134
+
135
+ return {
136
+ threads,
137
+ total,
138
+ page,
139
+ perPage,
140
+ hasMore: currentOffset + threads.length < total,
141
+ };
142
+ } catch (error) {
143
+ const mastraError = new MastraError(
144
+ {
145
+ id: 'MASTRA_STORAGE_PG_STORE_GET_THREADS_BY_RESOURCE_ID_PAGINATED_FAILED',
146
+ domain: ErrorDomain.STORAGE,
147
+ category: ErrorCategory.THIRD_PARTY,
148
+ details: {
149
+ resourceId,
150
+ page,
151
+ },
152
+ },
153
+ error,
154
+ );
155
+ this.logger?.error?.(mastraError.toString());
156
+ this.logger?.trackException(mastraError);
157
+ return { threads: [], total: 0, page, perPage: perPageInput || 100, hasMore: false };
158
+ }
159
+ }
160
+
161
+ async saveThread({ thread }: { thread: StorageThreadType }): Promise<StorageThreadType> {
162
+ try {
163
+ const tableName = getTableName({ indexName: TABLE_THREADS, schemaName: getSchemaName(this.schema) });
164
+ await this.client.none(
165
+ `INSERT INTO ${tableName} (
166
+ id,
167
+ "resourceId",
168
+ title,
169
+ metadata,
170
+ "createdAt",
171
+ "createdAtZ",
172
+ "updatedAt",
173
+ "updatedAtZ"
174
+ ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
175
+ ON CONFLICT (id) DO UPDATE SET
176
+ "resourceId" = EXCLUDED."resourceId",
177
+ title = EXCLUDED.title,
178
+ metadata = EXCLUDED.metadata,
179
+ "createdAt" = EXCLUDED."createdAt",
180
+ "createdAtZ" = EXCLUDED."createdAtZ",
181
+ "updatedAt" = EXCLUDED."updatedAt",
182
+ "updatedAtZ" = EXCLUDED."updatedAtZ"`,
183
+ [
184
+ thread.id,
185
+ thread.resourceId,
186
+ thread.title,
187
+ thread.metadata ? JSON.stringify(thread.metadata) : null,
188
+ thread.createdAt,
189
+ thread.createdAt,
190
+ thread.updatedAt,
191
+ thread.updatedAt,
192
+ ],
193
+ );
194
+
195
+ return thread;
196
+ } catch (error) {
197
+ throw new MastraError(
198
+ {
199
+ id: 'MASTRA_STORAGE_PG_STORE_SAVE_THREAD_FAILED',
200
+ domain: ErrorDomain.STORAGE,
201
+ category: ErrorCategory.THIRD_PARTY,
202
+ details: {
203
+ threadId: thread.id,
204
+ },
205
+ },
206
+ error,
207
+ );
208
+ }
209
+ }
210
+
211
+ async updateThread({
212
+ id,
213
+ title,
214
+ metadata,
215
+ }: {
216
+ id: string;
217
+ title: string;
218
+ metadata: Record<string, unknown>;
219
+ }): Promise<StorageThreadType> {
220
+ const threadTableName = getTableName({ indexName: TABLE_THREADS, schemaName: getSchemaName(this.schema) });
221
+ // First get the existing thread to merge metadata
222
+ const existingThread = await this.getThreadById({ threadId: id });
223
+ if (!existingThread) {
224
+ throw new MastraError({
225
+ id: 'MASTRA_STORAGE_PG_STORE_UPDATE_THREAD_FAILED',
226
+ domain: ErrorDomain.STORAGE,
227
+ category: ErrorCategory.USER,
228
+ text: `Thread ${id} not found`,
229
+ details: {
230
+ threadId: id,
231
+ title,
232
+ },
233
+ });
234
+ }
235
+
236
+ // Merge the existing metadata with the new metadata
237
+ const mergedMetadata = {
238
+ ...existingThread.metadata,
239
+ ...metadata,
240
+ };
241
+
242
+ try {
243
+ const thread = await this.client.one<StorageThreadType & { createdAtZ: Date; updatedAtZ: Date }>(
244
+ `UPDATE ${threadTableName}
245
+ SET
246
+ title = $1,
247
+ metadata = $2,
248
+ "updatedAt" = $3,
249
+ "updatedAtZ" = $3
250
+ WHERE id = $4
251
+ RETURNING *
252
+ `,
253
+ [title, mergedMetadata, new Date().toISOString(), id],
254
+ );
255
+
256
+ return {
257
+ id: thread.id,
258
+ resourceId: thread.resourceId,
259
+ title: thread.title,
260
+ metadata: typeof thread.metadata === 'string' ? JSON.parse(thread.metadata) : thread.metadata,
261
+ createdAt: thread.createdAtZ || thread.createdAt,
262
+ updatedAt: thread.updatedAtZ || thread.updatedAt,
263
+ };
264
+ } catch (error) {
265
+ throw new MastraError(
266
+ {
267
+ id: 'MASTRA_STORAGE_PG_STORE_UPDATE_THREAD_FAILED',
268
+ domain: ErrorDomain.STORAGE,
269
+ category: ErrorCategory.THIRD_PARTY,
270
+ details: {
271
+ threadId: id,
272
+ title,
273
+ },
274
+ },
275
+ error,
276
+ );
277
+ }
278
+ }
279
+
280
+ async deleteThread({ threadId }: { threadId: string }): Promise<void> {
281
+ try {
282
+ const tableName = getTableName({ indexName: TABLE_MESSAGES, schemaName: getSchemaName(this.schema) });
283
+ const threadTableName = getTableName({ indexName: TABLE_THREADS, schemaName: getSchemaName(this.schema) });
284
+ await this.client.tx(async t => {
285
+ // First delete all messages associated with this thread
286
+ await t.none(`DELETE FROM ${tableName} WHERE thread_id = $1`, [threadId]);
287
+
288
+ // Then delete the thread
289
+ await t.none(`DELETE FROM ${threadTableName} WHERE id = $1`, [threadId]);
290
+ });
291
+ } catch (error) {
292
+ throw new MastraError(
293
+ {
294
+ id: 'MASTRA_STORAGE_PG_STORE_DELETE_THREAD_FAILED',
295
+ domain: ErrorDomain.STORAGE,
296
+ category: ErrorCategory.THIRD_PARTY,
297
+ details: {
298
+ threadId,
299
+ },
300
+ },
301
+ error,
302
+ );
303
+ }
304
+ }
305
+
306
+ private async _getIncludedMessages({
307
+ threadId,
308
+ selectBy,
309
+ orderByStatement,
310
+ }: {
311
+ threadId: string;
312
+ selectBy: StorageGetMessagesArg['selectBy'];
313
+ orderByStatement: string;
314
+ }) {
315
+ const include = selectBy?.include;
316
+ if (!include) return null;
317
+
318
+ const unionQueries: string[] = [];
319
+ const params: any[] = [];
320
+ let paramIdx = 1;
321
+ const tableName = getTableName({ indexName: TABLE_MESSAGES, schemaName: getSchemaName(this.schema) });
322
+
323
+ for (const inc of include) {
324
+ const { id, withPreviousMessages = 0, withNextMessages = 0 } = inc;
325
+ // if threadId is provided, use it, otherwise use threadId from args
326
+ const searchId = inc.threadId || threadId;
327
+ unionQueries.push(
328
+ `
329
+ SELECT * FROM (
330
+ WITH ordered_messages AS (
331
+ SELECT
332
+ *,
333
+ ROW_NUMBER() OVER (${orderByStatement}) as row_num
334
+ FROM ${tableName}
335
+ WHERE thread_id = $${paramIdx}
336
+ )
337
+ SELECT
338
+ m.id,
339
+ m.content,
340
+ m.role,
341
+ m.type,
342
+ m."createdAt",
343
+ m.thread_id AS "threadId",
344
+ m."resourceId"
345
+ FROM ordered_messages m
346
+ WHERE m.id = $${paramIdx + 1}
347
+ OR EXISTS (
348
+ SELECT 1 FROM ordered_messages target
349
+ WHERE target.id = $${paramIdx + 1}
350
+ AND (
351
+ -- Get previous messages based on the max withPreviousMessages
352
+ (m.row_num <= target.row_num + $${paramIdx + 2} AND m.row_num > target.row_num)
353
+ OR
354
+ -- Get next messages based on the max withNextMessages
355
+ (m.row_num >= target.row_num - $${paramIdx + 3} AND m.row_num < target.row_num)
356
+ )
357
+ )
358
+ ) AS query_${paramIdx}
359
+ `, // Keep ASC for final sorting after fetching context
360
+ );
361
+ params.push(searchId, id, withPreviousMessages, withNextMessages);
362
+ paramIdx += 4;
363
+ }
364
+ const finalQuery = unionQueries.join(' UNION ALL ') + ' ORDER BY "createdAt" ASC';
365
+ const includedRows = await this.client.manyOrNone(finalQuery, params);
366
+ const seen = new Set<string>();
367
+ const dedupedRows = includedRows.filter(row => {
368
+ if (seen.has(row.id)) return false;
369
+ seen.add(row.id);
370
+ return true;
371
+ });
372
+ return dedupedRows;
373
+ }
374
+
375
+ /**
376
+ * @deprecated use getMessagesPaginated instead
377
+ */
378
+ public async getMessages(args: StorageGetMessagesArg & { format?: 'v1' }): Promise<MastraMessageV1[]>;
379
+ public async getMessages(args: StorageGetMessagesArg & { format: 'v2' }): Promise<MastraMessageV2[]>;
380
+ public async getMessages(
381
+ args: StorageGetMessagesArg & {
382
+ format?: 'v1' | 'v2';
383
+ },
384
+ ): Promise<MastraMessageV1[] | MastraMessageV2[]> {
385
+ const { threadId, format, selectBy } = args;
386
+
387
+ const selectStatement = `SELECT id, content, role, type, "createdAt", thread_id AS "threadId", "resourceId"`;
388
+ const orderByStatement = `ORDER BY "createdAt" DESC`;
389
+ const limit = resolveMessageLimit({ last: selectBy?.last, defaultLimit: 40 });
390
+
391
+ try {
392
+ let rows: any[] = [];
393
+ const include = selectBy?.include || [];
394
+
395
+ if (include?.length) {
396
+ const includeMessages = await this._getIncludedMessages({ threadId, selectBy, orderByStatement });
397
+ if (includeMessages) {
398
+ rows.push(...includeMessages);
399
+ }
400
+ }
401
+
402
+ const excludeIds = rows.map(m => m.id);
403
+ const tableName = getTableName({ indexName: TABLE_MESSAGES, schemaName: getSchemaName(this.schema) });
404
+ const excludeIdsParam = excludeIds.map((_, idx) => `$${idx + 2}`).join(', ');
405
+ let query = `${selectStatement} FROM ${tableName} WHERE thread_id = $1
406
+ ${excludeIds.length ? `AND id NOT IN (${excludeIdsParam})` : ''}
407
+ ${orderByStatement}
408
+ LIMIT $${excludeIds.length + 2}
409
+ `;
410
+ const queryParams: any[] = [threadId, ...excludeIds, limit];
411
+ const remainingRows = await this.client.manyOrNone(query, queryParams);
412
+ rows.push(...remainingRows);
413
+
414
+ const fetchedMessages = (rows || []).map(message => {
415
+ if (typeof message.content === 'string') {
416
+ try {
417
+ message.content = JSON.parse(message.content);
418
+ } catch {
419
+ /* ignore */
420
+ }
421
+ }
422
+ if (message.type === 'v2') delete message.type;
423
+ return message as MastraMessageV1;
424
+ });
425
+
426
+ // Sort all messages by creation date
427
+ const sortedMessages = fetchedMessages.sort(
428
+ (a, b) => new Date(a.createdAt).getTime() - new Date(b.createdAt).getTime(),
429
+ );
430
+
431
+ return format === 'v2'
432
+ ? sortedMessages.map(
433
+ m =>
434
+ ({ ...m, content: m.content || { format: 2, parts: [{ type: 'text', text: '' }] } }) as MastraMessageV2,
435
+ )
436
+ : sortedMessages;
437
+ } catch (error) {
438
+ const mastraError = new MastraError(
439
+ {
440
+ id: 'MASTRA_STORAGE_PG_STORE_GET_MESSAGES_FAILED',
441
+ domain: ErrorDomain.STORAGE,
442
+ category: ErrorCategory.THIRD_PARTY,
443
+ details: {
444
+ threadId,
445
+ },
446
+ },
447
+ error,
448
+ );
449
+ this.logger?.error?.(mastraError.toString());
450
+ this.logger?.trackException(mastraError);
451
+ return [];
452
+ }
453
+ }
454
+
455
+ public async getMessagesPaginated(
456
+ args: StorageGetMessagesArg & {
457
+ format?: 'v1' | 'v2';
458
+ },
459
+ ): Promise<PaginationInfo & { messages: MastraMessageV1[] | MastraMessageV2[] }> {
460
+ const { threadId, format, selectBy } = args;
461
+ const { page = 0, perPage: perPageInput, dateRange } = selectBy?.pagination || {};
462
+ const fromDate = dateRange?.start;
463
+ const toDate = dateRange?.end;
464
+
465
+ const selectStatement = `SELECT id, content, role, type, "createdAt", thread_id AS "threadId", "resourceId"`;
466
+ const orderByStatement = `ORDER BY "createdAt" DESC`;
467
+
468
+ const messages: MastraMessageV2[] = [];
469
+
470
+ if (selectBy?.include?.length) {
471
+ const includeMessages = await this._getIncludedMessages({ threadId, selectBy, orderByStatement });
472
+ if (includeMessages) {
473
+ messages.push(...includeMessages);
474
+ }
475
+ }
476
+
477
+ try {
478
+ const perPage =
479
+ perPageInput !== undefined ? perPageInput : resolveMessageLimit({ last: selectBy?.last, defaultLimit: 40 });
480
+ const currentOffset = page * perPage;
481
+
482
+ const conditions: string[] = [`thread_id = $1`];
483
+ const queryParams: any[] = [threadId];
484
+ let paramIndex = 2;
485
+
486
+ if (fromDate) {
487
+ conditions.push(`"createdAt" >= $${paramIndex++}`);
488
+ queryParams.push(fromDate);
489
+ }
490
+ if (toDate) {
491
+ conditions.push(`"createdAt" <= $${paramIndex++}`);
492
+ queryParams.push(toDate);
493
+ }
494
+ const whereClause = conditions.length > 0 ? `WHERE ${conditions.join(' AND ')}` : '';
495
+
496
+ const tableName = getTableName({ indexName: TABLE_MESSAGES, schemaName: getSchemaName(this.schema) });
497
+ const countQuery = `SELECT COUNT(*) FROM ${tableName} ${whereClause}`;
498
+ const countResult = await this.client.one(countQuery, queryParams);
499
+ const total = parseInt(countResult.count, 10);
500
+
501
+ if (total === 0 && messages.length === 0) {
502
+ return {
503
+ messages: [],
504
+ total: 0,
505
+ page,
506
+ perPage,
507
+ hasMore: false,
508
+ };
509
+ }
510
+
511
+ const excludeIds = messages.map(m => m.id);
512
+ const excludeIdsParam = excludeIds.map((_, idx) => `$${idx + paramIndex}`).join(', ');
513
+ paramIndex += excludeIds.length;
514
+
515
+ const dataQuery = `${selectStatement} FROM ${tableName} ${whereClause} ${excludeIds.length ? `AND id NOT IN (${excludeIdsParam})` : ''}${orderByStatement} LIMIT $${paramIndex++} OFFSET $${paramIndex++}`;
516
+ const rows = await this.client.manyOrNone(dataQuery, [...queryParams, ...excludeIds, perPage, currentOffset]);
517
+ messages.push(...(rows || []));
518
+
519
+ // Parse content back to objects if they were stringified during storage
520
+ const messagesWithParsedContent = messages.map(message => {
521
+ if (typeof message.content === 'string') {
522
+ try {
523
+ return { ...message, content: JSON.parse(message.content) };
524
+ } catch {
525
+ // If parsing fails, leave as string (V1 message)
526
+ return message;
527
+ }
528
+ }
529
+ return message;
530
+ });
531
+
532
+ const list = new MessageList().add(messagesWithParsedContent, 'memory');
533
+ const messagesToReturn = format === `v2` ? list.get.all.v2() : list.get.all.v1();
534
+
535
+ return {
536
+ messages: messagesToReturn,
537
+ total,
538
+ page,
539
+ perPage,
540
+ hasMore: currentOffset + rows.length < total,
541
+ };
542
+ } catch (error) {
543
+ const mastraError = new MastraError(
544
+ {
545
+ id: 'MASTRA_STORAGE_PG_STORE_GET_MESSAGES_PAGINATED_FAILED',
546
+ domain: ErrorDomain.STORAGE,
547
+ category: ErrorCategory.THIRD_PARTY,
548
+ details: {
549
+ threadId,
550
+ page,
551
+ },
552
+ },
553
+ error,
554
+ );
555
+ this.logger?.error?.(mastraError.toString());
556
+ this.logger?.trackException(mastraError);
557
+ return { messages: [], total: 0, page, perPage: perPageInput || 40, hasMore: false };
558
+ }
559
+ }
560
+
561
+ async saveMessages(args: { messages: MastraMessageV1[]; format?: undefined | 'v1' }): Promise<MastraMessageV1[]>;
562
+ async saveMessages(args: { messages: MastraMessageV2[]; format: 'v2' }): Promise<MastraMessageV2[]>;
563
+ async saveMessages({
564
+ messages,
565
+ format,
566
+ }:
567
+ | { messages: MastraMessageV1[]; format?: undefined | 'v1' }
568
+ | { messages: MastraMessageV2[]; format: 'v2' }): Promise<MastraMessageV2[] | MastraMessageV1[]> {
569
+ if (messages.length === 0) return messages;
570
+
571
+ const threadId = messages[0]?.threadId;
572
+ if (!threadId) {
573
+ throw new MastraError({
574
+ id: 'MASTRA_STORAGE_PG_STORE_SAVE_MESSAGES_FAILED',
575
+ domain: ErrorDomain.STORAGE,
576
+ category: ErrorCategory.THIRD_PARTY,
577
+ text: `Thread ID is required`,
578
+ });
579
+ }
580
+
581
+ // Check if thread exists
582
+ const thread = await this.getThreadById({ threadId });
583
+ if (!thread) {
584
+ throw new MastraError({
585
+ id: 'MASTRA_STORAGE_PG_STORE_SAVE_MESSAGES_FAILED',
586
+ domain: ErrorDomain.STORAGE,
587
+ category: ErrorCategory.THIRD_PARTY,
588
+ text: `Thread ${threadId} not found`,
589
+ details: {
590
+ threadId,
591
+ },
592
+ });
593
+ }
594
+
595
+ try {
596
+ const tableName = getTableName({ indexName: TABLE_MESSAGES, schemaName: getSchemaName(this.schema) });
597
+ await this.client.tx(async t => {
598
+ // Execute message inserts and thread update in parallel for better performance
599
+ const messageInserts = messages.map(message => {
600
+ if (!message.threadId) {
601
+ throw new Error(
602
+ `Expected to find a threadId for message, but couldn't find one. An unexpected error has occurred.`,
603
+ );
604
+ }
605
+ if (!message.resourceId) {
606
+ throw new Error(
607
+ `Expected to find a resourceId for message, but couldn't find one. An unexpected error has occurred.`,
608
+ );
609
+ }
610
+ return t.none(
611
+ `INSERT INTO ${tableName} (id, thread_id, content, "createdAt", "createdAtZ", role, type, "resourceId")
612
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
613
+ ON CONFLICT (id) DO UPDATE SET
614
+ thread_id = EXCLUDED.thread_id,
615
+ content = EXCLUDED.content,
616
+ role = EXCLUDED.role,
617
+ type = EXCLUDED.type,
618
+ "resourceId" = EXCLUDED."resourceId"`,
619
+ [
620
+ message.id,
621
+ message.threadId,
622
+ typeof message.content === 'string' ? message.content : JSON.stringify(message.content),
623
+ message.createdAt || new Date().toISOString(),
624
+ message.createdAt || new Date().toISOString(),
625
+ message.role,
626
+ message.type || 'v2',
627
+ message.resourceId,
628
+ ],
629
+ );
630
+ });
631
+
632
+ const threadTableName = getTableName({ indexName: TABLE_THREADS, schemaName: getSchemaName(this.schema) });
633
+ const threadUpdate = t.none(
634
+ `UPDATE ${threadTableName}
635
+ SET
636
+ "updatedAt" = $1,
637
+ "updatedAtZ" = $1
638
+ WHERE id = $2
639
+ `,
640
+ [new Date().toISOString(), threadId],
641
+ );
642
+
643
+ await Promise.all([...messageInserts, threadUpdate]);
644
+ });
645
+
646
+ // Parse content back to objects if they were stringified during storage
647
+ const messagesWithParsedContent = messages.map(message => {
648
+ if (typeof message.content === 'string') {
649
+ try {
650
+ return { ...message, content: JSON.parse(message.content) };
651
+ } catch {
652
+ // If parsing fails, leave as string (V1 message)
653
+ return message;
654
+ }
655
+ }
656
+ return message;
657
+ });
658
+
659
+ const list = new MessageList().add(messagesWithParsedContent, 'memory');
660
+ if (format === `v2`) return list.get.all.v2();
661
+ return list.get.all.v1();
662
+ } catch (error) {
663
+ throw new MastraError(
664
+ {
665
+ id: 'MASTRA_STORAGE_PG_STORE_SAVE_MESSAGES_FAILED',
666
+ domain: ErrorDomain.STORAGE,
667
+ category: ErrorCategory.THIRD_PARTY,
668
+ details: {
669
+ threadId,
670
+ },
671
+ },
672
+ error,
673
+ );
674
+ }
675
+ }
676
+
677
+ async updateMessages({
678
+ messages,
679
+ }: {
680
+ messages: (Partial<Omit<MastraMessageV2, 'createdAt'>> & {
681
+ id: string;
682
+ content?: {
683
+ metadata?: MastraMessageContentV2['metadata'];
684
+ content?: MastraMessageContentV2['content'];
685
+ };
686
+ })[];
687
+ }): Promise<MastraMessageV2[]> {
688
+ if (messages.length === 0) {
689
+ return [];
690
+ }
691
+
692
+ const messageIds = messages.map(m => m.id);
693
+
694
+ const selectQuery = `SELECT id, content, role, type, "createdAt", thread_id AS "threadId", "resourceId" FROM ${getTableName({ indexName: TABLE_MESSAGES, schemaName: getSchemaName(this.schema) })} WHERE id IN ($1:list)`;
695
+
696
+ const existingMessagesDb = await this.client.manyOrNone(selectQuery, [messageIds]);
697
+
698
+ if (existingMessagesDb.length === 0) {
699
+ return [];
700
+ }
701
+
702
+ // Parse content from string to object for merging
703
+ const existingMessages: MastraMessageV2[] = existingMessagesDb.map(msg => {
704
+ if (typeof msg.content === 'string') {
705
+ try {
706
+ msg.content = JSON.parse(msg.content);
707
+ } catch {
708
+ // ignore if not valid json
709
+ }
710
+ }
711
+ return msg as MastraMessageV2;
712
+ });
713
+
714
+ const threadIdsToUpdate = new Set<string>();
715
+
716
+ await this.client.tx(async t => {
717
+ const queries = [];
718
+ const columnMapping: Record<string, string> = {
719
+ threadId: 'thread_id',
720
+ };
721
+
722
+ for (const existingMessage of existingMessages) {
723
+ const updatePayload = messages.find(m => m.id === existingMessage.id);
724
+ if (!updatePayload) continue;
725
+
726
+ const { id, ...fieldsToUpdate } = updatePayload;
727
+ if (Object.keys(fieldsToUpdate).length === 0) continue;
728
+
729
+ threadIdsToUpdate.add(existingMessage.threadId!);
730
+ if (updatePayload.threadId && updatePayload.threadId !== existingMessage.threadId) {
731
+ threadIdsToUpdate.add(updatePayload.threadId);
732
+ }
733
+
734
+ const setClauses: string[] = [];
735
+ const values: any[] = [];
736
+ let paramIndex = 1;
737
+
738
+ const updatableFields = { ...fieldsToUpdate };
739
+
740
+ // Special handling for content: merge in code, then update the whole field
741
+ if (updatableFields.content) {
742
+ const newContent = {
743
+ ...existingMessage.content,
744
+ ...updatableFields.content,
745
+ // Deep merge metadata if it exists on both
746
+ ...(existingMessage.content?.metadata && updatableFields.content.metadata
747
+ ? {
748
+ metadata: {
749
+ ...existingMessage.content.metadata,
750
+ ...updatableFields.content.metadata,
751
+ },
752
+ }
753
+ : {}),
754
+ };
755
+ setClauses.push(`content = $${paramIndex++}`);
756
+ values.push(newContent);
757
+ delete updatableFields.content;
758
+ }
759
+
760
+ for (const key in updatableFields) {
761
+ if (Object.prototype.hasOwnProperty.call(updatableFields, key)) {
762
+ const dbColumn = columnMapping[key] || key;
763
+ setClauses.push(`"${dbColumn}" = $${paramIndex++}`);
764
+ values.push(updatableFields[key as keyof typeof updatableFields]);
765
+ }
766
+ }
767
+
768
+ if (setClauses.length > 0) {
769
+ values.push(id);
770
+ const sql = `UPDATE ${getTableName({ indexName: TABLE_MESSAGES, schemaName: getSchemaName(this.schema) })} SET ${setClauses.join(', ')} WHERE id = $${paramIndex}`;
771
+ queries.push(t.none(sql, values));
772
+ }
773
+ }
774
+
775
+ if (threadIdsToUpdate.size > 0) {
776
+ queries.push(
777
+ t.none(
778
+ `UPDATE ${getTableName({ indexName: TABLE_THREADS, schemaName: getSchemaName(this.schema) })} SET "updatedAt" = NOW(), "updatedAtZ" = NOW() WHERE id IN ($1:list)`,
779
+ [Array.from(threadIdsToUpdate)],
780
+ ),
781
+ );
782
+ }
783
+
784
+ if (queries.length > 0) {
785
+ await t.batch(queries);
786
+ }
787
+ });
788
+
789
+ // Re-fetch to return the fully updated messages
790
+ const updatedMessages = await this.client.manyOrNone<MastraMessageV2>(selectQuery, [messageIds]);
791
+
792
+ return (updatedMessages || []).map(message => {
793
+ if (typeof message.content === 'string') {
794
+ try {
795
+ message.content = JSON.parse(message.content);
796
+ } catch {
797
+ /* ignore */
798
+ }
799
+ }
800
+ return message;
801
+ });
802
+ }
803
+
804
+ async getResourceById({ resourceId }: { resourceId: string }): Promise<StorageResourceType | null> {
805
+ const tableName = getTableName({ indexName: TABLE_RESOURCES, schemaName: getSchemaName(this.schema) });
806
+ const result = await this.client.oneOrNone<StorageResourceType & { createdAtZ: Date; updatedAtZ: Date }>(
807
+ `SELECT * FROM ${tableName} WHERE id = $1`,
808
+ [resourceId],
809
+ );
810
+
811
+ if (!result) {
812
+ return null;
813
+ }
814
+
815
+ return {
816
+ id: result.id,
817
+ createdAt: result.createdAtZ || result.createdAt,
818
+ updatedAt: result.updatedAtZ || result.updatedAt,
819
+ workingMemory: result.workingMemory,
820
+ metadata: typeof result.metadata === 'string' ? JSON.parse(result.metadata) : result.metadata,
821
+ };
822
+ }
823
+
824
+ async saveResource({ resource }: { resource: StorageResourceType }): Promise<StorageResourceType> {
825
+ await this.operations.insert({
826
+ tableName: TABLE_RESOURCES,
827
+ record: {
828
+ ...resource,
829
+ metadata: JSON.stringify(resource.metadata),
830
+ },
831
+ });
832
+
833
+ return resource;
834
+ }
835
+
836
+ async updateResource({
837
+ resourceId,
838
+ workingMemory,
839
+ metadata,
840
+ }: {
841
+ resourceId: string;
842
+ workingMemory?: string;
843
+ metadata?: Record<string, unknown>;
844
+ }): Promise<StorageResourceType> {
845
+ const existingResource = await this.getResourceById({ resourceId });
846
+
847
+ if (!existingResource) {
848
+ // Create new resource if it doesn't exist
849
+ const newResource: StorageResourceType = {
850
+ id: resourceId,
851
+ workingMemory,
852
+ metadata: metadata || {},
853
+ createdAt: new Date(),
854
+ updatedAt: new Date(),
855
+ };
856
+ return this.saveResource({ resource: newResource });
857
+ }
858
+
859
+ const updatedResource = {
860
+ ...existingResource,
861
+ workingMemory: workingMemory !== undefined ? workingMemory : existingResource.workingMemory,
862
+ metadata: {
863
+ ...existingResource.metadata,
864
+ ...metadata,
865
+ },
866
+ updatedAt: new Date(),
867
+ };
868
+
869
+ const tableName = getTableName({ indexName: TABLE_RESOURCES, schemaName: getSchemaName(this.schema) });
870
+
871
+ const updates: string[] = [];
872
+ const values: any[] = [];
873
+ let paramIndex = 1;
874
+
875
+ if (workingMemory !== undefined) {
876
+ updates.push(`"workingMemory" = $${paramIndex}`);
877
+ values.push(workingMemory);
878
+ paramIndex++;
879
+ }
880
+
881
+ if (metadata) {
882
+ updates.push(`metadata = $${paramIndex}`);
883
+ values.push(JSON.stringify(updatedResource.metadata));
884
+ paramIndex++;
885
+ }
886
+
887
+ updates.push(`"updatedAt" = $${paramIndex}`);
888
+ values.push(updatedResource.updatedAt.toISOString());
889
+ updates.push(`"updatedAtZ" = $${paramIndex++}`);
890
+ values.push(updatedResource.updatedAt.toISOString());
891
+
892
+ paramIndex++;
893
+
894
+ values.push(resourceId);
895
+
896
+ await this.client.none(`UPDATE ${tableName} SET ${updates.join(', ')} WHERE id = $${paramIndex}`, values);
897
+
898
+ return updatedResource;
899
+ }
900
+ }