@mastra/lance 1.0.3 → 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -129,10 +129,52 @@ function resolveLanceConfig(config) {
129
129
  }
130
130
  var LanceDB = class extends MastraBase {
131
131
  client;
132
+ /** Cache of actual table columns: tableName -> Set<columnName> */
133
+ /** Cache of actual table columns: tableName -> Promise<Set<columnName>> (stores in-flight promise to coalesce concurrent calls) */
134
+ tableColumnsCache = /* @__PURE__ */ new Map();
132
135
  constructor({ client }) {
133
136
  super({ name: "lance-db" });
134
137
  this.client = client;
135
138
  }
139
+ /**
140
+ * Gets the set of column names that actually exist in the database table.
141
+ * Results are cached; the cache is invalidated when alterTable() adds new columns.
142
+ */
143
+ async getTableColumns(tableName) {
144
+ const cached = this.tableColumnsCache.get(tableName);
145
+ if (cached) return cached;
146
+ const promise = (async () => {
147
+ try {
148
+ const table = await this.client.openTable(tableName);
149
+ const schema = await table.schema();
150
+ const columns = new Set(schema.fields.map((f) => f.name));
151
+ if (columns.size === 0) {
152
+ this.tableColumnsCache.delete(tableName);
153
+ }
154
+ return columns;
155
+ } catch {
156
+ this.tableColumnsCache.delete(tableName);
157
+ return /* @__PURE__ */ new Set();
158
+ }
159
+ })();
160
+ this.tableColumnsCache.set(tableName, promise);
161
+ return promise;
162
+ }
163
+ /**
164
+ * Filters a record to only include columns that exist in the actual database table.
165
+ * Unknown columns are silently dropped to ensure forward compatibility.
166
+ */
167
+ async filterRecordToKnownColumns(tableName, record) {
168
+ const knownColumns = await this.getTableColumns(tableName);
169
+ if (knownColumns.size === 0) return record;
170
+ const filtered = {};
171
+ for (const [key, value] of Object.entries(record)) {
172
+ if (knownColumns.has(key)) {
173
+ filtered[key] = value;
174
+ }
175
+ }
176
+ return filtered;
177
+ }
136
178
  getDefaultValue(type) {
137
179
  switch (type) {
138
180
  case "text":
@@ -232,6 +274,8 @@ var LanceDB = class extends MastraBase {
232
274
  },
233
275
  error
234
276
  );
277
+ } finally {
278
+ this.tableColumnsCache.delete(tableName);
235
279
  }
236
280
  }
237
281
  async dropTable({ tableName }) {
@@ -270,6 +314,8 @@ var LanceDB = class extends MastraBase {
270
314
  },
271
315
  error
272
316
  );
317
+ } finally {
318
+ this.tableColumnsCache.delete(tableName);
273
319
  }
274
320
  }
275
321
  async alterTable({
@@ -336,6 +382,8 @@ var LanceDB = class extends MastraBase {
336
382
  },
337
383
  error
338
384
  );
385
+ } finally {
386
+ this.tableColumnsCache.delete(tableName);
339
387
  }
340
388
  }
341
389
  async clearTable({ tableName }) {
@@ -406,7 +454,9 @@ var LanceDB = class extends MastraBase {
406
454
  processedRecord[key] = JSON.stringify(processedRecord[key]);
407
455
  }
408
456
  }
