@mastra/upstash 0.1.0-alpha.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +32 -0
- package/LICENSE +44 -0
- package/README.md +59 -0
- package/dist/index.d.ts +88 -0
- package/dist/index.js +466 -0
- package/docker-compose.yaml +15 -0
- package/package.json +35 -0
- package/src/index.ts +2 -0
- package/src/storage/index.ts +286 -0
- package/src/storage/upstash.test.ts +381 -0
- package/src/vector/filter.test.ts +557 -0
- package/src/vector/filter.ts +253 -0
- package/src/vector/index.test.ts +872 -0
- package/src/vector/index.ts +101 -0
- package/tsconfig.json +5 -0
- package/vitest.config.ts +11 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
version: '3'
|
|
2
|
+
services:
|
|
3
|
+
redis:
|
|
4
|
+
image: redis:7-alpine
|
|
5
|
+
ports:
|
|
6
|
+
- "6379:6379"
|
|
7
|
+
command: redis-server --requirepass redis_password
|
|
8
|
+
serverless-redis-http:
|
|
9
|
+
image: hiett/serverless-redis-http:latest
|
|
10
|
+
ports:
|
|
11
|
+
- "8079:80"
|
|
12
|
+
environment:
|
|
13
|
+
SRH_MODE: env
|
|
14
|
+
SRH_TOKEN: test_token
|
|
15
|
+
SRH_CONNECTION_STRING: "redis://:redis_password@redis:6379"
|
package/package.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "@mastra/upstash",
|
|
3
|
+
"version": "0.1.0-alpha.2",
|
|
4
|
+
"description": "Upstash provider for Mastra - includes both vector and db storage capabilities",
|
|
5
|
+
"type": "module",
|
|
6
|
+
"main": "dist/index.js",
|
|
7
|
+
"types": "dist/index.d.ts",
|
|
8
|
+
"exports": {
|
|
9
|
+
".": {
|
|
10
|
+
"import": {
|
|
11
|
+
"types": "./dist/index.d.ts",
|
|
12
|
+
"default": "./dist/index.js"
|
|
13
|
+
}
|
|
14
|
+
},
|
|
15
|
+
"./package.json": "./package.json"
|
|
16
|
+
},
|
|
17
|
+
"dependencies": {
|
|
18
|
+
"@upstash/redis": "^1.28.3",
|
|
19
|
+
"@upstash/vector": "^1.1.7",
|
|
20
|
+
"@mastra/core": "^0.2.0-alpha.92"
|
|
21
|
+
},
|
|
22
|
+
"devDependencies": {
|
|
23
|
+
"@tsconfig/recommended": "^1.0.7",
|
|
24
|
+
"@types/node": "^22.9.0",
|
|
25
|
+
"tsup": "^8.0.1",
|
|
26
|
+
"vitest": "^3.0.4"
|
|
27
|
+
},
|
|
28
|
+
"scripts": {
|
|
29
|
+
"pretest": "docker compose up -d",
|
|
30
|
+
"test": "vitest run",
|
|
31
|
+
"posttest": "docker compose down",
|
|
32
|
+
"build": "tsup-node src/index.ts --format esm --dts --clean --treeshake",
|
|
33
|
+
"dev": "tsup-node src/index.ts --format esm --dts --clean --treeshake --watch"
|
|
34
|
+
}
|
|
35
|
+
}
|
package/src/index.ts
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
import { type StorageThreadType, type MessageType } from '@mastra/core/memory';
|
|
2
|
+
import { MastraStorage, type TABLE_NAMES, type StorageColumn, type StorageGetMessagesArg } from '@mastra/core/storage';
|
|
3
|
+
import { type WorkflowRunState } from '@mastra/core/workflows';
|
|
4
|
+
import { Redis } from '@upstash/redis';
|
|
5
|
+
|
|
6
|
+
export interface UpstashConfig {
|
|
7
|
+
url: string;
|
|
8
|
+
token: string;
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
export class UpstashStore extends MastraStorage {
|
|
12
|
+
private redis: Redis;
|
|
13
|
+
|
|
14
|
+
constructor(config: UpstashConfig) {
|
|
15
|
+
super({ name: 'Upstash' });
|
|
16
|
+
this.redis = new Redis({
|
|
17
|
+
url: config.url,
|
|
18
|
+
token: config.token,
|
|
19
|
+
});
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
private getKey(tableName: TABLE_NAMES, keys: Record<string, any>): string {
|
|
23
|
+
const keyParts = Object.entries(keys).map(([key, value]) => `${key}:${value}`);
|
|
24
|
+
return `${tableName}:${keyParts.join(':')}`;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
private ensureDate(date: Date | string | undefined): Date | undefined {
|
|
28
|
+
if (!date) return undefined;
|
|
29
|
+
return date instanceof Date ? date : new Date(date);
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
private serializeDate(date: Date | string | undefined): string | undefined {
|
|
33
|
+
if (!date) return undefined;
|
|
34
|
+
const dateObj = this.ensureDate(date);
|
|
35
|
+
return dateObj?.toISOString();
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
async createTable({
|
|
39
|
+
tableName,
|
|
40
|
+
schema,
|
|
41
|
+
}: {
|
|
42
|
+
tableName: TABLE_NAMES;
|
|
43
|
+
schema: Record<string, StorageColumn>;
|
|
44
|
+
}): Promise<void> {
|
|
45
|
+
// Redis is schemaless, so we don't need to create tables
|
|
46
|
+
// But we can store the schema for reference
|
|
47
|
+
await this.redis.set(`schema:${tableName}`, schema);
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
async clearTable({ tableName }: { tableName: TABLE_NAMES }): Promise<void> {
|
|
51
|
+
const pattern = `${tableName}:*`;
|
|
52
|
+
const keys = await this.redis.keys(pattern);
|
|
53
|
+
if (keys.length > 0) {
|
|
54
|
+
await this.redis.del(...keys);
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
async insert({ tableName, record }: { tableName: TABLE_NAMES; record: Record<string, any> }): Promise<void> {
|
|
59
|
+
let key: string;
|
|
60
|
+
|
|
61
|
+
if (tableName === MastraStorage.TABLE_MESSAGES) {
|
|
62
|
+
// For messages, use threadId as the primary key component
|
|
63
|
+
key = this.getKey(tableName, { threadId: record.threadId, id: record.id });
|
|
64
|
+
} else {
|
|
65
|
+
key = this.getKey(tableName, { id: record.id });
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
// Convert dates to ISO strings before storing
|
|
69
|
+
const processedRecord = {
|
|
70
|
+
...record,
|
|
71
|
+
createdAt: this.serializeDate(record.createdAt),
|
|
72
|
+
updatedAt: this.serializeDate(record.updatedAt),
|
|
73
|
+
};
|
|
74
|
+
|
|
75
|
+
await this.redis.set(key, processedRecord);
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
async load<R>({ tableName, keys }: { tableName: TABLE_NAMES; keys: Record<string, string> }): Promise<R | null> {
|
|
79
|
+
const key = this.getKey(tableName, keys);
|
|
80
|
+
const data = await this.redis.get<R>(key);
|
|
81
|
+
return data || null;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
async getThreadById({ threadId }: { threadId: string }): Promise<StorageThreadType | null> {
|
|
85
|
+
const thread = await this.load<StorageThreadType>({
|
|
86
|
+
tableName: MastraStorage.TABLE_THREADS,
|
|
87
|
+
keys: { id: threadId },
|
|
88
|
+
});
|
|
89
|
+
|
|
90
|
+
if (!thread) return null;
|
|
91
|
+
|
|
92
|
+
return {
|
|
93
|
+
...thread,
|
|
94
|
+
createdAt: this.ensureDate(thread.createdAt)!,
|
|
95
|
+
updatedAt: this.ensureDate(thread.updatedAt)!,
|
|
96
|
+
metadata: typeof thread.metadata === 'string' ? JSON.parse(thread.metadata) : thread.metadata,
|
|
97
|
+
};
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
async getThreadsByResourceId({ resourceId }: { resourceId: string }): Promise<StorageThreadType[]> {
|
|
101
|
+
const pattern = `${MastraStorage.TABLE_THREADS}:*`;
|
|
102
|
+
const keys = await this.redis.keys(pattern);
|
|
103
|
+
const threads = await Promise.all(
|
|
104
|
+
keys.map(async key => {
|
|
105
|
+
const data = await this.redis.get<StorageThreadType>(key);
|
|
106
|
+
return data;
|
|
107
|
+
}),
|
|
108
|
+
);
|
|
109
|
+
|
|
110
|
+
return threads
|
|
111
|
+
.filter(thread => thread && thread.resourceId === resourceId)
|
|
112
|
+
.map(thread => ({
|
|
113
|
+
...thread!,
|
|
114
|
+
createdAt: this.ensureDate(thread!.createdAt)!,
|
|
115
|
+
updatedAt: this.ensureDate(thread!.updatedAt)!,
|
|
116
|
+
metadata: typeof thread!.metadata === 'string' ? JSON.parse(thread!.metadata) : thread!.metadata,
|
|
117
|
+
}));
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
async saveThread({ thread }: { thread: StorageThreadType }): Promise<StorageThreadType> {
|
|
121
|
+
await this.insert({
|
|
122
|
+
tableName: MastraStorage.TABLE_THREADS,
|
|
123
|
+
record: thread,
|
|
124
|
+
});
|
|
125
|
+
return thread;
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
async updateThread({
|
|
129
|
+
id,
|
|
130
|
+
title,
|
|
131
|
+
metadata,
|
|
132
|
+
}: {
|
|
133
|
+
id: string;
|
|
134
|
+
title: string;
|
|
135
|
+
metadata: Record<string, unknown>;
|
|
136
|
+
}): Promise<StorageThreadType> {
|
|
137
|
+
const thread = await this.getThreadById({ threadId: id });
|
|
138
|
+
if (!thread) {
|
|
139
|
+
throw new Error(`Thread ${id} not found`);
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
const updatedThread = {
|
|
143
|
+
...thread,
|
|
144
|
+
title,
|
|
145
|
+
metadata: {
|
|
146
|
+
...thread.metadata,
|
|
147
|
+
...metadata,
|
|
148
|
+
},
|
|
149
|
+
};
|
|
150
|
+
|
|
151
|
+
await this.__saveThread({ thread: updatedThread });
|
|
152
|
+
return updatedThread;
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
async deleteThread({ threadId }: { threadId: string }): Promise<void> {
|
|
156
|
+
const key = this.getKey(MastraStorage.TABLE_THREADS, { id: threadId });
|
|
157
|
+
await this.redis.del(key);
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
private getMessageKey(threadId: string, messageId: string): string {
|
|
161
|
+
return this.getKey(MastraStorage.TABLE_MESSAGES, { threadId, id: messageId });
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
private getThreadMessagesKey(threadId: string): string {
|
|
165
|
+
return `thread:${threadId}:messages`;
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
async saveMessages({ messages }: { messages: MessageType[] }): Promise<MessageType[]> {
|
|
169
|
+
if (messages.length === 0) return [];
|
|
170
|
+
|
|
171
|
+
const pipeline = this.redis.pipeline();
|
|
172
|
+
|
|
173
|
+
// Add an index to each message to maintain order
|
|
174
|
+
const messagesWithIndex = messages.map((message, index) => ({
|
|
175
|
+
...message,
|
|
176
|
+
_index: index,
|
|
177
|
+
}));
|
|
178
|
+
|
|
179
|
+
for (const message of messagesWithIndex) {
|
|
180
|
+
const key = this.getMessageKey(message.threadId, message.id);
|
|
181
|
+
const score = message._index !== undefined ? message._index : new Date(message.createdAt).getTime();
|
|
182
|
+
|
|
183
|
+
// Store the message data
|
|
184
|
+
pipeline.set(key, message);
|
|
185
|
+
|
|
186
|
+
// Add to sorted set for this thread
|
|
187
|
+
pipeline.zadd(this.getThreadMessagesKey(message.threadId), {
|
|
188
|
+
score,
|
|
189
|
+
member: message.id,
|
|
190
|
+
});
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
await pipeline.exec();
|
|
194
|
+
return messages;
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
async getMessages<T = unknown>({ threadId, selectBy }: StorageGetMessagesArg): Promise<T[]> {
|
|
198
|
+
const limit = typeof selectBy?.last === `number` ? selectBy.last : 40;
|
|
199
|
+
const messageIds = new Set<string>();
|
|
200
|
+
const threadMessagesKey = this.getThreadMessagesKey(threadId);
|
|
201
|
+
|
|
202
|
+
if (limit === 0 && !selectBy?.include) {
|
|
203
|
+
return [];
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
// First, get specifically included messages and their context
|
|
207
|
+
if (selectBy?.include?.length) {
|
|
208
|
+
for (const item of selectBy.include) {
|
|
209
|
+
messageIds.add(item.id);
|
|
210
|
+
|
|
211
|
+
if (item.withPreviousMessages || item.withNextMessages) {
|
|
212
|
+
// Get the rank of this message in the sorted set
|
|
213
|
+
const rank = await this.redis.zrank(threadMessagesKey, item.id);
|
|
214
|
+
if (rank === null) continue;
|
|
215
|
+
|
|
216
|
+
// Get previous messages if requested
|
|
217
|
+
if (item.withPreviousMessages) {
|
|
218
|
+
const start = Math.max(0, rank - item.withPreviousMessages);
|
|
219
|
+
const prevIds = rank === 0 ? [] : await this.redis.zrange(threadMessagesKey, start, rank - 1);
|
|
220
|
+
prevIds.forEach(id => messageIds.add(id as string));
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
// Get next messages if requested
|
|
224
|
+
if (item.withNextMessages) {
|
|
225
|
+
const nextIds = await this.redis.zrange(threadMessagesKey, rank + 1, rank + item.withNextMessages);
|
|
226
|
+
nextIds.forEach(id => messageIds.add(id as string));
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
}
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
// Then get the most recent messages
|
|
233
|
+
const latestIds = limit === 0 ? [] : await this.redis.zrange(threadMessagesKey, -limit, -1);
|
|
234
|
+
latestIds.forEach(id => messageIds.add(id as string));
|
|
235
|
+
|
|
236
|
+
// Fetch all needed messages in parallel
|
|
237
|
+
const messages = (
|
|
238
|
+
await Promise.all(
|
|
239
|
+
Array.from(messageIds).map(async id =>
|
|
240
|
+
this.redis.get<MessageType & { _index?: number }>(this.getMessageKey(threadId, id)),
|
|
241
|
+
),
|
|
242
|
+
)
|
|
243
|
+
).filter(msg => msg !== null) as (MessageType & { _index?: number })[];
|
|
244
|
+
|
|
245
|
+
// Sort messages by their position in the sorted set
|
|
246
|
+
const messageOrder = await this.redis.zrange(threadMessagesKey, 0, -1);
|
|
247
|
+
messages.sort((a, b) => messageOrder.indexOf(a!.id) - messageOrder.indexOf(b!.id));
|
|
248
|
+
|
|
249
|
+
// Remove _index before returning
|
|
250
|
+
return messages.map(({ _index, ...message }) => message as unknown as T);
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
async persistWorkflowSnapshot(params: {
|
|
254
|
+
namespace: string;
|
|
255
|
+
workflowName: string;
|
|
256
|
+
runId: string;
|
|
257
|
+
snapshot: WorkflowRunState;
|
|
258
|
+
}): Promise<void> {
|
|
259
|
+
const { namespace, workflowName, runId, snapshot } = params;
|
|
260
|
+
const key = this.getKey(MastraStorage.TABLE_WORKFLOW_SNAPSHOT, {
|
|
261
|
+
namespace,
|
|
262
|
+
workflow_name: workflowName,
|
|
263
|
+
run_id: runId,
|
|
264
|
+
});
|
|
265
|
+
await this.redis.set(key, snapshot); // Store snapshot directly without wrapping
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
async loadWorkflowSnapshot(params: {
|
|
269
|
+
namespace: string;
|
|
270
|
+
workflowName: string;
|
|
271
|
+
runId: string;
|
|
272
|
+
}): Promise<WorkflowRunState | null> {
|
|
273
|
+
const { namespace, workflowName, runId } = params;
|
|
274
|
+
const key = this.getKey(MastraStorage.TABLE_WORKFLOW_SNAPSHOT, {
|
|
275
|
+
namespace,
|
|
276
|
+
workflow_name: workflowName,
|
|
277
|
+
run_id: runId,
|
|
278
|
+
});
|
|
279
|
+
const data = await this.redis.get<WorkflowRunState>(key);
|
|
280
|
+
return data || null;
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
async close(): Promise<void> {
|
|
284
|
+
// No explicit cleanup needed for Upstash Redis
|
|
285
|
+
}
|
|
286
|
+
}
|