@mastra/dynamodb 1.0.2 → 1.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1345,6 +1345,20 @@ var MemoryStorageDynamoDB = class extends MemoryStorage {
1345
1345
  field,
1346
1346
  direction
1347
1347
  });
1348
+ if (perPage === 0 && (!include || include.length === 0)) {
1349
+ return { messages: [], total: 0, page, perPage: perPageForResponse, hasMore: false };
1350
+ }
1351
+ if (perPage === 0 && include && include.length > 0) {
1352
+ const includeMessages2 = await this._getIncludedMessages({ include });
1353
+ const list2 = new MessageList().add(includeMessages2, "memory");
1354
+ return {
1355
+ messages: this._sortMessages(list2.get.all.db(), field, direction),
1356
+ total: 0,
1357
+ page,
1358
+ perPage: perPageForResponse,
1359
+ hasMore: false
1360
+ };
1361
+ }
1348
1362
  const query = this.service.entities.message.query.byThread({ entity: "message", threadId });
1349
1363
  const results = await query.go();
1350
1364
  let allThreadMessages = results.data.map((data) => this.parseMessageData(data)).filter((msg) => "content" in msg && typeof msg.content === "object");
@@ -1379,8 +1393,7 @@ var MemoryStorageDynamoDB = class extends MemoryStorage {
1379
1393
  const messageIds = new Set(paginatedMessages.map((m) => m.id));
1380
1394
  let includeMessages = [];
1381
1395
  if (include && include.length > 0) {
1382
- const selectBy = { include };
1383
- includeMessages = await this._getIncludedMessages(selectBy);
1396
+ includeMessages = await this._getIncludedMessages({ include });
1384
1397
  for (const includeMsg of includeMessages) {
1385
1398
  if (!messageIds.has(includeMsg.id)) {
1386
1399
  paginatedMessages.push(includeMsg);
@@ -1390,14 +1403,7 @@ var MemoryStorageDynamoDB = class extends MemoryStorage {
1390
1403
  }
1391
1404
  const list = new MessageList().add(paginatedMessages, "memory");
1392
1405
  let finalMessages = list.get.all.db();
1393
- finalMessages = finalMessages.sort((a, b) => {
1394
- const aValue = field === "createdAt" ? new Date(a.createdAt).getTime() : a[field];
1395
- const bValue = field === "createdAt" ? new Date(b.createdAt).getTime() : b[field];
1396
- if (aValue === bValue) {
1397
- return a.id.localeCompare(b.id);
1398
- }
1399
- return direction === "ASC" ? aValue - bValue : bValue - aValue;
1400
- });
1406
+ finalMessages = this._sortMessages(finalMessages, field, direction);
1401
1407
  const returnedThreadMessageIds = new Set(finalMessages.filter((m) => m.threadId === threadId).map((m) => m.id));
1402
1408
  const allThreadMessagesReturned = returnedThreadMessageIds.size >= total;
1403
1409
  let hasMore = false;
@@ -1584,68 +1590,74 @@ var MemoryStorageDynamoDB = class extends MemoryStorage {
1584
1590
  }
1585
1591
  }
1586
1592
  // Helper method to get included messages with context
1587
- async _getIncludedMessages(selectBy) {
1588
- if (!selectBy?.include?.length) {
1593
+ _sortMessages(messages, field, direction) {
1594
+ return messages.sort((a, b) => {
1595
+ const aValue = field === "createdAt" ? new Date(a.createdAt).getTime() : a[field];
1596
+ const bValue = field === "createdAt" ? new Date(b.createdAt).getTime() : b[field];
1597
+ if (aValue === bValue) {
1598
+ return a.id.localeCompare(b.id);
1599
+ }
1600
+ return direction === "ASC" ? aValue - bValue : bValue - aValue;
1601
+ });
1602
+ }
1603
+ async _getIncludedMessages({
1604
+ include
1605
+ }) {
1606
+ if (!include?.length) {
1589
1607
  return [];
1590
1608
  }
1591
- const includeMessages = [];
1592
- for (const includeItem of selectBy.include) {
1593
- try {
1594
- const { id, withPreviousMessages = 0, withNextMessages = 0 } = includeItem;
1595
- const targetResult = await this.service.entities.message.get({ entity: "message", id }).go();
1596
- if (!targetResult.data) {
1597
- this.logger.warn("Target message not found", { id });
1598
- continue;
1599
- }
1600
- const targetMessageData = targetResult.data;
1601
- const searchThreadId = targetMessageData.threadId;
1602
- this.logger.debug("Getting included messages for", {
1603
- id,
1604
- searchThreadId,
1605
- withPreviousMessages,
1606
- withNextMessages
1607
- });
1608
- const query = this.service.entities.message.query.byThread({ entity: "message", threadId: searchThreadId });
1609
- const results = await query.go();
1610
- const allMessages = results.data.map((data) => this.parseMessageData(data)).filter((msg) => "content" in msg && typeof msg.content === "object");
1611
- this.logger.debug("Found messages in thread", {
1612
- threadId: searchThreadId,
1613
- messageCount: allMessages.length,
1614
- messageIds: allMessages.map((m) => m.id)
1615
- });
1616
- allMessages.sort((a, b) => {
1617
- const timeA = a.createdAt.getTime();
1618
- const timeB = b.createdAt.getTime();
1619
- if (timeA === timeB) {
1620
- return a.id.localeCompare(b.id);
1621
- }
1622
- return timeA - timeB;
1623
- });
1624
- const targetIndex = allMessages.findIndex((msg) => msg.id === id);
1625
- if (targetIndex === -1) {
1626
- this.logger.warn("Target message not found in thread", { id, threadId: searchThreadId });
1627
- continue;
1628
- }
1629
- this.logger.debug("Found target message at index", { id, targetIndex, totalMessages: allMessages.length });
1630
- const startIndex = Math.max(0, targetIndex - withPreviousMessages);
1631
- const endIndex = Math.min(allMessages.length, targetIndex + withNextMessages + 1);
1632
- const contextMessages = allMessages.slice(startIndex, endIndex);
1633
- this.logger.debug("Context messages", {
1634
- startIndex,
1635
- endIndex,
1636
- contextCount: contextMessages.length,
1637
- contextIds: contextMessages.map((m) => m.id)
1638
- });
1639
- includeMessages.push(...contextMessages);
1640
- } catch (error) {
1641
- this.logger.warn("Failed to get included message", { messageId: includeItem.id, error });
1609
+ const targetResults = await Promise.all(
1610
+ include.map(
1611
+ (inc) => this.service.entities.message.get({ entity: "message", id: inc.id }).go().then((r) => ({ id: inc.id, data: r.data })).catch(() => ({ id: inc.id, data: null }))
1612
+ )
1613
+ );
1614
+ const targetMap = /* @__PURE__ */ new Map();
1615
+ for (const { id, data } of targetResults) {
1616
+ if (data) {
1617
+ targetMap.set(id, { threadId: data.threadId });
1642
1618
  }
1643
1619
  }
1644
- this.logger.debug("Total included messages", {
1645
- count: includeMessages.length,
1646
- ids: includeMessages.map((m) => m.id)
1620
+ if (targetMap.size === 0) return [];
1621
+ const threadCache = /* @__PURE__ */ new Map();
1622
+ const uniqueThreadIds = [...new Set([...targetMap.values()].map((t) => t.threadId))];
1623
+ await Promise.all(
1624
+ uniqueThreadIds.map(async (threadId) => {
1625
+ try {
1626
+ const query = this.service.entities.message.query.byThread({ entity: "message", threadId });
1627
+ const results = await query.go();
1628
+ const messages = results.data.map((data) => this.parseMessageData(data)).filter(
1629
+ (msg) => "content" in msg && typeof msg.content === "object"
1630
+ );
1631
+ messages.sort((a, b) => {
1632
+ const timeA = a.createdAt.getTime();
1633
+ const timeB = b.createdAt.getTime();
1634
+ if (timeA === timeB) return a.id.localeCompare(b.id);
1635
+ return timeA - timeB;
1636
+ });
1637
+ threadCache.set(threadId, messages);
1638
+ } catch {
1639
+ }
1640
+ })
1641
+ );
1642
+ const includeMessages = [];
1643
+ for (const includeItem of include) {
1644
+ const { id, withPreviousMessages = 0, withNextMessages = 0 } = includeItem;
1645
+ const target = targetMap.get(id);
1646
+ if (!target) continue;
1647
+ const allMessages = threadCache.get(target.threadId);
1648
+ if (!allMessages) continue;
1649
+ const targetIndex = allMessages.findIndex((msg) => msg.id === id);
1650
+ if (targetIndex === -1) continue;
1651
+ const startIndex = Math.max(0, targetIndex - withPreviousMessages);
1652
+ const endIndex = Math.min(allMessages.length, targetIndex + withNextMessages + 1);
1653
+ includeMessages.push(...allMessages.slice(startIndex, endIndex));
1654
+ }
1655
+ const seen = /* @__PURE__ */ new Set();
1656
+ return includeMessages.filter((msg) => {
1657
+ if (seen.has(msg.id)) return false;
1658
+ seen.add(msg.id);
1659
+ return true;
1647
1660
  });
1648
- return includeMessages;
1649
1661
  }
1650
1662
  async updateMessages(args) {
1651
1663
  const { messages } = args;