@mastra/pg 0.2.10-alpha.4 → 0.2.10-alpha.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -64,12 +64,20 @@ export class PgVector extends MastraVector {
64
64
  private describeIndexCache: Map<string, PGIndexStats> = new Map();
65
65
  private createdIndexes = new Map<string, number>();
66
66
  private mutexesByName = new Map<string, Mutex>();
67
+ private schema?: string;
68
+ private setupSchemaPromise: Promise<void> | null = null;
67
69
  private installVectorExtensionPromise: Promise<void> | null = null;
68
70
  private vectorExtensionInstalled: boolean | undefined = undefined;
71
+ private schemaSetupComplete: boolean | undefined = undefined;
69
72
 
70
- constructor(connectionString: string) {
73
+ constructor(connectionString: string);
74
+ constructor(config: { connectionString: string; schemaName?: string });
75
+ constructor(config: string | { connectionString: string; schemaName?: string }) {
71
76
  super();
72
77
 
78
+ const connectionString = typeof config === 'string' ? config : config.connectionString;
79
+ this.schema = typeof config === 'string' ? undefined : config.schemaName;
80
+
73
81
  const basePool = new pg.Pool({
74
82
  connectionString,
75
83
  max: 20, // Maximum number of clients in the pool
@@ -108,6 +116,10 @@ export class PgVector extends MastraVector {
108
116
  return this.mutexesByName.get(indexName)!;
109
117
  }
110
118
 
119
+ private getTableName(indexName: string) {
120
+ return this.schema ? `${this.schema}.${indexName}` : indexName;
121
+ }
122
+
111
123
  transformFilter(filter?: VectorFilter) {
112
124
  const translator = new PGFilterTranslator();
113
125
  return translator.translate(filter);
@@ -149,6 +161,8 @@ export class PgVector extends MastraVector {
149
161
  await client.query(`SET LOCAL ivfflat.probes = ${probes}`);
150
162
  }
151
163
 
164
+ const tableName = this.getTableName(indexName);
165
+
152
166
  const query = `
153
167
  WITH vector_scores AS (
154
168
  SELECT
@@ -156,7 +170,7 @@ export class PgVector extends MastraVector {
156
170
  1 - (embedding <=> '${vectorStr}'::vector) as score,
157
171
  metadata
158
172
  ${includeVector ? ', embedding' : ''}
159
- FROM ${indexName}
173
+ FROM ${tableName}
160
174
  ${filterQuery}
161
175
  )
162
176
  SELECT *
@@ -181,6 +195,7 @@ export class PgVector extends MastraVector {
181
195
  const params = this.normalizeArgs<UpsertVectorParams>('upsert', args);
182
196
 
183
197
  const { indexName, vectors, metadata, ids } = params;
198
+ const tableName = this.getTableName(indexName);
184
199
 
185
200
  // Start a transaction
186
201
  const client = await this.pool.connect();
@@ -190,7 +205,7 @@ export class PgVector extends MastraVector {
190
205
 
191
206
  for (let i = 0; i < vectors.length; i++) {
192
207
  const query = `
193
- INSERT INTO ${indexName} (vector_id, embedding, metadata)
208
+ INSERT INTO ${tableName} (vector_id, embedding, metadata)
194
209
  VALUES ($1, $2::vector, $3::jsonb)
195
210
  ON CONFLICT (vector_id)
196
211
  DO UPDATE SET
@@ -231,6 +246,57 @@ export class PgVector extends MastraVector {
231
246
  const existingIndexCacheKey = this.createdIndexes.get(indexName);
232
247
  return existingIndexCacheKey && existingIndexCacheKey === newKey;
233
248
  }
249
+ private async setupSchema(client: pg.PoolClient) {
250
+ if (!this.schema || this.schemaSetupComplete) {
251
+ return;
252
+ }
253
+
254
+ if (!this.setupSchemaPromise) {
255
+ this.setupSchemaPromise = (async () => {
256
+ try {
257
+ // First check if schema exists and we have usage permission
258
+ const schemaCheck = await client.query(
259
+ `
260
+ SELECT EXISTS (
261
+ SELECT 1 FROM information_schema.schemata
262
+ WHERE schema_name = $1
263
+ )
264
+ `,
265
+ [this.schema],
266
+ );
267
+
268
+ const schemaExists = schemaCheck.rows[0].exists;
269
+
270
+ if (!schemaExists) {
271
+ try {
272
+ await client.query(`CREATE SCHEMA IF NOT EXISTS ${this.schema}`);
273
+ this.logger.info(`Schema "${this.schema}" created successfully`);
274
+ } catch (error) {
275
+ this.logger.error(`Failed to create schema "${this.schema}"`, { error });
276
+ throw new Error(
277
+ `Unable to create schema "${this.schema}". This requires CREATE privilege on the database. ` +
278
+ `Either create the schema manually or grant CREATE privilege to the user.`,
279
+ );
280
+ }
281
+ }
282
+
283
+ // If we got here, schema exists and we can use it
284
+ this.schemaSetupComplete = true;
285
+ this.logger.debug(`Schema "${this.schema}" is ready for use`);
286
+ } catch (error) {
287
+ // Reset flags so we can retry
288
+ this.schemaSetupComplete = undefined;
289
+ this.setupSchemaPromise = null;
290
+ throw error;
291
+ } finally {
292
+ this.setupSchemaPromise = null;
293
+ }
294
+ })();
295
+ }
296
+
297
+ await this.setupSchemaPromise;
298
+ }
299
+
234
300
  async createIndex(...args: ParamsToArgs<PgCreateIndexParams> | PgCreateIndexArgs): Promise<void> {
235
301
  const params = this.normalizeArgs<PgCreateIndexParams, PgCreateIndexArgs>('createIndex', args, [
236
302
  'indexConfig',
@@ -238,6 +304,7 @@ export class PgVector extends MastraVector {
238
304
  ]);
239
305
 
240
306
  const { indexName, dimension, metric = 'cosine', indexConfig = {}, buildIndex = true } = params;
307
+ const tableName = this.getTableName(indexName);
241
308
 
242
309
  // Validate inputs
243
310
  if (!indexName.match(/^[a-zA-Z_][a-zA-Z0-9_]*$/)) {
@@ -262,17 +329,21 @@ export class PgVector extends MastraVector {
262
329
  }
263
330
 
264
331
  const client = await this.pool.connect();
332
+
265
333
  try {
266
- // install vector extension
334
+ // Setup schema if needed
335
+ await this.setupSchema(client);
336
+
337
+ // Install vector extension first (needs to be in public schema)
267
338
  await this.installVectorExtension(client);
268
339
  await client.query(`
269
- CREATE TABLE IF NOT EXISTS ${indexName} (
340
+ CREATE TABLE IF NOT EXISTS ${tableName} (
270
341
  id SERIAL PRIMARY KEY,
271
342
  vector_id TEXT UNIQUE NOT NULL,
272
343
  embedding vector(${dimension}),
273
344
  metadata JSONB DEFAULT '{}'::jsonb
274
345
  );
275
- `);
346
+ `);
276
347
  this.createdIndexes.set(indexName, indexCacheKey);
277
348
 
278
349
  if (buildIndex) {
@@ -280,7 +351,6 @@ export class PgVector extends MastraVector {
280
351
  }
281
352
  } catch (error: any) {
282
353
  this.createdIndexes.delete(indexName);
283
- console.error('Failed to create vector table:', error);
284
354
  throw error;
285
355
  } finally {
286
356
  client.release();
@@ -319,8 +389,10 @@ export class PgVector extends MastraVector {
319
389
  const mutex = this.getMutexByName(`build-${indexName}`);
320
390
  // Use async-mutex instead of advisory lock for perf (over 2x as fast)
321
391
  await mutex.runExclusive(async () => {
392
+ const tableName = this.getTableName(indexName);
393
+
322
394
  if (this.createdIndexes.has(indexName)) {
323
- await client.query(`DROP INDEX IF EXISTS ${indexName}_vector_idx`);
395
+ await client.query(`DROP INDEX IF EXISTS ${tableName}_vector_idx`);
324
396
  }
325
397
 
326
398
  if (indexConfig.type === 'flat') {
@@ -338,7 +410,7 @@ export class PgVector extends MastraVector {
338
410
 
339
411
  indexSQL = `
340
412
  CREATE INDEX IF NOT EXISTS ${indexName}_vector_idx
341
- ON ${indexName}
413
+ ON ${tableName}
342
414
  USING hnsw (embedding ${metricOp})
343
415
  WITH (
344
416
  m = ${m},
@@ -350,12 +422,12 @@ export class PgVector extends MastraVector {
350
422
  if (indexConfig.ivf?.lists) {
351
423
  lists = indexConfig.ivf.lists;
352
424
  } else {
353
- const size = (await client.query(`SELECT COUNT(*) FROM ${indexName}`)).rows[0].count;
425
+ const size = (await client.query(`SELECT COUNT(*) FROM ${tableName}`)).rows[0].count;
354
426
  lists = Math.max(100, Math.min(4000, Math.floor(Math.sqrt(size) * 2)));
355
427
  }
356
428
  indexSQL = `
357
429
  CREATE INDEX IF NOT EXISTS ${indexName}_vector_idx
358
- ON ${indexName}
430
+ ON ${tableName}
359
431
  USING ivfflat (embedding ${metricOp})
360
432
  WITH (lists = ${lists});
361
433
  `;
@@ -423,10 +495,10 @@ export class PgVector extends MastraVector {
423
495
  const vectorTablesQuery = `
424
496
  SELECT DISTINCT table_name
425
497
  FROM information_schema.columns
426
- WHERE table_schema = 'public'
498
+ WHERE table_schema = $1
427
499
  AND udt_name = 'vector';
428
500
  `;
429
- const vectorTables = await client.query(vectorTablesQuery);
501
+ const vectorTables = await client.query(vectorTablesQuery, [this.schema || 'public']);
430
502
  return vectorTables.rows.map(row => row.table_name);
431
503
  } finally {
432
504
  client.release();
@@ -436,6 +508,8 @@ export class PgVector extends MastraVector {
436
508
  async describeIndex(indexName: string): Promise<PGIndexStats> {
437
509
  const client = await this.pool.connect();
438
510
  try {
511
+ const tableName = this.getTableName(indexName);
512
+
439
513
  // Get vector dimension
440
514
  const dimensionQuery = `
441
515
  SELECT atttypmod as dimension
@@ -445,8 +519,9 @@ export class PgVector extends MastraVector {
445
519
  `;
446
520
 
447
521
  // Get row count
448
- const countQuery = ` SELECT COUNT(*) as count
449
- FROM ${indexName};
522
+ const countQuery = `
523
+ SELECT COUNT(*) as count
524
+ FROM ${tableName};
450
525
  `;
451
526
 
452
527
  // Get index metric type
@@ -459,11 +534,11 @@ export class PgVector extends MastraVector {
459
534
  JOIN pg_class c ON i.indexrelid = c.oid
460
535
  JOIN pg_am am ON c.relam = am.oid
461
536
  JOIN pg_opclass opclass ON i.indclass[0] = opclass.oid
462
- WHERE c.relname = '${indexName}_vector_idx';
537
+ WHERE c.relname = '${tableName}_vector_idx';
463
538
  `;
464
539
 
465
540
  const [dimResult, countResult, indexResult] = await Promise.all([
466
- client.query(dimensionQuery, [indexName]),
541
+ client.query(dimensionQuery, [tableName]),
467
542
  client.query(countQuery),
468
543
  client.query(indexQuery),
469
544
  ]);
@@ -512,8 +587,9 @@ export class PgVector extends MastraVector {
512
587
  async deleteIndex(indexName: string): Promise<void> {
513
588
  const client = await this.pool.connect();
514
589
  try {
590
+ const tableName = this.getTableName(indexName);
515
591
  // Drop the table
516
- await client.query(`DROP TABLE IF EXISTS ${indexName} CASCADE`);
592
+ await client.query(`DROP TABLE IF EXISTS ${tableName} CASCADE`);
517
593
  this.createdIndexes.delete(indexName);
518
594
  } catch (error: any) {
519
595
  await client.query('ROLLBACK');
@@ -526,7 +602,8 @@ export class PgVector extends MastraVector {
526
602
  async truncateIndex(indexName: string) {
527
603
  const client = await this.pool.connect();
528
604
  try {
529
- await client.query(`TRUNCATE ${indexName}`);
605
+ const tableName = this.getTableName(indexName);
606
+ await client.query(`TRUNCATE ${tableName}`);
530
607
  } catch (e: any) {
531
608
  await client.query('ROLLBACK');
532
609
  throw new Error(`Failed to truncate vector table: ${e.message}`);
@@ -572,10 +649,12 @@ export class PgVector extends MastraVector {
572
649
  return;
573
650
  }
574
651
 
652
+ const tableName = this.getTableName(indexName);
653
+
575
654
  // query looks like this:
576
655
  // UPDATE table SET embedding = $2::vector, metadata = $3::jsonb WHERE id = $1
577
656
  const query = `
578
- UPDATE ${indexName}
657
+ UPDATE ${tableName}
579
658
  SET ${updateParts.join(', ')}
580
659
  WHERE vector_id = $1
581
660
  `;
@@ -589,8 +668,9 @@ export class PgVector extends MastraVector {
589
668
  async deleteIndexById(indexName: string, id: string): Promise<void> {
590
669
  const client = await this.pool.connect();
591
670
  try {
671
+ const tableName = this.getTableName(indexName);
592
672
  const query = `
593
- DELETE FROM ${indexName}
673
+ DELETE FROM ${tableName}
594
674
  WHERE vector_id = $1
595
675
  `;
596
676
  await client.query(query, [id]);