@mastra/pg 0.3.1-alpha.2 → 0.3.1-alpha.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +9 -9
- package/CHANGELOG.md +17 -0
- package/dist/_tsup-dts-rollup.d.cts +14 -12
- package/dist/_tsup-dts-rollup.d.ts +14 -12
- package/dist/index.cjs +126 -52
- package/dist/index.js +126 -52
- package/package.json +4 -4
- package/src/storage/index.test.ts +216 -45
- package/src/storage/index.ts +143 -66
- package/src/vector/index.test.ts +41 -0
- package/src/vector/index.ts +31 -5
package/src/storage/index.ts
CHANGED
|
@@ -8,7 +8,14 @@ import {
|
|
|
8
8
|
TABLE_WORKFLOW_SNAPSHOT,
|
|
9
9
|
TABLE_EVALS,
|
|
10
10
|
} from '@mastra/core/storage';
|
|
11
|
-
import type {
|
|
11
|
+
import type {
|
|
12
|
+
EvalRow,
|
|
13
|
+
StorageColumn,
|
|
14
|
+
StorageGetMessagesArg,
|
|
15
|
+
TABLE_NAMES,
|
|
16
|
+
WorkflowRun,
|
|
17
|
+
WorkflowRuns,
|
|
18
|
+
} from '@mastra/core/storage';
|
|
12
19
|
import type { WorkflowRunState } from '@mastra/core/workflows';
|
|
13
20
|
import pgPromise from 'pg-promise';
|
|
14
21
|
import type { ISSLConfig } from 'pg-promise/typescript/pg-subset';
|
|
@@ -561,7 +568,7 @@ export class PostgresStore extends MastraStorage {
|
|
|
561
568
|
}
|
|
562
569
|
}
|
|
563
570
|
|
|
564
|
-
async getMessages<T = unknown>({ threadId, selectBy }: StorageGetMessagesArg): Promise<T> {
|
|
571
|
+
async getMessages<T = unknown>({ threadId, selectBy }: StorageGetMessagesArg): Promise<T[]> {
|
|
565
572
|
try {
|
|
566
573
|
const messages: any[] = [];
|
|
567
574
|
const limit = typeof selectBy?.last === `number` ? selectBy.last : 40;
|
|
@@ -645,7 +652,7 @@ export class PostgresStore extends MastraStorage {
|
|
|
645
652
|
}
|
|
646
653
|
});
|
|
647
654
|
|
|
648
|
-
return messages as T;
|
|
655
|
+
return messages as T[];
|
|
649
656
|
} catch (error) {
|
|
650
657
|
console.error('Error getting messages:', error);
|
|
651
658
|
throw error;
|
|
@@ -748,96 +755,166 @@ export class PostgresStore extends MastraStorage {
|
|
|
748
755
|
}
|
|
749
756
|
}
|
|
750
757
|
|
|
758
|
+
private async hasColumn(table: string, column: string): Promise<boolean> {
|
|
759
|
+
// Use this.schema to scope the check
|
|
760
|
+
const schema = this.schema || 'public';
|
|
761
|
+
const result = await this.db.oneOrNone(
|
|
762
|
+
`SELECT 1 FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 AND (column_name = $3 OR column_name = $4)`,
|
|
763
|
+
[schema, table, column, column.toLowerCase()],
|
|
764
|
+
);
|
|
765
|
+
return !!result;
|
|
766
|
+
}
|
|
767
|
+
|
|
768
|
+
private parseWorkflowRun(row: any): WorkflowRun {
|
|
769
|
+
let parsedSnapshot: WorkflowRunState | string = row.snapshot as string;
|
|
770
|
+
if (typeof parsedSnapshot === 'string') {
|
|
771
|
+
try {
|
|
772
|
+
parsedSnapshot = JSON.parse(row.snapshot as string) as WorkflowRunState;
|
|
773
|
+
} catch (e) {
|
|
774
|
+
// If parsing fails, return the raw snapshot string
|
|
775
|
+
console.warn(`Failed to parse snapshot for workflow ${row.workflow_name}: ${e}`);
|
|
776
|
+
}
|
|
777
|
+
}
|
|
778
|
+
|
|
779
|
+
return {
|
|
780
|
+
workflowName: row.workflow_name,
|
|
781
|
+
runId: row.run_id,
|
|
782
|
+
snapshot: parsedSnapshot,
|
|
783
|
+
createdAt: row.createdAt,
|
|
784
|
+
updatedAt: row.updatedAt,
|
|
785
|
+
resourceId: row.resourceId,
|
|
786
|
+
};
|
|
787
|
+
}
|
|
788
|
+
|
|
751
789
|
async getWorkflowRuns({
|
|
752
790
|
workflowName,
|
|
753
791
|
fromDate,
|
|
754
792
|
toDate,
|
|
755
793
|
limit,
|
|
756
794
|
offset,
|
|
795
|
+
resourceId,
|
|
757
796
|
}: {
|
|
758
797
|
workflowName?: string;
|
|
759
798
|
fromDate?: Date;
|
|
760
799
|
toDate?: Date;
|
|
761
800
|
limit?: number;
|
|
762
801
|
offset?: number;
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
let paramIndex = 1;
|
|
776
|
-
|
|
777
|
-
if (workflowName) {
|
|
778
|
-
conditions.push(`workflow_name = $${paramIndex}`);
|
|
779
|
-
values.push(workflowName);
|
|
780
|
-
paramIndex++;
|
|
781
|
-
}
|
|
782
|
-
|
|
783
|
-
if (fromDate) {
|
|
784
|
-
conditions.push(`"createdAt" >= $${paramIndex}`);
|
|
785
|
-
values.push(fromDate);
|
|
786
|
-
paramIndex++;
|
|
787
|
-
}
|
|
802
|
+
resourceId?: string;
|
|
803
|
+
} = {}): Promise<WorkflowRuns> {
|
|
804
|
+
try {
|
|
805
|
+
const conditions: string[] = [];
|
|
806
|
+
const values: any[] = [];
|
|
807
|
+
let paramIndex = 1;
|
|
808
|
+
|
|
809
|
+
if (workflowName) {
|
|
810
|
+
conditions.push(`workflow_name = $${paramIndex}`);
|
|
811
|
+
values.push(workflowName);
|
|
812
|
+
paramIndex++;
|
|
813
|
+
}
|
|
788
814
|
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
815
|
+
if (resourceId) {
|
|
816
|
+
const hasResourceId = await this.hasColumn(TABLE_WORKFLOW_SNAPSHOT, 'resourceId');
|
|
817
|
+
if (hasResourceId) {
|
|
818
|
+
conditions.push(`"resourceId" = $${paramIndex}`);
|
|
819
|
+
values.push(resourceId);
|
|
820
|
+
paramIndex++;
|
|
821
|
+
} else {
|
|
822
|
+
console.warn(`[${TABLE_WORKFLOW_SNAPSHOT}] resourceId column not found. Skipping resourceId filter.`);
|
|
823
|
+
}
|
|
824
|
+
}
|
|
794
825
|
|
|
795
|
-
|
|
826
|
+
if (fromDate) {
|
|
827
|
+
conditions.push(`"createdAt" >= $${paramIndex}`);
|
|
828
|
+
values.push(fromDate);
|
|
829
|
+
paramIndex++;
|
|
830
|
+
}
|
|
796
831
|
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
total =
|
|
805
|
-
|
|
832
|
+
if (toDate) {
|
|
833
|
+
conditions.push(`"createdAt" <= $${paramIndex}`);
|
|
834
|
+
values.push(toDate);
|
|
835
|
+
paramIndex++;
|
|
836
|
+
}
|
|
837
|
+
const whereClause = conditions.length > 0 ? `WHERE ${conditions.join(' AND ')}` : '';
|
|
838
|
+
|
|
839
|
+
let total = 0;
|
|
840
|
+
// Only get total count when using pagination
|
|
841
|
+
if (limit !== undefined && offset !== undefined) {
|
|
842
|
+
const countResult = await this.db.one(
|
|
843
|
+
`SELECT COUNT(*) as count FROM ${this.getTableName(TABLE_WORKFLOW_SNAPSHOT)} ${whereClause}`,
|
|
844
|
+
values,
|
|
845
|
+
);
|
|
846
|
+
total = Number(countResult.count);
|
|
847
|
+
}
|
|
806
848
|
|
|
807
|
-
|
|
808
|
-
|
|
849
|
+
// Get results
|
|
850
|
+
const query = `
|
|
809
851
|
SELECT * FROM ${this.getTableName(TABLE_WORKFLOW_SNAPSHOT)}
|
|
810
852
|
${whereClause}
|
|
811
853
|
ORDER BY "createdAt" DESC
|
|
812
854
|
${limit !== undefined && offset !== undefined ? ` LIMIT $${paramIndex} OFFSET $${paramIndex + 1}` : ''}
|
|
813
855
|
`;
|
|
814
856
|
|
|
815
|
-
|
|
857
|
+
const queryValues = limit !== undefined && offset !== undefined ? [...values, limit, offset] : values;
|
|
816
858
|
|
|
817
|
-
|
|
859
|
+
const result = await this.db.manyOrNone(query, queryValues);
|
|
818
860
|
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
861
|
+
const runs = (result || []).map(row => {
|
|
862
|
+
return this.parseWorkflowRun(row);
|
|
863
|
+
});
|
|
864
|
+
|
|
865
|
+
// Use runs.length as total when not paginating
|
|
866
|
+
return { runs, total: total || runs.length };
|
|
867
|
+
} catch (error) {
|
|
868
|
+
console.error('Error getting workflow runs:', error);
|
|
869
|
+
throw error;
|
|
870
|
+
}
|
|
871
|
+
}
|
|
872
|
+
|
|
873
|
+
async getWorkflowRunById({
|
|
874
|
+
runId,
|
|
875
|
+
workflowName,
|
|
876
|
+
}: {
|
|
877
|
+
runId: string;
|
|
878
|
+
workflowName?: string;
|
|
879
|
+
}): Promise<WorkflowRun | null> {
|
|
880
|
+
try {
|
|
881
|
+
const conditions: string[] = [];
|
|
882
|
+
const values: any[] = [];
|
|
883
|
+
let paramIndex = 1;
|
|
884
|
+
|
|
885
|
+
if (runId) {
|
|
886
|
+
conditions.push(`run_id = $${paramIndex}`);
|
|
887
|
+
values.push(runId);
|
|
888
|
+
paramIndex++;
|
|
828
889
|
}
|
|
829
890
|
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
};
|
|
837
|
-
|
|
891
|
+
if (workflowName) {
|
|
892
|
+
conditions.push(`workflow_name = $${paramIndex}`);
|
|
893
|
+
values.push(workflowName);
|
|
894
|
+
paramIndex++;
|
|
895
|
+
}
|
|
896
|
+
|
|
897
|
+
const whereClause = conditions.length > 0 ? `WHERE ${conditions.join(' AND ')}` : '';
|
|
898
|
+
|
|
899
|
+
// Get results
|
|
900
|
+
const query = `
|
|
901
|
+
SELECT * FROM ${this.getTableName(TABLE_WORKFLOW_SNAPSHOT)}
|
|
902
|
+
${whereClause}
|
|
903
|
+
`;
|
|
904
|
+
|
|
905
|
+
const queryValues = values;
|
|
838
906
|
|
|
839
|
-
|
|
840
|
-
|
|
907
|
+
const result = await this.db.oneOrNone(query, queryValues);
|
|
908
|
+
|
|
909
|
+
if (!result) {
|
|
910
|
+
return null;
|
|
911
|
+
}
|
|
912
|
+
|
|
913
|
+
return this.parseWorkflowRun(result);
|
|
914
|
+
} catch (error) {
|
|
915
|
+
console.error('Error getting workflow run by ID:', error);
|
|
916
|
+
throw error;
|
|
917
|
+
}
|
|
841
918
|
}
|
|
842
919
|
|
|
843
920
|
async close(): Promise<void> {
|
package/src/vector/index.test.ts
CHANGED
|
@@ -2436,4 +2436,45 @@ describe('PgVector', () => {
|
|
|
2436
2436
|
});
|
|
2437
2437
|
});
|
|
2438
2438
|
});
|
|
2439
|
+
|
|
2440
|
+
describe('PoolConfig Custom Options', () => {
|
|
2441
|
+
it('should apply custom values to properties with default values', async () => {
|
|
2442
|
+
const db = new PgVector({
|
|
2443
|
+
connectionString,
|
|
2444
|
+
pgPoolOptions: {
|
|
2445
|
+
max: 5,
|
|
2446
|
+
idleTimeoutMillis: 10000,
|
|
2447
|
+
connectionTimeoutMillis: 1000,
|
|
2448
|
+
},
|
|
2449
|
+
});
|
|
2450
|
+
|
|
2451
|
+
expect(db['pool'].options.max).toBe(5);
|
|
2452
|
+
expect(db['pool'].options.idleTimeoutMillis).toBe(10000);
|
|
2453
|
+
expect(db['pool'].options.connectionTimeoutMillis).toBe(1000);
|
|
2454
|
+
});
|
|
2455
|
+
|
|
2456
|
+
it('should pass properties with no default values', async () => {
|
|
2457
|
+
const db = new PgVector({
|
|
2458
|
+
connectionString,
|
|
2459
|
+
pgPoolOptions: {
|
|
2460
|
+
ssl: false,
|
|
2461
|
+
},
|
|
2462
|
+
});
|
|
2463
|
+
|
|
2464
|
+
expect(db['pool'].options.ssl).toBe(false);
|
|
2465
|
+
});
|
|
2466
|
+
it('should keep default values when custom values are added', async () => {
|
|
2467
|
+
const db = new PgVector({
|
|
2468
|
+
connectionString,
|
|
2469
|
+
pgPoolOptions: {
|
|
2470
|
+
ssl: false,
|
|
2471
|
+
},
|
|
2472
|
+
});
|
|
2473
|
+
|
|
2474
|
+
expect(db['pool'].options.max).toBe(20);
|
|
2475
|
+
expect(db['pool'].options.idleTimeoutMillis).toBe(30000);
|
|
2476
|
+
expect(db['pool'].options.connectionTimeoutMillis).toBe(2000);
|
|
2477
|
+
expect(db['pool'].options.ssl).toBe(false);
|
|
2478
|
+
});
|
|
2479
|
+
});
|
|
2439
2480
|
});
|
package/src/vector/index.ts
CHANGED
|
@@ -71,23 +71,49 @@ export class PgVector extends MastraVector {
|
|
|
71
71
|
private schemaSetupComplete: boolean | undefined = undefined;
|
|
72
72
|
|
|
73
73
|
constructor(connectionString: string);
|
|
74
|
-
constructor(config: {
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
74
|
+
constructor(config: {
|
|
75
|
+
connectionString: string;
|
|
76
|
+
schemaName?: string;
|
|
77
|
+
pgPoolOptions?: Omit<pg.PoolConfig, 'connectionString'>;
|
|
78
|
+
});
|
|
79
|
+
constructor(
|
|
80
|
+
config:
|
|
81
|
+
| string
|
|
82
|
+
| {
|
|
83
|
+
connectionString: string;
|
|
84
|
+
schemaName?: string;
|
|
85
|
+
pgPoolOptions?: Omit<pg.PoolConfig, 'connectionString'>;
|
|
86
|
+
},
|
|
87
|
+
) {
|
|
88
|
+
let connectionString: string;
|
|
89
|
+
let pgPoolOptions: Omit<pg.PoolConfig, 'connectionString'> | undefined;
|
|
90
|
+
let schemaName: string | undefined;
|
|
91
|
+
|
|
92
|
+
if (typeof config === 'string') {
|
|
93
|
+
connectionString = config;
|
|
94
|
+
schemaName = undefined;
|
|
95
|
+
pgPoolOptions = undefined;
|
|
96
|
+
} else {
|
|
97
|
+
connectionString = config.connectionString;
|
|
98
|
+
schemaName = config.schemaName;
|
|
99
|
+
pgPoolOptions = config.pgPoolOptions;
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
if (!connectionString || connectionString.trim() === '') {
|
|
78
103
|
throw new Error(
|
|
79
104
|
'PgVector: connectionString must be provided and cannot be empty. Passing an empty string may cause fallback to local Postgres defaults.',
|
|
80
105
|
);
|
|
81
106
|
}
|
|
82
107
|
super();
|
|
83
108
|
|
|
84
|
-
this.schema =
|
|
109
|
+
this.schema = schemaName;
|
|
85
110
|
|
|
86
111
|
const basePool = new pg.Pool({
|
|
87
112
|
connectionString,
|
|
88
113
|
max: 20, // Maximum number of clients in the pool
|
|
89
114
|
idleTimeoutMillis: 30000, // Close idle connections after 30 seconds
|
|
90
115
|
connectionTimeoutMillis: 2000, // Fail fast if can't connect
|
|
116
|
+
...pgPoolOptions,
|
|
91
117
|
});
|
|
92
118
|
|
|
93
119
|
const telemetry = this.__getTelemetry();
|