@mastra/upstash 1.0.0-beta.2 → 1.0.0-beta.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,4 +1,4 @@
1
- import { MastraStorage, StoreOperations, ScoresStorage, TABLE_SCORERS, normalizePerPage, calculatePagination, WorkflowsStorage, TABLE_WORKFLOW_SNAPSHOT, MemoryStorage, TABLE_THREADS, TABLE_RESOURCES, TABLE_MESSAGES, serializeDate } from '@mastra/core/storage';
1
+ import { MastraStorage, StoreOperations, ScoresStorage, TABLE_SCORERS, normalizePerPage, calculatePagination, WorkflowsStorage, TABLE_WORKFLOW_SNAPSHOT, MemoryStorage, TABLE_THREADS, TABLE_RESOURCES, TABLE_MESSAGES, serializeDate, transformScoreRow as transformScoreRow$1 } from '@mastra/core/storage';
2
2
  import { Redis } from '@upstash/redis';
3
3
  import { MessageList } from '@mastra/core/agent';
4
4
  import { MastraError, ErrorCategory, ErrorDomain } from '@mastra/core/error';
@@ -52,6 +52,9 @@ function getMessageKey(threadId, messageId) {
52
52
  const key = getKey(TABLE_MESSAGES, { threadId, id: messageId });
53
53
  return key;
54
54
  }
55
+ function getMessageIndexKey(messageId) {
56
+ return `msg-idx:${messageId}`;
57
+ }
55
58
  var StoreMemoryUpstash = class extends MemoryStorage {
56
59
  client;
57
60
  operations;
@@ -318,6 +321,7 @@ var StoreMemoryUpstash = class extends MemoryStorage {
318
321
  }
319
322
  }
320
323
  pipeline.set(key, message);
324
+ pipeline.set(getMessageIndexKey(message.id), message.threadId);
321
325
  pipeline.zadd(getThreadMessagesKey(message.threadId), {
322
326
  score,
323
327
  member: message.id
@@ -348,43 +352,60 @@ var StoreMemoryUpstash = class extends MemoryStorage {
348
352
  );
349
353
  }
350
354
  }
