@mastra/pg 0.2.10-alpha.3 → 0.2.10-alpha.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +7 -7
- package/CHANGELOG.md +13 -0
- package/dist/_tsup-dts-rollup.d.cts +17 -1
- package/dist/_tsup-dts-rollup.d.ts +17 -1
- package/dist/index.cjs +154 -39
- package/dist/index.js +154 -39
- package/package.json +2 -2
- package/src/storage/index.test.ts +277 -12
- package/src/storage/index.ts +90 -21
- package/src/vector/index.test.ts +483 -0
- package/src/vector/index.ts +101 -21
package/src/vector/index.ts
CHANGED
|
@@ -64,12 +64,20 @@ export class PgVector extends MastraVector {
|
|
|
64
64
|
private describeIndexCache: Map<string, PGIndexStats> = new Map();
|
|
65
65
|
private createdIndexes = new Map<string, number>();
|
|
66
66
|
private mutexesByName = new Map<string, Mutex>();
|
|
67
|
+
private schema?: string;
|
|
68
|
+
private setupSchemaPromise: Promise<void> | null = null;
|
|
67
69
|
private installVectorExtensionPromise: Promise<void> | null = null;
|
|
68
70
|
private vectorExtensionInstalled: boolean | undefined = undefined;
|
|
71
|
+
private schemaSetupComplete: boolean | undefined = undefined;
|
|
69
72
|
|
|
70
|
-
constructor(connectionString: string)
|
|
73
|
+
constructor(connectionString: string);
|
|
74
|
+
constructor(config: { connectionString: string; schemaName?: string });
|
|
75
|
+
constructor(config: string | { connectionString: string; schemaName?: string }) {
|
|
71
76
|
super();
|
|
72
77
|
|
|
78
|
+
const connectionString = typeof config === 'string' ? config : config.connectionString;
|
|
79
|
+
this.schema = typeof config === 'string' ? undefined : config.schemaName;
|
|
80
|
+
|
|
73
81
|
const basePool = new pg.Pool({
|
|
74
82
|
connectionString,
|
|
75
83
|
max: 20, // Maximum number of clients in the pool
|
|
@@ -108,6 +116,10 @@ export class PgVector extends MastraVector {
|
|
|
108
116
|
return this.mutexesByName.get(indexName)!;
|
|
109
117
|
}
|
|
110
118
|
|
|
119
|
+
private getTableName(indexName: string) {
|
|
120
|
+
return this.schema ? `${this.schema}.${indexName}` : indexName;
|
|
121
|
+
}
|
|
122
|
+
|
|
111
123
|
transformFilter(filter?: VectorFilter) {
|
|
112
124
|
const translator = new PGFilterTranslator();
|
|
113
125
|
return translator.translate(filter);
|
|
@@ -149,6 +161,8 @@ export class PgVector extends MastraVector {
|
|
|
149
161
|
await client.query(`SET LOCAL ivfflat.probes = ${probes}`);
|
|
150
162
|
}
|
|
151
163
|
|
|
164
|
+
const tableName = this.getTableName(indexName);
|
|
165
|
+
|
|
152
166
|
const query = `
|
|
153
167
|
WITH vector_scores AS (
|
|
154
168
|
SELECT
|
|
@@ -156,7 +170,7 @@ export class PgVector extends MastraVector {
|
|
|
156
170
|
1 - (embedding <=> '${vectorStr}'::vector) as score,
|
|
157
171
|
metadata
|
|
158
172
|
${includeVector ? ', embedding' : ''}
|
|
159
|
-
FROM ${
|
|
173
|
+
FROM ${tableName}
|
|
160
174
|
${filterQuery}
|
|
161
175
|
)
|
|
162
176
|
SELECT *
|
|
@@ -181,6 +195,7 @@ export class PgVector extends MastraVector {
|
|
|
181
195
|
const params = this.normalizeArgs<UpsertVectorParams>('upsert', args);
|
|
182
196
|
|
|
183
197
|
const { indexName, vectors, metadata, ids } = params;
|
|
198
|
+
const tableName = this.getTableName(indexName);
|
|
184
199
|
|
|
185
200
|
// Start a transaction
|
|
186
201
|
const client = await this.pool.connect();
|
|
@@ -190,7 +205,7 @@ export class PgVector extends MastraVector {
|
|
|
190
205
|
|
|
191
206
|
for (let i = 0; i < vectors.length; i++) {
|
|
192
207
|
const query = `
|
|
193
|
-
INSERT INTO ${
|
|
208
|
+
INSERT INTO ${tableName} (vector_id, embedding, metadata)
|
|
194
209
|
VALUES ($1, $2::vector, $3::jsonb)
|
|
195
210
|
ON CONFLICT (vector_id)
|
|
196
211
|
DO UPDATE SET
|
|
@@ -231,6 +246,57 @@ export class PgVector extends MastraVector {
|
|
|
231
246
|
const existingIndexCacheKey = this.createdIndexes.get(indexName);
|
|
232
247
|
return existingIndexCacheKey && existingIndexCacheKey === newKey;
|
|
233
248
|
}
|
|
249
|
+
private async setupSchema(client: pg.PoolClient) {
|
|
250
|
+
if (!this.schema || this.schemaSetupComplete) {
|
|
251
|
+
return;
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
if (!this.setupSchemaPromise) {
|
|
255
|
+
this.setupSchemaPromise = (async () => {
|
|
256
|
+
try {
|
|
257
|
+
// First check if schema exists and we have usage permission
|
|
258
|
+
const schemaCheck = await client.query(
|
|
259
|
+
`
|
|
260
|
+
SELECT EXISTS (
|
|
261
|
+
SELECT 1 FROM information_schema.schemata
|
|
262
|
+
WHERE schema_name = $1
|
|
263
|
+
)
|
|
264
|
+
`,
|
|
265
|
+
[this.schema],
|
|
266
|
+
);
|
|
267
|
+
|
|
268
|
+
const schemaExists = schemaCheck.rows[0].exists;
|
|
269
|
+
|
|
270
|
+
if (!schemaExists) {
|
|
271
|
+
try {
|
|
272
|
+
await client.query(`CREATE SCHEMA IF NOT EXISTS ${this.schema}`);
|
|
273
|
+
this.logger.info(`Schema "${this.schema}" created successfully`);
|
|
274
|
+
} catch (error) {
|
|
275
|
+
this.logger.error(`Failed to create schema "${this.schema}"`, { error });
|
|
276
|
+
throw new Error(
|
|
277
|
+
`Unable to create schema "${this.schema}". This requires CREATE privilege on the database. ` +
|
|
278
|
+
`Either create the schema manually or grant CREATE privilege to the user.`,
|
|
279
|
+
);
|
|
280
|
+
}
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
// If we got here, schema exists and we can use it
|
|
284
|
+
this.schemaSetupComplete = true;
|
|
285
|
+
this.logger.debug(`Schema "${this.schema}" is ready for use`);
|
|
286
|
+
} catch (error) {
|
|
287
|
+
// Reset flags so we can retry
|
|
288
|
+
this.schemaSetupComplete = undefined;
|
|
289
|
+
this.setupSchemaPromise = null;
|
|
290
|
+
throw error;
|
|
291
|
+
} finally {
|
|
292
|
+
this.setupSchemaPromise = null;
|
|
293
|
+
}
|
|
294
|
+
})();
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
await this.setupSchemaPromise;
|
|
298
|
+
}
|
|
299
|
+
|
|
234
300
|
async createIndex(...args: ParamsToArgs<PgCreateIndexParams> | PgCreateIndexArgs): Promise<void> {
|
|
235
301
|
const params = this.normalizeArgs<PgCreateIndexParams, PgCreateIndexArgs>('createIndex', args, [
|
|
236
302
|
'indexConfig',
|
|
@@ -238,6 +304,7 @@ export class PgVector extends MastraVector {
|
|
|
238
304
|
]);
|
|
239
305
|
|
|
240
306
|
const { indexName, dimension, metric = 'cosine', indexConfig = {}, buildIndex = true } = params;
|
|
307
|
+
const tableName = this.getTableName(indexName);
|
|
241
308
|
|
|
242
309
|
// Validate inputs
|
|
243
310
|
if (!indexName.match(/^[a-zA-Z_][a-zA-Z0-9_]*$/)) {
|
|
@@ -262,17 +329,21 @@ export class PgVector extends MastraVector {
|
|
|
262
329
|
}
|
|
263
330
|
|
|
264
331
|
const client = await this.pool.connect();
|
|
332
|
+
|
|
265
333
|
try {
|
|
266
|
-
//
|
|
334
|
+
// Setup schema if needed
|
|
335
|
+
await this.setupSchema(client);
|
|
336
|
+
|
|
337
|
+
// Install vector extension first (needs to be in public schema)
|
|
267
338
|
await this.installVectorExtension(client);
|
|
268
339
|
await client.query(`
|
|
269
|
-
CREATE TABLE IF NOT EXISTS ${
|
|
340
|
+
CREATE TABLE IF NOT EXISTS ${tableName} (
|
|
270
341
|
id SERIAL PRIMARY KEY,
|
|
271
342
|
vector_id TEXT UNIQUE NOT NULL,
|
|
272
343
|
embedding vector(${dimension}),
|
|
273
344
|
metadata JSONB DEFAULT '{}'::jsonb
|
|
274
345
|
);
|
|
275
|
-
|
|
346
|
+
`);
|
|
276
347
|
this.createdIndexes.set(indexName, indexCacheKey);
|
|
277
348
|
|
|
278
349
|
if (buildIndex) {
|
|
@@ -280,7 +351,6 @@ export class PgVector extends MastraVector {
|
|
|
280
351
|
}
|
|
281
352
|
} catch (error: any) {
|
|
282
353
|
this.createdIndexes.delete(indexName);
|
|
283
|
-
console.error('Failed to create vector table:', error);
|
|
284
354
|
throw error;
|
|
285
355
|
} finally {
|
|
286
356
|
client.release();
|
|
@@ -319,8 +389,10 @@ export class PgVector extends MastraVector {
|
|
|
319
389
|
const mutex = this.getMutexByName(`build-${indexName}`);
|
|
320
390
|
// Use async-mutex instead of advisory lock for perf (over 2x as fast)
|
|
321
391
|
await mutex.runExclusive(async () => {
|
|
392
|
+
const tableName = this.getTableName(indexName);
|
|
393
|
+
|
|
322
394
|
if (this.createdIndexes.has(indexName)) {
|
|
323
|
-
await client.query(`DROP INDEX IF EXISTS ${
|
|
395
|
+
await client.query(`DROP INDEX IF EXISTS ${tableName}_vector_idx`);
|
|
324
396
|
}
|
|
325
397
|
|
|
326
398
|
if (indexConfig.type === 'flat') {
|
|
@@ -338,7 +410,7 @@ export class PgVector extends MastraVector {
|
|
|
338
410
|
|
|
339
411
|
indexSQL = `
|
|
340
412
|
CREATE INDEX IF NOT EXISTS ${indexName}_vector_idx
|
|
341
|
-
ON ${
|
|
413
|
+
ON ${tableName}
|
|
342
414
|
USING hnsw (embedding ${metricOp})
|
|
343
415
|
WITH (
|
|
344
416
|
m = ${m},
|
|
@@ -350,12 +422,12 @@ export class PgVector extends MastraVector {
|
|
|
350
422
|
if (indexConfig.ivf?.lists) {
|
|
351
423
|
lists = indexConfig.ivf.lists;
|
|
352
424
|
} else {
|
|
353
|
-
const size = (await client.query(`SELECT COUNT(*) FROM ${
|
|
425
|
+
const size = (await client.query(`SELECT COUNT(*) FROM ${tableName}`)).rows[0].count;
|
|
354
426
|
lists = Math.max(100, Math.min(4000, Math.floor(Math.sqrt(size) * 2)));
|
|
355
427
|
}
|
|
356
428
|
indexSQL = `
|
|
357
429
|
CREATE INDEX IF NOT EXISTS ${indexName}_vector_idx
|
|
358
|
-
ON ${
|
|
430
|
+
ON ${tableName}
|
|
359
431
|
USING ivfflat (embedding ${metricOp})
|
|
360
432
|
WITH (lists = ${lists});
|
|
361
433
|
`;
|
|
@@ -423,10 +495,10 @@ export class PgVector extends MastraVector {
|
|
|
423
495
|
const vectorTablesQuery = `
|
|
424
496
|
SELECT DISTINCT table_name
|
|
425
497
|
FROM information_schema.columns
|
|
426
|
-
WHERE table_schema =
|
|
498
|
+
WHERE table_schema = $1
|
|
427
499
|
AND udt_name = 'vector';
|
|
428
500
|
`;
|
|
429
|
-
const vectorTables = await client.query(vectorTablesQuery);
|
|
501
|
+
const vectorTables = await client.query(vectorTablesQuery, [this.schema || 'public']);
|
|
430
502
|
return vectorTables.rows.map(row => row.table_name);
|
|
431
503
|
} finally {
|
|
432
504
|
client.release();
|
|
@@ -436,6 +508,8 @@ export class PgVector extends MastraVector {
|
|
|
436
508
|
async describeIndex(indexName: string): Promise<PGIndexStats> {
|
|
437
509
|
const client = await this.pool.connect();
|
|
438
510
|
try {
|
|
511
|
+
const tableName = this.getTableName(indexName);
|
|
512
|
+
|
|
439
513
|
// Get vector dimension
|
|
440
514
|
const dimensionQuery = `
|
|
441
515
|
SELECT atttypmod as dimension
|
|
@@ -445,8 +519,9 @@ export class PgVector extends MastraVector {
|
|
|
445
519
|
`;
|
|
446
520
|
|
|
447
521
|
// Get row count
|
|
448
|
-
const countQuery = `
|
|
449
|
-
|
|
522
|
+
const countQuery = `
|
|
523
|
+
SELECT COUNT(*) as count
|
|
524
|
+
FROM ${tableName};
|
|
450
525
|
`;
|
|
451
526
|
|
|
452
527
|
// Get index metric type
|
|
@@ -459,11 +534,11 @@ export class PgVector extends MastraVector {
|
|
|
459
534
|
JOIN pg_class c ON i.indexrelid = c.oid
|
|
460
535
|
JOIN pg_am am ON c.relam = am.oid
|
|
461
536
|
JOIN pg_opclass opclass ON i.indclass[0] = opclass.oid
|
|
462
|
-
WHERE c.relname = '${
|
|
537
|
+
WHERE c.relname = '${tableName}_vector_idx';
|
|
463
538
|
`;
|
|
464
539
|
|
|
465
540
|
const [dimResult, countResult, indexResult] = await Promise.all([
|
|
466
|
-
client.query(dimensionQuery, [
|
|
541
|
+
client.query(dimensionQuery, [tableName]),
|
|
467
542
|
client.query(countQuery),
|
|
468
543
|
client.query(indexQuery),
|
|
469
544
|
]);
|
|
@@ -512,8 +587,9 @@ export class PgVector extends MastraVector {
|
|
|
512
587
|
async deleteIndex(indexName: string): Promise<void> {
|
|
513
588
|
const client = await this.pool.connect();
|
|
514
589
|
try {
|
|
590
|
+
const tableName = this.getTableName(indexName);
|
|
515
591
|
// Drop the table
|
|
516
|
-
await client.query(`DROP TABLE IF EXISTS ${
|
|
592
|
+
await client.query(`DROP TABLE IF EXISTS ${tableName} CASCADE`);
|
|
517
593
|
this.createdIndexes.delete(indexName);
|
|
518
594
|
} catch (error: any) {
|
|
519
595
|
await client.query('ROLLBACK');
|
|
@@ -526,7 +602,8 @@ export class PgVector extends MastraVector {
|
|
|
526
602
|
async truncateIndex(indexName: string) {
|
|
527
603
|
const client = await this.pool.connect();
|
|
528
604
|
try {
|
|
529
|
-
|
|
605
|
+
const tableName = this.getTableName(indexName);
|
|
606
|
+
await client.query(`TRUNCATE ${tableName}`);
|
|
530
607
|
} catch (e: any) {
|
|
531
608
|
await client.query('ROLLBACK');
|
|
532
609
|
throw new Error(`Failed to truncate vector table: ${e.message}`);
|
|
@@ -572,10 +649,12 @@ export class PgVector extends MastraVector {
|
|
|
572
649
|
return;
|
|
573
650
|
}
|
|
574
651
|
|
|
652
|
+
const tableName = this.getTableName(indexName);
|
|
653
|
+
|
|
575
654
|
// query looks like this:
|
|
576
655
|
// UPDATE table SET embedding = $2::vector, metadata = $3::jsonb WHERE id = $1
|
|
577
656
|
const query = `
|
|
578
|
-
UPDATE ${
|
|
657
|
+
UPDATE ${tableName}
|
|
579
658
|
SET ${updateParts.join(', ')}
|
|
580
659
|
WHERE vector_id = $1
|
|
581
660
|
`;
|
|
@@ -589,8 +668,9 @@ export class PgVector extends MastraVector {
|
|
|
589
668
|
async deleteIndexById(indexName: string, id: string): Promise<void> {
|
|
590
669
|
const client = await this.pool.connect();
|
|
591
670
|
try {
|
|
671
|
+
const tableName = this.getTableName(indexName);
|
|
592
672
|
const query = `
|
|
593
|
-
DELETE FROM ${
|
|
673
|
+
DELETE FROM ${tableName}
|
|
594
674
|
WHERE vector_id = $1
|
|
595
675
|
`;
|
|
596
676
|
await client.query(query, [id]);
|