@mastra/mongodb 0.10.0 → 0.10.1-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,675 @@
1
+ import type { MetricResult, TestInfo } from '@mastra/core/eval';
2
+ import type { MessageType, StorageThreadType } from '@mastra/core/memory';
3
+ import type { EvalRow, StorageGetMessagesArg, TABLE_NAMES, WorkflowRun } from '@mastra/core/storage';
4
+ import {
5
+ MastraStorage,
6
+ TABLE_EVALS,
7
+ TABLE_MESSAGES,
8
+ TABLE_THREADS,
9
+ TABLE_TRACES,
10
+ TABLE_WORKFLOW_SNAPSHOT,
11
+ } from '@mastra/core/storage';
12
+ import type { WorkflowRunState } from '@mastra/core/workflows';
13
+ import type { Db, MongoClientOptions } from 'mongodb';
14
+ import { MongoClient } from 'mongodb';
15
+
16
+ function safelyParseJSON(jsonString: string): any {
17
+ try {
18
+ return JSON.parse(jsonString);
19
+ } catch {
20
+ return {};
21
+ }
22
+ }
23
+
24
+ export interface MongoDBConfig {
25
+ url: string;
26
+ dbName: string;
27
+ options?: MongoClientOptions;
28
+ }
29
+
30
+ export class MongoDBStore extends MastraStorage {
31
+ #isConnected = false;
32
+ #client: MongoClient;
33
+ #db: Db | undefined;
34
+ readonly #dbName: string;
35
+
36
+ constructor(config: MongoDBConfig) {
37
+ super({ name: 'MongoDBStore' });
38
+ this.#isConnected = false;
39
+
40
+ if (!config.url?.trim().length) {
41
+ throw new Error(
42
+ 'MongoDBStore: url must be provided and cannot be empty. Passing an empty string may cause fallback to local MongoDB defaults.',
43
+ );
44
+ }
45
+
46
+ if (!config.dbName?.trim().length) {
47
+ throw new Error(
48
+ 'MongoDBStore: dbName must be provided and cannot be empty. Passing an empty string may cause fallback to local MongoDB defaults.',
49
+ );
50
+ }
51
+
52
+ this.#dbName = config.dbName;
53
+ this.#client = new MongoClient(config.url, config.options);
54
+ }
55
+
56
+ private async getConnection(): Promise<Db> {
57
+ if (this.#isConnected) {
58
+ return this.#db!;
59
+ }
60
+
61
+ await this.#client.connect();
62
+ this.#db = this.#client.db(this.#dbName);
63
+ this.#isConnected = true;
64
+ return this.#db;
65
+ }
66
+
67
+ private async getCollection(collectionName: string) {
68
+ const db = await this.getConnection();
69
+ return db.collection(collectionName);
70
+ }
71
+
72
+ async createTable(): Promise<void> {
73
+ // Nothing to do here, MongoDB is schemaless
74
+ }
75
+
76
+ async clearTable({ tableName }: { tableName: TABLE_NAMES }): Promise<void> {
77
+ try {
78
+ const collection = await this.getCollection(tableName);
79
+ await collection.deleteMany({});
80
+ } catch (error) {
81
+ if (error instanceof Error) {
82
+ this.logger.error(error.message);
83
+ }
84
+ }
85
+ }
86
+
87
+ async insert({ tableName, record }: { tableName: TABLE_NAMES; record: Record<string, any> }): Promise<void> {
88
+ try {
89
+ const collection = await this.getCollection(tableName);
90
+ await collection.insertOne(record);
91
+ } catch (error) {
92
+ this.logger.error(`Error upserting into table ${tableName}: ${error}`);
93
+ throw error;
94
+ }
95
+ }
96
+
97
+ async batchInsert({ tableName, records }: { tableName: TABLE_NAMES; records: Record<string, any>[] }): Promise<void> {
98
+ if (!records.length) {
99
+ return;
100
+ }
101
+
102
+ try {
103
+ const collection = await this.getCollection(tableName);
104
+ await collection.insertMany(records);
105
+ } catch (error) {
106
+ this.logger.error(`Error upserting into table ${tableName}: ${error}`);
107
+ throw error;
108
+ }
109
+ }
110
+
111
+ async load<R>({ tableName, keys }: { tableName: TABLE_NAMES; keys: Record<string, string> }): Promise<R | null> {
112
+ this.logger.info(`Loading ${tableName} with keys ${JSON.stringify(keys)}`);
113
+ try {
114
+ const collection = await this.getCollection(tableName);
115
+ return (await collection.find(keys).toArray()) as R;
116
+ } catch (error) {
117
+ this.logger.error(`Error loading ${tableName} with keys ${JSON.stringify(keys)}: ${error}`);
118
+ throw error;
119
+ }
120
+ }
121
+
122
+ async getThreadById({ threadId }: { threadId: string }): Promise<StorageThreadType | null> {
123
+ try {
124
+ const collection = await this.getCollection(TABLE_THREADS);
125
+ const result = await collection.findOne<any>({ id: threadId });
126
+ if (!result) {
127
+ return null;
128
+ }
129
+
130
+ return {
131
+ ...result,
132
+ metadata: typeof result.metadata === 'string' ? JSON.parse(result.metadata) : result.metadata,
133
+ };
134
+ } catch (error) {
135
+ this.logger.error(`Error loading thread with ID ${threadId}: ${error}`);
136
+ throw error;
137
+ }
138
+ }
139
+
140
+ async getThreadsByResourceId({ resourceId }: { resourceId: string }): Promise<StorageThreadType[]> {
141
+ try {
142
+ const collection = await this.getCollection(TABLE_THREADS);
143
+ const results = await collection.find<any>({ resourceId }).toArray();
144
+ if (!results.length) {
145
+ return [];
146
+ }
147
+
148
+ return results.map(result => ({
149
+ ...result,
150
+ metadata: typeof result.metadata === 'string' ? JSON.parse(result.metadata) : result.metadata,
151
+ }));
152
+ } catch (error) {
153
+ this.logger.error(`Error loading threads by resourceId ${resourceId}: ${error}`);
154
+ throw error;
155
+ }
156
+ }
157
+
158
+ async saveThread({ thread }: { thread: StorageThreadType }): Promise<StorageThreadType> {
159
+ try {
160
+ const collection = await this.getCollection(TABLE_THREADS);
161
+ await collection.updateOne(
162
+ { id: thread.id },
163
+ {
164
+ $set: {
165
+ ...thread,
166
+ metadata: JSON.stringify(thread.metadata),
167
+ },
168
+ },
169
+ { upsert: true },
170
+ );
171
+ return thread;
172
+ } catch (error) {
173
+ this.logger.error(`Error saving thread ${thread.id}: ${error}`);
174
+ throw error;
175
+ }
176
+ }
177
+
178
+ async updateThread({
179
+ id,
180
+ title,
181
+ metadata,
182
+ }: {
183
+ id: string;
184
+ title: string;
185
+ metadata: Record<string, unknown>;
186
+ }): Promise<StorageThreadType> {
187
+ const thread = await this.getThreadById({ threadId: id });
188
+ if (!thread) {
189
+ throw new Error(`Thread ${id} not found`);
190
+ }
191
+
192
+ const updatedThread = {
193
+ ...thread,
194
+ title,
195
+ metadata: {
196
+ ...thread.metadata,
197
+ ...metadata,
198
+ },
199
+ };
200
+
201
+ try {
202
+ const collection = await this.getCollection(TABLE_THREADS);
203
+ await collection.updateOne(
204
+ { id },
205
+ {
206
+ $set: {
207
+ title,
208
+ metadata: JSON.stringify(updatedThread.metadata),
209
+ },
210
+ },
211
+ );
212
+ } catch (error) {
213
+ this.logger.error(`Error updating thread ${id}:) ${error}`);
214
+ throw error;
215
+ }
216
+
217
+ return updatedThread;
218
+ }
219
+
220
+ async deleteThread({ threadId }: { threadId: string }): Promise<void> {
221
+ try {
222
+ // First, delete all messages associated with the thread
223
+ const collectionMessages = await this.getCollection(TABLE_MESSAGES);
224
+ await collectionMessages.deleteMany({ thread_id: threadId });
225
+ // Then delete the thread itself
226
+ const collectionThreads = await this.getCollection(TABLE_THREADS);
227
+ await collectionThreads.deleteOne({ id: threadId });
228
+ } catch (error) {
229
+ this.logger.error(`Error deleting thread ${threadId}: ${error}`);
230
+ throw error;
231
+ }
232
+ }
233
+
234
+ async getMessages<T = unknown>({ threadId, selectBy }: StorageGetMessagesArg): Promise<T[]> {
235
+ try {
236
+ const limit = typeof selectBy?.last === 'number' ? selectBy.last : 40;
237
+ const include = selectBy?.include || [];
238
+ let messages: MessageType[] = [];
239
+ let allMessages: MessageType[] = [];
240
+ const collection = await this.getCollection(TABLE_MESSAGES);
241
+ // Get all messages from the thread ordered by creation date descending
242
+ allMessages = (await collection.find({ thread_id: threadId }).sort({ createdAt: -1 }).toArray()).map((row: any) =>
243
+ this.parseRow(row),
244
+ );
245
+
246
+ // If there are messages to include, select the messages around the included IDs
247
+ if (include.length) {
248
+ // Map IDs to their position in the ordered array
249
+ const idToIndex = new Map<string, number>();
250
+ allMessages.forEach((msg, idx) => {
251
+ idToIndex.set(msg.id, idx);
252
+ });
253
+
254
+ const selectedIndexes = new Set<number>();
255
+ for (const inc of include) {
256
+ const idx = idToIndex.get(inc.id);
257
+ if (idx === undefined) continue;
258
+ // Previous messages
259
+ for (let i = 1; i <= (inc.withPreviousMessages || 0); i++) {
260
+ if (idx + i < allMessages.length) selectedIndexes.add(idx + i);
261
+ }
262
+ // Included message
263
+ selectedIndexes.add(idx);
264
+ // Next messages
265
+ for (let i = 1; i <= (inc.withNextMessages || 0); i++) {
266
+ if (idx - i >= 0) selectedIndexes.add(idx - i);
267
+ }
268
+ }
269
+ // Add the selected messages, filtering out undefined
270
+ messages.push(
271
+ ...Array.from(selectedIndexes)
272
+ .map(i => allMessages[i])
273
+ .filter((m): m is MessageType => !!m),
274
+ );
275
+ }
276
+
277
+ // Get the remaining messages, excluding those already selected
278
+ const excludeIds = new Set(messages.map(m => m.id));
279
+ for (const msg of allMessages) {
280
+ if (messages.length >= limit) break;
281
+ if (!excludeIds.has(msg.id)) {
282
+ messages.push(msg);
283
+ }
284
+ }
285
+
286
+ // Sort all messages by creation date ascending
287
+ messages.sort((a, b) => a.createdAt.getTime() - b.createdAt.getTime());
288
+
289
+ return messages.slice(0, limit) as T[];
290
+ } catch (error) {
291
+ this.logger.error('Error getting messages:', error as Error);
292
+ throw error;
293
+ }
294
+ }
295
+
296
+ async saveMessages({ messages }: { messages: MessageType[] }): Promise<MessageType[]> {
297
+ if (!messages.length) {
298
+ return messages;
299
+ }
300
+
301
+ const threadId = messages[0]?.threadId;
302
+ if (!threadId) {
303
+ this.logger.error('Thread ID is required to save messages');
304
+ throw new Error('Thread ID is required');
305
+ }
306
+ try {
307
+ // Prepare batch statements for all messages
308
+ const messagesToInsert = messages.map(message => {
309
+ const time = message.createdAt || new Date();
310
+ return {
311
+ id: message.id,
312
+ thread_id: threadId,
313
+ content: typeof message.content === 'string' ? message.content : JSON.stringify(message.content),
314
+ role: message.role,
315
+ type: message.type,
316
+ resourceId: message.resourceId,
317
+ createdAt: time instanceof Date ? time.toISOString() : time,
318
+ };
319
+ });
320
+
321
+ // Execute all inserts in a single batch
322
+ const collection = await this.getCollection(TABLE_MESSAGES);
323
+ await collection.insertMany(messagesToInsert);
324
+ return messages;
325
+ } catch (error) {
326
+ this.logger.error('Failed to save messages in database: ' + (error as { message: string })?.message);
327
+ throw error;
328
+ }
329
+ }
330
+
331
+ async getTraces(
332
+ {
333
+ name,
334
+ scope,
335
+ page,
336
+ perPage,
337
+ attributes,
338
+ filters,
339
+ }: {
340
+ name?: string;
341
+ scope?: string;
342
+ page: number;
343
+ perPage: number;
344
+ attributes?: Record<string, string>;
345
+ filters?: Record<string, any>;
346
+ } = {
347
+ page: 0,
348
+ perPage: 100,
349
+ },
350
+ ): Promise<any[]> {
351
+ const limit = perPage;
352
+ const offset = page * perPage;
353
+
354
+ const query: any = {};
355
+ if (name) {
356
+ query['name'] = `%${name}%`;
357
+ }
358
+
359
+ if (scope) {
360
+ query['scope'] = scope;
361
+ }
362
+
363
+ if (attributes) {
364
+ Object.keys(attributes).forEach(key => {
365
+ query[`attributes.${key}`] = attributes[key];
366
+ });
367
+ }
368
+
369
+ if (filters) {
370
+ Object.entries(filters).forEach(([key, value]) => {
371
+ query[key] = value;
372
+ });
373
+ }
374
+
375
+ const collection = await this.getCollection(TABLE_TRACES);
376
+ const result = await collection
377
+ .find(query, {
378
+ sort: { startTime: -1 },
379
+ })
380
+ .limit(limit)
381
+ .skip(offset)
382
+ .toArray();
383
+
384
+ return result.map(row => ({
385
+ id: row.id,
386
+ parentSpanId: row.parentSpanId,
387
+ traceId: row.traceId,
388
+ name: row.name,
389
+ scope: row.scope,
390
+ kind: row.kind,
391
+ status: safelyParseJSON(row.status as string),
392
+ events: safelyParseJSON(row.events as string),
393
+ links: safelyParseJSON(row.links as string),
394
+ attributes: safelyParseJSON(row.attributes as string),
395
+ startTime: row.startTime,
396
+ endTime: row.endTime,
397
+ other: safelyParseJSON(row.other as string),
398
+ createdAt: row.createdAt,
399
+ })) as any;
400
+ }
401
+
402
+ async getWorkflowRuns({
403
+ workflowName,
404
+ fromDate,
405
+ toDate,
406
+ limit,
407
+ offset,
408
+ }: {
409
+ workflowName?: string;
410
+ fromDate?: Date;
411
+ toDate?: Date;
412
+ limit?: number;
413
+ offset?: number;
414
+ } = {}): Promise<{
415
+ runs: Array<{
416
+ workflowName: string;
417
+ runId: string;
418
+ snapshot: WorkflowRunState | string;
419
+ createdAt: Date;
420
+ updatedAt: Date;
421
+ }>;
422
+ total: number;
423
+ }> {
424
+ const query: any = {};
425
+ if (workflowName) {
426
+ query['workflow_name'] = workflowName;
427
+ }
428
+
429
+ if (fromDate || toDate) {
430
+ query['createdAt'] = {};
431
+ if (fromDate) {
432
+ query['createdAt']['$gte'] = fromDate;
433
+ }
434
+ if (toDate) {
435
+ query['createdAt']['$lte'] = toDate;
436
+ }
437
+ }
438
+
439
+ const collection = await this.getCollection(TABLE_WORKFLOW_SNAPSHOT);
440
+ let total = 0;
441
+ // Only get total count when using pagination
442
+ if (limit !== undefined && offset !== undefined) {
443
+ total = await collection.countDocuments(query);
444
+ }
445
+
446
+ // Get results
447
+ const request = collection.find(query).sort({ createdAt: 'desc' });
448
+ if (limit) {
449
+ request.limit(limit);
450
+ }
451
+
452
+ if (offset) {
453
+ request.skip(offset);
454
+ }
455
+
456
+ const result = await request.toArray();
457
+ const runs = result.map(row => {
458
+ let parsedSnapshot: WorkflowRunState | string = row.snapshot;
459
+ if (typeof parsedSnapshot === 'string') {
460
+ try {
461
+ parsedSnapshot = JSON.parse(row.snapshot as string) as WorkflowRunState;
462
+ } catch (e) {
463
+ // If parsing fails, return the raw snapshot string
464
+ console.warn(`Failed to parse snapshot for workflow ${row.workflow_name}: ${e}`);
465
+ }
466
+ }
467
+
468
+ return {
469
+ workflowName: row.workflow_name as string,
470
+ runId: row.run_id as string,
471
+ snapshot: parsedSnapshot,
472
+ createdAt: new Date(row.createdAt as string),
473
+ updatedAt: new Date(row.updatedAt as string),
474
+ };
475
+ });
476
+
477
+ // Use runs.length as total when not paginating
478
+ return { runs, total: total || runs.length };
479
+ }
480
+
481
+ async getEvalsByAgentName(agentName: string, type?: 'test' | 'live'): Promise<EvalRow[]> {
482
+ try {
483
+ const query: any = {
484
+ agent_name: agentName,
485
+ };
486
+
487
+ if (type === 'test') {
488
+ query['test_info'] = { $ne: null };
489
+ // is not possible to filter by test_info.testPath because it is not a json field
490
+ // query['test_info.testPath'] = { $ne: null };
491
+ }
492
+
493
+ if (type === 'live') {
494
+ // is not possible to filter by test_info.testPath because it is not a json field
495
+ query['test_info'] = null;
496
+ }
497
+
498
+ const collection = await this.getCollection(TABLE_EVALS);
499
+ const documents = await collection.find(query).sort({ created_at: 'desc' }).toArray();
500
+ const result = documents.map(row => this.transformEvalRow(row));
501
+ // Post filter to remove if test_info.testPath is null
502
+ return result.filter(row => {
503
+ if (type === 'live') {
504
+ return !Boolean(row.testInfo?.testPath);
505
+ }
506
+
507
+ if (type === 'test') {
508
+ return row.testInfo?.testPath !== null;
509
+ }
510
+ return true;
511
+ });
512
+ } catch (error) {
513
+ // Handle case where table doesn't exist yet
514
+ if (error instanceof Error && error.message.includes('no such table')) {
515
+ return [];
516
+ }
517
+ this.logger.error('Failed to get evals for the specified agent: ' + (error as any)?.message);
518
+ throw error;
519
+ }
520
+ }
521
+
522
+ async persistWorkflowSnapshot({
523
+ workflowName,
524
+ runId,
525
+ snapshot,
526
+ }: {
527
+ workflowName: string;
528
+ runId: string;
529
+ snapshot: WorkflowRunState;
530
+ }): Promise<void> {
531
+ try {
532
+ const now = new Date().toISOString();
533
+ const collection = await this.getCollection(TABLE_WORKFLOW_SNAPSHOT);
534
+ await collection.updateOne(
535
+ { workflow_name: workflowName, run_id: runId },
536
+ {
537
+ $set: {
538
+ snapshot: JSON.stringify(snapshot),
539
+ updatedAt: now,
540
+ },
541
+ $setOnInsert: {
542
+ createdAt: now,
543
+ },
544
+ },
545
+ { upsert: true },
546
+ );
547
+ } catch (error) {
548
+ this.logger.error(`Error persisting workflow snapshot: ${error}`);
549
+ throw error;
550
+ }
551
+ }
552
+
553
+ async loadWorkflowSnapshot({
554
+ workflowName,
555
+ runId,
556
+ }: {
557
+ workflowName: string;
558
+ runId: string;
559
+ }): Promise<WorkflowRunState | null> {
560
+ try {
561
+ const result = await this.load<any[]>({
562
+ tableName: TABLE_WORKFLOW_SNAPSHOT,
563
+ keys: {
564
+ workflow_name: workflowName,
565
+ run_id: runId,
566
+ },
567
+ });
568
+
569
+ if (!result?.length) {
570
+ return null;
571
+ }
572
+
573
+ return JSON.parse(result[0].snapshot);
574
+ } catch (error) {
575
+ console.error('Error loading workflow snapshot:', error);
576
+ throw error;
577
+ }
578
+ }
579
+
580
+ async getWorkflowRunById({
581
+ runId,
582
+ workflowName,
583
+ }: {
584
+ runId: string;
585
+ workflowName?: string;
586
+ }): Promise<WorkflowRun | null> {
587
+ try {
588
+ const query: any = {};
589
+ if (runId) {
590
+ query['run_id'] = runId;
591
+ }
592
+
593
+ if (workflowName) {
594
+ query['workflow_name'] = workflowName;
595
+ }
596
+
597
+ const collection = await this.getCollection(TABLE_WORKFLOW_SNAPSHOT);
598
+ const result = await collection.findOne(query);
599
+ if (!result) {
600
+ return null;
601
+ }
602
+
603
+ return this.parseWorkflowRun(result);
604
+ } catch (error) {
605
+ console.error('Error getting workflow run by ID:', error);
606
+ throw error;
607
+ }
608
+ }
609
+
610
+ private parseWorkflowRun(row: any): WorkflowRun {
611
+ let parsedSnapshot: WorkflowRunState | string = row.snapshot as string;
612
+ if (typeof parsedSnapshot === 'string') {
613
+ try {
614
+ parsedSnapshot = JSON.parse(row.snapshot as string) as WorkflowRunState;
615
+ } catch (e) {
616
+ // If parsing fails, return the raw snapshot string
617
+ console.warn(`Failed to parse snapshot for workflow ${row.workflow_name}: ${e}`);
618
+ }
619
+ }
620
+
621
+ return {
622
+ workflowName: row.workflow_name,
623
+ runId: row.run_id,
624
+ snapshot: parsedSnapshot,
625
+ createdAt: row.createdAt,
626
+ updatedAt: row.updatedAt,
627
+ resourceId: row.resourceId,
628
+ };
629
+ }
630
+
631
+ private parseRow(row: any): MessageType {
632
+ let content = row.content;
633
+ try {
634
+ content = JSON.parse(row.content);
635
+ } catch {
636
+ // use content as is if it's not JSON
637
+ }
638
+ return {
639
+ id: row.id,
640
+ content,
641
+ role: row.role,
642
+ type: row.type,
643
+ createdAt: new Date(row.createdAt as string),
644
+ threadId: row.thread_id,
645
+ } as MessageType;
646
+ }
647
+
648
+ private transformEvalRow(row: Record<string, any>): EvalRow {
649
+ let testInfoValue = null;
650
+ if (row.test_info) {
651
+ try {
652
+ testInfoValue = typeof row.test_info === 'string' ? JSON.parse(row.test_info) : row.test_info;
653
+ } catch (e) {
654
+ console.warn('Failed to parse test_info:', e);
655
+ }
656
+ }
657
+
658
+ return {
659
+ input: row.input as string,
660
+ output: row.output as string,
661
+ result: row.result as MetricResult,
662
+ agentName: row.agent_name as string,
663
+ metricName: row.metric_name as string,
664
+ instructions: row.instructions as string,
665
+ testInfo: testInfoValue as TestInfo,
666
+ globalRunId: row.global_run_id as string,
667
+ runId: row.run_id as string,
668
+ createdAt: row.created_at as string,
669
+ };
670
+ }
671
+
672
+ async close(): Promise<void> {
673
+ await this.#client.close();
674
+ }
675
+ }
@@ -1,8 +0,0 @@
1
- services:
2
- mongodb:
3
- image: mongodb/mongodb-atlas-local
4
- environment:
5
- MONGODB_INITDB_ROOT_USERNAME: mongodb
6
- MONGODB_INITDB_ROOT_PASSWORD: mongodb
7
- ports:
8
- - 27018:27017