@mastra/pg 0.1.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +116 -0
- package/LICENSE +44 -0
- package/README.md +155 -0
- package/dist/index.d.ts +82 -0
- package/dist/index.js +886 -0
- package/docker-compose.yaml +14 -0
- package/package.json +39 -0
- package/src/index.ts +2 -0
- package/src/storage/index.test.ts +379 -0
- package/src/storage/index.ts +477 -0
- package/src/vector/filter.test.ts +967 -0
- package/src/vector/filter.ts +106 -0
- package/src/vector/index.test.ts +1205 -0
- package/src/vector/index.ts +282 -0
- package/src/vector/sql-builder.ts +285 -0
- package/tsconfig.json +15 -0
- package/vitest.config.ts +11 -0
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
import type { Filter } from '@mastra/core/filter';
|
|
2
|
+
import { type IndexStats, type QueryResult, MastraVector } from '@mastra/core/vector';
|
|
3
|
+
import pg from 'pg';
|
|
4
|
+
|
|
5
|
+
import { PGFilterTranslator } from './filter';
|
|
6
|
+
import { buildFilterQuery } from './sql-builder';
|
|
7
|
+
|
|
8
|
+
export class PgVector extends MastraVector {
|
|
9
|
+
private pool: pg.Pool;
|
|
10
|
+
|
|
11
|
+
constructor(connectionString: string) {
|
|
12
|
+
super();
|
|
13
|
+
|
|
14
|
+
const basePool = new pg.Pool({
|
|
15
|
+
connectionString,
|
|
16
|
+
max: 20, // Maximum number of clients in the pool
|
|
17
|
+
idleTimeoutMillis: 30000, // Close idle connections after 30 seconds
|
|
18
|
+
connectionTimeoutMillis: 2000, // Fail fast if can't connect
|
|
19
|
+
});
|
|
20
|
+
|
|
21
|
+
const telemetry = this.__getTelemetry();
|
|
22
|
+
|
|
23
|
+
this.pool =
|
|
24
|
+
telemetry?.traceClass(basePool, {
|
|
25
|
+
spanNamePrefix: 'pg-vector',
|
|
26
|
+
attributes: {
|
|
27
|
+
'vector.type': 'postgres',
|
|
28
|
+
},
|
|
29
|
+
}) ?? basePool;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
transformFilter(filter?: Filter) {
|
|
33
|
+
const pgFilter = new PGFilterTranslator();
|
|
34
|
+
const translatedFilter = pgFilter.translate(filter ?? {});
|
|
35
|
+
return translatedFilter;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
async query(
|
|
39
|
+
indexName: string,
|
|
40
|
+
queryVector: number[],
|
|
41
|
+
topK: number = 10,
|
|
42
|
+
filter?: Filter,
|
|
43
|
+
includeVector: boolean = false,
|
|
44
|
+
minScore: number = 0, // Optional minimum score threshold
|
|
45
|
+
): Promise<QueryResult[]> {
|
|
46
|
+
const client = await this.pool.connect();
|
|
47
|
+
try {
|
|
48
|
+
const vectorStr = `[${queryVector.join(',')}]`;
|
|
49
|
+
|
|
50
|
+
const translatedFilter = this.transformFilter(filter);
|
|
51
|
+
const { sql: filterQuery, values: filterValues } = buildFilterQuery(translatedFilter, minScore);
|
|
52
|
+
|
|
53
|
+
const query = `
|
|
54
|
+
WITH vector_scores AS (
|
|
55
|
+
SELECT
|
|
56
|
+
vector_id as id,
|
|
57
|
+
1 - (embedding <=> '${vectorStr}'::vector) as score,
|
|
58
|
+
metadata
|
|
59
|
+
${includeVector ? ', embedding' : ''}
|
|
60
|
+
FROM ${indexName}
|
|
61
|
+
${filterQuery}
|
|
62
|
+
)
|
|
63
|
+
SELECT *
|
|
64
|
+
FROM vector_scores
|
|
65
|
+
WHERE score > $1
|
|
66
|
+
ORDER BY score DESC
|
|
67
|
+
LIMIT ${topK}`;
|
|
68
|
+
const result = await client.query(query, filterValues);
|
|
69
|
+
|
|
70
|
+
return result.rows.map(({ id, score, metadata, embedding }) => ({
|
|
71
|
+
id,
|
|
72
|
+
score,
|
|
73
|
+
metadata,
|
|
74
|
+
...(includeVector && embedding && { vector: JSON.parse(embedding) }),
|
|
75
|
+
}));
|
|
76
|
+
} finally {
|
|
77
|
+
client.release();
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
async upsert(
|
|
82
|
+
indexName: string,
|
|
83
|
+
vectors: number[][],
|
|
84
|
+
metadata?: Record<string, any>[],
|
|
85
|
+
ids?: string[],
|
|
86
|
+
): Promise<string[]> {
|
|
87
|
+
// Start a transaction
|
|
88
|
+
const client = await this.pool.connect();
|
|
89
|
+
try {
|
|
90
|
+
await client.query('BEGIN');
|
|
91
|
+
|
|
92
|
+
const vectorIds = ids || vectors.map(() => crypto.randomUUID());
|
|
93
|
+
|
|
94
|
+
for (let i = 0; i < vectors.length; i++) {
|
|
95
|
+
const query = `
|
|
96
|
+
INSERT INTO ${indexName} (vector_id, embedding, metadata)
|
|
97
|
+
VALUES ($1, $2::vector, $3::jsonb)
|
|
98
|
+
ON CONFLICT (vector_id)
|
|
99
|
+
DO UPDATE SET
|
|
100
|
+
embedding = $2::vector,
|
|
101
|
+
metadata = $3::jsonb
|
|
102
|
+
RETURNING embedding::text
|
|
103
|
+
`;
|
|
104
|
+
|
|
105
|
+
await client.query(query, [vectorIds[i], `[${vectors[i]?.join(',')}]`, JSON.stringify(metadata?.[i] || {})]);
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
await client.query('COMMIT');
|
|
109
|
+
|
|
110
|
+
return vectorIds;
|
|
111
|
+
} catch (error) {
|
|
112
|
+
await client.query('ROLLBACK');
|
|
113
|
+
throw error;
|
|
114
|
+
} finally {
|
|
115
|
+
client.release();
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
async createIndex(
|
|
120
|
+
indexName: string,
|
|
121
|
+
dimension: number,
|
|
122
|
+
metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
|
|
123
|
+
): Promise<void> {
|
|
124
|
+
const client = await this.pool.connect();
|
|
125
|
+
try {
|
|
126
|
+
// Validate inputs
|
|
127
|
+
if (!indexName.match(/^[a-zA-Z_][a-zA-Z0-9_]*$/)) {
|
|
128
|
+
throw new Error('Invalid index name format');
|
|
129
|
+
}
|
|
130
|
+
if (!Number.isInteger(dimension) || dimension <= 0) {
|
|
131
|
+
throw new Error('Dimension must be a positive integer');
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
// First check if vector extension is available
|
|
135
|
+
const extensionCheck = await client.query(`
|
|
136
|
+
SELECT EXISTS (
|
|
137
|
+
SELECT 1 FROM pg_available_extensions WHERE name = 'vector'
|
|
138
|
+
);
|
|
139
|
+
`);
|
|
140
|
+
|
|
141
|
+
if (!extensionCheck.rows[0].exists) {
|
|
142
|
+
throw new Error('PostgreSQL vector extension is not available. Please install it first.');
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
// Try to create extension
|
|
146
|
+
await client.query('CREATE EXTENSION IF NOT EXISTS vector');
|
|
147
|
+
|
|
148
|
+
// Create the table with explicit schema
|
|
149
|
+
await client.query(`
|
|
150
|
+
CREATE TABLE IF NOT EXISTS ${indexName} (
|
|
151
|
+
id SERIAL PRIMARY KEY,
|
|
152
|
+
vector_id TEXT UNIQUE NOT NULL,
|
|
153
|
+
embedding vector(${dimension}),
|
|
154
|
+
metadata JSONB DEFAULT '{}'::jsonb
|
|
155
|
+
);
|
|
156
|
+
`);
|
|
157
|
+
|
|
158
|
+
// Create the index
|
|
159
|
+
const indexMethod =
|
|
160
|
+
metric === 'cosine' ? 'vector_cosine_ops' : metric === 'euclidean' ? 'vector_l2_ops' : 'vector_ip_ops';
|
|
161
|
+
|
|
162
|
+
await client.query(`
|
|
163
|
+
CREATE INDEX IF NOT EXISTS ${indexName}_vector_idx
|
|
164
|
+
ON public.${indexName}
|
|
165
|
+
USING ivfflat (embedding ${indexMethod})
|
|
166
|
+
WITH (lists = 100);
|
|
167
|
+
`);
|
|
168
|
+
} catch (error: any) {
|
|
169
|
+
console.error('Failed to create vector table:', error);
|
|
170
|
+
throw error;
|
|
171
|
+
} finally {
|
|
172
|
+
client.release();
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
async listIndexes(): Promise<string[]> {
|
|
177
|
+
const client = await this.pool.connect();
|
|
178
|
+
try {
|
|
179
|
+
// Then let's see which ones have vector columns
|
|
180
|
+
const vectorTablesQuery = `
|
|
181
|
+
SELECT DISTINCT table_name
|
|
182
|
+
FROM information_schema.columns
|
|
183
|
+
WHERE table_schema = 'public'
|
|
184
|
+
AND udt_name = 'vector';
|
|
185
|
+
`;
|
|
186
|
+
const vectorTables = await client.query(vectorTablesQuery);
|
|
187
|
+
return vectorTables.rows.map(row => row.table_name);
|
|
188
|
+
} finally {
|
|
189
|
+
client.release();
|
|
190
|
+
}
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
async describeIndex(indexName: string): Promise<IndexStats> {
|
|
194
|
+
const client = await this.pool.connect();
|
|
195
|
+
try {
|
|
196
|
+
// Get vector dimension
|
|
197
|
+
const dimensionQuery = `
|
|
198
|
+
SELECT atttypmod as dimension
|
|
199
|
+
FROM pg_attribute
|
|
200
|
+
WHERE attrelid = $1::regclass
|
|
201
|
+
AND attname = 'embedding';
|
|
202
|
+
`;
|
|
203
|
+
|
|
204
|
+
// Get row count
|
|
205
|
+
const countQuery = `
|
|
206
|
+
SELECT COUNT(*) as count
|
|
207
|
+
FROM ${indexName};
|
|
208
|
+
`;
|
|
209
|
+
|
|
210
|
+
// Get index metric type
|
|
211
|
+
const metricQuery = `
|
|
212
|
+
SELECT
|
|
213
|
+
am.amname as index_method,
|
|
214
|
+
opclass.opcname as operator_class
|
|
215
|
+
FROM pg_index i
|
|
216
|
+
JOIN pg_class c ON i.indexrelid = c.oid
|
|
217
|
+
JOIN pg_am am ON c.relam = am.oid
|
|
218
|
+
JOIN pg_opclass opclass ON i.indclass[0] = opclass.oid
|
|
219
|
+
WHERE c.relname = '${indexName}_vector_idx';
|
|
220
|
+
`;
|
|
221
|
+
|
|
222
|
+
const [dimResult, countResult, metricResult] = await Promise.all([
|
|
223
|
+
client.query(dimensionQuery, [indexName]),
|
|
224
|
+
client.query(countQuery),
|
|
225
|
+
client.query(metricQuery),
|
|
226
|
+
]);
|
|
227
|
+
|
|
228
|
+
// Convert pg_vector index method to our metric type
|
|
229
|
+
let metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine';
|
|
230
|
+
if (metricResult.rows.length > 0) {
|
|
231
|
+
const operatorClass = metricResult.rows[0].operator_class;
|
|
232
|
+
if (operatorClass.includes('l2')) {
|
|
233
|
+
metric = 'euclidean';
|
|
234
|
+
} else if (operatorClass.includes('ip')) {
|
|
235
|
+
metric = 'dotproduct';
|
|
236
|
+
} else if (operatorClass.includes('cosine')) {
|
|
237
|
+
metric = 'cosine';
|
|
238
|
+
}
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
return {
|
|
242
|
+
dimension: dimResult.rows[0].dimension,
|
|
243
|
+
count: parseInt(countResult.rows[0].count),
|
|
244
|
+
metric,
|
|
245
|
+
};
|
|
246
|
+
} catch (e: any) {
|
|
247
|
+
await client.query('ROLLBACK');
|
|
248
|
+
throw new Error(`Failed to describe vector table: ${e.message}`);
|
|
249
|
+
} finally {
|
|
250
|
+
client.release();
|
|
251
|
+
}
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
async deleteIndex(indexName: string): Promise<void> {
|
|
255
|
+
const client = await this.pool.connect();
|
|
256
|
+
try {
|
|
257
|
+
// Drop the table
|
|
258
|
+
await client.query(`DROP TABLE IF EXISTS ${indexName} CASCADE`);
|
|
259
|
+
} catch (error: any) {
|
|
260
|
+
await client.query('ROLLBACK');
|
|
261
|
+
throw new Error(`Failed to delete vector table: ${error.message}`);
|
|
262
|
+
} finally {
|
|
263
|
+
client.release();
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
async truncateIndex(indexName: string) {
|
|
268
|
+
const client = await this.pool.connect();
|
|
269
|
+
try {
|
|
270
|
+
await client.query(`TRUNCATE ${indexName}`);
|
|
271
|
+
} catch (e: any) {
|
|
272
|
+
await client.query('ROLLBACK');
|
|
273
|
+
throw new Error(`Failed to truncate vector table: ${e.message}`);
|
|
274
|
+
} finally {
|
|
275
|
+
client.release();
|
|
276
|
+
}
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
async disconnect() {
|
|
280
|
+
await this.pool.end();
|
|
281
|
+
}
|
|
282
|
+
}
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
import {
|
|
2
|
+
BasicOperator,
|
|
3
|
+
NumericOperator,
|
|
4
|
+
ArrayOperator,
|
|
5
|
+
ElementOperator,
|
|
6
|
+
LogicalOperator,
|
|
7
|
+
RegexOperator,
|
|
8
|
+
Filter,
|
|
9
|
+
} from '@mastra/core/filter';
|
|
10
|
+
|
|
11
|
+
export type OperatorType =
|
|
12
|
+
| BasicOperator
|
|
13
|
+
| NumericOperator
|
|
14
|
+
| ArrayOperator
|
|
15
|
+
| ElementOperator
|
|
16
|
+
| LogicalOperator
|
|
17
|
+
| '$contains'
|
|
18
|
+
| Exclude<RegexOperator, '$options'>;
|
|
19
|
+
|
|
20
|
+
type FilterOperator = {
|
|
21
|
+
sql: string;
|
|
22
|
+
needsValue: boolean;
|
|
23
|
+
transformValue?: (value: any) => any;
|
|
24
|
+
};
|
|
25
|
+
|
|
26
|
+
type OperatorFn = (key: string, paramIndex: number, value?: any) => FilterOperator;
|
|
27
|
+
|
|
28
|
+
// Helper functions to create operators
|
|
29
|
+
const createBasicOperator = (symbol: string) => {
|
|
30
|
+
return (key: string, paramIndex: number) => ({
|
|
31
|
+
sql: `CASE
|
|
32
|
+
WHEN $${paramIndex}::text IS NULL THEN metadata#>>'{${handleKey(key)}}' IS ${symbol === '=' ? '' : 'NOT'} NULL
|
|
33
|
+
ELSE metadata#>>'{${handleKey(key)}}' ${symbol} $${paramIndex}::text
|
|
34
|
+
END`,
|
|
35
|
+
needsValue: true,
|
|
36
|
+
});
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
const createNumericOperator = (symbol: string) => {
|
|
40
|
+
return (key: string, paramIndex: number) => ({
|
|
41
|
+
sql: `(metadata#>>'{${handleKey(key)}}')::numeric ${symbol} $${paramIndex}`,
|
|
42
|
+
needsValue: true,
|
|
43
|
+
});
|
|
44
|
+
};
|
|
45
|
+
|
|
46
|
+
function buildElemMatchConditions(value: any, paramIndex: number): { sql: string; values: any[] } {
|
|
47
|
+
if (typeof value !== 'object' || Array.isArray(value)) {
|
|
48
|
+
throw new Error('$elemMatch requires an object with conditions');
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
const conditions: string[] = [];
|
|
52
|
+
const values: any[] = [];
|
|
53
|
+
|
|
54
|
+
Object.entries(value).forEach(([field, val]) => {
|
|
55
|
+
const nextParamIndex = paramIndex + values.length;
|
|
56
|
+
|
|
57
|
+
let paramOperator;
|
|
58
|
+
let paramKey;
|
|
59
|
+
let paramValue;
|
|
60
|
+
|
|
61
|
+
if (field.startsWith('$')) {
|
|
62
|
+
paramOperator = field;
|
|
63
|
+
paramKey = '';
|
|
64
|
+
paramValue = val;
|
|
65
|
+
} else if (typeof val === 'object' && !Array.isArray(val)) {
|
|
66
|
+
const [op, opValue] = Object.entries(val || {})[0] || [];
|
|
67
|
+
paramOperator = op;
|
|
68
|
+
paramKey = field;
|
|
69
|
+
paramValue = opValue;
|
|
70
|
+
} else {
|
|
71
|
+
paramOperator = '$eq';
|
|
72
|
+
paramKey = field;
|
|
73
|
+
paramValue = val;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
const operatorFn = FILTER_OPERATORS[paramOperator as keyof typeof FILTER_OPERATORS];
|
|
77
|
+
if (!operatorFn) {
|
|
78
|
+
throw new Error(`Invalid operator: ${paramOperator}`);
|
|
79
|
+
}
|
|
80
|
+
const result = operatorFn(paramKey, nextParamIndex, paramValue);
|
|
81
|
+
|
|
82
|
+
const sql = result.sql.replaceAll('metadata#>>', 'elem#>>');
|
|
83
|
+
conditions.push(sql);
|
|
84
|
+
if (result.needsValue) {
|
|
85
|
+
values.push(paramValue);
|
|
86
|
+
}
|
|
87
|
+
});
|
|
88
|
+
|
|
89
|
+
return {
|
|
90
|
+
sql: conditions.join(' AND '),
|
|
91
|
+
values,
|
|
92
|
+
};
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
// Define all filter operators
|
|
96
|
+
export const FILTER_OPERATORS: Record<string, OperatorFn> = {
|
|
97
|
+
$eq: createBasicOperator('='),
|
|
98
|
+
$ne: createBasicOperator('!='),
|
|
99
|
+
$gt: createNumericOperator('>'),
|
|
100
|
+
$gte: createNumericOperator('>='),
|
|
101
|
+
$lt: createNumericOperator('<'),
|
|
102
|
+
$lte: createNumericOperator('<='),
|
|
103
|
+
|
|
104
|
+
// Array Operators
|
|
105
|
+
$in: (key, paramIndex) => ({
|
|
106
|
+
sql: `metadata#>>'{${handleKey(key)}}' = ANY($${paramIndex}::text[])`,
|
|
107
|
+
needsValue: true,
|
|
108
|
+
}),
|
|
109
|
+
$nin: (key, paramIndex) => ({
|
|
110
|
+
sql: `metadata#>>'{${handleKey(key)}}' != ALL($${paramIndex}::text[])`,
|
|
111
|
+
needsValue: true,
|
|
112
|
+
}),
|
|
113
|
+
$all: (key, paramIndex) => ({
|
|
114
|
+
sql: `CASE WHEN array_length($${paramIndex}::text[], 1) IS NULL THEN false
|
|
115
|
+
ELSE (metadata#>'{${handleKey(key)}}')::jsonb ?& $${paramIndex}::text[] END`,
|
|
116
|
+
needsValue: true,
|
|
117
|
+
}),
|
|
118
|
+
$elemMatch: (key: string, paramIndex: number, value: any): FilterOperator => {
|
|
119
|
+
const { sql, values } = buildElemMatchConditions(value, paramIndex);
|
|
120
|
+
return {
|
|
121
|
+
sql: `(
|
|
122
|
+
CASE
|
|
123
|
+
WHEN jsonb_typeof(metadata->'${handleKey(key)}') = 'array' THEN
|
|
124
|
+
EXISTS (
|
|
125
|
+
SELECT 1
|
|
126
|
+
FROM jsonb_array_elements(metadata->'${handleKey(key)}') as elem
|
|
127
|
+
WHERE ${sql}
|
|
128
|
+
)
|
|
129
|
+
ELSE FALSE
|
|
130
|
+
END
|
|
131
|
+
)`,
|
|
132
|
+
needsValue: true,
|
|
133
|
+
transformValue: () => values,
|
|
134
|
+
};
|
|
135
|
+
},
|
|
136
|
+
// Element Operators
|
|
137
|
+
$exists: key => ({
|
|
138
|
+
sql: `metadata ? '${key}'`,
|
|
139
|
+
needsValue: false,
|
|
140
|
+
}),
|
|
141
|
+
|
|
142
|
+
// Logical Operators
|
|
143
|
+
$and: key => ({ sql: `(${key})`, needsValue: false }),
|
|
144
|
+
$or: key => ({ sql: `(${key})`, needsValue: false }),
|
|
145
|
+
$not: key => ({ sql: `NOT (${key})`, needsValue: false }),
|
|
146
|
+
$nor: key => ({ sql: `NOT (${key})`, needsValue: false }),
|
|
147
|
+
|
|
148
|
+
// Regex Operators
|
|
149
|
+
$regex: (key, paramIndex) => ({
|
|
150
|
+
sql: `metadata#>>'{${handleKey(key)}}' ~ $${paramIndex}`,
|
|
151
|
+
needsValue: true,
|
|
152
|
+
}),
|
|
153
|
+
|
|
154
|
+
$contains: (key, paramIndex) => ({
|
|
155
|
+
sql: `metadata @> $${paramIndex}::jsonb`,
|
|
156
|
+
needsValue: true,
|
|
157
|
+
transformValue: value => {
|
|
158
|
+
const parts = key.split('.');
|
|
159
|
+
return JSON.stringify(parts.reduceRight((value, key) => ({ [key]: value }), value));
|
|
160
|
+
},
|
|
161
|
+
}),
|
|
162
|
+
$size: (key: string, paramIndex: number) => ({
|
|
163
|
+
sql: `(
|
|
164
|
+
CASE
|
|
165
|
+
WHEN jsonb_typeof(metadata#>'{${handleKey(key)}}') = 'array' THEN
|
|
166
|
+
jsonb_array_length(metadata#>'{${handleKey(key)}}') = $${paramIndex}
|
|
167
|
+
ELSE FALSE
|
|
168
|
+
END
|
|
169
|
+
)`,
|
|
170
|
+
needsValue: true,
|
|
171
|
+
}),
|
|
172
|
+
};
|
|
173
|
+
|
|
174
|
+
export interface FilterResult {
|
|
175
|
+
sql: string;
|
|
176
|
+
values: any[];
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
export const handleKey = (key: string) => {
|
|
180
|
+
return key.replace(/\./g, ',');
|
|
181
|
+
};
|
|
182
|
+
|
|
183
|
+
export function buildFilterQuery(filter: Filter, minScore: number): FilterResult {
|
|
184
|
+
const values = [minScore];
|
|
185
|
+
|
|
186
|
+
function buildCondition(key: string, value: any, parentPath: string): string {
|
|
187
|
+
// Handle logical operators ($and/$or)
|
|
188
|
+
if (['$and', '$or', '$not', '$nor'].includes(key)) {
|
|
189
|
+
return handleLogicalOperator(key as '$and' | '$or' | '$not' | '$nor', value, parentPath);
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
// If condition is not a FilterCondition object, assume it's an equality check
|
|
193
|
+
if (!value || typeof value !== 'object') {
|
|
194
|
+
values.push(value);
|
|
195
|
+
return `metadata#>>'{${handleKey(key)}}' = $${values.length}`;
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
// Handle operator conditions
|
|
199
|
+
const [[operator, operatorValue] = []] = Object.entries(value);
|
|
200
|
+
|
|
201
|
+
// Special handling for nested $not
|
|
202
|
+
if (operator === '$not') {
|
|
203
|
+
const entries = Object.entries(operatorValue as Record<string, unknown>);
|
|
204
|
+
const conditions = entries
|
|
205
|
+
.map(([nestedOp, nestedValue]) => {
|
|
206
|
+
if (!FILTER_OPERATORS[nestedOp as keyof typeof FILTER_OPERATORS]) {
|
|
207
|
+
throw new Error(`Invalid operator in $not condition: ${nestedOp}`);
|
|
208
|
+
}
|
|
209
|
+
const operatorFn = FILTER_OPERATORS[nestedOp]!;
|
|
210
|
+
const operatorResult = operatorFn(key, values.length + 1);
|
|
211
|
+
if (operatorResult.needsValue) {
|
|
212
|
+
values.push(nestedValue as number);
|
|
213
|
+
}
|
|
214
|
+
return operatorResult.sql;
|
|
215
|
+
})
|
|
216
|
+
.join(' AND ');
|
|
217
|
+
|
|
218
|
+
return `NOT (${conditions})`;
|
|
219
|
+
}
|
|
220
|
+
const operatorFn = FILTER_OPERATORS[operator as string]!;
|
|
221
|
+
const operatorResult = operatorFn(key, values.length + 1, operatorValue);
|
|
222
|
+
if (operatorResult.needsValue) {
|
|
223
|
+
const transformedValue = operatorResult.transformValue
|
|
224
|
+
? operatorResult.transformValue(operatorValue)
|
|
225
|
+
: operatorValue;
|
|
226
|
+
if (Array.isArray(transformedValue) && operator === '$elemMatch') {
|
|
227
|
+
values.push(...transformedValue);
|
|
228
|
+
} else {
|
|
229
|
+
values.push(transformedValue);
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
return operatorResult.sql;
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
function handleLogicalOperator(key: '$and' | '$or' | '$not' | '$nor', value: Filter[], parentPath: string): string {
|
|
236
|
+
if (key === '$not') {
|
|
237
|
+
// For top-level $not
|
|
238
|
+
const entries = Object.entries(value);
|
|
239
|
+
const conditions = entries
|
|
240
|
+
.map(([fieldKey, fieldValue]) => buildCondition(fieldKey, fieldValue, key))
|
|
241
|
+
.join(' AND ');
|
|
242
|
+
return `NOT (${conditions})`;
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
// Handle empty conditions
|
|
246
|
+
if (!value || value.length === 0) {
|
|
247
|
+
switch (key) {
|
|
248
|
+
case '$and':
|
|
249
|
+
case '$nor':
|
|
250
|
+
return 'true'; // Empty $and/$nor match everything
|
|
251
|
+
case '$or':
|
|
252
|
+
return 'false'; // Empty $or matches nothing
|
|
253
|
+
default:
|
|
254
|
+
return 'true';
|
|
255
|
+
}
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
const joinOperator = key === '$or' || key === '$nor' ? 'OR' : 'AND';
|
|
259
|
+
const conditions = value.map((f: Filter) => {
|
|
260
|
+
const entries = Object.entries(f);
|
|
261
|
+
if (entries.length === 0) return '';
|
|
262
|
+
|
|
263
|
+
const [firstKey, firstValue] = entries[0] || [];
|
|
264
|
+
if (['$and', '$or', '$not', '$nor'].includes(firstKey as string)) {
|
|
265
|
+
return buildCondition(firstKey as string, firstValue, parentPath);
|
|
266
|
+
}
|
|
267
|
+
return entries.map(([k, v]) => buildCondition(k, v, parentPath)).join(` ${joinOperator} `);
|
|
268
|
+
});
|
|
269
|
+
|
|
270
|
+
const joined = conditions.join(` ${joinOperator} `);
|
|
271
|
+
const operatorFn = FILTER_OPERATORS[key]!;
|
|
272
|
+
return operatorFn(joined, 0, value).sql;
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
if (!filter) {
|
|
276
|
+
return { sql: '', values };
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
const conditions = Object.entries(filter)
|
|
280
|
+
.map(([key, value]) => buildCondition(key, value, ''))
|
|
281
|
+
.filter(Boolean)
|
|
282
|
+
.join(' AND ');
|
|
283
|
+
|
|
284
|
+
return { sql: conditions ? `WHERE ${conditions}` : '', values };
|
|
285
|
+
}
|
package/tsconfig.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
{
|
|
2
|
+
"extends": "../../tsconfig.node.json",
|
|
3
|
+
"compilerOptions": {
|
|
4
|
+
"moduleResolution": "bundler",
|
|
5
|
+
"outDir": "./dist",
|
|
6
|
+
"rootDir": "./src",
|
|
7
|
+
"module": "ES2022",
|
|
8
|
+
"target": "ES2020",
|
|
9
|
+
"declaration": true,
|
|
10
|
+
"declarationMap": true,
|
|
11
|
+
"noEmit": false
|
|
12
|
+
},
|
|
13
|
+
"include": ["src/**/*"],
|
|
14
|
+
"exclude": ["node_modules", "**/*.test.ts"]
|
|
15
|
+
}
|