@mastra/pg 0.2.7-alpha.1 → 0.2.7-alpha.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,17 @@
1
+ import type { MetricResult } from '@mastra/core/eval';
1
2
  import type { MessageType, StorageThreadType } from '@mastra/core/memory';
2
- import { MastraStorage } from '@mastra/core/storage';
3
+ import {
4
+ MastraStorage,
5
+ TABLE_MESSAGES,
6
+ TABLE_THREADS,
7
+ TABLE_TRACES,
8
+ TABLE_WORKFLOW_SNAPSHOT,
9
+ TABLE_EVALS,
10
+ } from '@mastra/core/storage';
3
11
  import type { EvalRow, StorageColumn, StorageGetMessagesArg, TABLE_NAMES } from '@mastra/core/storage';
4
12
  import type { WorkflowRunState } from '@mastra/core/workflows';
5
13
  import pgPromise from 'pg-promise';
14
+ import type { ISSLConfig } from 'pg-promise/typescript/pg-subset';
6
15
 
7
16
  export type PostgresConfig =
8
17
  | {
@@ -11,6 +20,7 @@ export type PostgresConfig =
11
20
  database: string;
12
21
  user: string;
13
22
  password: string;
23
+ ssl?: boolean | ISSLConfig;
14
24
  }
15
25
  | {
16
26
  connectionString: string;
@@ -32,12 +42,56 @@ export class PostgresStore extends MastraStorage {
32
42
  database: config.database,
33
43
  user: config.user,
34
44
  password: config.password,
45
+ ssl: config.ssl,
35
46
  },
36
47
  );
37
48
  }
38
49
 
39
- getEvalsByAgentName(_agentName: string, _type?: 'test' | 'live'): Promise<EvalRow[]> {
40
- throw new Error('Method not implemented.');
50
+ getEvalsByAgentName(agentName: string, type?: 'test' | 'live'): Promise<EvalRow[]> {
51
+ try {
52
+ const baseQuery = `SELECT * FROM ${TABLE_EVALS} WHERE agent_name = $1`;
53
+ const typeCondition =
54
+ type === 'test'
55
+ ? " AND test_info IS NOT NULL AND test_info->>'testPath' IS NOT NULL"
56
+ : type === 'live'
57
+ ? " AND (test_info IS NULL OR test_info->>'testPath' IS NULL)"
58
+ : '';
59
+
60
+ const query = `${baseQuery}${typeCondition} ORDER BY created_at DESC`;
61
+
62
+ return this.db.manyOrNone(query, [agentName]).then(rows => rows?.map(row => this.transformEvalRow(row)) ?? []);
63
+ } catch (error) {
64
+ // Handle case where table doesn't exist yet
65
+ if (error instanceof Error && error.message.includes('relation') && error.message.includes('does not exist')) {
66
+ return Promise.resolve([]);
67
+ }
68
+ console.error('Failed to get evals for the specified agent: ' + (error as any)?.message);
69
+ throw error;
70
+ }
71
+ }
72
+
73
+ private transformEvalRow(row: Record<string, any>): EvalRow {
74
+ let testInfoValue = null;
75
+ if (row.test_info) {
76
+ try {
77
+ testInfoValue = typeof row.test_info === 'string' ? JSON.parse(row.test_info) : row.test_info;
78
+ } catch (e) {
79
+ console.warn('Failed to parse test_info:', e);
80
+ }
81
+ }
82
+
83
+ return {
84
+ agentName: row.agent_name as string,
85
+ input: row.input as string,
86
+ output: row.output as string,
87
+ result: row.result as MetricResult,
88
+ metricName: row.metric_name as string,
89
+ instructions: row.instructions as string,
90
+ testInfo: testInfoValue,
91
+ globalRunId: row.global_run_id as string,
92
+ runId: row.run_id as string,
93
+ createdAt: row.created_at as string,
94
+ };
41
95
  }
42
96
 
43
97
  async batchInsert({ tableName, records }: { tableName: TABLE_NAMES; records: Record<string, any>[] }): Promise<void> {
@@ -60,12 +114,14 @@ export class PostgresStore extends MastraStorage {
60
114
  page,
61
115
  perPage,
62
116
  attributes,
117
+ filters,
63
118
  }: {
64
119
  name?: string;
65
120
  scope?: string;
66
121
  page: number;
67
122
  perPage: number;
68
123
  attributes?: Record<string, string>;
124
+ filters?: Record<string, any>;
69
125
  }): Promise<any[]> {
70
126
  let idx = 1;
71
127
  const limit = perPage;
@@ -86,6 +142,12 @@ export class PostgresStore extends MastraStorage {
86
142
  });
87
143
  }
