@mastra/lance 0.1.1-alpha.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,999 @@
1
+ import { connect } from '@lancedb/lancedb';
2
+ import type { Connection, ConnectionOptions, SchemaLike, FieldLike } from '@lancedb/lancedb';
3
+ import type {
4
+ EvalRow,
5
+ MastraMessageV1,
6
+ MastraMessageV2,
7
+ StorageColumn,
8
+ StorageGetMessagesArg,
9
+ StorageThreadType,
10
+ TraceType,
11
+ WorkflowRun,
12
+ WorkflowRuns,
13
+ WorkflowRunState,
14
+ } from '@mastra/core';
15
+ import { MessageList } from '@mastra/core/agent';
16
+ import {
17
+ MastraStorage,
18
+ TABLE_EVALS,
19
+ TABLE_MESSAGES,
20
+ TABLE_THREADS,
21
+ TABLE_TRACES,
22
+ TABLE_WORKFLOW_SNAPSHOT,
23
+ } from '@mastra/core/storage';
24
+ import type { TABLE_NAMES } from '@mastra/core/storage';
25
+ import type { DataType } from 'apache-arrow';
26
+ import { Utf8, Int32, Float32, Binary, Schema, Field, Float64 } from 'apache-arrow';
27
+
28
+ export class LanceStorage extends MastraStorage {
29
+ private lanceClient!: Connection;
30
+
31
+ /**
32
+ * Creates a new instance of LanceStorage
33
+ * @param uri The URI to connect to LanceDB
34
+ * @param options connection options
35
+ *
36
+ * Usage:
37
+ *
38
+ * Connect to a local database
39
+ * ```ts
40
+ * const store = await LanceStorage.create('/path/to/db');
41
+ * ```
42
+ *
43
+ * Connect to a LanceDB cloud database
44
+ * ```ts
45
+ * const store = await LanceStorage.create('db://host:port');
46
+ * ```
47
+ *
48
+ * Connect to a cloud database
49
+ * ```ts
50
+ * const store = await LanceStorage.create('s3://bucket/db', { storageOptions: { timeout: '60s' } });
51
+ * ```
52
+ */
53
+ public static async create(name: string, uri: string, options?: ConnectionOptions): Promise<LanceStorage> {
54
+ const instance = new LanceStorage(name);
55
+ try {
56
+ instance.lanceClient = await connect(uri, options);
57
+ return instance;
58
+ } catch (e: any) {
59
+ throw new Error(`Failed to connect to LanceDB: ${e}`);
60
+ }
61
+ }
62
+
63
+ /**
64
+ * @internal
65
+ * Private constructor to enforce using the create factory method
66
+ */
67
+ private constructor(name: string) {
68
+ super({ name });
69
+ }
70
+
71
+ async createTable({
72
+ tableName,
73
+ schema,
74
+ }: {
75
+ tableName: TABLE_NAMES;
76
+ schema: Record<string, StorageColumn>;
77
+ }): Promise<void> {
78
+ try {
79
+ const arrowSchema = this.translateSchema(schema);
80
+ await this.lanceClient.createEmptyTable(tableName, arrowSchema);
81
+ } catch (error: any) {
82
+ throw new Error(`Failed to create table: ${error}`);
83
+ }
84
+ }
85
+
86
+ private translateSchema(schema: Record<string, StorageColumn>): Schema {
87
+ const fields = Object.entries(schema).map(([name, column]) => {
88
+ // Convert string type to Arrow DataType
89
+ let arrowType: DataType;
90
+ switch (column.type.toLowerCase()) {
91
+ case 'text':
92
+ case 'uuid':
93
+ arrowType = new Utf8();
94
+ break;
95
+ case 'int':
96
+ case 'integer':
97
+ arrowType = new Int32();
98
+ break;
99
+ case 'bigint':
100
+ arrowType = new Float64();
101
+ break;
102
+ case 'float':
103
+ arrowType = new Float32();
104
+ break;
105
+ case 'jsonb':
106
+ case 'json':
107
+ arrowType = new Utf8();
108
+ break;
109
+ case 'binary':
110
+ arrowType = new Binary();
111
+ break;
112
+ case 'timestamp':
113
+ arrowType = new Float64();
114
+ break;
115
+ default:
116
+ // Default to string for unknown types
117
+ arrowType = new Utf8();
118
+ }
119
+
120
+ // Create a field with the appropriate arrow type
121
+ return new Field(name, arrowType, column.nullable ?? true);
122
+ });
123
+
124
+ return new Schema(fields);
125
+ }
126
+
127
+ /**
128
+ * Drop a table if it exists
129
+ * @param tableName Name of the table to drop
130
+ */
131
+ async dropTable(tableName: TABLE_NAMES): Promise<void> {
132
+ try {
133
+ await this.lanceClient.dropTable(tableName);
134
+ } catch (error: any) {
135
+ // Don't throw if the table doesn't exist
136
+ if (error.toString().includes('was not found')) {
137
+ this.logger.debug(`Table '${tableName}' does not exist, skipping drop`);
138
+ return;
139
+ }
140
+ throw new Error(`Failed to drop table: ${error}`);
141
+ }
142
+ }
143
+
144
+ /**
145
+ * Get table schema
146
+ * @param tableName Name of the table
147
+ * @returns Table schema
148
+ */
149
+ async getTableSchema(tableName: TABLE_NAMES): Promise<SchemaLike> {
150
+ try {
151
+ const table = await this.lanceClient.openTable(tableName);
152
+ const rawSchema = await table.schema();
153
+ const fields = rawSchema.fields as FieldLike[];
154
+
155
+ // Convert schema to SchemaLike format
156
+ return {
157
+ fields,
158
+ metadata: new Map<string, string>(),
159
+ get names() {
160
+ return fields.map((field: FieldLike) => field.name);
161
+ },
162
+ };
163
+ } catch (error: any) {
164
+ throw new Error(`Failed to get table schema: ${error}`);
165
+ }
166
+ }
167
+
168
+ protected getDefaultValue(type: StorageColumn['type']): string {
169
+ switch (type) {
170
+ case 'text':
171
+ return "''";
172
+ case 'timestamp':
173
+ return 'CURRENT_TIMESTAMP';
174
+ case 'integer':
175
+ case 'bigint':
176
+ return '0';
177
+ case 'jsonb':
178
+ return "'{}'";
179
+ case 'uuid':
180
+ return "''";
181
+ default:
182
+ return super.getDefaultValue(type);
183
+ }
184
+ }
185
+
186
+ /**
187
+ * Alters table schema to add columns if they don't exist
188
+ * @param tableName Name of the table
189
+ * @param schema Schema of the table
190
+ * @param ifNotExists Array of column names to add if they don't exist
191
+ */
192
+ async alterTable({
193
+ tableName,
194
+ schema,
195
+ ifNotExists,
196
+ }: {
197
+ tableName: string;
198
+ schema: Record<string, StorageColumn>;
199
+ ifNotExists: string[];
200
+ }): Promise<void> {
201
+ const table = await this.lanceClient.openTable(tableName);
202
+ const currentSchema = await table.schema();
203
+ const existingFields = new Set(currentSchema.fields.map((f: any) => f.name));
204
+
205
+ const typeMap: Record<string, string> = {
206
+ text: 'string',
207
+ integer: 'int',
208
+ bigint: 'bigint',
209
+ timestamp: 'timestamp',
210
+ jsonb: 'string',
211
+ uuid: 'string',
212
+ };
213
+
214
+ // Find columns to add
215
+ const columnsToAdd = ifNotExists
216
+ .filter(col => schema[col] && !existingFields.has(col))
217
+ .map(col => {
218
+ const colDef = schema[col];
219
+ return {
220
+ name: col,
221
+ valueSql: colDef?.nullable
222
+ ? `cast(NULL as ${typeMap[colDef.type ?? 'text']})`
223
+ : `cast(${this.getDefaultValue(colDef?.type ?? 'text')} as ${typeMap[colDef?.type ?? 'text']})`,
224
+ };
225
+ });
226
+
227
+ if (columnsToAdd.length > 0) {
228
+ await table.addColumns(columnsToAdd);
229
+ this.logger?.info?.(`Added columns [${columnsToAdd.map(c => c.name).join(', ')}] to table ${tableName}`);
230
+ }
231
+ }
232
+
233
+ async clearTable({ tableName }: { tableName: TABLE_NAMES }): Promise<void> {
234
+ const table = await this.lanceClient.openTable(tableName);
235
+
236
+ // delete function always takes a predicate as an argument, so we use '1=1' to delete all records because it is always true.
237
+ await table.delete('1=1');
238
+ }
239
+
240
+ /**
241
+ * Insert a single record into a table. This function overwrites the existing record if it exists. Use this function for inserting records into tables with custom schemas.
242
+ * @param tableName The name of the table to insert into.
243
+ * @param record The record to insert.
244
+ */
245
+ async insert({ tableName, record }: { tableName: string; record: Record<string, any> }): Promise<void> {
246
+ try {
247
+ const table = await this.lanceClient.openTable(tableName);
248
+
249
+ const processedRecord = { ...record };
250
+
251
+ for (const key in processedRecord) {
252
+ if (
253
+ processedRecord[key] !== null &&
254
+ typeof processedRecord[key] === 'object' &&
255
+ !(processedRecord[key] instanceof Date)
256
+ ) {
257
+ this.logger.debug('Converting object to JSON string: ', processedRecord[key]);
258
+ processedRecord[key] = JSON.stringify(processedRecord[key]);
259
+ }
260
+ }
261
+
262
+ await table.add([processedRecord], { mode: 'overwrite' });
263
+ } catch (error: any) {
264
+ throw new Error(`Failed to insert record: ${error}`);
265
+ }
266
+ }
267
+
268
+ /**
269
+ * Insert multiple records into a table. This function overwrites the existing records if they exist. Use this function for inserting records into tables with custom schemas.
270
+ * @param tableName The name of the table to insert into.
271
+ * @param records The records to insert.
272
+ */
273
+ async batchInsert({ tableName, records }: { tableName: string; records: Record<string, any>[] }): Promise<void> {
274
+ try {
275
+ const table = await this.lanceClient.openTable(tableName);
276
+
277
+ const processedRecords = records.map(record => {
278
+ const processedRecord = { ...record };
279
+
280
+ // Convert values based on schema type
281
+ for (const key in processedRecord) {
282
+ // Skip null/undefined values
283
+ if (processedRecord[key] == null) continue;
284
+
285
+ if (
286
+ processedRecord[key] !== null &&
287
+ typeof processedRecord[key] === 'object' &&
288
+ !(processedRecord[key] instanceof Date)
289
+ ) {
290
+ processedRecord[key] = JSON.stringify(processedRecord[key]);
291
+ }
292
+ }
293
+
294
+ return processedRecord;
295
+ });
296
+
297
+ await table.add(processedRecords, { mode: 'overwrite' });
298
+ } catch (error: any) {
299
+ throw new Error(`Failed to batch insert records: ${error}`);
300
+ }
301
+ }
302
+
303
+ /**
304
+ * Load a record from the database by its key(s)
305
+ * @param tableName The name of the table to query
306
+ * @param keys Record of key-value pairs to use for lookup
307
+ * @throws Error if invalid types are provided for keys
308
+ * @returns The loaded record with proper type conversions, or null if not found
309
+ */
310
+ async load({ tableName, keys }: { tableName: TABLE_NAMES; keys: Record<string, any> }): Promise<any> {
311
+ try {
312
+ const table = await this.lanceClient.openTable(tableName);
313
+ const tableSchema = await this.getTableSchema(tableName);
314
+ const query = table.query();
315
+
316
+ // Build filter condition with 'and' between all conditions
317
+ if (Object.keys(keys).length > 0) {
318
+ // Validate key types against schema
319
+ this.validateKeyTypes(keys, tableSchema);
320
+
321
+ const filterConditions = Object.entries(keys)
322
+ .map(([key, value]) => {
323
+ // Check if key is in camelCase and wrap it in backticks if it is
324
+ const isCamelCase = /^[a-z][a-zA-Z]*$/.test(key) && /[A-Z]/.test(key);
325
+ const quotedKey = isCamelCase ? `\`${key}\`` : key;
326
+
327
+ // Handle different types appropriately
328
+ if (typeof value === 'string') {
329
+ return `${quotedKey} = '${value}'`;
330
+ } else if (value === null) {
331
+ return `${quotedKey} IS NULL`;
332
+ } else {
333
+ // For numbers, booleans, etc.
334
+ return `${quotedKey} = ${value}`;
335
+ }
336
+ })
337
+ .join(' AND ');
338
+
339
+ this.logger.debug('where clause generated: ' + filterConditions);
340
+ query.where(filterConditions);
341
+ }
342
+
343
+ const result = await query.limit(1).toArray();
344
+
345
+ if (result.length === 0) {
346
+ this.logger.debug('No record found');
347
+ return null;
348
+ }
349
+
350
+ // Process the result with type conversions
351
+ return this.processResultWithTypeConversion(result[0], tableSchema);
352
+ } catch (error: any) {
353
+ throw new Error(`Failed to load record: ${error}`);
354
+ }
355
+ }
356
+
357
+ /**
358
+ * Validates that key types match the schema definition
359
+ * @param keys The keys to validate
360
+ * @param tableSchema The table schema to validate against
361
+ * @throws Error if a key has an incompatible type
362
+ */
363
+ private validateKeyTypes(keys: Record<string, any>, tableSchema: SchemaLike): void {
364
+ // Create a map of field names to their expected types
365
+ const fieldTypes = new Map(
366
+ tableSchema.fields.map((field: any) => [field.name, field.type?.toString().toLowerCase()]),
367
+ );
368
+
369
+ for (const [key, value] of Object.entries(keys)) {
370
+ const fieldType = fieldTypes.get(key);
371
+
372
+ if (!fieldType) {
373
+ throw new Error(`Field '${key}' does not exist in table schema`);
374
+ }
375
+
376
+ // Type validation
377
+ if (value !== null) {
378
+ if ((fieldType.includes('int') || fieldType.includes('bigint')) && typeof value !== 'number') {
379
+ throw new Error(`Expected numeric value for field '${key}', got ${typeof value}`);
380
+ }
381
+
382
+ if (fieldType.includes('utf8') && typeof value !== 'string') {
383
+ throw new Error(`Expected string value for field '${key}', got ${typeof value}`);
384
+ }
385
+
386
+ if (fieldType.includes('timestamp') && !(value instanceof Date) && typeof value !== 'string') {
387
+ throw new Error(`Expected Date or string value for field '${key}', got ${typeof value}`);
388
+ }
389
+ }
390
+ }
391
+ }
392
+
393
+ /**
394
+ * Process a database result with appropriate type conversions based on the table schema
395
+ * @param rawResult The raw result object from the database
396
+ * @param tableSchema The schema of the table containing type information
397
+ * @returns Processed result with correct data types
398
+ */
399
+ private processResultWithTypeConversion(
400
+ rawResult: Record<string, any> | Record<string, any>[],
401
+ tableSchema: SchemaLike,
402
+ ): Record<string, any> | Record<string, any>[] {
403
+ // Build a map of field names to their schema types
404
+ const fieldTypeMap = new Map();
405
+ tableSchema.fields.forEach((field: any) => {
406
+ const fieldName = field.name;
407
+ const fieldTypeStr = field.type.toString().toLowerCase();
408
+ fieldTypeMap.set(fieldName, fieldTypeStr);
409
+ });
410
+
411
+ // Handle array case
412
+ if (Array.isArray(rawResult)) {
413
+ return rawResult.map(item => this.processResultWithTypeConversion(item, tableSchema));
414
+ }
415
+
416
+ // Handle single record case
417
+ const processedResult = { ...rawResult };
418
+
419
+ // Convert each field according to its schema type
420
+ for (const key in processedResult) {
421
+ const fieldTypeStr = fieldTypeMap.get(key);
422
+ if (!fieldTypeStr) continue;
423
+
424
+ // Skip conversion for ID fields - preserve their original format
425
+ // if (key === 'id') {
426
+ // continue;
427
+ // }
428
+
429
+ // Only try to convert string values
430
+ if (typeof processedResult[key] === 'string') {
431
+ // Numeric types
432
+ if (fieldTypeStr.includes('int32') || fieldTypeStr.includes('float32')) {
433
+ if (!isNaN(Number(processedResult[key]))) {
434
+ processedResult[key] = Number(processedResult[key]);
435
+ }
436
+ } else if (fieldTypeStr.includes('int64')) {
437
+ processedResult[key] = Number(processedResult[key]);
438
+ } else if (fieldTypeStr.includes('utf8')) {
439
+ try {
440
+ processedResult[key] = JSON.parse(processedResult[key]);
441
+ } catch (e) {
442
+ // If JSON parsing fails, keep the original string
443
+ this.logger.debug(`Failed to parse JSON for key ${key}: ${e}`);
444
+ }
445
+ }
446
+ } else if (typeof processedResult[key] === 'bigint') {
447
+ // Convert BigInt values to regular numbers for application layer
448
+ processedResult[key] = Number(processedResult[key]);
449
+ }
450
+ }
451
+
452
+ return processedResult;
453
+ }
454
+
455
+ getThreadById({ threadId }: { threadId: string }): Promise<StorageThreadType | null> {
456
+ try {
457
+ return this.load({ tableName: TABLE_THREADS, keys: { id: threadId } });
458
+ } catch (error: any) {
459
+ throw new Error(`Failed to get thread by ID: ${error}`);
460
+ }
461
+ }
462
+
463
+ async getThreadsByResourceId({ resourceId }: { resourceId: string }): Promise<StorageThreadType[]> {
464
+ try {
465
+ const table = await this.lanceClient.openTable(TABLE_THREADS);
466
+ // fetches all threads with the given resourceId
467
+ const query = table.query().where(`\`resourceId\` = '${resourceId}'`);
468
+
469
+ const records = await query.toArray();
470
+ return this.processResultWithTypeConversion(
471
+ records,
472
+ await this.getTableSchema(TABLE_THREADS),
473
+ ) as StorageThreadType[];
474
+ } catch (error: any) {
475
+ throw new Error(`Failed to get threads by resource ID: ${error}`);
476
+ }
477
+ }
478
+
479
+ /**
480
+ * Saves a thread to the database. This function doesn't overwrite existing threads.
481
+ * @param thread - The thread to save
482
+ * @returns The saved thread
483
+ */
484
+ async saveThread({ thread }: { thread: StorageThreadType }): Promise<StorageThreadType> {
485
+ try {
486
+ const record = { ...thread, metadata: JSON.stringify(thread.metadata) };
487
+ const table = await this.lanceClient.openTable(TABLE_THREADS);
488
+ await table.add([record], { mode: 'append' });
489
+
490
+ return thread;
491
+ } catch (error: any) {
492
+ throw new Error(`Failed to save thread: ${error}`);
493
+ }
494
+ }
495
+
496
+ async updateThread({
497
+ id,
498
+ title,
499
+ metadata,
500
+ }: {
501
+ id: string;
502
+ title: string;
503
+ metadata: Record<string, unknown>;
504
+ }): Promise<StorageThreadType> {
505
+ try {
506
+ const record = { id, title, metadata: JSON.stringify(metadata) };
507
+ const table = await this.lanceClient.openTable(TABLE_THREADS);
508
+ await table.add([record], { mode: 'overwrite' });
509
+
510
+ const query = table.query().where(`id = '${id}'`);
511
+
512
+ const records = await query.toArray();
513
+ return this.processResultWithTypeConversion(
514
+ records[0],
515
+ await this.getTableSchema(TABLE_THREADS),
516
+ ) as StorageThreadType;
517
+ } catch (error: any) {
518
+ throw new Error(`Failed to update thread: ${error}`);
519
+ }
520
+ }
521
+
522
+ async deleteThread({ threadId }: { threadId: string }): Promise<void> {
523
+ try {
524
+ const table = await this.lanceClient.openTable(TABLE_THREADS);
525
+ await table.delete(`id = '${threadId}'`);
526
+ } catch (error: any) {
527
+ throw new Error(`Failed to delete thread: ${error}`);
528
+ }
529
+ }
530
+
531
+ /**
532
+ * Processes messages to include context messages based on withPreviousMessages and withNextMessages
533
+ * @param records - The sorted array of records to process
534
+ * @param include - The array of include specifications with context parameters
535
+ * @returns The processed array with context messages included
536
+ */
537
+ private processMessagesWithContext(
538
+ records: any[],
539
+ include: { id: string; withPreviousMessages?: number; withNextMessages?: number }[],
540
+ ): any[] {
541
+ const messagesWithContext = include.filter(item => item.withPreviousMessages || item.withNextMessages);
542
+
543
+ if (messagesWithContext.length === 0) {
544
+ return records;
545
+ }
546
+
547
+ // Create a map of message id to index in the sorted array for quick lookup
548
+ const messageIndexMap = new Map<string, number>();
549
+ records.forEach((message, index) => {
550
+ messageIndexMap.set(message.id, index);
551
+ });
552
+
553
+ // Keep track of additional indices to include
554
+ const additionalIndices = new Set<number>();
555
+
556
+ for (const item of messagesWithContext) {
557
+ const messageIndex = messageIndexMap.get(item.id);
558
+ if (messageIndex !== undefined) {
559
+ // Add previous messages if requested
560
+ if (item.withPreviousMessages) {
561
+ const startIdx = Math.max(0, messageIndex - item.withPreviousMessages);
562
+ for (let i = startIdx; i < messageIndex; i++) {
563
+ additionalIndices.add(i);
564
+ }
565
+ }
566
+
567
+ // Add next messages if requested
568
+ if (item.withNextMessages) {
569
+ const endIdx = Math.min(records.length - 1, messageIndex + item.withNextMessages);
570
+ for (let i = messageIndex + 1; i <= endIdx; i++) {
571
+ additionalIndices.add(i);
572
+ }
573
+ }
574
+ }
575
+ }
576
+
577
+ // If we need to include additional messages, create a new set of records
578
+ if (additionalIndices.size === 0) {
579
+ return records;
580
+ }
581
+
582
+ // Get IDs of the records that matched the original query
583
+ const originalMatchIds = new Set(include.map(item => item.id));
584
+
585
+ // Create a set of all indices we need to include
586
+ const allIndices = new Set<number>();
587
+
588
+ // Add indices of originally matched messages
589
+ records.forEach((record, index) => {
590
+ if (originalMatchIds.has(record.id)) {
591
+ allIndices.add(index);
592
+ }
593
+ });
594
+
595
+ // Add the additional context message indices
596
+ additionalIndices.forEach(index => {
597
+ allIndices.add(index);
598
+ });
599
+
600
+ // Create a new filtered array with only the required messages
601
+ // while maintaining chronological order
602
+ return Array.from(allIndices)
603
+ .sort((a, b) => a - b)
604
+ .map(index => records[index]);
605
+ }
606
+
607
+ public async getMessages(args: StorageGetMessagesArg & { format?: 'v1' }): Promise<MastraMessageV1[]>;
608
+ public async getMessages(args: StorageGetMessagesArg & { format: 'v2' }): Promise<MastraMessageV2[]>;
609
+ public async getMessages({
610
+ threadId,
611
+ resourceId,
612
+ selectBy,
613
+ format,
614
+ threadConfig,
615
+ }: StorageGetMessagesArg & { format?: 'v1' | 'v2' }): Promise<MastraMessageV1[] | MastraMessageV2[]> {
616
+ try {
617
+ if (threadConfig) {
618
+ throw new Error('ThreadConfig is not supported by LanceDB storage');
619
+ }
620
+
621
+ const table = await this.lanceClient.openTable(TABLE_MESSAGES);
622
+ let query = table.query().where(`\`threadId\` = '${threadId}'`);
623
+
624
+ // Apply selectBy filters if provided
625
+ if (selectBy) {
626
+ // Handle 'include' to fetch specific messages
627
+ if (selectBy.include && selectBy.include.length > 0) {
628
+ const includeIds = selectBy.include.map(item => item.id);
629
+ // Add additional query to include specific message IDs
630
+ // This will be combined with the threadId filter
631
+ const includeClause = includeIds.map(id => `\`id\` = '${id}'`).join(' OR ');
632
+ query = query.where(`(\`threadId\` = '${threadId}' OR (${includeClause}))`);
633
+
634
+ // Note: The surrounding messages (withPreviousMessages/withNextMessages) will be
635
+ // handled after we retrieve the results
636
+ }
637
+ }
638
+
639
+ // Fetch all records matching the query
640
+ let records = await query.toArray();
641
+
642
+ // Sort the records chronologically
643
+ records.sort((a, b) => {
644
+ const dateA = new Date(a.createdAt).getTime();
645
+ const dateB = new Date(b.createdAt).getTime();
646
+ return dateA - dateB; // Ascending order
647
+ });
648
+
649
+ // Process the include.withPreviousMessages and include.withNextMessages if specified
650
+ if (selectBy?.include && selectBy.include.length > 0) {
651
+ records = this.processMessagesWithContext(records, selectBy.include);
652
+ }
653
+
654
+ // If we're fetching the last N messages, take only the last N after sorting
655
+ if (selectBy?.last !== undefined && selectBy.last !== false) {
656
+ records = records.slice(-selectBy.last);
657
+ }
658
+
659
+ const messages = this.processResultWithTypeConversion(records, await this.getTableSchema(TABLE_MESSAGES));
660
+ const normalized = messages.map((msg: MastraMessageV2 | MastraMessageV1) => ({
661
+ ...msg,
662
+ content:
663
+ typeof msg.content === 'string'
664
+ ? (() => {
665
+ try {
666
+ return JSON.parse(msg.content);
667
+ } catch {
668
+ return msg.content;
669
+ }
670
+ })()
671
+ : msg.content,
672
+ }));
673
+ const list = new MessageList({ threadId, resourceId }).add(normalized, 'memory');
674
+ if (format === 'v2') return list.get.all.v2();
675
+ return list.get.all.v1();
676
+ } catch (error: any) {
677
+ throw new Error(`Failed to get messages: ${error}`);
678
+ }
679
+ }
680
+
681
+ async saveMessages(args: { messages: MastraMessageV1[]; format?: undefined | 'v1' }): Promise<MastraMessageV1[]>;
682
+ async saveMessages(args: { messages: MastraMessageV2[]; format: 'v2' }): Promise<MastraMessageV2[]>;
683
+ async saveMessages(
684
+ args: { messages: MastraMessageV1[]; format?: undefined | 'v1' } | { messages: MastraMessageV2[]; format: 'v2' },
685
+ ): Promise<MastraMessageV2[] | MastraMessageV1[]> {
686
+ try {
687
+ const { messages, format = 'v1' } = args;
688
+ if (messages.length === 0) {
689
+ return [];
690
+ }
691
+
692
+ const threadId = messages[0]?.threadId;
693
+
694
+ if (!threadId) {
695
+ throw new Error('Thread ID is required');
696
+ }
697
+
698
+ const transformedMessages = messages.map((message: MastraMessageV2 | MastraMessageV1) => ({
699
+ ...message,
700
+ content: JSON.stringify(message.content),
701
+ }));
702
+
703
+ const table = await this.lanceClient.openTable(TABLE_MESSAGES);
704
+ await table.add(transformedMessages, { mode: 'overwrite' });
705
+ const list = new MessageList().add(messages, 'memory');
706
+ if (format === `v2`) return list.get.all.v2();
707
+ return list.get.all.v1();
708
+ } catch (error: any) {
709
+ throw new Error(`Failed to save messages: ${error}`);
710
+ }
711
+ }
712
+
713
+ async saveTrace({ trace }: { trace: TraceType }): Promise<TraceType> {
714
+ try {
715
+ const table = await this.lanceClient.openTable(TABLE_TRACES);
716
+ const record = {
717
+ ...trace,
718
+ attributes: JSON.stringify(trace.attributes),
719
+ status: JSON.stringify(trace.status),
720
+ events: JSON.stringify(trace.events),
721
+ links: JSON.stringify(trace.links),
722
+ other: JSON.stringify(trace.other),
723
+ };
724
+ await table.add([record], { mode: 'append' });
725
+
726
+ return trace;
727
+ } catch (error: any) {
728
+ throw new Error(`Failed to save trace: ${error}`);
729
+ }
730
+ }
731
+
732
+ async getTraceById({ traceId }: { traceId: string }): Promise<TraceType> {
733
+ try {
734
+ const table = await this.lanceClient.openTable(TABLE_TRACES);
735
+ const query = table.query().where(`id = '${traceId}'`);
736
+ const records = await query.toArray();
737
+ return this.processResultWithTypeConversion(records[0], await this.getTableSchema(TABLE_TRACES)) as TraceType;
738
+ } catch (error: any) {
739
+ throw new Error(`Failed to get trace by ID: ${error}`);
740
+ }
741
+ }
742
+
743
+ async getTraces({
744
+ name,
745
+ scope,
746
+ page = 1,
747
+ perPage = 10,
748
+ attributes,
749
+ }: {
750
+ name?: string;
751
+ scope?: string;
752
+ page: number;
753
+ perPage: number;
754
+ attributes?: Record<string, string>;
755
+ }): Promise<TraceType[]> {
756
+ try {
757
+ const table = await this.lanceClient.openTable(TABLE_TRACES);
758
+ const query = table.query();
759
+
760
+ if (name) {
761
+ query.where(`name = '${name}'`);
762
+ }
763
+
764
+ if (scope) {
765
+ query.where(`scope = '${scope}'`);
766
+ }
767
+
768
+ if (attributes) {
769
+ query.where(`attributes = '${JSON.stringify(attributes)}'`);
770
+ }
771
+
772
+ // Calculate offset based on page and perPage
773
+ const offset = (page - 1) * perPage;
774
+
775
+ // Apply limit for pagination
776
+ query.limit(perPage);
777
+
778
+ // Apply offset if greater than 0
779
+ if (offset > 0) {
780
+ query.offset(offset);
781
+ }
782
+
783
+ const records = await query.toArray();
784
+ return records.map(record => {
785
+ return {
786
+ ...record,
787
+ attributes: JSON.parse(record.attributes),
788
+ status: JSON.parse(record.status),
789
+ events: JSON.parse(record.events),
790
+ links: JSON.parse(record.links),
791
+ other: JSON.parse(record.other),
792
+ startTime: new Date(record.startTime),
793
+ endTime: new Date(record.endTime),
794
+ createdAt: new Date(record.createdAt),
795
+ };
796
+ }) as TraceType[];
797
+ } catch (error: any) {
798
+ throw new Error(`Failed to get traces: ${error}`);
799
+ }
800
+ }
801
+
802
+ async saveEvals({ evals }: { evals: EvalRow[] }): Promise<EvalRow[]> {
803
+ try {
804
+ const table = await this.lanceClient.openTable(TABLE_EVALS);
805
+ const transformedEvals = evals.map(evalRecord => ({
806
+ input: evalRecord.input,
807
+ output: evalRecord.output,
808
+ agent_name: evalRecord.agentName,
809
+ metric_name: evalRecord.metricName,
810
+ result: JSON.stringify(evalRecord.result),
811
+ instructions: evalRecord.instructions,
812
+ test_info: JSON.stringify(evalRecord.testInfo),
813
+ global_run_id: evalRecord.globalRunId,
814
+ run_id: evalRecord.runId,
815
+ created_at: new Date(evalRecord.createdAt).getTime(),
816
+ }));
817
+
818
+ await table.add(transformedEvals, { mode: 'append' });
819
+ return evals;
820
+ } catch (error: any) {
821
+ throw new Error(`Failed to save evals: ${error}`);
822
+ }
823
+ }
824
+
825
+ async getEvalsByAgentName(agentName: string, type?: 'test' | 'live'): Promise<EvalRow[]> {
826
+ try {
827
+ if (type) {
828
+ this.logger.warn('Type is not implemented yet in LanceDB storage');
829
+ }
830
+ const table = await this.lanceClient.openTable(TABLE_EVALS);
831
+ const query = table.query().where(`agent_name = '${agentName}'`);
832
+ const records = await query.toArray();
833
+ return records.map(record => {
834
+ return {
835
+ id: record.id,
836
+ input: record.input,
837
+ output: record.output,
838
+ agentName: record.agent_name,
839
+ metricName: record.metric_name,
840
+ result: JSON.parse(record.result),
841
+ instructions: record.instructions,
842
+ testInfo: JSON.parse(record.test_info),
843
+ globalRunId: record.global_run_id,
844
+ runId: record.run_id,
845
+ createdAt: new Date(record.created_at).toString(),
846
+ };
847
+ }) as EvalRow[];
848
+ } catch (error: any) {
849
+ throw new Error(`Failed to get evals by agent name: ${error}`);
850
+ }
851
+ }
852
+
853
+ private parseWorkflowRun(row: any): WorkflowRun {
854
+ let parsedSnapshot: WorkflowRunState | string = row.snapshot as string;
855
+ if (typeof parsedSnapshot === 'string') {
856
+ try {
857
+ parsedSnapshot = JSON.parse(row.snapshot as string) as WorkflowRunState;
858
+ } catch (e) {
859
+ // If parsing fails, return the raw snapshot string
860
+ console.warn(`Failed to parse snapshot for workflow ${row.workflow_name}: ${e}`);
861
+ }
862
+ }
863
+
864
+ return {
865
+ workflowName: row.workflow_name,
866
+ runId: row.run_id,
867
+ snapshot: parsedSnapshot,
868
+ createdAt: this.ensureDate(row.createdAt)!,
869
+ updatedAt: this.ensureDate(row.updatedAt)!,
870
+ resourceId: row.resourceId,
871
+ };
872
+ }
873
+
874
+ async getWorkflowRuns(args?: {
875
+ namespace?: string;
876
+ workflowName?: string;
877
+ fromDate?: Date;
878
+ toDate?: Date;
879
+ limit?: number;
880
+ offset?: number;
881
+ }): Promise<WorkflowRuns> {
882
+ try {
883
+ const table = await this.lanceClient.openTable(TABLE_WORKFLOW_SNAPSHOT);
884
+ const query = table.query();
885
+
886
+ if (args?.workflowName) {
887
+ query.where(`workflow_name = '${args.workflowName}'`);
888
+ }
889
+
890
+ if (args?.fromDate) {
891
+ query.where(`\`createdAt\` >= ${args.fromDate.getTime()}`);
892
+ }
893
+
894
+ if (args?.toDate) {
895
+ query.where(`\`createdAt\` <= ${args.toDate.getTime()}`);
896
+ }
897
+
898
+ if (args?.limit) {
899
+ query.limit(args.limit);
900
+ }
901
+
902
+ if (args?.offset) {
903
+ query.offset(args.offset);
904
+ }
905
+
906
+ const records = await query.toArray();
907
+ return {
908
+ runs: records.map(record => this.parseWorkflowRun(record)),
909
+ total: records.length,
910
+ };
911
+ } catch (error: any) {
912
+ throw new Error(`Failed to get workflow runs: ${error}`);
913
+ }
914
+ }
915
+
916
+ /**
917
+ * Retrieve a single workflow run by its runId.
918
+ * @param args The ID of the workflow run to retrieve
919
+ * @returns The workflow run object or null if not found
920
+ */
921
+ async getWorkflowRunById(args: { runId: string; workflowName?: string }): Promise<{
922
+ workflowName: string;
923
+ runId: string;
924
+ snapshot: any;
925
+ createdAt: Date;
926
+ updatedAt: Date;
927
+ } | null> {
928
+ try {
929
+ const table = await this.lanceClient.openTable(TABLE_WORKFLOW_SNAPSHOT);
930
+ let whereClause = `run_id = '${args.runId}'`;
931
+ if (args.workflowName) {
932
+ whereClause += ` AND workflow_name = '${args.workflowName}'`;
933
+ }
934
+ const query = table.query().where(whereClause);
935
+ const records = await query.toArray();
936
+ if (records.length === 0) return null;
937
+ const record = records[0];
938
+ return this.parseWorkflowRun(record);
939
+ } catch (error: any) {
940
+ throw new Error(`Failed to get workflow run by id: ${error}`);
941
+ }
942
+ }
943
+
944
+ async persistWorkflowSnapshot({
945
+ workflowName,
946
+ runId,
947
+ snapshot,
948
+ }: {
949
+ workflowName: string;
950
+ runId: string;
951
+ snapshot: WorkflowRunState;
952
+ }): Promise<void> {
953
+ try {
954
+ const table = await this.lanceClient.openTable(TABLE_WORKFLOW_SNAPSHOT);
955
+
956
+ // Try to find the existing record
957
+ const query = table.query().where(`workflow_name = '${workflowName}' AND run_id = '${runId}'`);
958
+ const records = await query.toArray();
959
+ let createdAt: number;
960
+ const now = Date.now();
961
+ let mode: 'append' | 'overwrite' = 'append';
962
+
963
+ if (records.length > 0) {
964
+ createdAt = records[0].createdAt ?? now;
965
+ mode = 'overwrite';
966
+ } else {
967
+ createdAt = now;
968
+ }
969
+
970
+ const record = {
971
+ workflow_name: workflowName,
972
+ run_id: runId,
973
+ snapshot: JSON.stringify(snapshot),
974
+ createdAt,
975
+ updatedAt: now,
976
+ };
977
+
978
+ await table.add([record], { mode });
979
+ } catch (error: any) {
980
+ throw new Error(`Failed to persist workflow snapshot: ${error}`);
981
+ }
982
+ }
983
+ async loadWorkflowSnapshot({
984
+ workflowName,
985
+ runId,
986
+ }: {
987
+ workflowName: string;
988
+ runId: string;
989
+ }): Promise<WorkflowRunState | null> {
990
+ try {
991
+ const table = await this.lanceClient.openTable(TABLE_WORKFLOW_SNAPSHOT);
992
+ const query = table.query().where(`workflow_name = '${workflowName}' AND run_id = '${runId}'`);
993
+ const records = await query.toArray();
994
+ return records.length > 0 ? JSON.parse(records[0].snapshot) : null;
995
+ } catch (error: any) {
996
+ throw new Error(`Failed to load workflow snapshot: ${error}`);
997
+ }
998
+ }
999
+ }