@mastra/rag 0.0.2-alpha.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +7 -0
- package/docker-compose.yaml +18 -0
- package/jest.config.ts +19 -0
- package/package.json +57 -0
- package/src/document/index.test.ts +229 -0
- package/src/document/index.ts +129 -0
- package/src/index.ts +4 -0
- package/src/pg/index.ts +255 -0
- package/src/pg/index_test.ts +212 -0
- package/src/pinecone/index.test.ts +130 -0
- package/src/pinecone/index.ts +118 -0
- package/src/qdrant/index.test.ts +119 -0
- package/src/qdrant/index.ts +116 -0
- package/tsconfig.json +10 -0
package/src/pg/index.ts
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
import { IndexStats, QueryResult, MastraVector } from '@mastra/core';
|
|
2
|
+
import { Pool } from 'pg';
|
|
3
|
+
|
|
4
|
+
export class PgVector extends MastraVector {
|
|
5
|
+
private pool: Pool;
|
|
6
|
+
|
|
7
|
+
constructor(connectionString: string) {
|
|
8
|
+
super();
|
|
9
|
+
|
|
10
|
+
const basePool = new Pool({
|
|
11
|
+
connectionString,
|
|
12
|
+
max: 20, // Maximum number of clients in the pool
|
|
13
|
+
idleTimeoutMillis: 30000, // Close idle connections after 30 seconds
|
|
14
|
+
connectionTimeoutMillis: 2000, // Fail fast if can't connect
|
|
15
|
+
});
|
|
16
|
+
|
|
17
|
+
const telemetry = this.__getTelemetry();
|
|
18
|
+
this.pool =
|
|
19
|
+
telemetry?.traceClass(basePool, {
|
|
20
|
+
spanNamePrefix: 'pg-vector',
|
|
21
|
+
attributes: {
|
|
22
|
+
'vector.type': 'postgres',
|
|
23
|
+
},
|
|
24
|
+
}) ?? basePool;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
async query(
|
|
28
|
+
indexName: string,
|
|
29
|
+
queryVector: number[],
|
|
30
|
+
topK: number = 10,
|
|
31
|
+
filter?: Record<string, any>,
|
|
32
|
+
minScore: number = 0, // Optional minimum score threshold
|
|
33
|
+
): Promise<QueryResult[]> {
|
|
34
|
+
const client = await this.pool.connect();
|
|
35
|
+
try {
|
|
36
|
+
let filterQuery = '';
|
|
37
|
+
let filterValues: any[] = [minScore];
|
|
38
|
+
const vectorStr = `[${queryVector.join(',')}]`;
|
|
39
|
+
|
|
40
|
+
if (filter) {
|
|
41
|
+
const conditions = Object.entries(filter).map(([key, value], index) => {
|
|
42
|
+
filterValues.push(value);
|
|
43
|
+
return `metadata->>'${key}' = $${index + 2}`; // +2 because $1 is minScore
|
|
44
|
+
});
|
|
45
|
+
if (conditions.length > 0) {
|
|
46
|
+
filterQuery = 'AND ' + conditions.join(' AND ');
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
const query = `
|
|
51
|
+
WITH vector_scores AS (
|
|
52
|
+
SELECT
|
|
53
|
+
vector_id as id,
|
|
54
|
+
1 - (embedding <=> '${vectorStr}'::vector) as score,
|
|
55
|
+
metadata
|
|
56
|
+
FROM ${indexName}
|
|
57
|
+
WHERE true ${filterQuery}
|
|
58
|
+
)
|
|
59
|
+
SELECT *
|
|
60
|
+
FROM vector_scores
|
|
61
|
+
WHERE score > $1
|
|
62
|
+
ORDER BY score DESC
|
|
63
|
+
LIMIT ${topK};
|
|
64
|
+
`;
|
|
65
|
+
const result = await client.query(query, filterValues);
|
|
66
|
+
|
|
67
|
+
return result.rows.map(row => ({
|
|
68
|
+
id: row.id,
|
|
69
|
+
score: row.score,
|
|
70
|
+
metadata: row.metadata,
|
|
71
|
+
}));
|
|
72
|
+
} finally {
|
|
73
|
+
client.release();
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
async upsert(
|
|
78
|
+
indexName: string,
|
|
79
|
+
vectors: number[][],
|
|
80
|
+
metadata?: Record<string, any>[],
|
|
81
|
+
ids?: string[],
|
|
82
|
+
): Promise<string[]> {
|
|
83
|
+
// Start a transaction
|
|
84
|
+
const client = await this.pool.connect();
|
|
85
|
+
try {
|
|
86
|
+
await client.query('BEGIN');
|
|
87
|
+
|
|
88
|
+
const vectorIds = ids || vectors.map(() => crypto.randomUUID());
|
|
89
|
+
|
|
90
|
+
for (let i = 0; i < vectors.length; i++) {
|
|
91
|
+
const query = `
|
|
92
|
+
INSERT INTO ${indexName} (vector_id, embedding, metadata)
|
|
93
|
+
VALUES ($1, $2::vector, $3::jsonb)
|
|
94
|
+
ON CONFLICT (vector_id)
|
|
95
|
+
DO UPDATE SET
|
|
96
|
+
embedding = $2::vector,
|
|
97
|
+
metadata = $3::jsonb
|
|
98
|
+
RETURNING embedding::text
|
|
99
|
+
`;
|
|
100
|
+
|
|
101
|
+
await client.query(query, [vectorIds[i], `[${vectors[i]?.join(',')}]`, JSON.stringify(metadata?.[i] || {})]);
|
|
102
|
+
}
|
|
103
|
+
return vectorIds;
|
|
104
|
+
} catch (error) {
|
|
105
|
+
await client.query('ROLLBACK');
|
|
106
|
+
throw error;
|
|
107
|
+
} finally {
|
|
108
|
+
client.release();
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
async createIndex(
|
|
113
|
+
indexName: string,
|
|
114
|
+
dimension: number,
|
|
115
|
+
metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
|
|
116
|
+
): Promise<void> {
|
|
117
|
+
const client = await this.pool.connect();
|
|
118
|
+
try {
|
|
119
|
+
// Create the extension if it doesn't exist
|
|
120
|
+
await client.query('CREATE EXTENSION IF NOT EXISTS vector');
|
|
121
|
+
|
|
122
|
+
// Create the table
|
|
123
|
+
await client.query(`
|
|
124
|
+
CREATE TABLE IF NOT EXISTS ${indexName} (
|
|
125
|
+
id SERIAL PRIMARY KEY,
|
|
126
|
+
vector_id TEXT UNIQUE NOT NULL,
|
|
127
|
+
embedding vector(${dimension}),
|
|
128
|
+
metadata JSONB DEFAULT '{}'::jsonb
|
|
129
|
+
);
|
|
130
|
+
`);
|
|
131
|
+
|
|
132
|
+
// Create an index for vector similarity search based on metric
|
|
133
|
+
const indexMethod =
|
|
134
|
+
metric === 'cosine' ? 'vector_cosine_ops' : metric === 'euclidean' ? 'vector_l2_ops' : 'vector_ip_ops'; // for dotproduct
|
|
135
|
+
|
|
136
|
+
await client.query(`
|
|
137
|
+
CREATE INDEX IF NOT EXISTS ${indexName}_vector_idx
|
|
138
|
+
ON ${indexName}
|
|
139
|
+
USING ivfflat (embedding ${indexMethod})
|
|
140
|
+
WITH (lists = 100);
|
|
141
|
+
`);
|
|
142
|
+
} catch (error: any) {
|
|
143
|
+
throw error;
|
|
144
|
+
} finally {
|
|
145
|
+
client.release();
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
async listIndexes(): Promise<string[]> {
|
|
150
|
+
const client = await this.pool.connect();
|
|
151
|
+
try {
|
|
152
|
+
// Then let's see which ones have vector columns
|
|
153
|
+
const vectorTablesQuery = `
|
|
154
|
+
SELECT DISTINCT table_name
|
|
155
|
+
FROM information_schema.columns
|
|
156
|
+
WHERE table_schema = 'public'
|
|
157
|
+
AND udt_name = 'vector';
|
|
158
|
+
`;
|
|
159
|
+
const vectorTables = await client.query(vectorTablesQuery);
|
|
160
|
+
return vectorTables.rows.map(row => row.table_name);
|
|
161
|
+
} finally {
|
|
162
|
+
client.release();
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
async describeIndex(indexName: string): Promise<IndexStats> {
|
|
167
|
+
const client = await this.pool.connect();
|
|
168
|
+
try {
|
|
169
|
+
// Get vector dimension
|
|
170
|
+
const dimensionQuery = `
|
|
171
|
+
SELECT atttypmod as dimension
|
|
172
|
+
FROM pg_attribute
|
|
173
|
+
WHERE attrelid = $1::regclass
|
|
174
|
+
AND attname = 'embedding';
|
|
175
|
+
`;
|
|
176
|
+
|
|
177
|
+
// Get row count
|
|
178
|
+
const countQuery = `
|
|
179
|
+
SELECT COUNT(*) as count
|
|
180
|
+
FROM ${indexName};
|
|
181
|
+
`;
|
|
182
|
+
|
|
183
|
+
// Get index metric type
|
|
184
|
+
const metricQuery = `
|
|
185
|
+
SELECT
|
|
186
|
+
am.amname as index_method,
|
|
187
|
+
opclass.opcname as operator_class
|
|
188
|
+
FROM pg_index i
|
|
189
|
+
JOIN pg_class c ON i.indexrelid = c.oid
|
|
190
|
+
JOIN pg_am am ON c.relam = am.oid
|
|
191
|
+
JOIN pg_opclass opclass ON i.indclass[0] = opclass.oid
|
|
192
|
+
WHERE c.relname = '${indexName}_vector_idx';
|
|
193
|
+
`;
|
|
194
|
+
|
|
195
|
+
const [dimResult, countResult, metricResult] = await Promise.all([
|
|
196
|
+
client.query(dimensionQuery, [indexName]),
|
|
197
|
+
client.query(countQuery),
|
|
198
|
+
client.query(metricQuery),
|
|
199
|
+
]);
|
|
200
|
+
|
|
201
|
+
// Convert pg_vector index method to our metric type
|
|
202
|
+
let metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine';
|
|
203
|
+
if (metricResult.rows.length > 0) {
|
|
204
|
+
const operatorClass = metricResult.rows[0].operator_class;
|
|
205
|
+
if (operatorClass.includes('l2')) {
|
|
206
|
+
metric = 'euclidean';
|
|
207
|
+
} else if (operatorClass.includes('ip')) {
|
|
208
|
+
metric = 'dotproduct';
|
|
209
|
+
} else if (operatorClass.includes('cosine')) {
|
|
210
|
+
metric = 'cosine';
|
|
211
|
+
}
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
return {
|
|
215
|
+
dimension: dimResult.rows[0].dimension,
|
|
216
|
+
count: parseInt(countResult.rows[0].count),
|
|
217
|
+
metric,
|
|
218
|
+
};
|
|
219
|
+
} catch (e: any) {
|
|
220
|
+
await client.query('ROLLBACK');
|
|
221
|
+
throw new Error(`Failed to describe vector table: ${e.message}`);
|
|
222
|
+
} finally {
|
|
223
|
+
client.release();
|
|
224
|
+
}
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
async deleteIndex(indexName: string): Promise<void> {
|
|
228
|
+
const client = await this.pool.connect();
|
|
229
|
+
try {
|
|
230
|
+
// Drop the table
|
|
231
|
+
await client.query(`DROP TABLE IF EXISTS ${indexName} CASCADE`);
|
|
232
|
+
} catch (error: any) {
|
|
233
|
+
await client.query('ROLLBACK');
|
|
234
|
+
throw new Error(`Failed to delete vector table: ${error.message}`);
|
|
235
|
+
} finally {
|
|
236
|
+
client.release();
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
async truncateIndex(indexName: string) {
|
|
241
|
+
const client = await this.pool.connect();
|
|
242
|
+
try {
|
|
243
|
+
await client.query(`TRUNCATE ${indexName}`);
|
|
244
|
+
} catch (e: any) {
|
|
245
|
+
await client.query('ROLLBACK');
|
|
246
|
+
throw new Error(`Failed to truncate vector table: ${e.message}`);
|
|
247
|
+
} finally {
|
|
248
|
+
client.release();
|
|
249
|
+
}
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
async disconnect() {
|
|
253
|
+
await this.pool.end();
|
|
254
|
+
}
|
|
255
|
+
}
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
import { describe, it, expect, beforeAll, afterAll, beforeEach, afterEach } from '@jest/globals';
|
|
2
|
+
|
|
3
|
+
import { PgVector } from '.';
|
|
4
|
+
|
|
5
|
+
describe('PgVector', () => {
|
|
6
|
+
let pgVector: PgVector;
|
|
7
|
+
const testIndexName = 'test_vectors';
|
|
8
|
+
const testIndexName2 = 'test_vectors1';
|
|
9
|
+
const connectionString = process.env.DB_URL || 'postgresql://postgres:postgres@localhost:5433/mastra';
|
|
10
|
+
|
|
11
|
+
beforeAll(async () => {
|
|
12
|
+
// Initialize PgVector
|
|
13
|
+
pgVector = new PgVector(connectionString);
|
|
14
|
+
});
|
|
15
|
+
|
|
16
|
+
afterAll(async () => {
|
|
17
|
+
// Clean up test tables
|
|
18
|
+
|
|
19
|
+
await pgVector.deleteIndex(testIndexName);
|
|
20
|
+
|
|
21
|
+
await pgVector.disconnect();
|
|
22
|
+
});
|
|
23
|
+
|
|
24
|
+
describe('createIndex', () => {
|
|
25
|
+
afterAll(async () => {
|
|
26
|
+
await pgVector.deleteIndex(testIndexName2);
|
|
27
|
+
});
|
|
28
|
+
|
|
29
|
+
it('should create a new vector table with specified dimensions', async () => {
|
|
30
|
+
await pgVector.createIndex(testIndexName, 3);
|
|
31
|
+
|
|
32
|
+
const stats = await pgVector.describeIndex(testIndexName);
|
|
33
|
+
expect(stats?.dimension).toBe(3);
|
|
34
|
+
expect(stats?.count).toBe(0);
|
|
35
|
+
});
|
|
36
|
+
|
|
37
|
+
it('should create index with specified metric', async () => {
|
|
38
|
+
await pgVector.createIndex(testIndexName2, 3, 'euclidean');
|
|
39
|
+
|
|
40
|
+
const stats = await pgVector.describeIndex(testIndexName2);
|
|
41
|
+
|
|
42
|
+
expect(stats.metric).toBe('euclidean');
|
|
43
|
+
});
|
|
44
|
+
|
|
45
|
+
it('should throw error if dimension is invalid', async () => {
|
|
46
|
+
await expect(pgVector.createIndex(`testIndexNameFail`, 0)).rejects.toThrow();
|
|
47
|
+
});
|
|
48
|
+
});
|
|
49
|
+
|
|
50
|
+
describe('upsert', () => {
|
|
51
|
+
beforeEach(async () => {
|
|
52
|
+
await pgVector.createIndex(testIndexName, 3);
|
|
53
|
+
});
|
|
54
|
+
|
|
55
|
+
afterEach(async () => {
|
|
56
|
+
await pgVector.deleteIndex(testIndexName);
|
|
57
|
+
});
|
|
58
|
+
|
|
59
|
+
it('should insert new vectors', async () => {
|
|
60
|
+
const vectors = [
|
|
61
|
+
[1, 2, 3],
|
|
62
|
+
[4, 5, 6],
|
|
63
|
+
];
|
|
64
|
+
const ids = await pgVector.upsert(testIndexName, vectors);
|
|
65
|
+
|
|
66
|
+
expect(ids).toHaveLength(2);
|
|
67
|
+
const stats = await pgVector.describeIndex(testIndexName);
|
|
68
|
+
expect(stats.count).toBe(2);
|
|
69
|
+
});
|
|
70
|
+
|
|
71
|
+
it('should update existing vectors', async () => {
|
|
72
|
+
const vectors = [[1, 2, 3]];
|
|
73
|
+
const metadata = [{ test: 'initial' }];
|
|
74
|
+
const [id] = await pgVector.upsert(testIndexName, vectors, metadata);
|
|
75
|
+
|
|
76
|
+
// Update the same vector
|
|
77
|
+
const updatedVectors = [[4, 5, 6]];
|
|
78
|
+
const updatedMetadata = [{ test: 'updated' }];
|
|
79
|
+
await pgVector.upsert(testIndexName, updatedVectors, updatedMetadata, [id!]);
|
|
80
|
+
|
|
81
|
+
const results = await pgVector.query(testIndexName, [4, 5, 6], 1);
|
|
82
|
+
expect(results[0]?.id).toBe(id);
|
|
83
|
+
expect(results[0]?.metadata).toEqual({ test: 'updated' });
|
|
84
|
+
});
|
|
85
|
+
|
|
86
|
+
it('should handle metadata correctly', async () => {
|
|
87
|
+
const vectors = [[1, 2, 3]];
|
|
88
|
+
const metadata = [{ test: 'value', num: 123 }];
|
|
89
|
+
|
|
90
|
+
await pgVector.upsert(testIndexName, vectors, metadata);
|
|
91
|
+
const results = await pgVector.query(testIndexName, [1, 2, 3], 1);
|
|
92
|
+
|
|
93
|
+
expect(results[0]?.metadata).toEqual(metadata[0]);
|
|
94
|
+
});
|
|
95
|
+
|
|
96
|
+
it('should throw error if vector dimensions dont match', async () => {
|
|
97
|
+
const vectors = [[1, 2, 3, 4]]; // 4D vector for 3D index
|
|
98
|
+
await expect(pgVector.upsert(testIndexName, vectors)).rejects.toThrow();
|
|
99
|
+
});
|
|
100
|
+
});
|
|
101
|
+
|
|
102
|
+
describe('query', () => {
|
|
103
|
+
const indexName = 'test_query_2';
|
|
104
|
+
beforeAll(async () => {
|
|
105
|
+
// Drop if exists first
|
|
106
|
+
try {
|
|
107
|
+
await pgVector.deleteIndex(indexName);
|
|
108
|
+
} catch (e) {
|
|
109
|
+
// Ignore if doesn't exist
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
// Create fresh
|
|
113
|
+
await pgVector.createIndex(indexName, 3);
|
|
114
|
+
});
|
|
115
|
+
|
|
116
|
+
beforeEach(async () => {
|
|
117
|
+
// Clear the table first
|
|
118
|
+
await pgVector.truncateIndex(indexName);
|
|
119
|
+
|
|
120
|
+
const vectors = [
|
|
121
|
+
[1, 0, 0],
|
|
122
|
+
[0.8, 0.2, 0],
|
|
123
|
+
[0, 1, 0],
|
|
124
|
+
];
|
|
125
|
+
const metadata = [
|
|
126
|
+
{ type: 'a', value: 1 },
|
|
127
|
+
{ type: 'b', value: 2 },
|
|
128
|
+
{ type: 'a', value: 3 },
|
|
129
|
+
];
|
|
130
|
+
await pgVector.upsert(indexName, vectors, metadata);
|
|
131
|
+
});
|
|
132
|
+
|
|
133
|
+
afterAll(async () => {
|
|
134
|
+
console.log('deleting index');
|
|
135
|
+
console.log(await pgVector.listIndexes());
|
|
136
|
+
await pgVector.deleteIndex(indexName);
|
|
137
|
+
console.log(await pgVector.listIndexes());
|
|
138
|
+
});
|
|
139
|
+
|
|
140
|
+
it('should return closest vectors', async () => {
|
|
141
|
+
const results = await pgVector.query(indexName, [1, 0, 0], 1);
|
|
142
|
+
expect(results).toHaveLength(1);
|
|
143
|
+
expect(results[0]?.score).toBeCloseTo(1, 5);
|
|
144
|
+
});
|
|
145
|
+
|
|
146
|
+
it('should respect topK parameter', async () => {
|
|
147
|
+
const results = await pgVector.query(indexName, [1, 0, 0], 2);
|
|
148
|
+
expect(results).toHaveLength(2);
|
|
149
|
+
});
|
|
150
|
+
|
|
151
|
+
it('should handle filters correctly', async () => {
|
|
152
|
+
const results = await pgVector.query(indexName, [1, 0, 0], 10, { type: 'a' });
|
|
153
|
+
|
|
154
|
+
expect(results).toHaveLength(1);
|
|
155
|
+
results.forEach(result => {
|
|
156
|
+
expect(result?.metadata?.type).toBe('a');
|
|
157
|
+
});
|
|
158
|
+
});
|
|
159
|
+
});
|
|
160
|
+
|
|
161
|
+
describe('listIndexes', () => {
|
|
162
|
+
const indexName = 'test_query_3';
|
|
163
|
+
beforeAll(async () => {
|
|
164
|
+
await pgVector.createIndex(indexName, 3);
|
|
165
|
+
});
|
|
166
|
+
|
|
167
|
+
afterAll(async () => {
|
|
168
|
+
await pgVector.deleteIndex(indexName);
|
|
169
|
+
});
|
|
170
|
+
|
|
171
|
+
it('should list all vector tables', async () => {
|
|
172
|
+
const indexes = await pgVector.listIndexes();
|
|
173
|
+
expect(indexes).toContain(indexName);
|
|
174
|
+
});
|
|
175
|
+
|
|
176
|
+
it('should return empty array when no indexes exist', async () => {
|
|
177
|
+
await pgVector.deleteIndex(indexName);
|
|
178
|
+
const indexes = await pgVector.listIndexes();
|
|
179
|
+
expect(indexes).toEqual([]);
|
|
180
|
+
});
|
|
181
|
+
});
|
|
182
|
+
|
|
183
|
+
describe('describeIndex', () => {
|
|
184
|
+
const indexName = 'test_query_4';
|
|
185
|
+
beforeAll(async () => {
|
|
186
|
+
await pgVector.createIndex(indexName, 3);
|
|
187
|
+
});
|
|
188
|
+
|
|
189
|
+
afterAll(async () => {
|
|
190
|
+
await pgVector.deleteIndex(indexName);
|
|
191
|
+
});
|
|
192
|
+
it('should return correct index stats', async () => {
|
|
193
|
+
await pgVector.createIndex(indexName, 3, 'cosine');
|
|
194
|
+
const vectors = [
|
|
195
|
+
[1, 2, 3],
|
|
196
|
+
[4, 5, 6],
|
|
197
|
+
];
|
|
198
|
+
await pgVector.upsert(indexName, vectors);
|
|
199
|
+
|
|
200
|
+
const stats = await pgVector.describeIndex(indexName);
|
|
201
|
+
expect(stats).toEqual({
|
|
202
|
+
dimension: 3,
|
|
203
|
+
count: 2,
|
|
204
|
+
metric: 'cosine',
|
|
205
|
+
});
|
|
206
|
+
});
|
|
207
|
+
|
|
208
|
+
it('should throw error for non-existent index', async () => {
|
|
209
|
+
await expect(pgVector.describeIndex('non_existent')).rejects.toThrow();
|
|
210
|
+
});
|
|
211
|
+
});
|
|
212
|
+
});
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
import { describe, it, expect } from '@jest/globals';
|
|
2
|
+
import dotenv from 'dotenv';
|
|
3
|
+
|
|
4
|
+
import { PineconeVector } from './';
|
|
5
|
+
|
|
6
|
+
dotenv.config();
|
|
7
|
+
|
|
8
|
+
const PINECONE_API_KEY = process.env.PINECONE_API_KEY!;
|
|
9
|
+
|
|
10
|
+
// if (!PINECONE_API_KEY) {
|
|
11
|
+
// throw new Error('Please set PINECONE_API_KEY and PINECONE_ENVIRONMENT in .env file');
|
|
12
|
+
// }
|
|
13
|
+
// TODO: skip until we the secrets on Github
|
|
14
|
+
describe.skip('PineconeVector Integration Tests', () => {
|
|
15
|
+
let pineconeVector: PineconeVector;
|
|
16
|
+
const testIndexName = 'test-index-' + Date.now(); // Unique index name for each test run
|
|
17
|
+
const dimension = 3;
|
|
18
|
+
|
|
19
|
+
beforeAll(async () => {
|
|
20
|
+
pineconeVector = new PineconeVector(PINECONE_API_KEY);
|
|
21
|
+
// Create test index
|
|
22
|
+
await pineconeVector.createIndex(testIndexName, dimension);
|
|
23
|
+
// Wait for index to be ready
|
|
24
|
+
await new Promise(resolve => setTimeout(resolve, 60000));
|
|
25
|
+
}, 500000);
|
|
26
|
+
|
|
27
|
+
afterAll(async () => {
|
|
28
|
+
// Cleanup: delete test index
|
|
29
|
+
await pineconeVector.deleteIndex(testIndexName);
|
|
30
|
+
}, 500000);
|
|
31
|
+
|
|
32
|
+
describe('Index Operations', () => {
|
|
33
|
+
it('should list indexes including our test index', async () => {
|
|
34
|
+
const indexes = await pineconeVector.listIndexes();
|
|
35
|
+
expect(indexes).toContain(testIndexName);
|
|
36
|
+
}, 500000);
|
|
37
|
+
|
|
38
|
+
it('should describe index with correct properties', async () => {
|
|
39
|
+
const stats = await pineconeVector.describeIndex(testIndexName);
|
|
40
|
+
expect(stats.dimension).toBe(dimension);
|
|
41
|
+
expect(stats.metric).toBe('cosine');
|
|
42
|
+
expect(typeof stats.count).toBe('number');
|
|
43
|
+
}, 500000);
|
|
44
|
+
});
|
|
45
|
+
|
|
46
|
+
describe('Vector Operations', () => {
|
|
47
|
+
const testVectors = [
|
|
48
|
+
[1.0, 0.0, 0.0],
|
|
49
|
+
[0.0, 1.0, 0.0],
|
|
50
|
+
[0.0, 0.0, 1.0],
|
|
51
|
+
];
|
|
52
|
+
const testMetadata = [{ label: 'x-axis' }, { label: 'y-axis' }, { label: 'z-axis' }];
|
|
53
|
+
let vectorIds: string[];
|
|
54
|
+
|
|
55
|
+
it('should upsert vectors with metadata', async () => {
|
|
56
|
+
vectorIds = await pineconeVector.upsert(testIndexName, testVectors, testMetadata);
|
|
57
|
+
expect(vectorIds).toHaveLength(3);
|
|
58
|
+
// Wait for vectors to be indexed
|
|
59
|
+
await new Promise(resolve => setTimeout(resolve, 5000));
|
|
60
|
+
}, 500000);
|
|
61
|
+
|
|
62
|
+
it.skip('should query vectors and return nearest neighbors', async () => {
|
|
63
|
+
const queryVector = [1.0, 0.1, 0.1];
|
|
64
|
+
const results = await pineconeVector.query(testIndexName, queryVector, 3);
|
|
65
|
+
|
|
66
|
+
expect(results).toHaveLength(3);
|
|
67
|
+
expect(results[0]!.score).toBeGreaterThan(0);
|
|
68
|
+
expect(results[0]!.metadata).toBeDefined();
|
|
69
|
+
}, 500000);
|
|
70
|
+
|
|
71
|
+
it.skip('should query vectors with metadata filter', async () => {
|
|
72
|
+
const queryVector = [0.0, 1.0, 0.0];
|
|
73
|
+
const filter = { label: 'y-axis' };
|
|
74
|
+
|
|
75
|
+
const results = await pineconeVector.query(testIndexName, queryVector, 1, filter);
|
|
76
|
+
|
|
77
|
+
expect(results).toHaveLength(1);
|
|
78
|
+
expect(results?.[0]?.metadata?.label).toBe('y-axis');
|
|
79
|
+
}, 500000);
|
|
80
|
+
});
|
|
81
|
+
|
|
82
|
+
describe('Error Handling', () => {
|
|
83
|
+
it('should handle non-existent index query gracefully', async () => {
|
|
84
|
+
const nonExistentIndex = 'non-existent-index';
|
|
85
|
+
await expect(pineconeVector.query(nonExistentIndex, [1, 0, 0])).rejects.toThrow();
|
|
86
|
+
}, 500000);
|
|
87
|
+
|
|
88
|
+
it('should handle incorrect dimension vectors', async () => {
|
|
89
|
+
const wrongDimVector = [[1, 0]]; // 2D vector for 3D index
|
|
90
|
+
await expect(pineconeVector.upsert(testIndexName, wrongDimVector)).rejects.toThrow();
|
|
91
|
+
}, 500000);
|
|
92
|
+
});
|
|
93
|
+
|
|
94
|
+
describe('Performance Tests', () => {
|
|
95
|
+
it('should handle batch upsert of 1000 vectors', async () => {
|
|
96
|
+
const batchSize = 1000;
|
|
97
|
+
const vectors = Array(batchSize)
|
|
98
|
+
.fill(null)
|
|
99
|
+
.map(() =>
|
|
100
|
+
Array(dimension)
|
|
101
|
+
.fill(null)
|
|
102
|
+
.map(() => Math.random()),
|
|
103
|
+
);
|
|
104
|
+
const metadata = vectors.map((_, i) => ({ id: i }));
|
|
105
|
+
|
|
106
|
+
const start = Date.now();
|
|
107
|
+
const ids = await pineconeVector.upsert(testIndexName, vectors, metadata);
|
|
108
|
+
const duration = Date.now() - start;
|
|
109
|
+
|
|
110
|
+
expect(ids).toHaveLength(batchSize);
|
|
111
|
+
console.log(`Batch upsert of ${batchSize} vectors took ${duration}ms`);
|
|
112
|
+
}, 300000); // 5 minute timeout
|
|
113
|
+
|
|
114
|
+
it('should perform multiple concurrent queries', async () => {
|
|
115
|
+
const queryVector = [1, 0, 0];
|
|
116
|
+
const numQueries = 10;
|
|
117
|
+
|
|
118
|
+
const start = Date.now();
|
|
119
|
+
const promises = Array(numQueries)
|
|
120
|
+
.fill(null)
|
|
121
|
+
.map(() => pineconeVector.query(testIndexName, queryVector));
|
|
122
|
+
|
|
123
|
+
const results = await Promise.all(promises);
|
|
124
|
+
const duration = Date.now() - start;
|
|
125
|
+
|
|
126
|
+
expect(results).toHaveLength(numQueries);
|
|
127
|
+
console.log(`${numQueries} concurrent queries took ${duration}ms`);
|
|
128
|
+
}, 500000);
|
|
129
|
+
});
|
|
130
|
+
});
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import { MastraVector, QueryResult, IndexStats } from '@mastra/core';
|
|
2
|
+
import { Pinecone } from '@pinecone-database/pinecone';
|
|
3
|
+
|
|
4
|
+
export class PineconeVector extends MastraVector {
|
|
5
|
+
private client: Pinecone;
|
|
6
|
+
|
|
7
|
+
constructor(apiKey: string, environment?: string) {
|
|
8
|
+
super();
|
|
9
|
+
|
|
10
|
+
const opts: { apiKey: string; controllerHostUrl?: string } = { apiKey };
|
|
11
|
+
|
|
12
|
+
if (environment) {
|
|
13
|
+
opts['controllerHostUrl'] = environment;
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
const baseClient = new Pinecone(opts);
|
|
17
|
+
const telemetry = this.__getTelemetry();
|
|
18
|
+
this.client =
|
|
19
|
+
telemetry?.traceClass(baseClient, {
|
|
20
|
+
spanNamePrefix: 'pinecone-vector',
|
|
21
|
+
attributes: {
|
|
22
|
+
'vector.type': 'pinecone',
|
|
23
|
+
},
|
|
24
|
+
}) ?? baseClient;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
async createIndex(
|
|
28
|
+
indexName: string,
|
|
29
|
+
dimension: number,
|
|
30
|
+
metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
|
|
31
|
+
): Promise<void> {
|
|
32
|
+
await this.client.createIndex({
|
|
33
|
+
name: indexName,
|
|
34
|
+
dimension: dimension,
|
|
35
|
+
metric: metric,
|
|
36
|
+
spec: {
|
|
37
|
+
serverless: {
|
|
38
|
+
cloud: 'aws',
|
|
39
|
+
region: 'us-east-1',
|
|
40
|
+
},
|
|
41
|
+
},
|
|
42
|
+
});
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
async upsert(
|
|
46
|
+
indexName: string,
|
|
47
|
+
vectors: number[][],
|
|
48
|
+
metadata?: Record<string, any>[],
|
|
49
|
+
ids?: string[],
|
|
50
|
+
): Promise<string[]> {
|
|
51
|
+
const index = this.client.Index(indexName);
|
|
52
|
+
|
|
53
|
+
// Generate IDs if not provided
|
|
54
|
+
const vectorIds = ids || vectors.map(() => crypto.randomUUID());
|
|
55
|
+
|
|
56
|
+
const records = vectors.map((vector, i) => ({
|
|
57
|
+
id: vectorIds[i]!,
|
|
58
|
+
values: vector,
|
|
59
|
+
metadata: metadata?.[i] || {},
|
|
60
|
+
}));
|
|
61
|
+
|
|
62
|
+
// Pinecone has a limit of 100 vectors per upsert request
|
|
63
|
+
const batchSize = 100;
|
|
64
|
+
for (let i = 0; i < records.length; i += batchSize) {
|
|
65
|
+
const batch = records.slice(i, i + batchSize);
|
|
66
|
+
await index.upsert(batch);
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
return vectorIds;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
async query(
|
|
73
|
+
indexName: string,
|
|
74
|
+
queryVector: number[],
|
|
75
|
+
topK: number = 10,
|
|
76
|
+
filter?: Record<string, any>,
|
|
77
|
+
): Promise<QueryResult[]> {
|
|
78
|
+
const index = this.client.Index(indexName);
|
|
79
|
+
|
|
80
|
+
const results = await index.query({
|
|
81
|
+
vector: queryVector,
|
|
82
|
+
topK,
|
|
83
|
+
filter,
|
|
84
|
+
includeMetadata: true,
|
|
85
|
+
});
|
|
86
|
+
|
|
87
|
+
return results.matches.map(match => ({
|
|
88
|
+
id: match.id,
|
|
89
|
+
score: match.score || 0,
|
|
90
|
+
metadata: match.metadata as Record<string, any>,
|
|
91
|
+
}));
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
async listIndexes(): Promise<string[]> {
|
|
95
|
+
const indexesResult = await this.client.listIndexes();
|
|
96
|
+
return indexesResult?.indexes?.map(index => index.name) || [];
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
async describeIndex(indexName: string): Promise<IndexStats> {
|
|
100
|
+
const index = this.client.Index(indexName);
|
|
101
|
+
const stats = await index.describeIndexStats();
|
|
102
|
+
const description = await this.client.describeIndex(indexName);
|
|
103
|
+
|
|
104
|
+
return {
|
|
105
|
+
dimension: description.dimension,
|
|
106
|
+
count: stats.totalRecordCount || 0,
|
|
107
|
+
metric: description.metric as 'cosine' | 'euclidean' | 'dotproduct',
|
|
108
|
+
};
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
async deleteIndex(indexName: string): Promise<void> {
|
|
112
|
+
try {
|
|
113
|
+
await this.client.deleteIndex(indexName);
|
|
114
|
+
} catch (error: any) {
|
|
115
|
+
throw new Error(`Failed to delete Pinecone index: ${error.message}`);
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
}
|