@mastra/cloudflare-d1 1.0.1 → 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -285,6 +285,8 @@ var D1DB = class extends MastraBase {
285
285
  client;
286
286
  binding;
287
287
  tablePrefix;
288
+ /** Cache of actual table columns: tableName -> Promise<Set<columnName>> (stores in-flight promise to coalesce concurrent calls) */
289
+ tableColumnsCache = /* @__PURE__ */ new Map();
288
290
  constructor(config) {
289
291
  super({
290
292
  component: "STORAGE",
@@ -393,6 +395,38 @@ var D1DB = class extends MastraBase {
393
395
  throw new Error("Neither binding nor client is configured");
394
396
  }
395
397
  }
398
+ /**
399
+ * Gets the set of column names that actually exist in the database table.
400
+ * Results are cached; the cache is invalidated when alterTable() adds new columns.
401
+ */
402
+ async getKnownColumnNames(tableName) {
403
+ const cached = this.tableColumnsCache.get(tableName);
404
+ if (cached) return cached;
405
+ const promise = this.getTableColumns(tableName).then((columns) => {
406
+ const names = new Set(columns.map((c) => c.name));
407
+ if (names.size === 0) {
408
+ this.tableColumnsCache.delete(tableName);
409
+ }
410
+ return names;
411
+ });
412
+ this.tableColumnsCache.set(tableName, promise);
413
+ return promise;
414
+ }
415
+ /**
416
+ * Filters a record to only include columns that exist in the actual database table.
417
+ * Unknown columns are silently dropped to ensure forward compatibility.
418
+ */
419
+ async filterRecordToKnownColumns(tableName, record) {
420
+ const knownColumns = await this.getKnownColumnNames(tableName);
421
+ if (knownColumns.size === 0) return record;
422
+ const filtered = {};
423
+ for (const [key, value] of Object.entries(record)) {
424
+ if (knownColumns.has(key)) {
425
+ filtered[key] = value;
426
+ }
427
+ }
428
+ return filtered;
429
+ }
396
430
  async getTableColumns(tableName) {
397
431
  try {
398
432
  const sql = `PRAGMA table_info(${tableName})`;
@@ -459,6 +493,7 @@ var D1DB = class extends MastraBase {
459
493
  const { sql, params } = query.build();
460
494
  await this.executeQuery({ sql, params });
461
495
  this.logger.debug(`Created table ${fullTableName}`);
496
+ this.tableColumnsCache.delete(fullTableName);
462
497
  } catch (error) {
463
498
  throw new MastraError(
464
499
  {
@@ -491,8 +526,8 @@ var D1DB = class extends MastraBase {
491
526
  }
492
527
  }
493
528
  async dropTable({ tableName }) {
529
+ const fullTableName = this.getTableName(tableName);
494
530
  try {
495
- const fullTableName = this.getTableName(tableName);
496
531
  const sql = `DROP TABLE IF EXISTS ${fullTableName}`;
497
532
  await this.executeQuery({ sql });
498
533
  this.logger.debug(`Dropped table ${fullTableName}`);
@@ -506,11 +541,13 @@ var D1DB = class extends MastraBase {
506
541
  },
507
542
  error
508
543
  );
544
+ } finally {
545
+ this.tableColumnsCache.delete(fullTableName);
509
546
  }
510
547
  }
511
548
  async alterTable(args) {
549
+ const fullTableName = this.getTableName(args.tableName);
512
550
  try {
513
- const fullTableName = this.getTableName(args.tableName);
514
551
  const existingColumns = await this.getTableColumns(fullTableName);
515
552
  const existingColumnNames = new Set(existingColumns.map((col) => col.name));
516
553
  for (const [columnName, column] of Object.entries(args.schema)) {
@@ -532,14 +569,18 @@ var D1DB = class extends MastraBase {
532
569
  },
533
570
  error
534
571
  );
572
+ } finally {
573
+ this.tableColumnsCache.delete(fullTableName);
535
574
  }
536
575
  }
537
576
  async insert({ tableName, record }) {
538
577
  try {
539
578
  const fullTableName = this.getTableName(tableName);
540
579
  const processedRecord = await this.processRecord(record);
541
- const columns = Object.keys(processedRecord);
542
- const values = Object.values(processedRecord);
580
+ const filteredRecord = await this.filterRecordToKnownColumns(fullTableName, processedRecord);
581
+ const columns = Object.keys(filteredRecord);
582
+ if (columns.length === 0) return;
583
+ const values = Object.values(filteredRecord);
543
584
  const query = createSqlBuilder().insert(fullTableName, columns, values);
544
585
  const { sql, params } = query.build();
545
586
  await this.executeQuery({ sql, params });
@@ -560,8 +601,12 @@ var D1DB = class extends MastraBase {
560
601
  if (records.length === 0) return;
561
602
  const fullTableName = this.getTableName(tableName);
562
603
  const processedRecords = await Promise.all(records.map((record) => this.processRecord(record)));
563
- const columns = Object.keys(processedRecords[0] || {});
564
- for (const record of processedRecords) {
604
+ const filteredRecords = await Promise.all(
605
+ processedRecords.map((r) => this.filterRecordToKnownColumns(fullTableName, r))
606
+ );
607
+ for (const record of filteredRecords) {
608
+ const columns = Object.keys(record);
609
+ if (columns.length === 0) continue;
565
610
  const values = Object.values(record);
566
611
  const query = createSqlBuilder().insert(fullTableName, columns, values);
567
612
  const { sql, params } = query.build();
@@ -637,12 +682,14 @@ var D1DB = class extends MastraBase {
637
682
  const batch = records.slice(i, i + batchSize);
638
683
  const recordsToInsert = batch;
639
684
  if (recordsToInsert.length > 0) {
640
- const firstRecord = recordsToInsert[0];
641
- const columns = Object.keys(firstRecord || {});
642
- for (const record of recordsToInsert) {
685
+ const filteredRecords = await Promise.all(
686
+ recordsToInsert.map((r) => this.filterRecordToKnownColumns(fullTableName, r || {}))
687
+ );
688
+ for (const record of filteredRecords) {
689
+ const columns = Object.keys(record);
690
+ if (columns.length === 0) continue;
643
691
  const values = columns.map((col) => {
644
- if (!record) return null;
645
- const value = typeof col === "string" ? record[col] : null;
692
+ const value = record[col];
646
693
  return this.serializeValue(value);
647
694
  });
648
695
  const recordToUpsert = columns.reduce(
@@ -1128,56 +1175,79 @@ var MemoryStorageD1 = class extends MemoryStorage {
1128
1175
  );
1129
1176
  }
1130
1177
  }
1178
+ _sortMessages(messages, field, direction) {
1179
+ return messages.sort((a, b) => {
1180
+ const isDateField = field === "createdAt" || field === "updatedAt";
1181
+ const aValue = isDateField ? new Date(a[field]).getTime() : a[field];
1182
+ const bValue = isDateField ? new Date(b[field]).getTime() : b[field];
1183
+ if (aValue === bValue) {
1184
+ return a.id.localeCompare(b.id);
1185
+ }
1186
+ if (typeof aValue === "number" && typeof bValue === "number") {
1187
+ return direction === "ASC" ? aValue - bValue : bValue - aValue;
1188
+ }
1189
+ return direction === "ASC" ? String(aValue).localeCompare(String(bValue)) : String(bValue).localeCompare(String(aValue));
1190
+ });
1191
+ }
1131
1192
  async _getIncludedMessages(include) {
1132
1193
  if (!include || include.length === 0) return null;
1194
+ const tableName = this.#db.getTableName(TABLE_MESSAGES);
1195
+ const targetIds = include.map((inc) => inc.id).filter(Boolean);
1196
+ if (targetIds.length === 0) return null;
1197
+ const idPlaceholders = targetIds.map(() => "?").join(", ");
1198
+ const targetResult = await this.#db.executeQuery({
1199
+ sql: `SELECT id, thread_id, createdAt FROM ${tableName} WHERE id IN (${idPlaceholders})`,
1200
+ params: targetIds
1201
+ });
1202
+ if (!Array.isArray(targetResult) || targetResult.length === 0) return null;
1203
+ const targetMap = new Map(
1204
+ targetResult.map((r) => [r.id, { threadId: r.thread_id, createdAt: r.createdAt }])
1205
+ );
1133
1206
  const unionQueries = [];
1134
1207
  const params = [];
1135
- let paramIdx = 1;
1136
- const tableName = this.#db.getTableName(TABLE_MESSAGES);
1137
1208
  for (const inc of include) {
1138
1209
  const { id, withPreviousMessages = 0, withNextMessages = 0 } = inc;
1139
- unionQueries.push(`
1140
- SELECT * FROM (
1141
- WITH target_thread AS (
1142
- SELECT thread_id FROM ${tableName} WHERE id = ?
1143
- ),
1144
- ordered_messages AS (
1145
- SELECT
1146
- *,
1147
- ROW_NUMBER() OVER (ORDER BY createdAt ASC) AS row_num
1148
- FROM ${tableName}
1149
- WHERE thread_id = (SELECT thread_id FROM target_thread)
1150
- )
1151
- SELECT
1152
- m.id,
1153
- m.content,
1154
- m.role,
1155
- m.type,
1156
- m.createdAt,
1157
- m.thread_id AS threadId,
1158
- m.resourceId
1159
- FROM ordered_messages m
1160
- WHERE m.id = ?
1161
- OR EXISTS (
1162
- SELECT 1 FROM ordered_messages target
1163
- WHERE target.id = ?
1164
- AND (
1165
- (m.row_num <= target.row_num + ? AND m.row_num > target.row_num)
1166
- OR
1167
- (m.row_num >= target.row_num - ? AND m.row_num < target.row_num)
1168
- )
1169
- )
1170
- ) AS query_${paramIdx}
1171
- `);
1172
- params.push(id, id, id, withNextMessages, withPreviousMessages);
1173
- paramIdx++;
1174
- }
1175
- const finalQuery = unionQueries.join(" UNION ALL ") + " ORDER BY createdAt ASC";
1210
+ const target = targetMap.get(id);
1211
+ if (!target) continue;
1212
+ unionQueries.push(`SELECT * FROM (
1213
+ SELECT id, content, role, type, createdAt, thread_id AS threadId, resourceId
1214
+ FROM ${tableName}
1215
+ WHERE thread_id = ?
1216
+ AND createdAt <= ?
1217
+ ORDER BY createdAt DESC, id DESC
1218
+ LIMIT ?
1219
+ )`);
1220
+ params.push(target.threadId, target.createdAt, withPreviousMessages + 1);
1221
+ if (withNextMessages > 0) {
1222
+ unionQueries.push(`SELECT * FROM (
1223
+ SELECT id, content, role, type, createdAt, thread_id AS threadId, resourceId
1224
+ FROM ${tableName}
1225
+ WHERE thread_id = ?
1226
+ AND createdAt > ?
1227
+ ORDER BY createdAt ASC, id ASC
1228
+ LIMIT ?
1229
+ )`);
1230
+ params.push(target.threadId, target.createdAt, withNextMessages);
1231
+ }
1232
+ }
1233
+ if (unionQueries.length === 0) return null;
1234
+ let finalQuery;
1235
+ if (unionQueries.length === 1) {
1236
+ finalQuery = unionQueries[0];
1237
+ } else {
1238
+ finalQuery = `${unionQueries.join(" UNION ALL ")} ORDER BY createdAt ASC, id ASC`;
1239
+ }
1176
1240
  const messages = await this.#db.executeQuery({ sql: finalQuery, params });
1177
1241
  if (!Array.isArray(messages)) {
1178
1242
  return [];
1179
1243
  }
1180
- const processedMessages = messages.map((message) => {
1244
+ const seen = /* @__PURE__ */ new Set();
1245
+ const processedMessages = messages.filter((message) => {
1246
+ const id = message.id;
1247
+ if (seen.has(id)) return false;
1248
+ seen.add(id);
1249
+ return true;
1250
+ }).map((message) => {
1181
1251
  const processedMsg = {};
1182
1252
  for (const [key, value] of Object.entries(message)) {
1183
1253
  if (key === `type` && value === `v2`) continue;
@@ -1277,6 +1347,23 @@ var MemoryStorageD1 = class extends MemoryStorage {
1277
1347
  queryParams.push(endDate);
1278
1348
  }
1279
1349
  const { field, direction } = this.parseOrderBy(orderBy, "ASC");
1350
+ if (perPage === 0 && (!include || include.length === 0)) {
1351
+ return { messages: [], total: 0, page, perPage: perPageForResponse, hasMore: false };
1352
+ }
1353
+ if (perPage === 0 && include && include.length > 0) {
1354
+ const includeResult = await this._getIncludedMessages(include);
1355
+ if (!Array.isArray(includeResult) || includeResult.length === 0) {
1356
+ return { messages: [], total: 0, page, perPage: perPageForResponse, hasMore: false };
1357
+ }
1358
+ const list2 = new MessageList().add(includeResult, "memory");
1359
+ return {
1360
+ messages: this._sortMessages(list2.get.all.db(), field, direction),
1361
+ total: 0,
1362
+ page,
1363
+ perPage: perPageForResponse,
1364
+ hasMore: false
1365
+ };
1366
+ }
1280
1367
  query += ` ORDER BY "${field}" ${direction}`;
1281
1368
  if (perPage !== Number.MAX_SAFE_INTEGER) {
1282
1369
  query += ` LIMIT ? OFFSET ?`;
@@ -1336,19 +1423,7 @@ var MemoryStorageD1 = class extends MemoryStorage {
1336
1423
  }
1337
1424
  }
1338
1425
  const list = new MessageList().add(paginatedMessages, "memory");
1339
- let finalMessages = list.get.all.db();
1340
- finalMessages = finalMessages.sort((a, b) => {
1341
- const isDateField = field === "createdAt" || field === "updatedAt";
1342
- const aValue = isDateField ? new Date(a[field]).getTime() : a[field];
1343
- const bValue = isDateField ? new Date(b[field]).getTime() : b[field];
1344
- if (aValue === bValue) {
1345
- return a.id.localeCompare(b.id);
1346
- }
1347
- if (typeof aValue === "number" && typeof bValue === "number") {
1348
- return direction === "ASC" ? aValue - bValue : bValue - aValue;
1349
- }
1350
- return direction === "ASC" ? String(aValue).localeCompare(String(bValue)) : String(bValue).localeCompare(String(aValue));
1351
- });
1426
+ const finalMessages = this._sortMessages(list.get.all.db(), field, direction);
1352
1427
  const returnedThreadMessageIds = new Set(finalMessages.filter((m) => m.threadId === threadId).map((m) => m.id));
1353
1428
  const allThreadMessagesReturned = returnedThreadMessageIds.size >= total;
1354
1429
  const hasMore = perPageInput === false ? false : allThreadMessagesReturned ? false : offset + paginatedCount < total;