88
144
 
145
+ if (filters) {
146
+ Object.entries(filters).forEach(([key, value]) => {
147
+ conditions.push(`${key} = \$${idx++}`);
148
+ });
149
+ }
150
+
89
151
  const whereClause = conditions.length > 0 ? `WHERE ${conditions.join(' AND ')}` : '';
90
152
 
91
153
  if (name) {
@@ -102,9 +164,15 @@ export class PostgresStore extends MastraStorage {
102
164
  }
103
165
  }
104
166
 
167
+ if (filters) {
168
+ for (const [key, value] of Object.entries(filters)) {
169
+ args.push(value);
170
+ }
171
+ }
172
+
105
173
  console.log(
106
174
  'QUERY',
107
- `SELECT * FROM ${MastraStorage.TABLE_TRACES} ${whereClause} ORDER BY "createdAt" DESC LIMIT ${limit} OFFSET ${offset}`,
175
+ `SELECT * FROM ${TABLE_TRACES} ${whereClause} ORDER BY "createdAt" DESC LIMIT ${limit} OFFSET ${offset}`,
108
176
  args,
109
177
  );
110
178
 
@@ -123,10 +191,7 @@ export class PostgresStore extends MastraStorage {
123
191
  endTime: string;
124
192
  other: any;
125
193
  createdAt: string;
126
- }>(
127
- `SELECT * FROM ${MastraStorage.TABLE_TRACES} ${whereClause} ORDER BY "createdAt" DESC LIMIT ${limit} OFFSET ${offset}`,
128
- args,
129
- );
194
+ }>(`SELECT * FROM ${TABLE_TRACES} ${whereClause} ORDER BY "createdAt" DESC LIMIT ${limit} OFFSET ${offset}`, args);
130
195
 
