@mastra/cloudflare 1.0.0-beta.1 → 1.0.0-beta.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { MastraError, ErrorCategory, ErrorDomain } from '@mastra/core/error';
2
- import { MastraStorage, TABLE_THREADS, TABLE_MESSAGES, TABLE_WORKFLOW_SNAPSHOT, TABLE_SCORERS, StoreOperations, TABLE_TRACES, WorkflowsStorage, ensureDate, normalizePerPage, MemoryStorage, calculatePagination, serializeDate, TABLE_RESOURCES, ScoresStorage, safelyParseJSON } from '@mastra/core/storage';
2
+ import { MastraStorage, TABLE_THREADS, TABLE_MESSAGES, TABLE_WORKFLOW_SNAPSHOT, TABLE_SCORERS, StoreOperations, TABLE_TRACES, WorkflowsStorage, ensureDate, normalizePerPage, MemoryStorage, calculatePagination, serializeDate, TABLE_RESOURCES, ScoresStorage, transformScoreRow as transformScoreRow$1 } from '@mastra/core/storage';
3
3
  import Cloudflare from 'cloudflare';
4
4
  import { MessageList } from '@mastra/core/agent';
5
5
  import { saveScorePayloadSchema } from '@mastra/core/evals';
@@ -207,6 +207,17 @@ var MemoryStorageCloudflare = class extends MemoryStorage {
207
207
  );
208
208
  }
209
209
  }
