@mastra/mongodb 0.10.0 → 0.10.1-alpha.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,674 @@
1
+ import type { MetricResult, TestInfo } from '@mastra/core/eval';
2
+ import type { MessageType, StorageThreadType } from '@mastra/core/memory';
3
+ import type { EvalRow, StorageGetMessagesArg, TABLE_NAMES, WorkflowRun } from '@mastra/core/storage';
4
+ import {
5
+ MastraStorage,
6
+ TABLE_EVALS,
7
+ TABLE_MESSAGES,
8
+ TABLE_THREADS,
9
+ TABLE_TRACES,
10
+ TABLE_WORKFLOW_SNAPSHOT,
11
+ } from '@mastra/core/storage';
12
+ import type { WorkflowRunState } from '@mastra/core/workflows';
13
+ import type { Db } from 'mongodb';
14
+ import { MongoClient } from 'mongodb';
15
+
16
+ function safelyParseJSON(jsonString: string): any {
17
+ try {
18
+ return JSON.parse(jsonString);
19
+ } catch {
20
+ return {};
21
+ }
22
+ }
23
+
24
+ export interface MongoDBConfig {
25
+ url: string;
26
+ dbName: string;
27
+ }
28
+
29
+ export class MongoDBStore extends MastraStorage {
30
+ #isConnected = false;
31
+ #client: MongoClient;
32
+ #db: Db | undefined;
33
+ readonly #dbName: string;
34
+
35
+ constructor(config: MongoDBConfig) {
36
+ super({ name: 'MongoDBStore' });
37
+ this.#isConnected = false;
38
+
39
+ if (!config.url?.trim().length) {
40
+ throw new Error(
41
+ 'MongoDBStore: url must be provided and cannot be empty. Passing an empty string may cause fallback to local MongoDB defaults.',
42
+ );
43
+ }
44
+
45
+ if (!config.dbName?.trim().length) {
46
+ throw new Error(
47
+ 'MongoDBStore: dbName must be provided and cannot be empty. Passing an empty string may cause fallback to local MongoDB defaults.',
48
+ );
49
+ }
50
+
51
+ this.#dbName = config.dbName;
52
+ this.#client = new MongoClient(config.url);
53
+ }
54
+
55
+ private async getConnection(): Promise<Db> {
56
+ if (this.#isConnected) {
57
+ return this.#db!;
58
+ }
59
+
60
+ await this.#client.connect();
61
+ this.#db = this.#client.db(this.#dbName);
62
+ this.#isConnected = true;
63
+ return this.#db;
64
+ }
65
+
66
+ private async getCollection(collectionName: string) {
67
+ const db = await this.getConnection();
68
+ return db.collection(collectionName);
69
+ }
70
+
71
+ async createTable(): Promise<void> {
72
+ // Nothing to do here, MongoDB is schemaless
73
+ }
74
+
75
+ async clearTable({ tableName }: { tableName: TABLE_NAMES }): Promise<void> {
76
+ try {
77
+ const collection = await this.getCollection(tableName);
78
+ await collection.deleteMany({});
79
+ } catch (error) {
80
+ if (error instanceof Error) {
81
+ this.logger.error(error.message);
82
+ }
83
+ }
84
+ }
85
+
86
+ async insert({ tableName, record }: { tableName: TABLE_NAMES; record: Record<string, any> }): Promise<void> {
87
+ try {
88
+ const collection = await this.getCollection(tableName);
89
+ await collection.insertOne(record);
90
+ } catch (error) {
91
+ this.logger.error(`Error upserting into table ${tableName}: ${error}`);
92
+ throw error;
93
+ }
94
+ }
95
+
96
+ async batchInsert({ tableName, records }: { tableName: TABLE_NAMES; records: Record<string, any>[] }): Promise<void> {
97
+ if (!records.length) {
98
+ return;
99
+ }
100
+
101
+ try {
102
+ const collection = await this.getCollection(tableName);
103
+ await collection.insertMany(records);
104
+ } catch (error) {
105
+ this.logger.error(`Error upserting into table ${tableName}: ${error}`);
106
+ throw error;
107
+ }
108
+ }
109
+
110
+ async load<R>({ tableName, keys }: { tableName: TABLE_NAMES; keys: Record<string, string> }): Promise<R | null> {
111
+ this.logger.info(`Loading ${tableName} with keys ${JSON.stringify(keys)}`);
112
+ try {
113
+ const collection = await this.getCollection(tableName);
114
+ return (await collection.find(keys).toArray()) as R;
115
+ } catch (error) {
116
+ this.logger.error(`Error loading ${tableName} with keys ${JSON.stringify(keys)}: ${error}`);
117
+ throw error;
118
+ }
119
+ }
120
+
121
+ async getThreadById({ threadId }: { threadId: string }): Promise<StorageThreadType | null> {
122
+ try {
123
+ const collection = await this.getCollection(TABLE_THREADS);
124
+ const result = await collection.findOne<any>({ id: threadId });
125
+ if (!result) {
126
+ return null;
127
+ }
128
+
129
+ return {
130
+ ...result,
131
+ metadata: typeof result.metadata === 'string' ? JSON.parse(result.metadata) : result.metadata,
132
+ };
133
+ } catch (error) {
134
+ this.logger.error(`Error loading thread with ID ${threadId}: ${error}`);
135
+ throw error;
136
+ }
137
+ }
138
+
139
+ async getThreadsByResourceId({ resourceId }: { resourceId: string }): Promise<StorageThreadType[]> {
140
+ try {
141
+ const collection = await this.getCollection(TABLE_THREADS);
142
+ const results = await collection.find<any>({ resourceId }).toArray();
143
+ if (!results.length) {
144
+ return [];
145
+ }
146
+
147
+ return results.map(result => ({
148
+ ...result,
149
+ metadata: typeof result.metadata === 'string' ? JSON.parse(result.metadata) : result.metadata,
150
+ }));
151
+ } catch (error) {
152
+ this.logger.error(`Error loading threads by resourceId ${resourceId}: ${error}`);
153
+ throw error;
154
+ }
155
+ }
156
+
157
+ async saveThread({ thread }: { thread: StorageThreadType }): Promise<StorageThreadType> {
158
+ try {
159
+ const collection = await this.getCollection(TABLE_THREADS);
160
+ await collection.updateOne(
161
+ { id: thread.id },
162
+ {
163
+ $set: {
164
+ ...thread,
165
+ metadata: JSON.stringify(thread.metadata),
166
+ },
167
+ },
168
+ { upsert: true },
169
+ );
170
+ return thread;
171
+ } catch (error) {
172
+ this.logger.error(`Error saving thread ${thread.id}: ${error}`);
173
+ throw error;
174
+ }
175
+ }
176
+
177
+ async updateThread({
178
+ id,
179
+ title,
180
+ metadata,
181
+ }: {
182
+ id: string;
183
+ title: string;
184
+ metadata: Record<string, unknown>;
185
+ }): Promise<StorageThreadType> {
186
+ const thread = await this.getThreadById({ threadId: id });
187
+ if (!thread) {
188
+ throw new Error(`Thread ${id} not found`);
189
+ }
190
+
191
+ const updatedThread = {
192
+ ...thread,
193
+ title,
194
+ metadata: {
195
+ ...thread.metadata,
196
+ ...metadata,
197
+ },
198
+ };
199
+
200
+ try {
201
+ const collection = await this.getCollection(TABLE_THREADS);
202
+ await collection.updateOne(
203
+ { id },
204
+ {
205
+ $set: {
206
+ title,
207
+ metadata: JSON.stringify(updatedThread.metadata),
208
+ },
209
+ },
210
+ );
211
+ } catch (error) {
212
+ this.logger.error(`Error updating thread ${id}:) ${error}`);
213
+ throw error;
214
+ }
215
+
216
+ return updatedThread;
217
+ }
218
+
219
+ async deleteThread({ threadId }: { threadId: string }): Promise<void> {
220
+ try {
221
+ // First, delete all messages associated with the thread
222
+ const collectionMessages = await this.getCollection(TABLE_MESSAGES);
223
+ await collectionMessages.deleteMany({ thread_id: threadId });
224
+ // Then delete the thread itself
225
+ const collectionThreads = await this.getCollection(TABLE_THREADS);
226
+ await collectionThreads.deleteOne({ id: threadId });
227
+ } catch (error) {
228
+ this.logger.error(`Error deleting thread ${threadId}: ${error}`);
229
+ throw error;
230
+ }
231
+ }
232
+
233
+ async getMessages<T = unknown>({ threadId, selectBy }: StorageGetMessagesArg): Promise<T[]> {
234
+ try {
235
+ const limit = typeof selectBy?.last === 'number' ? selectBy.last : 40;
236
+ const include = selectBy?.include || [];
237
+ let messages: MessageType[] = [];
238
+ let allMessages: MessageType[] = [];
239
+ const collection = await this.getCollection(TABLE_MESSAGES);
240
+ // Get all messages from the thread ordered by creation date descending
241
+ allMessages = (await collection.find({ thread_id: threadId }).sort({ createdAt: -1 }).toArray()).map((row: any) =>
242
+ this.parseRow(row),
243
+ );
244
+
245
+ // If there are messages to include, select the messages around the included IDs
246
+ if (include.length) {
247
+ // Map IDs to their position in the ordered array
248
+ const idToIndex = new Map<string, number>();
249
+ allMessages.forEach((msg, idx) => {
250
+ idToIndex.set(msg.id, idx);
251
+ });
252
+
253
+ const selectedIndexes = new Set<number>();
254
+ for (const inc of include) {
255
+ const idx = idToIndex.get(inc.id);
256
+ if (idx === undefined) continue;
257
+ // Previous messages
258
+ for (let i = 1; i <= (inc.withPreviousMessages || 0); i++) {
259
+ if (idx + i < allMessages.length) selectedIndexes.add(idx + i);
260
+ }
261
+ // Included message
262
+ selectedIndexes.add(idx);
263
+ // Next messages
264
+ for (let i = 1; i <= (inc.withNextMessages || 0); i++) {
265
+ if (idx - i >= 0) selectedIndexes.add(idx - i);
266
+ }
267
+ }
268
+ // Add the selected messages, filtering out undefined
269
+ messages.push(
270
+ ...Array.from(selectedIndexes)
271
+ .map(i => allMessages[i])
272
+ .filter((m): m is MessageType => !!m),
273
+ );
274
+ }
275
+
276
+ // Get the remaining messages, excluding those already selected
277
+ const excludeIds = new Set(messages.map(m => m.id));
278
+ for (const msg of allMessages) {
279
+ if (messages.length >= limit) break;
280
+ if (!excludeIds.has(msg.id)) {
281
+ messages.push(msg);
282
+ }
283
+ }
284
+
285
+ // Sort all messages by creation date ascending
286
+ messages.sort((a, b) => a.createdAt.getTime() - b.createdAt.getTime());
287
+
288
+ return messages.slice(0, limit) as T[];
289
+ } catch (error) {
290
+ this.logger.error('Error getting messages:', error as Error);
291
+ throw error;
292
+ }
293
+ }
294
+
295
+ async saveMessages({ messages }: { messages: MessageType[] }): Promise<MessageType[]> {
296
+ if (!messages.length) {
297
+ return messages;
298
+ }
299
+
300
+ const threadId = messages[0]?.threadId;
301
+ if (!threadId) {
302
+ this.logger.error('Thread ID is required to save messages');
303
+ throw new Error('Thread ID is required');
304
+ }
305
+ try {
306
+ // Prepare batch statements for all messages
307
+ const messagesToInsert = messages.map(message => {
308
+ const time = message.createdAt || new Date();
309
+ return {
310
+ id: message.id,
311
+ thread_id: threadId,
312
+ content: typeof message.content === 'string' ? message.content : JSON.stringify(message.content),
313
+ role: message.role,
314
+ type: message.type,
315
+ resourceId: message.resourceId,
316
+ createdAt: time instanceof Date ? time.toISOString() : time,
317
+ };
318
+ });
319
+
320
+ // Execute all inserts in a single batch
321
+ const collection = await this.getCollection(TABLE_MESSAGES);
322
+ await collection.insertMany(messagesToInsert);
323
+ return messages;
324
+ } catch (error) {
325
+ this.logger.error('Failed to save messages in database: ' + (error as { message: string })?.message);
326
+ throw error;
327
+ }
328
+ }
329
+
330
+ async getTraces(
331
+ {
332
+ name,
333
+ scope,
334
+ page,
335
+ perPage,
336
+ attributes,
337
+ filters,
338
+ }: {
339
+ name?: string;
340
+ scope?: string;
341
+ page: number;
342
+ perPage: number;
343
+ attributes?: Record<string, string>;
344
+ filters?: Record<string, any>;
345
+ } = {
346
+ page: 0,
347
+ perPage: 100,
348
+ },
349
+ ): Promise<any[]> {
350
+ const limit = perPage;
351
+ const offset = page * perPage;
352
+
353
+ const query: any = {};
354
+ if (name) {
355
+ query['name'] = `%${name}%`;
356
+ }
357
+
358
+ if (scope) {
359
+ query['scope'] = scope;
360
+ }
361
+
362
+ if (attributes) {
363
+ Object.keys(attributes).forEach(key => {
364
+ query[`attributes.${key}`] = attributes[key];
365
+ });
366
+ }
367
+
368
+ if (filters) {
369
+ Object.entries(filters).forEach(([key, value]) => {
370
+ query[key] = value;
371
+ });
372
+ }
373
+
374
+ const collection = await this.getCollection(TABLE_TRACES);
375
+ const result = await collection
376
+ .find(query, {
377
+ sort: { startTime: -1 },
378
+ })
379
+ .limit(limit)
380
+ .skip(offset)
381
+ .toArray();
382
+
383
+ return result.map(row => ({
384
+ id: row.id,
385
+ parentSpanId: row.parentSpanId,
386
+ traceId: row.traceId,
387
+ name: row.name,
388
+ scope: row.scope,
389
+ kind: row.kind,
390
+ status: safelyParseJSON(row.status as string),
391
+ events: safelyParseJSON(row.events as string),
392
+ links: safelyParseJSON(row.links as string),
393
+ attributes: safelyParseJSON(row.attributes as string),
394
+ startTime: row.startTime,
395
+ endTime: row.endTime,
396
+ other: safelyParseJSON(row.other as string),
397
+ createdAt: row.createdAt,
398
+ })) as any;
399
+ }
400
+
401
+ async getWorkflowRuns({
402
+ workflowName,
403
+ fromDate,
404
+ toDate,
405
+ limit,
406
+ offset,
407
+ }: {
408
+ workflowName?: string;
409
+ fromDate?: Date;
410
+ toDate?: Date;
411
+ limit?: number;
412
+ offset?: number;
413
+ } = {}): Promise<{
414
+ runs: Array<{
415
+ workflowName: string;
416
+ runId: string;
417
+ snapshot: WorkflowRunState | string;
418
+ createdAt: Date;
419
+ updatedAt: Date;
420
+ }>;
421
+ total: number;
422
+ }> {
423
+ const query: any = {};
424
+ if (workflowName) {
425
+ query['workflow_name'] = workflowName;
426
+ }
427
+
428
+ if (fromDate || toDate) {
429
+ query['createdAt'] = {};
430
+ if (fromDate) {
431
+ query['createdAt']['$gte'] = fromDate;
432
+ }
433
+ if (toDate) {
434
+ query['createdAt']['$lte'] = toDate;
435
+ }
436
+ }
437
+
438
+ const collection = await this.getCollection(TABLE_WORKFLOW_SNAPSHOT);
439
+ let total = 0;
440
+ // Only get total count when using pagination
441
+ if (limit !== undefined && offset !== undefined) {
442
+ total = await collection.countDocuments(query);
443
+ }
444
+
445
+ // Get results
446
+ const request = collection.find(query).sort({ createdAt: 'desc' });
447
+ if (limit) {
448
+ request.limit(limit);
449
+ }
450
+
451
+ if (offset) {
452
+ request.skip(offset);
453
+ }
454
+
455
+ const result = await request.toArray();
456
+ const runs = result.map(row => {
457
+ let parsedSnapshot: WorkflowRunState | string = row.snapshot;
458
+ if (typeof parsedSnapshot === 'string') {
459
+ try {
460
+ parsedSnapshot = JSON.parse(row.snapshot as string) as WorkflowRunState;
461
+ } catch (e) {
462
+ // If parsing fails, return the raw snapshot string
463
+ console.warn(`Failed to parse snapshot for workflow ${row.workflow_name}: ${e}`);
464
+ }
465
+ }
466
+
467
+ return {
468
+ workflowName: row.workflow_name as string,
469
+ runId: row.run_id as string,
470
+ snapshot: parsedSnapshot,
471
+ createdAt: new Date(row.createdAt as string),
472
+ updatedAt: new Date(row.updatedAt as string),
473
+ };
474
+ });
475
+
476
+ // Use runs.length as total when not paginating
477
+ return { runs, total: total || runs.length };
478
+ }
479
+
480
+ async getEvalsByAgentName(agentName: string, type?: 'test' | 'live'): Promise<EvalRow[]> {
481
+ try {
482
+ const query: any = {
483
+ agent_name: agentName,
484
+ };
485
+
486
+ if (type === 'test') {
487
+ query['test_info'] = { $ne: null };
488
+ // is not possible to filter by test_info.testPath because it is not a json field
489
+ // query['test_info.testPath'] = { $ne: null };
490
+ }
491
+
492
+ if (type === 'live') {
493
+ // is not possible to filter by test_info.testPath because it is not a json field
494
+ query['test_info'] = null;
495
+ }
496
+
497
+ const collection = await this.getCollection(TABLE_EVALS);
498
+ const documents = await collection.find(query).sort({ created_at: 'desc' }).toArray();
499
+ const result = documents.map(row => this.transformEvalRow(row));
500
+ // Post filter to remove if test_info.testPath is null
501
+ return result.filter(row => {
502
+ if (type === 'live') {
503
+ return !Boolean(row.testInfo?.testPath);
504
+ }
505
+
506
+ if (type === 'test') {
507
+ return row.testInfo?.testPath !== null;
508
+ }
509
+ return true;
510
+ });
511
+ } catch (error) {
512
+ // Handle case where table doesn't exist yet
513
+ if (error instanceof Error && error.message.includes('no such table')) {
514
+ return [];
515
+ }
516
+ this.logger.error('Failed to get evals for the specified agent: ' + (error as any)?.message);
517
+ throw error;
518
+ }
519
+ }
520
+
521
+ async persistWorkflowSnapshot({
522
+ workflowName,
523
+ runId,
524
+ snapshot,
525
+ }: {
526
+ workflowName: string;
527
+ runId: string;
528
+ snapshot: WorkflowRunState;
529
+ }): Promise<void> {
530
+ try {
531
+ const now = new Date().toISOString();
532
+ const collection = await this.getCollection(TABLE_WORKFLOW_SNAPSHOT);
533
+ await collection.updateOne(
534
+ { workflow_name: workflowName, run_id: runId },
535
+ {
536
+ $set: {
537
+ snapshot: JSON.stringify(snapshot),
538
+ updatedAt: now,
539
+ },
540
+ $setOnInsert: {
541
+ createdAt: now,
542
+ },
543
+ },
544
+ { upsert: true },
545
+ );
546
+ } catch (error) {
547
+ this.logger.error(`Error persisting workflow snapshot: ${error}`);
548
+ throw error;
549
+ }
550
+ }
551
+
552
+ async loadWorkflowSnapshot({
553
+ workflowName,
554
+ runId,
555
+ }: {
556
+ workflowName: string;
557
+ runId: string;
558
+ }): Promise<WorkflowRunState | null> {
559
+ try {
560
+ const result = await this.load<any[]>({
561
+ tableName: TABLE_WORKFLOW_SNAPSHOT,
562
+ keys: {
563
+ workflow_name: workflowName,
564
+ run_id: runId,
565
+ },
566
+ });
567
+
568
+ if (!result?.length) {
569
+ return null;
570
+ }
571
+
572
+ return JSON.parse(result[0].snapshot);
573
+ } catch (error) {
574
+ console.error('Error loading workflow snapshot:', error);
575
+ throw error;
576
+ }
577
+ }
578
+
579
+ async getWorkflowRunById({
580
+ runId,
581
+ workflowName,
582
+ }: {
583
+ runId: string;
584
+ workflowName?: string;
585
+ }): Promise<WorkflowRun | null> {
586
+ try {
587
+ const query: any = {};
588
+ if (runId) {
589
+ query['run_id'] = runId;
590
+ }
591
+
592
+ if (workflowName) {
593
+ query['workflow_name'] = workflowName;
594
+ }
595
+
596
+ const collection = await this.getCollection(TABLE_WORKFLOW_SNAPSHOT);
597
+ const result = await collection.findOne(query);
598
+ if (!result) {
599
+ return null;
600
+ }
601
+
602
+ return this.parseWorkflowRun(result);
603
+ } catch (error) {
604
+ console.error('Error getting workflow run by ID:', error);
605
+ throw error;
606
+ }
607
+ }
608
+
609
+ private parseWorkflowRun(row: any): WorkflowRun {
610
+ let parsedSnapshot: WorkflowRunState | string = row.snapshot as string;
611
+ if (typeof parsedSnapshot === 'string') {
612
+ try {
613
+ parsedSnapshot = JSON.parse(row.snapshot as string) as WorkflowRunState;
614
+ } catch (e) {
615
+ // If parsing fails, return the raw snapshot string
616
+ console.warn(`Failed to parse snapshot for workflow ${row.workflow_name}: ${e}`);
617
+ }
618
+ }
619
+
620
+ return {
621
+ workflowName: row.workflow_name,
622
+ runId: row.run_id,
623
+ snapshot: parsedSnapshot,
624
+ createdAt: row.createdAt,
625
+ updatedAt: row.updatedAt,
626
+ resourceId: row.resourceId,
627
+ };
628
+ }
629
+
630
+ private parseRow(row: any): MessageType {
631
+ let content = row.content;
632
+ try {
633
+ content = JSON.parse(row.content);
634
+ } catch {
635
+ // use content as is if it's not JSON
636
+ }
637
+ return {
638
+ id: row.id,
639
+ content,
640
+ role: row.role,
641
+ type: row.type,
642
+ createdAt: new Date(row.createdAt as string),
643
+ threadId: row.thread_id,
644
+ } as MessageType;
645
+ }
646
+
647
+ private transformEvalRow(row: Record<string, any>): EvalRow {
648
+ let testInfoValue = null;
649
+ if (row.test_info) {
650
+ try {
651
+ testInfoValue = typeof row.test_info === 'string' ? JSON.parse(row.test_info) : row.test_info;
652
+ } catch (e) {
653
+ console.warn('Failed to parse test_info:', e);
654
+ }
655
+ }
656
+
657
+ return {
658
+ input: row.input as string,
659
+ output: row.output as string,
660
+ result: row.result as MetricResult,
661
+ agentName: row.agent_name as string,
662
+ metricName: row.metric_name as string,
663
+ instructions: row.instructions as string,
664
+ testInfo: testInfoValue as TestInfo,
665
+ globalRunId: row.global_run_id as string,
666
+ runId: row.run_id as string,
667
+ createdAt: row.created_at as string,
668
+ };
669
+ }
670
+
671
+ async close(): Promise<void> {
672
+ await this.#client.close();
673
+ }
674
+ }
@@ -1,8 +0,0 @@
1
- services:
2
- mongodb:
3
- image: mongodb/mongodb-atlas-local
4
- environment:
5
- MONGODB_INITDB_ROOT_USERNAME: mongodb
6
- MONGODB_INITDB_ROOT_PASSWORD: mongodb
7
- ports:
8
- - 27018:27017