@mastra/mongodb 0.0.2-alpha.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +23 -0
- package/CHANGELOG.md +15 -0
- package/LICENSE.md +46 -0
- package/README.md +140 -0
- package/dist/_tsup-dts-rollup.d.cts +88 -0
- package/dist/_tsup-dts-rollup.d.ts +88 -0
- package/dist/index.cjs +363 -0
- package/dist/index.d.cts +6 -0
- package/dist/index.d.ts +6 -0
- package/dist/index.js +361 -0
- package/docker-compose.yml +8 -0
- package/eslint.config.js +6 -0
- package/package.json +48 -0
- package/src/index.ts +1 -0
- package/src/vector/filter.test.ts +415 -0
- package/src/vector/filter.ts +124 -0
- package/src/vector/index.test.ts +448 -0
- package/src/vector/index.ts +380 -0
- package/tsconfig.json +5 -0
- package/vitest.config.ts +11 -0
|
@@ -0,0 +1,380 @@
|
|
|
1
|
+
import { MastraVector } from '@mastra/core/vector';
|
|
2
|
+
import type {
|
|
3
|
+
QueryResult,
|
|
4
|
+
IndexStats,
|
|
5
|
+
CreateIndexParams,
|
|
6
|
+
UpsertVectorParams,
|
|
7
|
+
QueryVectorParams,
|
|
8
|
+
ParamsToArgs,
|
|
9
|
+
QueryVectorArgs,
|
|
10
|
+
UpsertVectorArgs,
|
|
11
|
+
} from '@mastra/core/vector';
|
|
12
|
+
import type { VectorFilter } from '@mastra/core/vector/filter';
|
|
13
|
+
import { MongoClient } from 'mongodb';
|
|
14
|
+
import type { MongoClientOptions, Document, Db, Collection } from 'mongodb';
|
|
15
|
+
import { v4 as uuidv4 } from 'uuid';
|
|
16
|
+
|
|
17
|
+
import { MongoDBFilterTranslator } from './filter';
|
|
18
|
+
|
|
19
|
+
// Define necessary types and interfaces
|
|
20
|
+
export type MongoDBUpsertArgs = [...UpsertVectorArgs, string[]?];
|
|
21
|
+
export type MongoDBQueryArgs = [...QueryVectorArgs, string?];
|
|
22
|
+
export type MongoDBUpsertParams = ParamsToArgs<MongoDBUpsertArgs>;
|
|
23
|
+
|
|
24
|
+
export interface MongoDBUpsertVectorParams extends UpsertVectorParams {
|
|
25
|
+
documents?: string[];
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
export interface MongoDBQueryVectorParams extends QueryVectorParams {
|
|
29
|
+
documentFilter?: VectorFilter;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
// Define the document interface
|
|
33
|
+
interface MongoDBDocument extends Document {
|
|
34
|
+
_id: string; // Explicitly declare '_id' as string
|
|
35
|
+
embedding?: number[];
|
|
36
|
+
metadata?: Record<string, any>;
|
|
37
|
+
document?: string;
|
|
38
|
+
[key: string]: any; // Index signature for additional properties
|
|
39
|
+
}
|
|
40
|
+
// The MongoDBVector class
|
|
41
|
+
export class MongoDBVector extends MastraVector {
|
|
42
|
+
private client: MongoClient;
|
|
43
|
+
private db: Db;
|
|
44
|
+
private collections: Map<string, Collection<MongoDBDocument>>;
|
|
45
|
+
private readonly embeddingFieldName = 'embedding';
|
|
46
|
+
private readonly metadataFieldName = 'metadata';
|
|
47
|
+
private readonly documentFieldName = 'document';
|
|
48
|
+
private collectionForValidation: Collection<MongoDBDocument> | null = null;
|
|
49
|
+
private mongoMetricMap: { [key: string]: string } = {
|
|
50
|
+
cosine: 'cosine',
|
|
51
|
+
euclidean: 'euclidean',
|
|
52
|
+
dotproduct: 'dotProduct',
|
|
53
|
+
};
|
|
54
|
+
|
|
55
|
+
constructor({ uri, dbName, options }: { uri: string; dbName: string; options?: MongoClientOptions }) {
|
|
56
|
+
super();
|
|
57
|
+
this.client = new MongoClient(uri, options);
|
|
58
|
+
this.db = this.client.db(dbName);
|
|
59
|
+
this.collections = new Map();
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
// Public methods
|
|
63
|
+
async connect(): Promise<void> {
|
|
64
|
+
await this.client.connect();
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
async disconnect(): Promise<void> {
|
|
68
|
+
await this.client.close();
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
async createIndex(params: CreateIndexParams): Promise<void> {
|
|
72
|
+
const { indexName, dimension, metric = 'cosine' } = params;
|
|
73
|
+
|
|
74
|
+
if (!Number.isInteger(dimension) || dimension <= 0) {
|
|
75
|
+
throw new Error('Dimension must be a positive integer');
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
const mongoMetric = this.mongoMetricMap[metric];
|
|
79
|
+
if (!mongoMetric) {
|
|
80
|
+
throw new Error(`Invalid metric: "${metric}". Must be one of: cosine, euclidean, dotproduct`);
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
// Check if collection exists
|
|
84
|
+
const collectionExists = await this.db.listCollections({ name: indexName }).hasNext();
|
|
85
|
+
if (!collectionExists) {
|
|
86
|
+
await this.db.createCollection(indexName);
|
|
87
|
+
}
|
|
88
|
+
const collection = await this.getCollection(indexName);
|
|
89
|
+
|
|
90
|
+
const indexNameInternal = `${indexName}_vector_index`;
|
|
91
|
+
|
|
92
|
+
const embeddingField = this.embeddingFieldName;
|
|
93
|
+
const numDimensions = dimension;
|
|
94
|
+
|
|
95
|
+
try {
|
|
96
|
+
// Create the search index
|
|
97
|
+
await (collection as any).createSearchIndex({
|
|
98
|
+
definition: {
|
|
99
|
+
fields: [
|
|
100
|
+
{
|
|
101
|
+
type: 'vector',
|
|
102
|
+
path: embeddingField,
|
|
103
|
+
numDimensions: numDimensions,
|
|
104
|
+
similarity: mongoMetric,
|
|
105
|
+
},
|
|
106
|
+
],
|
|
107
|
+
},
|
|
108
|
+
name: indexNameInternal,
|
|
109
|
+
type: 'vectorSearch',
|
|
110
|
+
});
|
|
111
|
+
} catch (error: any) {
|
|
112
|
+
if (error.codeName !== 'IndexAlreadyExists') {
|
|
113
|
+
throw error;
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
// Store the dimension and metric in a special metadata document
|
|
118
|
+
await collection.updateOne({ _id: '__index_metadata__' }, { $set: { dimension, metric } }, { upsert: true });
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
async waitForIndexReady(indexName: string, timeoutMs: number = 60000, checkIntervalMs: number = 2000): Promise<void> {
|
|
122
|
+
const collection = await this.getCollection(indexName, true);
|
|
123
|
+
const indexNameInternal = `${indexName}_vector_index`;
|
|
124
|
+
|
|
125
|
+
const startTime = Date.now();
|
|
126
|
+
while (Date.now() - startTime < timeoutMs) {
|
|
127
|
+
const indexInfo: any[] = await (collection as any).listSearchIndexes().toArray();
|
|
128
|
+
const indexData = indexInfo.find((idx: any) => idx.name === indexNameInternal);
|
|
129
|
+
const status = indexData?.status;
|
|
130
|
+
if (status === 'READY') {
|
|
131
|
+
return;
|
|
132
|
+
}
|
|
133
|
+
await new Promise(resolve => setTimeout(resolve, checkIntervalMs));
|
|
134
|
+
}
|
|
135
|
+
throw new Error(`Index "${indexNameInternal}" did not become ready within timeout`);
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
async upsert(params: MongoDBUpsertVectorParams): Promise<string[]> {
|
|
139
|
+
const { indexName, vectors, metadata, ids, documents } = params;
|
|
140
|
+
|
|
141
|
+
const collection = await this.getCollection(indexName);
|
|
142
|
+
|
|
143
|
+
this.collectionForValidation = collection;
|
|
144
|
+
|
|
145
|
+
// Get index stats to check dimension
|
|
146
|
+
const stats = await this.describeIndex(indexName);
|
|
147
|
+
|
|
148
|
+
// Validate vector dimensions
|
|
149
|
+
await this.validateVectorDimensions(vectors, stats.dimension);
|
|
150
|
+
|
|
151
|
+
// Generate IDs if not provided
|
|
152
|
+
const generatedIds = ids || vectors.map(() => uuidv4());
|
|
153
|
+
|
|
154
|
+
const operations = vectors.map((vector, idx) => {
|
|
155
|
+
const id = generatedIds[idx];
|
|
156
|
+
const meta = metadata?.[idx] || {};
|
|
157
|
+
const doc = documents?.[idx];
|
|
158
|
+
|
|
159
|
+
// Normalize metadata - convert Date objects to ISO strings
|
|
160
|
+
const normalizedMeta = Object.keys(meta).reduce(
|
|
161
|
+
(acc, key) => {
|
|
162
|
+
acc[key] = meta[key] instanceof Date ? meta[key].toISOString() : meta[key];
|
|
163
|
+
return acc;
|
|
164
|
+
},
|
|
165
|
+
{} as Record<string, any>,
|
|
166
|
+
);
|
|
167
|
+
|
|
168
|
+
const updateDoc: Partial<MongoDBDocument> = {
|
|
169
|
+
[this.embeddingFieldName]: vector,
|
|
170
|
+
[this.metadataFieldName]: normalizedMeta,
|
|
171
|
+
};
|
|
172
|
+
if (doc !== undefined) {
|
|
173
|
+
updateDoc[this.documentFieldName] = doc;
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
return {
|
|
177
|
+
updateOne: {
|
|
178
|
+
filter: { _id: id }, // '_id' is a string as per MongoDBDocument interface
|
|
179
|
+
update: { $set: updateDoc },
|
|
180
|
+
upsert: true,
|
|
181
|
+
},
|
|
182
|
+
};
|
|
183
|
+
});
|
|
184
|
+
|
|
185
|
+
await collection.bulkWrite(operations);
|
|
186
|
+
|
|
187
|
+
return generatedIds;
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
async query(params: MongoDBQueryVectorParams): Promise<QueryResult[]> {
|
|
191
|
+
const { indexName, queryVector, topK = 10, filter, includeVector = false, documentFilter } = params;
|
|
192
|
+
|
|
193
|
+
const collection = await this.getCollection(indexName, true);
|
|
194
|
+
const indexNameInternal = `${indexName}_vector_index`;
|
|
195
|
+
|
|
196
|
+
// Transform the filters using MongoDBFilterTranslator
|
|
197
|
+
const mongoFilter = this.transformFilter(filter);
|
|
198
|
+
const documentMongoFilter = documentFilter ? { [this.documentFieldName]: documentFilter } : {};
|
|
199
|
+
|
|
200
|
+
// Combine the filters
|
|
201
|
+
let combinedFilter: any = {};
|
|
202
|
+
if (Object.keys(mongoFilter).length > 0 && Object.keys(documentMongoFilter).length > 0) {
|
|
203
|
+
combinedFilter = { $and: [mongoFilter, documentMongoFilter] };
|
|
204
|
+
} else if (Object.keys(mongoFilter).length > 0) {
|
|
205
|
+
combinedFilter = mongoFilter;
|
|
206
|
+
} else if (Object.keys(documentMongoFilter).length > 0) {
|
|
207
|
+
combinedFilter = documentMongoFilter;
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
// Build the aggregation pipeline
|
|
211
|
+
const pipeline = [
|
|
212
|
+
{
|
|
213
|
+
$vectorSearch: {
|
|
214
|
+
index: indexNameInternal,
|
|
215
|
+
queryVector: queryVector,
|
|
216
|
+
path: this.embeddingFieldName,
|
|
217
|
+
numCandidates: 100,
|
|
218
|
+
limit: topK,
|
|
219
|
+
},
|
|
220
|
+
},
|
|
221
|
+
// Apply the filter using $match stage
|
|
222
|
+
...(Object.keys(combinedFilter).length > 0 ? [{ $match: combinedFilter }] : []),
|
|
223
|
+
{
|
|
224
|
+
$set: { score: { $meta: 'vectorSearchScore' } },
|
|
225
|
+
},
|
|
226
|
+
{
|
|
227
|
+
$project: {
|
|
228
|
+
_id: 1,
|
|
229
|
+
score: 1,
|
|
230
|
+
metadata: `$${this.metadataFieldName}`,
|
|
231
|
+
document: `$${this.documentFieldName}`,
|
|
232
|
+
...(includeVector && { vector: `$${this.embeddingFieldName}` }),
|
|
233
|
+
},
|
|
234
|
+
},
|
|
235
|
+
];
|
|
236
|
+
|
|
237
|
+
try {
|
|
238
|
+
const results = await collection.aggregate(pipeline).toArray();
|
|
239
|
+
|
|
240
|
+
return results.map((result: any) => ({
|
|
241
|
+
id: result._id,
|
|
242
|
+
score: result.score,
|
|
243
|
+
metadata: result.metadata,
|
|
244
|
+
vector: includeVector ? result.vector : undefined,
|
|
245
|
+
document: result.document,
|
|
246
|
+
}));
|
|
247
|
+
} catch (error) {
|
|
248
|
+
console.error('Error during vector search:', error);
|
|
249
|
+
throw error;
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
async listIndexes(): Promise<string[]> {
|
|
254
|
+
const collections = await this.db.listCollections().toArray();
|
|
255
|
+
return collections.map(col => col.name);
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
async describeIndex(indexName: string): Promise<IndexStats> {
|
|
259
|
+
const collection = await this.getCollection(indexName, true);
|
|
260
|
+
|
|
261
|
+
// Get the count of documents, excluding the metadata document
|
|
262
|
+
const count = await collection.countDocuments({ _id: { $ne: '__index_metadata__' } });
|
|
263
|
+
|
|
264
|
+
// Retrieve the dimension and metric from the metadata document
|
|
265
|
+
const metadataDoc = await collection.findOne({ _id: '__index_metadata__' });
|
|
266
|
+
const dimension = metadataDoc?.dimension || 0;
|
|
267
|
+
const metric = metadataDoc?.metric || 'cosine';
|
|
268
|
+
|
|
269
|
+
return {
|
|
270
|
+
dimension,
|
|
271
|
+
count,
|
|
272
|
+
metric: metric as 'cosine' | 'euclidean' | 'dotproduct',
|
|
273
|
+
};
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
async deleteIndex(indexName: string): Promise<void> {
|
|
277
|
+
const collection = await this.getCollection(indexName, false); // Do not throw error if collection doesn't exist
|
|
278
|
+
if (collection) {
|
|
279
|
+
await collection.drop();
|
|
280
|
+
this.collections.delete(indexName);
|
|
281
|
+
} else {
|
|
282
|
+
// Optionally, you can log or handle the case where the collection doesn't exist
|
|
283
|
+
throw new Error(`Index (Collection) "${indexName}" does not exist`);
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
async updateIndexById(
|
|
288
|
+
indexName: string,
|
|
289
|
+
id: string,
|
|
290
|
+
update: { vector?: number[]; metadata?: Record<string, any> },
|
|
291
|
+
): Promise<void> {
|
|
292
|
+
if (!update.vector && !update.metadata) {
|
|
293
|
+
throw new Error('No updates provided');
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
const collection = await this.getCollection(indexName, true);
|
|
297
|
+
const updateDoc: Record<string, any> = {};
|
|
298
|
+
|
|
299
|
+
if (update.vector) {
|
|
300
|
+
updateDoc[this.embeddingFieldName] = update.vector;
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
if (update.metadata) {
|
|
304
|
+
// Normalize metadata in updates too
|
|
305
|
+
const normalizedMeta = Object.keys(update.metadata).reduce(
|
|
306
|
+
(acc, key) => {
|
|
307
|
+
acc[key] =
|
|
308
|
+
update.metadata![key] instanceof Date ? update.metadata![key].toISOString() : update.metadata![key];
|
|
309
|
+
return acc;
|
|
310
|
+
},
|
|
311
|
+
{} as Record<string, any>,
|
|
312
|
+
);
|
|
313
|
+
|
|
314
|
+
updateDoc[this.metadataFieldName] = normalizedMeta;
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
await collection.findOneAndUpdate({ _id: id }, { $set: updateDoc });
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
async deleteIndexById(indexName: string, id: string): Promise<void> {
|
|
321
|
+
try {
|
|
322
|
+
const collection = await this.getCollection(indexName, true);
|
|
323
|
+
await collection.deleteOne({ _id: id });
|
|
324
|
+
} catch (error: any) {
|
|
325
|
+
throw new Error(`Failed to delete index by id: ${id} for index name: ${indexName}: ${error.message}`);
|
|
326
|
+
}
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
// Private methods
|
|
330
|
+
private async getCollection(
|
|
331
|
+
indexName: string,
|
|
332
|
+
throwIfNotExists: boolean = true,
|
|
333
|
+
): Promise<Collection<MongoDBDocument>> {
|
|
334
|
+
if (this.collections.has(indexName)) {
|
|
335
|
+
return this.collections.get(indexName)!;
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
const collection = this.db.collection<MongoDBDocument>(indexName);
|
|
339
|
+
|
|
340
|
+
// Check if collection exists
|
|
341
|
+
const collectionExists = await this.db.listCollections({ name: indexName }).hasNext();
|
|
342
|
+
if (!collectionExists && throwIfNotExists) {
|
|
343
|
+
throw new Error(`Index (Collection) "${indexName}" does not exist`);
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
this.collections.set(indexName, collection);
|
|
347
|
+
return collection;
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
private async validateVectorDimensions(vectors: number[][], dimension: number): Promise<void> {
|
|
351
|
+
if (vectors.length === 0) {
|
|
352
|
+
throw new Error('No vectors provided for validation');
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
if (dimension === 0) {
|
|
356
|
+
// If dimension is not set, retrieve and set it from the vectors
|
|
357
|
+
dimension = vectors[0] ? vectors[0].length : 0;
|
|
358
|
+
await this.setIndexDimension(dimension);
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
for (let i = 0; i < vectors.length; i++) {
|
|
362
|
+
let v = vectors[i]?.length;
|
|
363
|
+
if (v !== dimension) {
|
|
364
|
+
throw new Error(`Vector at index ${i} has invalid dimension ${v}. Expected ${dimension} dimensions.`);
|
|
365
|
+
}
|
|
366
|
+
}
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
private async setIndexDimension(dimension: number): Promise<void> {
|
|
370
|
+
// Store the dimension in a special metadata document
|
|
371
|
+
const collection = this.collectionForValidation!; // 'collectionForValidation' is set in 'upsert' method
|
|
372
|
+
await collection.updateOne({ _id: '__index_metadata__' }, { $set: { dimension } }, { upsert: true });
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
private transformFilter(filter?: VectorFilter): any {
|
|
376
|
+
const translator = new MongoDBFilterTranslator();
|
|
377
|
+
if (!filter) return {};
|
|
378
|
+
return translator.translate(filter);
|
|
379
|
+
}
|
|
380
|
+
}
|
package/tsconfig.json
ADDED