@mastra/pg 0.1.6-alpha.1 → 0.1.6-alpha.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +11 -6
- package/CHANGELOG.md +25 -0
- package/README.md +22 -17
- package/dist/_tsup-dts-rollup.d.cts +329 -0
- package/dist/_tsup-dts-rollup.d.ts +45 -20
- package/dist/index.cjs +1050 -0
- package/dist/index.d.cts +4 -0
- package/dist/index.js +24 -17
- package/package.json +7 -3
- package/src/storage/index.ts +2 -2
- package/src/vector/filter.ts +5 -5
- package/src/vector/index.test.ts +696 -314
- package/src/vector/index.ts +62 -42
- package/src/vector/performance.helpers.ts +1 -1
- package/src/vector/sql-builder.ts +10 -6
- package/src/vector/vector.performance.test.ts +14 -17
package/src/vector/index.ts
CHANGED
|
@@ -1,6 +1,13 @@
|
|
|
1
|
-
import type { Filter } from '@mastra/core/filter';
|
|
2
1
|
import { MastraVector } from '@mastra/core/vector';
|
|
3
|
-
import type {
|
|
2
|
+
import type {
|
|
3
|
+
IndexStats,
|
|
4
|
+
QueryResult,
|
|
5
|
+
QueryVectorParams,
|
|
6
|
+
CreateIndexParams,
|
|
7
|
+
UpsertVectorParams,
|
|
8
|
+
ParamsToArgs,
|
|
9
|
+
} from '@mastra/core/vector';
|
|
10
|
+
import type { VectorFilter } from '@mastra/core/vector/filter';
|
|
4
11
|
import pg from 'pg';
|
|
5
12
|
|
|
6
13
|
import { PGFilterTranslator } from './filter';
|
|
@@ -17,6 +24,31 @@ export interface PGIndexStats extends IndexStats {
|
|
|
17
24
|
};
|
|
18
25
|
}
|
|
19
26
|
|
|
27
|
+
interface PgQueryVectorParams extends QueryVectorParams {
|
|
28
|
+
minScore?: number;
|
|
29
|
+
/**
|
|
30
|
+
* HNSW search parameter. Controls the size of the dynamic candidate
|
|
31
|
+
* list during search. Higher values improve accuracy at the cost of speed.
|
|
32
|
+
*/
|
|
33
|
+
ef?: number;
|
|
34
|
+
/**
|
|
35
|
+
* IVFFlat probe parameter. Number of cells to visit during search.
|
|
36
|
+
* Higher values improve accuracy at the cost of speed.
|
|
37
|
+
*/
|
|
38
|
+
probes?: number;
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
interface PgCreateIndexParams extends CreateIndexParams {
|
|
42
|
+
indexConfig?: IndexConfig;
|
|
43
|
+
buildIndex?: boolean;
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
interface PgDefineIndexParams {
|
|
47
|
+
indexName: string;
|
|
48
|
+
metric: 'cosine' | 'euclidean' | 'dotproduct';
|
|
49
|
+
indexConfig: IndexConfig;
|
|
50
|
+
}
|
|
51
|
+
|
|
20
52
|
export class PgVector extends MastraVector {
|
|
21
53
|
private pool: pg.Pool;
|
|
22
54
|
private indexCache: Map<string, PGIndexStats> = new Map();
|
|
@@ -42,10 +74,9 @@ export class PgVector extends MastraVector {
|
|
|
42
74
|
}) ?? basePool;
|
|
43
75
|
}
|
|
44
76
|
|
|
45
|
-
transformFilter(filter?:
|
|
46
|
-
const
|
|
47
|
-
|
|
48
|
-
return translatedFilter;
|
|
77
|
+
transformFilter(filter?: VectorFilter) {
|
|
78
|
+
const translator = new PGFilterTranslator();
|
|
79
|
+
return translator.translate(filter);
|
|
49
80
|
}
|
|
50
81
|
|
|
51
82
|
async getIndexInfo(indexName: string): Promise<PGIndexStats> {
|
|
@@ -55,18 +86,10 @@ export class PgVector extends MastraVector {
|
|
|
55
86
|
return this.indexCache.get(indexName)!;
|
|
56
87
|
}
|
|
57
88
|
|
|
58
|
-
async query(
|
|
59
|
-
|
|
60
|
-
queryVector
|
|
61
|
-
|
|
62
|
-
filter?: Filter,
|
|
63
|
-
includeVector: boolean = false,
|
|
64
|
-
minScore: number = 0, // Optional minimum score threshold
|
|
65
|
-
options?: {
|
|
66
|
-
ef?: number; // For HNSW
|
|
67
|
-
probes?: number; // For IVF
|
|
68
|
-
},
|
|
69
|
-
): Promise<QueryResult[]> {
|
|
89
|
+
async query(...args: ParamsToArgs<PgQueryVectorParams>): Promise<QueryResult[]> {
|
|
90
|
+
const params = this.normalizeArgs<PgQueryVectorParams>('query', args, ['minScore', 'ef', 'probes']);
|
|
91
|
+
const { indexName, queryVector, topK = 10, filter, includeVector = false, minScore = 0, ef, probes } = params;
|
|
92
|
+
|
|
70
93
|
const client = await this.pool.connect();
|
|
71
94
|
try {
|
|
72
95
|
const vectorStr = `[${queryVector.join(',')}]`;
|
|
@@ -79,13 +102,13 @@ export class PgVector extends MastraVector {
|
|
|
79
102
|
// Set HNSW search parameter if applicable
|
|
80
103
|
if (indexInfo.type === 'hnsw') {
|
|
81
104
|
// Calculate ef and clamp between 1 and 1000
|
|
82
|
-
const calculatedEf =
|
|
105
|
+
const calculatedEf = ef ?? Math.max(topK, (indexInfo?.config?.m ?? 16) * topK);
|
|
83
106
|
const searchEf = Math.min(1000, Math.max(1, calculatedEf));
|
|
84
107
|
await client.query(`SET LOCAL hnsw.ef_search = ${searchEf}`);
|
|
85
108
|
}
|
|
86
109
|
|
|
87
|
-
if (indexInfo.type === 'ivfflat' &&
|
|
88
|
-
await client.query(`SET LOCAL ivfflat.probes = ${
|
|
110
|
+
if (indexInfo.type === 'ivfflat' && probes) {
|
|
111
|
+
await client.query(`SET LOCAL ivfflat.probes = ${probes}`);
|
|
89
112
|
}
|
|
90
113
|
|
|
91
114
|
const query = `
|
|
@@ -116,12 +139,11 @@ export class PgVector extends MastraVector {
|
|
|
116
139
|
}
|
|
117
140
|
}
|
|
118
141
|
|
|
119
|
-
async upsert(
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
): Promise<string[]> {
|
|
142
|
+
async upsert(...args: ParamsToArgs<UpsertVectorParams>): Promise<string[]> {
|
|
143
|
+
const params = this.normalizeArgs<UpsertVectorParams>('upsert', args);
|
|
144
|
+
|
|
145
|
+
const { indexName, vectors, metadata, ids } = params;
|
|
146
|
+
|
|
125
147
|
// Start a transaction
|
|
126
148
|
const client = await this.pool.connect();
|
|
127
149
|
try {
|
|
@@ -152,13 +174,11 @@ export class PgVector extends MastraVector {
|
|
|
152
174
|
}
|
|
153
175
|
}
|
|
154
176
|
|
|
155
|
-
async createIndex(
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
metric
|
|
159
|
-
|
|
160
|
-
defineIndex: boolean = true,
|
|
161
|
-
): Promise<void> {
|
|
177
|
+
async createIndex(...args: ParamsToArgs<PgCreateIndexParams>): Promise<void> {
|
|
178
|
+
const params = this.normalizeArgs<PgCreateIndexParams>('createIndex', args, ['indexConfig', 'buildIndex']);
|
|
179
|
+
|
|
180
|
+
const { indexName, dimension, metric = 'cosine', indexConfig = {}, buildIndex = true } = params;
|
|
181
|
+
|
|
162
182
|
const client = await this.pool.connect();
|
|
163
183
|
try {
|
|
164
184
|
// Validate inputs
|
|
@@ -193,8 +213,8 @@ export class PgVector extends MastraVector {
|
|
|
193
213
|
);
|
|
194
214
|
`);
|
|
195
215
|
|
|
196
|
-
if (
|
|
197
|
-
await this.
|
|
216
|
+
if (buildIndex) {
|
|
217
|
+
await this.buildIndex({ indexName, metric, indexConfig });
|
|
198
218
|
}
|
|
199
219
|
} catch (error: any) {
|
|
200
220
|
console.error('Failed to create vector table:', error);
|
|
@@ -212,14 +232,14 @@ export class PgVector extends MastraVector {
|
|
|
212
232
|
metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
|
|
213
233
|
indexConfig: IndexConfig,
|
|
214
234
|
): Promise<void> {
|
|
215
|
-
return this.buildIndex(indexName, metric, indexConfig);
|
|
235
|
+
return this.buildIndex({ indexName, metric, indexConfig });
|
|
216
236
|
}
|
|
217
237
|
|
|
218
|
-
async buildIndex(
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
indexConfig
|
|
222
|
-
|
|
238
|
+
async buildIndex(...args: ParamsToArgs<PgDefineIndexParams>): Promise<void> {
|
|
239
|
+
const params = this.normalizeArgs<PgDefineIndexParams>('buildIndex', args, ['metric', 'indexConfig']);
|
|
240
|
+
|
|
241
|
+
const { indexName, metric = 'cosine', indexConfig } = params;
|
|
242
|
+
|
|
223
243
|
const client = await this.pool.connect();
|
|
224
244
|
try {
|
|
225
245
|
await client.query(`DROP INDEX IF EXISTS ${indexName}_vector_idx`);
|
|
@@ -236,7 +236,7 @@ export interface TestConfig {
|
|
|
236
236
|
|
|
237
237
|
export async function warmupQuery(vectorDB: PgVector, indexName: string, dimension: number, k: number) {
|
|
238
238
|
const warmupVector = generateRandomVectors(1, dimension)[0] as number[];
|
|
239
|
-
await vectorDB.query(indexName, warmupVector, k);
|
|
239
|
+
await vectorDB.query({ indexName, queryVector: warmupVector, topK: k });
|
|
240
240
|
}
|
|
241
241
|
|
|
242
242
|
export async function measureLatency<T>(fn: () => Promise<T>): Promise<[number, T]> {
|
|
@@ -5,8 +5,8 @@ import type {
|
|
|
5
5
|
ElementOperator,
|
|
6
6
|
LogicalOperator,
|
|
7
7
|
RegexOperator,
|
|
8
|
-
|
|
9
|
-
} from '@mastra/core/filter';
|
|
8
|
+
VectorFilter,
|
|
9
|
+
} from '@mastra/core/vector/filter';
|
|
10
10
|
|
|
11
11
|
export type OperatorType =
|
|
12
12
|
| BasicOperator
|
|
@@ -180,7 +180,7 @@ export const handleKey = (key: string) => {
|
|
|
180
180
|
return key.replace(/\./g, ',');
|
|
181
181
|
};
|
|
182
182
|
|
|
183
|
-
export function buildFilterQuery(filter:
|
|
183
|
+
export function buildFilterQuery(filter: VectorFilter, minScore: number): FilterResult {
|
|
184
184
|
const values = [minScore];
|
|
185
185
|
|
|
186
186
|
function buildCondition(key: string, value: any, parentPath: string): string {
|
|
@@ -232,7 +232,11 @@ export function buildFilterQuery(filter: Filter, minScore: number): FilterResult
|
|
|
232
232
|
return operatorResult.sql;
|
|
233
233
|
}
|
|
234
234
|
|
|
235
|
-
function handleLogicalOperator(
|
|
235
|
+
function handleLogicalOperator(
|
|
236
|
+
key: '$and' | '$or' | '$not' | '$nor',
|
|
237
|
+
value: VectorFilter[],
|
|
238
|
+
parentPath: string,
|
|
239
|
+
): string {
|
|
236
240
|
if (key === '$not') {
|
|
237
241
|
// For top-level $not
|
|
238
242
|
const entries = Object.entries(value);
|
|
@@ -256,8 +260,8 @@ export function buildFilterQuery(filter: Filter, minScore: number): FilterResult
|
|
|
256
260
|
}
|
|
257
261
|
|
|
258
262
|
const joinOperator = key === '$or' || key === '$nor' ? 'OR' : 'AND';
|
|
259
|
-
const conditions = value.map((f:
|
|
260
|
-
const entries = Object.entries(f);
|
|
263
|
+
const conditions = value.map((f: VectorFilter) => {
|
|
264
|
+
const entries = Object.entries(f || {});
|
|
261
265
|
if (entries.length === 0) return '';
|
|
262
266
|
|
|
263
267
|
const [firstKey, firstValue] = entries[0] || [];
|
|
@@ -17,6 +17,7 @@ import {
|
|
|
17
17
|
generateSkewedVectors,
|
|
18
18
|
getHNSWConfig,
|
|
19
19
|
getIndexDescription,
|
|
20
|
+
warmupQuery,
|
|
20
21
|
} from './performance.helpers';
|
|
21
22
|
import type { IndexConfig, IndexType } from './types';
|
|
22
23
|
|
|
@@ -92,8 +93,7 @@ async function smartWarmup(
|
|
|
92
93
|
const cacheKey = `${dimension}-${k}-${indexType}`;
|
|
93
94
|
if (!warmupCache.has(cacheKey)) {
|
|
94
95
|
console.log(`Warming up ${indexType} index for ${dimension}d vectors, k=${k}`);
|
|
95
|
-
|
|
96
|
-
await vectorDB.query(testIndexName, warmupVector, k);
|
|
96
|
+
await warmupQuery(vectorDB, testIndexName, dimension, k);
|
|
97
97
|
warmupCache.set(cacheKey, true);
|
|
98
98
|
}
|
|
99
99
|
}
|
|
@@ -162,13 +162,13 @@ describe('PostgreSQL Index Performance', () => {
|
|
|
162
162
|
// Create index and insert vectors
|
|
163
163
|
const lists = getListCount(indexConfig, testConfig.size);
|
|
164
164
|
|
|
165
|
-
await vectorDB.createIndex(
|
|
166
|
-
testIndexName,
|
|
167
|
-
testConfig.dimension,
|
|
168
|
-
'cosine',
|
|
165
|
+
await vectorDB.createIndex({
|
|
166
|
+
indexName: testIndexName,
|
|
167
|
+
dimension: testConfig.dimension,
|
|
168
|
+
metric: 'cosine',
|
|
169
169
|
indexConfig,
|
|
170
|
-
indexType === 'ivfflat',
|
|
171
|
-
);
|
|
170
|
+
buildIndex: indexType === 'ivfflat',
|
|
171
|
+
});
|
|
172
172
|
|
|
173
173
|
console.log(
|
|
174
174
|
`Batched bulk upserting ${testVectors.length} ${distType} vectors into index ${testIndexName}`,
|
|
@@ -177,7 +177,7 @@ describe('PostgreSQL Index Performance', () => {
|
|
|
177
177
|
await batchedBulkUpsert(vectorDB, testIndexName, testVectors, batchSizes);
|
|
178
178
|
if (indexType === 'hnsw' || rebuild) {
|
|
179
179
|
console.log('rebuilding index');
|
|
180
|
-
await vectorDB.buildIndex(testIndexName, 'cosine', indexConfig);
|
|
180
|
+
await vectorDB.buildIndex({ indexName: testIndexName, metric: 'cosine', indexConfig });
|
|
181
181
|
console.log('index rebuilt');
|
|
182
182
|
}
|
|
183
183
|
await smartWarmup(vectorDB, testIndexName, indexType, testConfig.dimension, testConfig.k);
|
|
@@ -193,15 +193,12 @@ describe('PostgreSQL Index Performance', () => {
|
|
|
193
193
|
const expectedNeighbors = findNearestBruteForce(queryVector, testVectors, testConfig.k);
|
|
194
194
|
|
|
195
195
|
const [latency, actualResults] = await measureLatency(async () =>
|
|
196
|
-
vectorDB.query(
|
|
197
|
-
testIndexName,
|
|
196
|
+
vectorDB.query({
|
|
197
|
+
indexName: testIndexName,
|
|
198
198
|
queryVector,
|
|
199
|
-
testConfig.k,
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
0,
|
|
203
|
-
{ ef }, // For HNSW
|
|
204
|
-
),
|
|
199
|
+
topK: testConfig.k,
|
|
200
|
+
ef, // For HNSW
|
|
201
|
+
}),
|
|
205
202
|
);
|
|
206
203
|
|
|
207
204
|
const actualNeighbors = actualResults.map(r => r.metadata?.index);
|