@mastra/pg 0.3.1-alpha.2 → 0.3.1-alpha.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,7 +8,14 @@ import {
8
8
  TABLE_WORKFLOW_SNAPSHOT,
9
9
  TABLE_EVALS,
10
10
  } from '@mastra/core/storage';
11
- import type { EvalRow, StorageColumn, StorageGetMessagesArg, TABLE_NAMES } from '@mastra/core/storage';
11
+ import type {
12
+ EvalRow,
13
+ StorageColumn,
14
+ StorageGetMessagesArg,
15
+ TABLE_NAMES,
16
+ WorkflowRun,
17
+ WorkflowRuns,
18
+ } from '@mastra/core/storage';
12
19
  import type { WorkflowRunState } from '@mastra/core/workflows';
13
20
  import pgPromise from 'pg-promise';
14
21
  import type { ISSLConfig } from 'pg-promise/typescript/pg-subset';
@@ -561,7 +568,7 @@ export class PostgresStore extends MastraStorage {
561
568
  }
562
569
  }
563
570
 
564
- async getMessages<T = unknown>({ threadId, selectBy }: StorageGetMessagesArg): Promise<T> {
571
+ async getMessages<T = unknown>({ threadId, selectBy }: StorageGetMessagesArg): Promise<T[]> {
565
572
  try {
566
573
  const messages: any[] = [];
567
574
  const limit = typeof selectBy?.last === `number` ? selectBy.last : 40;
@@ -645,7 +652,7 @@ export class PostgresStore extends MastraStorage {
645
652
  }
646
653
  });
647
654
 
648
- return messages as T;
655
+ return messages as T[];
649
656
  } catch (error) {
650
657
  console.error('Error getting messages:', error);
651
658
  throw error;
@@ -748,96 +755,166 @@ export class PostgresStore extends MastraStorage {
748
755
  }
749
756
  }
750
757
 
