@mastra/pg 0.10.2-alpha.0 → 0.10.2-alpha.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +7 -7
- package/CHANGELOG.md +20 -0
- package/dist/_tsup-dts-rollup.d.cts +33 -40
- package/dist/_tsup-dts-rollup.d.ts +33 -40
- package/dist/index.cjs +286 -225
- package/dist/index.js +286 -225
- package/package.json +4 -4
- package/src/storage/index.test.ts +308 -191
- package/src/storage/index.ts +348 -345
package/src/storage/index.ts
CHANGED
|
@@ -12,11 +12,13 @@ import {
|
|
|
12
12
|
} from '@mastra/core/storage';
|
|
13
13
|
import type {
|
|
14
14
|
EvalRow,
|
|
15
|
+
PaginationInfo,
|
|
15
16
|
StorageColumn,
|
|
16
17
|
StorageGetMessagesArg,
|
|
17
18
|
TABLE_NAMES,
|
|
18
19
|
WorkflowRun,
|
|
19
20
|
WorkflowRuns,
|
|
21
|
+
PaginationArgs,
|
|
20
22
|
} from '@mastra/core/storage';
|
|
21
23
|
import { parseSqlIdentifier } from '@mastra/core/utils';
|
|
22
24
|
import type { WorkflowRunState } from '@mastra/core/workflows';
|
|
@@ -85,6 +87,14 @@ export class PostgresStore extends MastraStorage {
|
|
|
85
87
|
);
|
|
86
88
|
}
|
|
87
89
|
|
|
90
|
+
public get supports(): {
|
|
91
|
+
selectByIncludeResourceScope: boolean;
|
|
92
|
+
} {
|
|
93
|
+
return {
|
|
94
|
+
selectByIncludeResourceScope: true,
|
|
95
|
+
};
|
|
96
|
+
}
|
|
97
|
+
|
|
88
98
|
private getTableName(indexName: string) {
|
|
89
99
|
const parsedIndexName = parseSqlIdentifier(indexName, 'table name');
|
|
90
100
|
const parsedSchemaName = this.schema ? parseSqlIdentifier(this.schema, 'schema name') : undefined;
|
|
@@ -154,6 +164,9 @@ export class PostgresStore extends MastraStorage {
|
|
|
154
164
|
}
|
|
155
165
|
}
|
|
156
166
|
|
|
167
|
+
/**
|
|
168
|
+
* @deprecated use getTracesPaginated instead
|
|
169
|
+
*/
|
|
157
170
|
public async getTraces(args: {
|
|
158
171
|
name?: string;
|
|
159
172
|
scope?: string;
|
|
@@ -163,55 +176,32 @@ export class PostgresStore extends MastraStorage {
|
|
|
163
176
|
perPage?: number;
|
|
164
177
|
fromDate?: Date;
|
|
165
178
|
toDate?: Date;
|
|
166
|
-
}): Promise<any[]
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
perPage?: number;
|
|
189
|
-
attributes?: Record<string, string>;
|
|
190
|
-
filters?: Record<string, any>;
|
|
191
|
-
fromDate?: Date;
|
|
192
|
-
toDate?: Date;
|
|
193
|
-
returnPaginationResults?: boolean;
|
|
194
|
-
}): Promise<
|
|
195
|
-
| any[]
|
|
196
|
-
| {
|
|
197
|
-
traces: any[];
|
|
198
|
-
total: number;
|
|
199
|
-
page: number;
|
|
200
|
-
perPage: number;
|
|
201
|
-
hasMore: boolean;
|
|
202
|
-
}
|
|
179
|
+
}): Promise<any[]> {
|
|
180
|
+
if (args.fromDate || args.toDate) {
|
|
181
|
+
(args as any).dateRange = {
|
|
182
|
+
start: args.fromDate,
|
|
183
|
+
end: args.toDate,
|
|
184
|
+
};
|
|
185
|
+
}
|
|
186
|
+
const result = await this.getTracesPaginated(args);
|
|
187
|
+
return result.traces;
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
public async getTracesPaginated(
|
|
191
|
+
args: {
|
|
192
|
+
name?: string;
|
|
193
|
+
scope?: string;
|
|
194
|
+
attributes?: Record<string, string>;
|
|
195
|
+
filters?: Record<string, any>;
|
|
196
|
+
} & PaginationArgs,
|
|
197
|
+
): Promise<
|
|
198
|
+
PaginationInfo & {
|
|
199
|
+
traces: any[];
|
|
200
|
+
}
|
|
203
201
|
> {
|
|
204
|
-
const {
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
page,
|
|
208
|
-
perPage: perPageInput,
|
|
209
|
-
attributes,
|
|
210
|
-
filters,
|
|
211
|
-
fromDate,
|
|
212
|
-
toDate,
|
|
213
|
-
returnPaginationResults,
|
|
214
|
-
} = args;
|
|
202
|
+
const { name, scope, page = 0, perPage: perPageInput, attributes, filters, dateRange } = args;
|
|
203
|
+
const fromDate = dateRange?.start;
|
|
204
|
+
const toDate = dateRange?.end;
|
|
215
205
|
|
|
216
206
|
const perPage = perPageInput !== undefined ? perPageInput : 100; // Default perPage
|
|
217
207
|
const currentOffset = page * perPage;
|
|
@@ -258,7 +248,7 @@ export class PostgresStore extends MastraStorage {
|
|
|
258
248
|
const countResult = await this.db.one(countQuery, queryParams);
|
|
259
249
|
const total = parseInt(countResult.count, 10);
|
|
260
250
|
|
|
261
|
-
if (total === 0
|
|
251
|
+
if (total === 0) {
|
|
262
252
|
return {
|
|
263
253
|
traces: [],
|
|
264
254
|
total: 0,
|
|
@@ -266,11 +256,11 @@ export class PostgresStore extends MastraStorage {
|
|
|
266
256
|
perPage,
|
|
267
257
|
hasMore: false,
|
|
268
258
|
};
|
|
269
|
-
} else if (total === 0) {
|
|
270
|
-
return [];
|
|
271
259
|
}
|
|
272
260
|
|
|
273
|
-
const dataQuery = `SELECT * FROM ${this.getTableName(
|
|
261
|
+
const dataQuery = `SELECT * FROM ${this.getTableName(
|
|
262
|
+
TABLE_TRACES,
|
|
263
|
+
)} ${whereClause} ORDER BY "createdAt" DESC LIMIT $${paramIndex++} OFFSET $${paramIndex++}`;
|
|
274
264
|
const finalQueryParams = [...queryParams, perPage, currentOffset];
|
|
275
265
|
|
|
276
266
|
const rows = await this.db.manyOrNone<any>(dataQuery, finalQueryParams);
|
|
@@ -291,17 +281,13 @@ export class PostgresStore extends MastraStorage {
|
|
|
291
281
|
createdAt: row.createdAt,
|
|
292
282
|
}));
|
|
293
283
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
};
|
|
302
|
-
} else {
|
|
303
|
-
return traces;
|
|
304
|
-
}
|
|
284
|
+
return {
|
|
285
|
+
traces,
|
|
286
|
+
total,
|
|
287
|
+
page,
|
|
288
|
+
perPage,
|
|
289
|
+
hasMore: currentOffset + traces.length < total,
|
|
290
|
+
};
|
|
305
291
|
}
|
|
306
292
|
|
|
307
293
|
private async setupSchema() {
|
|
@@ -404,6 +390,57 @@ export class PostgresStore extends MastraStorage {
|
|
|
404
390
|
}
|
|
405
391
|
}
|
|
406
392
|
|
|
393
|
+
protected getDefaultValue(type: StorageColumn['type']): string {
|
|
394
|
+
switch (type) {
|
|
395
|
+
case 'timestamp':
|
|
396
|
+
return 'DEFAULT NOW()';
|
|
397
|
+
case 'jsonb':
|
|
398
|
+
return "DEFAULT '{}'::jsonb";
|
|
399
|
+
default:
|
|
400
|
+
return super.getDefaultValue(type);
|
|
401
|
+
}
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
/**
|
|
405
|
+
* Alters table schema to add columns if they don't exist
|
|
406
|
+
* @param tableName Name of the table
|
|
407
|
+
* @param schema Schema of the table
|
|
408
|
+
* @param ifNotExists Array of column names to add if they don't exist
|
|
409
|
+
*/
|
|
410
|
+
async alterTable({
|
|
411
|
+
tableName,
|
|
412
|
+
schema,
|
|
413
|
+
ifNotExists,
|
|
414
|
+
}: {
|
|
415
|
+
tableName: TABLE_NAMES;
|
|
416
|
+
schema: Record<string, StorageColumn>;
|
|
417
|
+
ifNotExists: string[];
|
|
418
|
+
}): Promise<void> {
|
|
419
|
+
const fullTableName = this.getTableName(tableName);
|
|
420
|
+
|
|
421
|
+
try {
|
|
422
|
+
for (const columnName of ifNotExists) {
|
|
423
|
+
if (schema[columnName]) {
|
|
424
|
+
const columnDef = schema[columnName];
|
|
425
|
+
const sqlType = this.getSqlType(columnDef.type);
|
|
426
|
+
const nullable = columnDef.nullable === false ? 'NOT NULL' : '';
|
|
427
|
+
const defaultValue = columnDef.nullable === false ? this.getDefaultValue(columnDef.type) : '';
|
|
428
|
+
const parsedColumnName = parseSqlIdentifier(columnName, 'column name');
|
|
429
|
+
const alterSql =
|
|
430
|
+
`ALTER TABLE ${fullTableName} ADD COLUMN IF NOT EXISTS "${parsedColumnName}" ${sqlType} ${nullable} ${defaultValue}`.trim();
|
|
431
|
+
|
|
432
|
+
await this.db.none(alterSql);
|
|
433
|
+
this.logger?.debug?.(`Ensured column ${parsedColumnName} exists in table ${fullTableName}`);
|
|
434
|
+
}
|
|
435
|
+
}
|
|
436
|
+
} catch (error) {
|
|
437
|
+
this.logger?.error?.(
|
|
438
|
+
`Error altering table ${tableName}: ${error instanceof Error ? error.message : String(error)}`,
|
|
439
|
+
);
|
|
440
|
+
throw new Error(`Failed to alter table ${tableName}: ${error}`);
|
|
441
|
+
}
|
|
442
|
+
}
|
|
443
|
+
|
|
407
444
|
async clearTable({ tableName }: { tableName: TABLE_NAMES }): Promise<void> {
|
|
408
445
|
try {
|
|
409
446
|
await this.db.none(`TRUNCATE TABLE ${this.getTableName(tableName)} CASCADE`);
|
|
@@ -491,82 +528,76 @@ export class PostgresStore extends MastraStorage {
|
|
|
491
528
|
}
|
|
492
529
|
}
|
|
493
530
|
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
perPage: number;
|
|
500
|
-
hasMore: boolean;
|
|
501
|
-
}>;
|
|
502
|
-
public async getThreadsByResourceId(args: { resourceId: string; page?: number; perPage?: number }): Promise<
|
|
503
|
-
| StorageThreadType[]
|
|
504
|
-
| {
|
|
505
|
-
threads: StorageThreadType[];
|
|
506
|
-
total: number;
|
|
507
|
-
page: number;
|
|
508
|
-
perPage: number;
|
|
509
|
-
hasMore: boolean;
|
|
510
|
-
}
|
|
511
|
-
> {
|
|
512
|
-
const { resourceId, page, perPage: perPageInput } = args;
|
|
531
|
+
/**
|
|
532
|
+
* @deprecated use getThreadsByResourceIdPaginated instead
|
|
533
|
+
*/
|
|
534
|
+
public async getThreadsByResourceId(args: { resourceId: string }): Promise<StorageThreadType[]> {
|
|
535
|
+
const { resourceId } = args;
|
|
513
536
|
|
|
514
537
|
try {
|
|
515
538
|
const baseQuery = `FROM ${this.getTableName(TABLE_THREADS)} WHERE "resourceId" = $1`;
|
|
516
539
|
const queryParams: any[] = [resourceId];
|
|
517
540
|
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
perPage,
|
|
532
|
-
hasMore: false,
|
|
533
|
-
};
|
|
534
|
-
}
|
|
541
|
+
const dataQuery = `SELECT id, "resourceId", title, metadata, "createdAt", "updatedAt" ${baseQuery} ORDER BY "createdAt" DESC`;
|
|
542
|
+
const rows = await this.db.manyOrNone(dataQuery, queryParams);
|
|
543
|
+
return (rows || []).map(thread => ({
|
|
544
|
+
...thread,
|
|
545
|
+
metadata: typeof thread.metadata === 'string' ? JSON.parse(thread.metadata) : thread.metadata,
|
|
546
|
+
createdAt: thread.createdAt,
|
|
547
|
+
updatedAt: thread.updatedAt,
|
|
548
|
+
}));
|
|
549
|
+
} catch (error) {
|
|
550
|
+
this.logger.error(`Error getting threads for resource ${resourceId}:`, error);
|
|
551
|
+
return [];
|
|
552
|
+
}
|
|
553
|
+
}
|
|
535
554
|
|
|
536
|
-
|
|
537
|
-
|
|
555
|
+
public async getThreadsByResourceIdPaginated(
|
|
556
|
+
args: {
|
|
557
|
+
resourceId: string;
|
|
558
|
+
} & PaginationArgs,
|
|
559
|
+
): Promise<PaginationInfo & { threads: StorageThreadType[] }> {
|
|
560
|
+
const { resourceId, page = 0, perPage: perPageInput } = args;
|
|
561
|
+
try {
|
|
562
|
+
const baseQuery = `FROM ${this.getTableName(TABLE_THREADS)} WHERE "resourceId" = $1`;
|
|
563
|
+
const queryParams: any[] = [resourceId];
|
|
564
|
+
const perPage = perPageInput !== undefined ? perPageInput : 100;
|
|
565
|
+
const currentOffset = page * perPage;
|
|
538
566
|
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
createdAt: thread.createdAt, // Assuming already Date objects or ISO strings
|
|
543
|
-
updatedAt: thread.updatedAt,
|
|
544
|
-
}));
|
|
567
|
+
const countQuery = `SELECT COUNT(*) ${baseQuery}`;
|
|
568
|
+
const countResult = await this.db.one(countQuery, queryParams);
|
|
569
|
+
const total = parseInt(countResult.count, 10);
|
|
545
570
|
|
|
571
|
+
if (total === 0) {
|
|
546
572
|
return {
|
|
547
|
-
threads,
|
|
548
|
-
total,
|
|
573
|
+
threads: [],
|
|
574
|
+
total: 0,
|
|
549
575
|
page,
|
|
550
576
|
perPage,
|
|
551
|
-
hasMore:
|
|
577
|
+
hasMore: false,
|
|
552
578
|
};
|
|
553
|
-
} else {
|
|
554
|
-
// Non-paginated path
|
|
555
|
-
const dataQuery = `SELECT id, "resourceId", title, metadata, "createdAt", "updatedAt" ${baseQuery} ORDER BY "createdAt" DESC`;
|
|
556
|
-
const rows = await this.db.manyOrNone(dataQuery, queryParams);
|
|
557
|
-
return (rows || []).map(thread => ({
|
|
558
|
-
...thread,
|
|
559
|
-
metadata: typeof thread.metadata === 'string' ? JSON.parse(thread.metadata) : thread.metadata,
|
|
560
|
-
createdAt: thread.createdAt,
|
|
561
|
-
updatedAt: thread.updatedAt,
|
|
562
|
-
}));
|
|
563
579
|
}
|
|
580
|
+
|
|
581
|
+
const dataQuery = `SELECT id, "resourceId", title, metadata, "createdAt", "updatedAt" ${baseQuery} ORDER BY "createdAt" DESC LIMIT $2 OFFSET $3`;
|
|
582
|
+
const rows = await this.db.manyOrNone(dataQuery, [...queryParams, perPage, currentOffset]);
|
|
583
|
+
|
|
584
|
+
const threads = (rows || []).map(thread => ({
|
|
585
|
+
...thread,
|
|
586
|
+
metadata: typeof thread.metadata === 'string' ? JSON.parse(thread.metadata) : thread.metadata,
|
|
587
|
+
createdAt: thread.createdAt, // Assuming already Date objects or ISO strings
|
|
588
|
+
updatedAt: thread.updatedAt,
|
|
589
|
+
}));
|
|
590
|
+
|
|
591
|
+
return {
|
|
592
|
+
threads,
|
|
593
|
+
total,
|
|
594
|
+
page,
|
|
595
|
+
perPage,
|
|
596
|
+
hasMore: currentOffset + threads.length < total,
|
|
597
|
+
};
|
|
564
598
|
} catch (error) {
|
|
565
599
|
this.logger.error(`Error getting threads for resource ${resourceId}:`, error);
|
|
566
|
-
|
|
567
|
-
return { threads: [], total: 0, page, perPage: perPageInput || 100, hasMore: false };
|
|
568
|
-
}
|
|
569
|
-
return [];
|
|
600
|
+
return { threads: [], total: 0, page, perPage: perPageInput || 100, hasMore: false };
|
|
570
601
|
}
|
|
571
602
|
}
|
|
572
603
|
|
|
@@ -663,203 +694,193 @@ export class PostgresStore extends MastraStorage {
|
|
|
663
694
|
}
|
|
664
695
|
}
|
|
665
696
|
|
|
697
|
+
/**
|
|
698
|
+
* @deprecated use getMessagesPaginated instead
|
|
699
|
+
*/
|
|
666
700
|
public async getMessages(args: StorageGetMessagesArg & { format?: 'v1' }): Promise<MastraMessageV1[]>;
|
|
667
701
|
public async getMessages(args: StorageGetMessagesArg & { format: 'v2' }): Promise<MastraMessageV2[]>;
|
|
668
702
|
public async getMessages(
|
|
669
703
|
args: StorageGetMessagesArg & {
|
|
670
704
|
format?: 'v1' | 'v2';
|
|
671
|
-
page: number;
|
|
672
|
-
perPage?: number;
|
|
673
|
-
fromDate?: Date;
|
|
674
|
-
toDate?: Date;
|
|
675
705
|
},
|
|
676
|
-
): Promise<{
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
706
|
+
): Promise<MastraMessageV1[] | MastraMessageV2[]> {
|
|
707
|
+
const { threadId, format, selectBy } = args;
|
|
708
|
+
|
|
709
|
+
const selectStatement = `SELECT id, content, role, type, "createdAt", thread_id AS "threadId"`;
|
|
710
|
+
const orderByStatement = `ORDER BY "createdAt" DESC`;
|
|
711
|
+
|
|
712
|
+
try {
|
|
713
|
+
let rows: any[] = [];
|
|
714
|
+
const include = selectBy?.include || [];
|
|
715
|
+
|
|
716
|
+
if (include.length) {
|
|
717
|
+
const unionQueries: string[] = [];
|
|
718
|
+
const params: any[] = [];
|
|
719
|
+
let paramIdx = 1;
|
|
720
|
+
|
|
721
|
+
for (const inc of include) {
|
|
722
|
+
const { id, withPreviousMessages = 0, withNextMessages = 0 } = inc;
|
|
723
|
+
// if threadId is provided, use it, otherwise use threadId from args
|
|
724
|
+
const searchId = inc.threadId || threadId;
|
|
725
|
+
unionQueries.push(
|
|
726
|
+
`
|
|
727
|
+
SELECT * FROM (
|
|
728
|
+
WITH ordered_messages AS (
|
|
729
|
+
SELECT
|
|
730
|
+
*,
|
|
731
|
+
ROW_NUMBER() OVER (${orderByStatement}) as row_num
|
|
732
|
+
FROM ${this.getTableName(TABLE_MESSAGES)}
|
|
733
|
+
WHERE thread_id = $${paramIdx}
|
|
734
|
+
)
|
|
735
|
+
SELECT
|
|
736
|
+
m.id,
|
|
737
|
+
m.content,
|
|
738
|
+
m.role,
|
|
739
|
+
m.type,
|
|
740
|
+
m."createdAt",
|
|
741
|
+
m.thread_id AS "threadId",
|
|
742
|
+
m."resourceId"
|
|
743
|
+
FROM ordered_messages m
|
|
744
|
+
WHERE m.id = $${paramIdx + 1}
|
|
745
|
+
OR EXISTS (
|
|
746
|
+
SELECT 1 FROM ordered_messages target
|
|
747
|
+
WHERE target.id = $${paramIdx + 1}
|
|
748
|
+
AND (
|
|
749
|
+
-- Get previous messages based on the max withPreviousMessages
|
|
750
|
+
(m.row_num <= target.row_num + $${paramIdx + 2} AND m.row_num > target.row_num)
|
|
751
|
+
OR
|
|
752
|
+
-- Get next messages based on the max withNextMessages
|
|
753
|
+
(m.row_num >= target.row_num - $${paramIdx + 3} AND m.row_num < target.row_num)
|
|
754
|
+
)
|
|
755
|
+
)
|
|
756
|
+
)
|
|
757
|
+
`, // Keep ASC for final sorting after fetching context
|
|
758
|
+
);
|
|
759
|
+
params.push(searchId, id, withPreviousMessages, withNextMessages);
|
|
760
|
+
paramIdx += 4;
|
|
761
|
+
}
|
|
762
|
+
const finalQuery = unionQueries.join(' UNION ALL ') + ' ORDER BY "createdAt" ASC';
|
|
763
|
+
const includedRows = await this.db.manyOrNone(finalQuery, params);
|
|
764
|
+
const dedupedRows = Object.values(
|
|
765
|
+
includedRows.reduce(
|
|
766
|
+
(acc, row) => {
|
|
767
|
+
acc[row.id] = row;
|
|
768
|
+
return acc;
|
|
769
|
+
},
|
|
770
|
+
{} as Record<string, (typeof includedRows)[0]>,
|
|
771
|
+
),
|
|
772
|
+
);
|
|
773
|
+
rows = dedupedRows;
|
|
774
|
+
} else {
|
|
775
|
+
const limit = typeof selectBy?.last === `number` ? selectBy.last : 40;
|
|
776
|
+
if (limit === 0 && selectBy?.last !== false) {
|
|
777
|
+
// if last is explicitly false, we fetch all
|
|
778
|
+
// Do nothing, rows will be empty, and we return empty array later.
|
|
779
|
+
} else {
|
|
780
|
+
let query = `${selectStatement} FROM ${this.getTableName(
|
|
781
|
+
TABLE_MESSAGES,
|
|
782
|
+
)} WHERE thread_id = $1 ${orderByStatement}`;
|
|
783
|
+
const queryParams: any[] = [threadId];
|
|
784
|
+
if (limit !== undefined && selectBy?.last !== false) {
|
|
785
|
+
query += ` LIMIT $2`;
|
|
786
|
+
queryParams.push(limit);
|
|
787
|
+
}
|
|
788
|
+
rows = await this.db.manyOrNone(query, queryParams);
|
|
789
|
+
}
|
|
790
|
+
}
|
|
791
|
+
|
|
792
|
+
const fetchedMessages = (rows || []).map(message => {
|
|
793
|
+
if (typeof message.content === 'string') {
|
|
794
|
+
try {
|
|
795
|
+
message.content = JSON.parse(message.content);
|
|
796
|
+
} catch {
|
|
797
|
+
/* ignore */
|
|
798
|
+
}
|
|
799
|
+
}
|
|
800
|
+
if (message.type === 'v2') delete message.type;
|
|
801
|
+
return message as MastraMessageV1;
|
|
802
|
+
});
|
|
803
|
+
|
|
804
|
+
// Sort all messages by creation date
|
|
805
|
+
const sortedMessages = fetchedMessages.sort(
|
|
806
|
+
(a, b) => new Date(a.createdAt).getTime() - new Date(b.createdAt).getTime(),
|
|
807
|
+
);
|
|
808
|
+
|
|
809
|
+
return format === 'v2'
|
|
810
|
+
? sortedMessages.map(
|
|
811
|
+
m =>
|
|
812
|
+
({ ...m, content: m.content || { format: 2, parts: [{ type: 'text', text: '' }] } }) as MastraMessageV2,
|
|
813
|
+
)
|
|
814
|
+
: sortedMessages;
|
|
815
|
+
} catch (error) {
|
|
816
|
+
this.logger.error('Error getting messages:', error);
|
|
817
|
+
return [];
|
|
818
|
+
}
|
|
819
|
+
}
|
|
820
|
+
|
|
821
|
+
public async getMessagesPaginated(
|
|
684
822
|
args: StorageGetMessagesArg & {
|
|
685
823
|
format?: 'v1' | 'v2';
|
|
686
|
-
page?: number;
|
|
687
|
-
perPage?: number;
|
|
688
|
-
fromDate?: Date;
|
|
689
|
-
toDate?: Date;
|
|
690
824
|
},
|
|
691
|
-
): Promise<
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
total: number;
|
|
697
|
-
page: number;
|
|
698
|
-
perPage: number;
|
|
699
|
-
hasMore: boolean;
|
|
700
|
-
}
|
|
701
|
-
> {
|
|
702
|
-
const { threadId, format, page, perPage: perPageInput, fromDate, toDate, selectBy } = args;
|
|
825
|
+
): Promise<PaginationInfo & { messages: MastraMessageV1[] | MastraMessageV2[] }> {
|
|
826
|
+
const { threadId, format, selectBy } = args;
|
|
827
|
+
const { page = 0, perPage: perPageInput, dateRange } = selectBy?.pagination || {};
|
|
828
|
+
const fromDate = dateRange?.start;
|
|
829
|
+
const toDate = dateRange?.end;
|
|
703
830
|
|
|
704
831
|
const selectStatement = `SELECT id, content, role, type, "createdAt", thread_id AS "threadId"`;
|
|
705
832
|
const orderByStatement = `ORDER BY "createdAt" DESC`;
|
|
706
833
|
|
|
707
834
|
try {
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
const currentOffset = page * perPage;
|
|
835
|
+
const perPage = perPageInput !== undefined ? perPageInput : 40;
|
|
836
|
+
const currentOffset = page * perPage;
|
|
711
837
|
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
838
|
+
const conditions: string[] = [`thread_id = $1`];
|
|
839
|
+
const queryParams: any[] = [threadId];
|
|
840
|
+
let paramIndex = 2;
|
|
715
841
|
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
const countQuery = `SELECT COUNT(*) FROM ${this.getTableName(TABLE_MESSAGES)} ${whereClause}`;
|
|
727
|
-
const countResult = await this.db.one(countQuery, queryParams);
|
|
728
|
-
const total = parseInt(countResult.count, 10);
|
|
729
|
-
|
|
730
|
-
if (total === 0) {
|
|
731
|
-
return {
|
|
732
|
-
messages: [],
|
|
733
|
-
total: 0,
|
|
734
|
-
page,
|
|
735
|
-
perPage,
|
|
736
|
-
hasMore: false,
|
|
737
|
-
};
|
|
738
|
-
}
|
|
842
|
+
if (fromDate) {
|
|
843
|
+
conditions.push(`"createdAt" >= $${paramIndex++}`);
|
|
844
|
+
queryParams.push(fromDate);
|
|
845
|
+
}
|
|
846
|
+
if (toDate) {
|
|
847
|
+
conditions.push(`"createdAt" <= $${paramIndex++}`);
|
|
848
|
+
queryParams.push(toDate);
|
|
849
|
+
}
|
|
850
|
+
const whereClause = conditions.length > 0 ? `WHERE ${conditions.join(' AND ')}` : '';
|
|
739
851
|
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
const fetchedMessages = (rows || []).map(message => {
|
|
744
|
-
if (typeof message.content === 'string') {
|
|
745
|
-
try {
|
|
746
|
-
message.content = JSON.parse(message.content);
|
|
747
|
-
} catch {
|
|
748
|
-
/* ignore */
|
|
749
|
-
}
|
|
750
|
-
}
|
|
751
|
-
if (message.type === 'v2') delete message.type;
|
|
752
|
-
return message as MastraMessageV1;
|
|
753
|
-
});
|
|
754
|
-
|
|
755
|
-
const messagesToReturn =
|
|
756
|
-
format === 'v2'
|
|
757
|
-
? fetchedMessages.map(
|
|
758
|
-
m =>
|
|
759
|
-
({
|
|
760
|
-
...m,
|
|
761
|
-
content: m.content || { format: 2, parts: [{ type: 'text', text: '' }] },
|
|
762
|
-
}) as MastraMessageV2,
|
|
763
|
-
)
|
|
764
|
-
: fetchedMessages;
|
|
852
|
+
const countQuery = `SELECT COUNT(*) FROM ${this.getTableName(TABLE_MESSAGES)} ${whereClause}`;
|
|
853
|
+
const countResult = await this.db.one(countQuery, queryParams);
|
|
854
|
+
const total = parseInt(countResult.count, 10);
|
|
765
855
|
|
|
856
|
+
if (total === 0) {
|
|
766
857
|
return {
|
|
767
|
-
messages:
|
|
768
|
-
total,
|
|
858
|
+
messages: [],
|
|
859
|
+
total: 0,
|
|
769
860
|
page,
|
|
770
861
|
perPage,
|
|
771
|
-
hasMore:
|
|
862
|
+
hasMore: false,
|
|
772
863
|
};
|
|
773
|
-
}
|
|
774
|
-
// Non-paginated path: Handle selectBy.include or selectBy.last
|
|
775
|
-
let rows: any[] = [];
|
|
776
|
-
const include = selectBy?.include || [];
|
|
777
|
-
|
|
778
|
-
if (include.length) {
|
|
779
|
-
rows = await this.db.manyOrNone(
|
|
780
|
-
`
|
|
781
|
-
WITH ordered_messages AS (
|
|
782
|
-
SELECT
|
|
783
|
-
*,
|
|
784
|
-
ROW_NUMBER() OVER (${orderByStatement}) as row_num
|
|
785
|
-
FROM ${this.getTableName(TABLE_MESSAGES)}
|
|
786
|
-
WHERE thread_id = $1
|
|
787
|
-
)
|
|
788
|
-
SELECT
|
|
789
|
-
m.id,
|
|
790
|
-
m.content,
|
|
791
|
-
m.role,
|
|
792
|
-
m.type,
|
|
793
|
-
m."createdAt",
|
|
794
|
-
m.thread_id AS "threadId"
|
|
795
|
-
FROM ordered_messages m
|
|
796
|
-
WHERE m.id = ANY($2)
|
|
797
|
-
OR EXISTS (
|
|
798
|
-
SELECT 1 FROM ordered_messages target
|
|
799
|
-
WHERE target.id = ANY($2)
|
|
800
|
-
AND (
|
|
801
|
-
-- Get previous messages based on the max withPreviousMessages
|
|
802
|
-
(m.row_num <= target.row_num + $3 AND m.row_num > target.row_num)
|
|
803
|
-
OR
|
|
804
|
-
-- Get next messages based on the max withNextMessages
|
|
805
|
-
(m.row_num >= target.row_num - $4 AND m.row_num < target.row_num)
|
|
806
|
-
)
|
|
807
|
-
)
|
|
808
|
-
ORDER BY m."createdAt" ASC
|
|
809
|
-
`, // Keep ASC for final sorting after fetching context
|
|
810
|
-
[
|
|
811
|
-
threadId,
|
|
812
|
-
include.map(i => i.id),
|
|
813
|
-
Math.max(0, ...include.map(i => i.withPreviousMessages || 0)), // Ensure non-negative
|
|
814
|
-
Math.max(0, ...include.map(i => i.withNextMessages || 0)), // Ensure non-negative
|
|
815
|
-
],
|
|
816
|
-
);
|
|
817
|
-
} else {
|
|
818
|
-
const limit = typeof selectBy?.last === `number` ? selectBy.last : 40;
|
|
819
|
-
if (limit === 0 && selectBy?.last !== false) {
|
|
820
|
-
// if last is explicitly false, we fetch all
|
|
821
|
-
// Do nothing, rows will be empty, and we return empty array later.
|
|
822
|
-
} else {
|
|
823
|
-
let query = `${selectStatement} FROM ${this.getTableName(TABLE_MESSAGES)} WHERE thread_id = $1 ${orderByStatement}`;
|
|
824
|
-
const queryParams: any[] = [threadId];
|
|
825
|
-
if (limit !== undefined && selectBy?.last !== false) {
|
|
826
|
-
query += ` LIMIT $2`;
|
|
827
|
-
queryParams.push(limit);
|
|
828
|
-
}
|
|
829
|
-
rows = await this.db.manyOrNone(query, queryParams);
|
|
830
|
-
}
|
|
831
|
-
}
|
|
864
|
+
}
|
|
832
865
|
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
} catch {
|
|
838
|
-
/* ignore */
|
|
839
|
-
}
|
|
840
|
-
}
|
|
841
|
-
if (message.type === 'v2') delete message.type;
|
|
842
|
-
return message as MastraMessageV1;
|
|
843
|
-
});
|
|
866
|
+
const dataQuery = `${selectStatement} FROM ${this.getTableName(
|
|
867
|
+
TABLE_MESSAGES,
|
|
868
|
+
)} ${whereClause} ${orderByStatement} LIMIT $${paramIndex++} OFFSET $${paramIndex++}`;
|
|
869
|
+
const rows = await this.db.manyOrNone(dataQuery, [...queryParams, perPage, currentOffset]);
|
|
844
870
|
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
(a, b) => new Date(a.createdAt).getTime() - new Date(b.createdAt).getTime(),
|
|
848
|
-
);
|
|
871
|
+
const list = new MessageList().add(rows || [], 'memory');
|
|
872
|
+
const messagesToReturn = format === `v2` ? list.get.all.v2() : list.get.all.v1();
|
|
849
873
|
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
}
|
|
874
|
+
return {
|
|
875
|
+
messages: messagesToReturn,
|
|
876
|
+
total,
|
|
877
|
+
page,
|
|
878
|
+
perPage,
|
|
879
|
+
hasMore: currentOffset + rows.length < total,
|
|
880
|
+
};
|
|
857
881
|
} catch (error) {
|
|
858
882
|
this.logger.error('Error getting messages:', error);
|
|
859
|
-
|
|
860
|
-
return { messages: [], total: 0, page, perPage: perPageInput || 40, hasMore: false };
|
|
861
|
-
}
|
|
862
|
-
return [];
|
|
883
|
+
return { messages: [], total: 0, page, perPage: perPageInput || 40, hasMore: false };
|
|
863
884
|
}
|
|
864
885
|
}
|
|
865
886
|
|
|
@@ -887,16 +908,27 @@ export class PostgresStore extends MastraStorage {
|
|
|
887
908
|
|
|
888
909
|
await this.db.tx(async t => {
|
|
889
910
|
for (const message of messages) {
|
|
911
|
+
if (!message.threadId) {
|
|
912
|
+
throw new Error(
|
|
913
|
+
`Expected to find a threadId for message, but couldn't find one. An unexpected error has occurred.`,
|
|
914
|
+
);
|
|
915
|
+
}
|
|
916
|
+
if (!message.resourceId) {
|
|
917
|
+
throw new Error(
|
|
918
|
+
`Expected to find a resourceId for message, but couldn't find one. An unexpected error has occurred.`,
|
|
919
|
+
);
|
|
920
|
+
}
|
|
890
921
|
await t.none(
|
|
891
|
-
`INSERT INTO ${this.getTableName(TABLE_MESSAGES)} (id, thread_id, content, "createdAt", role, type)
|
|
892
|
-
VALUES ($1, $2, $3, $4, $5, $6)`,
|
|
922
|
+
`INSERT INTO ${this.getTableName(TABLE_MESSAGES)} (id, thread_id, content, "createdAt", role, type, "resourceId")
|
|
923
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
|
|
893
924
|
[
|
|
894
925
|
message.id,
|
|
895
|
-
threadId,
|
|
926
|
+
message.threadId,
|
|
896
927
|
typeof message.content === 'string' ? message.content : JSON.stringify(message.content),
|
|
897
928
|
message.createdAt || new Date().toISOString(),
|
|
898
929
|
message.role,
|
|
899
930
|
message.type || 'v2',
|
|
931
|
+
message.resourceId,
|
|
900
932
|
],
|
|
901
933
|
);
|
|
902
934
|
}
|
|
@@ -1134,23 +1166,15 @@ export class PostgresStore extends MastraStorage {
|
|
|
1134
1166
|
this.pgp.end();
|
|
1135
1167
|
}
|
|
1136
1168
|
|
|
1137
|
-
async getEvals(
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
fromDate
|
|
1145
|
-
toDate
|
|
1146
|
-
}): Promise<{
|
|
1147
|
-
evals: EvalRow[];
|
|
1148
|
-
total: number;
|
|
1149
|
-
page?: number;
|
|
1150
|
-
perPage?: number;
|
|
1151
|
-
hasMore?: boolean;
|
|
1152
|
-
}> {
|
|
1153
|
-
const { agentName, type, page, perPage, limit, offset, fromDate, toDate } = options || {};
|
|
1169
|
+
async getEvals(
|
|
1170
|
+
options: {
|
|
1171
|
+
agentName?: string;
|
|
1172
|
+
type?: 'test' | 'live';
|
|
1173
|
+
} & PaginationArgs = {},
|
|
1174
|
+
): Promise<PaginationInfo & { evals: EvalRow[] }> {
|
|
1175
|
+
const { agentName, type, page = 0, perPage = 100, dateRange } = options;
|
|
1176
|
+
const fromDate = dateRange?.start;
|
|
1177
|
+
const toDate = dateRange?.end;
|
|
1154
1178
|
|
|
1155
1179
|
const conditions: string[] = [];
|
|
1156
1180
|
const queryParams: any[] = [];
|
|
@@ -1182,50 +1206,29 @@ export class PostgresStore extends MastraStorage {
|
|
|
1182
1206
|
const countQuery = `SELECT COUNT(*) FROM ${this.getTableName(TABLE_EVALS)} ${whereClause}`;
|
|
1183
1207
|
const countResult = await this.db.one(countQuery, queryParams);
|
|
1184
1208
|
const total = parseInt(countResult.count, 10);
|
|
1185
|
-
|
|
1186
|
-
let currentLimit: number;
|
|
1187
|
-
let currentOffset: number;
|
|
1188
|
-
let currentPage: number | undefined = page;
|
|
1189
|
-
let currentPerPage: number | undefined = perPage;
|
|
1190
|
-
let hasMore = false;
|
|
1191
|
-
|
|
1192
|
-
if (limit !== undefined && offset !== undefined) {
|
|
1193
|
-
currentLimit = limit;
|
|
1194
|
-
currentOffset = offset;
|
|
1195
|
-
currentPage = undefined;
|
|
1196
|
-
currentPerPage = undefined;
|
|
1197
|
-
hasMore = currentOffset + currentLimit < total;
|
|
1198
|
-
} else if (page !== undefined && perPage !== undefined) {
|
|
1199
|
-
currentLimit = perPage;
|
|
1200
|
-
currentOffset = page * perPage;
|
|
1201
|
-
hasMore = currentOffset + currentLimit < total;
|
|
1202
|
-
} else {
|
|
1203
|
-
currentLimit = perPage || 100;
|
|
1204
|
-
currentOffset = (page || 0) * currentLimit;
|
|
1205
|
-
if (page === undefined) currentPage = 0;
|
|
1206
|
-
if (currentPerPage === undefined) currentPerPage = currentLimit;
|
|
1207
|
-
hasMore = currentOffset + currentLimit < total;
|
|
1208
|
-
}
|
|
1209
|
+
const currentOffset = page * perPage;
|
|
1209
1210
|
|
|
1210
1211
|
if (total === 0) {
|
|
1211
1212
|
return {
|
|
1212
1213
|
evals: [],
|
|
1213
1214
|
total: 0,
|
|
1214
|
-
page
|
|
1215
|
-
perPage
|
|
1215
|
+
page,
|
|
1216
|
+
perPage,
|
|
1216
1217
|
hasMore: false,
|
|
1217
1218
|
};
|
|
1218
1219
|
}
|
|
1219
1220
|
|
|
1220
|
-
const dataQuery = `SELECT * FROM ${this.getTableName(
|
|
1221
|
-
|
|
1221
|
+
const dataQuery = `SELECT * FROM ${this.getTableName(
|
|
1222
|
+
TABLE_EVALS,
|
|
1223
|
+
)} ${whereClause} ORDER BY created_at DESC LIMIT $${paramIndex++} OFFSET $${paramIndex++}`;
|
|
1224
|
+
const rows = await this.db.manyOrNone(dataQuery, [...queryParams, perPage, currentOffset]);
|
|
1222
1225
|
|
|
1223
1226
|
return {
|
|
1224
1227
|
evals: rows?.map(row => this.transformEvalRow(row)) ?? [],
|
|
1225
1228
|
total,
|
|
1226
|
-
page
|
|
1227
|
-
perPage
|
|
1228
|
-
hasMore,
|
|
1229
|
+
page,
|
|
1230
|
+
perPage,
|
|
1231
|
+
hasMore: currentOffset + (rows?.length ?? 0) < total,
|
|
1229
1232
|
};
|
|
1230
1233
|
}
|
|
1231
1234
|
}
|