@mastra/upstash 1.0.0-beta.2 → 1.0.0-beta.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +25 -0
- package/dist/index.cjs +138 -80
- package/dist/index.cjs.map +1 -1
- package/dist/index.js +139 -81
- package/dist/index.js.map +1 -1
- package/dist/storage/domains/memory/index.d.ts +4 -0
- package/dist/storage/domains/memory/index.d.ts.map +1 -1
- package/dist/storage/domains/scores/index.d.ts.map +1 -1
- package/package.json +3 -3
package/dist/index.js
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { MastraStorage, StoreOperations, ScoresStorage, TABLE_SCORERS, normalizePerPage, calculatePagination, WorkflowsStorage, TABLE_WORKFLOW_SNAPSHOT, MemoryStorage, TABLE_THREADS, TABLE_RESOURCES, TABLE_MESSAGES, serializeDate } from '@mastra/core/storage';
|
|
1
|
+
import { MastraStorage, StoreOperations, ScoresStorage, TABLE_SCORERS, normalizePerPage, calculatePagination, WorkflowsStorage, TABLE_WORKFLOW_SNAPSHOT, MemoryStorage, TABLE_THREADS, TABLE_RESOURCES, TABLE_MESSAGES, serializeDate, transformScoreRow as transformScoreRow$1 } from '@mastra/core/storage';
|
|
2
2
|
import { Redis } from '@upstash/redis';
|
|
3
3
|
import { MessageList } from '@mastra/core/agent';
|
|
4
4
|
import { MastraError, ErrorCategory, ErrorDomain } from '@mastra/core/error';
|
|
@@ -52,6 +52,9 @@ function getMessageKey(threadId, messageId) {
|
|
|
52
52
|
const key = getKey(TABLE_MESSAGES, { threadId, id: messageId });
|
|
53
53
|
return key;
|
|
54
54
|
}
|
|
55
|
+
function getMessageIndexKey(messageId) {
|
|
56
|
+
return `msg-idx:${messageId}`;
|
|
57
|
+
}
|
|
55
58
|
var StoreMemoryUpstash = class extends MemoryStorage {
|
|
56
59
|
client;
|
|
57
60
|
operations;
|
|
@@ -318,6 +321,7 @@ var StoreMemoryUpstash = class extends MemoryStorage {
|
|
|
318
321
|
}
|
|
319
322
|
}
|
|
320
323
|
pipeline.set(key, message);
|
|
324
|
+
pipeline.set(getMessageIndexKey(message.id), message.threadId);
|
|
321
325
|
pipeline.zadd(getThreadMessagesKey(message.threadId), {
|
|
322
326
|
score,
|
|
323
327
|
member: message.id
|
|
@@ -348,43 +352,60 @@ var StoreMemoryUpstash = class extends MemoryStorage {
|
|
|
348
352
|
);
|
|
349
353
|
}
|
|
350
354
|
}
|
|
351
|
-
|
|
352
|
-
|
|
355
|
+
/**
|
|
356
|
+
* Lookup threadId for a message - tries index first (O(1)), falls back to scan (backwards compatible)
|
|
357
|
+
*/
|
|
358
|
+
async _getThreadIdForMessage(messageId) {
|
|
359
|
+
const indexedThreadId = await this.client.get(getMessageIndexKey(messageId));
|
|
360
|
+
if (indexedThreadId) {
|
|
361
|
+
return indexedThreadId;
|
|
362
|
+
}
|
|
363
|
+
const existingKeyPattern = getMessageKey("*", messageId);
|
|
364
|
+
const keys = await this.operations.scanKeys(existingKeyPattern);
|
|
365
|
+
if (keys.length === 0) return null;
|
|
366
|
+
const messageData = await this.client.get(keys[0]);
|
|
367
|
+
if (!messageData) return null;
|
|
368
|
+
if (messageData.threadId) {
|
|
369
|
+
await this.client.set(getMessageIndexKey(messageId), messageData.threadId);
|
|
370
|
+
}
|
|
371
|
+
return messageData.threadId || null;
|
|
372
|
+
}
|
|
373
|
+
async _getIncludedMessages(include) {
|
|
374
|
+
if (!include?.length) return [];
|
|
353
375
|
const messageIds = /* @__PURE__ */ new Set();
|
|
354
376
|
const messageIdToThreadIds = {};
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
}
|
|
377
|
+
for (const item of include) {
|
|
378
|
+
const itemThreadId = await this._getThreadIdForMessage(item.id);
|
|
379
|
+
if (!itemThreadId) continue;
|
|
380
|
+
messageIds.add(item.id);
|
|
381
|
+
messageIdToThreadIds[item.id] = itemThreadId;
|
|
382
|
+
const itemThreadMessagesKey = getThreadMessagesKey(itemThreadId);
|
|
383
|
+
const rank = await this.client.zrank(itemThreadMessagesKey, item.id);
|
|
384
|
+
if (rank === null) continue;
|
|
385
|
+
if (item.withPreviousMessages) {
|
|
386
|
+
const start = Math.max(0, rank - item.withPreviousMessages);
|
|
387
|
+
const prevIds = rank === 0 ? [] : await this.client.zrange(itemThreadMessagesKey, start, rank - 1);
|
|
388
|
+
prevIds.forEach((id) => {
|
|
389
|
+
messageIds.add(id);
|
|
390
|
+
messageIdToThreadIds[id] = itemThreadId;
|
|
391
|
+
});
|
|
392
|
+
}
|
|
393
|
+
if (item.withNextMessages) {
|
|
394
|
+
const nextIds = await this.client.zrange(itemThreadMessagesKey, rank + 1, rank + item.withNextMessages);
|
|
395
|
+
nextIds.forEach((id) => {
|
|
396
|
+
messageIds.add(id);
|
|
397
|
+
messageIdToThreadIds[id] = itemThreadId;
|
|
398
|
+
});
|
|
378
399
|
}
|
|
379
|
-
const pipeline = this.client.pipeline();
|
|
380
|
-
Array.from(messageIds).forEach((id) => {
|
|
381
|
-
const tId = messageIdToThreadIds[id] || threadId;
|
|
382
|
-
pipeline.get(getMessageKey(tId, id));
|
|
383
|
-
});
|
|
384
|
-
const results = await pipeline.exec();
|
|
385
|
-
return results.filter((result) => result !== null);
|
|
386
400
|
}
|
|
387
|
-
return [];
|
|
401
|
+
if (messageIds.size === 0) return [];
|
|
402
|
+
const pipeline = this.client.pipeline();
|
|
403
|
+
Array.from(messageIds).forEach((id) => {
|
|
404
|
+
const tId = messageIdToThreadIds[id];
|
|
405
|
+
pipeline.get(getMessageKey(tId, id));
|
|
406
|
+
});
|
|
407
|
+
const results = await pipeline.exec();
|
|
408
|
+
return results.filter((result) => result !== null);
|
|
388
409
|
}
|
|
389
410
|
parseStoredMessage(storedMessage) {
|
|
390
411
|
const defaultMessageContent = { format: 2, parts: [{ type: "text", text: "" }] };
|
|
@@ -398,17 +419,49 @@ var StoreMemoryUpstash = class extends MemoryStorage {
|
|
|
398
419
|
async listMessagesById({ messageIds }) {
|
|
399
420
|
if (messageIds.length === 0) return { messages: [] };
|
|
400
421
|
try {
|
|
401
|
-
const
|
|
402
|
-
const
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
422
|
+
const rawMessages = [];
|
|
423
|
+
const indexPipeline = this.client.pipeline();
|
|
424
|
+
messageIds.forEach((id) => indexPipeline.get(getMessageIndexKey(id)));
|
|
425
|
+
const indexResults = await indexPipeline.exec();
|
|
426
|
+
const indexedIds = [];
|
|
427
|
+
const unindexedIds = [];
|
|
428
|
+
messageIds.forEach((id, i) => {
|
|
429
|
+
const threadId = indexResults[i];
|
|
430
|
+
if (threadId) {
|
|
431
|
+
indexedIds.push({ messageId: id, threadId });
|
|
432
|
+
} else {
|
|
433
|
+
unindexedIds.push(id);
|
|
434
|
+
}
|
|
435
|
+
});
|
|
436
|
+
if (indexedIds.length > 0) {
|
|
437
|
+
const messagePipeline = this.client.pipeline();
|
|
438
|
+
indexedIds.forEach(({ messageId, threadId }) => messagePipeline.get(getMessageKey(threadId, messageId)));
|
|
439
|
+
const messageResults = await messagePipeline.exec();
|
|
440
|
+
rawMessages.push(...messageResults.filter((msg) => msg !== null));
|
|
441
|
+
}
|
|
442
|
+
if (unindexedIds.length > 0) {
|
|
443
|
+
const threadKeys = await this.client.keys("thread:*");
|
|
444
|
+
const result = await Promise.all(
|
|
445
|
+
threadKeys.map((threadKey) => {
|
|
446
|
+
const threadId = threadKey.split(":")[1];
|
|
447
|
+
if (!threadId) throw new Error(`Failed to parse thread ID from thread key "${threadKey}"`);
|
|
448
|
+
return this.client.mget(
|
|
449
|
+
unindexedIds.map((id) => getMessageKey(threadId, id))
|
|
450
|
+
);
|
|
451
|
+
})
|
|
452
|
+
);
|
|
453
|
+
const foundMessages = result.flat(1).filter((msg) => !!msg);
|
|
454
|
+
rawMessages.push(...foundMessages);
|
|
455
|
+
if (foundMessages.length > 0) {
|
|
456
|
+
const backfillPipeline = this.client.pipeline();
|
|
457
|
+
foundMessages.forEach((msg) => {
|
|
458
|
+
if (msg.threadId) {
|
|
459
|
+
backfillPipeline.set(getMessageIndexKey(msg.id), msg.threadId);
|
|
460
|
+
}
|
|
461
|
+
});
|
|
462
|
+
await backfillPipeline.exec();
|
|
463
|
+
}
|
|
464
|
+
}
|
|
412
465
|
const list = new MessageList().add(rawMessages.map(this.parseStoredMessage), "memory");
|
|
413
466
|
return { messages: list.get.all.db() };
|
|
414
467
|
} catch (error) {
|
|
@@ -427,18 +480,18 @@ var StoreMemoryUpstash = class extends MemoryStorage {
|
|
|
427
480
|
}
|
|
428
481
|
async listMessages(args) {
|
|
429
482
|
const { threadId, resourceId, include, filter, perPage: perPageInput, page = 0, orderBy } = args;
|
|
430
|
-
|
|
483
|
+
const threadIds = Array.isArray(threadId) ? threadId : [threadId];
|
|
484
|
+
if (threadIds.length === 0 || threadIds.some((id) => !id.trim())) {
|
|
431
485
|
throw new MastraError(
|
|
432
486
|
{
|
|
433
487
|
id: "STORAGE_UPSTASH_LIST_MESSAGES_INVALID_THREAD_ID",
|
|
434
488
|
domain: ErrorDomain.STORAGE,
|
|
435
489
|
category: ErrorCategory.THIRD_PARTY,
|
|
436
|
-
details: { threadId }
|
|
490
|
+
details: { threadId: Array.isArray(threadId) ? threadId.join(",") : threadId }
|
|
437
491
|
},
|
|
438
|
-
new Error("threadId must be a non-empty string")
|
|
492
|
+
new Error("threadId must be a non-empty string or array of non-empty strings")
|
|
439
493
|
);
|
|
440
494
|
}
|
|
441
|
-
const threadMessagesKey = getThreadMessagesKey(threadId);
|
|
442
495
|
const perPage = normalizePerPage(perPageInput, 40);
|
|
443
496
|
const { offset, perPage: perPageForResponse } = calculatePagination(page, perPageInput, perPage);
|
|
444
497
|
try {
|
|
@@ -455,11 +508,18 @@ var StoreMemoryUpstash = class extends MemoryStorage {
|
|
|
455
508
|
}
|
|
456
509
|
let includedMessages = [];
|
|
457
510
|
if (include && include.length > 0) {
|
|
458
|
-
const included = await this._getIncludedMessages(
|
|
511
|
+
const included = await this._getIncludedMessages(include);
|
|
459
512
|
includedMessages = included.map(this.parseStoredMessage);
|
|
460
513
|
}
|
|
461
|
-
const
|
|
462
|
-
|
|
514
|
+
const allMessageIdsWithThreads = [];
|
|
515
|
+
for (const tid of threadIds) {
|
|
516
|
+
const threadMessagesKey = getThreadMessagesKey(tid);
|
|
517
|
+
const messageIds2 = await this.client.zrange(threadMessagesKey, 0, -1);
|
|
518
|
+
for (const mid of messageIds2) {
|
|
519
|
+
allMessageIdsWithThreads.push({ threadId: tid, messageId: mid });
|
|
520
|
+
}
|
|
521
|
+
}
|
|
522
|
+
if (allMessageIdsWithThreads.length === 0) {
|
|
463
523
|
return {
|
|
464
524
|
messages: [],
|
|
465
525
|
total: 0,
|
|
@@ -469,7 +529,7 @@ var StoreMemoryUpstash = class extends MemoryStorage {
|
|
|
469
529
|
};
|
|
470
530
|
}
|
|
471
531
|
const pipeline = this.client.pipeline();
|
|
472
|
-
|
|
532
|
+
allMessageIdsWithThreads.forEach(({ threadId: tid, messageId }) => pipeline.get(getMessageKey(tid, messageId)));
|
|
473
533
|
const results = await pipeline.exec();
|
|
474
534
|
let messagesData = results.filter((msg) => msg !== null).map(this.parseStoredMessage);
|
|
475
535
|
if (resourceId) {
|
|
@@ -545,7 +605,7 @@ var StoreMemoryUpstash = class extends MemoryStorage {
|
|
|
545
605
|
domain: ErrorDomain.STORAGE,
|
|
546
606
|
category: ErrorCategory.THIRD_PARTY,
|
|
547
607
|
details: {
|
|
548
|
-
threadId,
|
|
608
|
+
threadId: Array.isArray(threadId) ? threadId.join(",") : threadId,
|
|
549
609
|
resourceId: resourceId ?? ""
|
|
550
610
|
}
|
|
551
611
|
},
|
|
@@ -750,13 +810,33 @@ var StoreMemoryUpstash = class extends MemoryStorage {
|
|
|
750
810
|
try {
|
|
751
811
|
const threadIds = /* @__PURE__ */ new Set();
|
|
752
812
|
const messageKeys = [];
|
|
753
|
-
|
|
813
|
+
const foundMessageIds = [];
|
|
814
|
+
const indexPipeline = this.client.pipeline();
|
|
815
|
+
messageIds.forEach((id) => indexPipeline.get(getMessageIndexKey(id)));
|
|
816
|
+
const indexResults = await indexPipeline.exec();
|
|
817
|
+
const indexedMessages = [];
|
|
818
|
+
const unindexedMessageIds = [];
|
|
819
|
+
messageIds.forEach((id, i) => {
|
|
820
|
+
const threadId = indexResults[i];
|
|
821
|
+
if (threadId) {
|
|
822
|
+
indexedMessages.push({ messageId: id, threadId });
|
|
823
|
+
} else {
|
|
824
|
+
unindexedMessageIds.push(id);
|
|
825
|
+
}
|
|
826
|
+
});
|
|
827
|
+
for (const { messageId, threadId } of indexedMessages) {
|
|
828
|
+
messageKeys.push(getMessageKey(threadId, messageId));
|
|
829
|
+
foundMessageIds.push(messageId);
|
|
830
|
+
threadIds.add(threadId);
|
|
831
|
+
}
|
|
832
|
+
for (const messageId of unindexedMessageIds) {
|
|
754
833
|
const pattern = getMessageKey("*", messageId);
|
|
755
834
|
const keys = await this.operations.scanKeys(pattern);
|
|
756
835
|
for (const key of keys) {
|
|
757
836
|
const message = await this.client.get(key);
|
|
758
837
|
if (message && message.id === messageId) {
|
|
759
838
|
messageKeys.push(key);
|
|
839
|
+
foundMessageIds.push(messageId);
|
|
760
840
|
if (message.threadId) {
|
|
761
841
|
threadIds.add(message.threadId);
|
|
762
842
|
}
|
|
@@ -771,6 +851,9 @@ var StoreMemoryUpstash = class extends MemoryStorage {
|
|
|
771
851
|
for (const key of messageKeys) {
|
|
772
852
|
pipeline.del(key);
|
|
773
853
|
}
|
|
854
|
+
for (const messageId of foundMessageIds) {
|
|
855
|
+
pipeline.del(getMessageIndexKey(messageId));
|
|
856
|
+
}
|
|
774
857
|
if (threadIds.size > 0) {
|
|
775
858
|
for (const threadId of threadIds) {
|
|
776
859
|
const threadKey = getKey(TABLE_THREADS, { id: threadId });
|
|
@@ -946,32 +1029,7 @@ var StoreOperationsUpstash = class extends StoreOperations {
|
|
|
946
1029
|
}
|
|
947
1030
|
};
|
|
948
1031
|
function transformScoreRow(row) {
|
|
949
|
-
|
|
950
|
-
if (typeof v === "string") {
|
|
951
|
-
try {
|
|
952
|
-
return JSON.parse(v);
|
|
953
|
-
} catch {
|
|
954
|
-
return v;
|
|
955
|
-
}
|
|
956
|
-
}
|
|
957
|
-
return v;
|
|
958
|
-
};
|
|
959
|
-
return {
|
|
960
|
-
...row,
|
|
961
|
-
scorer: parseField(row.scorer),
|
|
962
|
-
preprocessStepResult: parseField(row.preprocessStepResult),
|
|
963
|
-
generateScorePrompt: row.generateScorePrompt,
|
|
964
|
-
generateReasonPrompt: row.generateReasonPrompt,
|
|
965
|
-
analyzeStepResult: parseField(row.analyzeStepResult),
|
|
966
|
-
metadata: parseField(row.metadata),
|
|
967
|
-
input: parseField(row.input),
|
|
968
|
-
output: parseField(row.output),
|
|
969
|
-
additionalContext: parseField(row.additionalContext),
|
|
970
|
-
requestContext: parseField(row.requestContext),
|
|
971
|
-
entity: parseField(row.entity),
|
|
972
|
-
createdAt: row.createdAt,
|
|
973
|
-
updatedAt: row.updatedAt
|
|
974
|
-
};
|
|
1032
|
+
return transformScoreRow$1(row);
|
|
975
1033
|
}
|
|
976
1034
|
var ScoresUpstash = class extends ScoresStorage {
|
|
977
1035
|
client;
|