758
+ private async hasColumn(table: string, column: string): Promise<boolean> {
759
+ // Use this.schema to scope the check
760
+ const schema = this.schema || 'public';
761
+ const result = await this.db.oneOrNone(
762
+ `SELECT 1 FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 AND (column_name = $3 OR column_name = $4)`,
763
+ [schema, table, column, column.toLowerCase()],
764
+ );
765
+ return !!result;
766
+ }
767
+
768
+ private parseWorkflowRun(row: any): WorkflowRun {
769
+ let parsedSnapshot: WorkflowRunState | string = row.snapshot as string;
770
+ if (typeof parsedSnapshot === 'string') {
771
+ try {
772
+ parsedSnapshot = JSON.parse(row.snapshot as string) as WorkflowRunState;
773
+ } catch (e) {
774
+ // If parsing fails, return the raw snapshot string
775
+ console.warn(`Failed to parse snapshot for workflow ${row.workflow_name}: ${e}`);
776
+ }
777
+ }
778
+
779
+ return {
780
+ workflowName: row.workflow_name,
781
+ runId: row.run_id,
782
+ snapshot: parsedSnapshot,
783
+ createdAt: row.createdAt,
784
+ updatedAt: row.updatedAt,
785
+ resourceId: row.resourceId,
786
+ };
787
+ }
788
+
751
789
  async getWorkflowRuns({
752
790
  workflowName,
753
791
  fromDate,
754
792
  toDate,
755
793
  limit,
756
794
  offset,
795
+ resourceId,
757
796
  }: {
758
797
  workflowName?: string;
759
798
  fromDate?: Date;
760
799
  toDate?: Date;
761
800
  limit?: number;
762
801
  offset?: number;
763
- } = {}): Promise<{
764
- runs: Array<{
765
- workflowName: string;
766
- runId: string;
767
- snapshot: WorkflowRunState | string;
768
- createdAt: Date;
769
- updatedAt: Date;
770
- }>;
771
- total: number;
772
- }> {
773
- const conditions: string[] = [];
774
- const values: any[] = [];
775
- let paramIndex = 1;
776
-
777
- if (workflowName) {
778
- conditions.push(`workflow_name = $${paramIndex}`);
779
- values.push(workflowName);
780
- paramIndex++;
781
- }
782
-
783
- if (fromDate) {
784
- conditions.push(`"createdAt" >= $${paramIndex}`);
785
- values.push(fromDate);
786
- paramIndex++;
787
- }
802
+ resourceId?: string;
803
+ } = {}): Promise<WorkflowRuns> {
804
+ try {
805
+ const conditions: string[] = [];
806
+ const values: any[] = [];
807
+ let paramIndex = 1;
808
+
809
+ if (workflowName) {
810
+ conditions.push(`workflow_name = $${paramIndex}`);
811
+ values.push(workflowName);
812
+ paramIndex++;
813
+ }
788
814
 
789
- if (toDate) {
790
- conditions.push(`"createdAt" <= $${paramIndex}`);
791
- values.push(toDate);
792
- paramIndex++;
793
- }
815
+ if (resourceId) {
816
+ const hasResourceId = await this.hasColumn(TABLE_WORKFLOW_SNAPSHOT, 'resourceId');
817
+ if (hasResourceId) {
818
+ conditions.push(`"resourceId" = $${paramIndex}`);
819
+ values.push(resourceId);
820
+ paramIndex++;
821
+ } else {
822
+ console.warn(`[${TABLE_WORKFLOW_SNAPSHOT}] resourceId column not found. Skipping resourceId filter.`);
823
+ }
824
+ }
794
825
 
795
- const whereClause = conditions.length > 0 ? `WHERE ${conditions.join(' AND ')}` : '';
826
+ if (fromDate) {
827
+ conditions.push(`"createdAt" >= $${paramIndex}`);
828
+ values.push(fromDate);
829
+ paramIndex++;
830
+ }
796
831
 
797
- let total = 0;
798
- // Only get total count when using pagination
799
- if (limit !== undefined && offset !== undefined) {
800
- const countResult = await this.db.one(
801
- `SELECT COUNT(*) as count FROM ${this.getTableName(TABLE_WORKFLOW_SNAPSHOT)} ${whereClause}`,
802
- values,
803
- );
804
- total = Number(countResult.count);
805
- }
832
+ if (toDate) {
833
+ conditions.push(`"createdAt" <= $${paramIndex}`);
834
+ values.push(toDate);
835
+ paramIndex++;
836
+ }
837
+ const whereClause = conditions.length > 0 ? `WHERE ${conditions.join(' AND ')}` : '';
838
+
839
+ let total = 0;
840
+ // Only get total count when using pagination
841
+ if (limit !== undefined && offset !== undefined) {
842
+ const countResult = await this.db.one(
843
+ `SELECT COUNT(*) as count FROM ${this.getTableName(TABLE_WORKFLOW_SNAPSHOT)} ${whereClause}`,
844
+ values,
845
+ );
846
+ total = Number(countResult.count);
847
+ }
806
848
 
807
- // Get results
808
- const query = `
849
+ // Get results
850
+ const query = `
809
851
  SELECT * FROM ${this.getTableName(TABLE_WORKFLOW_SNAPSHOT)}
810
852
  ${whereClause}
811
853
  ORDER BY "createdAt" DESC
812
854
  ${limit !== undefined && offset !== undefined ? ` LIMIT $${paramIndex} OFFSET $${paramIndex + 1}` : ''}
813
855
  `;
814
856
 
815
- const queryValues = limit !== undefined && offset !== undefined ? [...values, limit, offset] : values;
857
+ const queryValues = limit !== undefined && offset !== undefined ? [...values, limit, offset] : values;
816
858
 
817
- const result = await this.db.manyOrNone(query, queryValues);
859
+ const result = await this.db.manyOrNone(query, queryValues);
818
860
 
819
- const runs = (result || []).map(row => {
820
- let parsedSnapshot: WorkflowRunState | string = row.snapshot as string;
821
- if (typeof parsedSnapshot === 'string') {
822
- try {
823
- parsedSnapshot = JSON.parse(row.snapshot as string) as WorkflowRunState;
824
- } catch (e) {
825
- // If parsing fails, return the raw snapshot string
826
- console.warn(`Failed to parse snapshot for workflow ${row.workflow_name}: ${e}`);
827
- }
861
+ const runs = (result || []).map(row => {
862
+ return this.parseWorkflowRun(row);
863
+ });
864
+
865
+ // Use runs.length as total when not paginating
866
+ return { runs, total: total || runs.length };
867
+ } catch (error) {
868
+ console.error('Error getting workflow runs:', error);
869
+ throw error;
870
+ }
871
+ }
872
+
873
+ async getWorkflowRunById({
874
+ runId,
875
+ workflowName,
876
+ }: {
877
+ runId: string;
878
+ workflowName?: string;
879
+ }): Promise<WorkflowRun | null> {
880
+ try {
881
+ const conditions: string[] = [];
882
+ const values: any[] = [];
883
+ let paramIndex = 1;
884
+
885
+ if (runId) {
886
+ conditions.push(`run_id = $${paramIndex}`);
887
+ values.push(runId);
888
+ paramIndex++;
828
889
  }
829
890
 
830
- return {
831
- workflowName: row.workflow_name,
832
- runId: row.run_id,
833
- snapshot: parsedSnapshot,
834
- createdAt: row.createdAt,
835
- updatedAt: row.updatedAt,
836
- };
837
- });
891
+ if (workflowName) {
892
+ conditions.push(`workflow_name = $${paramIndex}`);
893
+ values.push(workflowName);
894
+ paramIndex++;
895
+ }
896
+
897
+ const whereClause = conditions.length > 0 ? `WHERE ${conditions.join(' AND ')}` : '';
898
+
899
+ // Get results
900
+ const query = `
901
+ SELECT * FROM ${this.getTableName(TABLE_WORKFLOW_SNAPSHOT)}
902
+ ${whereClause}
903
+ `;
904
+
905
+ const queryValues = values;
838
906
 
839
- // Use runs.length as total when not paginating
840
- return { runs, total: total || runs.length };
907
+ const result = await this.db.oneOrNone(query, queryValues);
908
+
909
+ if (!result) {
910
+ return null;
911
+ }
912
+
913
+ return this.parseWorkflowRun(result);
914
+ } catch (error) {
915
+ console.error('Error getting workflow run by ID:', error);
916
+ throw error;
917
+ }
841
918
  }
842
919
 
843
920
  async close(): Promise<void> {
@@ -2436,4 +2436,45 @@ describe('PgVector', () => {
2436
2436
  });
2437
2437
  });
2438
2438
  });
2439
+
2440
+ describe('PoolConfig Custom Options', () => {
2441
+ it('should apply custom values to properties with default values', async () => {
2442
+ const db = new PgVector({
2443
+ connectionString,
2444
+ pgPoolOptions: {
2445
+ max: 5,
2446
+ idleTimeoutMillis: 10000,
2447
+ connectionTimeoutMillis: 1000,
2448
+ },
2449
+ });
2450
+
2451
+ expect(db['pool'].options.max).toBe(5);
2452
+ expect(db['pool'].options.idleTimeoutMillis).toBe(10000);
2453
+ expect(db['pool'].options.connectionTimeoutMillis).toBe(1000);
2454
+ });
2455
+
2456
+ it('should pass properties with no default values', async () => {
2457
+ const db = new PgVector({
2458
+ connectionString,
2459
+ pgPoolOptions: {
2460
+ ssl: false,
2461
+ },
2462
+ });
2463
+
2464
+ expect(db['pool'].options.ssl).toBe(false);
2465
+ });
2466
+ it('should keep default values when custom values are added', async () => {
2467
+ const db = new PgVector({
2468
+ connectionString,
2469
+ pgPoolOptions: {
2470
+ ssl: false,
2471
+ },
2472
+ });
2473
+
2474
+ expect(db['pool'].options.max).toBe(20);
2475
+ expect(db['pool'].options.idleTimeoutMillis).toBe(30000);
2476
+ expect(db['pool'].options.connectionTimeoutMillis).toBe(2000);
2477
+ expect(db['pool'].options.ssl).toBe(false);
2478
+ });
2479
+ });
2439
2480
  });
@@ -71,23 +71,49 @@ export class PgVector extends MastraVector {
71
71
  private schemaSetupComplete: boolean | undefined = undefined;
72
72
 
73
73
  constructor(connectionString: string);
74
- constructor(config: { connectionString: string; schemaName?: string });
75
- constructor(config: string | { connectionString: string; schemaName?: string }) {
76
- const connectionString = typeof config === 'string' ? config : config.connectionString;
77
- if (!connectionString || typeof connectionString !== 'string' || connectionString.trim() === '') {
74
+ constructor(config: {
75
+ connectionString: string;
76
+ schemaName?: string;
77
+ pgPoolOptions?: Omit<pg.PoolConfig, 'connectionString'>;
78
+ });
79
+ constructor(
80
+ config:
81
+ | string
82
+ | {
83
+ connectionString: string;
84
+ schemaName?: string;
85
+ pgPoolOptions?: Omit<pg.PoolConfig, 'connectionString'>;
86
+ },
87
+ ) {
88
+ let connectionString: string;
89
+ let pgPoolOptions: Omit<pg.PoolConfig, 'connectionString'> | undefined;
90
+ let schemaName: string | undefined;
91
+
92
+ if (typeof config === 'string') {
93
+ connectionString = config;
94
+ schemaName = undefined;
95
+ pgPoolOptions = undefined;
96
+ } else {
97
+ connectionString = config.connectionString;
98
+ schemaName = config.schemaName;
99
+ pgPoolOptions = config.pgPoolOptions;
100
+ }
101
+
102
+ if (!connectionString || connectionString.trim() === '') {
78
103
  throw new Error(
79
104
  'PgVector: connectionString must be provided and cannot be empty. Passing an empty string may cause fallback to local Postgres defaults.',
80
105
  );
81
106
  }
82
107
  super();
83
108
 
84
- this.schema = typeof config === 'string' ? undefined : config.schemaName;
109
+ this.schema = schemaName;
85
110
 
86
111
  const basePool = new pg.Pool({
87
112
  connectionString,
88
113
  max: 20, // Maximum number of clients in the pool
89
114
  idleTimeoutMillis: 30000, // Close idle connections after 30 seconds
90
115
  connectionTimeoutMillis: 2000, // Fail fast if can't connect
116
+ ...pgPoolOptions,
91
117
  });
92
118
 
93
119
  const telemetry = this.__getTelemetry();