@mastra/lance 1.0.3 → 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -131,10 +131,52 @@ function resolveLanceConfig(config) {
131
131
  }
132
132
  var LanceDB = class extends base.MastraBase {
133
133
  client;
134
+ /** Cache of actual table columns: tableName -> Set<columnName> */
135
+ /** Cache of actual table columns: tableName -> Promise<Set<columnName>> (stores in-flight promise to coalesce concurrent calls) */
136
+ tableColumnsCache = /* @__PURE__ */ new Map();
134
137
  constructor({ client }) {
135
138
  super({ name: "lance-db" });
136
139
  this.client = client;
137
140
  }
141
+ /**
142
+ * Gets the set of column names that actually exist in the database table.
143
+ * Results are cached; the cache is invalidated when alterTable() adds new columns.
144
+ */
145
+ async getTableColumns(tableName) {
146
+ const cached = this.tableColumnsCache.get(tableName);
147
+ if (cached) return cached;
148
+ const promise = (async () => {
149
+ try {
150
+ const table = await this.client.openTable(tableName);
151
+ const schema = await table.schema();
152
+ const columns = new Set(schema.fields.map((f) => f.name));
153
+ if (columns.size === 0) {
154
+ this.tableColumnsCache.delete(tableName);
155
+ }
156
+ return columns;
157
+ } catch {
158
+ this.tableColumnsCache.delete(tableName);
159
+ return /* @__PURE__ */ new Set();
160
+ }
161
+ })();
162
+ this.tableColumnsCache.set(tableName, promise);
163
+ return promise;
164
+ }
165
+ /**
166
+ * Filters a record to only include columns that exist in the actual database table.
167
+ * Unknown columns are silently dropped to ensure forward compatibility.
168
+ */
169
+ async filterRecordToKnownColumns(tableName, record) {
170
+ const knownColumns = await this.getTableColumns(tableName);
171
+ if (knownColumns.size === 0) return record;
172
+ const filtered = {};
173
+ for (const [key, value] of Object.entries(record)) {
174
+ if (knownColumns.has(key)) {
175
+ filtered[key] = value;
176
+ }
177
+ }
178
+ return filtered;
179
+ }
138
180
  getDefaultValue(type) {
139
181
  switch (type) {
140
182
  case "text":
@@ -234,6 +276,8 @@ var LanceDB = class extends base.MastraBase {
234
276
  },
235
277
  error$1
236
278
  );
279
+ } finally {
280
+ this.tableColumnsCache.delete(tableName);
237
281
  }
238
282
  }
239
283
  async dropTable({ tableName }) {
@@ -272,6 +316,8 @@ var LanceDB = class extends base.MastraBase {
272
316
  },
273
317
  error$1
274
318
  );
319
+ } finally {
320
+ this.tableColumnsCache.delete(tableName);
275
321
  }
276
322
  }
277
323
  async alterTable({
@@ -338,6 +384,8 @@ var LanceDB = class extends base.MastraBase {
338
384
  },
339
385
  error$1
340
386
  );
387
+ } finally {
388
+ this.tableColumnsCache.delete(tableName);
341
389
  }
342
390
  }
343
391
  async clearTable({ tableName }) {
@@ -408,7 +456,9 @@ var LanceDB = class extends base.MastraBase {
408
456
  processedRecord[key] = JSON.stringify(processedRecord[key]);
409
457
  }
410
458
  }
