@mastra/pg 0.3.4 → 0.4.0-alpha.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +7 -7
- package/CHANGELOG.md +37 -0
- package/dist/_tsup-dts-rollup.d.cts +21 -83
- package/dist/_tsup-dts-rollup.d.ts +21 -83
- package/dist/index.cjs +169 -185
- package/dist/index.js +169 -185
- package/docker-compose.perf.yaml +9 -9
- package/package.json +7 -4
- package/src/storage/index.test.ts +32 -51
- package/src/storage/index.ts +13 -17
- package/src/vector/index.test.ts +52 -179
- package/src/vector/index.ts +64 -152
- package/src/vector/sql-builder.ts +110 -77
- package/src/vector/vector.performance.test.ts +2 -2
package/src/vector/index.ts
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import { parseSqlIdentifier } from '@mastra/core/utils';
|
|
1
2
|
import { MastraVector } from '@mastra/core/vector';
|
|
2
3
|
import type {
|
|
3
4
|
IndexStats,
|
|
@@ -5,9 +6,10 @@ import type {
|
|
|
5
6
|
QueryVectorParams,
|
|
6
7
|
CreateIndexParams,
|
|
7
8
|
UpsertVectorParams,
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
9
|
+
DescribeIndexParams,
|
|
10
|
+
DeleteIndexParams,
|
|
11
|
+
DeleteVectorParams,
|
|
12
|
+
UpdateVectorParams,
|
|
11
13
|
} from '@mastra/core/vector';
|
|
12
14
|
import type { VectorFilter } from '@mastra/core/vector/filter';
|
|
13
15
|
import { Mutex } from 'async-mutex';
|
|
@@ -42,23 +44,17 @@ interface PgQueryVectorParams extends QueryVectorParams {
|
|
|
42
44
|
probes?: number;
|
|
43
45
|
}
|
|
44
46
|
|
|
45
|
-
type PgQueryVectorArgs = [...QueryVectorArgs, number?, number?, number?];
|
|
46
|
-
|
|
47
47
|
interface PgCreateIndexParams extends CreateIndexParams {
|
|
48
48
|
indexConfig?: IndexConfig;
|
|
49
49
|
buildIndex?: boolean;
|
|
50
50
|
}
|
|
51
51
|
|
|
52
|
-
type PgCreateIndexArgs = [...CreateIndexArgs, IndexConfig?, boolean?];
|
|
53
|
-
|
|
54
52
|
interface PgDefineIndexParams {
|
|
55
53
|
indexName: string;
|
|
56
54
|
metric: 'cosine' | 'euclidean' | 'dotproduct';
|
|
57
55
|
indexConfig: IndexConfig;
|
|
58
56
|
}
|
|
59
57
|
|
|
60
|
-
type PgDefineIndexArgs = [string, 'cosine' | 'euclidean' | 'dotproduct', IndexConfig];
|
|
61
|
-
|
|
62
58
|
export class PgVector extends MastraVector {
|
|
63
59
|
private pool: pg.Pool;
|
|
64
60
|
private describeIndexCache: Map<string, PGIndexStats> = new Map();
|
|
@@ -70,48 +66,15 @@ export class PgVector extends MastraVector {
|
|
|
70
66
|
private vectorExtensionInstalled: boolean | undefined = undefined;
|
|
71
67
|
private schemaSetupComplete: boolean | undefined = undefined;
|
|
72
68
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
constructor(config: {
|
|
69
|
+
constructor({
|
|
70
|
+
connectionString,
|
|
71
|
+
schemaName,
|
|
72
|
+
pgPoolOptions,
|
|
73
|
+
}: {
|
|
79
74
|
connectionString: string;
|
|
80
75
|
schemaName?: string;
|
|
81
76
|
pgPoolOptions?: Omit<pg.PoolConfig, 'connectionString'>;
|
|
82
|
-
})
|
|
83
|
-
constructor(
|
|
84
|
-
config:
|
|
85
|
-
| string
|
|
86
|
-
| {
|
|
87
|
-
connectionString: string;
|
|
88
|
-
schemaName?: string;
|
|
89
|
-
pgPoolOptions?: Omit<pg.PoolConfig, 'connectionString'>;
|
|
90
|
-
},
|
|
91
|
-
) {
|
|
92
|
-
let connectionString: string;
|
|
93
|
-
let pgPoolOptions: Omit<pg.PoolConfig, 'connectionString'> | undefined;
|
|
94
|
-
let schemaName: string | undefined;
|
|
95
|
-
|
|
96
|
-
if (typeof config === 'string') {
|
|
97
|
-
// DEPRECATION WARNING
|
|
98
|
-
console.warn(
|
|
99
|
-
`DEPRECATION WARNING: Passing connectionString as a string to PgVector constructor is deprecated.
|
|
100
|
-
|
|
101
|
-
Please use an object parameter instead:
|
|
102
|
-
new PgVector({ connectionString })
|
|
103
|
-
|
|
104
|
-
The string signature will be removed on May 20th, 2025.`,
|
|
105
|
-
);
|
|
106
|
-
connectionString = config;
|
|
107
|
-
schemaName = undefined;
|
|
108
|
-
pgPoolOptions = undefined;
|
|
109
|
-
} else {
|
|
110
|
-
connectionString = config.connectionString;
|
|
111
|
-
schemaName = config.schemaName;
|
|
112
|
-
pgPoolOptions = config.pgPoolOptions;
|
|
113
|
-
}
|
|
114
|
-
|
|
77
|
+
}) {
|
|
115
78
|
if (!connectionString || connectionString.trim() === '') {
|
|
116
79
|
throw new Error(
|
|
117
80
|
'PgVector: connectionString must be provided and cannot be empty. Passing an empty string may cause fallback to local Postgres defaults.',
|
|
@@ -143,7 +106,7 @@ export class PgVector extends MastraVector {
|
|
|
143
106
|
// warm the created indexes cache so we don't need to check if indexes exist every time
|
|
144
107
|
const existingIndexes = await this.listIndexes();
|
|
145
108
|
void existingIndexes.map(async indexName => {
|
|
146
|
-
const info = await this.getIndexInfo(indexName);
|
|
109
|
+
const info = await this.getIndexInfo({ indexName });
|
|
147
110
|
const key = await this.getIndexCacheKey({
|
|
148
111
|
indexName,
|
|
149
112
|
metric: info.metric,
|
|
@@ -161,7 +124,9 @@ export class PgVector extends MastraVector {
|
|
|
161
124
|
}
|
|
162
125
|
|
|
163
126
|
private getTableName(indexName: string) {
|
|
164
|
-
|
|
127
|
+
const parsedIndexName = parseSqlIdentifier(indexName, 'index name');
|
|
128
|
+
const parsedSchemaName = this.schema ? parseSqlIdentifier(this.schema, 'schema name') : undefined;
|
|
129
|
+
return parsedSchemaName ? `${parsedSchemaName}.${parsedIndexName}` : parsedIndexName;
|
|
165
130
|
}
|
|
166
131
|
|
|
167
132
|
transformFilter(filter?: VectorFilter) {
|
|
@@ -169,29 +134,38 @@ export class PgVector extends MastraVector {
|
|
|
169
134
|
return translator.translate(filter);
|
|
170
135
|
}
|
|
171
136
|
|
|
172
|
-
async getIndexInfo(indexName:
|
|
137
|
+
async getIndexInfo({ indexName }: DescribeIndexParams): Promise<PGIndexStats> {
|
|
173
138
|
if (!this.describeIndexCache.has(indexName)) {
|
|
174
|
-
this.describeIndexCache.set(indexName, await this.describeIndex(indexName));
|
|
139
|
+
this.describeIndexCache.set(indexName, await this.describeIndex({ indexName }));
|
|
175
140
|
}
|
|
176
141
|
return this.describeIndexCache.get(indexName)!;
|
|
177
142
|
}
|
|
178
143
|
|
|
179
|
-
async query(
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
144
|
+
async query({
|
|
145
|
+
indexName,
|
|
146
|
+
queryVector,
|
|
147
|
+
topK = 10,
|
|
148
|
+
filter,
|
|
149
|
+
includeVector = false,
|
|
150
|
+
minScore = 0,
|
|
151
|
+
ef,
|
|
152
|
+
probes,
|
|
153
|
+
}: PgQueryVectorParams): Promise<QueryResult[]> {
|
|
154
|
+
if (!Number.isInteger(topK) || topK <= 0) {
|
|
155
|
+
throw new Error('topK must be a positive integer');
|
|
156
|
+
}
|
|
157
|
+
if (!Array.isArray(queryVector) || !queryVector.every(x => typeof x === 'number' && Number.isFinite(x))) {
|
|
158
|
+
throw new Error('queryVector must be an array of finite numbers');
|
|
159
|
+
}
|
|
186
160
|
|
|
187
161
|
const client = await this.pool.connect();
|
|
188
162
|
try {
|
|
189
163
|
const vectorStr = `[${queryVector.join(',')}]`;
|
|
190
164
|
const translatedFilter = this.transformFilter(filter);
|
|
191
|
-
const { sql: filterQuery, values: filterValues } = buildFilterQuery(translatedFilter, minScore);
|
|
165
|
+
const { sql: filterQuery, values: filterValues } = buildFilterQuery(translatedFilter, minScore, topK);
|
|
192
166
|
|
|
193
167
|
// Get index type and configuration
|
|
194
|
-
const indexInfo = await this.getIndexInfo(indexName);
|
|
168
|
+
const indexInfo = await this.getIndexInfo({ indexName });
|
|
195
169
|
|
|
196
170
|
// Set HNSW search parameter if applicable
|
|
197
171
|
if (indexInfo.type === 'hnsw') {
|
|
@@ -221,7 +195,7 @@ export class PgVector extends MastraVector {
|
|
|
221
195
|
FROM vector_scores
|
|
222
196
|
WHERE score > $1
|
|
223
197
|
ORDER BY score DESC
|
|
224
|
-
LIMIT $
|
|
198
|
+
LIMIT $2`;
|
|
225
199
|
const result = await client.query(query, filterValues);
|
|
226
200
|
|
|
227
201
|
return result.rows.map(({ id, score, metadata, embedding }) => ({
|
|
@@ -235,10 +209,7 @@ export class PgVector extends MastraVector {
|
|
|
235
209
|
}
|
|
236
210
|
}
|
|
237
211
|
|
|
238
|
-
async upsert(
|
|
239
|
-
const params = this.normalizeArgs<UpsertVectorParams>('upsert', args);
|
|
240
|
-
|
|
241
|
-
const { indexName, vectors, metadata, ids } = params;
|
|
212
|
+
async upsert({ indexName, vectors, metadata, ids }: UpsertVectorParams): Promise<string[]> {
|
|
242
213
|
const tableName = this.getTableName(indexName);
|
|
243
214
|
|
|
244
215
|
// Start a transaction
|
|
@@ -270,7 +241,7 @@ export class PgVector extends MastraVector {
|
|
|
270
241
|
if (match) {
|
|
271
242
|
const [, expected, actual] = match;
|
|
272
243
|
throw new Error(
|
|
273
|
-
`Vector dimension mismatch: Index "${
|
|
244
|
+
`Vector dimension mismatch: Index "${indexName}" expects ${expected} dimensions but got ${actual} dimensions. ` +
|
|
274
245
|
`Either use a matching embedding model or delete and recreate the index with the new dimension.`,
|
|
275
246
|
);
|
|
276
247
|
}
|
|
@@ -282,8 +253,13 @@ export class PgVector extends MastraVector {
|
|
|
282
253
|
}
|
|
283
254
|
|
|
284
255
|
private hasher = xxhash();
|
|
285
|
-
private async getIndexCacheKey(
|
|
286
|
-
|
|
256
|
+
private async getIndexCacheKey({
|
|
257
|
+
indexName,
|
|
258
|
+
dimension,
|
|
259
|
+
metric,
|
|
260
|
+
type,
|
|
261
|
+
}: CreateIndexParams & { type: IndexType | undefined }) {
|
|
262
|
+
const input = indexName + dimension + metric + (type || 'ivfflat'); // ivfflat is default
|
|
287
263
|
return (await this.hasher).h32(input);
|
|
288
264
|
}
|
|
289
265
|
private cachedIndexExists(indexName: string, newKey: number) {
|
|
@@ -341,13 +317,13 @@ export class PgVector extends MastraVector {
|
|
|
341
317
|
await this.setupSchemaPromise;
|
|
342
318
|
}
|
|
343
319
|
|
|
344
|
-
async createIndex(
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
320
|
+
async createIndex({
|
|
321
|
+
indexName,
|
|
322
|
+
dimension,
|
|
323
|
+
metric = 'cosine',
|
|
324
|
+
indexConfig = {},
|
|
325
|
+
buildIndex = true,
|
|
326
|
+
}: PgCreateIndexParams): Promise<void> {
|
|
351
327
|
const tableName = this.getTableName(indexName);
|
|
352
328
|
|
|
353
329
|
// Validate inputs
|
|
@@ -402,27 +378,7 @@ export class PgVector extends MastraVector {
|
|
|
402
378
|
});
|
|
403
379
|
}
|
|
404
380
|
|
|
405
|
-
|
|
406
|
-
* @deprecated This function is deprecated. Use buildIndex instead
|
|
407
|
-
* This function will be removed on May 20th, 2025
|
|
408
|
-
*/
|
|
409
|
-
async defineIndex(
|
|
410
|
-
indexName: string,
|
|
411
|
-
metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
|
|
412
|
-
indexConfig: IndexConfig,
|
|
413
|
-
): Promise<void> {
|
|
414
|
-
console.warn('defineIndex is deprecated. Use buildIndex instead. This function will be removed on May 20th, 2025');
|
|
415
|
-
return this.buildIndex({ indexName, metric, indexConfig });
|
|
416
|
-
}
|
|
417
|
-
|
|
418
|
-
async buildIndex(...args: ParamsToArgs<PgDefineIndexParams> | PgDefineIndexArgs): Promise<void> {
|
|
419
|
-
const params = this.normalizeArgs<PgDefineIndexParams, PgDefineIndexArgs>('buildIndex', args, [
|
|
420
|
-
'metric',
|
|
421
|
-
'indexConfig',
|
|
422
|
-
]);
|
|
423
|
-
|
|
424
|
-
const { indexName, metric = 'cosine', indexConfig } = params;
|
|
425
|
-
|
|
381
|
+
async buildIndex({ indexName, metric = 'cosine', indexConfig }: PgDefineIndexParams): Promise<void> {
|
|
426
382
|
const client = await this.pool.connect();
|
|
427
383
|
try {
|
|
428
384
|
await this.setupIndex({ indexName, metric, indexConfig }, client);
|
|
@@ -552,7 +508,13 @@ export class PgVector extends MastraVector {
|
|
|
552
508
|
}
|
|
553
509
|
}
|
|
554
510
|
|
|
555
|
-
|
|
511
|
+
/**
|
|
512
|
+
* Retrieves statistics about a vector index.
|
|
513
|
+
*
|
|
514
|
+
* @param {string} indexName - The name of the index to describe
|
|
515
|
+
* @returns A promise that resolves to the index statistics including dimension, count and metric
|
|
516
|
+
*/
|
|
517
|
+
async describeIndex({ indexName }: DescribeIndexParams): Promise<PGIndexStats> {
|
|
556
518
|
const client = await this.pool.connect();
|
|
557
519
|
try {
|
|
558
520
|
const tableName = this.getTableName(indexName);
|
|
@@ -648,7 +610,7 @@ export class PgVector extends MastraVector {
|
|
|
648
610
|
}
|
|
649
611
|
}
|
|
650
612
|
|
|
651
|
-
async deleteIndex(indexName:
|
|
613
|
+
async deleteIndex({ indexName }: DeleteIndexParams): Promise<void> {
|
|
652
614
|
const client = await this.pool.connect();
|
|
653
615
|
try {
|
|
654
616
|
const tableName = this.getTableName(indexName);
|
|
@@ -663,7 +625,7 @@ export class PgVector extends MastraVector {
|
|
|
663
625
|
}
|
|
664
626
|
}
|
|
665
627
|
|
|
666
|
-
async truncateIndex(indexName:
|
|
628
|
+
async truncateIndex({ indexName }: DeleteIndexParams): Promise<void> {
|
|
667
629
|
const client = await this.pool.connect();
|
|
668
630
|
try {
|
|
669
631
|
const tableName = this.getTableName(indexName);
|
|
@@ -680,31 +642,6 @@ export class PgVector extends MastraVector {
|
|
|
680
642
|
await this.pool.end();
|
|
681
643
|
}
|
|
682
644
|
|
|
683
|
-
/**
|
|
684
|
-
* @deprecated Use {@link updateVector} instead. This method will be removed on May 20th, 2025.
|
|
685
|
-
*
|
|
686
|
-
* Updates a vector by its ID with the provided vector and/or metadata.
|
|
687
|
-
* @param indexName - The name of the index containing the vector.
|
|
688
|
-
* @param id - The ID of the vector to update.
|
|
689
|
-
* @param update - An object containing the vector and/or metadata to update.
|
|
690
|
-
* @param update.vector - An optional array of numbers representing the new vector.
|
|
691
|
-
* @param update.metadata - An optional record containing the new metadata.
|
|
692
|
-
* @returns A promise that resolves when the update is complete.
|
|
693
|
-
* @throws Will throw an error if no updates are provided or if the update operation fails.
|
|
694
|
-
*/
|
|
695
|
-
async updateIndexById(
|
|
696
|
-
indexName: string,
|
|
697
|
-
id: string,
|
|
698
|
-
update: { vector?: number[]; metadata?: Record<string, any> },
|
|
699
|
-
): Promise<void> {
|
|
700
|
-
this.logger.warn(
|
|
701
|
-
`Deprecation Warning: updateIndexById() is deprecated.
|
|
702
|
-
Please use updateVector() instead.
|
|
703
|
-
updateIndexById() will be removed on May 20th, 2025.`,
|
|
704
|
-
);
|
|
705
|
-
await this.updateVector(indexName, id, update);
|
|
706
|
-
}
|
|
707
|
-
|
|
708
645
|
/**
|
|
709
646
|
* Updates a vector by its ID with the provided vector and/or metadata.
|
|
710
647
|
* @param indexName - The name of the index containing the vector.
|
|
@@ -715,14 +652,7 @@ export class PgVector extends MastraVector {
|
|
|
715
652
|
* @returns A promise that resolves when the update is complete.
|
|
716
653
|
* @throws Will throw an error if no updates are provided or if the update operation fails.
|
|
717
654
|
*/
|
|
718
|
-
async updateVector(
|
|
719
|
-
indexName: string,
|
|
720
|
-
id: string,
|
|
721
|
-
update: {
|
|
722
|
-
vector?: number[];
|
|
723
|
-
metadata?: Record<string, any>;
|
|
724
|
-
},
|
|
725
|
-
): Promise<void> {
|
|
655
|
+
async updateVector({ indexName, id, update }: UpdateVectorParams): Promise<void> {
|
|
726
656
|
if (!update.vector && !update.metadata) {
|
|
727
657
|
throw new Error('No updates provided');
|
|
728
658
|
}
|
|
@@ -766,24 +696,6 @@ export class PgVector extends MastraVector {
|
|
|
766
696
|
}
|
|
767
697
|
}
|
|
768
698
|
|
|
769
|
-
/**
|
|
770
|
-
* @deprecated Use {@link deleteVector} instead. This method will be removed on May 20th, 2025.
|
|
771
|
-
*
|
|
772
|
-
* Deletes a vector by its ID.
|
|
773
|
-
* @param indexName - The name of the index containing the vector.
|
|
774
|
-
* @param id - The ID of the vector to delete.
|
|
775
|
-
* @returns A promise that resolves when the deletion is complete.
|
|
776
|
-
* @throws Will throw an error if the deletion operation fails.
|
|
777
|
-
*/
|
|
778
|
-
async deleteIndexById(indexName: string, id: string): Promise<void> {
|
|
779
|
-
this.logger.warn(
|
|
780
|
-
`Deprecation Warning: deleteIndexById() is deprecated.
|
|
781
|
-
Please use deleteVector() instead.
|
|
782
|
-
deleteIndexById() will be removed on May 20th, 2025.`,
|
|
783
|
-
);
|
|
784
|
-
await this.deleteVector(indexName, id);
|
|
785
|
-
}
|
|
786
|
-
|
|
787
699
|
/**
|
|
788
700
|
* Deletes a vector by its ID.
|
|
789
701
|
* @param indexName - The name of the index containing the vector.
|
|
@@ -791,7 +703,7 @@ export class PgVector extends MastraVector {
|
|
|
791
703
|
* @returns A promise that resolves when the deletion is complete.
|
|
792
704
|
* @throws Will throw an error if the deletion operation fails.
|
|
793
705
|
*/
|
|
794
|
-
async deleteVector(indexName
|
|
706
|
+
async deleteVector({ indexName, id }: DeleteVectorParams): Promise<void> {
|
|
795
707
|
const client = await this.pool.connect();
|
|
796
708
|
try {
|
|
797
709
|
const tableName = this.getTableName(indexName);
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import { parseFieldKey } from '@mastra/core/utils';
|
|
1
2
|
import type {
|
|
2
3
|
BasicOperator,
|
|
3
4
|
NumericOperator,
|
|
@@ -8,14 +9,15 @@ import type {
|
|
|
8
9
|
VectorFilter,
|
|
9
10
|
} from '@mastra/core/vector/filter';
|
|
10
11
|
|
|
11
|
-
|
|
12
|
+
type OperatorType =
|
|
12
13
|
| BasicOperator
|
|
13
14
|
| NumericOperator
|
|
14
15
|
| ArrayOperator
|
|
15
16
|
| ElementOperator
|
|
16
17
|
| LogicalOperator
|
|
17
18
|
| '$contains'
|
|
18
|
-
| Exclude<RegexOperator, '$options'
|
|
19
|
+
| Exclude<RegexOperator, '$options'>
|
|
20
|
+
| '$size';
|
|
19
21
|
|
|
20
22
|
type FilterOperator = {
|
|
21
23
|
sql: string;
|
|
@@ -25,22 +27,27 @@ type FilterOperator = {
|
|
|
25
27
|
|
|
26
28
|
type OperatorFn = (key: string, paramIndex: number, value?: any) => FilterOperator;
|
|
27
29
|
|
|
28
|
-
// Helper functions to create operators
|
|
29
30
|
const createBasicOperator = (symbol: string) => {
|
|
30
|
-
return (key: string, paramIndex: number) =>
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
31
|
+
return (key: string, paramIndex: number) => {
|
|
32
|
+
const jsonPathKey = parseJsonPathKey(key);
|
|
33
|
+
return {
|
|
34
|
+
sql: `CASE
|
|
35
|
+
WHEN $${paramIndex}::text IS NULL THEN metadata#>>'{${jsonPathKey}}' IS ${symbol === '=' ? '' : 'NOT'} NULL
|
|
36
|
+
ELSE metadata#>>'{${jsonPathKey}}' ${symbol} $${paramIndex}::text
|
|
37
|
+
END`,
|
|
38
|
+
needsValue: true,
|
|
39
|
+
};
|
|
40
|
+
};
|
|
37
41
|
};
|
|
38
42
|
|
|
39
43
|
const createNumericOperator = (symbol: string) => {
|
|
40
|
-
return (key: string, paramIndex: number) =>
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
+
return (key: string, paramIndex: number) => {
|
|
45
|
+
const jsonPathKey = parseJsonPathKey(key);
|
|
46
|
+
return {
|
|
47
|
+
sql: `(metadata#>>'{${jsonPathKey}}')::numeric ${symbol} $${paramIndex}`,
|
|
48
|
+
needsValue: true,
|
|
49
|
+
};
|
|
50
|
+
};
|
|
44
51
|
};
|
|
45
52
|
|
|
46
53
|
function buildElemMatchConditions(value: any, paramIndex: number): { sql: string; values: any[] } {
|
|
@@ -73,7 +80,7 @@ function buildElemMatchConditions(value: any, paramIndex: number): { sql: string
|
|
|
73
80
|
paramValue = val;
|
|
74
81
|
}
|
|
75
82
|
|
|
76
|
-
const operatorFn = FILTER_OPERATORS[paramOperator as
|
|
83
|
+
const operatorFn = FILTER_OPERATORS[paramOperator as OperatorType];
|
|
77
84
|
if (!operatorFn) {
|
|
78
85
|
throw new Error(`Invalid operator: ${paramOperator}`);
|
|
79
86
|
}
|
|
@@ -93,7 +100,7 @@ function buildElemMatchConditions(value: any, paramIndex: number): { sql: string
|
|
|
93
100
|
}
|
|
94
101
|
|
|
95
102
|
// Define all filter operators
|
|
96
|
-
|
|
103
|
+
const FILTER_OPERATORS: Record<OperatorType, OperatorFn> = {
|
|
97
104
|
$eq: createBasicOperator('='),
|
|
98
105
|
$ne: createBasicOperator('!='),
|
|
99
106
|
$gt: createNumericOperator('>'),
|
|
@@ -102,46 +109,56 @@ export const FILTER_OPERATORS: Record<string, OperatorFn> = {
|
|
|
102
109
|
$lte: createNumericOperator('<='),
|
|
103
110
|
|
|
104
111
|
// Array Operators
|
|
105
|
-
$in: (key, paramIndex) =>
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
112
|
+
$in: (key, paramIndex) => {
|
|
113
|
+
const jsonPathKey = parseJsonPathKey(key);
|
|
114
|
+
return {
|
|
115
|
+
sql: `(
|
|
116
|
+
CASE
|
|
117
|
+
WHEN jsonb_typeof(metadata->'${jsonPathKey}') = 'array' THEN
|
|
118
|
+
EXISTS (
|
|
119
|
+
SELECT 1 FROM jsonb_array_elements_text(metadata->'${jsonPathKey}') as elem
|
|
120
|
+
WHERE elem = ANY($${paramIndex}::text[])
|
|
121
|
+
)
|
|
122
|
+
ELSE metadata#>>'{${jsonPathKey}}' = ANY($${paramIndex}::text[])
|
|
123
|
+
END
|
|
124
|
+
)`,
|
|
125
|
+
needsValue: true,
|
|
126
|
+
};
|
|
127
|
+
},
|
|
128
|
+
$nin: (key, paramIndex) => {
|
|
129
|
+
const jsonPathKey = parseJsonPathKey(key);
|
|
130
|
+
return {
|
|
131
|
+
sql: `(
|
|
132
|
+
CASE
|
|
133
|
+
WHEN jsonb_typeof(metadata->'${jsonPathKey}') = 'array' THEN
|
|
134
|
+
NOT EXISTS (
|
|
135
|
+
SELECT 1 FROM jsonb_array_elements_text(metadata->'${jsonPathKey}') as elem
|
|
136
|
+
WHERE elem = ANY($${paramIndex}::text[])
|
|
137
|
+
)
|
|
138
|
+
ELSE metadata#>>'{${jsonPathKey}}' != ALL($${paramIndex}::text[])
|
|
139
|
+
END
|
|
140
|
+
)`,
|
|
141
|
+
needsValue: true,
|
|
142
|
+
};
|
|
143
|
+
},
|
|
144
|
+
$all: (key, paramIndex) => {
|
|
145
|
+
const jsonPathKey = parseJsonPathKey(key);
|
|
146
|
+
return {
|
|
147
|
+
sql: `CASE WHEN array_length($${paramIndex}::text[], 1) IS NULL THEN false
|
|
148
|
+
ELSE (metadata#>'{${jsonPathKey}}')::jsonb ?& $${paramIndex}::text[] END`,
|
|
149
|
+
needsValue: true,
|
|
150
|
+
};
|
|
151
|
+
},
|
|
136
152
|
$elemMatch: (key: string, paramIndex: number, value: any): FilterOperator => {
|
|
137
153
|
const { sql, values } = buildElemMatchConditions(value, paramIndex);
|
|
154
|
+
const jsonPathKey = parseJsonPathKey(key);
|
|
138
155
|
return {
|
|
139
156
|
sql: `(
|
|
140
157
|
CASE
|
|
141
|
-
WHEN jsonb_typeof(metadata->'${
|
|
158
|
+
WHEN jsonb_typeof(metadata->'${jsonPathKey}') = 'array' THEN
|
|
142
159
|
EXISTS (
|
|
143
160
|
SELECT 1
|
|
144
|
-
FROM jsonb_array_elements(metadata->'${
|
|
161
|
+
FROM jsonb_array_elements(metadata->'${jsonPathKey}') as elem
|
|
145
162
|
WHERE ${sql}
|
|
146
163
|
)
|
|
147
164
|
ELSE FALSE
|
|
@@ -152,10 +169,13 @@ export const FILTER_OPERATORS: Record<string, OperatorFn> = {
|
|
|
152
169
|
};
|
|
153
170
|
},
|
|
154
171
|
// Element Operators
|
|
155
|
-
$exists: key =>
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
172
|
+
$exists: key => {
|
|
173
|
+
const jsonPathKey = parseJsonPathKey(key);
|
|
174
|
+
return {
|
|
175
|
+
sql: `metadata ? '${jsonPathKey}'`,
|
|
176
|
+
needsValue: false,
|
|
177
|
+
};
|
|
178
|
+
},
|
|
159
179
|
|
|
160
180
|
// Logical Operators
|
|
161
181
|
$and: key => ({ sql: `(${key})`, needsValue: false }),
|
|
@@ -164,24 +184,29 @@ export const FILTER_OPERATORS: Record<string, OperatorFn> = {
|
|
|
164
184
|
$nor: key => ({ sql: `NOT (${key})`, needsValue: false }),
|
|
165
185
|
|
|
166
186
|
// Regex Operators
|
|
167
|
-
$regex: (key, paramIndex) =>
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
187
|
+
$regex: (key, paramIndex) => {
|
|
188
|
+
const jsonPathKey = parseJsonPathKey(key);
|
|
189
|
+
return {
|
|
190
|
+
sql: `metadata#>>'{${jsonPathKey}}' ~ $${paramIndex}`,
|
|
191
|
+
needsValue: true,
|
|
192
|
+
};
|
|
193
|
+
},
|
|
171
194
|
|
|
172
195
|
$contains: (key, paramIndex, value: any) => {
|
|
196
|
+
const jsonPathKey = parseJsonPathKey(key);
|
|
173
197
|
let sql;
|
|
174
198
|
if (Array.isArray(value)) {
|
|
175
|
-
sql = `(metadata->'${
|
|
199
|
+
sql = `(metadata->'${jsonPathKey}') ?& $${paramIndex}`;
|
|
176
200
|
} else if (typeof value === 'string') {
|
|
177
|
-
sql = `metadata->>'${
|
|
201
|
+
sql = `metadata->>'${jsonPathKey}' ILIKE '%' || $${paramIndex} || '%' ESCAPE '\\'`;
|
|
178
202
|
} else {
|
|
179
|
-
sql = `metadata->>'${
|
|
203
|
+
sql = `metadata->>'${jsonPathKey}' = $${paramIndex}`;
|
|
180
204
|
}
|
|
181
205
|
return {
|
|
182
206
|
sql,
|
|
183
207
|
needsValue: true,
|
|
184
|
-
transformValue: () =>
|
|
208
|
+
transformValue: () =>
|
|
209
|
+
Array.isArray(value) ? value.map(String) : typeof value === 'string' ? escapeLikePattern(value) : value,
|
|
185
210
|
};
|
|
186
211
|
},
|
|
187
212
|
/**
|
|
@@ -196,29 +221,37 @@ export const FILTER_OPERATORS: Record<string, OperatorFn> = {
|
|
|
196
221
|
// return JSON.stringify(parts.reduceRight((value, key) => ({ [key]: value }), value));
|
|
197
222
|
// },
|
|
198
223
|
// }),
|
|
199
|
-
$size: (key: string, paramIndex: number) =>
|
|
200
|
-
|
|
224
|
+
$size: (key: string, paramIndex: number) => {
|
|
225
|
+
const jsonPathKey = parseJsonPathKey(key);
|
|
226
|
+
return {
|
|
227
|
+
sql: `(
|
|
201
228
|
CASE
|
|
202
|
-
WHEN jsonb_typeof(metadata#>'{${
|
|
203
|
-
jsonb_array_length(metadata#>'{${
|
|
229
|
+
WHEN jsonb_typeof(metadata#>'{${jsonPathKey}}') = 'array' THEN
|
|
230
|
+
jsonb_array_length(metadata#>'{${jsonPathKey}}') = $${paramIndex}
|
|
204
231
|
ELSE FALSE
|
|
205
232
|
END
|
|
206
233
|
)`,
|
|
207
|
-
|
|
208
|
-
|
|
234
|
+
needsValue: true,
|
|
235
|
+
};
|
|
236
|
+
},
|
|
209
237
|
};
|
|
210
238
|
|
|
211
|
-
|
|
239
|
+
interface FilterResult {
|
|
212
240
|
sql: string;
|
|
213
241
|
values: any[];
|
|
214
242
|
}
|
|
215
243
|
|
|
216
|
-
|
|
217
|
-
|
|
244
|
+
const parseJsonPathKey = (key: string) => {
|
|
245
|
+
const parsedKey = key !== '' ? parseFieldKey(key) : '';
|
|
246
|
+
return parsedKey.replace(/\./g, ',');
|
|
218
247
|
};
|
|
219
248
|
|
|
220
|
-
|
|
221
|
-
|
|
249
|
+
function escapeLikePattern(str: string): string {
|
|
250
|
+
return str.replace(/([%_\\])/g, '\\$1');
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
export function buildFilterQuery(filter: VectorFilter, minScore: number, topK: number): FilterResult {
|
|
254
|
+
const values = [minScore, topK];
|
|
222
255
|
|
|
223
256
|
function buildCondition(key: string, value: any, parentPath: string): string {
|
|
224
257
|
// Handle logical operators ($and/$or)
|
|
@@ -229,7 +262,7 @@ export function buildFilterQuery(filter: VectorFilter, minScore: number): Filter
|
|
|
229
262
|
// If condition is not a FilterCondition object, assume it's an equality check
|
|
230
263
|
if (!value || typeof value !== 'object') {
|
|
231
264
|
values.push(value);
|
|
232
|
-
return `metadata#>>'{${
|
|
265
|
+
return `metadata#>>'{${parseJsonPathKey(key)}}' = $${values.length}`;
|
|
233
266
|
}
|
|
234
267
|
|
|
235
268
|
// Handle operator conditions
|
|
@@ -240,11 +273,11 @@ export function buildFilterQuery(filter: VectorFilter, minScore: number): Filter
|
|
|
240
273
|
const entries = Object.entries(operatorValue as Record<string, unknown>);
|
|
241
274
|
const conditions = entries
|
|
242
275
|
.map(([nestedOp, nestedValue]) => {
|
|
243
|
-
if (!FILTER_OPERATORS[nestedOp as
|
|
276
|
+
if (!FILTER_OPERATORS[nestedOp as OperatorType]) {
|
|
244
277
|
throw new Error(`Invalid operator in $not condition: ${nestedOp}`);
|
|
245
278
|
}
|
|
246
|
-
const operatorFn = FILTER_OPERATORS[nestedOp]!;
|
|
247
|
-
const operatorResult = operatorFn(key, values.length + 1);
|
|
279
|
+
const operatorFn = FILTER_OPERATORS[nestedOp as OperatorType]!;
|
|
280
|
+
const operatorResult = operatorFn(key, values.length + 1, nestedValue);
|
|
248
281
|
if (operatorResult.needsValue) {
|
|
249
282
|
values.push(nestedValue as number);
|
|
250
283
|
}
|
|
@@ -254,7 +287,7 @@ export function buildFilterQuery(filter: VectorFilter, minScore: number): Filter
|
|
|
254
287
|
|
|
255
288
|
return `NOT (${conditions})`;
|
|
256
289
|
}
|
|
257
|
-
const operatorFn = FILTER_OPERATORS[operator as
|
|
290
|
+
const operatorFn = FILTER_OPERATORS[operator as OperatorType]!;
|
|
258
291
|
const operatorResult = operatorFn(key, values.length + 1, operatorValue);
|
|
259
292
|
if (operatorResult.needsValue) {
|
|
260
293
|
const transformedValue = operatorResult.transformValue ? operatorResult.transformValue() : operatorValue;
|