351
- async _getIncludedMessages(threadId, include) {
352
- if (!threadId.trim()) throw new Error("threadId must be a non-empty string");
355
+ /**
356
+ * Lookup threadId for a message - tries index first (O(1)), falls back to scan (backwards compatible)
357
+ */
358
+ async _getThreadIdForMessage(messageId) {
359
+ const indexedThreadId = await this.client.get(getMessageIndexKey(messageId));
360
+ if (indexedThreadId) {
361
+ return indexedThreadId;
362
+ }
363
+ const existingKeyPattern = getMessageKey("*", messageId);
364
+ const keys = await this.operations.scanKeys(existingKeyPattern);
365
+ if (keys.length === 0) return null;
366
+ const messageData = await this.client.get(keys[0]);
367
+ if (!messageData) return null;
368
+ if (messageData.threadId) {
369
+ await this.client.set(getMessageIndexKey(messageId), messageData.threadId);
370
+ }
371
+ return messageData.threadId || null;
372
+ }
373
+ async _getIncludedMessages(include) {
374
+ if (!include?.length) return [];
353
375
  const messageIds = /* @__PURE__ */ new Set();
354
376
  const messageIdToThreadIds = {};
355
- if (include?.length) {
356
- for (const item of include) {
357
- messageIds.add(item.id);
358
- const itemThreadId = item.threadId || threadId;
359
- messageIdToThreadIds[item.id] = itemThreadId;
360
- const itemThreadMessagesKey = getThreadMessagesKey(itemThreadId);
361
- const rank = await this.client.zrank(itemThreadMessagesKey, item.id);
362
- if (rank === null) continue;
363
- if (item.withPreviousMessages) {
364
- const start = Math.max(0, rank - item.withPreviousMessages);
365
- const prevIds = rank === 0 ? [] : await this.client.zrange(itemThreadMessagesKey, start, rank - 1);
366
- prevIds.forEach((id) => {
367
- messageIds.add(id);
368
- messageIdToThreadIds[id] = itemThreadId;
369
- });
370
- }
371
- if (item.withNextMessages) {
372
- const nextIds = await this.client.zrange(itemThreadMessagesKey, rank + 1, rank + item.withNextMessages);
373
- nextIds.forEach((id) => {
374
- messageIds.add(id);
375
- messageIdToThreadIds[id] = itemThreadId;
376
- });
377
- }
377
+ for (const item of include) {
378
+ const itemThreadId = await this._getThreadIdForMessage(item.id);
379
+ if (!itemThreadId) continue;
380
+ messageIds.add(item.id);
381
+ messageIdToThreadIds[item.id] = itemThreadId;
382
+ const itemThreadMessagesKey = getThreadMessagesKey(itemThreadId);
383
+ const rank = await this.client.zrank(itemThreadMessagesKey, item.id);
384
+ if (rank === null) continue;
385
+ if (item.withPreviousMessages) {
386
+ const start = Math.max(0, rank - item.withPreviousMessages);
387
+ const prevIds = rank === 0 ? [] : await this.client.zrange(itemThreadMessagesKey, start, rank - 1);
388
+ prevIds.forEach((id) => {
389
+ messageIds.add(id);
390
+ messageIdToThreadIds[id] = itemThreadId;
391
+ });
392
+ }
393
+ if (item.withNextMessages) {
394
+ const nextIds = await this.client.zrange(itemThreadMessagesKey, rank + 1, rank + item.withNextMessages);
395
+ nextIds.forEach((id) => {
396
+ messageIds.add(id);
397
+ messageIdToThreadIds[id] = itemThreadId;
398
+ });
378
399
  }
379
- const pipeline = this.client.pipeline();
380
- Array.from(messageIds).forEach((id) => {
381
- const tId = messageIdToThreadIds[id] || threadId;
382
- pipeline.get(getMessageKey(tId, id));
383
- });
384
- const results = await pipeline.exec();
385
- return results.filter((result) => result !== null);
386
400
  }
387
- return [];
401
+ if (messageIds.size === 0) return [];
402
+ const pipeline = this.client.pipeline();
403
+ Array.from(messageIds).forEach((id) => {
404
+ const tId = messageIdToThreadIds[id];
405
+ pipeline.get(getMessageKey(tId, id));
406
+ });
407
+ const results = await pipeline.exec();
408
+ return results.filter((result) => result !== null);
388
409
  }
389
410
  parseStoredMessage(storedMessage) {
390
411
  const defaultMessageContent = { format: 2, parts: [{ type: "text", text: "" }] };
@@ -398,17 +419,49 @@ var StoreMemoryUpstash = class extends MemoryStorage {
398
419
  async listMessagesById({ messageIds }) {
399
420
  if (messageIds.length === 0) return { messages: [] };
400
421
  try {
401
- const threadKeys = await this.client.keys("thread:*");
402
- const result = await Promise.all(
403
- threadKeys.map((threadKey) => {
404
- const threadId = threadKey.split(":")[1];
405
- if (!threadId) throw new Error(`Failed to parse thread ID from thread key "${threadKey}"`);
406
- return this.client.mget(
407
- messageIds.map((id) => getMessageKey(threadId, id))
408
- );
409
- })
410
- );
411
- const rawMessages = result.flat(1).filter((msg) => !!msg);
422
+ const rawMessages = [];
423
+ const indexPipeline = this.client.pipeline();
424
+ messageIds.forEach((id) => indexPipeline.get(getMessageIndexKey(id)));
425
+ const indexResults = await indexPipeline.exec();
426
+ const indexedIds = [];
427
+ const unindexedIds = [];
428
+ messageIds.forEach((id, i) => {
429
+ const threadId = indexResults[i];
430
+ if (threadId) {
431
+ indexedIds.push({ messageId: id, threadId });
432
+ } else {
433
+ unindexedIds.push(id);
434
+ }
435
+ });
436
+ if (indexedIds.length > 0) {
437
+ const messagePipeline = this.client.pipeline();
438
+ indexedIds.forEach(({ messageId, threadId }) => messagePipeline.get(getMessageKey(threadId, messageId)));
439
+ const messageResults = await messagePipeline.exec();
440
+ rawMessages.push(...messageResults.filter((msg) => msg !== null));
441
+ }
442
+ if (unindexedIds.length > 0) {
443
+ const threadKeys = await this.client.keys("thread:*");
444
+ const result = await Promise.all(
445
+ threadKeys.map((threadKey) => {
446
+ const threadId = threadKey.split(":")[1];
447
+ if (!threadId) throw new Error(`Failed to parse thread ID from thread key "${threadKey}"`);
448
+ return this.client.mget(
449
+ unindexedIds.map((id) => getMessageKey(threadId, id))
450
+ );
451
+ })
452
+ );
453
+ const foundMessages = result.flat(1).filter((msg) => !!msg);
454
+ rawMessages.push(...foundMessages);
455
+ if (foundMessages.length > 0) {
456
+ const backfillPipeline = this.client.pipeline();
457
+ foundMessages.forEach((msg) => {
458
+ if (msg.threadId) {
459
+ backfillPipeline.set(getMessageIndexKey(msg.id), msg.threadId);
460
+ }
461
+ });
462
+ await backfillPipeline.exec();
463
+ }
464
+ }
412
465
  const list = new MessageList().add(rawMessages.map(this.parseStoredMessage), "memory");
413
466
  return { messages: list.get.all.db() };
414
467
  } catch (error) {
@@ -427,18 +480,18 @@ var StoreMemoryUpstash = class extends MemoryStorage {
427
480
  }
428
481
  async listMessages(args) {
429
482
  const { threadId, resourceId, include, filter, perPage: perPageInput, page = 0, orderBy } = args;
430
- if (!threadId.trim()) {
483
+ const threadIds = Array.isArray(threadId) ? threadId : [threadId];
484
+ if (threadIds.length === 0 || threadIds.some((id) => !id.trim())) {
431
485
  throw new MastraError(
432
486
  {
433
487
  id: "STORAGE_UPSTASH_LIST_MESSAGES_INVALID_THREAD_ID",
434
488
  domain: ErrorDomain.STORAGE,
435
489
  category: ErrorCategory.THIRD_PARTY,
436
- details: { threadId }
490
+ details: { threadId: Array.isArray(threadId) ? threadId.join(",") : threadId }
437
491
  },
438
- new Error("threadId must be a non-empty string")
492
+ new Error("threadId must be a non-empty string or array of non-empty strings")
439
493
  );
440
494
  }
441
- const threadMessagesKey = getThreadMessagesKey(threadId);
442
495
  const perPage = normalizePerPage(perPageInput, 40);
443
496
  const { offset, perPage: perPageForResponse } = calculatePagination(page, perPageInput, perPage);
444
497
  try {
@@ -455,11 +508,18 @@ var StoreMemoryUpstash = class extends MemoryStorage {
455
508
  }
456
509
  let includedMessages = [];
457
510
  if (include && include.length > 0) {
458
- const included = await this._getIncludedMessages(threadId, include);
511
+ const included = await this._getIncludedMessages(include);
459
512
  includedMessages = included.map(this.parseStoredMessage);
460
513
  }
461
- const allMessageIds = await this.client.zrange(threadMessagesKey, 0, -1);
462
- if (allMessageIds.length === 0) {
514
+ const allMessageIdsWithThreads = [];
515
+ for (const tid of threadIds) {
516
+ const threadMessagesKey = getThreadMessagesKey(tid);
517
+ const messageIds2 = await this.client.zrange(threadMessagesKey, 0, -1);
518
+ for (const mid of messageIds2) {
519
+ allMessageIdsWithThreads.push({ threadId: tid, messageId: mid });
520
+ }
521
+ }
522
+ if (allMessageIdsWithThreads.length === 0) {
463
523
  return {
464
524
  messages: [],
465
525
  total: 0,
@@ -469,7 +529,7 @@ var StoreMemoryUpstash = class extends MemoryStorage {
469
529
  };
470
530
  }
471
531
  const pipeline = this.client.pipeline();
472
- allMessageIds.forEach((id) => pipeline.get(getMessageKey(threadId, id)));
532
+ allMessageIdsWithThreads.forEach(({ threadId: tid, messageId }) => pipeline.get(getMessageKey(tid, messageId)));
473
533
  const results = await pipeline.exec();
474
534
  let messagesData = results.filter((msg) => msg !== null).map(this.parseStoredMessage);
475
535
  if (resourceId) {
@@ -545,7 +605,7 @@ var StoreMemoryUpstash = class extends MemoryStorage {
545
605
  domain: ErrorDomain.STORAGE,
546
606
  category: ErrorCategory.THIRD_PARTY,
547
607
  details: {
548
- threadId,
608
+ threadId: Array.isArray(threadId) ? threadId.join(",") : threadId,
549
609
  resourceId: resourceId ?? ""
550
610
  }
551
611
  },
@@ -750,13 +810,33 @@ var StoreMemoryUpstash = class extends MemoryStorage {
750
810
  try {
751
811
  const threadIds = /* @__PURE__ */ new Set();
752
812
  const messageKeys = [];
753
- for (const messageId of messageIds) {
813
+ const foundMessageIds = [];
814
+ const indexPipeline = this.client.pipeline();
815
+ messageIds.forEach((id) => indexPipeline.get(getMessageIndexKey(id)));
816
+ const indexResults = await indexPipeline.exec();
817
+ const indexedMessages = [];
818
+ const unindexedMessageIds = [];
819
+ messageIds.forEach((id, i) => {
820
+ const threadId = indexResults[i];
821
+ if (threadId) {
822
+ indexedMessages.push({ messageId: id, threadId });
823
+ } else {
824
+ unindexedMessageIds.push(id);
825
+ }
826
+ });
827
+ for (const { messageId, threadId } of indexedMessages) {
828
+ messageKeys.push(getMessageKey(threadId, messageId));
829
+ foundMessageIds.push(messageId);
830
+ threadIds.add(threadId);
831
+ }
832
+ for (const messageId of unindexedMessageIds) {
754
833
  const pattern = getMessageKey("*", messageId);
755
834
  const keys = await this.operations.scanKeys(pattern);
756
835
  for (const key of keys) {
757
836
  const message = await this.client.get(key);
758
837
  if (message && message.id === messageId) {
759
838
  messageKeys.push(key);
839
+ foundMessageIds.push(messageId);
760
840
  if (message.threadId) {
761
841
  threadIds.add(message.threadId);
762
842
  }
@@ -771,6 +851,9 @@ var StoreMemoryUpstash = class extends MemoryStorage {
771
851
  for (const key of messageKeys) {
772
852
  pipeline.del(key);
773
853
  }
854
+ for (const messageId of foundMessageIds) {
855
+ pipeline.del(getMessageIndexKey(messageId));
856
+ }
774
857
  if (threadIds.size > 0) {
775
858
  for (const threadId of threadIds) {
776
859
  const threadKey = getKey(TABLE_THREADS, { id: threadId });
@@ -946,32 +1029,7 @@ var StoreOperationsUpstash = class extends StoreOperations {
946
1029
  }
947
1030
  };
948
1031
  function transformScoreRow(row) {
949
- const parseField = (v) => {
950
- if (typeof v === "string") {
951
- try {
952
- return JSON.parse(v);
953
- } catch {
954
- return v;
955
- }
956
- }
957
- return v;
958
- };
959
- return {
960
- ...row,
961
- scorer: parseField(row.scorer),
962
- preprocessStepResult: parseField(row.preprocessStepResult),
963
- generateScorePrompt: row.generateScorePrompt,
964
- generateReasonPrompt: row.generateReasonPrompt,
965
- analyzeStepResult: parseField(row.analyzeStepResult),
966
- metadata: parseField(row.metadata),
967
- input: parseField(row.input),
968
- output: parseField(row.output),
969
- additionalContext: parseField(row.additionalContext),
970
- requestContext: parseField(row.requestContext),
971
- entity: parseField(row.entity),
972
- createdAt: row.createdAt,
973
- updatedAt: row.updatedAt
974
- };
1032
+ return transformScoreRow$1(row);
975
1033
  }
976
1034
  var ScoresUpstash = class extends ScoresStorage {
977
1035
  client;