409
- await table.mergeInsert(primaryId).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute([processedRecord]);
457
+ const filteredRecord = await this.filterRecordToKnownColumns(tableName, processedRecord);
458
+ if (Object.keys(filteredRecord).length === 0) return;
459
+ await table.mergeInsert(primaryId).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute([filteredRecord]);
410
460
  } catch (error) {
411
461
  throw new MastraError(
412
462
  {
@@ -455,7 +505,12 @@ var LanceDB = class extends MastraBase {
455
505
  }
456
506
  return processedRecord;
457
507
  });
458
- await table.mergeInsert(primaryId).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute(processedRecords);
508
+ const filteredRecords = await Promise.all(
509
+ processedRecords.map((r) => this.filterRecordToKnownColumns(tableName, r))
510
+ );
511
+ const nonEmptyRecords = filteredRecords.filter((r) => Object.keys(r).length > 0);
512
+ if (nonEmptyRecords.length === 0) return;
513
+ await table.mergeInsert(primaryId).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute(nonEmptyRecords);
459
514
  } catch (error) {
460
515
  throw new MastraError(
461
516
  {
@@ -806,6 +861,20 @@ var StoreMemoryLance = class extends MemoryStorage {
806
861
  conditions.push(`\`createdAt\` ${endOp} ${endTime}`);
807
862
  }
808
863
  const whereClause = conditions.join(" AND ");
864
+ if (perPage === 0 && (!include || include.length === 0)) {
865
+ return { messages: [], total: 0, page, perPage: perPageForResponse, hasMore: false };
866
+ }
867
+ if (perPage === 0 && include && include.length > 0) {
868
+ const includedMessages = await this._getIncludedMessages(table, include);
869
+ const list2 = new MessageList().add(includedMessages, "memory");
870
+ return {
871
+ messages: this._sortMessages(list2.get.all.db(), field, direction),
872
+ total: 0,
873
+ page,
874
+ perPage: perPageForResponse,
875
+ hasMore: false
876
+ };
877
+ }
809
878
  const total = await table.countRows(whereClause);
810
879
  const query = table.query().where(whereClause);
811
880
  let allRecords = await query.toArray();
@@ -833,16 +902,7 @@ var StoreMemoryLance = class extends MemoryStorage {
833
902
  }
834
903
  const messageIds = new Set(messages.map((m) => m.id));
835
904
  if (include && include.length > 0) {
836
- const threadIds2 = [...new Set(include.map((item) => item.threadId || threadId))];
837
- const allThreadMessages = [];
838
- for (const tid of threadIds2) {
839
- const threadQuery = table.query().where(`thread_id = '${tid}'`);
840
- let threadRecords = await threadQuery.toArray();
841
- allThreadMessages.push(...threadRecords);
842
- }
843
- allThreadMessages.sort((a, b) => a.createdAt - b.createdAt);
844
- const contextMessages = this.processMessagesWithContext(allThreadMessages, include);
845
- const includedMessages = contextMessages.map((row) => this.normalizeMessage(row));
905
+ const includedMessages = await this._getIncludedMessages(table, include);
846
906
  for (const includeMsg of includedMessages) {
847
907
  if (!messageIds.has(includeMsg.id)) {
848
908
  messages.push(includeMsg);
@@ -852,17 +912,7 @@ var StoreMemoryLance = class extends MemoryStorage {
852
912
  }
853
913
  const list = new MessageList().add(messages, "memory");
854
914
  let finalMessages = list.get.all.db();
855
- finalMessages = finalMessages.sort((a, b) => {
856
- const aValue = field === "createdAt" ? new Date(a.createdAt).getTime() : a[field];
857
- const bValue = field === "createdAt" ? new Date(b.createdAt).getTime() : b[field];
858
- if (aValue == null && bValue == null) return 0;
859
- if (aValue == null) return direction === "ASC" ? -1 : 1;
860
- if (bValue == null) return direction === "ASC" ? 1 : -1;
861
- if (typeof aValue === "string" && typeof bValue === "string") {
862
- return direction === "ASC" ? aValue.localeCompare(bValue) : bValue.localeCompare(aValue);
863
- }
864
- return direction === "ASC" ? aValue - bValue : bValue - aValue;
865
- });
915
+ finalMessages = this._sortMessages(finalMessages, field, direction);
866
916
  const returnedThreadMessageIds = new Set(finalMessages.filter((m) => m.threadId === threadId).map((m) => m.id));
867
917
  const allThreadMessagesReturned = returnedThreadMessageIds.size >= total;
868
918
  const fetchedAll = perPageInput === false || allThreadMessagesReturned;
@@ -1041,6 +1091,62 @@ var StoreMemoryLance = class extends MemoryStorage {
1041
1091
  );
1042
1092
  }
1043
1093
  }
1094
+ _sortMessages(messages, field, direction) {
1095
+ return messages.sort((a, b) => {
1096
+ const aValue = field === "createdAt" ? new Date(a.createdAt).getTime() : a[field];
1097
+ const bValue = field === "createdAt" ? new Date(b.createdAt).getTime() : b[field];
1098
+ if (aValue == null && bValue == null) return 0;
1099
+ if (aValue == null) return direction === "ASC" ? -1 : 1;
1100
+ if (bValue == null) return direction === "ASC" ? 1 : -1;
1101
+ if (typeof aValue === "string" && typeof bValue === "string") {
1102
+ return direction === "ASC" ? aValue.localeCompare(bValue) : bValue.localeCompare(aValue);
1103
+ }
1104
+ return direction === "ASC" ? aValue - bValue : bValue - aValue;
1105
+ });
1106
+ }
1107
+ async _getIncludedMessages(table, include) {
1108
+ if (include.length === 0) return [];
1109
+ const targetIds = include.map((item) => item.id);
1110
+ const idCondition = targetIds.length === 1 ? `id = '${this.escapeSql(targetIds[0])}'` : `id IN (${targetIds.map((id) => `'${this.escapeSql(id)}'`).join(", ")})`;
1111
+ const targetRecords = await table.query().where(idCondition).toArray();
1112
+ const needsContext = include.some((item) => item.withPreviousMessages || item.withNextMessages);
1113
+ if (!needsContext) {
1114
+ return targetRecords.map((row) => this.normalizeMessage(row));
1115
+ }
1116
+ const threadIdsToFetch = [...new Set(targetRecords.map((r) => r.thread_id))];
1117
+ const threadCache = /* @__PURE__ */ new Map();
1118
+ for (const tid of threadIdsToFetch) {
1119
+ const threadRecords = await table.query().where(`thread_id = '${this.escapeSql(tid)}'`).toArray();
1120
+ threadRecords.sort((a, b) => a.createdAt - b.createdAt);
1121
+ threadCache.set(tid, threadRecords);
1122
+ }
1123
+ const targetThreadMap = /* @__PURE__ */ new Map();
1124
+ for (const r of targetRecords) {
1125
+ targetThreadMap.set(r.id, r.thread_id);
1126
+ }
1127
+ const includeByThread = /* @__PURE__ */ new Map();
1128
+ for (const item of include) {
1129
+ const tid = targetThreadMap.get(item.id);
1130
+ if (!tid) continue;
1131
+ const items = includeByThread.get(tid) ?? [];
1132
+ items.push(item);
1133
+ includeByThread.set(tid, items);
1134
+ }
1135
+ const seen = /* @__PURE__ */ new Set();
1136
+ const allContextMessages = [];
1137
+ for (const [tid, threadInclude] of includeByThread) {
1138
+ const threadMessages = threadCache.get(tid) ?? [];
1139
+ const contextMessages = this.processMessagesWithContext(threadMessages, threadInclude);
1140
+ for (const msg of contextMessages) {
1141
+ if (!seen.has(msg.id)) {
1142
+ seen.add(msg.id);
1143
+ allContextMessages.push(msg);
1144
+ }
1145
+ }
1146
+ }
1147
+ allContextMessages.sort((a, b) => a.createdAt - b.createdAt);
1148
+ return allContextMessages.map((row) => this.normalizeMessage(row));
1149
+ }
1044
1150
  /**
1045
1151
  * Processes messages to include context messages based on withPreviousMessages and withNextMessages
1046
1152
  * @param records - The sorted array of records to process
@@ -1050,7 +1156,8 @@ var StoreMemoryLance = class extends MemoryStorage {
1050
1156
  processMessagesWithContext(records, include) {
1051
1157
  const messagesWithContext = include.filter((item) => item.withPreviousMessages || item.withNextMessages);
1052
1158
  if (messagesWithContext.length === 0) {
1053
- return records;
1159
+ const targetIds = new Set(include.map((item) => item.id));
1160
+ return records.filter((record) => targetIds.has(record.id));
1054
1161
  }
1055
1162
  const messageIndexMap = /* @__PURE__ */ new Map();
1056
1163
  records.forEach((message, index) => {
@@ -1075,7 +1182,8 @@ var StoreMemoryLance = class extends MemoryStorage {
1075
1182
  }
1076
1183
  }
1077
1184
  if (additionalIndices.size === 0) {
1078
- return records;
1185
+ const targetIds = new Set(include.map((item) => item.id));
1186
+ return records.filter((record) => targetIds.has(record.id));
1079
1187
  }
1080
1188
  const originalMatchIds = new Set(include.map((item) => item.id));
1081
1189
  const allIndices = /* @__PURE__ */ new Set();