@mastra/mongodb 0.12.0 → 0.12.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +7 -7
- package/CHANGELOG.md +57 -0
- package/LICENSE.md +11 -42
- package/dist/_tsup-dts-rollup.d.cts +376 -54
- package/dist/_tsup-dts-rollup.d.ts +376 -54
- package/dist/index.cjs +1420 -424
- package/dist/index.d.cts +0 -1
- package/dist/index.d.ts +0 -1
- package/dist/index.js +1414 -418
- package/docker-compose.yaml +1 -1
- package/package.json +6 -6
- package/src/storage/ConnectorHandler.ts +7 -0
- package/src/storage/MongoDBConnector.ts +93 -0
- package/src/storage/connectors/MongoDBConnector.ts +93 -0
- package/src/storage/connectors/base.ts +7 -0
- package/src/storage/domains/legacy-evals/index.ts +193 -0
- package/src/storage/domains/memory/index.ts +741 -0
- package/src/storage/domains/operations/index.ts +152 -0
- package/src/storage/domains/scores/index.ts +379 -0
- package/src/storage/domains/traces/index.ts +142 -0
- package/src/storage/domains/utils.ts +43 -0
- package/src/storage/domains/workflows/index.ts +196 -0
- package/src/storage/index.test.ts +24 -1226
- package/src/storage/index.ts +218 -776
- package/src/storage/types.ts +14 -0
- package/src/vector/index.test.ts +16 -1
- package/src/vector/index.ts +34 -11
|
@@ -0,0 +1,741 @@
|
|
|
1
|
+
import { MessageList } from '@mastra/core/agent';
|
|
2
|
+
import type { MastraMessageContentV2 } from '@mastra/core/agent';
|
|
3
|
+
import { ErrorCategory, ErrorDomain, MastraError } from '@mastra/core/error';
|
|
4
|
+
import type { MastraMessageV1, MastraMessageV2, StorageThreadType } from '@mastra/core/memory';
|
|
5
|
+
import {
|
|
6
|
+
MemoryStorage,
|
|
7
|
+
resolveMessageLimit,
|
|
8
|
+
safelyParseJSON,
|
|
9
|
+
TABLE_MESSAGES,
|
|
10
|
+
TABLE_RESOURCES,
|
|
11
|
+
TABLE_THREADS,
|
|
12
|
+
} from '@mastra/core/storage';
|
|
13
|
+
import type { PaginationInfo, StorageGetMessagesArg, StorageResourceType } from '@mastra/core/storage';
|
|
14
|
+
import type { StoreOperationsMongoDB } from '../operations';
|
|
15
|
+
import { formatDateForMongoDB } from '../utils';
|
|
16
|
+
|
|
17
|
+
export class MemoryStorageMongoDB extends MemoryStorage {
|
|
18
|
+
private operations: StoreOperationsMongoDB;
|
|
19
|
+
|
|
20
|
+
constructor({ operations }: { operations: StoreOperationsMongoDB }) {
|
|
21
|
+
super();
|
|
22
|
+
this.operations = operations;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
private parseRow(row: any): MastraMessageV2 {
|
|
26
|
+
let content = row.content;
|
|
27
|
+
if (typeof content === 'string') {
|
|
28
|
+
try {
|
|
29
|
+
content = JSON.parse(content);
|
|
30
|
+
} catch {
|
|
31
|
+
// use content as is if it's not JSON
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
const result = {
|
|
36
|
+
id: row.id,
|
|
37
|
+
content,
|
|
38
|
+
role: row.role,
|
|
39
|
+
createdAt: formatDateForMongoDB(row.createdAt),
|
|
40
|
+
threadId: row.thread_id,
|
|
41
|
+
resourceId: row.resourceId,
|
|
42
|
+
} as MastraMessageV2;
|
|
43
|
+
|
|
44
|
+
if (row.type && row.type !== 'v2') result.type = row.type;
|
|
45
|
+
return result;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
private async _getIncludedMessages({
|
|
49
|
+
threadId,
|
|
50
|
+
selectBy,
|
|
51
|
+
}: {
|
|
52
|
+
threadId: string;
|
|
53
|
+
selectBy: StorageGetMessagesArg['selectBy'];
|
|
54
|
+
}) {
|
|
55
|
+
const include = selectBy?.include;
|
|
56
|
+
if (!include) return null;
|
|
57
|
+
|
|
58
|
+
const collection = await this.operations.getCollection(TABLE_MESSAGES);
|
|
59
|
+
|
|
60
|
+
const includedMessages: any[] = [];
|
|
61
|
+
|
|
62
|
+
for (const inc of include) {
|
|
63
|
+
const { id, withPreviousMessages = 0, withNextMessages = 0 } = inc;
|
|
64
|
+
const searchThreadId = inc.threadId || threadId;
|
|
65
|
+
|
|
66
|
+
// Get all messages for the search thread ordered by creation date
|
|
67
|
+
const allMessages = await collection.find({ thread_id: searchThreadId }).sort({ createdAt: 1 }).toArray();
|
|
68
|
+
|
|
69
|
+
// Find the target message
|
|
70
|
+
const targetIndex = allMessages.findIndex((msg: any) => msg.id === id);
|
|
71
|
+
|
|
72
|
+
if (targetIndex === -1) continue;
|
|
73
|
+
|
|
74
|
+
// Get previous messages
|
|
75
|
+
const startIndex = Math.max(0, targetIndex - withPreviousMessages);
|
|
76
|
+
// Get next messages
|
|
77
|
+
const endIndex = Math.min(allMessages.length - 1, targetIndex + withNextMessages);
|
|
78
|
+
|
|
79
|
+
// Add messages in range
|
|
80
|
+
for (let i = startIndex; i <= endIndex; i++) {
|
|
81
|
+
includedMessages.push(allMessages[i]);
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
// Remove duplicates
|
|
86
|
+
const seen = new Set<string>();
|
|
87
|
+
const dedupedMessages = includedMessages.filter(msg => {
|
|
88
|
+
if (seen.has(msg.id)) return false;
|
|
89
|
+
seen.add(msg.id);
|
|
90
|
+
return true;
|
|
91
|
+
});
|
|
92
|
+
|
|
93
|
+
return dedupedMessages.map(row => this.parseRow(row));
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
/**
|
|
97
|
+
* @deprecated use getMessagesPaginated instead for paginated results.
|
|
98
|
+
*/
|
|
99
|
+
public async getMessages(args: StorageGetMessagesArg & { format?: 'v1' }): Promise<MastraMessageV1[]>;
|
|
100
|
+
public async getMessages(args: StorageGetMessagesArg & { format: 'v2' }): Promise<MastraMessageV2[]>;
|
|
101
|
+
public async getMessages({
|
|
102
|
+
threadId,
|
|
103
|
+
selectBy,
|
|
104
|
+
format,
|
|
105
|
+
}: StorageGetMessagesArg & {
|
|
106
|
+
format?: 'v1' | 'v2';
|
|
107
|
+
}): Promise<MastraMessageV1[] | MastraMessageV2[]> {
|
|
108
|
+
try {
|
|
109
|
+
const messages: MastraMessageV2[] = [];
|
|
110
|
+
const limit = resolveMessageLimit({ last: selectBy?.last, defaultLimit: 40 });
|
|
111
|
+
|
|
112
|
+
if (selectBy?.include?.length) {
|
|
113
|
+
const includeMessages = await this._getIncludedMessages({ threadId, selectBy });
|
|
114
|
+
if (includeMessages) {
|
|
115
|
+
messages.push(...includeMessages);
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
const excludeIds = messages.map(m => m.id);
|
|
120
|
+
const collection = await this.operations.getCollection(TABLE_MESSAGES);
|
|
121
|
+
|
|
122
|
+
const query: any = { thread_id: threadId };
|
|
123
|
+
if (excludeIds.length > 0) {
|
|
124
|
+
query.id = { $nin: excludeIds };
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
// Only fetch remaining messages if limit > 0
|
|
128
|
+
if (limit > 0) {
|
|
129
|
+
const remainingMessages = await collection.find(query).sort({ createdAt: -1 }).limit(limit).toArray();
|
|
130
|
+
|
|
131
|
+
messages.push(...remainingMessages.map((row: any) => this.parseRow(row)));
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
// Sort all messages by creation date ascending
|
|
135
|
+
messages.sort((a, b) => a.createdAt.getTime() - b.createdAt.getTime());
|
|
136
|
+
|
|
137
|
+
const list = new MessageList().add(messages, 'memory');
|
|
138
|
+
if (format === 'v2') return list.get.all.v2();
|
|
139
|
+
return list.get.all.v1();
|
|
140
|
+
} catch (error) {
|
|
141
|
+
throw new MastraError(
|
|
142
|
+
{
|
|
143
|
+
id: 'MONGODB_STORE_GET_MESSAGES_FAILED',
|
|
144
|
+
domain: ErrorDomain.STORAGE,
|
|
145
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
146
|
+
details: { threadId },
|
|
147
|
+
},
|
|
148
|
+
error,
|
|
149
|
+
);
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
public async getMessagesPaginated(
|
|
154
|
+
args: StorageGetMessagesArg & {
|
|
155
|
+
format?: 'v1' | 'v2';
|
|
156
|
+
},
|
|
157
|
+
): Promise<PaginationInfo & { messages: MastraMessageV1[] | MastraMessageV2[] }> {
|
|
158
|
+
const { threadId, format, selectBy } = args;
|
|
159
|
+
const { page = 0, perPage: perPageInput, dateRange } = selectBy?.pagination || {};
|
|
160
|
+
const perPage =
|
|
161
|
+
perPageInput !== undefined ? perPageInput : resolveMessageLimit({ last: selectBy?.last, defaultLimit: 40 });
|
|
162
|
+
const fromDate = dateRange?.start;
|
|
163
|
+
const toDate = dateRange?.end;
|
|
164
|
+
|
|
165
|
+
const messages: MastraMessageV2[] = [];
|
|
166
|
+
|
|
167
|
+
if (selectBy?.include?.length) {
|
|
168
|
+
try {
|
|
169
|
+
const includeMessages = await this._getIncludedMessages({ threadId, selectBy });
|
|
170
|
+
if (includeMessages) {
|
|
171
|
+
messages.push(...includeMessages);
|
|
172
|
+
}
|
|
173
|
+
} catch (error) {
|
|
174
|
+
throw new MastraError(
|
|
175
|
+
{
|
|
176
|
+
id: 'MONGODB_STORE_GET_MESSAGES_PAGINATED_GET_INCLUDE_MESSAGES_FAILED',
|
|
177
|
+
domain: ErrorDomain.STORAGE,
|
|
178
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
179
|
+
details: { threadId },
|
|
180
|
+
},
|
|
181
|
+
error,
|
|
182
|
+
);
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
try {
|
|
187
|
+
const currentOffset = page * perPage;
|
|
188
|
+
const collection = await this.operations.getCollection(TABLE_MESSAGES);
|
|
189
|
+
|
|
190
|
+
const query: any = { thread_id: threadId };
|
|
191
|
+
|
|
192
|
+
if (fromDate) {
|
|
193
|
+
query.createdAt = { ...query.createdAt, $gte: fromDate };
|
|
194
|
+
}
|
|
195
|
+
if (toDate) {
|
|
196
|
+
query.createdAt = { ...query.createdAt, $lte: toDate };
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
const total = await collection.countDocuments(query);
|
|
200
|
+
|
|
201
|
+
if (total === 0 && messages.length === 0) {
|
|
202
|
+
return {
|
|
203
|
+
messages: [],
|
|
204
|
+
total: 0,
|
|
205
|
+
page,
|
|
206
|
+
perPage,
|
|
207
|
+
hasMore: false,
|
|
208
|
+
};
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
const excludeIds = messages.map(m => m.id);
|
|
212
|
+
if (excludeIds.length > 0) {
|
|
213
|
+
query.id = { $nin: excludeIds };
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
const dataResult = await collection
|
|
217
|
+
.find(query)
|
|
218
|
+
.sort({ createdAt: -1 })
|
|
219
|
+
.skip(currentOffset)
|
|
220
|
+
.limit(perPage)
|
|
221
|
+
.toArray();
|
|
222
|
+
|
|
223
|
+
messages.push(...dataResult.map((row: any) => this.parseRow(row)));
|
|
224
|
+
|
|
225
|
+
const messagesToReturn =
|
|
226
|
+
format === 'v1'
|
|
227
|
+
? new MessageList().add(messages, 'memory').get.all.v1()
|
|
228
|
+
: new MessageList().add(messages, 'memory').get.all.v2();
|
|
229
|
+
|
|
230
|
+
return {
|
|
231
|
+
messages: messagesToReturn,
|
|
232
|
+
total,
|
|
233
|
+
page,
|
|
234
|
+
perPage,
|
|
235
|
+
hasMore: (page + 1) * perPage < total,
|
|
236
|
+
};
|
|
237
|
+
} catch (error) {
|
|
238
|
+
const mastraError = new MastraError(
|
|
239
|
+
{
|
|
240
|
+
id: 'MONGODB_STORE_GET_MESSAGES_PAGINATED_FAILED',
|
|
241
|
+
domain: ErrorDomain.STORAGE,
|
|
242
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
243
|
+
details: { threadId },
|
|
244
|
+
},
|
|
245
|
+
error,
|
|
246
|
+
);
|
|
247
|
+
this.logger?.trackException?.(mastraError);
|
|
248
|
+
this.logger?.error?.(mastraError.toString());
|
|
249
|
+
return { messages: [], total: 0, page, perPage, hasMore: false };
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
async saveMessages(args: { messages: MastraMessageV1[]; format?: undefined | 'v1' }): Promise<MastraMessageV1[]>;
|
|
254
|
+
async saveMessages(args: { messages: MastraMessageV2[]; format: 'v2' }): Promise<MastraMessageV2[]>;
|
|
255
|
+
async saveMessages({
|
|
256
|
+
messages,
|
|
257
|
+
format,
|
|
258
|
+
}:
|
|
259
|
+
| { messages: MastraMessageV1[]; format?: undefined | 'v1' }
|
|
260
|
+
| { messages: MastraMessageV2[]; format: 'v2' }): Promise<MastraMessageV2[] | MastraMessageV1[]> {
|
|
261
|
+
if (messages.length === 0) return messages;
|
|
262
|
+
|
|
263
|
+
try {
|
|
264
|
+
const threadId = messages[0]?.threadId;
|
|
265
|
+
if (!threadId) {
|
|
266
|
+
throw new Error('Thread ID is required');
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
const collection = await this.operations.getCollection(TABLE_MESSAGES);
|
|
270
|
+
const threadsCollection = await this.operations.getCollection(TABLE_THREADS);
|
|
271
|
+
|
|
272
|
+
// Prepare messages for insertion
|
|
273
|
+
const messagesToInsert = messages.map(message => {
|
|
274
|
+
const time = message.createdAt || new Date();
|
|
275
|
+
if (!message.threadId) {
|
|
276
|
+
throw new Error(
|
|
277
|
+
"Expected to find a threadId for message, but couldn't find one. An unexpected error has occurred.",
|
|
278
|
+
);
|
|
279
|
+
}
|
|
280
|
+
if (!message.resourceId) {
|
|
281
|
+
throw new Error(
|
|
282
|
+
"Expected to find a resourceId for message, but couldn't find one. An unexpected error has occurred.",
|
|
283
|
+
);
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
return {
|
|
287
|
+
updateOne: {
|
|
288
|
+
filter: { id: message.id },
|
|
289
|
+
update: {
|
|
290
|
+
$set: {
|
|
291
|
+
id: message.id,
|
|
292
|
+
thread_id: message.threadId!,
|
|
293
|
+
content: typeof message.content === 'object' ? JSON.stringify(message.content) : message.content,
|
|
294
|
+
role: message.role,
|
|
295
|
+
type: message.type || 'v2',
|
|
296
|
+
createdAt: formatDateForMongoDB(time),
|
|
297
|
+
resourceId: message.resourceId,
|
|
298
|
+
},
|
|
299
|
+
},
|
|
300
|
+
upsert: true,
|
|
301
|
+
},
|
|
302
|
+
};
|
|
303
|
+
});
|
|
304
|
+
|
|
305
|
+
// Execute message inserts and thread update in parallel
|
|
306
|
+
await Promise.all([
|
|
307
|
+
collection.bulkWrite(messagesToInsert),
|
|
308
|
+
threadsCollection.updateOne({ id: threadId }, { $set: { updatedAt: new Date() } }),
|
|
309
|
+
]);
|
|
310
|
+
|
|
311
|
+
const list = new MessageList().add(messages, 'memory');
|
|
312
|
+
if (format === 'v2') return list.get.all.v2();
|
|
313
|
+
return list.get.all.v1();
|
|
314
|
+
} catch (error) {
|
|
315
|
+
throw new MastraError(
|
|
316
|
+
{
|
|
317
|
+
id: 'MONGODB_STORE_SAVE_MESSAGES_FAILED',
|
|
318
|
+
domain: ErrorDomain.STORAGE,
|
|
319
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
320
|
+
},
|
|
321
|
+
error,
|
|
322
|
+
);
|
|
323
|
+
}
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
async updateMessages({
|
|
327
|
+
messages,
|
|
328
|
+
}: {
|
|
329
|
+
messages: (Partial<Omit<MastraMessageV2, 'createdAt'>> & {
|
|
330
|
+
id: string;
|
|
331
|
+
content?: { metadata?: MastraMessageContentV2['metadata']; content?: MastraMessageContentV2['content'] };
|
|
332
|
+
})[];
|
|
333
|
+
}): Promise<MastraMessageV2[]> {
|
|
334
|
+
if (messages.length === 0) {
|
|
335
|
+
return [];
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
const messageIds = messages.map(m => m.id);
|
|
339
|
+
const collection = await this.operations.getCollection(TABLE_MESSAGES);
|
|
340
|
+
|
|
341
|
+
const existingMessages = await collection.find({ id: { $in: messageIds } }).toArray();
|
|
342
|
+
|
|
343
|
+
const existingMessagesParsed: MastraMessageV2[] = existingMessages.map((msg: any) => this.parseRow(msg));
|
|
344
|
+
|
|
345
|
+
if (existingMessagesParsed.length === 0) {
|
|
346
|
+
return [];
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
const threadIdsToUpdate = new Set<string>();
|
|
350
|
+
const bulkOps = [];
|
|
351
|
+
|
|
352
|
+
for (const existingMessage of existingMessagesParsed) {
|
|
353
|
+
const updatePayload = messages.find(m => m.id === existingMessage.id);
|
|
354
|
+
if (!updatePayload) continue;
|
|
355
|
+
|
|
356
|
+
const { id, ...fieldsToUpdate } = updatePayload;
|
|
357
|
+
if (Object.keys(fieldsToUpdate).length === 0) continue;
|
|
358
|
+
|
|
359
|
+
threadIdsToUpdate.add(existingMessage.threadId!);
|
|
360
|
+
if (updatePayload.threadId && updatePayload.threadId !== existingMessage.threadId) {
|
|
361
|
+
threadIdsToUpdate.add(updatePayload.threadId);
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
const updateDoc: any = {};
|
|
365
|
+
const updatableFields = { ...fieldsToUpdate };
|
|
366
|
+
|
|
367
|
+
// Special handling for content field to merge instead of overwrite
|
|
368
|
+
if (updatableFields.content) {
|
|
369
|
+
const newContent = {
|
|
370
|
+
...existingMessage.content,
|
|
371
|
+
...updatableFields.content,
|
|
372
|
+
// Deep merge metadata if it exists on both
|
|
373
|
+
...(existingMessage.content?.metadata && updatableFields.content.metadata
|
|
374
|
+
? {
|
|
375
|
+
metadata: {
|
|
376
|
+
...existingMessage.content.metadata,
|
|
377
|
+
...updatableFields.content.metadata,
|
|
378
|
+
},
|
|
379
|
+
}
|
|
380
|
+
: {}),
|
|
381
|
+
};
|
|
382
|
+
updateDoc.content = JSON.stringify(newContent);
|
|
383
|
+
delete updatableFields.content;
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
// Handle other fields
|
|
387
|
+
for (const key in updatableFields) {
|
|
388
|
+
if (Object.prototype.hasOwnProperty.call(updatableFields, key)) {
|
|
389
|
+
const dbKey = key === 'threadId' ? 'thread_id' : key;
|
|
390
|
+
let value = updatableFields[key as keyof typeof updatableFields];
|
|
391
|
+
|
|
392
|
+
if (typeof value === 'object' && value !== null) {
|
|
393
|
+
value = JSON.stringify(value);
|
|
394
|
+
}
|
|
395
|
+
updateDoc[dbKey] = value;
|
|
396
|
+
}
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
if (Object.keys(updateDoc).length > 0) {
|
|
400
|
+
bulkOps.push({
|
|
401
|
+
updateOne: {
|
|
402
|
+
filter: { id },
|
|
403
|
+
update: { $set: updateDoc },
|
|
404
|
+
},
|
|
405
|
+
});
|
|
406
|
+
}
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
if (bulkOps.length > 0) {
|
|
410
|
+
await collection.bulkWrite(bulkOps);
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
// Update thread timestamps
|
|
414
|
+
if (threadIdsToUpdate.size > 0) {
|
|
415
|
+
const threadsCollection = await this.operations.getCollection(TABLE_THREADS);
|
|
416
|
+
await threadsCollection.updateMany(
|
|
417
|
+
{ id: { $in: Array.from(threadIdsToUpdate) } },
|
|
418
|
+
{ $set: { updatedAt: new Date() } },
|
|
419
|
+
);
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
// Re-fetch updated messages
|
|
423
|
+
const updatedMessages = await collection.find({ id: { $in: messageIds } }).toArray();
|
|
424
|
+
|
|
425
|
+
return updatedMessages.map((row: any) => this.parseRow(row));
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
async getResourceById({ resourceId }: { resourceId: string }): Promise<StorageResourceType | null> {
|
|
429
|
+
try {
|
|
430
|
+
const collection = await this.operations.getCollection(TABLE_RESOURCES);
|
|
431
|
+
const result = await collection.findOne<any>({ id: resourceId });
|
|
432
|
+
|
|
433
|
+
if (!result) {
|
|
434
|
+
return null;
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
return {
|
|
438
|
+
id: result.id,
|
|
439
|
+
workingMemory: result.workingMemory || '',
|
|
440
|
+
metadata: typeof result.metadata === 'string' ? safelyParseJSON(result.metadata) : result.metadata,
|
|
441
|
+
createdAt: formatDateForMongoDB(result.createdAt),
|
|
442
|
+
updatedAt: formatDateForMongoDB(result.updatedAt),
|
|
443
|
+
};
|
|
444
|
+
} catch (error) {
|
|
445
|
+
throw new MastraError(
|
|
446
|
+
{
|
|
447
|
+
id: 'STORAGE_MONGODB_STORE_GET_RESOURCE_BY_ID_FAILED',
|
|
448
|
+
domain: ErrorDomain.STORAGE,
|
|
449
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
450
|
+
details: { resourceId },
|
|
451
|
+
},
|
|
452
|
+
error,
|
|
453
|
+
);
|
|
454
|
+
}
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
async saveResource({ resource }: { resource: StorageResourceType }): Promise<StorageResourceType> {
|
|
458
|
+
try {
|
|
459
|
+
const collection = await this.operations.getCollection(TABLE_RESOURCES);
|
|
460
|
+
await collection.updateOne(
|
|
461
|
+
{ id: resource.id },
|
|
462
|
+
{
|
|
463
|
+
$set: {
|
|
464
|
+
...resource,
|
|
465
|
+
metadata: JSON.stringify(resource.metadata),
|
|
466
|
+
},
|
|
467
|
+
},
|
|
468
|
+
{ upsert: true },
|
|
469
|
+
);
|
|
470
|
+
|
|
471
|
+
return resource;
|
|
472
|
+
} catch (error) {
|
|
473
|
+
throw new MastraError(
|
|
474
|
+
{
|
|
475
|
+
id: 'STORAGE_MONGODB_STORE_SAVE_RESOURCE_FAILED',
|
|
476
|
+
domain: ErrorDomain.STORAGE,
|
|
477
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
478
|
+
details: { resourceId: resource.id },
|
|
479
|
+
},
|
|
480
|
+
error,
|
|
481
|
+
);
|
|
482
|
+
}
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
async updateResource({
|
|
486
|
+
resourceId,
|
|
487
|
+
workingMemory,
|
|
488
|
+
metadata,
|
|
489
|
+
}: {
|
|
490
|
+
resourceId: string;
|
|
491
|
+
workingMemory?: string;
|
|
492
|
+
metadata?: Record<string, unknown>;
|
|
493
|
+
}): Promise<StorageResourceType> {
|
|
494
|
+
try {
|
|
495
|
+
const existingResource = await this.getResourceById({ resourceId });
|
|
496
|
+
|
|
497
|
+
if (!existingResource) {
|
|
498
|
+
// Create new resource if it doesn't exist
|
|
499
|
+
const newResource: StorageResourceType = {
|
|
500
|
+
id: resourceId,
|
|
501
|
+
workingMemory: workingMemory || '',
|
|
502
|
+
metadata: metadata || {},
|
|
503
|
+
createdAt: new Date(),
|
|
504
|
+
updatedAt: new Date(),
|
|
505
|
+
};
|
|
506
|
+
return this.saveResource({ resource: newResource });
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
const updatedResource = {
|
|
510
|
+
...existingResource,
|
|
511
|
+
workingMemory: workingMemory !== undefined ? workingMemory : existingResource.workingMemory,
|
|
512
|
+
metadata: metadata ? { ...existingResource.metadata, ...metadata } : existingResource.metadata,
|
|
513
|
+
updatedAt: new Date(),
|
|
514
|
+
};
|
|
515
|
+
|
|
516
|
+
const collection = await this.operations.getCollection(TABLE_RESOURCES);
|
|
517
|
+
const updateDoc: any = { updatedAt: updatedResource.updatedAt };
|
|
518
|
+
|
|
519
|
+
if (workingMemory !== undefined) {
|
|
520
|
+
updateDoc.workingMemory = workingMemory;
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
if (metadata) {
|
|
524
|
+
updateDoc.metadata = JSON.stringify(updatedResource.metadata);
|
|
525
|
+
}
|
|
526
|
+
|
|
527
|
+
await collection.updateOne({ id: resourceId }, { $set: updateDoc });
|
|
528
|
+
|
|
529
|
+
return updatedResource;
|
|
530
|
+
} catch (error) {
|
|
531
|
+
throw new MastraError(
|
|
532
|
+
{
|
|
533
|
+
id: 'STORAGE_MONGODB_STORE_UPDATE_RESOURCE_FAILED',
|
|
534
|
+
domain: ErrorDomain.STORAGE,
|
|
535
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
536
|
+
details: { resourceId },
|
|
537
|
+
},
|
|
538
|
+
error,
|
|
539
|
+
);
|
|
540
|
+
}
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
async getThreadById({ threadId }: { threadId: string }): Promise<StorageThreadType | null> {
|
|
544
|
+
try {
|
|
545
|
+
const collection = await this.operations.getCollection(TABLE_THREADS);
|
|
546
|
+
const result = await collection.findOne<any>({ id: threadId });
|
|
547
|
+
if (!result) {
|
|
548
|
+
return null;
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
return {
|
|
552
|
+
...result,
|
|
553
|
+
metadata: typeof result.metadata === 'string' ? safelyParseJSON(result.metadata) : result.metadata,
|
|
554
|
+
};
|
|
555
|
+
} catch (error) {
|
|
556
|
+
throw new MastraError(
|
|
557
|
+
{
|
|
558
|
+
id: 'STORAGE_MONGODB_STORE_GET_THREAD_BY_ID_FAILED',
|
|
559
|
+
domain: ErrorDomain.STORAGE,
|
|
560
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
561
|
+
details: { threadId },
|
|
562
|
+
},
|
|
563
|
+
error,
|
|
564
|
+
);
|
|
565
|
+
}
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
async getThreadsByResourceId({ resourceId }: { resourceId: string }): Promise<StorageThreadType[]> {
|
|
569
|
+
try {
|
|
570
|
+
const collection = await this.operations.getCollection(TABLE_THREADS);
|
|
571
|
+
const results = await collection.find<any>({ resourceId }).toArray();
|
|
572
|
+
if (!results.length) {
|
|
573
|
+
return [];
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
return results.map((result: any) => ({
|
|
577
|
+
...result,
|
|
578
|
+
metadata: typeof result.metadata === 'string' ? safelyParseJSON(result.metadata) : result.metadata,
|
|
579
|
+
}));
|
|
580
|
+
} catch (error) {
|
|
581
|
+
throw new MastraError(
|
|
582
|
+
{
|
|
583
|
+
id: 'STORAGE_MONGODB_STORE_GET_THREADS_BY_RESOURCE_ID_FAILED',
|
|
584
|
+
domain: ErrorDomain.STORAGE,
|
|
585
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
586
|
+
details: { resourceId },
|
|
587
|
+
},
|
|
588
|
+
error,
|
|
589
|
+
);
|
|
590
|
+
}
|
|
591
|
+
}
|
|
592
|
+
|
|
593
|
+
public async getThreadsByResourceIdPaginated(args: {
|
|
594
|
+
resourceId: string;
|
|
595
|
+
page: number;
|
|
596
|
+
perPage: number;
|
|
597
|
+
}): Promise<PaginationInfo & { threads: StorageThreadType[] }> {
|
|
598
|
+
try {
|
|
599
|
+
const { resourceId, page, perPage } = args;
|
|
600
|
+
const collection = await this.operations.getCollection(TABLE_THREADS);
|
|
601
|
+
|
|
602
|
+
const query = { resourceId };
|
|
603
|
+
const total = await collection.countDocuments(query);
|
|
604
|
+
|
|
605
|
+
const threads = await collection
|
|
606
|
+
.find(query)
|
|
607
|
+
.sort({ updatedAt: -1 })
|
|
608
|
+
.skip(page * perPage)
|
|
609
|
+
.limit(perPage)
|
|
610
|
+
.toArray();
|
|
611
|
+
|
|
612
|
+
return {
|
|
613
|
+
threads: threads.map((thread: any) => ({
|
|
614
|
+
id: thread.id,
|
|
615
|
+
title: thread.title,
|
|
616
|
+
resourceId: thread.resourceId,
|
|
617
|
+
createdAt: formatDateForMongoDB(thread.createdAt),
|
|
618
|
+
updatedAt: formatDateForMongoDB(thread.updatedAt),
|
|
619
|
+
metadata: thread.metadata || {},
|
|
620
|
+
})),
|
|
621
|
+
total,
|
|
622
|
+
page,
|
|
623
|
+
perPage,
|
|
624
|
+
hasMore: (page + 1) * perPage < total,
|
|
625
|
+
};
|
|
626
|
+
} catch (error) {
|
|
627
|
+
throw new MastraError(
|
|
628
|
+
{
|
|
629
|
+
id: 'MONGODB_STORE_GET_THREADS_BY_RESOURCE_ID_PAGINATED_FAILED',
|
|
630
|
+
domain: ErrorDomain.STORAGE,
|
|
631
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
632
|
+
details: { resourceId: args.resourceId },
|
|
633
|
+
},
|
|
634
|
+
error,
|
|
635
|
+
);
|
|
636
|
+
}
|
|
637
|
+
}
|
|
638
|
+
|
|
639
|
+
async saveThread({ thread }: { thread: StorageThreadType }): Promise<StorageThreadType> {
|
|
640
|
+
try {
|
|
641
|
+
const collection = await this.operations.getCollection(TABLE_THREADS);
|
|
642
|
+
await collection.updateOne(
|
|
643
|
+
{ id: thread.id },
|
|
644
|
+
{
|
|
645
|
+
$set: {
|
|
646
|
+
...thread,
|
|
647
|
+
metadata: thread.metadata,
|
|
648
|
+
},
|
|
649
|
+
},
|
|
650
|
+
{ upsert: true },
|
|
651
|
+
);
|
|
652
|
+
return thread;
|
|
653
|
+
} catch (error) {
|
|
654
|
+
throw new MastraError(
|
|
655
|
+
{
|
|
656
|
+
id: 'STORAGE_MONGODB_STORE_SAVE_THREAD_FAILED',
|
|
657
|
+
domain: ErrorDomain.STORAGE,
|
|
658
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
659
|
+
details: { threadId: thread.id },
|
|
660
|
+
},
|
|
661
|
+
error,
|
|
662
|
+
);
|
|
663
|
+
}
|
|
664
|
+
}
|
|
665
|
+
|
|
666
|
+
async updateThread({
|
|
667
|
+
id,
|
|
668
|
+
title,
|
|
669
|
+
metadata,
|
|
670
|
+
}: {
|
|
671
|
+
id: string;
|
|
672
|
+
title: string;
|
|
673
|
+
metadata: Record<string, unknown>;
|
|
674
|
+
}): Promise<StorageThreadType> {
|
|
675
|
+
const thread = await this.getThreadById({ threadId: id });
|
|
676
|
+
if (!thread) {
|
|
677
|
+
throw new MastraError({
|
|
678
|
+
id: 'STORAGE_MONGODB_STORE_UPDATE_THREAD_NOT_FOUND',
|
|
679
|
+
domain: ErrorDomain.STORAGE,
|
|
680
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
681
|
+
details: { threadId: id, status: 404 },
|
|
682
|
+
text: `Thread ${id} not found`,
|
|
683
|
+
});
|
|
684
|
+
}
|
|
685
|
+
|
|
686
|
+
const updatedThread = {
|
|
687
|
+
...thread,
|
|
688
|
+
title,
|
|
689
|
+
metadata: {
|
|
690
|
+
...thread.metadata,
|
|
691
|
+
...metadata,
|
|
692
|
+
},
|
|
693
|
+
};
|
|
694
|
+
|
|
695
|
+
try {
|
|
696
|
+
const collection = await this.operations.getCollection(TABLE_THREADS);
|
|
697
|
+
await collection.updateOne(
|
|
698
|
+
{ id },
|
|
699
|
+
{
|
|
700
|
+
$set: {
|
|
701
|
+
title,
|
|
702
|
+
metadata: updatedThread.metadata,
|
|
703
|
+
},
|
|
704
|
+
},
|
|
705
|
+
);
|
|
706
|
+
} catch (error) {
|
|
707
|
+
throw new MastraError(
|
|
708
|
+
{
|
|
709
|
+
id: 'STORAGE_MONGODB_STORE_UPDATE_THREAD_FAILED',
|
|
710
|
+
domain: ErrorDomain.STORAGE,
|
|
711
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
712
|
+
details: { threadId: id },
|
|
713
|
+
},
|
|
714
|
+
error,
|
|
715
|
+
);
|
|
716
|
+
}
|
|
717
|
+
|
|
718
|
+
return updatedThread;
|
|
719
|
+
}
|
|
720
|
+
|
|
721
|
+
async deleteThread({ threadId }: { threadId: string }): Promise<void> {
|
|
722
|
+
try {
|
|
723
|
+
// First, delete all messages associated with the thread
|
|
724
|
+
const collectionMessages = await this.operations.getCollection(TABLE_MESSAGES);
|
|
725
|
+
await collectionMessages.deleteMany({ thread_id: threadId });
|
|
726
|
+
// Then delete the thread itself
|
|
727
|
+
const collectionThreads = await this.operations.getCollection(TABLE_THREADS);
|
|
728
|
+
await collectionThreads.deleteOne({ id: threadId });
|
|
729
|
+
} catch (error) {
|
|
730
|
+
throw new MastraError(
|
|
731
|
+
{
|
|
732
|
+
id: 'STORAGE_MONGODB_STORE_DELETE_THREAD_FAILED',
|
|
733
|
+
domain: ErrorDomain.STORAGE,
|
|
734
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
735
|
+
details: { threadId },
|
|
736
|
+
},
|
|
737
|
+
error,
|
|
738
|
+
);
|
|
739
|
+
}
|
|
740
|
+
}
|
|
741
|
+
}
|