210
+ /**
211
+ * Searches all threads in the KV store to find a message by its ID.
212
+ *
213
+ * **Performance Warning**: This method sequentially scans all threads to locate
214
+ * the message. For stores with many threads, this can result in significant
215
+ * latency and API calls. When possible, callers should provide the `threadId`
216
+ * directly to avoid this full scan.
217
+ *
218
+ * @param messageId - The globally unique message ID to search for
219
+ * @returns The message with its threadId if found, null otherwise
220
+ */
210
221
  async findMessageInAnyThread(messageId) {
211
222
  try {
212
223
  const prefix = this.operations.namespacePrefix ? `${this.operations.namespacePrefix}:` : "";
@@ -436,10 +447,25 @@ var MemoryStorageCloudflare = class extends MemoryStorage {
436
447
  async getFullOrder(orderKey) {
437
448
  return this.getRange(orderKey, 0, -1);
438
449
  }
439
- async getIncludedMessagesWithContext(threadId, include, messageIds) {
450
+ /**
451
+ * Retrieves messages specified in the include array along with their surrounding context.
452
+ *
453
+ * **Performance Note**: When `threadId` is not provided in an include entry, this method
454
+ * must call `findMessageInAnyThread` which sequentially scans all threads in the KV store.
455
+ * For optimal performance, callers should provide `threadId` in include entries when known.
456
+ *
457
+ * @param include - Array of message IDs to include, optionally with context windows
458
+ * @param messageIds - Set to accumulate the message IDs that should be fetched
459
+ */
460
+ async getIncludedMessagesWithContext(include, messageIds) {
440
461
  await Promise.all(
441
462
  include.map(async (item) => {
442
- const targetThreadId = item.threadId || threadId;
463
+ let targetThreadId = item.threadId;
464
+ if (!targetThreadId) {
465
+ const foundMessage = await this.findMessageInAnyThread(item.id);
466
+ if (!foundMessage) return;
467
+ targetThreadId = foundMessage.threadId;
468
+ }
443
469
  if (!targetThreadId) return;
444
470
  const threadMessagesKey = this.getThreadMessagesKey(targetThreadId);
445
471
  messageIds.add(item.id);
@@ -472,6 +498,13 @@ var MemoryStorageCloudflare = class extends MemoryStorage {
472
498
  this.logger?.debug(`No message order found for thread ${threadId}, skipping latest messages`);
473
499
  }
474
500
  }
501
+ /**
502
+ * Fetches and parses messages from one or more threads.
503
+ *
504
+ * **Performance Note**: When neither `include` entries with `threadId` nor `targetThreadId`
505
+ * are provided, this method falls back to `findMessageInAnyThread` which scans all threads.
506
+ * For optimal performance, provide `threadId` in include entries or specify `targetThreadId`.
507
+ */
475
508
  async fetchAndParseMessagesFromMultipleThreads(messageIds, include, targetThreadId) {
476
509
  const messageIdToThreadId = /* @__PURE__ */ new Map();
477
510
  if (include) {
@@ -513,6 +546,14 @@ var MemoryStorageCloudflare = class extends MemoryStorage {
513
546
  );
514
547
  return messages.filter((msg) => msg !== null);
515
548
  }
549
+ /**
550
+ * Retrieves messages by their IDs.
551
+ *
552
+ * **Performance Warning**: This method calls `findMessageInAnyThread` for each message ID,
553
+ * which scans all threads in the KV store. For large numbers of messages or threads,
554
+ * this can result in significant latency. Consider using `listMessages` with specific
555
+ * thread IDs when the thread context is known.
556
+ */
516
557
  async listMessagesById({ messageIds }) {
517
558
  if (messageIds.length === 0) return { messages: [] };
518
559
  try {
@@ -546,15 +587,17 @@ var MemoryStorageCloudflare = class extends MemoryStorage {
546
587
  }
547
588
  async listMessages(args) {
548
589
  const { threadId, resourceId, include, filter, perPage: perPageInput, page = 0, orderBy } = args;
549
- if (!threadId.trim()) {
590
+ const threadIds = Array.isArray(threadId) ? threadId : [threadId];
591
+ const isValidThreadId = (id) => typeof id === "string" && id.trim().length > 0;
592
+ if (threadIds.length === 0 || threadIds.some((id) => !isValidThreadId(id))) {
550
593
  throw new MastraError(
551
594
  {
552
595
  id: "STORAGE_CLOUDFLARE_LIST_MESSAGES_INVALID_THREAD_ID",
553
596
  domain: ErrorDomain.STORAGE,
554
597
  category: ErrorCategory.THIRD_PARTY,
555
- details: { threadId }
598
+ details: { threadId: Array.isArray(threadId) ? JSON.stringify(threadId) : String(threadId) }
556
599
  },
557
- new Error("threadId must be a non-empty string")
600
+ new Error("threadId must be a non-empty string or array of non-empty strings")
558
601
  );
559
602
  }
560
603
  const perPage = normalizePerPage(perPageInput, 40);
@@ -572,69 +615,34 @@ var MemoryStorageCloudflare = class extends MemoryStorage {
572
615
  );
573
616
  }
574
617
  const { field, direction } = this.parseOrderBy(orderBy, "ASC");
575
- const messageIds = /* @__PURE__ */ new Set();
576
- const hasFilters = !!resourceId || !!filter?.dateRange;
577
- if (hasFilters || perPage === Number.MAX_SAFE_INTEGER) {
618
+ const threadMessageIds = /* @__PURE__ */ new Set();
619
+ for (const tid of threadIds) {
578
620
  try {
579
- const threadMessagesKey = this.getThreadMessagesKey(threadId);
621
+ const threadMessagesKey = this.getThreadMessagesKey(tid);
580
622
  const allIds = await this.getFullOrder(threadMessagesKey);
581
- allIds.forEach((id) => messageIds.add(id));
623
+ allIds.forEach((id) => threadMessageIds.add(id));
582
624
  } catch {
583
625
  }
584
- } else {
585
- if (perPage > 0) {
586
- try {
587
- const threadMessagesKey = this.getThreadMessagesKey(threadId);
588
- const fullOrder = await this.getFullOrder(threadMessagesKey);
589
- const totalMessages = fullOrder.length;
590
- let start;
591
- let end;
592
- if (direction === "ASC") {
593
- start = offset;
594
- end = Math.min(offset + perPage - 1, totalMessages - 1);
595
- } else {
596
- start = Math.max(totalMessages - offset - perPage, 0);
597
- end = totalMessages - offset - 1;
598
- }
599
- const paginatedIds = await this.getRange(threadMessagesKey, start, end);
600
- paginatedIds.forEach((id) => messageIds.add(id));
601
- } catch {
602
- }
603
- }
604
626
  }
605
- if (include && include.length > 0) {
606
- await this.getIncludedMessagesWithContext(threadId, include, messageIds);
607
- }
608
- const messages = await this.fetchAndParseMessagesFromMultipleThreads(
609
- Array.from(messageIds),
610
- include,
611
- include && include.length > 0 ? void 0 : threadId
627
+ const threadMessages = await this.fetchAndParseMessagesFromMultipleThreads(
628
+ Array.from(threadMessageIds),
629
+ void 0,
630
+ threadIds.length === 1 ? threadIds[0] : void 0
612
631
  );
613
- let filteredMessages = messages;
632
+ let filteredThreadMessages = threadMessages;
614
633
  if (resourceId) {
615
- filteredMessages = filteredMessages.filter((msg) => msg.resourceId === resourceId);
634
+ filteredThreadMessages = filteredThreadMessages.filter((msg) => msg.resourceId === resourceId);
616
635
  }
617
636
  const dateRange = filter?.dateRange;
618
637
  if (dateRange) {
619
- filteredMessages = filteredMessages.filter((msg) => {
638
+ filteredThreadMessages = filteredThreadMessages.filter((msg) => {
620
639
  const messageDate = new Date(msg.createdAt);
621
640
  if (dateRange.start && messageDate < new Date(dateRange.start)) return false;
622
641
  if (dateRange.end && messageDate > new Date(dateRange.end)) return false;
623
642
  return true;
624
643
  });
625
644
  }
626
- let total;
627
- if (hasFilters) {
628
- total = filteredMessages.length;
629
- } else {
630
- try {
631
- const threadMessagesKey = this.getThreadMessagesKey(threadId);
632
- const fullOrder = await this.getFullOrder(threadMessagesKey);
633
- total = fullOrder.length;
634
- } catch {
635
- total = filteredMessages.length;
636
- }
637
- }
645
+ const total = filteredThreadMessages.length;
638
646
  if (perPage === 0 && (!include || include.length === 0)) {
639
647
  return {
640
648
  messages: [],
@@ -644,45 +652,58 @@ var MemoryStorageCloudflare = class extends MemoryStorage {
644
652
  hasMore: offset < total
645
653
  };
646
654
  }
647
- if (hasFilters && perPage !== Number.MAX_SAFE_INTEGER && perPage > 0) {
648
- if (direction === "ASC") {
649
- filteredMessages = filteredMessages.slice(offset, offset + perPage);
650
- } else {
651
- const start = Math.max(filteredMessages.length - offset - perPage, 0);
652
- const end = filteredMessages.length - offset;
653
- filteredMessages = filteredMessages.slice(start, end);
655
+ filteredThreadMessages.sort((a, b) => {
656
+ const timeA = new Date(a.createdAt).getTime();
657
+ const timeB = new Date(b.createdAt).getTime();
658
+ const timeDiff = direction === "ASC" ? timeA - timeB : timeB - timeA;
659
+ if (timeDiff === 0) {
660
+ return a.id.localeCompare(b.id);
654
661
  }
662
+ return timeDiff;
663
+ });
664
+ let paginatedMessages;
665
+ if (perPage === 0) {
666
+ paginatedMessages = [];
667
+ } else if (perPage === Number.MAX_SAFE_INTEGER) {
668
+ paginatedMessages = filteredThreadMessages;
669
+ } else {
670
+ paginatedMessages = filteredThreadMessages.slice(offset, offset + perPage);
655
671
  }
656
- const paginatedCount = hasFilters && perPage !== Number.MAX_SAFE_INTEGER && perPage > 0 ? filteredMessages.length : filteredMessages.length;
657
- try {
658
- const threadMessagesKey = this.getThreadMessagesKey(threadId);
659
- const messageOrder = await this.getFullOrder(threadMessagesKey);
660
- const orderMap = new Map(messageOrder.map((id, index) => [id, index]));
661
- filteredMessages.sort((a, b) => {
662
- const indexA = orderMap.get(a.id);
663
- const indexB = orderMap.get(b.id);
664
- if (indexA !== void 0 && indexB !== void 0) {
665
- return direction === "ASC" ? indexA - indexB : indexB - indexA;
666
- }
667
- const timeA = new Date(a.createdAt).getTime();
668
- const timeB = new Date(b.createdAt).getTime();
669
- const timeDiff = direction === "ASC" ? timeA - timeB : timeB - timeA;
670
- if (timeDiff === 0) {
671
- return a.id.localeCompare(b.id);
672
- }
673
- return timeDiff;
674
- });
675
- } catch {
676
- filteredMessages.sort((a, b) => {
677
- const timeA = new Date(a.createdAt).getTime();
678
- const timeB = new Date(b.createdAt).getTime();
679
- const timeDiff = direction === "ASC" ? timeA - timeB : timeB - timeA;
680
- if (timeDiff === 0) {
681
- return a.id.localeCompare(b.id);
682
- }
683
- return timeDiff;
684
- });
672
+ let includedMessages = [];
673
+ if (include && include.length > 0) {
674
+ const includedMessageIds = /* @__PURE__ */ new Set();
675
+ await this.getIncludedMessagesWithContext(include, includedMessageIds);
676
+ const paginatedIds = new Set(paginatedMessages.map((m) => m.id));
677
+ const idsToFetch = Array.from(includedMessageIds).filter((id) => !paginatedIds.has(id));
678
+ if (idsToFetch.length > 0) {
679
+ includedMessages = await this.fetchAndParseMessagesFromMultipleThreads(idsToFetch, include, void 0);
680
+ }
681
+ }
682
+ const seenIds = /* @__PURE__ */ new Set();
683
+ const allMessages = [];
684
+ for (const msg of paginatedMessages) {
685
+ if (!seenIds.has(msg.id)) {
686
+ allMessages.push(msg);
687
+ seenIds.add(msg.id);
688
+ }
685
689
  }
690
+ for (const msg of includedMessages) {
691
+ if (!seenIds.has(msg.id)) {
692
+ allMessages.push(msg);
693
+ seenIds.add(msg.id);
694
+ }
695
+ }
696
+ allMessages.sort((a, b) => {
697
+ const timeA = new Date(a.createdAt).getTime();
698
+ const timeB = new Date(b.createdAt).getTime();
699
+ const timeDiff = direction === "ASC" ? timeA - timeB : timeB - timeA;
700
+ if (timeDiff === 0) {
701
+ return a.id.localeCompare(b.id);
702
+ }
703
+ return timeDiff;
704
+ });
705
+ let filteredMessages = allMessages;
706
+ const paginatedCount = paginatedMessages.length;
686
707
  if (total === 0 && filteredMessages.length === 0 && (!include || include.length === 0)) {
687
708
  return {
688
709
  messages: [],
@@ -697,7 +718,11 @@ var MemoryStorageCloudflare = class extends MemoryStorage {
697
718
  type: message.type !== "v2" ? message.type : void 0,
698
719
  createdAt: ensureDate(message.createdAt)
699
720
  }));
700
- const list = new MessageList({ threadId, resourceId }).add(prepared, "memory");
721
+ const primaryThreadId = Array.isArray(threadId) ? threadId[0] : threadId;
722
+ const list = new MessageList({ threadId: primaryThreadId, resourceId }).add(
723
+ prepared,
724
+ "memory"
725
+ );
701
726
  let finalMessages = list.get.all.db();
702
727
  finalMessages = finalMessages.sort((a, b) => {
703
728
  const isDateField = field === "createdAt" || field === "updatedAt";
@@ -713,16 +738,12 @@ var MemoryStorageCloudflare = class extends MemoryStorage {
713
738
  const cmp = direction === "ASC" ? String(aVal).localeCompare(String(bVal)) : String(bVal).localeCompare(String(aVal));
714
739
  return cmp !== 0 ? cmp : a.id.localeCompare(b.id);
715
740
  });
716
- const returnedThreadMessageIds = new Set(finalMessages.filter((m) => m.threadId === threadId).map((m) => m.id));
741
+ const threadIdSet = new Set(threadIds);
742
+ const returnedThreadMessageIds = new Set(
743
+ finalMessages.filter((m) => m.threadId && threadIdSet.has(m.threadId)).map((m) => m.id)
744
+ );
717
745
  const allThreadMessagesReturned = returnedThreadMessageIds.size >= total;
718
- let hasMore;
719
- if (perPageInput === false || allThreadMessagesReturned) {
720
- hasMore = false;
721
- } else if (direction === "ASC") {
722
- hasMore = offset + paginatedCount < total;
723
- } else {
724
- hasMore = total - offset - perPage > 0;
725
- }
746
+ const hasMore = perPageInput !== false && !allThreadMessagesReturned && offset + paginatedCount < total;
726
747
  return {
727
748
  messages: finalMessages,
728
749
  total,
@@ -736,9 +757,9 @@ var MemoryStorageCloudflare = class extends MemoryStorage {
736
757
  id: "CLOUDFLARE_STORAGE_LIST_MESSAGES_FAILED",
737
758
  domain: ErrorDomain.STORAGE,
738
759
  category: ErrorCategory.THIRD_PARTY,
739
- text: `Failed to list messages for thread ${threadId}: ${error instanceof Error ? error.message : String(error)}`,
760
+ text: `Failed to list messages for thread ${Array.isArray(threadId) ? threadId.join(",") : threadId}: ${error instanceof Error ? error.message : String(error)}`,
740
761
  details: {
741
- threadId,
762
+ threadId: Array.isArray(threadId) ? threadId.join(",") : threadId,
742
763
  resourceId: resourceId ?? ""
743
764
  }
744
765
  },
@@ -1490,17 +1511,7 @@ var StoreOperationsCloudflare = class extends StoreOperations {
1490
1511
  }
1491
1512
  };
1492
1513
  function transformScoreRow(row) {
1493
- const deserialized = { ...row };
1494
- deserialized.input = safelyParseJSON(row.input);
1495
- deserialized.output = safelyParseJSON(row.output);
1496
- deserialized.scorer = safelyParseJSON(row.scorer);
1497
- deserialized.preprocessStepResult = safelyParseJSON(row.preprocessStepResult);
1498
- deserialized.analyzeStepResult = safelyParseJSON(row.analyzeStepResult);
1499
- deserialized.metadata = safelyParseJSON(row.metadata);
1500
- deserialized.additionalContext = safelyParseJSON(row.additionalContext);
1501
- deserialized.requestContext = safelyParseJSON(row.requestContext);
1502
- deserialized.entity = safelyParseJSON(row.entity);
1503
- return deserialized;
1514
+ return transformScoreRow$1(row);
1504
1515
  }
1505
1516
  var ScoresStorageCloudflare = class extends ScoresStorage {
1506
1517
  operations;