@mastra/pg 0.1.5-alpha.0 → 0.1.5-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -42,6 +42,32 @@ describe('PgVector', () => {
42
42
  it('should throw error if dimension is invalid', async () => {
43
43
  await expect(vectorDB.createIndex(`testIndexNameFail`, 0)).rejects.toThrow();
44
44
  });
45
+
46
+ it('should create index with flat type', async () => {
47
+ await vectorDB.createIndex(testIndexName2, 3, 'cosine', { type: 'flat' });
48
+ const stats = await vectorDB.describeIndex(testIndexName2);
49
+ expect(stats.type).toBe('flat');
50
+ });
51
+
52
+ it('should create index with hnsw type', async () => {
53
+ await vectorDB.createIndex(testIndexName2, 3, 'cosine', {
54
+ type: 'hnsw',
55
+ hnsw: { m: 16, efConstruction: 64 }, // Any reasonable values work
56
+ });
57
+ const stats = await vectorDB.describeIndex(testIndexName2);
58
+ expect(stats.type).toBe('hnsw');
59
+ expect(stats.config.m).toBe(16);
60
+ });
61
+
62
+ it('should create index with ivfflat type and lists', async () => {
63
+ await vectorDB.createIndex(testIndexName2, 3, 'cosine', {
64
+ type: 'ivfflat',
65
+ ivf: { lists: 100 },
66
+ });
67
+ const stats = await vectorDB.describeIndex(testIndexName2);
68
+ expect(stats.type).toBe('ivfflat');
69
+ expect(stats.config.lists).toBe(100);
70
+ });
45
71
  });
46
72
 
