@mastra/pg 0.1.6-alpha.1 → 0.1.6-alpha.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,13 @@
1
- import type { Filter } from '@mastra/core/filter';
2
1
  import { MastraVector } from '@mastra/core/vector';
3
- import type { IndexStats, QueryResult } from '@mastra/core/vector';
2
+ import type {
3
+ IndexStats,
4
+ QueryResult,
5
+ QueryVectorParams,
6
+ CreateIndexParams,
7
+ UpsertVectorParams,
8
+ ParamsToArgs,
9
+ } from '@mastra/core/vector';
10
+ import type { VectorFilter } from '@mastra/core/vector/filter';
4
11
  import pg from 'pg';
5
12
 
6
13
  import { PGFilterTranslator } from './filter';
@@ -17,6 +24,31 @@ export interface PGIndexStats extends IndexStats {
17
24
  };
18
25
  }
19
26
 
27
+ interface PgQueryVectorParams extends QueryVectorParams {
28
+ minScore?: number;
29
+ /**
30
+ * HNSW search parameter. Controls the size of the dynamic candidate
31
+ * list during search. Higher values improve accuracy at the cost of speed.
32
+ */
33
+ ef?: number;
34
+ /**
35
+ * IVFFlat probe parameter. Number of cells to visit during search.
36
+ * Higher values improve accuracy at the cost of speed.
37
+ */
38
+ probes?: number;
39
+ }
40
+
41
+ interface PgCreateIndexParams extends CreateIndexParams {
42
+ indexConfig?: IndexConfig;
43
+ buildIndex?: boolean;
44
+ }
45
+
46
+ interface PgDefineIndexParams {
47
+ indexName: string;
48
+ metric: 'cosine' | 'euclidean' | 'dotproduct';
49
+ indexConfig: IndexConfig;
50
+ }
51
+
20
52
  export class PgVector extends MastraVector {
21
53
  private pool: pg.Pool;
22
54
  private indexCache: Map<string, PGIndexStats> = new Map();
@@ -42,10 +74,9 @@ export class PgVector extends MastraVector {
42
74
  }) ?? basePool;
43
75
  }
44
76
 
