@mastra/pg 0.1.6-alpha.0 → 0.1.6-alpha.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,18 @@
1
- import type { Filter } from '@mastra/core/filter';
2
- import { type IndexStats, type QueryResult, MastraVector } from '@mastra/core/vector';
1
+ import { MastraVector } from '@mastra/core/vector';
2
+ import type {
3
+ IndexStats,
4
+ QueryResult,
5
+ QueryVectorParams,
6
+ CreateIndexParams,
7
+ UpsertVectorParams,
8
+ ParamsToArgs,
9
+ } from '@mastra/core/vector';
10
+ import type { VectorFilter } from '@mastra/core/vector/filter';
3
11
  import pg from 'pg';
4
12
 
5
13
  import { PGFilterTranslator } from './filter';
6
14
  import { buildFilterQuery } from './sql-builder';
7
- import { type IndexConfig, type IndexType } from './types';
15
+ import type { IndexConfig, IndexType } from './types';
8
16
 
9
17
  export interface PGIndexStats extends IndexStats {
10
18
  type: IndexType;
@@ -16,6 +24,31 @@ export interface PGIndexStats extends IndexStats {
16
24
  };
17
25
  }
18
26
 
27
+ interface PgQueryVectorParams extends QueryVectorParams {
28
+ minScore?: number;
29
+ /**
30
+ * HNSW search parameter. Controls the size of the dynamic candidate
31
+ * list during search. Higher values improve accuracy at the cost of speed.
32
+ */
33
+ ef?: number;
34
+ /**
35
+ * IVFFlat probe parameter. Number of cells to visit during search.
36
+ * Higher values improve accuracy at the cost of speed.
37
+ */
38
+ probes?: number;
39
+ }
40
+
41
+ interface PgCreateIndexParams extends CreateIndexParams {
42
+ indexConfig?: IndexConfig;
43
+ buildIndex?: boolean;
44
+ }
45
+
46
+ interface PgDefineIndexParams {
47
+ indexName: string;
48
+ metric: 'cosine' | 'euclidean' | 'dotproduct';
49
+ indexConfig: IndexConfig;
50
+ }
51
+
19
52
  export class PgVector extends MastraVector {
20
53
  private pool: pg.Pool;
21
54
  private indexCache: Map<string, PGIndexStats> = new Map();
@@ -41,10 +74,9 @@ export class PgVector extends MastraVector {
41
74
  }) ?? basePool;
42
75
  }
43
76
 
