@mastra/memory 0.0.2-alpha.61 → 0.0.2-alpha.63
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +13 -0
- package/dist/index.d.ts +2 -3
- package/dist/index.js +0 -7
- package/dist/kv/upstash.d.ts +6 -5
- package/dist/kv/upstash.js +204 -0
- package/dist/postgres/index.d.ts +5 -3
- package/dist/postgres/index.js +414 -0
- package/package.json +21 -13
- package/src/index.ts +0 -2
- package/dist/index.d.ts.map +0 -1
- package/dist/kv/upstash.d.ts.map +0 -1
- package/dist/memory.cjs.development.js +0 -1575
- package/dist/memory.cjs.development.js.map +0 -1
- package/dist/memory.cjs.production.min.js +0 -2
- package/dist/memory.cjs.production.min.js.map +0 -1
- package/dist/memory.esm.js +0 -1570
- package/dist/memory.esm.js.map +0 -1
- package/dist/postgres/index.d.ts.map +0 -1
package/CHANGELOG.md
CHANGED
|
@@ -1,5 +1,18 @@
|
|
|
1
1
|
# @mastra/memory
|
|
2
2
|
|
|
3
|
+
## 0.0.2-alpha.63
|
|
4
|
+
|
|
5
|
+
### Patch Changes
|
|
6
|
+
|
|
7
|
+
- Updated dependencies [9fb3039]
|
|
8
|
+
- @mastra/core@0.1.27-alpha.81
|
|
9
|
+
|
|
10
|
+
## 0.0.2-alpha.62
|
|
11
|
+
|
|
12
|
+
### Patch Changes
|
|
13
|
+
|
|
14
|
+
- 7f5b1b2: @mastra/memory tsup bundling
|
|
15
|
+
|
|
3
16
|
## 0.0.2-alpha.61
|
|
4
17
|
|
|
5
18
|
### Patch Changes
|
package/dist/index.d.ts
CHANGED
|
@@ -1,3 +1,2 @@
|
|
|
1
|
-
|
|
2
|
-
export
|
|
3
|
-
//# sourceMappingURL=index.d.ts.map
|
|
1
|
+
|
|
2
|
+
export { }
|
package/dist/index.js
CHANGED
package/dist/kv/upstash.d.ts
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
|
-
import { MastraMemory,
|
|
1
|
+
import { MastraMemory, ThreadType, MessageResponse, AiMessageType, MessageType as MessageType$1 } from '@mastra/core';
|
|
2
2
|
import { Redis } from '@upstash/redis';
|
|
3
3
|
import { ToolResultPart } from 'ai';
|
|
4
|
-
|
|
4
|
+
|
|
5
|
+
interface MessageType extends MessageType$1 {
|
|
5
6
|
tokens?: number;
|
|
6
7
|
}
|
|
7
8
|
interface SerializedThreadType extends Omit<ThreadType, 'createdAt' | 'updatedAt'> {
|
|
8
9
|
createdAt: string;
|
|
9
10
|
updatedAt: string;
|
|
10
11
|
}
|
|
11
|
-
|
|
12
|
+
declare class UpstashKVMemory extends MastraMemory {
|
|
12
13
|
private prefix;
|
|
13
14
|
kv: Redis;
|
|
14
15
|
constructor(config: {
|
|
@@ -68,5 +69,5 @@ export declare class UpstashKVMemory extends MastraMemory {
|
|
|
68
69
|
parseThread(thread: SerializedThreadType): ThreadType;
|
|
69
70
|
parseMessages(messages: MessageType[]): MessageType[];
|
|
70
71
|
}
|
|
71
|
-
|
|
72
|
-
|
|
72
|
+
|
|
73
|
+
export { UpstashKVMemory };
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import { MastraMemory } from '@mastra/core';
|
|
2
|
+
import { Redis } from '@upstash/redis';
|
|
3
|
+
import crypto from 'crypto';
|
|
4
|
+
|
|
5
|
+
// src/kv/upstash.ts
|
|
6
|
+
var UpstashKVMemory = class extends MastraMemory {
|
|
7
|
+
constructor(config) {
|
|
8
|
+
super();
|
|
9
|
+
this.prefix = config.prefix || "mastra";
|
|
10
|
+
this.MAX_CONTEXT_TOKENS = config.maxTokens;
|
|
11
|
+
this.kv = new Redis({
|
|
12
|
+
url: config.url,
|
|
13
|
+
token: config.token
|
|
14
|
+
});
|
|
15
|
+
}
|
|
16
|
+
getThreadKey(threadId) {
|
|
17
|
+
return `${this.prefix}:thread:${threadId}`;
|
|
18
|
+
}
|
|
19
|
+
getMessagesKey(threadId) {
|
|
20
|
+
return `${this.prefix}:messages:${threadId}`;
|
|
21
|
+
}
|
|
22
|
+
getToolCacheKey(hashedArgs) {
|
|
23
|
+
return `${this.prefix}:tool:${hashedArgs}`;
|
|
24
|
+
}
|
|
25
|
+
async getThreadById({ threadId }) {
|
|
26
|
+
const thread = await this.kv.get(this.getThreadKey(threadId));
|
|
27
|
+
return thread ? this.parseThread(thread) : null;
|
|
28
|
+
}
|
|
29
|
+
async getThreadsByResourceId({ resourceid }) {
|
|
30
|
+
const pattern = `${this.prefix}:thread:*`;
|
|
31
|
+
const keys = await this.kv.keys(pattern);
|
|
32
|
+
const threads = await Promise.all(keys.map((key) => this.kv.get(key)));
|
|
33
|
+
return threads.filter((thread) => thread?.resourceid === resourceid).map((thread) => this.parseThread(thread));
|
|
34
|
+
}
|
|
35
|
+
async saveThread({ thread }) {
|
|
36
|
+
const key = this.getThreadKey(thread.id);
|
|
37
|
+
const serializedThread = {
|
|
38
|
+
...thread,
|
|
39
|
+
createdAt: thread.createdAt.toISOString(),
|
|
40
|
+
updatedAt: thread.updatedAt.toISOString()
|
|
41
|
+
};
|
|
42
|
+
await this.kv.set(key, serializedThread);
|
|
43
|
+
return thread;
|
|
44
|
+
}
|
|
45
|
+
async updateThread(id, title, metadata) {
|
|
46
|
+
const key = this.getThreadKey(id);
|
|
47
|
+
const thread = await this.kv.get(key);
|
|
48
|
+
if (!thread) {
|
|
49
|
+
throw new Error(`Thread ${id} not found`);
|
|
50
|
+
}
|
|
51
|
+
const updatedThread = {
|
|
52
|
+
...thread,
|
|
53
|
+
title,
|
|
54
|
+
metadata,
|
|
55
|
+
updatedAt: (/* @__PURE__ */ new Date()).toISOString()
|
|
56
|
+
};
|
|
57
|
+
await this.kv.set(key, updatedThread);
|
|
58
|
+
return this.parseThread(updatedThread);
|
|
59
|
+
}
|
|
60
|
+
async deleteThread(id) {
|
|
61
|
+
await this.kv.del(this.getThreadKey(id));
|
|
62
|
+
await this.kv.del(this.getMessagesKey(id));
|
|
63
|
+
}
|
|
64
|
+
/**
|
|
65
|
+
* Tool Cache
|
|
66
|
+
*/
|
|
67
|
+
async validateToolCallArgs({ hashedArgs }) {
|
|
68
|
+
const cacheKey = this.getToolCacheKey(hashedArgs);
|
|
69
|
+
const cached = await this.kv.get(cacheKey);
|
|
70
|
+
return !!cached && new Date(cached.expireAt) > /* @__PURE__ */ new Date();
|
|
71
|
+
}
|
|
72
|
+
async getToolResult({
|
|
73
|
+
threadId,
|
|
74
|
+
toolArgs,
|
|
75
|
+
toolName
|
|
76
|
+
}) {
|
|
77
|
+
const hashedToolArgs = crypto.createHash("sha256").update(JSON.stringify({ args: toolArgs, threadId, toolName })).digest("hex");
|
|
78
|
+
const cacheKey = this.getToolCacheKey(hashedToolArgs);
|
|
79
|
+
const cached = await this.kv.get(cacheKey);
|
|
80
|
+
if (cached && new Date(cached.expireAt) > /* @__PURE__ */ new Date()) {
|
|
81
|
+
return cached.result || null;
|
|
82
|
+
}
|
|
83
|
+
return null;
|
|
84
|
+
}
|
|
85
|
+
async getContextWindow({
|
|
86
|
+
threadId,
|
|
87
|
+
startDate,
|
|
88
|
+
endDate,
|
|
89
|
+
// @ts-ignore
|
|
90
|
+
format = "raw"
|
|
91
|
+
}) {
|
|
92
|
+
const messagesKey = this.getMessagesKey(threadId);
|
|
93
|
+
const messages = await this.kv.lrange(messagesKey, 0, -1);
|
|
94
|
+
let filteredMessages = messages;
|
|
95
|
+
if (startDate) {
|
|
96
|
+
filteredMessages = filteredMessages.filter((msg) => new Date(msg.createdAt) >= startDate);
|
|
97
|
+
}
|
|
98
|
+
if (endDate) {
|
|
99
|
+
filteredMessages = filteredMessages.filter((msg) => new Date(msg.createdAt) <= endDate);
|
|
100
|
+
}
|
|
101
|
+
if (this.MAX_CONTEXT_TOKENS) {
|
|
102
|
+
let totalTokens = 0;
|
|
103
|
+
const messagesWithinTokenLimit = [];
|
|
104
|
+
for (const message of filteredMessages.reverse()) {
|
|
105
|
+
const content = message.role === "assistant" ? message.content[0]?.text || "" : message.content;
|
|
106
|
+
const tokens = Math.ceil(content.length / 4);
|
|
107
|
+
if (totalTokens + tokens > this.MAX_CONTEXT_TOKENS) {
|
|
108
|
+
break;
|
|
109
|
+
}
|
|
110
|
+
totalTokens += tokens;
|
|
111
|
+
messagesWithinTokenLimit.unshift({
|
|
112
|
+
...message,
|
|
113
|
+
tokens
|
|
114
|
+
});
|
|
115
|
+
}
|
|
116
|
+
return this.parseMessages(messagesWithinTokenLimit);
|
|
117
|
+
}
|
|
118
|
+
return this.parseMessages(filteredMessages);
|
|
119
|
+
}
|
|
120
|
+
/**
|
|
121
|
+
* Messages
|
|
122
|
+
*/
|
|
123
|
+
async getMessages({
|
|
124
|
+
threadId
|
|
125
|
+
}) {
|
|
126
|
+
const messagesKey = this.getMessagesKey(threadId);
|
|
127
|
+
const messages = await this.kv.lrange(messagesKey, 0, -1);
|
|
128
|
+
const parsedMessages = this.parseMessages(messages);
|
|
129
|
+
const uiMessages = this.convertToUIMessages(parsedMessages);
|
|
130
|
+
return { messages: parsedMessages, uiMessages };
|
|
131
|
+
}
|
|
132
|
+
async saveMessages({ messages }) {
|
|
133
|
+
const processedMessages = [];
|
|
134
|
+
for (const message of messages) {
|
|
135
|
+
const { threadId, toolCallArgs, toolNames, createdAt } = message;
|
|
136
|
+
const messagesKey = this.getMessagesKey(threadId);
|
|
137
|
+
const processedMessage = { ...message };
|
|
138
|
+
if (message.type === "text") {
|
|
139
|
+
const content = message.role === "assistant" ? message.content[0]?.text || "" : message.content;
|
|
140
|
+
processedMessage.tokens = this.estimateTokens(content);
|
|
141
|
+
}
|
|
142
|
+
if (toolCallArgs?.length) {
|
|
143
|
+
const hashedToolCallArgs = toolCallArgs.map(
|
|
144
|
+
(args, index) => crypto.createHash("sha256").update(JSON.stringify({ args, threadId, toolName: toolNames?.[index] })).digest("hex")
|
|
145
|
+
);
|
|
146
|
+
let validArgExists = true;
|
|
147
|
+
for (const hashedArg of hashedToolCallArgs) {
|
|
148
|
+
const isValid = await this.validateToolCallArgs({ hashedArgs: hashedArg });
|
|
149
|
+
if (!isValid) {
|
|
150
|
+
validArgExists = false;
|
|
151
|
+
break;
|
|
152
|
+
}
|
|
153
|
+
}
|
|
154
|
+
const expireAt = validArgExists ? createdAt : new Date(createdAt.getTime() + 5 * 60 * 1e3);
|
|
155
|
+
for (const hashedArg of hashedToolCallArgs) {
|
|
156
|
+
const cacheKey = this.getToolCacheKey(hashedArg);
|
|
157
|
+
await this.kv.set(cacheKey, { expireAt: expireAt.toISOString() });
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
await this.kv.rpush(messagesKey, processedMessage);
|
|
161
|
+
processedMessages.push(processedMessage);
|
|
162
|
+
}
|
|
163
|
+
return processedMessages;
|
|
164
|
+
}
|
|
165
|
+
async deleteMessage(id) {
|
|
166
|
+
const pattern = `${this.prefix}:messages:*`;
|
|
167
|
+
const keys = await this.kv.keys(pattern);
|
|
168
|
+
for (const key of keys) {
|
|
169
|
+
const messages = await this.kv.lrange(key, 0, -1);
|
|
170
|
+
const filteredMessages = messages.filter((msg) => msg.id !== id);
|
|
171
|
+
if (messages.length !== filteredMessages.length) {
|
|
172
|
+
await this.kv.del(key);
|
|
173
|
+
if (filteredMessages.length > 0) {
|
|
174
|
+
await this.kv.rpush(key, ...filteredMessages);
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
/**
|
|
180
|
+
* Cleanup
|
|
181
|
+
*/
|
|
182
|
+
async drop() {
|
|
183
|
+
const pattern = `${this.prefix}:*`;
|
|
184
|
+
const keys = await this.kv.keys(pattern);
|
|
185
|
+
if (keys.length > 0) {
|
|
186
|
+
await this.kv.del(...keys);
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
parseThread(thread) {
|
|
190
|
+
return {
|
|
191
|
+
...thread,
|
|
192
|
+
createdAt: new Date(thread.createdAt),
|
|
193
|
+
updatedAt: new Date(thread.updatedAt)
|
|
194
|
+
};
|
|
195
|
+
}
|
|
196
|
+
parseMessages(messages) {
|
|
197
|
+
return messages.map((message) => ({
|
|
198
|
+
...message,
|
|
199
|
+
createdAt: new Date(message.createdAt)
|
|
200
|
+
}));
|
|
201
|
+
}
|
|
202
|
+
};
|
|
203
|
+
|
|
204
|
+
export { UpstashKVMemory };
|
package/dist/postgres/index.d.ts
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
import { MastraMemory,
|
|
1
|
+
import { MastraMemory, ThreadType, MessageResponse, MessageType, AiMessageType } from '@mastra/core';
|
|
2
2
|
import { ToolResultPart } from 'ai';
|
|
3
|
-
|
|
3
|
+
|
|
4
|
+
declare class PgMemory extends MastraMemory {
|
|
4
5
|
private pool;
|
|
5
6
|
hasTables: boolean;
|
|
6
7
|
constructor(config: {
|
|
@@ -57,4 +58,5 @@ export declare class PgMemory extends MastraMemory {
|
|
|
57
58
|
drop(): Promise<void>;
|
|
58
59
|
ensureTablesExist(): Promise<void>;
|
|
59
60
|
}
|
|
60
|
-
|
|
61
|
+
|
|
62
|
+
export { PgMemory };
|
|
@@ -0,0 +1,414 @@
|
|
|
1
|
+
import { MastraMemory } from '@mastra/core';
|
|
2
|
+
import crypto from 'crypto';
|
|
3
|
+
import pg from 'pg';
|
|
4
|
+
|
|
5
|
+
// src/postgres/index.ts
|
|
6
|
+
var { Pool } = pg;
|
|
7
|
+
var PgMemory = class extends MastraMemory {
|
|
8
|
+
constructor(config) {
|
|
9
|
+
super();
|
|
10
|
+
this.hasTables = false;
|
|
11
|
+
this.pool = new Pool({ connectionString: config.connectionString });
|
|
12
|
+
this.MAX_CONTEXT_TOKENS = config.maxTokens;
|
|
13
|
+
}
|
|
14
|
+
/**
|
|
15
|
+
* Threads
|
|
16
|
+
*/
|
|
17
|
+
async getThreadById({ threadId }) {
|
|
18
|
+
await this.ensureTablesExist();
|
|
19
|
+
const client = await this.pool.connect();
|
|
20
|
+
try {
|
|
21
|
+
const result = await client.query(
|
|
22
|
+
`
|
|
23
|
+
SELECT id, title, created_at AS createdAt, updated_at AS updatedAt, resourceid, metadata
|
|
24
|
+
FROM mastra_threads
|
|
25
|
+
WHERE id = $1
|
|
26
|
+
`,
|
|
27
|
+
[threadId]
|
|
28
|
+
);
|
|
29
|
+
return result.rows[0] || null;
|
|
30
|
+
} finally {
|
|
31
|
+
client.release();
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
async getThreadsByResourceId({ resourceid }) {
|
|
35
|
+
await this.ensureTablesExist();
|
|
36
|
+
const client = await this.pool.connect();
|
|
37
|
+
try {
|
|
38
|
+
const result = await client.query(
|
|
39
|
+
`
|
|
40
|
+
SELECT id, title, resourceid, created_at AS createdAt, updated_at AS updatedAt, metadata
|
|
41
|
+
FROM mastra_threads
|
|
42
|
+
WHERE resourceid = $1
|
|
43
|
+
`,
|
|
44
|
+
[resourceid]
|
|
45
|
+
);
|
|
46
|
+
return result.rows;
|
|
47
|
+
} finally {
|
|
48
|
+
client.release();
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
async saveThread({ thread }) {
|
|
52
|
+
await this.ensureTablesExist();
|
|
53
|
+
const client = await this.pool.connect();
|
|
54
|
+
try {
|
|
55
|
+
const { id, title, createdAt, updatedAt, resourceid, metadata } = thread;
|
|
56
|
+
const result = await client.query(
|
|
57
|
+
`
|
|
58
|
+
INSERT INTO mastra_threads (id, title, created_at, updated_at, resourceid, metadata)
|
|
59
|
+
VALUES ($1, $2, $3, $4, $5, $6)
|
|
60
|
+
ON CONFLICT (id) DO UPDATE SET title = $2, updated_at = $4, resourceid = $5, metadata = $6
|
|
61
|
+
RETURNING id, title, created_at AS createdAt, updated_at AS updatedAt, resourceid, metadata
|
|
62
|
+
`,
|
|
63
|
+
[id, title, createdAt, updatedAt, resourceid, JSON.stringify(metadata)]
|
|
64
|
+
);
|
|
65
|
+
return result?.rows?.[0];
|
|
66
|
+
} finally {
|
|
67
|
+
client.release();
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
async updateThread(id, title, metadata) {
|
|
71
|
+
const client = await this.pool.connect();
|
|
72
|
+
try {
|
|
73
|
+
const result = await client.query(
|
|
74
|
+
`
|
|
75
|
+
UPDATE mastra_threads
|
|
76
|
+
SET title = $1, metadata = $2, updated_at = NOW()
|
|
77
|
+
WHERE id = $3
|
|
78
|
+
RETURNING *
|
|
79
|
+
`,
|
|
80
|
+
[title, JSON.stringify(metadata), id]
|
|
81
|
+
);
|
|
82
|
+
return result?.rows?.[0];
|
|
83
|
+
} finally {
|
|
84
|
+
client.release();
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
async deleteThread(id) {
|
|
88
|
+
const client = await this.pool.connect();
|
|
89
|
+
try {
|
|
90
|
+
await client.query(
|
|
91
|
+
`
|
|
92
|
+
DELETE FROM mastra_messages
|
|
93
|
+
WHERE thread_id = $1
|
|
94
|
+
`,
|
|
95
|
+
[id]
|
|
96
|
+
);
|
|
97
|
+
await client.query(
|
|
98
|
+
`
|
|
99
|
+
DELETE FROM mastra_threads
|
|
100
|
+
WHERE id = $1
|
|
101
|
+
`,
|
|
102
|
+
[id]
|
|
103
|
+
);
|
|
104
|
+
} finally {
|
|
105
|
+
client.release();
|
|
106
|
+
}
|
|
107
|
+
}
|
|
108
|
+
/**
|
|
109
|
+
* Tool Cache
|
|
110
|
+
*/
|
|
111
|
+
async validateToolCallArgs({ hashedArgs }) {
|
|
112
|
+
await this.ensureTablesExist();
|
|
113
|
+
const client = await this.pool.connect();
|
|
114
|
+
try {
|
|
115
|
+
const toolArgsResult = await client.query(
|
|
116
|
+
` SELECT tool_call_ids as toolCallIds,
|
|
117
|
+
tool_call_args as toolCallArgs,
|
|
118
|
+
created_at AS createdAt
|
|
119
|
+
FROM mastra_messages
|
|
120
|
+
WHERE tool_call_args::jsonb @> $1
|
|
121
|
+
AND tool_call_args_expire_at > $2
|
|
122
|
+
ORDER BY created_at ASC
|
|
123
|
+
LIMIT 1`,
|
|
124
|
+
[JSON.stringify([hashedArgs]), (/* @__PURE__ */ new Date()).toISOString()]
|
|
125
|
+
);
|
|
126
|
+
return toolArgsResult.rows.length > 0;
|
|
127
|
+
} catch (error) {
|
|
128
|
+
console.log("error checking if valid arg exists====", error);
|
|
129
|
+
return false;
|
|
130
|
+
} finally {
|
|
131
|
+
client.release();
|
|
132
|
+
}
|
|
133
|
+
}
|
|
134
|
+
async getToolResult({
|
|
135
|
+
threadId,
|
|
136
|
+
toolArgs,
|
|
137
|
+
toolName
|
|
138
|
+
}) {
|
|
139
|
+
await this.ensureTablesExist();
|
|
140
|
+
console.log("checking for cached tool result====", JSON.stringify(toolArgs, null, 2));
|
|
141
|
+
const client = await this.pool.connect();
|
|
142
|
+
try {
|
|
143
|
+
const hashedToolArgs = crypto.createHash("sha256").update(JSON.stringify({ args: toolArgs, threadId, toolName })).digest("hex");
|
|
144
|
+
const toolArgsResult = await client.query(
|
|
145
|
+
`SELECT tool_call_ids,
|
|
146
|
+
tool_call_args,
|
|
147
|
+
created_at
|
|
148
|
+
FROM mastra_messages
|
|
149
|
+
WHERE tool_call_args::jsonb @> $1
|
|
150
|
+
AND tool_call_args_expire_at > $2
|
|
151
|
+
ORDER BY created_at ASC
|
|
152
|
+
LIMIT 1`,
|
|
153
|
+
[JSON.stringify([hashedToolArgs]), (/* @__PURE__ */ new Date()).toISOString()]
|
|
154
|
+
);
|
|
155
|
+
if (toolArgsResult.rows.length > 0) {
|
|
156
|
+
const toolCallArgs = JSON.parse(toolArgsResult.rows[0]?.tool_call_args);
|
|
157
|
+
const toolCallIds = JSON.parse(toolArgsResult.rows[0]?.tool_call_ids);
|
|
158
|
+
const createdAt = toolArgsResult.rows[0]?.created_at;
|
|
159
|
+
const toolCallArgsIndex = toolCallArgs.findIndex((arg) => arg === hashedToolArgs);
|
|
160
|
+
const correspondingToolCallId = toolCallIds[toolCallArgsIndex];
|
|
161
|
+
const toolResult = await client.query(
|
|
162
|
+
`SELECT content
|
|
163
|
+
FROM mastra_messages
|
|
164
|
+
WHERE thread_id = $1
|
|
165
|
+
AND tool_call_ids ILIKE $2
|
|
166
|
+
AND type = 'tool-result'
|
|
167
|
+
AND created_at = $3
|
|
168
|
+
LIMIT 1`,
|
|
169
|
+
[threadId, `%${correspondingToolCallId}%`, new Date(createdAt).toISOString()]
|
|
170
|
+
);
|
|
171
|
+
if (toolResult.rows.length === 0) {
|
|
172
|
+
console.log("no tool result found");
|
|
173
|
+
return null;
|
|
174
|
+
}
|
|
175
|
+
const toolResultContent = JSON.parse(toolResult.rows[0]?.content);
|
|
176
|
+
const requiredToolResult = toolResultContent.find((part) => part.toolCallId === correspondingToolCallId);
|
|
177
|
+
if (requiredToolResult) {
|
|
178
|
+
return requiredToolResult.result;
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
return null;
|
|
182
|
+
} catch (error) {
|
|
183
|
+
console.log("error getting cached tool result====", error);
|
|
184
|
+
return null;
|
|
185
|
+
} finally {
|
|
186
|
+
client.release();
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
async getContextWindow({
|
|
190
|
+
threadId,
|
|
191
|
+
startDate,
|
|
192
|
+
endDate,
|
|
193
|
+
format = "raw"
|
|
194
|
+
}) {
|
|
195
|
+
await this.ensureTablesExist();
|
|
196
|
+
const client = await this.pool.connect();
|
|
197
|
+
try {
|
|
198
|
+
if (this.MAX_CONTEXT_TOKENS) {
|
|
199
|
+
const result2 = await client.query(
|
|
200
|
+
`WITH RankedMessages AS (
|
|
201
|
+
SELECT *,
|
|
202
|
+
SUM(tokens) OVER (ORDER BY created_at DESC) as running_total
|
|
203
|
+
FROM mastra_messages
|
|
204
|
+
WHERE thread_id = $1
|
|
205
|
+
${startDate ? `AND created_at >= '${startDate.toISOString()}'` : ""}
|
|
206
|
+
${endDate ? `AND created_at <= '${endDate.toISOString()}'` : ""}
|
|
207
|
+
ORDER BY created_at DESC
|
|
208
|
+
)
|
|
209
|
+
SELECT id,
|
|
210
|
+
content,
|
|
211
|
+
role,
|
|
212
|
+
type,
|
|
213
|
+
created_at AS createdAt,
|
|
214
|
+
thread_id AS threadId
|
|
215
|
+
FROM RankedMessages
|
|
216
|
+
WHERE running_total <= $2
|
|
217
|
+
ORDER BY created_at ASC`,
|
|
218
|
+
[threadId, this.MAX_CONTEXT_TOKENS]
|
|
219
|
+
);
|
|
220
|
+
console.log("Format", format);
|
|
221
|
+
return this.parseMessages(result2.rows);
|
|
222
|
+
}
|
|
223
|
+
const result = await client.query(
|
|
224
|
+
`SELECT id,
|
|
225
|
+
content,
|
|
226
|
+
role,
|
|
227
|
+
type,
|
|
228
|
+
created_at AS createdAt,
|
|
229
|
+
thread_id AS threadId
|
|
230
|
+
FROM mastra_messages
|
|
231
|
+
WHERE thread_id = $1
|
|
232
|
+
${startDate ? `AND created_at >= '${startDate.toISOString()}'` : ""}
|
|
233
|
+
${endDate ? `AND created_at <= '${endDate.toISOString()}'` : ""}
|
|
234
|
+
ORDER BY created_at ASC`,
|
|
235
|
+
[threadId]
|
|
236
|
+
);
|
|
237
|
+
console.log("Format", format);
|
|
238
|
+
return this.parseMessages(result.rows);
|
|
239
|
+
} catch (error) {
|
|
240
|
+
console.log("error getting context window====", error);
|
|
241
|
+
return [];
|
|
242
|
+
} finally {
|
|
243
|
+
client.release();
|
|
244
|
+
}
|
|
245
|
+
}
|
|
246
|
+
/**
|
|
247
|
+
* Messages
|
|
248
|
+
*/
|
|
249
|
+
async getMessages({
|
|
250
|
+
threadId
|
|
251
|
+
}) {
|
|
252
|
+
await this.ensureTablesExist();
|
|
253
|
+
const client = await this.pool.connect();
|
|
254
|
+
try {
|
|
255
|
+
const result = await client.query(
|
|
256
|
+
`
|
|
257
|
+
SELECT
|
|
258
|
+
id,
|
|
259
|
+
content,
|
|
260
|
+
role,
|
|
261
|
+
type,
|
|
262
|
+
created_at AS createdAt,
|
|
263
|
+
thread_id AS threadId
|
|
264
|
+
FROM mastra_messages
|
|
265
|
+
WHERE thread_id = $1
|
|
266
|
+
ORDER BY created_at ASC
|
|
267
|
+
`,
|
|
268
|
+
[threadId]
|
|
269
|
+
);
|
|
270
|
+
const messages = this.parseMessages(result.rows);
|
|
271
|
+
const uiMessages = this.convertToUIMessages(messages);
|
|
272
|
+
return { messages, uiMessages };
|
|
273
|
+
} finally {
|
|
274
|
+
client.release();
|
|
275
|
+
}
|
|
276
|
+
}
|
|
277
|
+
async saveMessages({ messages }) {
|
|
278
|
+
await this.ensureTablesExist();
|
|
279
|
+
const client = await this.pool.connect();
|
|
280
|
+
try {
|
|
281
|
+
await client.query("BEGIN");
|
|
282
|
+
for (const message of messages) {
|
|
283
|
+
const { id, content, role, createdAt, threadId, toolCallIds, toolCallArgs, type, toolNames } = message;
|
|
284
|
+
let tokens = null;
|
|
285
|
+
if (type === "text") {
|
|
286
|
+
const contentMssg = role === "assistant" ? content[0]?.text || "" : content;
|
|
287
|
+
tokens = this.estimateTokens(contentMssg);
|
|
288
|
+
}
|
|
289
|
+
const hashedToolCallArgs = toolCallArgs ? toolCallArgs.map(
|
|
290
|
+
(args, index) => crypto.createHash("sha256").update(JSON.stringify({ args, threadId, toolName: toolNames?.[index] })).digest("hex")
|
|
291
|
+
) : null;
|
|
292
|
+
let validArgExists = false;
|
|
293
|
+
if (hashedToolCallArgs?.length) {
|
|
294
|
+
validArgExists = true;
|
|
295
|
+
for (let i = 0; i < hashedToolCallArgs.length; i++) {
|
|
296
|
+
const isValid = await this.validateToolCallArgs({
|
|
297
|
+
hashedArgs: hashedToolCallArgs[i]
|
|
298
|
+
});
|
|
299
|
+
if (!isValid) {
|
|
300
|
+
validArgExists = false;
|
|
301
|
+
break;
|
|
302
|
+
}
|
|
303
|
+
}
|
|
304
|
+
}
|
|
305
|
+
const toolCallArgsExpireAt = !toolCallArgs ? null : validArgExists ? createdAt : new Date(createdAt.getTime() + 5 * 60 * 1e3);
|
|
306
|
+
await client.query(
|
|
307
|
+
`
|
|
308
|
+
INSERT INTO mastra_messages (id, content, role, created_at, thread_id, tool_call_ids, tool_call_args, type, tokens, tool_call_args_expire_at)
|
|
309
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
|
310
|
+
`,
|
|
311
|
+
[
|
|
312
|
+
id,
|
|
313
|
+
JSON.stringify(content),
|
|
314
|
+
role,
|
|
315
|
+
createdAt.toISOString(),
|
|
316
|
+
threadId,
|
|
317
|
+
JSON.stringify(toolCallIds),
|
|
318
|
+
JSON.stringify(hashedToolCallArgs),
|
|
319
|
+
type,
|
|
320
|
+
tokens,
|
|
321
|
+
toolCallArgsExpireAt?.toISOString()
|
|
322
|
+
]
|
|
323
|
+
);
|
|
324
|
+
}
|
|
325
|
+
await client.query("COMMIT");
|
|
326
|
+
return messages;
|
|
327
|
+
} catch (error) {
|
|
328
|
+
await client.query("ROLLBACK");
|
|
329
|
+
throw error;
|
|
330
|
+
} finally {
|
|
331
|
+
client.release();
|
|
332
|
+
}
|
|
333
|
+
}
|
|
334
|
+
async deleteMessage(id) {
|
|
335
|
+
const client = await this.pool.connect();
|
|
336
|
+
try {
|
|
337
|
+
await client.query(
|
|
338
|
+
`
|
|
339
|
+
DELETE FROM mastra_messages
|
|
340
|
+
WHERE id = $1
|
|
341
|
+
`,
|
|
342
|
+
[id]
|
|
343
|
+
);
|
|
344
|
+
} finally {
|
|
345
|
+
client.release();
|
|
346
|
+
}
|
|
347
|
+
}
|
|
348
|
+
/**
|
|
349
|
+
* Table Management
|
|
350
|
+
*/
|
|
351
|
+
async drop() {
|
|
352
|
+
const client = await this.pool.connect();
|
|
353
|
+
await client.query("DELETE FROM mastra_messages");
|
|
354
|
+
await client.query("DELETE FROM mastra_threads");
|
|
355
|
+
client.release();
|
|
356
|
+
await this.pool.end();
|
|
357
|
+
}
|
|
358
|
+
async ensureTablesExist() {
|
|
359
|
+
if (this.hasTables) {
|
|
360
|
+
return;
|
|
361
|
+
}
|
|
362
|
+
const client = await this.pool.connect();
|
|
363
|
+
try {
|
|
364
|
+
const threadsResult = await client.query(`
|
|
365
|
+
SELECT EXISTS (
|
|
366
|
+
SELECT 1
|
|
367
|
+
FROM information_schema.tables
|
|
368
|
+
WHERE table_name = 'mastra_threads'
|
|
369
|
+
);
|
|
370
|
+
`);
|
|
371
|
+
if (!threadsResult?.rows?.[0]?.exists) {
|
|
372
|
+
await client.query(`
|
|
373
|
+
CREATE TABLE IF NOT EXISTS mastra_threads (
|
|
374
|
+
id UUID PRIMARY KEY,
|
|
375
|
+
resourceid TEXT,
|
|
376
|
+
title TEXT,
|
|
377
|
+
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
|
378
|
+
updated_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
|
379
|
+
metadata JSONB
|
|
380
|
+
);
|
|
381
|
+
`);
|
|
382
|
+
}
|
|
383
|
+
const messagesResult = await client.query(`
|
|
384
|
+
SELECT EXISTS (
|
|
385
|
+
SELECT 1
|
|
386
|
+
FROM information_schema.tables
|
|
387
|
+
WHERE table_name = 'mastra_messages'
|
|
388
|
+
);
|
|
389
|
+
`);
|
|
390
|
+
if (!messagesResult?.rows?.[0]?.exists) {
|
|
391
|
+
await client.query(`
|
|
392
|
+
CREATE TABLE IF NOT EXISTS mastra_messages (
|
|
393
|
+
id UUID PRIMARY KEY,
|
|
394
|
+
content TEXT NOT NULL,
|
|
395
|
+
role VARCHAR(20) NOT NULL,
|
|
396
|
+
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
|
397
|
+
tool_call_ids TEXT DEFAULT NULL,
|
|
398
|
+
tool_call_args TEXT DEFAULT NULL,
|
|
399
|
+
tool_call_args_expire_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
|
|
400
|
+
type VARCHAR(20) NOT NULL,
|
|
401
|
+
tokens INTEGER DEFAULT NULL,
|
|
402
|
+
thread_id UUID NOT NULL,
|
|
403
|
+
FOREIGN KEY (thread_id) REFERENCES mastra_threads(id)
|
|
404
|
+
);
|
|
405
|
+
`);
|
|
406
|
+
}
|
|
407
|
+
} finally {
|
|
408
|
+
client.release();
|
|
409
|
+
this.hasTables = true;
|
|
410
|
+
}
|
|
411
|
+
}
|
|
412
|
+
};
|
|
413
|
+
|
|
414
|
+
export { PgMemory };
|