@mastra/pg 0.1.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,282 @@
1
+ import type { Filter } from '@mastra/core/filter';
2
+ import { type IndexStats, type QueryResult, MastraVector } from '@mastra/core/vector';
3
+ import pg from 'pg';
4
+
5
+ import { PGFilterTranslator } from './filter';
6
+ import { buildFilterQuery } from './sql-builder';
7
+
8
+ export class PgVector extends MastraVector {
9
+ private pool: pg.Pool;
10
+
11
+ constructor(connectionString: string) {
12
+ super();
13
+
14
+ const basePool = new pg.Pool({
15
+ connectionString,
16
+ max: 20, // Maximum number of clients in the pool
17
+ idleTimeoutMillis: 30000, // Close idle connections after 30 seconds
18
+ connectionTimeoutMillis: 2000, // Fail fast if can't connect
19
+ });
20
+
21
+ const telemetry = this.__getTelemetry();
22
+
23
+ this.pool =
24
+ telemetry?.traceClass(basePool, {
25
+ spanNamePrefix: 'pg-vector',
26
+ attributes: {
27
+ 'vector.type': 'postgres',
28
+ },
29
+ }) ?? basePool;
30
+ }
31
+
32
+ transformFilter(filter?: Filter) {
33
+ const pgFilter = new PGFilterTranslator();
34
+ const translatedFilter = pgFilter.translate(filter ?? {});
35
+ return translatedFilter;
36
+ }
37
+
38
+ async query(
39
+ indexName: string,
40
+ queryVector: number[],
41
+ topK: number = 10,
42
+ filter?: Filter,
43
+ includeVector: boolean = false,
44
+ minScore: number = 0, // Optional minimum score threshold
45
+ ): Promise<QueryResult[]> {
46
+ const client = await this.pool.connect();
47
+ try {
48
+ const vectorStr = `[${queryVector.join(',')}]`;
49
+
50
+ const translatedFilter = this.transformFilter(filter);
51
+ const { sql: filterQuery, values: filterValues } = buildFilterQuery(translatedFilter, minScore);
52
+
53
+ const query = `
54
+ WITH vector_scores AS (
55
+ SELECT
56
+ vector_id as id,
57
+ 1 - (embedding <=> '${vectorStr}'::vector) as score,
58
+ metadata
59
+ ${includeVector ? ', embedding' : ''}
60
+ FROM ${indexName}
61
+ ${filterQuery}
62
+ )
63
+ SELECT *
64
+ FROM vector_scores
65
+ WHERE score > $1
66
+ ORDER BY score DESC
67
+ LIMIT ${topK}`;
68
+ const result = await client.query(query, filterValues);
69
+
70
+ return result.rows.map(({ id, score, metadata, embedding }) => ({
71
+ id,
72
+ score,
73
+ metadata,
74
+ ...(includeVector && embedding && { vector: JSON.parse(embedding) }),
75
+ }));
76
+ } finally {
77
+ client.release();
78
+ }
79
+ }
80
+
81
+ async upsert(
82
+ indexName: string,
83
+ vectors: number[][],
84
+ metadata?: Record<string, any>[],
85
+ ids?: string[],
86
+ ): Promise<string[]> {
87
+ // Start a transaction
88
+ const client = await this.pool.connect();
89
+ try {
90
+ await client.query('BEGIN');
91
+
92
+ const vectorIds = ids || vectors.map(() => crypto.randomUUID());
93
+
94
+ for (let i = 0; i < vectors.length; i++) {
95
+ const query = `
96
+ INSERT INTO ${indexName} (vector_id, embedding, metadata)
97
+ VALUES ($1, $2::vector, $3::jsonb)
98
+ ON CONFLICT (vector_id)
99
+ DO UPDATE SET
100
+ embedding = $2::vector,
101
+ metadata = $3::jsonb
102
+ RETURNING embedding::text
103
+ `;
104
+
105
+ await client.query(query, [vectorIds[i], `[${vectors[i]?.join(',')}]`, JSON.stringify(metadata?.[i] || {})]);
106
+ }
107
+
108
+ await client.query('COMMIT');
109
+
110
+ return vectorIds;
111
+ } catch (error) {
112
+ await client.query('ROLLBACK');
113
+ throw error;
114
+ } finally {
115
+ client.release();
116
+ }
117
+ }
118
+
119
+ async createIndex(
120
+ indexName: string,
121
+ dimension: number,
122
+ metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
123
+ ): Promise<void> {
124
+ const client = await this.pool.connect();
125
+ try {
126
+ // Validate inputs
127
+ if (!indexName.match(/^[a-zA-Z_][a-zA-Z0-9_]*$/)) {
128
+ throw new Error('Invalid index name format');
129
+ }
130
+ if (!Number.isInteger(dimension) || dimension <= 0) {
131
+ throw new Error('Dimension must be a positive integer');
132
+ }
133
+
134
+ // First check if vector extension is available
135
+ const extensionCheck = await client.query(`
136
+ SELECT EXISTS (
137
+ SELECT 1 FROM pg_available_extensions WHERE name = 'vector'
138
+ );
139
+ `);
140
+
141
+ if (!extensionCheck.rows[0].exists) {
142
+ throw new Error('PostgreSQL vector extension is not available. Please install it first.');
143
+ }
144
+
145
+ // Try to create extension
146
+ await client.query('CREATE EXTENSION IF NOT EXISTS vector');
147
+
148
+ // Create the table with explicit schema
149
+ await client.query(`
150
+ CREATE TABLE IF NOT EXISTS ${indexName} (
151
+ id SERIAL PRIMARY KEY,
152
+ vector_id TEXT UNIQUE NOT NULL,
153
+ embedding vector(${dimension}),
154
+ metadata JSONB DEFAULT '{}'::jsonb
155
+ );
156
+ `);
157
+
158
+ // Create the index
159
+ const indexMethod =
160
+ metric === 'cosine' ? 'vector_cosine_ops' : metric === 'euclidean' ? 'vector_l2_ops' : 'vector_ip_ops';
161
+
162
+ await client.query(`
163
+ CREATE INDEX IF NOT EXISTS ${indexName}_vector_idx
164
+ ON public.${indexName}
165
+ USING ivfflat (embedding ${indexMethod})
166
+ WITH (lists = 100);
167
+ `);
168
+ } catch (error: any) {
169
+ console.error('Failed to create vector table:', error);
170
+ throw error;
171
+ } finally {
172
+ client.release();
173
+ }
174
+ }
175
+
176
+ async listIndexes(): Promise<string[]> {
177
+ const client = await this.pool.connect();
178
+ try {
179
+ // Then let's see which ones have vector columns
180
+ const vectorTablesQuery = `
181
+ SELECT DISTINCT table_name
182
+ FROM information_schema.columns
183
+ WHERE table_schema = 'public'
184
+ AND udt_name = 'vector';
185
+ `;
186
+ const vectorTables = await client.query(vectorTablesQuery);
187
+ return vectorTables.rows.map(row => row.table_name);
188
+ } finally {
189
+ client.release();
190
+ }
191
+ }
192
+
193
+ async describeIndex(indexName: string): Promise<IndexStats> {
194
+ const client = await this.pool.connect();
195
+ try {
196
+ // Get vector dimension
197
+ const dimensionQuery = `
198
+ SELECT atttypmod as dimension
199
+ FROM pg_attribute
200
+ WHERE attrelid = $1::regclass
201
+ AND attname = 'embedding';
202
+ `;
203
+
204
+ // Get row count
205
+ const countQuery = `
206
+ SELECT COUNT(*) as count
207
+ FROM ${indexName};
208
+ `;
209
+
210
+ // Get index metric type
211
+ const metricQuery = `
212
+ SELECT
213
+ am.amname as index_method,
214
+ opclass.opcname as operator_class
215
+ FROM pg_index i
216
+ JOIN pg_class c ON i.indexrelid = c.oid
217
+ JOIN pg_am am ON c.relam = am.oid
218
+ JOIN pg_opclass opclass ON i.indclass[0] = opclass.oid
219
+ WHERE c.relname = '${indexName}_vector_idx';
220
+ `;
221
+
222
+ const [dimResult, countResult, metricResult] = await Promise.all([
223
+ client.query(dimensionQuery, [indexName]),
224
+ client.query(countQuery),
225
+ client.query(metricQuery),
226
+ ]);
227
+
228
+ // Convert pg_vector index method to our metric type
229
+ let metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine';
230
+ if (metricResult.rows.length > 0) {
231
+ const operatorClass = metricResult.rows[0].operator_class;
232
+ if (operatorClass.includes('l2')) {
233
+ metric = 'euclidean';
234
+ } else if (operatorClass.includes('ip')) {
235
+ metric = 'dotproduct';
236
+ } else if (operatorClass.includes('cosine')) {
237
+ metric = 'cosine';
238
+ }
239
+ }
240
+
241
+ return {
242
+ dimension: dimResult.rows[0].dimension,
243
+ count: parseInt(countResult.rows[0].count),
244
+ metric,
245
+ };
246
+ } catch (e: any) {
247
+ await client.query('ROLLBACK');
248
+ throw new Error(`Failed to describe vector table: ${e.message}`);
249
+ } finally {
250
+ client.release();
251
+ }
252
+ }
253
+
254
+ async deleteIndex(indexName: string): Promise<void> {
255
+ const client = await this.pool.connect();
256
+ try {
257
+ // Drop the table
258
+ await client.query(`DROP TABLE IF EXISTS ${indexName} CASCADE`);
259
+ } catch (error: any) {
260
+ await client.query('ROLLBACK');
261
+ throw new Error(`Failed to delete vector table: ${error.message}`);
262
+ } finally {
263
+ client.release();
264
+ }
265
+ }
266
+
267
+ async truncateIndex(indexName: string) {
268
+ const client = await this.pool.connect();
269
+ try {
270
+ await client.query(`TRUNCATE ${indexName}`);
271
+ } catch (e: any) {
272
+ await client.query('ROLLBACK');
273
+ throw new Error(`Failed to truncate vector table: ${e.message}`);
274
+ } finally {
275
+ client.release();
276
+ }
277
+ }
278
+
279
+ async disconnect() {
280
+ await this.pool.end();
281
+ }
282
+ }
@@ -0,0 +1,285 @@
1
+ import {
2
+ BasicOperator,
3
+ NumericOperator,
4
+ ArrayOperator,
5
+ ElementOperator,
6
+ LogicalOperator,
7
+ RegexOperator,
8
+ Filter,
9
+ } from '@mastra/core/filter';
10
+
11
+ export type OperatorType =
12
+ | BasicOperator
13
+ | NumericOperator
14
+ | ArrayOperator
15
+ | ElementOperator
16
+ | LogicalOperator
17
+ | '$contains'
18
+ | Exclude<RegexOperator, '$options'>;
19
+
20
+ type FilterOperator = {
21
+ sql: string;
22
+ needsValue: boolean;
23
+ transformValue?: (value: any) => any;
24
+ };
25
+
26
+ type OperatorFn = (key: string, paramIndex: number, value?: any) => FilterOperator;
27
+
28
+ // Helper functions to create operators
29
+ const createBasicOperator = (symbol: string) => {
30
+ return (key: string, paramIndex: number) => ({
31
+ sql: `CASE
32
+ WHEN $${paramIndex}::text IS NULL THEN metadata#>>'{${handleKey(key)}}' IS ${symbol === '=' ? '' : 'NOT'} NULL
33
+ ELSE metadata#>>'{${handleKey(key)}}' ${symbol} $${paramIndex}::text
34
+ END`,
35
+ needsValue: true,
36
+ });
37
+ };
38
+
39
+ const createNumericOperator = (symbol: string) => {
40
+ return (key: string, paramIndex: number) => ({
41
+ sql: `(metadata#>>'{${handleKey(key)}}')::numeric ${symbol} $${paramIndex}`,
42
+ needsValue: true,
43
+ });
44
+ };
45
+
46
+ function buildElemMatchConditions(value: any, paramIndex: number): { sql: string; values: any[] } {
47
+ if (typeof value !== 'object' || Array.isArray(value)) {
48
+ throw new Error('$elemMatch requires an object with conditions');
49
+ }
50
+
51
+ const conditions: string[] = [];
52
+ const values: any[] = [];
53
+
54
+ Object.entries(value).forEach(([field, val]) => {
55
+ const nextParamIndex = paramIndex + values.length;
56
+
57
+ let paramOperator;
58
+ let paramKey;
59
+ let paramValue;
60
+
61
+ if (field.startsWith('$')) {
62
+ paramOperator = field;
63
+ paramKey = '';
64
+ paramValue = val;
65
+ } else if (typeof val === 'object' && !Array.isArray(val)) {
66
+ const [op, opValue] = Object.entries(val || {})[0] || [];
67
+ paramOperator = op;
68
+ paramKey = field;
69
+ paramValue = opValue;
70
+ } else {
71
+ paramOperator = '$eq';
72
+ paramKey = field;
73
+ paramValue = val;
74
+ }
75
+
76
+ const operatorFn = FILTER_OPERATORS[paramOperator as keyof typeof FILTER_OPERATORS];
77
+ if (!operatorFn) {
78
+ throw new Error(`Invalid operator: ${paramOperator}`);
79
+ }
80
+ const result = operatorFn(paramKey, nextParamIndex, paramValue);
81
+
82
+ const sql = result.sql.replaceAll('metadata#>>', 'elem#>>');
83
+ conditions.push(sql);
84
+ if (result.needsValue) {
85
+ values.push(paramValue);
86
+ }
87
+ });
88
+
89
+ return {
90
+ sql: conditions.join(' AND '),
91
+ values,
92
+ };
93
+ }
94
+
95
+ // Define all filter operators
96
+ export const FILTER_OPERATORS: Record<string, OperatorFn> = {
97
+ $eq: createBasicOperator('='),
98
+ $ne: createBasicOperator('!='),
99
+ $gt: createNumericOperator('>'),
100
+ $gte: createNumericOperator('>='),
101
+ $lt: createNumericOperator('<'),
102
+ $lte: createNumericOperator('<='),
103
+
104
+ // Array Operators
105
+ $in: (key, paramIndex) => ({
106
+ sql: `metadata#>>'{${handleKey(key)}}' = ANY($${paramIndex}::text[])`,
107
+ needsValue: true,
108
+ }),
109
+ $nin: (key, paramIndex) => ({
110
+ sql: `metadata#>>'{${handleKey(key)}}' != ALL($${paramIndex}::text[])`,
111
+ needsValue: true,
112
+ }),
113
+ $all: (key, paramIndex) => ({
114
+ sql: `CASE WHEN array_length($${paramIndex}::text[], 1) IS NULL THEN false
115
+ ELSE (metadata#>'{${handleKey(key)}}')::jsonb ?& $${paramIndex}::text[] END`,
116
+ needsValue: true,
117
+ }),
118
+ $elemMatch: (key: string, paramIndex: number, value: any): FilterOperator => {
119
+ const { sql, values } = buildElemMatchConditions(value, paramIndex);
120
+ return {
121
+ sql: `(
122
+ CASE
123
+ WHEN jsonb_typeof(metadata->'${handleKey(key)}') = 'array' THEN
124
+ EXISTS (
125
+ SELECT 1
126
+ FROM jsonb_array_elements(metadata->'${handleKey(key)}') as elem
127
+ WHERE ${sql}
128
+ )
129
+ ELSE FALSE
130
+ END
131
+ )`,
132
+ needsValue: true,
133
+ transformValue: () => values,
134
+ };
135
+ },
136
+ // Element Operators
137
+ $exists: key => ({
138
+ sql: `metadata ? '${key}'`,
139
+ needsValue: false,
140
+ }),
141
+
142
+ // Logical Operators
143
+ $and: key => ({ sql: `(${key})`, needsValue: false }),
144
+ $or: key => ({ sql: `(${key})`, needsValue: false }),
145
+ $not: key => ({ sql: `NOT (${key})`, needsValue: false }),
146
+ $nor: key => ({ sql: `NOT (${key})`, needsValue: false }),
147
+
148
+ // Regex Operators
149
+ $regex: (key, paramIndex) => ({
150
+ sql: `metadata#>>'{${handleKey(key)}}' ~ $${paramIndex}`,
151
+ needsValue: true,
152
+ }),
153
+
154
+ $contains: (key, paramIndex) => ({
155
+ sql: `metadata @> $${paramIndex}::jsonb`,
156
+ needsValue: true,
157
+ transformValue: value => {
158
+ const parts = key.split('.');
159
+ return JSON.stringify(parts.reduceRight((value, key) => ({ [key]: value }), value));
160
+ },
161
+ }),
162
+ $size: (key: string, paramIndex: number) => ({
163
+ sql: `(
164
+ CASE
165
+ WHEN jsonb_typeof(metadata#>'{${handleKey(key)}}') = 'array' THEN
166
+ jsonb_array_length(metadata#>'{${handleKey(key)}}') = $${paramIndex}
167
+ ELSE FALSE
168
+ END
169
+ )`,
170
+ needsValue: true,
171
+ }),
172
+ };
173
+
174
+ export interface FilterResult {
175
+ sql: string;
176
+ values: any[];
177
+ }
178
+
179
+ export const handleKey = (key: string) => {
180
+ return key.replace(/\./g, ',');
181
+ };
182
+
183
+ export function buildFilterQuery(filter: Filter, minScore: number): FilterResult {
184
+ const values = [minScore];
185
+
186
+ function buildCondition(key: string, value: any, parentPath: string): string {
187
+ // Handle logical operators ($and/$or)
188
+ if (['$and', '$or', '$not', '$nor'].includes(key)) {
189
+ return handleLogicalOperator(key as '$and' | '$or' | '$not' | '$nor', value, parentPath);
190
+ }
191
+
192
+ // If condition is not a FilterCondition object, assume it's an equality check
193
+ if (!value || typeof value !== 'object') {
194
+ values.push(value);
195
+ return `metadata#>>'{${handleKey(key)}}' = $${values.length}`;
196
+ }
197
+
198
+ // Handle operator conditions
199
+ const [[operator, operatorValue] = []] = Object.entries(value);
200
+
201
+ // Special handling for nested $not
202
+ if (operator === '$not') {
203
+ const entries = Object.entries(operatorValue as Record<string, unknown>);
204
+ const conditions = entries
205
+ .map(([nestedOp, nestedValue]) => {
206
+ if (!FILTER_OPERATORS[nestedOp as keyof typeof FILTER_OPERATORS]) {
207
+ throw new Error(`Invalid operator in $not condition: ${nestedOp}`);
208
+ }
209
+ const operatorFn = FILTER_OPERATORS[nestedOp]!;
210
+ const operatorResult = operatorFn(key, values.length + 1);
211
+ if (operatorResult.needsValue) {
212
+ values.push(nestedValue as number);
213
+ }
214
+ return operatorResult.sql;
215
+ })
216
+ .join(' AND ');
217
+
218
+ return `NOT (${conditions})`;
219
+ }
220
+ const operatorFn = FILTER_OPERATORS[operator as string]!;
221
+ const operatorResult = operatorFn(key, values.length + 1, operatorValue);
222
+ if (operatorResult.needsValue) {
223
+ const transformedValue = operatorResult.transformValue
224
+ ? operatorResult.transformValue(operatorValue)
225
+ : operatorValue;
226
+ if (Array.isArray(transformedValue) && operator === '$elemMatch') {
227
+ values.push(...transformedValue);
228
+ } else {
229
+ values.push(transformedValue);
230
+ }
231
+ }
232
+ return operatorResult.sql;
233
+ }
234
+
235
+ function handleLogicalOperator(key: '$and' | '$or' | '$not' | '$nor', value: Filter[], parentPath: string): string {
236
+ if (key === '$not') {
237
+ // For top-level $not
238
+ const entries = Object.entries(value);
239
+ const conditions = entries
240
+ .map(([fieldKey, fieldValue]) => buildCondition(fieldKey, fieldValue, key))
241
+ .join(' AND ');
242
+ return `NOT (${conditions})`;
243
+ }
244
+
245
+ // Handle empty conditions
246
+ if (!value || value.length === 0) {
247
+ switch (key) {
248
+ case '$and':
249
+ case '$nor':
250
+ return 'true'; // Empty $and/$nor match everything
251
+ case '$or':
252
+ return 'false'; // Empty $or matches nothing
253
+ default:
254
+ return 'true';
255
+ }
256
+ }
257
+
258
+ const joinOperator = key === '$or' || key === '$nor' ? 'OR' : 'AND';
259
+ const conditions = value.map((f: Filter) => {
260
+ const entries = Object.entries(f);
261
+ if (entries.length === 0) return '';
262
+
263
+ const [firstKey, firstValue] = entries[0] || [];
264
+ if (['$and', '$or', '$not', '$nor'].includes(firstKey as string)) {
265
+ return buildCondition(firstKey as string, firstValue, parentPath);
266
+ }
267
+ return entries.map(([k, v]) => buildCondition(k, v, parentPath)).join(` ${joinOperator} `);
268
+ });
269
+
270
+ const joined = conditions.join(` ${joinOperator} `);
271
+ const operatorFn = FILTER_OPERATORS[key]!;
272
+ return operatorFn(joined, 0, value).sql;
273
+ }
274
+
275
+ if (!filter) {
276
+ return { sql: '', values };
277
+ }
278
+
279
+ const conditions = Object.entries(filter)
280
+ .map(([key, value]) => buildCondition(key, value, ''))
281
+ .filter(Boolean)
282
+ .join(' AND ');
283
+
284
+ return { sql: conditions ? `WHERE ${conditions}` : '', values };
285
+ }
package/tsconfig.json ADDED
@@ -0,0 +1,15 @@
1
+ {
2
+ "extends": "../../tsconfig.node.json",
3
+ "compilerOptions": {
4
+ "moduleResolution": "bundler",
5
+ "outDir": "./dist",
6
+ "rootDir": "./src",
7
+ "module": "ES2022",
8
+ "target": "ES2020",
9
+ "declaration": true,
10
+ "declarationMap": true,
11
+ "noEmit": false
12
+ },
13
+ "include": ["src/**/*"],
14
+ "exclude": ["node_modules", "**/*.test.ts"]
15
+ }
@@ -0,0 +1,11 @@
1
+ import { defineConfig } from 'vitest/config';
2
+
3
+ export default defineConfig({
4
+ test: {
5
+ environment: 'node',
6
+ include: ['src/**/*.test.ts'],
7
+ coverage: {
8
+ reporter: ['text', 'json', 'html'],
9
+ },
10
+ },
11
+ });