44
- transformFilter(filter?: Filter) {
45
- const pgFilter = new PGFilterTranslator();
46
- const translatedFilter = pgFilter.translate(filter ?? {});
47
- return translatedFilter;
77
+ transformFilter(filter?: VectorFilter) {
78
+ const translator = new PGFilterTranslator();
79
+ return translator.translate(filter);
48
80
  }
49
81
 
50
82
  async getIndexInfo(indexName: string): Promise<PGIndexStats> {
@@ -54,18 +86,10 @@ export class PgVector extends MastraVector {
54
86
  return this.indexCache.get(indexName)!;
55
87
  }
56
88
 
57
- async query(
58
- indexName: string,
59
- queryVector: number[],
60
- topK: number = 10,
61
- filter?: Filter,
62
- includeVector: boolean = false,
63
- minScore: number = 0, // Optional minimum score threshold
64
- options?: {
65
- ef?: number; // For HNSW
66
- probes?: number; // For IVF
67
- },
68
- ): Promise<QueryResult[]> {
89
+ async query(...args: ParamsToArgs<PgQueryVectorParams>): Promise<QueryResult[]> {
90
+ const params = this.normalizeArgs<PgQueryVectorParams>('query', args, ['minScore', 'ef', 'probes']);
91
+ const { indexName, queryVector, topK = 10, filter, includeVector = false, minScore = 0, ef, probes } = params;
92
+
69
93
  const client = await this.pool.connect();
70
94
  try {
71
95
  const vectorStr = `[${queryVector.join(',')}]`;
@@ -78,13 +102,13 @@ export class PgVector extends MastraVector {
78
102
  // Set HNSW search parameter if applicable
79
103
  if (indexInfo.type === 'hnsw') {
80
104
  // Calculate ef and clamp between 1 and 1000
81
- const calculatedEf = options?.ef ?? Math.max(topK, (indexInfo?.config?.m ?? 16) * topK);
105
+ const calculatedEf = ef ?? Math.max(topK, (indexInfo?.config?.m ?? 16) * topK);
82
106
  const searchEf = Math.min(1000, Math.max(1, calculatedEf));
83
107
  await client.query(`SET LOCAL hnsw.ef_search = ${searchEf}`);
84
108
  }
85
109
 
86
- if (indexInfo.type === 'ivfflat' && options?.probes) {
87
- await client.query(`SET LOCAL ivfflat.probes = ${options.probes}`);
110
+ if (indexInfo.type === 'ivfflat' && probes) {
111
+ await client.query(`SET LOCAL ivfflat.probes = ${probes}`);
88
112
  }
89
113
 
90
114
  const query = `
@@ -115,12 +139,11 @@ export class PgVector extends MastraVector {
115
139
  }
116
140
  }
117
141
 
118
- async upsert(
119
- indexName: string,
120
- vectors: number[][],
121
- metadata?: Record<string, any>[],
122
- ids?: string[],
123
- ): Promise<string[]> {
142
+ async upsert(...args: ParamsToArgs<UpsertVectorParams>): Promise<string[]> {
143
+ const params = this.normalizeArgs<UpsertVectorParams>('upsert', args);
144
+
145
+ const { indexName, vectors, metadata, ids } = params;
146
+
124
147
  // Start a transaction
125
148
  const client = await this.pool.connect();
126
149
  try {
@@ -151,13 +174,11 @@ export class PgVector extends MastraVector {
151
174
  }
152
175
  }
153
176
 
154
- async createIndex(
155
- indexName: string,
156
- dimension: number,
157
- metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
158
- indexConfig: IndexConfig = {},
159
- defineIndex: boolean = true,
160
- ): Promise<void> {
177
+ async createIndex(...args: ParamsToArgs<PgCreateIndexParams>): Promise<void> {
178
+ const params = this.normalizeArgs<PgCreateIndexParams>('createIndex', args, ['indexConfig', 'buildIndex']);
179
+
180
+ const { indexName, dimension, metric = 'cosine', indexConfig = {}, buildIndex = true } = params;
181
+
161
182
  const client = await this.pool.connect();
162
183
  try {
163
184
  // Validate inputs
@@ -192,8 +213,8 @@ export class PgVector extends MastraVector {
192
213
  );
193
214
  `);
194
215
 
195
- if (defineIndex) {
196
- await this.defineIndex(indexName, metric, indexConfig);
216
+ if (buildIndex) {
217
+ await this.buildIndex({ indexName, metric, indexConfig });
197
218
  }
198
219
  } catch (error: any) {
199
220
  console.error('Failed to create vector table:', error);
@@ -203,11 +224,22 @@ export class PgVector extends MastraVector {
203
224
  }
204
225
  }
205
226
 
227
+ /**
228
+ * @deprecated This function is deprecated. Use buildIndex instead
229
+ */
206
230
  async defineIndex(
207
231
  indexName: string,
208
232
  metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
209
233
  indexConfig: IndexConfig,
210
234
  ): Promise<void> {
235
+ return this.buildIndex({ indexName, metric, indexConfig });
236
+ }
237
+
238
+ async buildIndex(...args: ParamsToArgs<PgDefineIndexParams>): Promise<void> {
239
+ const params = this.normalizeArgs<PgDefineIndexParams>('buildIndex', args, ['metric', 'indexConfig']);
240
+
241
+ const { indexName, metric = 'cosine', indexConfig } = params;
242
+
211
243
  const client = await this.pool.connect();
212
244
  try {
213
245
  await client.query(`DROP INDEX IF EXISTS ${indexName}_vector_idx`);
@@ -1,6 +1,6 @@
1
- import { type IndexConfig, type IndexType } from './types';
1
+ import type { IndexConfig, IndexType } from './types';
2
2
 
3
- import { PgVector } from '.';
3
+ import type { PgVector } from '.';
4
4
 
5
5
  export interface TestResult {
6
6
  distribution: string;
@@ -236,7 +236,7 @@ export interface TestConfig {
236
236
 
237
237
  export async function warmupQuery(vectorDB: PgVector, indexName: string, dimension: number, k: number) {
238
238
  const warmupVector = generateRandomVectors(1, dimension)[0] as number[];
239
- await vectorDB.query(indexName, warmupVector, k);
239
+ await vectorDB.query({ indexName, queryVector: warmupVector, topK: k });
240
240
  }
241
241
 
242
242
  export async function measureLatency<T>(fn: () => Promise<T>): Promise<[number, T]> {
@@ -1,12 +1,12 @@
1
- import {
2
- type BasicOperator,
3
- type NumericOperator,
4
- type ArrayOperator,
5
- type ElementOperator,
6
- type LogicalOperator,
7
- type RegexOperator,
8
- type Filter,
9
- } from '@mastra/core/filter';
1
+ import type {
2
+ BasicOperator,
3
+ NumericOperator,
4
+ ArrayOperator,
5
+ ElementOperator,
6
+ LogicalOperator,
7
+ RegexOperator,
8
+ VectorFilter,
9
+ } from '@mastra/core/vector/filter';
10
10
 
11
11
  export type OperatorType =
12
12
  | BasicOperator
@@ -180,7 +180,7 @@ export const handleKey = (key: string) => {
180
180
  return key.replace(/\./g, ',');
181
181
  };
182
182
 
183
- export function buildFilterQuery(filter: Filter, minScore: number): FilterResult {
183
+ export function buildFilterQuery(filter: VectorFilter, minScore: number): FilterResult {
184
184
  const values = [minScore];
185
185
 
186
186
  function buildCondition(key: string, value: any, parentPath: string): string {
@@ -232,7 +232,11 @@ export function buildFilterQuery(filter: Filter, minScore: number): FilterResult
232
232
  return operatorResult.sql;
233
233
  }
234
234
 
235
- function handleLogicalOperator(key: '$and' | '$or' | '$not' | '$nor', value: Filter[], parentPath: string): string {
235
+ function handleLogicalOperator(
236
+ key: '$and' | '$or' | '$not' | '$nor',
237
+ value: VectorFilter[],
238
+ parentPath: string,
239
+ ): string {
236
240
  if (key === '$not') {
237
241
  // For top-level $not
238
242
  const entries = Object.entries(value);
@@ -256,8 +260,8 @@ export function buildFilterQuery(filter: Filter, minScore: number): FilterResult
256
260
  }
257
261
 
258
262
  const joinOperator = key === '$or' || key === '$nor' ? 'OR' : 'AND';
259
- const conditions = value.map((f: Filter) => {
260
- const entries = Object.entries(f);
263
+ const conditions = value.map((f: VectorFilter) => {
264
+ const entries = Object.entries(f || {});
261
265
  if (entries.length === 0) return '';
262
266
 
263
267
  const [firstKey, firstValue] = entries[0] || [];
@@ -1,10 +1,9 @@
1
1
  import pg from 'pg';
2
2
  import { describe, it, beforeAll, afterAll, beforeEach, afterEach } from 'vitest';
3
3
 
4
+ import type { TestConfig, TestResult } from './performance.helpers';
4
5
  import {
5
6
  baseTestConfigs,
6
- TestConfig,
7
- TestResult,
8
7
  calculateTimeout,
9
8
  generateRandomVectors,
10
9
  findNearestBruteForce,
@@ -18,8 +17,9 @@ import {
18
17
  generateSkewedVectors,
19
18
  getHNSWConfig,
20
19
  getIndexDescription,
20
+ warmupQuery,
21
21
  } from './performance.helpers';
22
- import { IndexConfig, IndexType } from './types';
22
+ import type { IndexConfig, IndexType } from './types';
23
23
 
24
24
  import { PgVector } from '.';
25
25
 
@@ -93,8 +93,7 @@ async function smartWarmup(
93
93
  const cacheKey = `${dimension}-${k}-${indexType}`;
94
94
  if (!warmupCache.has(cacheKey)) {
95
95
  console.log(`Warming up ${indexType} index for ${dimension}d vectors, k=${k}`);
96
- const warmupVector = generateRandomVectors(1, dimension)[0] as number[];
97
- await vectorDB.query(testIndexName, warmupVector, k);
96
+ await warmupQuery(vectorDB, testIndexName, dimension, k);
98
97
  warmupCache.set(cacheKey, true);
99
98
  }
100
99
  }
@@ -163,13 +162,13 @@ describe('PostgreSQL Index Performance', () => {
163
162
  // Create index and insert vectors
164
163
  const lists = getListCount(indexConfig, testConfig.size);
165
164
 
166
- await vectorDB.createIndex(
167
- testIndexName,
168
- testConfig.dimension,
169
- 'cosine',
165
+ await vectorDB.createIndex({
166
+ indexName: testIndexName,
167
+ dimension: testConfig.dimension,
168
+ metric: 'cosine',
170
169
  indexConfig,
171
- indexType === 'ivfflat',
172
- );
170
+ buildIndex: indexType === 'ivfflat',
171
+ });
173
172
 
174
173
  console.log(
175
174
  `Batched bulk upserting ${testVectors.length} ${distType} vectors into index ${testIndexName}`,
@@ -178,7 +177,7 @@ describe('PostgreSQL Index Performance', () => {
178
177
  await batchedBulkUpsert(vectorDB, testIndexName, testVectors, batchSizes);
179
178
  if (indexType === 'hnsw' || rebuild) {
180
179
  console.log('rebuilding index');
181
- await vectorDB.defineIndex(testIndexName, 'cosine', indexConfig);
180
+ await vectorDB.buildIndex({ indexName: testIndexName, metric: 'cosine', indexConfig });
182
181
  console.log('index rebuilt');
183
182
  }
184
183
  await smartWarmup(vectorDB, testIndexName, indexType, testConfig.dimension, testConfig.k);
@@ -194,15 +193,12 @@ describe('PostgreSQL Index Performance', () => {
194
193
  const expectedNeighbors = findNearestBruteForce(queryVector, testVectors, testConfig.k);
195
194
 
196
195
  const [latency, actualResults] = await measureLatency(async () =>
197
- vectorDB.query(
198
- testIndexName,
196
+ vectorDB.query({
197
+ indexName: testIndexName,
199
198
  queryVector,
200
- testConfig.k,
201
- undefined,
202
- false,
203
- 0,
204
- { ef }, // For HNSW
205
- ),
199
+ topK: testConfig.k,
200
+ ef, // For HNSW
201
+ }),
206
202
  );
207
203
 
208
204
  const actualNeighbors = actualResults.map(r => r.metadata?.index);