45
- transformFilter(filter?: Filter) {
46
- const pgFilter = new PGFilterTranslator();
47
- const translatedFilter = pgFilter.translate(filter ?? {});
48
- return translatedFilter;
77
+ transformFilter(filter?: VectorFilter) {
78
+ const translator = new PGFilterTranslator();
79
+ return translator.translate(filter);
49
80
  }
50
81
 
51
82
  async getIndexInfo(indexName: string): Promise<PGIndexStats> {
@@ -55,18 +86,10 @@ export class PgVector extends MastraVector {
55
86
  return this.indexCache.get(indexName)!;
56
87
  }
57
88
 
58
- async query(
59
- indexName: string,
60
- queryVector: number[],
61
- topK: number = 10,
62
- filter?: Filter,
63
- includeVector: boolean = false,
64
- minScore: number = 0, // Optional minimum score threshold
65
- options?: {
66
- ef?: number; // For HNSW
67
- probes?: number; // For IVF
68
- },
69
- ): Promise<QueryResult[]> {
89
+ async query(...args: ParamsToArgs<PgQueryVectorParams>): Promise<QueryResult[]> {
90
+ const params = this.normalizeArgs<PgQueryVectorParams>('query', args, ['minScore', 'ef', 'probes']);
91
+ const { indexName, queryVector, topK = 10, filter, includeVector = false, minScore = 0, ef, probes } = params;
92
+
70
93
  const client = await this.pool.connect();
71
94
  try {
72
95
  const vectorStr = `[${queryVector.join(',')}]`;
@@ -79,13 +102,13 @@ export class PgVector extends MastraVector {
79
102
  // Set HNSW search parameter if applicable
80
103
  if (indexInfo.type === 'hnsw') {
81
104
  // Calculate ef and clamp between 1 and 1000
82
- const calculatedEf = options?.ef ?? Math.max(topK, (indexInfo?.config?.m ?? 16) * topK);
105
+ const calculatedEf = ef ?? Math.max(topK, (indexInfo?.config?.m ?? 16) * topK);
83
106
  const searchEf = Math.min(1000, Math.max(1, calculatedEf));
84
107
  await client.query(`SET LOCAL hnsw.ef_search = ${searchEf}`);
85
108
  }
86
109
 
87
- if (indexInfo.type === 'ivfflat' && options?.probes) {
88
- await client.query(`SET LOCAL ivfflat.probes = ${options.probes}`);
110
+ if (indexInfo.type === 'ivfflat' && probes) {
111
+ await client.query(`SET LOCAL ivfflat.probes = ${probes}`);
89
112
  }
90
113
 
91
114
  const query = `
@@ -116,12 +139,11 @@ export class PgVector extends MastraVector {
116
139
  }
117
140
  }
118
141
 
119
- async upsert(
120
- indexName: string,
121
- vectors: number[][],
122
- metadata?: Record<string, any>[],
123
- ids?: string[],
124
- ): Promise<string[]> {
142
+ async upsert(...args: ParamsToArgs<UpsertVectorParams>): Promise<string[]> {
143
+ const params = this.normalizeArgs<UpsertVectorParams>('upsert', args);
144
+
145
+ const { indexName, vectors, metadata, ids } = params;
146
+
125
147
  // Start a transaction
126
148
  const client = await this.pool.connect();
127
149
  try {
@@ -152,13 +174,11 @@ export class PgVector extends MastraVector {
152
174
  }
153
175
  }
154
176
 
155
- async createIndex(
156
- indexName: string,
157
- dimension: number,
158
- metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
159
- indexConfig: IndexConfig = {},
160
- defineIndex: boolean = true,
161
- ): Promise<void> {
177
+ async createIndex(...args: ParamsToArgs<PgCreateIndexParams>): Promise<void> {
178
+ const params = this.normalizeArgs<PgCreateIndexParams>('createIndex', args, ['indexConfig', 'buildIndex']);
179
+
180
+ const { indexName, dimension, metric = 'cosine', indexConfig = {}, buildIndex = true } = params;
181
+
162
182
  const client = await this.pool.connect();
163
183
  try {
164
184
  // Validate inputs
@@ -193,8 +213,8 @@ export class PgVector extends MastraVector {
193
213
  );
194
214
  `);
195
215
 
196
- if (defineIndex) {
197
- await this.defineIndex(indexName, metric, indexConfig);
216
+ if (buildIndex) {
217
+ await this.buildIndex({ indexName, metric, indexConfig });
198
218
  }
199
219
  } catch (error: any) {
200
220
  console.error('Failed to create vector table:', error);
@@ -212,14 +232,14 @@ export class PgVector extends MastraVector {
212
232
  metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
213
233
  indexConfig: IndexConfig,
214
234
  ): Promise<void> {
215
- return this.buildIndex(indexName, metric, indexConfig);
235
+ return this.buildIndex({ indexName, metric, indexConfig });
216
236
  }
217
237
 
218
- async buildIndex(
219
- indexName: string,
220
- metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
221
- indexConfig: IndexConfig,
222
- ): Promise<void> {
238
+ async buildIndex(...args: ParamsToArgs<PgDefineIndexParams>): Promise<void> {
239
+ const params = this.normalizeArgs<PgDefineIndexParams>('buildIndex', args, ['metric', 'indexConfig']);
240
+
241
+ const { indexName, metric = 'cosine', indexConfig } = params;
242
+
223
243
  const client = await this.pool.connect();
224
244
  try {
225
245
  await client.query(`DROP INDEX IF EXISTS ${indexName}_vector_idx`);
@@ -236,7 +236,7 @@ export interface TestConfig {
236
236
 
237
237
  export async function warmupQuery(vectorDB: PgVector, indexName: string, dimension: number, k: number) {
238
238
  const warmupVector = generateRandomVectors(1, dimension)[0] as number[];
239
- await vectorDB.query(indexName, warmupVector, k);
239
+ await vectorDB.query({ indexName, queryVector: warmupVector, topK: k });
240
240
  }
241
241
 
242
242
  export async function measureLatency<T>(fn: () => Promise<T>): Promise<[number, T]> {
@@ -5,8 +5,8 @@ import type {
5
5
  ElementOperator,
6
6
  LogicalOperator,
7
7
  RegexOperator,
8
- Filter,
9
- } from '@mastra/core/filter';
8
+ VectorFilter,
9
+ } from '@mastra/core/vector/filter';
10
10
 
11
11
  export type OperatorType =
12
12
  | BasicOperator
@@ -180,7 +180,7 @@ export const handleKey = (key: string) => {
180
180
  return key.replace(/\./g, ',');
181
181
  };
182
182
 
183
- export function buildFilterQuery(filter: Filter, minScore: number): FilterResult {
183
+ export function buildFilterQuery(filter: VectorFilter, minScore: number): FilterResult {
184
184
  const values = [minScore];
185
185
 
186
186
  function buildCondition(key: string, value: any, parentPath: string): string {
@@ -232,7 +232,11 @@ export function buildFilterQuery(filter: Filter, minScore: number): FilterResult
232
232
  return operatorResult.sql;
233
233
  }
234
234
 
235
- function handleLogicalOperator(key: '$and' | '$or' | '$not' | '$nor', value: Filter[], parentPath: string): string {
235
+ function handleLogicalOperator(
236
+ key: '$and' | '$or' | '$not' | '$nor',
237
+ value: VectorFilter[],
238
+ parentPath: string,
239
+ ): string {
236
240
  if (key === '$not') {
237
241
  // For top-level $not
238
242
  const entries = Object.entries(value);
@@ -256,8 +260,8 @@ export function buildFilterQuery(filter: Filter, minScore: number): FilterResult
256
260
  }
257
261
 
258
262
  const joinOperator = key === '$or' || key === '$nor' ? 'OR' : 'AND';
259
- const conditions = value.map((f: Filter) => {
260
- const entries = Object.entries(f);
263
+ const conditions = value.map((f: VectorFilter) => {
264
+ const entries = Object.entries(f || {});
261
265
  if (entries.length === 0) return '';
262
266
 
263
267
  const [firstKey, firstValue] = entries[0] || [];
@@ -17,6 +17,7 @@ import {
17
17
  generateSkewedVectors,
18
18
  getHNSWConfig,
19
19
  getIndexDescription,
20
+ warmupQuery,
20
21
  } from './performance.helpers';
21
22
  import type { IndexConfig, IndexType } from './types';
22
23
 
@@ -92,8 +93,7 @@ async function smartWarmup(
92
93
  const cacheKey = `${dimension}-${k}-${indexType}`;
93
94
  if (!warmupCache.has(cacheKey)) {
94
95
  console.log(`Warming up ${indexType} index for ${dimension}d vectors, k=${k}`);
95
- const warmupVector = generateRandomVectors(1, dimension)[0] as number[];
96
- await vectorDB.query(testIndexName, warmupVector, k);
96
+ await warmupQuery(vectorDB, testIndexName, dimension, k);
97
97
  warmupCache.set(cacheKey, true);
98
98
  }
99
99
  }
@@ -162,13 +162,13 @@ describe('PostgreSQL Index Performance', () => {
162
162
  // Create index and insert vectors
163
163
  const lists = getListCount(indexConfig, testConfig.size);
164
164
 
165
- await vectorDB.createIndex(
166
- testIndexName,
167
- testConfig.dimension,
168
- 'cosine',
165
+ await vectorDB.createIndex({
166
+ indexName: testIndexName,
167
+ dimension: testConfig.dimension,
168
+ metric: 'cosine',
169
169
  indexConfig,
170
- indexType === 'ivfflat',
171
- );
170
+ buildIndex: indexType === 'ivfflat',
171
+ });
172
172
 
173
173
  console.log(
174
174
  `Batched bulk upserting ${testVectors.length} ${distType} vectors into index ${testIndexName}`,
@@ -177,7 +177,7 @@ describe('PostgreSQL Index Performance', () => {
177
177
  await batchedBulkUpsert(vectorDB, testIndexName, testVectors, batchSizes);
178
178
  if (indexType === 'hnsw' || rebuild) {
179
179
  console.log('rebuilding index');
180
- await vectorDB.buildIndex(testIndexName, 'cosine', indexConfig);
180
+ await vectorDB.buildIndex({ indexName: testIndexName, metric: 'cosine', indexConfig });
181
181
  console.log('index rebuilt');
182
182
  }
183
183
  await smartWarmup(vectorDB, testIndexName, indexType, testConfig.dimension, testConfig.k);
@@ -193,15 +193,12 @@ describe('PostgreSQL Index Performance', () => {
193
193
  const expectedNeighbors = findNearestBruteForce(queryVector, testVectors, testConfig.k);
194
194
 
195
195
  const [latency, actualResults] = await measureLatency(async () =>
196
- vectorDB.query(
197
- testIndexName,
196
+ vectorDB.query({
197
+ indexName: testIndexName,
198
198
  queryVector,
199
- testConfig.k,
200
- undefined,
201
- false,
202
- 0,
203
- { ef }, // For HNSW
204
- ),
199
+ topK: testConfig.k,
200
+ ef, // For HNSW
201
+ }),
205
202
  );
206
203
 
207
204
  const actualNeighbors = actualResults.map(r => r.metadata?.index);