411
- await table.mergeInsert(primaryId).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute([processedRecord]);
459
+ const filteredRecord = await this.filterRecordToKnownColumns(tableName, processedRecord);
460
+ if (Object.keys(filteredRecord).length === 0) return;
461
+ await table.mergeInsert(primaryId).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute([filteredRecord]);
412
462
  } catch (error$1) {
413
463
  throw new error.MastraError(
414
464
  {
@@ -457,7 +507,12 @@ var LanceDB = class extends base.MastraBase {
457
507
  }
458
508
  return processedRecord;
459
509
  });
460
- await table.mergeInsert(primaryId).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute(processedRecords);
510
+ const filteredRecords = await Promise.all(
511
+ processedRecords.map((r) => this.filterRecordToKnownColumns(tableName, r))
512
+ );
513
+ const nonEmptyRecords = filteredRecords.filter((r) => Object.keys(r).length > 0);
514
+ if (nonEmptyRecords.length === 0) return;
515
+ await table.mergeInsert(primaryId).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute(nonEmptyRecords);
461
516
  } catch (error$1) {
462
517
  throw new error.MastraError(
463
518
  {
@@ -808,6 +863,20 @@ var StoreMemoryLance = class extends storage.MemoryStorage {
808
863
  conditions.push(`\`createdAt\` ${endOp} ${endTime}`);
809
864
  }
810
865
  const whereClause = conditions.join(" AND ");
866
+ if (perPage === 0 && (!include || include.length === 0)) {
867
+ return { messages: [], total: 0, page, perPage: perPageForResponse, hasMore: false };
868
+ }
869
+ if (perPage === 0 && include && include.length > 0) {
870
+ const includedMessages = await this._getIncludedMessages(table, include);
871
+ const list2 = new agent.MessageList().add(includedMessages, "memory");
872
+ return {
873
+ messages: this._sortMessages(list2.get.all.db(), field, direction),
874
+ total: 0,
875
+ page,
876
+ perPage: perPageForResponse,
877
+ hasMore: false
878
+ };
879
+ }
811
880
  const total = await table.countRows(whereClause);
812
881
  const query = table.query().where(whereClause);
813
882
  let allRecords = await query.toArray();
@@ -835,16 +904,7 @@ var StoreMemoryLance = class extends storage.MemoryStorage {
835
904
  }
836
905
  const messageIds = new Set(messages.map((m) => m.id));
837
906
  if (include && include.length > 0) {
838
- const threadIds2 = [...new Set(include.map((item) => item.threadId || threadId))];
839
- const allThreadMessages = [];
840
- for (const tid of threadIds2) {
841
- const threadQuery = table.query().where(`thread_id = '${tid}'`);
842
- let threadRecords = await threadQuery.toArray();
843
- allThreadMessages.push(...threadRecords);
844
- }
845
- allThreadMessages.sort((a, b) => a.createdAt - b.createdAt);
846
- const contextMessages = this.processMessagesWithContext(allThreadMessages, include);
847
- const includedMessages = contextMessages.map((row) => this.normalizeMessage(row));
907
+ const includedMessages = await this._getIncludedMessages(table, include);
848
908
  for (const includeMsg of includedMessages) {
849
909
  if (!messageIds.has(includeMsg.id)) {
850
910
  messages.push(includeMsg);
@@ -854,17 +914,7 @@ var StoreMemoryLance = class extends storage.MemoryStorage {
854
914
  }
855
915
  const list = new agent.MessageList().add(messages, "memory");
856
916
  let finalMessages = list.get.all.db();
857
- finalMessages = finalMessages.sort((a, b) => {
858
- const aValue = field === "createdAt" ? new Date(a.createdAt).getTime() : a[field];
859
- const bValue = field === "createdAt" ? new Date(b.createdAt).getTime() : b[field];
860
- if (aValue == null && bValue == null) return 0;
861
- if (aValue == null) return direction === "ASC" ? -1 : 1;
862
- if (bValue == null) return direction === "ASC" ? 1 : -1;
863
- if (typeof aValue === "string" && typeof bValue === "string") {
864
- return direction === "ASC" ? aValue.localeCompare(bValue) : bValue.localeCompare(aValue);
865
- }
866
- return direction === "ASC" ? aValue - bValue : bValue - aValue;
867
- });
917
+ finalMessages = this._sortMessages(finalMessages, field, direction);
868
918
  const returnedThreadMessageIds = new Set(finalMessages.filter((m) => m.threadId === threadId).map((m) => m.id));
869
919
  const allThreadMessagesReturned = returnedThreadMessageIds.size >= total;
870
920
  const fetchedAll = perPageInput === false || allThreadMessagesReturned;
@@ -1043,6 +1093,62 @@ var StoreMemoryLance = class extends storage.MemoryStorage {
1043
1093
  );
1044
1094
  }
1045
1095
  }
1096
+ _sortMessages(messages, field, direction) {
1097
+ return messages.sort((a, b) => {
1098
+ const aValue = field === "createdAt" ? new Date(a.createdAt).getTime() : a[field];
1099
+ const bValue = field === "createdAt" ? new Date(b.createdAt).getTime() : b[field];
1100
+ if (aValue == null && bValue == null) return 0;
1101
+ if (aValue == null) return direction === "ASC" ? -1 : 1;
1102
+ if (bValue == null) return direction === "ASC" ? 1 : -1;
1103
+ if (typeof aValue === "string" && typeof bValue === "string") {
1104
+ return direction === "ASC" ? aValue.localeCompare(bValue) : bValue.localeCompare(aValue);
1105
+ }
1106
+ return direction === "ASC" ? aValue - bValue : bValue - aValue;
1107
+ });
1108
+ }
1109
+ async _getIncludedMessages(table, include) {
1110
+ if (include.length === 0) return [];
1111
+ const targetIds = include.map((item) => item.id);
1112
+ const idCondition = targetIds.length === 1 ? `id = '${this.escapeSql(targetIds[0])}'` : `id IN (${targetIds.map((id) => `'${this.escapeSql(id)}'`).join(", ")})`;
1113
+ const targetRecords = await table.query().where(idCondition).toArray();
1114
+ const needsContext = include.some((item) => item.withPreviousMessages || item.withNextMessages);
1115
+ if (!needsContext) {
1116
+ return targetRecords.map((row) => this.normalizeMessage(row));
1117
+ }
1118
+ const threadIdsToFetch = [...new Set(targetRecords.map((r) => r.thread_id))];
1119
+ const threadCache = /* @__PURE__ */ new Map();
1120
+ for (const tid of threadIdsToFetch) {
1121
+ const threadRecords = await table.query().where(`thread_id = '${this.escapeSql(tid)}'`).toArray();
1122
+ threadRecords.sort((a, b) => a.createdAt - b.createdAt);
1123
+ threadCache.set(tid, threadRecords);
1124
+ }
1125
+ const targetThreadMap = /* @__PURE__ */ new Map();
1126
+ for (const r of targetRecords) {
1127
+ targetThreadMap.set(r.id, r.thread_id);
1128
+ }
1129
+ const includeByThread = /* @__PURE__ */ new Map();
1130
+ for (const item of include) {
1131
+ const tid = targetThreadMap.get(item.id);
1132
+ if (!tid) continue;
1133
+ const items = includeByThread.get(tid) ?? [];
1134
+ items.push(item);
1135
+ includeByThread.set(tid, items);
1136
+ }
1137
+ const seen = /* @__PURE__ */ new Set();
1138
+ const allContextMessages = [];
1139
+ for (const [tid, threadInclude] of includeByThread) {
1140
+ const threadMessages = threadCache.get(tid) ?? [];
1141
+ const contextMessages = this.processMessagesWithContext(threadMessages, threadInclude);
1142
+ for (const msg of contextMessages) {
1143
+ if (!seen.has(msg.id)) {
1144
+ seen.add(msg.id);
1145
+ allContextMessages.push(msg);
1146
+ }
1147
+ }
1148
+ }
1149
+ allContextMessages.sort((a, b) => a.createdAt - b.createdAt);
1150
+ return allContextMessages.map((row) => this.normalizeMessage(row));
1151
+ }
1046
1152
  /**
1047
1153
  * Processes messages to include context messages based on withPreviousMessages and withNextMessages
1048
1154
  * @param records - The sorted array of records to process
@@ -1052,7 +1158,8 @@ var StoreMemoryLance = class extends storage.MemoryStorage {
1052
1158
  processMessagesWithContext(records, include) {
1053
1159
  const messagesWithContext = include.filter((item) => item.withPreviousMessages || item.withNextMessages);
1054
1160
  if (messagesWithContext.length === 0) {
1055
- return records;
1161
+ const targetIds = new Set(include.map((item) => item.id));
1162
+ return records.filter((record) => targetIds.has(record.id));
1056
1163
  }
1057
1164
  const messageIndexMap = /* @__PURE__ */ new Map();
1058
1165
  records.forEach((message, index) => {
@@ -1077,7 +1184,8 @@ var StoreMemoryLance = class extends storage.MemoryStorage {
1077
1184
  }
1078
1185
  }
1079
1186
  if (additionalIndices.size === 0) {
1080
- return records;
1187
+ const targetIds = new Set(include.map((item) => item.id));
1188
+ return records.filter((record) => targetIds.has(record.id));
1081
1189
  }
1082
1190
  const originalMatchIds = new Set(include.map((item) => item.id));
1083
1191
  const allIndices = /* @__PURE__ */ new Set();