131
196
  if (!result) {
132
197
  return [];
@@ -172,7 +237,7 @@ export class PostgresStore extends MastraStorage {
172
237
  ${columns}
173
238
  );
174
239
  ${
175
- tableName === MastraStorage.TABLE_WORKFLOW_SNAPSHOT
240
+ tableName === TABLE_WORKFLOW_SNAPSHOT
176
241
  ? `
177
242
  DO $$ BEGIN
178
243
  IF NOT EXISTS (
@@ -233,7 +298,7 @@ export class PostgresStore extends MastraStorage {
233
298
  }
234
299
 
235
300
  // If this is a workflow snapshot, parse the snapshot field
236
- if (tableName === MastraStorage.TABLE_WORKFLOW_SNAPSHOT) {
301
+ if (tableName === TABLE_WORKFLOW_SNAPSHOT) {
237
302
  const snapshot = result as any;
238
303
  if (typeof snapshot.snapshot === 'string') {
239
304
  snapshot.snapshot = JSON.parse(snapshot.snapshot);
@@ -258,7 +323,7 @@ export class PostgresStore extends MastraStorage {
258
323
  metadata,
259
324
  "createdAt",
260
325
  "updatedAt"
261
- FROM "${MastraStorage.TABLE_THREADS}"
326
+ FROM "${TABLE_THREADS}"
262
327
  WHERE id = $1`,
263
328
  [threadId],
264
329
  );
@@ -289,7 +354,7 @@ export class PostgresStore extends MastraStorage {
289
354
  metadata,
290
355
  "createdAt",
291
356
  "updatedAt"
292
- FROM "${MastraStorage.TABLE_THREADS}"
357
+ FROM "${TABLE_THREADS}"
293
358
  WHERE "resourceId" = $1`,
294
359
  [resourceId],
295
360
  );
@@ -309,7 +374,7 @@ export class PostgresStore extends MastraStorage {
309
374
  async saveThread({ thread }: { thread: StorageThreadType }): Promise<StorageThreadType> {
310
375
  try {
311
376
  await this.db.none(
312
- `INSERT INTO "${MastraStorage.TABLE_THREADS}" (
377
+ `INSERT INTO "${TABLE_THREADS}" (
313
378
  id,
314
379
  "resourceId",
315
380
  title,
@@ -363,7 +428,7 @@ export class PostgresStore extends MastraStorage {
363
428
  };
364
429
 
365
430
  const thread = await this.db.one<StorageThreadType>(
366
- `UPDATE "${MastraStorage.TABLE_THREADS}"
431
+ `UPDATE "${TABLE_THREADS}"
367
432
  SET title = $1,
368
433
  metadata = $2,
369
434
  "updatedAt" = $3
@@ -388,10 +453,10 @@ export class PostgresStore extends MastraStorage {
388
453
  try {
389
454
  await this.db.tx(async t => {
390
455
  // First delete all messages associated with this thread
391
- await t.none(`DELETE FROM "${MastraStorage.TABLE_MESSAGES}" WHERE thread_id = $1`, [threadId]);
456
+ await t.none(`DELETE FROM "${TABLE_MESSAGES}" WHERE thread_id = $1`, [threadId]);
392
457
 
393
458
  // Then delete the thread
394
- await t.none(`DELETE FROM "${MastraStorage.TABLE_THREADS}" WHERE id = $1`, [threadId]);
459
+ await t.none(`DELETE FROM "${TABLE_THREADS}" WHERE id = $1`, [threadId]);
395
460
  });
396
461
  } catch (error) {
397
462
  console.error('Error deleting thread:', error);
@@ -412,7 +477,7 @@ export class PostgresStore extends MastraStorage {
412
477
  SELECT
413
478
  *,
414
479
  ROW_NUMBER() OVER (ORDER BY "createdAt" DESC) as row_num
415
- FROM "${MastraStorage.TABLE_MESSAGES}"
480
+ FROM "${TABLE_MESSAGES}"
416
481
  WHERE thread_id = $1
417
482
  )
418
483
  SELECT
@@ -458,7 +523,7 @@ export class PostgresStore extends MastraStorage {
458
523
  type,
459
524
  "createdAt",
460
525
  thread_id AS "threadId"
461
- FROM "${MastraStorage.TABLE_MESSAGES}"
526
+ FROM "${TABLE_MESSAGES}"
462
527
  WHERE thread_id = $1
463
528
  AND id != ALL($2)
464
529
  ORDER BY "createdAt" DESC
@@ -508,7 +573,7 @@ export class PostgresStore extends MastraStorage {
508
573
  await this.db.tx(async t => {
509
574
  for (const message of messages) {
510
575
  await t.none(
511
- `INSERT INTO "${MastraStorage.TABLE_MESSAGES}" (id, thread_id, content, "createdAt", role, type)
576
+ `INSERT INTO "${TABLE_MESSAGES}" (id, thread_id, content, "createdAt", role, type)
512
577
  VALUES ($1, $2, $3, $4, $5, $6)`,
513
578
  [
514
579
  message.id,
@@ -541,7 +606,7 @@ export class PostgresStore extends MastraStorage {
541
606
  try {
542
607
  const now = new Date().toISOString();
543
608
  await this.db.none(
544
- `INSERT INTO "${MastraStorage.TABLE_WORKFLOW_SNAPSHOT}" (
609
+ `INSERT INTO "${TABLE_WORKFLOW_SNAPSHOT}" (
545
610
  workflow_name,
546
611
  run_id,
547
612
  snapshot,
@@ -568,7 +633,7 @@ export class PostgresStore extends MastraStorage {
568
633
  }): Promise<WorkflowRunState | null> {
569
634
  try {
570
635
  const result = await this.load({
571
- tableName: MastraStorage.TABLE_WORKFLOW_SNAPSHOT,
636
+ tableName: TABLE_WORKFLOW_SNAPSHOT,
572
637
  keys: {
573
638
  workflow_name: workflowName,
574
639
  run_id: runId,
@@ -586,6 +651,98 @@ export class PostgresStore extends MastraStorage {
586
651
  }
587
652
  }
588
653
 
654
+ async getWorkflowRuns({
655
+ workflowName,
656
+ fromDate,
657
+ toDate,
658
+ limit,
659
+ offset,
660
+ }: {
661
+ workflowName?: string;
662
+ fromDate?: Date;
663
+ toDate?: Date;
664
+ limit?: number;
665
+ offset?: number;
666
+ } = {}): Promise<{
667
+ runs: Array<{
668
+ workflowName: string;
669
+ runId: string;
670
+ snapshot: WorkflowRunState | string;
671
+ createdAt: Date;
672
+ updatedAt: Date;
673
+ }>;
674
+ total: number;
675
+ }> {
676
+ const conditions: string[] = [];
677
+ const values: any[] = [];
678
+ let paramIndex = 1;
679
+
680
+ if (workflowName) {
681
+ conditions.push(`workflow_name = $${paramIndex}`);
682
+ values.push(workflowName);
683
+ paramIndex++;
684
+ }
685
+
686
+ if (fromDate) {
687
+ conditions.push(`"createdAt" >= $${paramIndex}`);
688
+ values.push(fromDate);
689
+ paramIndex++;
690
+ }
691
+
692
+ if (toDate) {
693
+ conditions.push(`"createdAt" <= $${paramIndex}`);
694
+ values.push(toDate);
695
+ paramIndex++;
696
+ }
697
+
698
+ const whereClause = conditions.length > 0 ? `WHERE ${conditions.join(' AND ')}` : '';
699
+
700
+ let total = 0;
701
+ // Only get total count when using pagination
702
+ if (limit !== undefined && offset !== undefined) {
703
+ const countResult = await this.db.one(
704
+ `SELECT COUNT(*) as count FROM ${TABLE_WORKFLOW_SNAPSHOT} ${whereClause}`,
705
+ values,
706
+ );
707
+ total = Number(countResult.count);
708
+ }
709
+
710
+ // Get results
711
+ const query = `
712
+ SELECT * FROM ${TABLE_WORKFLOW_SNAPSHOT}
713
+ ${whereClause}
714
+ ORDER BY "createdAt" DESC
715
+ ${limit !== undefined && offset !== undefined ? ` LIMIT $${paramIndex} OFFSET $${paramIndex + 1}` : ''}
716
+ `;
717
+
718
+ const queryValues = limit !== undefined && offset !== undefined ? [...values, limit, offset] : values;
719
+
720
+ const result = await this.db.manyOrNone(query, queryValues);
721
+
722
+ const runs = (result || []).map(row => {
723
+ let parsedSnapshot: WorkflowRunState | string = row.snapshot as string;
724
+ if (typeof parsedSnapshot === 'string') {
725
+ try {
726
+ parsedSnapshot = JSON.parse(row.snapshot as string) as WorkflowRunState;
727
+ } catch (e) {
728
+ // If parsing fails, return the raw snapshot string
729
+ console.warn(`Failed to parse snapshot for workflow ${row.workflow_name}: ${e}`);
730
+ }
731
+ }
732
+
733
+ return {
734
+ workflowName: row.workflow_name,
735
+ runId: row.run_id,
736
+ snapshot: parsedSnapshot,
737
+ createdAt: row.createdAt,
738
+ updatedAt: row.updatedAt,
739
+ };
740
+ });
741
+
742
+ // Use runs.length as total when not paginating
743
+ return { runs, total: total || runs.length };
744
+ }
745
+
589
746
  async close(): Promise<void> {
590
747
  this.pgp.end();
591
748
  }