@mastra/mongodb 0.10.0 → 0.10.1-alpha.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +7 -7
- package/CHANGELOG.md +9 -0
- package/README.md +50 -0
- package/dist/_tsup-dts-rollup.d.cts +103 -0
- package/dist/_tsup-dts-rollup.d.ts +103 -0
- package/dist/index.cjs +520 -0
- package/dist/index.d.cts +2 -0
- package/dist/index.d.ts +2 -0
- package/dist/index.js +520 -1
- package/docker-compose.yaml +30 -0
- package/package.json +3 -3
- package/src/index.ts +1 -0
- package/src/storage/index.test.ts +779 -0
- package/src/storage/index.ts +674 -0
- package/docker-compose.yml +0 -8
|
@@ -0,0 +1,674 @@
|
|
|
1
|
+
import type { MetricResult, TestInfo } from '@mastra/core/eval';
|
|
2
|
+
import type { MessageType, StorageThreadType } from '@mastra/core/memory';
|
|
3
|
+
import type { EvalRow, StorageGetMessagesArg, TABLE_NAMES, WorkflowRun } from '@mastra/core/storage';
|
|
4
|
+
import {
|
|
5
|
+
MastraStorage,
|
|
6
|
+
TABLE_EVALS,
|
|
7
|
+
TABLE_MESSAGES,
|
|
8
|
+
TABLE_THREADS,
|
|
9
|
+
TABLE_TRACES,
|
|
10
|
+
TABLE_WORKFLOW_SNAPSHOT,
|
|
11
|
+
} from '@mastra/core/storage';
|
|
12
|
+
import type { WorkflowRunState } from '@mastra/core/workflows';
|
|
13
|
+
import type { Db } from 'mongodb';
|
|
14
|
+
import { MongoClient } from 'mongodb';
|
|
15
|
+
|
|
16
|
+
function safelyParseJSON(jsonString: string): any {
|
|
17
|
+
try {
|
|
18
|
+
return JSON.parse(jsonString);
|
|
19
|
+
} catch {
|
|
20
|
+
return {};
|
|
21
|
+
}
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
export interface MongoDBConfig {
|
|
25
|
+
url: string;
|
|
26
|
+
dbName: string;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
export class MongoDBStore extends MastraStorage {
|
|
30
|
+
#isConnected = false;
|
|
31
|
+
#client: MongoClient;
|
|
32
|
+
#db: Db | undefined;
|
|
33
|
+
readonly #dbName: string;
|
|
34
|
+
|
|
35
|
+
constructor(config: MongoDBConfig) {
|
|
36
|
+
super({ name: 'MongoDBStore' });
|
|
37
|
+
this.#isConnected = false;
|
|
38
|
+
|
|
39
|
+
if (!config.url?.trim().length) {
|
|
40
|
+
throw new Error(
|
|
41
|
+
'MongoDBStore: url must be provided and cannot be empty. Passing an empty string may cause fallback to local MongoDB defaults.',
|
|
42
|
+
);
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
if (!config.dbName?.trim().length) {
|
|
46
|
+
throw new Error(
|
|
47
|
+
'MongoDBStore: dbName must be provided and cannot be empty. Passing an empty string may cause fallback to local MongoDB defaults.',
|
|
48
|
+
);
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
this.#dbName = config.dbName;
|
|
52
|
+
this.#client = new MongoClient(config.url);
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
private async getConnection(): Promise<Db> {
|
|
56
|
+
if (this.#isConnected) {
|
|
57
|
+
return this.#db!;
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
await this.#client.connect();
|
|
61
|
+
this.#db = this.#client.db(this.#dbName);
|
|
62
|
+
this.#isConnected = true;
|
|
63
|
+
return this.#db;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
private async getCollection(collectionName: string) {
|
|
67
|
+
const db = await this.getConnection();
|
|
68
|
+
return db.collection(collectionName);
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
async createTable(): Promise<void> {
|
|
72
|
+
// Nothing to do here, MongoDB is schemaless
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
async clearTable({ tableName }: { tableName: TABLE_NAMES }): Promise<void> {
|
|
76
|
+
try {
|
|
77
|
+
const collection = await this.getCollection(tableName);
|
|
78
|
+
await collection.deleteMany({});
|
|
79
|
+
} catch (error) {
|
|
80
|
+
if (error instanceof Error) {
|
|
81
|
+
this.logger.error(error.message);
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
async insert({ tableName, record }: { tableName: TABLE_NAMES; record: Record<string, any> }): Promise<void> {
|
|
87
|
+
try {
|
|
88
|
+
const collection = await this.getCollection(tableName);
|
|
89
|
+
await collection.insertOne(record);
|
|
90
|
+
} catch (error) {
|
|
91
|
+
this.logger.error(`Error upserting into table ${tableName}: ${error}`);
|
|
92
|
+
throw error;
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
async batchInsert({ tableName, records }: { tableName: TABLE_NAMES; records: Record<string, any>[] }): Promise<void> {
|
|
97
|
+
if (!records.length) {
|
|
98
|
+
return;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
try {
|
|
102
|
+
const collection = await this.getCollection(tableName);
|
|
103
|
+
await collection.insertMany(records);
|
|
104
|
+
} catch (error) {
|
|
105
|
+
this.logger.error(`Error upserting into table ${tableName}: ${error}`);
|
|
106
|
+
throw error;
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
async load<R>({ tableName, keys }: { tableName: TABLE_NAMES; keys: Record<string, string> }): Promise<R | null> {
|
|
111
|
+
this.logger.info(`Loading ${tableName} with keys ${JSON.stringify(keys)}`);
|
|
112
|
+
try {
|
|
113
|
+
const collection = await this.getCollection(tableName);
|
|
114
|
+
return (await collection.find(keys).toArray()) as R;
|
|
115
|
+
} catch (error) {
|
|
116
|
+
this.logger.error(`Error loading ${tableName} with keys ${JSON.stringify(keys)}: ${error}`);
|
|
117
|
+
throw error;
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
async getThreadById({ threadId }: { threadId: string }): Promise<StorageThreadType | null> {
|
|
122
|
+
try {
|
|
123
|
+
const collection = await this.getCollection(TABLE_THREADS);
|
|
124
|
+
const result = await collection.findOne<any>({ id: threadId });
|
|
125
|
+
if (!result) {
|
|
126
|
+
return null;
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
return {
|
|
130
|
+
...result,
|
|
131
|
+
metadata: typeof result.metadata === 'string' ? JSON.parse(result.metadata) : result.metadata,
|
|
132
|
+
};
|
|
133
|
+
} catch (error) {
|
|
134
|
+
this.logger.error(`Error loading thread with ID ${threadId}: ${error}`);
|
|
135
|
+
throw error;
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
async getThreadsByResourceId({ resourceId }: { resourceId: string }): Promise<StorageThreadType[]> {
|
|
140
|
+
try {
|
|
141
|
+
const collection = await this.getCollection(TABLE_THREADS);
|
|
142
|
+
const results = await collection.find<any>({ resourceId }).toArray();
|
|
143
|
+
if (!results.length) {
|
|
144
|
+
return [];
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
return results.map(result => ({
|
|
148
|
+
...result,
|
|
149
|
+
metadata: typeof result.metadata === 'string' ? JSON.parse(result.metadata) : result.metadata,
|
|
150
|
+
}));
|
|
151
|
+
} catch (error) {
|
|
152
|
+
this.logger.error(`Error loading threads by resourceId ${resourceId}: ${error}`);
|
|
153
|
+
throw error;
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
async saveThread({ thread }: { thread: StorageThreadType }): Promise<StorageThreadType> {
|
|
158
|
+
try {
|
|
159
|
+
const collection = await this.getCollection(TABLE_THREADS);
|
|
160
|
+
await collection.updateOne(
|
|
161
|
+
{ id: thread.id },
|
|
162
|
+
{
|
|
163
|
+
$set: {
|
|
164
|
+
...thread,
|
|
165
|
+
metadata: JSON.stringify(thread.metadata),
|
|
166
|
+
},
|
|
167
|
+
},
|
|
168
|
+
{ upsert: true },
|
|
169
|
+
);
|
|
170
|
+
return thread;
|
|
171
|
+
} catch (error) {
|
|
172
|
+
this.logger.error(`Error saving thread ${thread.id}: ${error}`);
|
|
173
|
+
throw error;
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
async updateThread({
|
|
178
|
+
id,
|
|
179
|
+
title,
|
|
180
|
+
metadata,
|
|
181
|
+
}: {
|
|
182
|
+
id: string;
|
|
183
|
+
title: string;
|
|
184
|
+
metadata: Record<string, unknown>;
|
|
185
|
+
}): Promise<StorageThreadType> {
|
|
186
|
+
const thread = await this.getThreadById({ threadId: id });
|
|
187
|
+
if (!thread) {
|
|
188
|
+
throw new Error(`Thread ${id} not found`);
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
const updatedThread = {
|
|
192
|
+
...thread,
|
|
193
|
+
title,
|
|
194
|
+
metadata: {
|
|
195
|
+
...thread.metadata,
|
|
196
|
+
...metadata,
|
|
197
|
+
},
|
|
198
|
+
};
|
|
199
|
+
|
|
200
|
+
try {
|
|
201
|
+
const collection = await this.getCollection(TABLE_THREADS);
|
|
202
|
+
await collection.updateOne(
|
|
203
|
+
{ id },
|
|
204
|
+
{
|
|
205
|
+
$set: {
|
|
206
|
+
title,
|
|
207
|
+
metadata: JSON.stringify(updatedThread.metadata),
|
|
208
|
+
},
|
|
209
|
+
},
|
|
210
|
+
);
|
|
211
|
+
} catch (error) {
|
|
212
|
+
this.logger.error(`Error updating thread ${id}:) ${error}`);
|
|
213
|
+
throw error;
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
return updatedThread;
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
async deleteThread({ threadId }: { threadId: string }): Promise<void> {
|
|
220
|
+
try {
|
|
221
|
+
// First, delete all messages associated with the thread
|
|
222
|
+
const collectionMessages = await this.getCollection(TABLE_MESSAGES);
|
|
223
|
+
await collectionMessages.deleteMany({ thread_id: threadId });
|
|
224
|
+
// Then delete the thread itself
|
|
225
|
+
const collectionThreads = await this.getCollection(TABLE_THREADS);
|
|
226
|
+
await collectionThreads.deleteOne({ id: threadId });
|
|
227
|
+
} catch (error) {
|
|
228
|
+
this.logger.error(`Error deleting thread ${threadId}: ${error}`);
|
|
229
|
+
throw error;
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
async getMessages<T = unknown>({ threadId, selectBy }: StorageGetMessagesArg): Promise<T[]> {
|
|
234
|
+
try {
|
|
235
|
+
const limit = typeof selectBy?.last === 'number' ? selectBy.last : 40;
|
|
236
|
+
const include = selectBy?.include || [];
|
|
237
|
+
let messages: MessageType[] = [];
|
|
238
|
+
let allMessages: MessageType[] = [];
|
|
239
|
+
const collection = await this.getCollection(TABLE_MESSAGES);
|
|
240
|
+
// Get all messages from the thread ordered by creation date descending
|
|
241
|
+
allMessages = (await collection.find({ thread_id: threadId }).sort({ createdAt: -1 }).toArray()).map((row: any) =>
|
|
242
|
+
this.parseRow(row),
|
|
243
|
+
);
|
|
244
|
+
|
|
245
|
+
// If there are messages to include, select the messages around the included IDs
|
|
246
|
+
if (include.length) {
|
|
247
|
+
// Map IDs to their position in the ordered array
|
|
248
|
+
const idToIndex = new Map<string, number>();
|
|
249
|
+
allMessages.forEach((msg, idx) => {
|
|
250
|
+
idToIndex.set(msg.id, idx);
|
|
251
|
+
});
|
|
252
|
+
|
|
253
|
+
const selectedIndexes = new Set<number>();
|
|
254
|
+
for (const inc of include) {
|
|
255
|
+
const idx = idToIndex.get(inc.id);
|
|
256
|
+
if (idx === undefined) continue;
|
|
257
|
+
// Previous messages
|
|
258
|
+
for (let i = 1; i <= (inc.withPreviousMessages || 0); i++) {
|
|
259
|
+
if (idx + i < allMessages.length) selectedIndexes.add(idx + i);
|
|
260
|
+
}
|
|
261
|
+
// Included message
|
|
262
|
+
selectedIndexes.add(idx);
|
|
263
|
+
// Next messages
|
|
264
|
+
for (let i = 1; i <= (inc.withNextMessages || 0); i++) {
|
|
265
|
+
if (idx - i >= 0) selectedIndexes.add(idx - i);
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
// Add the selected messages, filtering out undefined
|
|
269
|
+
messages.push(
|
|
270
|
+
...Array.from(selectedIndexes)
|
|
271
|
+
.map(i => allMessages[i])
|
|
272
|
+
.filter((m): m is MessageType => !!m),
|
|
273
|
+
);
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
// Get the remaining messages, excluding those already selected
|
|
277
|
+
const excludeIds = new Set(messages.map(m => m.id));
|
|
278
|
+
for (const msg of allMessages) {
|
|
279
|
+
if (messages.length >= limit) break;
|
|
280
|
+
if (!excludeIds.has(msg.id)) {
|
|
281
|
+
messages.push(msg);
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
// Sort all messages by creation date ascending
|
|
286
|
+
messages.sort((a, b) => a.createdAt.getTime() - b.createdAt.getTime());
|
|
287
|
+
|
|
288
|
+
return messages.slice(0, limit) as T[];
|
|
289
|
+
} catch (error) {
|
|
290
|
+
this.logger.error('Error getting messages:', error as Error);
|
|
291
|
+
throw error;
|
|
292
|
+
}
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
async saveMessages({ messages }: { messages: MessageType[] }): Promise<MessageType[]> {
|
|
296
|
+
if (!messages.length) {
|
|
297
|
+
return messages;
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
const threadId = messages[0]?.threadId;
|
|
301
|
+
if (!threadId) {
|
|
302
|
+
this.logger.error('Thread ID is required to save messages');
|
|
303
|
+
throw new Error('Thread ID is required');
|
|
304
|
+
}
|
|
305
|
+
try {
|
|
306
|
+
// Prepare batch statements for all messages
|
|
307
|
+
const messagesToInsert = messages.map(message => {
|
|
308
|
+
const time = message.createdAt || new Date();
|
|
309
|
+
return {
|
|
310
|
+
id: message.id,
|
|
311
|
+
thread_id: threadId,
|
|
312
|
+
content: typeof message.content === 'string' ? message.content : JSON.stringify(message.content),
|
|
313
|
+
role: message.role,
|
|
314
|
+
type: message.type,
|
|
315
|
+
resourceId: message.resourceId,
|
|
316
|
+
createdAt: time instanceof Date ? time.toISOString() : time,
|
|
317
|
+
};
|
|
318
|
+
});
|
|
319
|
+
|
|
320
|
+
// Execute all inserts in a single batch
|
|
321
|
+
const collection = await this.getCollection(TABLE_MESSAGES);
|
|
322
|
+
await collection.insertMany(messagesToInsert);
|
|
323
|
+
return messages;
|
|
324
|
+
} catch (error) {
|
|
325
|
+
this.logger.error('Failed to save messages in database: ' + (error as { message: string })?.message);
|
|
326
|
+
throw error;
|
|
327
|
+
}
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
async getTraces(
|
|
331
|
+
{
|
|
332
|
+
name,
|
|
333
|
+
scope,
|
|
334
|
+
page,
|
|
335
|
+
perPage,
|
|
336
|
+
attributes,
|
|
337
|
+
filters,
|
|
338
|
+
}: {
|
|
339
|
+
name?: string;
|
|
340
|
+
scope?: string;
|
|
341
|
+
page: number;
|
|
342
|
+
perPage: number;
|
|
343
|
+
attributes?: Record<string, string>;
|
|
344
|
+
filters?: Record<string, any>;
|
|
345
|
+
} = {
|
|
346
|
+
page: 0,
|
|
347
|
+
perPage: 100,
|
|
348
|
+
},
|
|
349
|
+
): Promise<any[]> {
|
|
350
|
+
const limit = perPage;
|
|
351
|
+
const offset = page * perPage;
|
|
352
|
+
|
|
353
|
+
const query: any = {};
|
|
354
|
+
if (name) {
|
|
355
|
+
query['name'] = `%${name}%`;
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
if (scope) {
|
|
359
|
+
query['scope'] = scope;
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
if (attributes) {
|
|
363
|
+
Object.keys(attributes).forEach(key => {
|
|
364
|
+
query[`attributes.${key}`] = attributes[key];
|
|
365
|
+
});
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
if (filters) {
|
|
369
|
+
Object.entries(filters).forEach(([key, value]) => {
|
|
370
|
+
query[key] = value;
|
|
371
|
+
});
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
const collection = await this.getCollection(TABLE_TRACES);
|
|
375
|
+
const result = await collection
|
|
376
|
+
.find(query, {
|
|
377
|
+
sort: { startTime: -1 },
|
|
378
|
+
})
|
|
379
|
+
.limit(limit)
|
|
380
|
+
.skip(offset)
|
|
381
|
+
.toArray();
|
|
382
|
+
|
|
383
|
+
return result.map(row => ({
|
|
384
|
+
id: row.id,
|
|
385
|
+
parentSpanId: row.parentSpanId,
|
|
386
|
+
traceId: row.traceId,
|
|
387
|
+
name: row.name,
|
|
388
|
+
scope: row.scope,
|
|
389
|
+
kind: row.kind,
|
|
390
|
+
status: safelyParseJSON(row.status as string),
|
|
391
|
+
events: safelyParseJSON(row.events as string),
|
|
392
|
+
links: safelyParseJSON(row.links as string),
|
|
393
|
+
attributes: safelyParseJSON(row.attributes as string),
|
|
394
|
+
startTime: row.startTime,
|
|
395
|
+
endTime: row.endTime,
|
|
396
|
+
other: safelyParseJSON(row.other as string),
|
|
397
|
+
createdAt: row.createdAt,
|
|
398
|
+
})) as any;
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
async getWorkflowRuns({
|
|
402
|
+
workflowName,
|
|
403
|
+
fromDate,
|
|
404
|
+
toDate,
|
|
405
|
+
limit,
|
|
406
|
+
offset,
|
|
407
|
+
}: {
|
|
408
|
+
workflowName?: string;
|
|
409
|
+
fromDate?: Date;
|
|
410
|
+
toDate?: Date;
|
|
411
|
+
limit?: number;
|
|
412
|
+
offset?: number;
|
|
413
|
+
} = {}): Promise<{
|
|
414
|
+
runs: Array<{
|
|
415
|
+
workflowName: string;
|
|
416
|
+
runId: string;
|
|
417
|
+
snapshot: WorkflowRunState | string;
|
|
418
|
+
createdAt: Date;
|
|
419
|
+
updatedAt: Date;
|
|
420
|
+
}>;
|
|
421
|
+
total: number;
|
|
422
|
+
}> {
|
|
423
|
+
const query: any = {};
|
|
424
|
+
if (workflowName) {
|
|
425
|
+
query['workflow_name'] = workflowName;
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
if (fromDate || toDate) {
|
|
429
|
+
query['createdAt'] = {};
|
|
430
|
+
if (fromDate) {
|
|
431
|
+
query['createdAt']['$gte'] = fromDate;
|
|
432
|
+
}
|
|
433
|
+
if (toDate) {
|
|
434
|
+
query['createdAt']['$lte'] = toDate;
|
|
435
|
+
}
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
const collection = await this.getCollection(TABLE_WORKFLOW_SNAPSHOT);
|
|
439
|
+
let total = 0;
|
|
440
|
+
// Only get total count when using pagination
|
|
441
|
+
if (limit !== undefined && offset !== undefined) {
|
|
442
|
+
total = await collection.countDocuments(query);
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
// Get results
|
|
446
|
+
const request = collection.find(query).sort({ createdAt: 'desc' });
|
|
447
|
+
if (limit) {
|
|
448
|
+
request.limit(limit);
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
if (offset) {
|
|
452
|
+
request.skip(offset);
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
const result = await request.toArray();
|
|
456
|
+
const runs = result.map(row => {
|
|
457
|
+
let parsedSnapshot: WorkflowRunState | string = row.snapshot;
|
|
458
|
+
if (typeof parsedSnapshot === 'string') {
|
|
459
|
+
try {
|
|
460
|
+
parsedSnapshot = JSON.parse(row.snapshot as string) as WorkflowRunState;
|
|
461
|
+
} catch (e) {
|
|
462
|
+
// If parsing fails, return the raw snapshot string
|
|
463
|
+
console.warn(`Failed to parse snapshot for workflow ${row.workflow_name}: ${e}`);
|
|
464
|
+
}
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
return {
|
|
468
|
+
workflowName: row.workflow_name as string,
|
|
469
|
+
runId: row.run_id as string,
|
|
470
|
+
snapshot: parsedSnapshot,
|
|
471
|
+
createdAt: new Date(row.createdAt as string),
|
|
472
|
+
updatedAt: new Date(row.updatedAt as string),
|
|
473
|
+
};
|
|
474
|
+
});
|
|
475
|
+
|
|
476
|
+
// Use runs.length as total when not paginating
|
|
477
|
+
return { runs, total: total || runs.length };
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
async getEvalsByAgentName(agentName: string, type?: 'test' | 'live'): Promise<EvalRow[]> {
|
|
481
|
+
try {
|
|
482
|
+
const query: any = {
|
|
483
|
+
agent_name: agentName,
|
|
484
|
+
};
|
|
485
|
+
|
|
486
|
+
if (type === 'test') {
|
|
487
|
+
query['test_info'] = { $ne: null };
|
|
488
|
+
// is not possible to filter by test_info.testPath because it is not a json field
|
|
489
|
+
// query['test_info.testPath'] = { $ne: null };
|
|
490
|
+
}
|
|
491
|
+
|
|
492
|
+
if (type === 'live') {
|
|
493
|
+
// is not possible to filter by test_info.testPath because it is not a json field
|
|
494
|
+
query['test_info'] = null;
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
const collection = await this.getCollection(TABLE_EVALS);
|
|
498
|
+
const documents = await collection.find(query).sort({ created_at: 'desc' }).toArray();
|
|
499
|
+
const result = documents.map(row => this.transformEvalRow(row));
|
|
500
|
+
// Post filter to remove if test_info.testPath is null
|
|
501
|
+
return result.filter(row => {
|
|
502
|
+
if (type === 'live') {
|
|
503
|
+
return !Boolean(row.testInfo?.testPath);
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
if (type === 'test') {
|
|
507
|
+
return row.testInfo?.testPath !== null;
|
|
508
|
+
}
|
|
509
|
+
return true;
|
|
510
|
+
});
|
|
511
|
+
} catch (error) {
|
|
512
|
+
// Handle case where table doesn't exist yet
|
|
513
|
+
if (error instanceof Error && error.message.includes('no such table')) {
|
|
514
|
+
return [];
|
|
515
|
+
}
|
|
516
|
+
this.logger.error('Failed to get evals for the specified agent: ' + (error as any)?.message);
|
|
517
|
+
throw error;
|
|
518
|
+
}
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
async persistWorkflowSnapshot({
|
|
522
|
+
workflowName,
|
|
523
|
+
runId,
|
|
524
|
+
snapshot,
|
|
525
|
+
}: {
|
|
526
|
+
workflowName: string;
|
|
527
|
+
runId: string;
|
|
528
|
+
snapshot: WorkflowRunState;
|
|
529
|
+
}): Promise<void> {
|
|
530
|
+
try {
|
|
531
|
+
const now = new Date().toISOString();
|
|
532
|
+
const collection = await this.getCollection(TABLE_WORKFLOW_SNAPSHOT);
|
|
533
|
+
await collection.updateOne(
|
|
534
|
+
{ workflow_name: workflowName, run_id: runId },
|
|
535
|
+
{
|
|
536
|
+
$set: {
|
|
537
|
+
snapshot: JSON.stringify(snapshot),
|
|
538
|
+
updatedAt: now,
|
|
539
|
+
},
|
|
540
|
+
$setOnInsert: {
|
|
541
|
+
createdAt: now,
|
|
542
|
+
},
|
|
543
|
+
},
|
|
544
|
+
{ upsert: true },
|
|
545
|
+
);
|
|
546
|
+
} catch (error) {
|
|
547
|
+
this.logger.error(`Error persisting workflow snapshot: ${error}`);
|
|
548
|
+
throw error;
|
|
549
|
+
}
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
async loadWorkflowSnapshot({
|
|
553
|
+
workflowName,
|
|
554
|
+
runId,
|
|
555
|
+
}: {
|
|
556
|
+
workflowName: string;
|
|
557
|
+
runId: string;
|
|
558
|
+
}): Promise<WorkflowRunState | null> {
|
|
559
|
+
try {
|
|
560
|
+
const result = await this.load<any[]>({
|
|
561
|
+
tableName: TABLE_WORKFLOW_SNAPSHOT,
|
|
562
|
+
keys: {
|
|
563
|
+
workflow_name: workflowName,
|
|
564
|
+
run_id: runId,
|
|
565
|
+
},
|
|
566
|
+
});
|
|
567
|
+
|
|
568
|
+
if (!result?.length) {
|
|
569
|
+
return null;
|
|
570
|
+
}
|
|
571
|
+
|
|
572
|
+
return JSON.parse(result[0].snapshot);
|
|
573
|
+
} catch (error) {
|
|
574
|
+
console.error('Error loading workflow snapshot:', error);
|
|
575
|
+
throw error;
|
|
576
|
+
}
|
|
577
|
+
}
|
|
578
|
+
|
|
579
|
+
async getWorkflowRunById({
|
|
580
|
+
runId,
|
|
581
|
+
workflowName,
|
|
582
|
+
}: {
|
|
583
|
+
runId: string;
|
|
584
|
+
workflowName?: string;
|
|
585
|
+
}): Promise<WorkflowRun | null> {
|
|
586
|
+
try {
|
|
587
|
+
const query: any = {};
|
|
588
|
+
if (runId) {
|
|
589
|
+
query['run_id'] = runId;
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
if (workflowName) {
|
|
593
|
+
query['workflow_name'] = workflowName;
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
const collection = await this.getCollection(TABLE_WORKFLOW_SNAPSHOT);
|
|
597
|
+
const result = await collection.findOne(query);
|
|
598
|
+
if (!result) {
|
|
599
|
+
return null;
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
return this.parseWorkflowRun(result);
|
|
603
|
+
} catch (error) {
|
|
604
|
+
console.error('Error getting workflow run by ID:', error);
|
|
605
|
+
throw error;
|
|
606
|
+
}
|
|
607
|
+
}
|
|
608
|
+
|
|
609
|
+
private parseWorkflowRun(row: any): WorkflowRun {
|
|
610
|
+
let parsedSnapshot: WorkflowRunState | string = row.snapshot as string;
|
|
611
|
+
if (typeof parsedSnapshot === 'string') {
|
|
612
|
+
try {
|
|
613
|
+
parsedSnapshot = JSON.parse(row.snapshot as string) as WorkflowRunState;
|
|
614
|
+
} catch (e) {
|
|
615
|
+
// If parsing fails, return the raw snapshot string
|
|
616
|
+
console.warn(`Failed to parse snapshot for workflow ${row.workflow_name}: ${e}`);
|
|
617
|
+
}
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
return {
|
|
621
|
+
workflowName: row.workflow_name,
|
|
622
|
+
runId: row.run_id,
|
|
623
|
+
snapshot: parsedSnapshot,
|
|
624
|
+
createdAt: row.createdAt,
|
|
625
|
+
updatedAt: row.updatedAt,
|
|
626
|
+
resourceId: row.resourceId,
|
|
627
|
+
};
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
private parseRow(row: any): MessageType {
|
|
631
|
+
let content = row.content;
|
|
632
|
+
try {
|
|
633
|
+
content = JSON.parse(row.content);
|
|
634
|
+
} catch {
|
|
635
|
+
// use content as is if it's not JSON
|
|
636
|
+
}
|
|
637
|
+
return {
|
|
638
|
+
id: row.id,
|
|
639
|
+
content,
|
|
640
|
+
role: row.role,
|
|
641
|
+
type: row.type,
|
|
642
|
+
createdAt: new Date(row.createdAt as string),
|
|
643
|
+
threadId: row.thread_id,
|
|
644
|
+
} as MessageType;
|
|
645
|
+
}
|
|
646
|
+
|
|
647
|
+
private transformEvalRow(row: Record<string, any>): EvalRow {
|
|
648
|
+
let testInfoValue = null;
|
|
649
|
+
if (row.test_info) {
|
|
650
|
+
try {
|
|
651
|
+
testInfoValue = typeof row.test_info === 'string' ? JSON.parse(row.test_info) : row.test_info;
|
|
652
|
+
} catch (e) {
|
|
653
|
+
console.warn('Failed to parse test_info:', e);
|
|
654
|
+
}
|
|
655
|
+
}
|
|
656
|
+
|
|
657
|
+
return {
|
|
658
|
+
input: row.input as string,
|
|
659
|
+
output: row.output as string,
|
|
660
|
+
result: row.result as MetricResult,
|
|
661
|
+
agentName: row.agent_name as string,
|
|
662
|
+
metricName: row.metric_name as string,
|
|
663
|
+
instructions: row.instructions as string,
|
|
664
|
+
testInfo: testInfoValue as TestInfo,
|
|
665
|
+
globalRunId: row.global_run_id as string,
|
|
666
|
+
runId: row.run_id as string,
|
|
667
|
+
createdAt: row.created_at as string,
|
|
668
|
+
};
|
|
669
|
+
}
|
|
670
|
+
|
|
671
|
+
async close(): Promise<void> {
|
|
672
|
+
await this.#client.close();
|
|
673
|
+
}
|
|
674
|
+
}
|