47
73
  describe('listIndexes', () => {
@@ -86,6 +112,10 @@ describe('PgVector', () => {
86
112
 
87
113
  const stats = await vectorDB.describeIndex(indexName);
88
114
  expect(stats).toEqual({
115
+ type: 'ivfflat',
116
+ config: {
117
+ lists: 100,
118
+ },
89
119
  dimension: 3,
90
120
  count: 2,
91
121
  metric: 'cosine',
@@ -152,59 +182,61 @@ describe('PgVector', () => {
152
182
  });
153
183
 
154
184
  describe('Basic Query Operations', () => {
155
- const indexName = 'test_query_2';
156
- beforeAll(async () => {
157
- try {
185
+ ['flat', 'hnsw', 'ivfflat'].forEach(indexType => {
186
+ const indexName = `test_query_2_${indexType}`;
187
+ beforeAll(async () => {
188
+ try {
189
+ await vectorDB.deleteIndex(indexName);
190
+ } catch (e) {
191
+ // Ignore if doesn't exist
192
+ }
193
+ await vectorDB.createIndex(indexName, 3);
194
+ });
195
+
196
+ beforeEach(async () => {
197
+ await vectorDB.truncateIndex(indexName);
198
+ const vectors = [
199
+ [1, 0, 0],
200
+ [0.8, 0.2, 0],
201
+ [0, 1, 0],
202
+ ];
203
+ const metadata = [
204
+ { type: 'a', value: 1 },
205
+ { type: 'b', value: 2 },
206
+ { type: 'a', value: 3 },
207
+ ];
208
+ await vectorDB.upsert(indexName, vectors, metadata);
209
+ });
210
+
211
+ afterAll(async () => {
158
212
  await vectorDB.deleteIndex(indexName);
159
- } catch (e) {
160
- // Ignore if doesn't exist
161
- }
162
- await vectorDB.createIndex(indexName, 3);
163
- });
164
-
165
- beforeEach(async () => {
166
- await vectorDB.truncateIndex(indexName);
167
- const vectors = [
168
- [1, 0, 0],
169
- [0.8, 0.2, 0],
170
- [0, 1, 0],
171
- ];
172
- const metadata = [
173
- { type: 'a', value: 1 },
174
- { type: 'b', value: 2 },
175
- { type: 'a', value: 3 },
176
- ];
177
- await vectorDB.upsert(indexName, vectors, metadata);
178
- });
179
-
180
- afterAll(async () => {
181
- await vectorDB.deleteIndex(indexName);
182
- });
213
+ });
183
214
 
184
- it('should return closest vectors', async () => {
185
- const results = await vectorDB.query(indexName, [1, 0, 0], 1);
186
- expect(results).toHaveLength(1);
187
- expect(results[0]?.vector).toBe(undefined);
188
- expect(results[0]?.score).toBeCloseTo(1, 5);
189
- });
215
+ it('should return closest vectors', async () => {
216
+ const results = await vectorDB.query(indexName, [1, 0, 0], 1);
217
+ expect(results).toHaveLength(1);
218
+ expect(results[0]?.vector).toBe(undefined);
219
+ expect(results[0]?.score).toBeCloseTo(1, 5);
220
+ });
190
221
 
191
- it('should return vector with result', async () => {
192
- const results = await vectorDB.query(indexName, [1, 0, 0], 1, undefined, true);
193
- expect(results).toHaveLength(1);
194
- expect(results[0]?.vector).toStrictEqual([1, 0, 0]);
195
- });
222
+ it('should return vector with result', async () => {
223
+ const results = await vectorDB.query(indexName, [1, 0, 0], 1, undefined, true);
224
+ expect(results).toHaveLength(1);
225
+ expect(results[0]?.vector).toStrictEqual([1, 0, 0]);
226
+ });
196
227
 
197
- it('should respect topK parameter', async () => {
198
- const results = await vectorDB.query(indexName, [1, 0, 0], 2);
199
- expect(results).toHaveLength(2);
200
- });
228
+ it('should respect topK parameter', async () => {
229
+ const results = await vectorDB.query(indexName, [1, 0, 0], 2);
230
+ expect(results).toHaveLength(2);
231
+ });
201
232
 
202
- it('should handle filters correctly', async () => {
203
- const results = await vectorDB.query(indexName, [1, 0, 0], 10, { type: 'a' });
233
+ it('should handle filters correctly', async () => {
234
+ const results = await vectorDB.query(indexName, [1, 0, 0], 10, { type: 'a' });
204
235
 
205
- expect(results).toHaveLength(1);
206
- results.forEach(result => {
207
- expect(result?.metadata?.type).toBe('a');
236
+ expect(results).toHaveLength(1);
237
+ results.forEach(result => {
238
+ expect(result?.metadata?.type).toBe('a');
239
+ });
208
240
  });
209
241
  });
210
242
  });
@@ -1202,4 +1234,69 @@ describe('PgVector', () => {
1202
1234
  });
1203
1235
  });
1204
1236
  });
1237
+
1238
+ describe('Search Parameters', () => {
1239
+ const indexName = 'test_search_params';
1240
+ const vectors = [
1241
+ [1, 0, 0], // Query vector will be closest to this
1242
+ [0.8, 0.2, 0], // Second closest
1243
+ [0, 1, 0], // Third (much further)
1244
+ ];
1245
+
1246
+ describe('HNSW Parameters', () => {
1247
+ beforeAll(async () => {
1248
+ await vectorDB.createIndex(indexName, 3, 'cosine', {
1249
+ type: 'hnsw',
1250
+ hnsw: { m: 16, efConstruction: 64 },
1251
+ });
1252
+ await vectorDB.upsert(indexName, vectors);
1253
+ });
1254
+
1255
+ afterAll(async () => {
1256
+ await vectorDB.deleteIndex(indexName);
1257
+ });
1258
+
1259
+ it('should use default ef value', async () => {
1260
+ const results = await vectorDB.query(indexName, [1, 0, 0], 2);
1261
+ expect(results).toHaveLength(2);
1262
+ expect(results[0]?.score).toBeCloseTo(1, 5);
1263
+ expect(results[1]?.score).toBeGreaterThan(0.9); // Second vector should be close
1264
+ });
1265
+
1266
+ it('should respect custom ef value', async () => {
1267
+ const results = await vectorDB.query(indexName, [1, 0, 0], 2, undefined, undefined, undefined, { ef: 100 });
1268
+ expect(results).toHaveLength(2);
1269
+ expect(results[0]?.score).toBeCloseTo(1, 5);
1270
+ expect(results[1]?.score).toBeGreaterThan(0.9);
1271
+ });
1272
+ });
1273
+
1274
+ describe('IVF Parameters', () => {
1275
+ beforeAll(async () => {
1276
+ await vectorDB.createIndex(indexName, 3, 'cosine', {
1277
+ type: 'ivfflat',
1278
+ ivf: { lists: 2 }, // Small number for test data
1279
+ });
1280
+ await vectorDB.upsert(indexName, vectors);
1281
+ });
1282
+
1283
+ afterAll(async () => {
1284
+ await vectorDB.deleteIndex(indexName);
1285
+ });
1286
+
1287
+ it('should use default probe value', async () => {
1288
+ const results = await vectorDB.query(indexName, [1, 0, 0], 2);
1289
+ expect(results).toHaveLength(2);
1290
+ expect(results[0]?.score).toBeCloseTo(1, 5);
1291
+ expect(results[1]?.score).toBeGreaterThan(0.9);
1292
+ });
1293
+
1294
+ it('should respect custom probe value', async () => {
1295
+ const results = await vectorDB.query(indexName, [1, 0, 0], 2, undefined, undefined, undefined, { probes: 2 });
1296
+ expect(results).toHaveLength(2);
1297
+ expect(results[0]?.score).toBeCloseTo(1, 5);
1298
+ expect(results[1]?.score).toBeGreaterThan(0.9);
1299
+ });
1300
+ });
1301
+ });
1205
1302
  });
@@ -4,9 +4,21 @@ import pg from 'pg';
4
4
 
5
5
  import { PGFilterTranslator } from './filter';
6
6
  import { buildFilterQuery } from './sql-builder';
7
+ import { type IndexConfig, type IndexType } from './types';
8
+
9
+ export interface PGIndexStats extends IndexStats {
10
+ type: IndexType;
11
+ config: {
12
+ m?: number;
13
+ efConstruction?: number;
14
+ lists?: number;
15
+ probes?: number;
16
+ };
17
+ }
7
18
 
8
19
  export class PgVector extends MastraVector {
9
20
  private pool: pg.Pool;
21
+ private indexCache: Map<string, PGIndexStats> = new Map();
10
22
 
11
23
  constructor(connectionString: string) {
12
24
  super();
@@ -35,6 +47,13 @@ export class PgVector extends MastraVector {
35
47
  return translatedFilter;
36
48
  }
37
49
 
50
+ async getIndexInfo(indexName: string): Promise<PGIndexStats> {
51
+ if (!this.indexCache.has(indexName)) {
52
+ this.indexCache.set(indexName, await this.describeIndex(indexName));
53
+ }
54
+ return this.indexCache.get(indexName)!;
55
+ }
56
+
38
57
  async query(
39
58
  indexName: string,
40
59
  queryVector: number[],
@@ -42,14 +61,32 @@ export class PgVector extends MastraVector {
42
61
  filter?: Filter,
43
62
  includeVector: boolean = false,
44
63
  minScore: number = 0, // Optional minimum score threshold
64
+ options?: {
65
+ ef?: number; // For HNSW
66
+ probes?: number; // For IVF
67
+ },
45
68
  ): Promise<QueryResult[]> {
46
69
  const client = await this.pool.connect();
47
70
  try {
48
71
  const vectorStr = `[${queryVector.join(',')}]`;
49
-
50
72
  const translatedFilter = this.transformFilter(filter);
51
73
  const { sql: filterQuery, values: filterValues } = buildFilterQuery(translatedFilter, minScore);
52
74
 
75
+ // Get index type and configuration
76
+ const indexInfo = await this.getIndexInfo(indexName);
77
+
78
+ // Set HNSW search parameter if applicable
79
+ if (indexInfo.type === 'hnsw') {
80
+ // Calculate ef and clamp between 1 and 1000
81
+ const calculatedEf = options?.ef ?? Math.max(topK, (indexInfo?.config?.m ?? 16) * topK);
82
+ const searchEf = Math.min(1000, Math.max(1, calculatedEf));
83
+ await client.query(`SET LOCAL hnsw.ef_search = ${searchEf}`);
84
+ }
85
+
86
+ if (indexInfo.type === 'ivfflat' && options?.probes) {
87
+ await client.query(`SET LOCAL ivfflat.probes = ${options.probes}`);
88
+ }
89
+
53
90
  const query = `
54
91
  WITH vector_scores AS (
55
92
  SELECT
@@ -88,25 +125,23 @@ export class PgVector extends MastraVector {
88
125
  const client = await this.pool.connect();
89
126
  try {
90
127
  await client.query('BEGIN');
91
-
92
128
  const vectorIds = ids || vectors.map(() => crypto.randomUUID());
93
129
 
94
130
  for (let i = 0; i < vectors.length; i++) {
95
131
  const query = `
96
- INSERT INTO ${indexName} (vector_id, embedding, metadata)
97
- VALUES ($1, $2::vector, $3::jsonb)
98
- ON CONFLICT (vector_id)
99
- DO UPDATE SET
100
- embedding = $2::vector,
101
- metadata = $3::jsonb
102
- RETURNING embedding::text
132
+ INSERT INTO ${indexName} (vector_id, embedding, metadata)
133
+ VALUES ($1, $2::vector, $3::jsonb)
134
+ ON CONFLICT (vector_id)
135
+ DO UPDATE SET
136
+ embedding = $2::vector,
137
+ metadata = $3::jsonb
138
+ RETURNING embedding::text
103
139
  `;
104
140
 
105
141
  await client.query(query, [vectorIds[i], `[${vectors[i]?.join(',')}]`, JSON.stringify(metadata?.[i] || {})]);
106
142
  }
107
143
 
108
144
  await client.query('COMMIT');
109
-
110
145
  return vectorIds;
111
146
  } catch (error) {
112
147
  await client.query('ROLLBACK');
@@ -120,6 +155,8 @@ export class PgVector extends MastraVector {
120
155
  indexName: string,
121
156
  dimension: number,
122
157
  metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
158
+ indexConfig: IndexConfig = {},
159
+ defineIndex: boolean = true,
123
160
  ): Promise<void> {
124
161
  const client = await this.pool.connect();
125
162
  try {
@@ -155,16 +192,9 @@ export class PgVector extends MastraVector {
155
192
  );
156
193
  `);
157
194
 
158
- // Create the index
159
- const indexMethod =
160
- metric === 'cosine' ? 'vector_cosine_ops' : metric === 'euclidean' ? 'vector_l2_ops' : 'vector_ip_ops';
161
-
162
- await client.query(`
163
- CREATE INDEX IF NOT EXISTS ${indexName}_vector_idx
164
- ON public.${indexName}
165
- USING ivfflat (embedding ${indexMethod})
166
- WITH (lists = 100);
167
- `);
195
+ if (defineIndex) {
196
+ await this.defineIndex(indexName, metric, indexConfig);
197
+ }
168
198
  } catch (error: any) {
169
199
  console.error('Failed to create vector table:', error);
170
200
  throw error;
@@ -173,6 +203,57 @@ export class PgVector extends MastraVector {
173
203
  }
174
204
  }
175
205
 
206
+ async defineIndex(
207
+ indexName: string,
208
+ metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
209
+ indexConfig: IndexConfig,
210
+ ): Promise<void> {
211
+ const client = await this.pool.connect();
212
+ try {
213
+ await client.query(`DROP INDEX IF EXISTS ${indexName}_vector_idx`);
214
+
215
+ if (indexConfig.type === 'flat') return;
216
+
217
+ const metricOp =
218
+ metric === 'cosine' ? 'vector_cosine_ops' : metric === 'euclidean' ? 'vector_l2_ops' : 'vector_ip_ops';
219
+
220
+ let indexSQL: string;
221
+ if (indexConfig.type === 'hnsw') {
222
+ const m = indexConfig.hnsw?.m ?? 8;
223
+ const efConstruction = indexConfig.hnsw?.efConstruction ?? 32;
224
+
225
+ indexSQL = `
226
+ CREATE INDEX ${indexName}_vector_idx
227
+ ON ${indexName}
228
+ USING hnsw (embedding ${metricOp})
229
+ WITH (
230
+ m = ${m},
231
+ ef_construction = ${efConstruction}
232
+ )
233
+ `;
234
+ } else {
235
+ let lists: number;
236
+ if (indexConfig.ivf?.lists) {
237
+ lists = indexConfig.ivf.lists;
238
+ } else {
239
+ const size = (await client.query(`SELECT COUNT(*) FROM ${indexName}`)).rows[0].count;
240
+ lists = Math.max(100, Math.min(4000, Math.floor(Math.sqrt(size) * 2)));
241
+ }
242
+ indexSQL = `
243
+ CREATE INDEX ${indexName}_vector_idx
244
+ ON ${indexName}
245
+ USING ivfflat (embedding ${metricOp})
246
+ WITH (lists = ${lists});
247
+ `;
248
+ }
249
+
250
+ await client.query(indexSQL);
251
+ this.indexCache.delete(indexName);
252
+ } finally {
253
+ client.release();
254
+ }
255
+ }
256
+
176
257
  async listIndexes(): Promise<string[]> {
177
258
  const client = await this.pool.connect();
178
259
  try {
@@ -190,7 +271,7 @@ export class PgVector extends MastraVector {
190
271
  }
191
272
  }
192
273
 
193
- async describeIndex(indexName: string): Promise<IndexStats> {
274
+ async describeIndex(indexName: string): Promise<PGIndexStats> {
194
275
  const client = await this.pool.connect();
195
276
  try {
196
277
  // Get vector dimension
@@ -208,9 +289,10 @@ export class PgVector extends MastraVector {
208
289
  `;
209
290
 
210
291
  // Get index metric type
211
- const metricQuery = `
292
+ const indexQuery = `
212
293
  SELECT
213
294
  am.amname as index_method,
295
+ pg_get_indexdef(i.indexrelid) as index_def,
214
296
  opclass.opcname as operator_class
215
297
  FROM pg_index i
216
298
  JOIN pg_class c ON i.indexrelid = c.oid
@@ -219,29 +301,44 @@ export class PgVector extends MastraVector {
219
301
  WHERE c.relname = '${indexName}_vector_idx';
220
302
  `;
221
303
 
222
- const [dimResult, countResult, metricResult] = await Promise.all([
304
+ const [dimResult, countResult, indexResult] = await Promise.all([
223
305
  client.query(dimensionQuery, [indexName]),
224
306
  client.query(countQuery),
225
- client.query(metricQuery),
307
+ client.query(indexQuery),
226
308
  ]);
227
309
 
310
+ const { index_method, index_def, operator_class } = indexResult.rows[0] || {
311
+ index_method: 'flat',
312
+ index_def: '',
313
+ operator_class: 'cosine',
314
+ };
315
+
228
316
  // Convert pg_vector index method to our metric type
229
- let metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine';
230
- if (metricResult.rows.length > 0) {
231
- const operatorClass = metricResult.rows[0].operator_class;
232
- if (operatorClass.includes('l2')) {
233
- metric = 'euclidean';
234
- } else if (operatorClass.includes('ip')) {
235
- metric = 'dotproduct';
236
- } else if (operatorClass.includes('cosine')) {
237
- metric = 'cosine';
238
- }
317
+ const metric = operator_class.includes('l2')
318
+ ? 'euclidean'
319
+ : operator_class.includes('ip')
320
+ ? 'dotproduct'
321
+ : 'cosine';
322
+
323
+ // Parse index configuration
324
+ const config: { m?: number; efConstruction?: number; lists?: number } = {};
325
+
326
+ if (index_method === 'hnsw') {
327
+ const m = index_def.match(/m\s*=\s*'?(\d+)'?/)?.[1];
328
+ const efConstruction = index_def.match(/ef_construction\s*=\s*'?(\d+)'?/)?.[1];
329
+ if (m) config.m = parseInt(m);
330
+ if (efConstruction) config.efConstruction = parseInt(efConstruction);
331
+ } else if (index_method === 'ivfflat') {
332
+ const lists = index_def.match(/lists\s*=\s*'?(\d+)'?/)?.[1];
333
+ if (lists) config.lists = parseInt(lists);
239
334
  }
240
335
 
241
336
  return {
242
337
  dimension: dimResult.rows[0].dimension,
243
338
  count: parseInt(countResult.rows[0].count),
244
339
  metric,
340
+ type: index_method as 'flat' | 'hnsw' | 'ivfflat',
341
+ config,
245
342
  };
246
343
  } catch (e: any) {
247
344
  await client.query('ROLLBACK');